Skip to content

Conversation

@windmaple
Copy link
Collaborator

@windmaple windmaple commented Nov 5, 2025

This PR enables KV cache for miniGPT inference (inference time on goes from ~9s to ~3s on my Cloudtop). Also updated to use simpler sharding annotation with newer NNX API.

Correctness is validated in https://colab.research.google.com/drive/1Fw2IQjH-UcGReOXw6ykqJaXKWUv3HN_O?resourcekey=0-lNpYdIeKUxoOMfG_KpZfAw&usp=sharing

@windmaple windmaple marked this pull request as ready for review November 5, 2025 06:58
@windmaple
Copy link
Collaborator Author

@emilyfertig This is a bit beyond the original revamping scope but I think it's nice to have because 1) we need to use the newer sharding annotation anyway at some point 2) KV cache is very commonly used for LLM inference. This would be a good reference regardless of whether we can get vLLM TPU integration.

Copy link

@salfaris salfaris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just about to suggest the updated nnx.with_partitioning changes! This is great work :)

Copy link
Collaborator

@emilyfertig emilyfertig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@emilyfertig emilyfertig merged commit bac0c90 into jax-ml:revamp-2025 Nov 6, 2025
3 checks passed
@windmaple windmaple deleted the kvcache branch November 6, 2025 23:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants