0

0

如何在 JAX 中正确计算批量矩阵指数(expm)

心靈之曲

心靈之曲

发布时间:2026-01-15 23:57:28

|

829人浏览过

|

来源于php中文网

原创

如何在 JAX 中正确计算批量矩阵指数(expm)

本文详解 jax 中 `jax.scipy.linalg.expm` 批量计算失败的常见原因与解决方案,涵盖新版原生支持、旧版兼容写法及关键形状调试技巧。

在使用 JAX 计算矩阵指数(如量子线路中的参数化幺正演化 $ e^{iA} $)时,一个典型错误是:

ValueError: expected A to be a square matrix

尽管你确认最后两维是方阵(如 (4, 4)),但报错仍发生——这往往源于 输入张量的维度结构不符合 expm 的隐式批处理规则

? 根本原因:expm 对输入形状有严格要求

jax.scipy.linalg.expm 自 JAX v0.4.7 起原生支持批量输入,但前提是:
✅ 输入数组的最后两个轴必须构成方阵(如 (..., n, n));
❌ 其余前导维度将被自动视为 batch 维度;
❌ 若中间存在非 batch 的冗余维度(如你的 A.shape = (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)),它仍能工作;
⚠️ 但若 A 的最后两维不满足 n == n(例如 (4, 5)),或 A.ndim

在你的代码中,问题出在 pauli_matrix(num_qubits) 的构造逻辑:

def pauli_matrix(num_qubits):
    _pauli_matrices = jnp.array(
        [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
    )
    # ❌ 错误:对 _pauli_matrices 重复 kronecker 积,却未指定作用于哪一组 qubit
    # 且 [1:] 切片导致维度混乱,最终使 tensordot 结果 A 的 shape 不符合预期
    return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]

该函数实际生成的是 (15, 4**num_qubits, 4**num_qubits) 形状的 Pauli 基(对 2-qubit 应为 (15, 4, 4)),但 reduce(jnp.kron, ...) 在 num_qubits=2 时会生成 (4^2, 4^2) = (16, 16) 矩阵,再 [1:] 切片得 (15, 16, 16) —— 而你 theta 是 (15, 2,2,2,2,2,2,2,2),tensordot 后 A 实际为 (2,2,2,2,2,2,2,2, 16, 16),并非你误以为的 (2,...,2,4,4)。因此 expm 接收的不是 (N, 4, 4),而是高维张量,但只要末两维是方阵,新版 JAX 就能处理。

✅ 正确做法:确保 A 的 shape 为 (..., d, d),其中 d = 2**num_qubits。

Copy Leaks
Copy Leaks

AI内容检测和分级,帮助创建和保护原创内容

下载

✅ 解决方案一:升级 JAX 并规范输入(推荐)

确保使用 JAX ≥ 0.4.7:

pip install --upgrade jax jaxlib

然后修正 pauli_matrix 和 SpecialUnitary:

import jax.numpy as jnp
import jax.scipy.linalg as linalg
from functools import reduce

def pauli_basis_1q():
    return jnp.array([
        [[1., 0.], [0., 1.]],   # I
        [[0., 1.], [1., 0.]],   # X
        [[0., -1j], [1j, 0.]],  # Y
        [[1., 0.], [0., -1.]],  # Z
    ])

def pauli_matrix(num_qubits):
    """返回 (4**num_qubits - 1) 个 traceless n-qubit Pauli 算符,shape (15, 4, 4) for n=2"""
    basis = pauli_basis_1q()
    # 构造所有非恒等的 n-qubit Pauli 张量积:共 4^n - 1 个
    from itertools import product
    ops = []
    for indices in product(range(4), repeat=num_qubits):
        if all(i == 0 for i in indices):  # skip identity
            continue
        op = basis[indices[0]]
        for i in indices[1:]:
            op = jnp.kron(op, basis[i])
        ops.append(op)
    return jnp.stack(ops)  # shape: (15, 4, 4) for num_qubits=2

num_qubits = 2
d = 2 ** num_qubits  # 4
theta = jnp.pi * jnp.random.uniform(shape=(15,))  # 简化:单组参数,shape (15,)

