einsum字符串需确保输入维度标签与输出标签严格匹配,字母顺序须与张量ndim一致,重复字母表求和或对角线,跨输入重复触发求和,空输出表示标量,省略号要求前缀维度对齐。

einsum 字符串怎么写才不报错
核心是让输入张量的维度标签和输出标签严格匹配,einsum 不会自动广播或对齐轴,写错一个字母就直接抛 ValueError: operands could not be broadcast together 或 IndexError。
- 每个输入张量用一串字母(如
"ij"、"jk")表示其形状,字母顺序必须和ndim一致 - 输出标签必须是所有输入中出现过的字母子集;没出现在输出里的字母,就代表要沿该轴求和
- 重复字母只在同一个输入中允许(如
"ii"表示对角线),跨输入重复则触发求和(如"ij,jk"中的j) - 空输出(如
"ii->")表示标量结果,不能漏掉箭头后的空字符串
替代 dot / matmul 的常见写法
多数矩阵乘、转置、迹运算都能用 einsum 更清晰地表达,且避免临时数组分配。
- 矩阵乘
A @ B→np.einsum("ij,jk->ik", A, B)(比dot更显式控制哪维参与计算) - 批量矩阵乘
B @ C,其中B.shape = (b, i, j),C.shape = (b, j, k)→np.einsum("bij,bjk->bik", B, C) - 提取对角线
np.diag(A)→np.einsum("ii->i", A);求迹 →np.einsum("ii->", A) - 外积
np.outer(u, v)→np.einsum("i,j->ij", u, v)
为什么有时比内置函数还慢
einsum 默认走通用路径,对简单操作(如二维矩阵乘)不如高度优化的 BLAS 后端快;是否加速取决于操作复杂度和数据规模。
- 小矩阵(
):通常matmul或dot更快,einsum有解析字符串开销 - 高维或带求和+广播混合的操作(如
"ab,cd,be->acde"):einsum可能显著胜出,因避免多个中间数组 - 加
optimize=True(如np.einsum("...,...->...", A, B, optimize=True))可启用路径优化,对三阶及以上张量尤其重要 - 注意:
optimize="greedy"或"optimal"会增加预处理时间,仅当反复调用同结构时值得开启
容易被忽略的内存与 dtype 陷阱
einsum 默认按输入中最高精度 dtype 输出,但不会自动提升整数精度;同时,它不共享内存,结果总是新分配数组。
- 两个
int32矩阵相乘,结果仍是int32,可能溢出;需显式转成float64或用dtype=np.float64参数指定 - 传入视图(如切片)时,输出一定是副本,无法通过
out=参数复用内存(einsum不支持out参数) - 含省略号
...时(如"...ij,...jk->...ik"),要确保前缀维度完全对齐,否则运行时报错而非静默截断
einsum 版本验证逻辑,再测性能;别为了“看起来高级”硬套,尤其是二维场景下 @ 还是最稳的。










