Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,33 @@ OpenCL_jll = "6cb37087-e8b6-5417-8430-1f242f1e46e4"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[sources]
SPIRVIntrinsics = {path = "lib/intrinsics"}

[compat]
Adapt = "4"
GPUArrays = "11.2.1"
GPUCompiler = "1.6"
GPUCompiler = "1.7.1"
KernelAbstractions = "0.9.2"
LLVM = "9.1"
LinearAlgebra = "1"
OpenCL_jll = "=2024.10.24"
Preferences = "1"
Printf = "1"
Random = "1"
Random123 = "1.7.1"
RandomNumbers = "1.6.0"
Reexport = "1"
SPIRVIntrinsics = "0.5"
SPIRV_LLVM_Backend_jll = "20"
SPIRV_Tools_jll = "2025.1"
StaticArrays = "1"
julia = "1.10"

[sources]
SPIRVIntrinsics = {path="lib/intrinsics"}
66 changes: 63 additions & 3 deletions lib/cl/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,59 @@ function Base.getproperty(ki::KernelWorkGroupInfo, s::Symbol)
end
end

struct KernelSubGroupInfo
kernel::Kernel
device::Device
local_work_size::Vector{Csize_t}
end
sub_group_info(k::Kernel, d::Device, l) = KernelSubGroupInfo(k, d, Vector{Csize_t}(l))

# Helper function for getting local size for a specific sub-group count
function local_size_for_sub_group_count(ki::KernelSubGroupInfo, sub_group_count::Integer)
k = getfield(ki, :kernel)
d = getfield(ki, :device)
input_value = Ref{Csize_t}(sub_group_count)
result = Ref{NTuple{3, Csize_t}}()
clGetKernelSubGroupInfo(k, d, CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT,
sizeof(Csize_t), input_value, sizeof(NTuple{3, Csize_t}), result, C_NULL)
return Int.(result[])
end

function Base.getproperty(ki::KernelSubGroupInfo, s::Symbol)
k = getfield(ki, :kernel)
d = getfield(ki, :device)
lws = getfield(ki, :local_work_size)

function get(val, typ)
result = Ref{typ}()
clGetKernelSubGroupInfo(k, d, val, sizeof(lws), lws, sizeof(typ), result, C_NULL)
return result[]
end

if s == :max_sub_group_size
Int(get(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, Csize_t))
elseif s == :sub_group_count
Int(get(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE, Csize_t))
elseif s == :local_size_for_sub_group_count
# This requires input_value to be the desired sub-group count
error("local_size_for_sub_group_count requires specifying desired sub-group count")
elseif s == :max_num_sub_groups
Int(get(CL_KERNEL_MAX_NUM_SUB_GROUPS, Csize_t))
elseif s == :compile_num_sub_groups
Int(get(CL_KERNEL_COMPILE_NUM_SUB_GROUPS, Csize_t))
elseif s == :compile_sub_group_size
Int(get(CL_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL, Csize_t))
else
getfield(ki, s)
end
end


## kernel calling

function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
global_work_offset=nothing, wait_on::Vector{Event}=Event[])
global_work_offset=nothing, wait_on::Vector{Event}=Event[],
device_rng=false)
max_work_dim = device().max_work_item_dims
work_dim = length(global_work_size)
if work_dim > max_work_dim
Expand Down Expand Up @@ -153,6 +201,17 @@ function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
# null local size means OpenCL decides
end

if device_rng
if local_work_size !== nothing
num_sub_groups = KernelSubGroupInfo(k, device(), lsize).sub_group_count
else
num_sub_groups = KernelSubGroupInfo(k, device(), Csize_t[]).max_num_sub_groups
end
rng_state_size = sizeof(UInt32) * num_sub_groups
set_arg!(k, k.num_args - 1, LocalMem(UInt32, rng_state_size))
set_arg!(k, k.num_args, LocalMem(UInt32, rng_state_size))
end

if !isempty(wait_on)
n_events = length(wait_on)
wait_event_ids = [evt.id for evt in wait_on]
Expand Down Expand Up @@ -189,7 +248,8 @@ end
function call(
k::Kernel, args...; global_size = (1,), local_size = nothing,
global_work_offset = nothing, wait_on::Vector{Event} = Event[],
indirect_memory::Vector{AbstractMemory} = AbstractMemory[]
indirect_memory::Vector{AbstractMemory} = AbstractMemory[],
device_rng=false,
)
set_args!(k, args...)
if !isempty(indirect_memory)
Expand Down Expand Up @@ -243,7 +303,7 @@ function call(
clSetKernelExecInfo(k, CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, sizeof(usm_pointers), usm_pointers)
end
end
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on)
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, device_rng)
end

# From `julia/base/reflection.jl`, adjusted to add specialization on `t`.
Expand Down
6 changes: 5 additions & 1 deletion lib/intrinsics/src/math.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Math Functions

# TODO: vector types
const generic_types = [Float32,Float64]
const generic_types = [Float16, Float32, Float64]
const generic_types_float = [Float32]
const generic_types_double = [Float64]

Expand Down Expand Up @@ -151,11 +151,13 @@ end
# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp)
# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp)

@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x)
# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x)
@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x)
# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x)
@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x)

@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k)
# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k)
# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k)
@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k)
Expand All @@ -168,11 +170,13 @@ end
# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp)
# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp)

@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode)
# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode)
@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode)
# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode)
@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode)

@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y)
# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y)
@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y)
# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y)
Expand Down
2 changes: 2 additions & 0 deletions lib/intrinsics/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...)
"c"
elseif T == UInt8
"h"
elseif T == Float16
"Dh"
elseif T == Float32
"f"
elseif T == Float64
Expand Down
1 change: 1 addition & 0 deletions src/OpenCL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Base.Experimental.@MethodTable(method_table)
include("device/runtime.jl")
include("device/array.jl")
include("device/quirks.jl")
include("device/random.jl")

# high level implementation
include("memory.jl")
Expand Down
77 changes: 77 additions & 0 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,83 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
in(fn, known_intrinsics) ||
contains(fn, "__spirv_")

GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(functions(mod), "julia.spirv.random_keys") && job.config.kernel
# insert call to `initialize_rng_state`
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# create a deferred compilation job for `initialize_rng_state`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = if LLVM.version() >= v"17"
LLVM.PointerType()
elseif VERSION >= v"1.12.0-DEV.225"
LLVM.PointerType(LLVM.Int8Type())
else
LLVM.Int64Type()
end
T_id = convert(LLVMType, Int)
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end

function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
for f in GPUCompiler.kernels(mod)
kernel_intrinsics = Dict(
"julia.spirv.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
"julia.spirv.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
)
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
end
return
end

## compiler implementation (cache, configure, compile, and link)

Expand Down
6 changes: 5 additions & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,16 @@ abstract type AbstractKernel{F, TT} end
end
end

pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(KernelState(Base.rand(UInt32))))

# finalize types
call_tt = Base.to_tuple_type(call_t)

quote
indirect_memory = cl.AbstractMemory[]
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, call_kwargs...)
device_rng = kernel.fun.num_args == $(length(call_args) + 2)
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, device_rng, call_kwargs...)
end
end

Expand Down
Loading
Loading