-
I have some code to wrap selected nodes in any given PyTree. The way I achieved that was to traverse the tree by flattening one level, wrapping the node and unflatten again, with the wrapped node. For the implementation I relied on the private https://github.com/adonath/clouseau/blob/main/clouseau/jax_utils.py#L59 In JAX 0.6 the So the concrete questions are:
I would be happy for any hints or suggestions for alternative solutions! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can do this: jax.tree.flatten(x, is_leaf = lambda node: node is not x) |
Beta Was this translation helpful? Give feedback.
You can do this: