@@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
53
53
# update calling conventions
54
54
if job. config. kernel
55
55
entry = pass_by_reference! (job, mod, entry)
56
-
57
- add_input_arguments! (job, mod, entry)
58
- entry = LLVM. functions (mod)[entry_fn]
56
+ entry = add_input_arguments! (job, mod, entry, kernel_intrinsics)
59
57
end
60
58
61
59
# emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -553,164 +551,6 @@ function argument_type_name(typ)
553
551
end
554
552
end
555
553
556
- function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
557
- entry:: LLVM.Function )
558
- entry_fn = LLVM. name (entry)
559
-
560
- # figure out which intrinsics are used and need to be added as arguments
561
- used_intrinsics = filter (keys (kernel_intrinsics)) do intr_fn
562
- haskey (functions (mod), intr_fn)
563
- end |> collect
564
- nargs = length (used_intrinsics)
565
-
566
- # determine which functions need these arguments
567
- worklist = Set {LLVM.Function} ([entry])
568
- for intr_fn in used_intrinsics
569
- push! (worklist, functions (mod)[intr_fn])
570
- end
571
- worklist_length = 0
572
- while worklist_length != length (worklist)
573
- # iteratively discover functions that use an intrinsic or any function calling it
574
- worklist_length = length (worklist)
575
- additions = LLVM. Function[]
576
- for f in worklist, use in uses (f)
577
- inst = user (use):: Instruction
578
- bb = LLVM. parent (inst)
579
- new_f = LLVM. parent (bb)
580
- in (new_f, worklist) || push! (additions, new_f)
581
- end
582
- for f in additions
583
- push! (worklist, f)
584
- end
585
- end
586
- for intr_fn in used_intrinsics
587
- delete! (worklist, functions (mod)[intr_fn])
588
- end
589
-
590
- # add the arguments
591
- # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
592
- workmap = Dict {LLVM.Function, LLVM.Function} ()
593
- for f in worklist
594
- fn = LLVM. name (f)
595
- ft = function_type (f)
596
- LLVM. name! (f, fn * " .orig" )
597
- # create a new function
598
- new_param_types = LLVMType[parameters (ft)... ]
599
-
600
- for intr_fn in used_intrinsics
601
- llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
602
- push! (new_param_types, llvm_typ)
603
- end
604
- new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
605
- new_f = LLVM. Function (mod, fn, new_ft)
606
- linkage! (new_f, linkage (f))
607
- for (arg, new_arg) in zip (parameters (f), parameters (new_f))
608
- LLVM. name! (new_arg, LLVM. name (arg))
609
- end
610
- for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
611
- LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
612
- end
613
-
614
- workmap[f] = new_f
615
- end
616
-
617
- # clone and rewrite the function bodies.
618
- # we don't need to rewrite much as the arguments are added last.
619
- for (f, new_f) in workmap
620
- # map the arguments
621
- value_map = Dict {LLVM.Value, LLVM.Value} ()
622
- for (param, new_param) in zip (parameters (f), parameters (new_f))
623
- LLVM. name! (new_param, LLVM. name (param))
624
- value_map[param] = new_param
625
- end
626
-
627
- value_map[f] = new_f
628
- clone_into! (new_f, f; value_map,
629
- changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
630
-
631
- # we can't remove this function yet, as we might still need to rewrite any called,
632
- # but remove the IR already
633
- empty! (f)
634
- end
635
-
636
- # drop unused constants that may be referring to the old functions
637
- # XXX : can we do this differently?
638
- for f in worklist
639
- prune_constexpr_uses! (f)
640
- end
641
-
642
- # update other uses of the old function, modifying call sites to pass the arguments
643
- function rewrite_uses! (f, new_f)
644
- # update uses
645
- @dispose builder= IRBuilder () begin
646
- for use in uses (f)
647
- val = user (use)
648
- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
649
- callee_f = LLVM. parent (LLVM. parent (val))
650
- # forward the arguments
651
- position! (builder, val)
652
- new_val = if val isa LLVM. CallInst
653
- call! (builder, function_type (new_f), new_f,
654
- [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
655
- operand_bundles (val))
656
- else
657
- # TODO : invoke and callbr
658
- error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
659
- end
660
- callconv! (new_val, callconv (val))
661
-
662
- replace_uses! (val, new_val)
663
- @assert isempty (uses (val))
664
- erase! (val)
665
- elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
666
- # XXX : why isn't this caught by the value materializer above?
667
- target = operands (val)[1 ]
668
- @assert target == f
669
- new_val = LLVM. const_bitcast (new_f, value_type (val))
670
- rewrite_uses! (val, new_val)
671
- # we can't simply replace this constant expression, as it may be used
672
- # as a call, taking arguments (so we need to rewrite it to pass the input arguments)
673
-
674
- # drop the old constant if it is unused
675
- # XXX : can we do this differently?
676
- if isempty (uses (val))
677
- LLVM. unsafe_destroy! (val)
678
- end
679
- else
680
- error (" Cannot rewrite unknown use of function: $val " )
681
- end
682
- end
683
- end
684
- end
685
- for (f, new_f) in workmap
686
- rewrite_uses! (f, new_f)
687
- @assert isempty (uses (f))
688
- erase! (f)
689
- end
690
-
691
- # replace uses of the intrinsics with references to the input arguments
692
- for (i, intr_fn) in enumerate (used_intrinsics)
693
- intr = functions (mod)[intr_fn]
694
- for use in uses (intr)
695
- val = user (use)
696
- callee_f = LLVM. parent (LLVM. parent (val))
697
- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
698
- replace_uses! (val, parameters (callee_f)[end - nargs+ i])
699
- else
700
- error (" Cannot rewrite unknown use of function: $val " )
701
- end
702
-
703
- @assert isempty (uses (val))
704
- erase! (val)
705
- end
706
- @assert isempty (uses (intr))
707
- erase! (intr)
708
- end
709
-
710
- return
711
- end
712
-
713
-
714
554
# argument metadata generation
715
555
#
716
556
# module metadata is used to identify buffers that are passed as kernel arguments.
0 commit comments