A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])  # -> (4, 4)
U = linalg.expm(1j * A / 2)  # ✅ works: (4, 4)

# 批量示例:theta shape (8, 15) → A shape (8, 4, 4) → U shape (8, 4, 4)
theta_batch = jnp.pi * jnp.random.uniform(shape=(8, 15))
A_batch = jnp.einsum('bi,ij->bjk', theta_batch, pauli_matrix(num_qubits))  # (8, 4, 4)
U_batch = linalg.expm(1j * A_batch / 2)  # ✅ native batch support
print(U_batch.shape)  # (8, 4, 4)

⚙️ 解决方案二:旧版 JAX 兼容写法(jnp.vectorize)

若受限于旧版 JAX(

expm_vec = jnp.vectorize(linalg.expm, signature='(n,n)->(n,n)')

# A_batch shape: (B, d, d)
U_batch = expm_vec(1j * A_batch / 2)  # returns (B, d, d)

⚠️ 注意:vectorize 在 JIT 下可能不如原生批量高效,仅作兼容之用。

? 关键检查清单

  • ✅ 使用 A.shape[-2] == A.shape[-1] 验证末两维是否为方阵;
  • ✅ 避免在 tensordot 或 einsum 中引入意外维度(如你的原始 theta 有 9 维,极易出错);
  • ✅ 优先用 einsum 替代嵌套 tensordot 提升可读性;
  • ✅ 调试时打印 A.shape 和 A.dtype,确认无 float64(JAX 默认 float32,expm 要求浮点)。

掌握这些要点,你就能稳健地在 JAX 中实现量子态演化、李群指数映射等核心计算。

相关专题

更多
go语言 数组和切片
go语言 数组和切片

本专题整合了go语言数组和切片的区别与含义,阅读专题下面的文章了解更多详细内容。

46

2025.09.03

Golang gRPC 服务开发与Protobuf实战
Golang gRPC 服务开发与Protobuf实战

本专题系统讲解 Golang 在 gRPC 服务开发中的完整实践,涵盖 Protobuf 定义与代码生成、gRPC 服务端与客户端实现、流式 RPC(Unary/Server/Client/Bidirectional)、错误处理、拦截器、中间件以及与 HTTP/REST 的对接方案。通过实际案例,帮助学习者掌握 使用 Go 构建高性能、强类型、可扩展的 RPC 服务体系,适用于微服务与内部系统通信场景。

8

2026.01.15

公务员递补名单公布时间 公务员递补要求
公务员递补名单公布时间 公务员递补要求

公务员递补名单公布时间不固定,通常在面试前,由招录单位(如国家知识产权局、海关等)发布,依据是原入围考生放弃资格,会按笔试成绩从高到低递补,递补考生需按公告要求限时确认并提交材料,及时参加面试/体检等后续环节。要求核心是按招录单位公告及时响应、提交材料(确认书、资格复审材料)并准时参加面试。

44

2026.01.15

公务员调剂条件 2026调剂公告时间
公务员调剂条件 2026调剂公告时间

(一)符合拟调剂职位所要求的资格条件。 (二)公共科目笔试成绩同时达到拟调剂职位和原报考职位的合格分数线,且考试类别相同。 拟调剂职位设置了专业科目笔试条件的,专业科目笔试成绩还须同时达到合格分数线,且考试类别相同。 (三)未进入原报考职位面试人员名单。

55

2026.01.15

国考成绩查询入口 国考分数公布时间2026
国考成绩查询入口 国考分数公布时间2026

笔试成绩查询入口已开通,考生可登录国家公务员局中央机关及其直属机构2026年度考试录用公务员专题网站http://bm.scs.gov.cn/pp/gkweb/core/web/ui/business/examResult/written_result.html,查询笔试成绩和合格分数线,点击“笔试成绩查询”按钮,凭借身份证及准考证进行查询。

11

2026.01.15

Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

65

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

36

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

75

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

21

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Java 教程
Java 教程

共578课时 | 46.2万人学习

国外Web开发全栈课程全集
国外Web开发全栈课程全集

共12课时 | 1.0万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号