FFI API tutorial #26602
-
I'm in the process of refactoring a few (recent) kernel bindings we had working, since the XLA Thanks for the tutorial and relifted API, which will probably help reduce boilerplate. Let me however a few points of improvement which would really help:
Thanks a lot in advance, I think this might help future readers too! 🙏 (cc @dfm, hoping you have some bandwidth for this, apologies if not) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
PS : on a side note,
|
Beta Was this translation helpful? Give feedback.
-
Hi again @dfm, thanks a lot for the dispatch addition in the tutorial! 🙏 I'm only finding the time to update my JAX bindings now and this is very useful. However, you might still want to fix the header
With some kind of disclaimer, as with JAX 0.4.34 I still have the following error: >>> from jax import ffi
ModuleNotFoundError: No module named jax.ffi Since ideally the kernels should be useable across a range of JAX versions in downstream projects, I'm also wondering how far |
Beta Was this translation helpful? Give feedback.
The rms_norm example code now includes an example of how I do dtype dispatching with the FFI. Here's the relevant PR: #26607