Skip to content

Commit eb462aa

Browse files
committed
Add helpers for iterating kernels in a module.
1 parent d440d01 commit eb462aa

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

src/driver.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,7 @@ const __llvm_initialized = Ref(false)
322322
@tracepoint "IR post-processing" begin
323323
# mark the kernel entry-point functions (optimization may need it)
324324
if job.config.kernel
325-
push!(metadata(ir)["julia.kernel"], MDNode([entry]))
326-
327-
# IDEA: save all jobs, not only kernels, and save other attributes
328-
# so that we can reconstruct the CompileJob instead of setting it globally
325+
mark_kernel!(entry)
329326
end
330327

331328
if job.config.toplevel

src/irgen.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,12 @@ function add_kernel_state!(mod::LLVM.Module)
526526
state_intr = kernel_state_intr(mod, T_state)
527527
state_intr_ft = LLVM.FunctionType(T_state)
528528

529-
kernels = []
530-
kernels_md = metadata(mod)["julia.kernel"]
531-
for kernel_md in operands(kernels_md)
532-
push!(kernels, Value(operands(kernel_md)[1]))
533-
end
534-
535529
# determine which functions need a kernel state argument
536530
#
537531
# previously, we add the argument to every function and relied on unused arg elim to
538532
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
539533
# function pointers. such IR is hard to rewrite, so instead be more conservative.
540-
worklist = Set{LLVM.Function}([state_intr, kernels...])
534+
worklist = Set{LLVM.Function}([state_intr, kernels(mod)...])
541535
worklist_length = 0
542536
while worklist_length != length(worklist)
543537
# iteratively discover functions that use the intrinsic or any function calling it

src/metal.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,10 @@ isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) =
4848
return startswith(fn, "air.")
4949

5050
function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module)
51-
if haskey(metadata(mod), "julia.kernel")
52-
kernels_md = metadata(mod)["julia.kernel"]
53-
for kernel_md in operands(kernels_md)
54-
f = LLVM.Value(operands(kernel_md)[1])::LLVM.Function
55-
56-
# update calling conventions
57-
f = pass_by_reference!(job, mod, f)
58-
f = add_input_arguments!(job, mod, f, kernel_intrinsics)
59-
end
51+
for f in kernels(mod)
52+
# update calling conventions
53+
f = pass_by_reference!(job, mod, f)
54+
f = add_input_arguments!(job, mod, f, kernel_intrinsics)
6055
end
6156

6257
# emit the AIR and Metal version numbers as constants in the module. this makes it

src/utils.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,30 @@ function prune_constexpr_uses!(root::LLVM.Value)
155155
end
156156
end
157157
end
158+
159+
160+
## kernel metadata handling
161+
162+
# kernels are encoded in the IR using the julia.kernel metadata.
163+
164+
# IDEA: don't only mark kernels, but all jobs, and save all attributes of the CompileJob
165+
# so that we can reconstruct the CompileJob instead of setting it globally
166+
167+
# mark a function as kernel
168+
function mark_kernel!(f::LLVM.Function)
169+
mod = LLVM.parent(f)
170+
push!(metadata(mod)["julia.kernel"], MDNode([f]))
171+
return f
172+
end
173+
174+
# iterate over all kernels in the module
175+
function kernels(mod::LLVM.Module)
176+
vals = LLVM.Function[]
177+
if haskey(metadata(mod), "julia.kernel")
178+
kernels_md = metadata(mod)["julia.kernel"]
179+
for kernel_md in operands(kernels_md)
180+
push!(vals, LLVM.Value(operands(kernel_md)[1]))
181+
end
182+
end
183+
return vals
184+
end

0 commit comments

Comments
 (0)