Skip to content

Commit e893726

Browse files
committed
1 parent ee1207e commit e893726

File tree

6 files changed

+25
-15
lines changed

6 files changed

+25
-15
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
julia --project -e '
137137
using Pkg
138138
Pkg.develop(path="lib/intrinsics")
139-
Pkg.add(name="GPUCompiler", rev="sds/additional_args")'
139+
Pkg.add(name="GPUCompiler", rev="sds/add_input_args")'
140140
141141
- name: Test OpenCL.jl
142142
uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
2222
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2323

2424
[sources]
25-
GPUCompiler = {rev = "sds/additional_args", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
25+
GPUCompiler = {rev = "sds/add_input_args", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
2626
SPIRVIntrinsics = {path = "lib/intrinsics"}
2727

2828
[compat]

src/compiler/compilation.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,27 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919
contains(fn, "__spirv_")
2020

2121
GPUCompiler.kernel_state_type(job::OpenCLCompilerJob) = KernelState
22-
GPUCompiler.additional_arg_types(job::OpenCLCompilerJob) = (; random_keys = LLVMPtr{UInt32, AS.Workgroup}, random_counters = LLVMPtr{UInt32, AS.Workgroup})
2322

2423
function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2524
mod::LLVM.Module, entry::LLVM.Function)
2625
entry = invoke(GPUCompiler.finish_module!,
2726
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
2827
job, mod, entry)
2928

29+
kernel_intrinsics = Dict(
30+
"julia.spirv.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
31+
"julia.spirv.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
32+
)
33+
entry′ = GPUCompiler.add_input_arguments!(job, mod, entry, kernel_intrinsics)
34+
3035
# if this kernel uses our RNG, we should prime the shared state.
3136
# XXX: these transformations should really happen at the Julia IR level...
32-
if callconv(entry) == LLVM.API.LLVMSPIRKERNELCallConv && haskey(functions(mod), "julia.gpu.additional_arg_getter")
37+
if job.config.kernel && entry !== entry′
38+
entry = entry′
3339
# insert call to `initialize_rng_state`
3440
f = initialize_rng_state
3541
ft = typeof(f)
36-
tt = Tuple{}
42+
tt = NTuple{2, LLVMPtr{UInt32, AS.Workgroup}}
3743

3844
# create a deferred compilation job for `initialize_rng_state`
3945
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
@@ -74,16 +80,18 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
7480
# call the `initialize_rng_state` function
7581
rt = Core.Compiler.return_type(f, tt)
7682
llvm_rt = convert(LLVMType, rt)
77-
llvm_ft = LLVM.FunctionType(llvm_rt)
83+
llvm_ft = LLVM.FunctionType(llvm_rt, [convert(LLVMType, LLVMPtr{UInt32, AS.Workgroup}) for _ in 1:2])
7884
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
79-
call!(builder, llvm_ft, fptr)
85+
random_keys = findfirst(arg -> name(arg) == "random_keys", parameters(entry))
86+
random_counters = findfirst(arg -> name(arg) == "random_counters", parameters(entry))
87+
call!(builder, llvm_ft, fptr, parameters(entry)[[random_keys, random_counters]])
8088
br!(builder, top_bb)
8189
end
8290

8391
# XXX: put some of the above behind GPUCompiler abstractions
8492
# (e.g., a compile-time version of `deferred_codegen`)
8593
end
86-
return entry
94+
return entry
8795
end
8896

8997
## compiler implementation (cache, configure, compile, and link)

src/device/random.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ import RandomNumbers
1010
# local memory with the actual seed, per subgroup, set by `initialize_rng_state`` or overridden by calling `seed!`
1111
@inline function global_random_keys()
1212
n = get_num_sub_groups()
13-
ptr = additional_args(Val{1}())::LLVMPtr{UInt32, AS.Workgroup}
13+
ptr = random_keys()::LLVMPtr{UInt32, AS.Workgroup}
1414
return CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), ptr)
1515
end
1616

1717
# local memory with per-subgroup counters, incremented when generating numbers
1818
@inline function global_random_counters()
1919
n = get_num_sub_groups()
20-
ptr = additional_args(Val{2}())::LLVMPtr{UInt32, AS.Workgroup}
20+
ptr = random_counters()::LLVMPtr{UInt32, AS.Workgroup}
2121
return CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), ptr)
2222
end
2323

2424
# initialization function, called automatically at the start of each kernel
25-
function initialize_rng_state()
26-
random_keys = global_random_keys()
27-
random_counters = global_random_counters()
25+
function initialize_rng_state(random_keys_ptr::LLVMPtr{UInt32, AS.Workgroup}, random_counters_ptr::LLVMPtr{UInt32, AS.Workgroup})
26+
n = get_num_sub_groups()
27+
random_keys = CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), random_keys_ptr)
28+
random_counters = CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), random_counters_ptr)
2829

2930
subgroup_id = get_sub_group_id()
3031
@inbounds random_keys[subgroup_id] = kernel_state().random_seed

src/device/runtime.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ struct KernelState
2020
end
2121

2222
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)
23-
@inline @generated additional_args(::Val{i}) where {i} = GPUCompiler.additional_arg_value(LLVMPtr{UInt32, AS.Workgroup}, i)
23+
@inline @generated random_keys() = GPUCompiler.call_custom_intrinsic(LLVMPtr{UInt32, AS.Workgroup}, "julia.spirv.random_keys", "random_keys")
24+
@inline @generated random_counters() = GPUCompiler.call_custom_intrinsic(LLVMPtr{UInt32, AS.Workgroup}, "julia.spirv.random_counters", "random_counters")

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2424
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
2525

2626
[sources]
27-
GPUCompiler = {rev = "sds/additional_args", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
27+
GPUCompiler = {rev = "sds/add_input_args", url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"}
2828

2929
[compat]
3030
pocl_jll = "7.0"

0 commit comments

Comments
 (0)