File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change 66
77import sympy as sp
88from sympy .printing .numpy import NumPyPrinter
9+ from sympy .core .function import UndefinedFunction
910
1011
1112def _jnp_array_str (array ) -> str :
@@ -42,8 +43,11 @@ def _print_Mul(self, expr: sp.Expr) -> str:
4243 return super ()._print_Mul (expr )
4344 return f"safe_div({ self .doprint (numer )} , { self .doprint (denom )} )"
4445
45- def _print_Function (self , expr ):
46- return f"self.nns['{ expr .func .__name__ } '].forward(jnp.array([{ ', ' .join (self .doprint (a ) for a in expr .args [:- 1 ])} ]))[{ expr .args [- 1 ]} ]"
46+ def _print_Function (self , expr : sp .Expr ) -> str :
47+ if isinstance (expr .func , UndefinedFunction ):
48+ return f"self.nns['{ expr .func .__name__ } '].forward(jnp.array([{ ', ' .join (self .doprint (a ) for a in expr .args [:- 1 ])} ]))[{ expr .args [- 1 ]} ]"
49+ else :
50+ return super ()._print_Function (expr )
4751
4852 def _print_Max (self , expr : sp .Expr ) -> str :
4953 """
You can’t perform that action at this time.
0 commit comments