diff --git a/octo/model/components/block_transformer.py b/octo/model/components/block_transformer.py index c4107e35..851195b6 100644 --- a/octo/model/components/block_transformer.py +++ b/octo/model/components/block_transformer.py @@ -290,7 +290,7 @@ def generate_attention_mask( self.verify_causality(prefix_groups, timestep_groups) def _get_position(i, tokens_per_elem): - return np.searchsorted(np.cumsum(tokens_per_elem), i) + return np.searchsorted(np.cumsum(tokens_per_elem), i, side='right') horizon = timestep_groups[0].tokens.shape[1] tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups]