Skip to content

Commit 19fd99f

Browse files
committed
fix kernel & requirements
1 parent 7a7c6c1 commit 19fd99f

File tree

5 files changed

+33
-24
lines changed

5 files changed

+33
-24
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
jax
22
pybind11
3-
scikit-build-core
3+
scikit-build-core
4+
pytest

src/gpu_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
1212
.Arg<ffi::Buffer<ffi::F32>>() // b
1313
.Ret<ffi::Buffer<ffi::F32>>() // result (scalar)
1414
.Ret<ffi::Buffer<ffi::F32>>() // b_plus_1
15-
.Attr<size_t>("n"),
15+
.Attr<int64_t>("n"),
1616
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
1717

1818
// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes
@@ -25,7 +25,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
2525
.Arg<ffi::Buffer<ffi::F32>>() // b_plus_1
2626
.Ret<ffi::Buffer<ffi::F32>>() // a_grad
2727
.Ret<ffi::Buffer<ffi::F32>>() // b_grad
28-
.Attr<size_t>("n"),
28+
.Attr<int64_t>("n"),
2929
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
3030

3131
template <typename T>

src/kernels.cc.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace ffi = xla::ffi;
55
__global__ void FooFwdKernel(const float *a, const float *b, float *result,
6-
float *b_plus_1, size_t n)
6+
float *b_plus_1, int64_t n)
77
{
88
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
99
const size_t grid_stride = blockDim.x * gridDim.x;
@@ -22,7 +22,7 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *result,
2222

2323
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::F32> a,
2424
ffi::Buffer<ffi::F32> b, ffi::ResultBuffer<ffi::F32> result,
25-
ffi::ResultBuffer<ffi::F32> b_plus_1, size_t n)
25+
ffi::ResultBuffer<ffi::F32> b_plus_1, int64_t n)
2626
{
2727
const int block_dim = 128;
2828
const int grid_dim = std::min(32, (int)((n + block_dim - 1) / block_dim));
@@ -47,7 +47,7 @@ __global__ void FooBwdKernel(const float *scalar_grad,
4747
const float *b_plus_1,
4848
float *a_grad,
4949
float *b_grad,
50-
size_t n)
50+
int64_t n)
5151
{
5252
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
5353
const size_t grid_stride = blockDim.x * gridDim.x;
@@ -67,7 +67,7 @@ ffi::Error FooBwdHost(cudaStream_t stream,
6767
ffi::Buffer<ffi::F32> b_plus_1,
6868
ffi::ResultBuffer<ffi::F32> a_grad,
6969
ffi::ResultBuffer<ffi::F32> b_grad,
70-
size_t n)
70+
int64_t n)
7171
{
7272
const int block_dim = 128;
7373
const int grid_dim = std::min(32, (int)((n + block_dim - 1) / block_dim));

src/kernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ namespace ffi = xla::ffi;
88

99
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::F32> a,
1010
ffi::Buffer<ffi::F32> b, ffi::ResultBuffer<ffi::F32> result,
11-
ffi::ResultBuffer<ffi::F32> b_plus_1, size_t n);
11+
ffi::ResultBuffer<ffi::F32> b_plus_1, int64_t n);
1212

1313
ffi::Error FooBwdHost(cudaStream_t stream,
1414
ffi::Buffer<ffi::F32> scalar_grad,
1515
ffi::Buffer<ffi::F32> a,
1616
ffi::Buffer<ffi::F32> b_plus_1,
1717
ffi::ResultBuffer<ffi::F32> a_grad,
1818
ffi::ResultBuffer<ffi::F32> b_grad,
19-
size_t n);
19+
int64_t n);
2020

2121
#endif // KERNELS_H_

src/sooki/ops.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import jax.numpy as jnp
55
import sooki
66

7-
# Note: Using float32 to match FFI expectations
8-
# jax.config.update("jax_enable_x64", True)
7+
jax.config.update("jax_enable_x64", True)
98

109
gpu = False
1110
gpu_targets = {}
@@ -28,33 +27,42 @@
2827
def foo_fwd(a, b):
2928
assert a.shape == b.shape
3029
assert a.dtype == b.dtype
30+
31+
if a.size == 0:
32+
return jnp.array(0.0, dtype=a.dtype), (a, b)
33+
3134
n = np.prod(a.shape).astype(np.int64)
32-
scalar_type = jax.ShapeDtypeStruct((), a.dtype) # scalar output
33-
intermediate_type = jax.ShapeDtypeStruct(a.shape, a.dtype) # b_plus_1 shape
35+
scalar_type = jax.ShapeDtypeStruct((), a.dtype)
36+
intermediate_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
3437

35-
# Use GPU if available, otherwise use CPU
36-
ffi_name = "foo_fwd" if gpu else "foo_fwd_cpu"
38+
def impl(target_name):
39+
return lambda: jax.ffi.ffi_call(
40+
target_name, (scalar_type, intermediate_type), vmap_method="sequential"
41+
)(a, b, n=n)
3742

38-
result, b_plus_1 = jax.ffi.ffi_call(
39-
ffi_name, (scalar_type, intermediate_type), vmap_method="sequential"
40-
)(a, b, n=n)
43+
result, b_plus_1 = jax.lax.platform_dependent(
44+
cpu=impl("foo_fwd_cpu"), cuda=impl("foo_fwd")
45+
)
4146
return result, (a, b_plus_1)
4247

4348

4449
def foo_bwd(res, c_grad):
4550
a, b_plus_1 = res
51+
52+
if a.size == 0:
53+
return jnp.zeros_like(a), jnp.zeros_like(a)
54+
4655
assert c_grad.dtype == a.dtype
4756
assert a.dtype == b_plus_1.dtype
4857
n = np.prod(a.shape).astype(np.int64)
4958
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
5059

51-
# Use GPU if available, otherwise use CPU
52-
ffi_name = "foo_bwd" if gpu else "foo_bwd_cpu"
60+
def impl(target_name):
61+
return lambda: jax.ffi.ffi_call(
62+
target_name, (out_type, out_type), vmap_method="sequential"
63+
)(c_grad, a, b_plus_1, n=n)
5364

54-
# c_grad is now a scalar, pass it directly to the FFI function
55-
return jax.ffi.ffi_call(ffi_name, (out_type, out_type), vmap_method="sequential")(
56-
c_grad, a, b_plus_1, n=n
57-
)
65+
return jax.lax.platform_dependent(cpu=impl("foo_bwd_cpu"), cuda=impl("foo_bwd"))
5866

5967

6068
@jax.custom_vjp

0 commit comments

Comments
 (0)