Skip to content

Commit daedc21

Browse files
log prob should be a pytree node
1 parent 2756c6f commit daedc21

File tree

1 file changed

+0
-1
lines changed

1 file changed

+0
-1
lines changed

jetstream/engine/engine_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class ResultTokens(abc.ABC):
9494
)
9595
# log probabilities of the tokens. Shape: [batch, tokens]
9696
log_prob: Union[jax.Array, np.ndarray] = struct.field(
97-
pytree_node=False,
9897
default=None,
9998
)
10099

0 commit comments

Comments
 (0)