Skip to content

Commit 9e744c7

Browse files
committed
fix sbml jax tests
1 parent ddc68fa commit 9e744c7

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

python/sdist/amici/jax/jaxcodeprinter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import sympy as sp
88
from sympy.printing.numpy import NumPyPrinter
9+
from sympy.core.function import UndefinedFunction
910

1011

1112
def _jnp_array_str(array) -> str:
@@ -42,8 +43,11 @@ def _print_Mul(self, expr: sp.Expr) -> str:
4243
return super()._print_Mul(expr)
4344
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"
4445

45-
def _print_Function(self, expr):
46-
return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]"
46+
def _print_Function(self, expr: sp.Expr) -> str:
47+
if isinstance(expr.func, UndefinedFunction):
48+
return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]"
49+
else:
50+
return super()._print_Function(expr)
4751

4852
def _print_Max(self, expr: sp.Expr) -> str:
4953
"""

0 commit comments

Comments
 (0)