-
Couldn't load subscription status.
- Fork 19
Added the interpolation code #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi Aatman09. Thanks for the quick turn around on this commit! I'm leaving a few comments on the files. |
bonsai/models/vit/modeling.py
Outdated
| gs_old = int(np.sqrt(len(posemb_grid))) | ||
| gs_new = int(np.sqrt(num_tokens)) | ||
|
|
||
| logging.info('interpolate_posembed: grid-size from %s to %s', gs_old, gs_new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove logging since this stateful behavior gets removed during jit.
bonsai/models/vit/modeling.py
Outdated
| posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) | ||
|
|
||
| zoom = (gs_new / gs_old, gs_new / gs_old, 1) | ||
| posemb_grid = scipy.ndimage.zoom(posemb_grid, zoom, order=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using scipy.ndimage.zoom is a nice solution, but it gives a jit tracer error. For example, the code
def test_full_interpolation_jit(self):
image_shape_384 = (self.batch_size, 384, 384, 3)
jx = jax.random.normal(jax.random.key(1), image_shape_384, dtype=jnp.float32)
g, s = nnx.split(self.bonsai_model)
s = jax.tree.leaves(s)
jy = model_lib.forward(g, s, jx)gives the error The numpy.ndarray conversion method __array__() was called on traced array with shape float32[14,14,768].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update this to be compatible with jit for efficiency? In terms of correctness, the PR looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, using scipy.ndimage.zoom can be a performance bottleneck. Instead, we could use jax.image.resize with bicubic with a higher quality interpolation in native jax.
bonsai/models/vit/modeling.py
Outdated
| import scipy.ndimage | ||
| from absl import logging | ||
|
|
||
| def interpolate_posembed(posemb, num_tokens: int, has_class_token: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, could you also add the typehint for posemb and output?
bonsai/models/vit/modeling.py
Outdated
| else: | ||
| posemb_tok, posemb_grid = posemb[:, :0], posemb[0, 0:] | ||
|
|
||
| gs_old = int(np.sqrt(len(posemb_grid))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use jnp equivalents (ex: jnp.sqrt and jnp.concatenate) to be more jax native and jit-compatible?
|
Hello, I wanted to confirm the changes I need to make:
Please let me know if I’m missing anything. |
* general maintenance * Platform updates Added interpolation change Added interpolation change fix(vit): Make positional embedding interpolation JIT-compatible
|
Hello, I've implemented the requested changes. This is ready for re-review when you have a moment. Thank you. |
bonsai/models/vit/modeling.py
Outdated
| import jax.numpy as jnp | ||
| from flax import nnx | ||
|
|
||
| import scipy.ndimage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this since it's no longer used?
Thanks, also we've updated the discord link in guide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick turnaround! :)
Resolves #<57>
This PR adds dynamic positional embedding interpolation to the
ViTClassificationModel(inbonsai/models/vit/modeling.py).Reference
https://github.com/google-research/vision_transformer/blob/c6de1e5378c9831a8477feb30994971bdc409e46/vit_jax/checkpoint.py#L209
Colab Notebook
N/A
Checklist
run_model.pyfor usage,test_outputs.pyandmodel_validation_colab.ipynb(if applicable) for quality).