diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index d3e5ac11f7..2d64edf3fd 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -29,6 +29,7 @@ Erfinv, GammaIncCInv, GammaIncInv, + Hyp2F1, Iv, Ive, Kve, @@ -341,3 +342,11 @@ def softplus(x): ) return softplus + + +@jax_funcify.register(Hyp2F1) +def jax_funcify_Hyp2F1(op, **kwargs): + def hyp2f1(a, b, c, x): + return jax.scipy.special.hyp2f1(a, b, c, x) + + return hyp2f1 diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 463405fff4..038c22e423 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -19,6 +19,7 @@ erfinv, gammainccinv, gammaincinv, + hyp2f1, iv, kve, log, @@ -324,3 +325,22 @@ def test_jax_logp(): value_test_value, ], ) + + +def test_jax_hyp2f1(): + a = vector("a") + b = vector("b") + c = vector("c") + x = vector("x") + out = hyp2f1(a, b, c, x) + + compare_jax_and_py( + [a, b, c, x], + [out], + [ + np.array([0.0, 0.0]), + np.array([0.0, 0.0]), + np.array([0.0, 0.0]), + np.array([0.0, 0.0]), + ], + )