|
4 | 4 | import jax.numpy as jnp |
5 | 5 | import sooki |
6 | 6 |
|
7 | | -# Note: Using float32 to match FFI expectations |
8 | | -# jax.config.update("jax_enable_x64", True) |
| 7 | +jax.config.update("jax_enable_x64", True) |
9 | 8 |
|
10 | 9 | gpu = False |
11 | 10 | gpu_targets = {} |
|
28 | 27 | def foo_fwd(a, b): |
29 | 28 | assert a.shape == b.shape |
30 | 29 | assert a.dtype == b.dtype |
| 30 | + |
| 31 | + if a.size == 0: |
| 32 | + return jnp.array(0.0, dtype=a.dtype), (a, b) |
| 33 | + |
31 | 34 | n = np.prod(a.shape).astype(np.int64) |
32 | | - scalar_type = jax.ShapeDtypeStruct((), a.dtype) # scalar output |
33 | | - intermediate_type = jax.ShapeDtypeStruct(a.shape, a.dtype) # b_plus_1 shape |
| 35 | + scalar_type = jax.ShapeDtypeStruct((), a.dtype) |
| 36 | + intermediate_type = jax.ShapeDtypeStruct(a.shape, a.dtype) |
34 | 37 |
|
35 | | - # Use GPU if available, otherwise use CPU |
36 | | - ffi_name = "foo_fwd" if gpu else "foo_fwd_cpu" |
| 38 | + def impl(target_name): |
| 39 | + return lambda: jax.ffi.ffi_call( |
| 40 | + target_name, (scalar_type, intermediate_type), vmap_method="sequential" |
| 41 | + )(a, b, n=n) |
37 | 42 |
|
38 | | - result, b_plus_1 = jax.ffi.ffi_call( |
39 | | - ffi_name, (scalar_type, intermediate_type), vmap_method="sequential" |
40 | | - )(a, b, n=n) |
| 43 | + result, b_plus_1 = jax.lax.platform_dependent( |
| 44 | + cpu=impl("foo_fwd_cpu"), cuda=impl("foo_fwd") |
| 45 | + ) |
41 | 46 | return result, (a, b_plus_1) |
42 | 47 |
|
43 | 48 |
|
44 | 49 | def foo_bwd(res, c_grad): |
45 | 50 | a, b_plus_1 = res |
| 51 | + |
| 52 | + if a.size == 0: |
| 53 | + return jnp.zeros_like(a), jnp.zeros_like(a) |
| 54 | + |
46 | 55 | assert c_grad.dtype == a.dtype |
47 | 56 | assert a.dtype == b_plus_1.dtype |
48 | 57 | n = np.prod(a.shape).astype(np.int64) |
49 | 58 | out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) |
50 | 59 |
|
51 | | - # Use GPU if available, otherwise use CPU |
52 | | - ffi_name = "foo_bwd" if gpu else "foo_bwd_cpu" |
| 60 | + def impl(target_name): |
| 61 | + return lambda: jax.ffi.ffi_call( |
| 62 | + target_name, (out_type, out_type), vmap_method="sequential" |
| 63 | + )(c_grad, a, b_plus_1, n=n) |
53 | 64 |
|
54 | | - # c_grad is now a scalar, pass it directly to the FFI function |
55 | | - return jax.ffi.ffi_call(ffi_name, (out_type, out_type), vmap_method="sequential")( |
56 | | - c_grad, a, b_plus_1, n=n |
57 | | - ) |
| 65 | + return jax.lax.platform_dependent(cpu=impl("foo_bwd_cpu"), cuda=impl("foo_bwd")) |
58 | 66 |
|
59 | 67 |
|
60 | 68 | @jax.custom_vjp |
|
0 commit comments