Skip to content

Commit 301f389

Browse files
simeonschaubmaleadt
authored andcommitted
add_input_arguments! for other backends (#718)
Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel.
1 parent 185d06d commit 301f389

File tree

2 files changed

+158
-161
lines changed

2 files changed

+158
-161
lines changed

src/irgen.jl

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,3 +921,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
921921
return new_f
922922
end
923923
end
924+
925+
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
926+
entry::LLVM.Function, kernel_intrinsics::Dict)
927+
entry_fn = LLVM.name(entry)
928+
929+
# figure out which intrinsics are used and need to be added as arguments
930+
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
931+
haskey(functions(mod), intr_fn)
932+
end |> collect
933+
nargs = length(used_intrinsics)
934+
935+
# determine which functions need these arguments
936+
worklist = Set{LLVM.Function}([entry])
937+
for intr_fn in used_intrinsics
938+
push!(worklist, functions(mod)[intr_fn])
939+
end
940+
worklist_length = 0
941+
while worklist_length != length(worklist)
942+
# iteratively discover functions that use an intrinsic or any function calling it
943+
worklist_length = length(worklist)
944+
additions = LLVM.Function[]
945+
for f in worklist, use in uses(f)
946+
inst = user(use)::Instruction
947+
bb = LLVM.parent(inst)
948+
new_f = LLVM.parent(bb)
949+
in(new_f, worklist) || push!(additions, new_f)
950+
end
951+
for f in additions
952+
push!(worklist, f)
953+
end
954+
end
955+
for intr_fn in used_intrinsics
956+
delete!(worklist, functions(mod)[intr_fn])
957+
end
958+
959+
# add the arguments
960+
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
961+
workmap = Dict{LLVM.Function, LLVM.Function}()
962+
for f in worklist
963+
fn = LLVM.name(f)
964+
ft = function_type(f)
965+
LLVM.name!(f, fn * ".orig")
966+
# create a new function
967+
new_param_types = LLVMType[parameters(ft)...]
968+
969+
for intr_fn in used_intrinsics
970+
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
971+
push!(new_param_types, llvm_typ)
972+
end
973+
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
974+
new_f = LLVM.Function(mod, fn, new_ft)
975+
linkage!(new_f, linkage(f))
976+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
977+
LLVM.name!(new_arg, LLVM.name(arg))
978+
end
979+
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
980+
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
981+
end
982+
983+
workmap[f] = new_f
984+
end
985+
986+
# clone and rewrite the function bodies.
987+
# we don't need to rewrite much as the arguments are added last.
988+
for (f, new_f) in workmap
989+
# map the arguments
990+
value_map = Dict{LLVM.Value, LLVM.Value}()
991+
for (param, new_param) in zip(parameters(f), parameters(new_f))
992+
LLVM.name!(new_param, LLVM.name(param))
993+
value_map[param] = new_param
994+
end
995+
996+
value_map[f] = new_f
997+
clone_into!(new_f, f; value_map,
998+
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
999+
1000+
# we can't remove this function yet, as we might still need to rewrite any called,
1001+
# but remove the IR already
1002+
empty!(f)
1003+
end
1004+
1005+
# drop unused constants that may be referring to the old functions
1006+
# XXX: can we do this differently?
1007+
for f in worklist
1008+
prune_constexpr_uses!(f)
1009+
end
1010+
1011+
# update other uses of the old function, modifying call sites to pass the arguments
1012+
function rewrite_uses!(f, new_f)
1013+
# update uses
1014+
@dispose builder=IRBuilder() begin
1015+
for use in uses(f)
1016+
val = user(use)
1017+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1018+
callee_f = LLVM.parent(LLVM.parent(val))
1019+
# forward the arguments
1020+
position!(builder, val)
1021+
new_val = if val isa LLVM.CallInst
1022+
call!(builder, function_type(new_f), new_f,
1023+
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
1024+
operand_bundles(val))
1025+
else
1026+
# TODO: invoke and callbr
1027+
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
1028+
end
1029+
callconv!(new_val, callconv(val))
1030+
1031+
replace_uses!(val, new_val)
1032+
@assert isempty(uses(val))
1033+
erase!(val)
1034+
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
1035+
# XXX: why isn't this caught by the value materializer above?
1036+
target = operands(val)[1]
1037+
@assert target == f
1038+
new_val = LLVM.const_bitcast(new_f, value_type(val))
1039+
rewrite_uses!(val, new_val)
1040+
# we can't simply replace this constant expression, as it may be used
1041+
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1042+
1043+
# drop the old constant if it is unused
1044+
# XXX: can we do this differently?
1045+
if isempty(uses(val))
1046+
LLVM.unsafe_destroy!(val)
1047+
end
1048+
else
1049+
error("Cannot rewrite unknown use of function: $val")
1050+
end
1051+
end
1052+
end
1053+
end
1054+
for (f, new_f) in workmap
1055+
rewrite_uses!(f, new_f)
1056+
@assert isempty(uses(f))
1057+
erase!(f)
1058+
end
1059+
1060+
# replace uses of the intrinsics with references to the input arguments
1061+
for (i, intr_fn) in enumerate(used_intrinsics)
1062+
intr = functions(mod)[intr_fn]
1063+
for use in uses(intr)
1064+
val = user(use)
1065+
callee_f = LLVM.parent(LLVM.parent(val))
1066+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1067+
replace_uses!(val, parameters(callee_f)[end-nargs+i])
1068+
else
1069+
error("Cannot rewrite unknown use of function: $val")
1070+
end
1071+
1072+
@assert isempty(uses(val))
1073+
erase!(val)
1074+
end
1075+
@assert isempty(uses(intr))
1076+
erase!(intr)
1077+
end
1078+
1079+
return functions(mod)[entry_fn]
1080+
end

src/metal.jl

Lines changed: 1 addition & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
5353
# update calling conventions
5454
if job.config.kernel
5555
entry = pass_by_reference!(job, mod, entry)
56-
57-
add_input_arguments!(job, mod, entry)
58-
entry = LLVM.functions(mod)[entry_fn]
56+
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
5957
end
6058

6159
# emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -553,164 +551,6 @@ function argument_type_name(typ)
553551
end
554552
end
555553

556-
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
557-
entry::LLVM.Function)
558-
entry_fn = LLVM.name(entry)
559-
560-
# figure out which intrinsics are used and need to be added as arguments
561-
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
562-
haskey(functions(mod), intr_fn)
563-
end |> collect
564-
nargs = length(used_intrinsics)
565-
566-
# determine which functions need these arguments
567-
worklist = Set{LLVM.Function}([entry])
568-
for intr_fn in used_intrinsics
569-
push!(worklist, functions(mod)[intr_fn])
570-
end
571-
worklist_length = 0
572-
while worklist_length != length(worklist)
573-
# iteratively discover functions that use an intrinsic or any function calling it
574-
worklist_length = length(worklist)
575-
additions = LLVM.Function[]
576-
for f in worklist, use in uses(f)
577-
inst = user(use)::Instruction
578-
bb = LLVM.parent(inst)
579-
new_f = LLVM.parent(bb)
580-
in(new_f, worklist) || push!(additions, new_f)
581-
end
582-
for f in additions
583-
push!(worklist, f)
584-
end
585-
end
586-
for intr_fn in used_intrinsics
587-
delete!(worklist, functions(mod)[intr_fn])
588-
end
589-
590-
# add the arguments
591-
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
592-
workmap = Dict{LLVM.Function, LLVM.Function}()
593-
for f in worklist
594-
fn = LLVM.name(f)
595-
ft = function_type(f)
596-
LLVM.name!(f, fn * ".orig")
597-
# create a new function
598-
new_param_types = LLVMType[parameters(ft)...]
599-
600-
for intr_fn in used_intrinsics
601-
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
602-
push!(new_param_types, llvm_typ)
603-
end
604-
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
605-
new_f = LLVM.Function(mod, fn, new_ft)
606-
linkage!(new_f, linkage(f))
607-
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
608-
LLVM.name!(new_arg, LLVM.name(arg))
609-
end
610-
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
611-
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
612-
end
613-
614-
workmap[f] = new_f
615-
end
616-
617-
# clone and rewrite the function bodies.
618-
# we don't need to rewrite much as the arguments are added last.
619-
for (f, new_f) in workmap
620-
# map the arguments
621-
value_map = Dict{LLVM.Value, LLVM.Value}()
622-
for (param, new_param) in zip(parameters(f), parameters(new_f))
623-
LLVM.name!(new_param, LLVM.name(param))
624-
value_map[param] = new_param
625-
end
626-
627-
value_map[f] = new_f
628-
clone_into!(new_f, f; value_map,
629-
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
630-
631-
# we can't remove this function yet, as we might still need to rewrite any called,
632-
# but remove the IR already
633-
empty!(f)
634-
end
635-
636-
# drop unused constants that may be referring to the old functions
637-
# XXX: can we do this differently?
638-
for f in worklist
639-
prune_constexpr_uses!(f)
640-
end
641-
642-
# update other uses of the old function, modifying call sites to pass the arguments
643-
function rewrite_uses!(f, new_f)
644-
# update uses
645-
@dispose builder=IRBuilder() begin
646-
for use in uses(f)
647-
val = user(use)
648-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
649-
callee_f = LLVM.parent(LLVM.parent(val))
650-
# forward the arguments
651-
position!(builder, val)
652-
new_val = if val isa LLVM.CallInst
653-
call!(builder, function_type(new_f), new_f,
654-
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
655-
operand_bundles(val))
656-
else
657-
# TODO: invoke and callbr
658-
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
659-
end
660-
callconv!(new_val, callconv(val))
661-
662-
replace_uses!(val, new_val)
663-
@assert isempty(uses(val))
664-
erase!(val)
665-
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
666-
# XXX: why isn't this caught by the value materializer above?
667-
target = operands(val)[1]
668-
@assert target == f
669-
new_val = LLVM.const_bitcast(new_f, value_type(val))
670-
rewrite_uses!(val, new_val)
671-
# we can't simply replace this constant expression, as it may be used
672-
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
673-
674-
# drop the old constant if it is unused
675-
# XXX: can we do this differently?
676-
if isempty(uses(val))
677-
LLVM.unsafe_destroy!(val)
678-
end
679-
else
680-
error("Cannot rewrite unknown use of function: $val")
681-
end
682-
end
683-
end
684-
end
685-
for (f, new_f) in workmap
686-
rewrite_uses!(f, new_f)
687-
@assert isempty(uses(f))
688-
erase!(f)
689-
end
690-
691-
# replace uses of the intrinsics with references to the input arguments
692-
for (i, intr_fn) in enumerate(used_intrinsics)
693-
intr = functions(mod)[intr_fn]
694-
for use in uses(intr)
695-
val = user(use)
696-
callee_f = LLVM.parent(LLVM.parent(val))
697-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
698-
replace_uses!(val, parameters(callee_f)[end-nargs+i])
699-
else
700-
error("Cannot rewrite unknown use of function: $val")
701-
end
702-
703-
@assert isempty(uses(val))
704-
erase!(val)
705-
end
706-
@assert isempty(uses(intr))
707-
erase!(intr)
708-
end
709-
710-
return
711-
end
712-
713-
714554
# argument metadata generation
715555
#
716556
# module metadata is used to identify buffers that are passed as kernel arguments.

0 commit comments

Comments
 (0)