Skip to content

Conversation

@Aatman09
Copy link
Contributor

Resolves #<57>

Please check issues for any pending model implementations. Consider opening issue if none exists.

This PR adds dynamic positional embedding interpolation to the ViTClassificationModel (in bonsai/models/vit/modeling.py).

Reference
https://github.com/google-research/vision_transformer/blob/c6de1e5378c9831a8477feb30994971bdc409e46/vit_jax/checkpoint.py#L209

Colab Notebook
N/A

Checklist

  • [ 1] I have read contribution guidelines.
  • [ 1] I have added all the necessary unit tests for my change. (run_model.py for usage, test_outputs.py and model_validation_colab.ipynb (if applicable) for quality).
  • [1 ] I have verified that my change does not break existing code and all unit tests pass.
  • [1 ] I have added all appropriate doc-strings/documentation.
  • [ 1] My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • [ 1] I have signed the Contributor License Agreement.

@chapman20j
Copy link
Collaborator

Hi Aatman09. Thanks for the quick turn around on this commit! I'm leaving a few comments on the files.

gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(num_tokens))

logging.info('interpolate_posembed: grid-size from %s to %s', gs_old, gs_new)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove logging since this stateful behavior gets removed during jit.

posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = scipy.ndimage.zoom(posemb_grid, zoom, order=3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using scipy.ndimage.zoom is a nice solution, but it gives a jit tracer error. For example, the code

def test_full_interpolation_jit(self):
    image_shape_384 = (self.batch_size, 384, 384, 3)
    jx = jax.random.normal(jax.random.key(1), image_shape_384, dtype=jnp.float32)
    g, s = nnx.split(self.bonsai_model)
    s = jax.tree.leaves(s)
    jy = model_lib.forward(g, s, jx)

gives the error The numpy.ndarray conversion method __array__() was called on traced array with shape float32[14,14,768].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this to be compatible with jit for efficiency? In terms of correctness, the PR looks good.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, using scipy.ndimage.zoom can be a performance bottleneck. Instead, we could use jax.image.resize with bicubic with a higher quality interpolation in native jax.

import scipy.ndimage
from absl import logging

def interpolate_posembed(posemb, num_tokens: int, has_class_token: bool):
Copy link
Member

@jenriver jenriver Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, could you also add the typehint for posemb and output?

else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0, 0:]

gs_old = int(np.sqrt(len(posemb_grid)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use jnp equivalents (ex: jnp.sqrt and jnp.concatenate) to be more jax native and jit-compatible?

@Aatman09
Copy link
Contributor Author

Hello,

I wanted to confirm the changes I need to make:

  1. Remove all instances of logging. Should I replace them with something else, or simply remove them?
  2. Replace scipy.ndimage.zoom with jax.image.resize.
  3. Avoid using NumPy and switch to JAX equivalents.
  4. Ensure the implementation is JIT-compatible.

Please let me know if I’m missing anything.
Also, I wanted to point out that the Discord link is not opening.

* general maintenance

* Platform updates

Added interpolation change

Added interpolation change

fix(vit): Make positional embedding interpolation JIT-compatible
@Aatman09
Copy link
Contributor Author

Hello, I've implemented the requested changes. This is ready for re-review when you have a moment. Thank you.

import jax.numpy as jnp
from flax import nnx

import scipy.ndimage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this since it's no longer used?

@jenriver
Copy link
Member

Hello,

I wanted to confirm the changes I need to make:

  1. Remove all instances of logging. Should I replace them with something else, or simply remove them?
  2. Replace scipy.ndimage.zoom with jax.image.resize.
  3. Avoid using NumPy and switch to JAX equivalents.
  4. Ensure the implementation is JIT-compatible.

Please let me know if I’m missing anything. Also, I wanted to point out that the Discord link is not opening.

Thanks, also we've updated the discord link in guide.

Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick turnaround! :)

@jenriver jenriver merged commit d99af86 into jax-ml:main Oct 28, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants