
numpy的`einsum`提供了一种简洁高效的张量运算方式,通过爱因斯坦求和约定实现元素乘法与求和。本文将深入解析`np.einsum('ijk,jil->kl', a, b)`这类复杂表达式的内部机制,通过中间索引输出和等效循环两种方法,详细阐述其如何基于共享和非共享索引完成张量元素的组合与累加,帮助读者透彻理解其工作原理,从而更有效地利用`einsum`处理复杂的张量操作。
np.einsum(Einstein Summation Convention,爱因斯坦求和约定)是NumPy中一个强大而灵活的函数,用于执行各种张量运算,包括点积、外积、转置、求和、矩阵乘法等。其核心在于通过字符串形式的索引标记来定义输入张量的维度关系和期望的输出张量结构。
一个典型的einsum表达式形如 '输入下标,输入下标->输出下标'。其中,重复出现的下标表示对该维度进行求和(隐式求和),未在输出下标中出现的输入下标也会被求和。本文将以一个具体的例子np.einsum('ijk,jil->kl', a, b)来深入剖析其内部的元素乘法与求和过程。
假设我们有两个张量 a 和 b:
import numpy as np
# 张量 a 的形状为 (4, 2, 1)
a = np.arange(8.).reshape(4, 2, 1)
print("张量 a:\n", a)
# 张量 b 的形状为 (2, 4, 2)
b = np.arange(16.).reshape(2, 4, 2)
print("\n张量 b:\n", b)对于表达式 np.einsum('ijk,jil->kl', a, b):
这意味着:
为了更好地理解 einsum 如何组合并求和元素,我们可以首先生成一个不进行任何求和的中间结果。通过在输出下标中包含所有输入下标,我们可以观察到每个元素乘法的具体结果。
表达式 np.einsum('ijk,jil->ijkl', a, b) 告诉 einsum 保持所有 i, j, k, l 维度,仅执行元素乘法而不进行任何求和。这样,输出张量 ijkl 的每个元素 output[i,j,k,l] 都将是 a[i,j,k] * b[j,i,l] 的结果。
# 步骤 1: 执行元素乘法,不进行任何求和
# 通过在输出下标中包含所有输入下标 (i, j, k, l),可以查看每个元素的乘积。
# 此时输出张量的形状为 (i_len, j_len, k_len, l_len) = (4, 2, 1, 2)
intermediate_product = np.einsum('ijk,jil->ijkl', a, b)
print("中间乘积 (ijkl):\n", intermediate_product)
# 步骤 2: 对 'j' 维度进行求和
# 原始表达式 'ijk,jil->kl' 中,j 是一个被求和的维度。
# 对应到 intermediate_product,j 是其第二个轴 (axis=1)。
sum_over_j = intermediate_product.sum(axis=1)
print("\n对 'j' 维度求和后的结果:\n", sum_over_j)
# 此时张量形状变为 (i_len, k_len, l_len) = (4, 1, 2)
# 步骤 3: 对 'i' 维度进行求和
# 原始表达式 'ijk,jil->kl' 中,i 也是一个被求和的维度。
# 对应到 sum_over_j,i 是其第一个轴 (axis=0)。
final_result_method1 = sum_over_j.sum(axis=0)
print("\n对 'i' 维度求和后的最终结果:\n", final_result_method1)
# 最终张量形状为 (k_len, l_len) = (1, 2),与 'kl' 匹配。通过这种分解方式,我们清晰地看到了 einsum 如何首先进行所有可能的元素乘法,然后按照爱因斯坦求和约定对未出现在输出下标中的维度进行累加求和。
理解 einsum 的另一种有效方法是将其转换为等效的显式循环。这有助于我们追踪每个元素是如何被计算和累加的。
对于 np.einsum('ijk,jil->kl', a, b),我们可以将其转换为以下嵌套循环:
def sum_array_explicit_loop(A, B):
# 获取张量 A 的维度长度
i_len, j_len, k_len = A.shape
# 获取张量 B 的最后一个维度长度(对应 l)
# 注意:B 的实际形状是 (j_max, i_max, l_max),因此 l_len 对应 B.shape[2]
l_len = B.shape[2]
# 初始化结果张量,其形状应为 (k_len, l_len)
ret = np.zeros((k_len, l_len))
# 遍历所有可能的 i, j, k, l 组合
# i 遍历 A 的第一个维度 (0-3)
# j 遍历 A 的第二个维度 (0-1)
# k 遍历 A 的第三个维度 (0-0)
# l 遍历 B 的第三个维度 (0-1)
for i in range(i_len):
for j in range(j_len):
for k in range(k_len):
for l in range(l_len):
# 根据 einsum 表达式 'ijk,jil->kl',执行元素乘法
# A 的索引为 (i, j, k)
# B 的索引为 (j, i, l)
# 将乘积累加到结果张量 ret[k, l] 中
ret[k, l] += A[i, j, k] * B[j, i, l]
return ret
final_result_method2 = sum_array_explicit_loop(a, b)
print("\n通过显式循环计算的最终结果:\n", final_result_method2)通过运行上述代码,我们可以看到它与 np.einsum('ijk,jil->kl', a, b) 直接计算的结果以及方法一分解求和后的结果是完全一致的。
# 验证与直接使用 einsum 的结果是否一致
einsum_direct_result = np.einsum('ijk,jil->kl', a, b)
print("\n直接使用 einsum 的结果:\n", einsum_direct_result)
# 比较两种方法的结果
print("\n两种方法结果是否一致 (方法1 vs 直接einsum):", np.allclose(final_result_method1, einsum_direct_result))
print("两种方法结果是否一致 (方法2 vs 直接einsum):", np.allclose(final_result_method2, einsum_direct_result))这个循环清晰地展示了 einsum 的内部逻辑:它遍历所有与输入张量形状兼容的索引组合,对每个组合执行元素乘法,并将结果累加到由输出下标定义的相应位置。
通过本文的详细解析,相信读者对 np.einsum 在处理复杂张量运算时的内部机制有了更深入的理解。掌握 einsum 不仅能提升代码的简洁性和效率,更能加深对张量代数的理解。
以上就是深入理解 NumPy einsum 的张量运算细节的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号