
在 sympy 中创建真正可用的自定义符号类需继承 `expr` 并重写 `add`/`mul`/`pow` 的核心方法(如 `flatten` 和 `__new__`),而非仅继承 `symbol`;否则幂运算等复合操作仍会退化为原生类型。
SymPy 的核心代数运算(如 +、*、**)并非完全通过操作符重载(如 __pow__)驱动,而是由顶层工厂类(如 Add、Mul、Pow)统一调度并执行表达式规范化(如合并同类项、指数相加)。因此,单纯让 CustomSymbol 重写 __pow__ 是无效的——当执行 a**2 * a**3 时,SymPy 内部会先将整个乘积识别为 Mul,再调用 Pow.flatten 或直接构造 Pow(a, 5),完全绕过 CustomSymbol.__pow__。
✅ 正确做法是构建一套协同工作的自定义表达式家族:
- 定义基类 CustomExpr:统一实现 __add__/__mul__/__pow__,强制返回自定义运算类;
- 子类化 Symbol + CustomExpr:获得符号语义与自定义行为双重能力;
-
重写 CustomAdd/CustomMul/CustomPow 的关键方法:
- flatten(cls, seq):复制 SymPy 源码,将所有 Add/Mul/Pow 替换为对应自定义类(注意处理系数排序、非交换乘法等细节);
- __new__(仅 CustomPow):缓存并确保返回 CustomPow 实例;
- _eval_subs(self, old, new):返回 None,防止默认替换逻辑将自定义类“降级”为原生类;
-
全局替换运算符别名(可选但推荐):
Add = CustomAdd Mul = CustomMul Pow = CustomPow
这能确保后续所有 SymPy 内部构造(如 expand()、simplify())也使用你的类——否则它们仍会生成原生类型。
⚠️ 注意事项:
- 不要试图只覆盖 __pow__ 或 __mul__:SymPy 的 Mul 构造器会跳过实例方法,直接调用 Mul.flatten;
- flatten 方法极其关键:它负责归并 a**2 * a**3 → a**5,若未重写,结果必为 Pow 而非 CustomPow;
- 所有自定义类必须继承 CustomExpr(而非仅 Expr),以保证运算符重载链完整;
- 若需 expand()、rewrite() 等高级功能,必须为每个自定义类实现对应的 _eval_expand_* 或 _eval_rewrite 方法,否则它们会回退到原生逻辑并破坏类型一致性。
下面是一个最小可行示例(已精简关键逻辑,生产环境请严格按 SymPy 源码补全 flatten):
from sympy import Expr, Symbol, Add, Mul, Pow, S
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort, _keep_coeff
class CustomExpr(Expr):
def __add__(self, other): return CustomAdd(self, other)
def __radd__(self, other): return CustomAdd(other, self)
def __mul__(self, other): return CustomMul(self, other)
def __rmul__(self, other): return CustomMul(other, self)
def __pow__(self, other): return CustomPow(self, other)
class CustomSymbol(CustomExpr, Symbol):
def __new__(cls, name, **kwargs):
return Symbol.__new__(cls, name, **kwargs)
class CustomAdd(CustomExpr, Add):
@classmethod
def flatten(cls, seq):
# 复制 sympy.core.add.Add.flatten 源码,将内部 Add→CustomAdd, Mul→CustomMul, Pow→CustomPow
from sympy.core.add import _addsort
terms = []
for x in seq:
if isinstance(x, cls):
terms.extend(x.args)
else:
terms.append(x)
# ...(省略归并常数、排序等完整逻辑)
coeff, nonnumber = S.Zero, []
for t in terms:
if t.is_Number:
coeff += t
else:
nonnumber.append(t)
if coeff is S.Zero and not nonnumber:
return [S.Zero], []
if coeff is S.Zero:
newseq = nonnumber
else:
newseq = [coeff] + nonnumber
_addsort(newseq) # 排序
return newseq, {}
def _eval_subs(self, old, new): return None
class CustomMul(CustomExpr, Mul):
@classmethod
def flatten(cls, seq):
# 同理:复制 Mul.flatten,替换为 CustomMul/CustomAdd/CustomPow
...
def _eval_subs(self, old, new): return None
class CustomPow(CustomExpr, Pow):
def __new__(cls, b, e, evaluate=None):
# 复制 Pow.__new__,确保返回 CustomPow
if evaluate is None:
evaluate = global_parameters.evaluate
if evaluate:
# ...(标准求值逻辑)
pass
return Expr.__new__(cls, b, e)
def _eval_subs(self, old, new): return None
# 全局启用(关键!)
Add = CustomAdd
Mul = CustomMul
Pow = CustomPow
# 验证
a = CustomSymbol('a')
x = a**2 * a**3
print(type(x)) #
print(x) # a**5 总结:SymPy 的设计决定了自定义符号必须是“生态级”扩展——不是单点重载,而是构建与 Symbol/Add/Mul/Pow 深度耦合的平行类体系。虽然工作量较大,但这是唯一能保证符号计算全程保持自定义语义(如附加元数据、特殊求值规则)的稳健方案。






