Autoconvert from float to jax.numpy #197
Unanswered
patel-zeel
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have recently started working with
jaxopt
and I found thatjaxopt
does not auto-convert float tojax.numpy
array. For example, If I run the following code:Example 1
It fails because params is just a float, but if I change last two lines with the following:
The above code works fine. On the other side,
optax
does not fail in this case (maybe becausejax.grad
auto converts float tojax.numpy
). Consider the following example for the same:Example 2
Output
So, I was wondering if
jaxopt
should auto-convert floats tojax.numpy
or will that create some new issues?Edit:
Error trace shows the following lines:
Will it suffice to just replace a few lines like this? (I tried it locally and it works for Example 1)
Beta Was this translation helpful? Give feedback.
All reactions