
本文详细解析 numpy `einsum` 在处理多张量求和时的内部机制。通过逐步分解求和过程和提供等效的显式循环实现,帮助读者理解 `einsum` 如何根据索引字符串高效地执行元素乘法、重排和特定维度上的求和操作,从而掌握其在复杂张量运算中的应用细节。
NumPy 的 einsum 函数提供了一种极其灵活且高效的方式来执行张量运算,包括点积、转置、求和、矩阵乘法等。其核心在于通过一个简洁的字符串表达式来定义输入张量的索引关系以及输出张量的索引顺序。然而,当涉及到多个张量的复杂求和(收缩)操作时,理解其内部元素的组合和求和过程可能会变得有些抽象。本文将深入探讨 np.einsum('ijk,jil->kl', a, b) 这一特定操作的细节,帮助读者透彻理解其背后的机制。
首先,我们来解析 np.einsum('ijk,jil->kl', a, b) 中的索引字符串:
理解操作规则:
简而言之,np.einsum('ijk,jil->kl', a, b) 的数学表达式等价于: $$ \text{output}_{kl} = \sum_i \sumj \text{a}{ijk} \cdot \text{b}_{jil} $$
为了更直观地理解 einsum 的求和细节,我们可以通过一个技巧来逐步分解它。这个技巧是先执行所有元素的乘法而不进行任何求和,然后手动执行求和步骤。
假设我们有以下两个 NumPy 张量:
import numpy as np
a = np.arange(8.).reshape(4, 2, 1)
b = np.arange(16.).reshape(2, 4, 2)
print("张量 a 的形状:", a.shape) # (4, 2, 1)
print("张量 b 的形状:", b.shape) # (2, 4, 2)步骤一:生成所有未求和的乘积
我们可以通过在输出索引中包含所有输入索引来阻止 einsum 进行求和。对于 ijk,jil->kl,如果我们将输出定义为 ijkl,则 einsum 将返回所有 a[i,j,k] * b[j,i,l] 的乘积,但不会进行任何求和。
# 生成所有元素的乘积,不进行求和
intermediate_products = np.einsum('ijk,jil->ijkl', a, b)
print("\n所有未求和的乘积 (形状: i, j, k, l):")
print(intermediate_products)
print("形状:", intermediate_products.shape) # (4, 2, 1, 2)在这个 intermediate_products 张量中,每个元素 [i, j, k, l] 都对应着 a[i, j, k] * b[j, i, l] 的乘积。例如,intermediate_products[0, 0, 0, 0] 对应 a[0, 0, 0] * b[0, 0, 0]。
步骤二:逐步执行求和
现在,我们知道 i 和 j 是需要被求和的维度。在 intermediate_products 张量中,i 对应轴 0,j 对应轴 1。我们可以逐个对这些轴进行求和。
首先,对 j 轴(轴 1)进行求和:
# 对 j 轴 (轴 1) 进行求和
sum_over_j = intermediate_products.sum(axis=1)
print("\n对 j 轴求和后的结果 (形状: i, k, l):")
print(sum_over_j)
print("形状:", sum_over_j.shape) # (4, 1, 2)接下来,对 i 轴(轴 0)进行求和:
# 对 i 轴 (轴 0) 进行求和
final_result = sum_over_j.sum(axis=0)
print("\n对 i 轴求和后的最终结果 (形状: k, l):")
print(final_result)
print("形状:", final_result.shape) # (1, 2)为了验证,我们可以直接运行原始的 einsum 操作:
original_einsum_result = np.einsum('ijk,jil->kl', a, b)
print("\n原始 einsum 结果 (形状: k, l):")
print(original_einsum_result)
print("形状:", original_einsum_result.shape) # (1, 2)
# 验证结果是否一致
print("\n逐步求和结果与原始 einsum 结果是否一致:", np.allclose(final_result, original_einsum_result))通过这种逐步分解的方式,我们清晰地看到了 einsum 如何先进行元素乘法,然后对指定维度进行求和,最终得到结果。
另一种理解 einsum 细节的方式是将其转换为等效的显式循环。这有助于我们从最基本的元素层面观察操作。
def sum_array_explicit_loop(A, B):
    # 获取张量 A 的形状 (i_len, j_len, k_len)
    i_len_a, j_len_a, k_len_a = A.shape
    # 获取张量 B 的形状,这里我们只关心与输出相关的维度 (j_len, i_len, l_len)
    # 实际上,B 的形状是 (j_len_b, i_len_b, l_len_b)
    # 为了匹配 einsum 的索引,B 的实际形状是 (j_len_from_B, i_len_from_B, l_len_from_B)
    # 我们需要确保 A 和 B 的匹配维度长度一致
    j_len_b, i_len_b, l_len_b = B.shape
    # 检查维度兼容性(einsum 会自动处理)
    if not (j_len_a == j_len_b and i_len_a == i_len_b):
        raise ValueError("张量维度不兼容")
    # 初始化结果张量,其形状为 (k_len, l_len)
    ret = np.zeros((k_len_a, l_len_b))
    # 遍历所有可能的 i, j, k, l 组合
    # i 和 j 是将被求和的维度
    # k 和 l 是输出张量的维度
    for i in range(i_len_a): # 遍历 A 的第一个维度 (i)
        for j in range(j_len_a): # 遍历 A 的第二个维度 (j)
            for k in range(k_len_a): # 遍历 A 的第三个维度 (k)
                for l in range(l_len_b): # 遍历 B 的第三个维度 (l)
                    # 执行元素乘法并累加到 ret[k, l]
                    # 注意 B 的索引是 j, i, l,与 einsum 字符串 'jil' 对应
                    ret[k, l] += A[i, j, k] * B[j, i, l]
    return ret
# 使用显式循环计算结果
explicit_loop_result = sum_array_explicit_loop(a, b)
print("\n显式循环计算结果:")
print(explicit_loop_result)
# 验证结果是否与原始 einsum 一致
print("显式循环结果与原始 einsum 结果是否一致:", np.allclose(explicit_loop_result, original_einsum_result))通过这个显式循环,我们可以清晰地看到:
通过本文的详细解析,相信读者对 np.einsum 在处理多张量求和时的内部工作机制有了更深入的理解。掌握 einsum 将使您能够更高效、更灵活地处理各种张量计算任务。
以上就是深入理解 NumPy einsum:多张量求和与索引机制详解的详细内容,更多请关注php中文网其它相关文章!
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号