
在许多数据处理和机器学习任务中,我们可能需要计算两组向量集 A 和 B 之间的所有成对距离。然而,在某些特定场景下,我们仅对其中一小部分成对距离感兴趣,例如,当一个掩码矩阵 M 指定了需要保留的距离对时。
传统的NumPy方法通常涉及计算所有可能的成对距离,然后通过掩码矩阵进行筛选。以下是一个示例:
import numpy as np
A = np.array([[1, 2], [2, 3], [3, 4]]) # (3, 2)
B = np.array([[4, 5], [5, 6], [6, 7], [7, 8], [8, 9]]) # (5, 2)
M = np.array([[0, 0, 0, 1, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 1]]) # (3, 5)
# 计算所有向量对的差值
diff = A[:, None] - B[None, :] # (3, 5, 2)
# 计算所有成对距离(L2范数)
distances = np.linalg.norm(diff, ord=2, axis=2) # (3, 5)
# 应用掩码,保留所需距离
masked_distances = distances * M # (3, 5)
print("计算的距离矩阵:\n", distances)
print("掩码后的距离矩阵:\n", masked_distances)这种方法虽然简洁,但当 A 和 B 的行数非常大时(例如数千行),diff 和 distances 矩阵会变得非常庞大,导致计算大量不必要的距离,从而消耗大量的计算资源和内存。即使通过 np.vectorize 尝试创建条件函数,也可能因为Python循环的开销而导致性能不佳,甚至更慢。
为了解决上述性能瓶颈,我们引入一种结合 Numba 即时编译和 SciPy 稀疏矩阵(特别是 Compressed Sparse Row, CSR 格式)的优化方案。该方案的核心思想是:
立即学习“Python免费学习笔记(深入)”;
首先,我们定义一个 Numba 加速的欧几里得距离函数。在 Numba 环境下,自定义的循环计算通常比调用 np.linalg.norm 更快。
import numba as nb
import numpy as np
import scipy.sparse
import math
@nb.njit()
def euclidean_distance(vec_a, vec_b):
"""
计算两个向量之间的欧几里得距离。
使用 Numba 加速,避免 np.linalg.norm 的开销。
"""
acc = 0.0
for i in range(vec_a.shape[0]):
acc += (vec_a[i] - vec_b[i]) ** 2
return math.sqrt(acc)这里,@nb.njit() 装饰器指示 Numba 在函数首次调用时将其编译为优化的机器码。
接下来,我们创建 masked_distance_inner 函数。这是一个 Numba 加速的核心函数,负责遍历掩码矩阵,只计算所需的距离,并将结果填充到 CSR 矩阵所需的 data、indicies 和 indptr 数组中。
@nb.njit()
def masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask):
"""
Numba 加速的核心函数,根据掩码计算距离并填充 CSR 矩阵的内部数组。
参数:
data (np.ndarray): 存储非零距离值的数组。
indicies (np.ndarray): 存储非零距离值对应列索引的数组。
indptr (np.ndarray): 存储每行在 data/indicies 中起始位置的数组。
matrix_a (np.ndarray): 第一个向量集。
matrix_b (np.ndarray): 第二个向量集。
mask (np.ndarray): 布尔掩码矩阵,指示哪些距离需要计算。
"""
write_pos = 0
N, M = matrix_a.shape[0], matrix_b.shape[0]
# 遍历所有可能的向量对
for i in range(N):
for j in range(M):
# 只有当掩码为 True 时才计算距离
if mask[i, j]:
# 记录距离值
data[write_pos] = euclidean_distance(matrix_a[i], matrix_b[j])
# 记录该距离值对应的列索引
indicies[write_pos] = j
write_pos += 1
# 记录当前行结束后,data/indicies 中元素的总数,作为下一行的起始位置
indptr[i + 1] = write_pos
# 确保所有预分配的空间都被使用
assert write_pos == data.shape[0]
assert write_pos == indicies.shape[0]
# data, indicies, indptr 会在函数外部被修改并用于构建 CSR 矩阵最后,我们定义 masked_distance 函数,它负责设置算法的参数、预分配内存,并调用 masked_distance_inner 来执行计算,最终返回一个 scipy.sparse.csr_matrix 对象。
def masked_distance(matrix_a, matrix_b, mask):
"""
计算两组向量之间掩码指定的稀疏成对距离。
参数:
matrix_a (np.ndarray): 第一个向量集。
matrix_b (np.ndarray): 第二个向量集。
mask (np.ndarray): 布尔掩码矩阵。
返回:
scipy.sparse.csr_matrix: 包含指定成对距离的稀疏矩阵。
"""
N, M = matrix_a.shape[0], matrix_b.shape[0]
assert mask.shape == (N, M), "掩码矩阵的形状必须与向量集兼容。"
# 确保掩码是布尔类型
mask = mask != 0
# 计算稀疏矩阵中非零元素的总数
sparse_length = mask.sum()
# 为 CSR 矩阵预分配内存。这些数组不需要初始化为零,直接分配内存更高效。
data = np.empty(sparse_length, dtype='float64') # 存储非零数据值
indicies = np.empty(sparse_length, dtype='int64') # 存储列索引
indptr = np.zeros(N + 1, dtype='int64') # 存储行指针
# 调用 Numba 加速的核心函数进行计算和填充
masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask)
# 使用填充好的数据构建 CSR 稀疏矩阵
return scipy.sparse.csr_matrix((data, indicies, indptr), shape=(N, M))为了演示和评估其性能,我们使用更大的随机生成数据集进行测试。
# 准备大型测试数据
A_big = np.random.rand(2000, 10)
B_big = np.random.rand(4000, 10)
# 创建一个高度稀疏的掩码(0.1% 的元素为 True)
M_big = np.random.rand(A_big.shape[0], B_big.shape[0]) < 0.001
# 使用优化的方法计算稀疏距离
sparse_distances = masked_distance(A_big, B_big, M_big)
print(f"稀疏距离矩阵的形状: {sparse_distances.shape}")
print(f"稀疏距离矩阵的非零元素数量: {sparse_distances.nnz}")
print(f"稀疏距离矩阵的密度: {sparse_distances.nnz / (sparse_distances.shape[0] * sparse_distances.shape[1]):.6f}")
# 性能基准测试 (在Jupyter/IPython环境中运行)
# %timeit masked_distance(A_big, B_big, M_big)
#
# 原始方法的性能基准测试 (仅供参考,不推荐在生产环境运行大型矩阵)
# %timeit np.linalg.norm(A_big[:,None] - B_big[None,:], ord=2, axis=2) * M_big在上述 A_big (2000x10) 和 B_big (4000x10) 的测试场景中,当掩码 M_big 只有约 0.1% 的元素为 True 时,此优化方案相比原始的 NumPy 全量计算方法,可以实现显著的性能提升(例如,40倍甚至更高)。具体的加速效果会随着矩阵大小和掩码稀疏度的增加而更加明显。
通过结合 Numba 的即时编译能力和 SciPy 的 CSR 稀疏矩阵格式,我们能够高效地计算两组向量之间指定的一小部分成对距离。这种方法通过避免不必要的计算和优化内存使用,为处理大规模稀疏距离计算问题提供了一个强大且高性能的解决方案。在面临大量数据且仅需少量成对距离的场景时,采用此教程介绍的方案将显著提升应用程序的性能和资源利用率。
以上就是高效计算Python中的稀疏成对距离的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号