We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 36fd2e1 commit bae4112Copy full SHA for bae4112
learned_optimization/jax_utils.py
@@ -46,9 +46,12 @@ def body_fn(_, operand):
46
47
def in_jit() -> bool:
48
"""Returns true if tracing jit."""
49
- return "DynamicJaxprTrace" in str(
50
- jax.core.thread_local_state.trace_state.trace_stack
51
- )
+ if jax.__version_info__ <= (0, 4, 33):
+ return "DynamicJaxprTrace" in str(
+ jax.core.thread_local_state.trace_state.trace_stack # type: ignore
52
+ )
53
+
54
+ return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
55
56
57
Carry = TypeVar("Carry")
0 commit comments