Skip to content

Commit bae4112

Browse files
dougalmlearned_optimization authors
authored andcommitted
Stackless yashful
PiperOrigin-RevId: 681582933
1 parent 36fd2e1 commit bae4112

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

learned_optimization/jax_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ def body_fn(_, operand):
4646

4747
def in_jit() -> bool:
4848
"""Returns true if tracing jit."""
49-
return "DynamicJaxprTrace" in str(
50-
jax.core.thread_local_state.trace_state.trace_stack
51-
)
49+
if jax.__version_info__ <= (0, 4, 33):
50+
return "DynamicJaxprTrace" in str(
51+
jax.core.thread_local_state.trace_state.trace_stack # type: ignore
52+
)
53+
54+
return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
5255

5356

5457
Carry = TypeVar("Carry")

0 commit comments

Comments
 (0)