深入理解 NumPy einsum 的张量运算细节

花韻仙語
发布: 2025-10-23 09:27:10
原创
889人浏览过

深入理解 NumPy einsum 的张量运算细节

numpy的`einsum`提供了一种简洁高效的张量运算方式,通过爱因斯坦求和约定实现元素乘法与求和。本文将深入解析`np.einsum('ijk,jil->kl', a, b)`这类复杂表达式的内部机制,通过中间索引输出和等效循环两种方法,详细阐述其如何基于共享和非共享索引完成张量元素的组合与累加,帮助读者透彻理解其工作原理,从而更有效地利用`einsum`处理复杂的张量操作。

1. np.einsum 简介与核心机制

np.einsum(Einstein Summation Convention,爱因斯坦求和约定)是NumPy中一个强大而灵活的函数,用于执行各种张量运算,包括点积、外积、转置、求和、矩阵乘法等。其核心在于通过字符串形式的索引标记来定义输入张量的维度关系和期望的输出张量结构。

一个典型的einsum表达式形如 '输入下标,输入下标->输出下标'。其中,重复出现的下标表示对该维度进行求和(隐式求和),未在输出下标中出现的输入下标也会被求和。本文将以一个具体的例子np.einsum('ijk,jil->kl', a, b)来深入剖析其内部的元素乘法与求和过程。

假设我们有两个张量 a 和 b:

import numpy as np

# 张量 a 的形状为 (4, 2, 1)
a = np.arange(8.).reshape(4, 2, 1)
print("张量 a:\n", a)
# 张量 b 的形状为 (2, 4, 2)
b = np.arange(16.).reshape(2, 4, 2)
print("\n张量 b:\n", b)
登录后复制

对于表达式 np.einsum('ijk,jil->kl', a, b):

  • 张量 a 的维度由 i, j, k 表示。
  • 张量 b 的维度由 j, i, l 表示。
  • 输出张量的维度由 k, l 表示。

这意味着:

  1. a 的第一个维度 i 与 b 的第二个维度 i 相乘。
  2. a 的第二个维度 j 与 b 的第一个维度 j 相乘。
  3. k 和 l 是输出张量的维度。
  4. 由于 i 和 j 在输出下标 kl 中没有出现,因此将对 i 和 j 维度进行求和。

2. 方法一:通过中间输出分解求和过程

为了更好地理解 einsum 如何组合并求和元素,我们可以首先生成一个不进行任何求和的中间结果。通过在输出下标中包含所有输入下标,我们可以观察到每个元素乘法的具体结果。

表达式 np.einsum('ijk,jil->ijkl', a, b) 告诉 einsum 保持所有 i, j, k, l 维度,仅执行元素乘法而不进行任何求和。这样,输出张量 ijkl 的每个元素 output[i,j,k,l] 都将是 a[i,j,k] * b[j,i,l] 的结果。

商汤商量
商汤商量

商汤科技研发的AI对话工具,商量商量,都能解决。

商汤商量36
查看详情 商汤商量
# 步骤 1: 执行元素乘法,不进行任何求和
# 通过在输出下标中包含所有输入下标 (i, j, k, l),可以查看每个元素的乘积。
# 此时输出张量的形状为 (i_len, j_len, k_len, l_len) = (4, 2, 1, 2)
intermediate_product = np.einsum('ijk,jil->ijkl', a, b)
print("中间乘积 (ijkl):\n", intermediate_product)

# 步骤 2: 对 'j' 维度进行求和
# 原始表达式 'ijk,jil->kl' 中,j 是一个被求和的维度。
# 对应到 intermediate_product,j 是其第二个轴 (axis=1)。
sum_over_j = intermediate_product.sum(axis=1)
print("\n对 'j' 维度求和后的结果:\n", sum_over_j)
# 此时张量形状变为 (i_len, k_len, l_len) = (4, 1, 2)

# 步骤 3: 对 'i' 维度进行求和
# 原始表达式 'ijk,jil->kl' 中,i 也是一个被求和的维度。
# 对应到 sum_over_j,i 是其第一个轴 (axis=0)。
final_result_method1 = sum_over_j.sum(axis=0)
print("\n对 'i' 维度求和后的最终结果:\n", final_result_method1)
# 最终张量形状为 (k_len, l_len) = (1, 2),与 'kl' 匹配。
登录后复制

