Skip to content

Conversation

@sguada
Copy link

@sguada sguada commented Dec 16, 2024

This commit converts the core logic of the difflogic library from PyTorch to JAX. The CUDA implementation is rewritten using Pallas kernels. The Python implementation is also converted to JAX. The script is adapted to use JAX for training and evaluation. Basic tests are added to verify the JAX implementation. The and compiled model functionality are not yet ported to JAX and are left as placeholders for future work.

This commit converts the core logic of the difflogic library from PyTorch to JAX.  The CUDA implementation is rewritten using Pallas kernels.  The Python implementation is also converted to JAX.  The  script is adapted to use JAX for training and evaluation.  Basic tests are added to verify the JAX implementation.  The  and compiled model functionality are not yet ported to JAX and are left as placeholders for future work.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant