@@ -19,21 +19,27 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919 contains (fn, " __spirv_" )
2020
2121GPUCompiler. kernel_state_type (job:: OpenCLCompilerJob ) = KernelState
22- GPUCompiler. additional_arg_types (job:: OpenCLCompilerJob ) = (; random_keys = LLVMPtr{UInt32, AS. Workgroup}, random_counters = LLVMPtr{UInt32, AS. Workgroup})
2322
2423function GPUCompiler. finish_module! (@nospecialize (job:: OpenCLCompilerJob ),
2524 mod:: LLVM.Module , entry:: LLVM.Function )
2625 entry = invoke (GPUCompiler. finish_module!,
2726 Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM. Module, LLVM. Function},
2827 job, mod, entry)
2928
29+ kernel_intrinsics = Dict (
30+ " julia.spirv.random_keys" => (; name = " random_keys" , typ = LLVMPtr{UInt32, AS. Workgroup}),
31+ " julia.spirv.random_counters" => (; name = " random_counters" , typ = LLVMPtr{UInt32, AS. Workgroup}),
32+ )
33+ entry′ = GPUCompiler. add_input_arguments! (job, mod, entry, kernel_intrinsics)
34+
3035 # if this kernel uses our RNG, we should prime the shared state.
3136 # XXX : these transformations should really happen at the Julia IR level...
32- if callconv (entry) == LLVM. API. LLVMSPIRKERNELCallConv && haskey (functions (mod), " julia.gpu.additional_arg_getter" )
37+ if job. config. kernel && entry != = entry′
38+ entry = entry′
3339 # insert call to `initialize_rng_state`
3440 f = initialize_rng_state
3541 ft = typeof (f)
36- tt = Tuple{ }
42+ tt = NTuple{ 2 , LLVMPtr{UInt32, AS . Workgroup} }
3743
3844 # create a deferred compilation job for `initialize_rng_state`
3945 src = methodinstance (ft, tt, GPUCompiler. tls_world_age ())
@@ -74,16 +80,18 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
7480 # call the `initialize_rng_state` function
7581 rt = Core. Compiler. return_type (f, tt)
7682 llvm_rt = convert (LLVMType, rt)
77- llvm_ft = LLVM. FunctionType (llvm_rt)
83+ llvm_ft = LLVM. FunctionType (llvm_rt, [ convert (LLVMType, LLVMPtr{UInt32, AS . Workgroup}) for _ in 1 : 2 ] )
7884 fptr = inttoptr! (builder, fptr, LLVM. PointerType (llvm_ft))
79- call! (builder, llvm_ft, fptr)
85+ random_keys = findfirst (arg -> name (arg) == " random_keys" , parameters (entry))
86+ random_counters = findfirst (arg -> name (arg) == " random_counters" , parameters (entry))
87+ call! (builder, llvm_ft, fptr, parameters (entry)[[random_keys, random_counters]])
8088 br! (builder, top_bb)
8189 end
8290
8391 # XXX : put some of the above behind GPUCompiler abstractions
8492 # (e.g., a compile-time version of `deferred_codegen`)
8593 end
86- return entry
94+ return entry′
8795end
8896
8997# # compiler implementation (cache, configure, compile, and link)
0 commit comments