通过这种分解方式,我们清晰地看到了 einsum 如何首先进行所有可能的元素乘法,然后按照爱因斯坦求和约定对未出现在输出下标中的维度进行累加求和。

3. 方法二:等效循环实现

理解 einsum 的另一种有效方法是将其转换为等效的显式循环。这有助于我们追踪每个元素是如何被计算和累加的。

对于 np.einsum('ijk,jil->kl', a, b),我们可以将其转换为以下嵌套循环:

def sum_array_explicit_loop(A, B):
    # 获取张量 A 的维度长度
    i_len, j_len, k_len = A.shape
    # 获取张量 B 的最后一个维度长度(对应 l)
    # 注意:B 的实际形状是 (j_max, i_max, l_max),因此 l_len 对应 B.shape[2]
    l_len = B.shape[2]

    # 初始化结果张量,其形状应为 (k_len, l_len)
    ret = np.zeros((k_len, l_len))

    # 遍历所有可能的 i, j, k, l 组合
    # i 遍历 A 的第一个维度 (0-3)
    # j 遍历 A 的第二个维度 (0-1)
    # k 遍历 A 的第三个维度 (0-0)
    # l 遍历 B 的第三个维度 (0-1)
    for i in range(i_len):
        for j in range(j_len):
            for k in range(k_len):
                for l in range(l_len):
                    # 根据 einsum 表达式 'ijk,jil->kl',执行元素乘法
                    # A 的索引为 (i, j, k)
                    # B 的索引为 (j, i, l)
                    # 将乘积累加到结果张量 ret[k, l] 中
                    ret[k, l] += A[i, j, k] * B[j, i, l]
    return ret

final_result_method2 = sum_array_explicit_loop(a, b)
print("\n通过显式循环计算的最终结果:\n", final_result_method2)
登录后复制

通过运行上述代码,我们可以看到它与 np.einsum('ijk,jil->kl', a, b) 直接计算的结果以及方法一分解求和后的结果是完全一致的。

# 验证与直接使用 einsum 的结果是否一致
einsum_direct_result = np.einsum('ijk,jil->kl', a, b)
print("\n直接使用 einsum 的结果:\n", einsum_direct_result)

# 比较两种方法的结果
print("\n两种方法结果是否一致 (方法1 vs 直接einsum):", np.allclose(final_result_method1, einsum_direct_result))
print("两种方法结果是否一致 (方法2 vs 直接einsum):", np.allclose(final_result_method2, einsum_direct_result))
登录后复制

这个循环清晰地展示了 einsum 的内部逻辑:它遍历所有与输入张量形状兼容的索引组合,对每个组合执行元素乘法,并将结果累加到由输出下标定义的相应位置。

4. 注意事项与总结

  • 效率:尽管显式循环有助于理解,但在实际应用中,np.einsum 通常比手动编写的Python循环效率高得多,因为它在底层利用了优化的C或Fortran实现。
  • 索引的对应关系:理解输入张量和输出张量中索引的对应关系是使用 einsum 的关键。共享的输入索引表示这些维度必须匹配,并且如果它们未出现在输出中,则会进行求和。
  • 灵活性:einsum 可以表达非常广泛的张量操作,从简单的转置('ij->ji')到复杂的张量积。熟练掌握其索引规则能够极大地简化张量运算的代码。
  • 调试:当遇到复杂的 einsum 表达式时,使用本文介绍的“中间输出”方法(即在输出下标中包含所有输入下标)是一个非常有用的调试技巧,可以帮助你一步步追踪计算过程。

通过本文的详细解析,相信读者对 np.einsum 在处理复杂张量运算时的内部机制有了更深入的理解。掌握 einsum 不仅能提升代码的简洁性和效率,更能加深对张量代数的理解。

以上就是深入理解 NumPy einsum 的张量运算细节的详细内容,更多请关注php中文网其它相关文章!

相关标签:
最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号