Skip to content

Commit 6bae13e

Browse files
tgymnichmaleadt
andauthored
Generalize check_ir_values (#630)
Use that to have Metal support Float64 in the context of logging. Co-authored-by: Tim Besard <[email protected]>
1 parent bbd6124 commit 6bae13e

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

src/metal.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,32 @@ end
8989
function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
9090
errors = IRError[]
9191

92-
# Metal never supports double precision
93-
append!(errors, check_ir_values(mod, LLVM.DoubleType()))
92+
# Metal does not support double precision, except for logging
93+
function is_illegal_double(val)
94+
T_bad = LLVM.DoubleType()
95+
if value_type(val) != T_bad
96+
return false
97+
end
98+
99+
function used_for_logging(use::LLVM.Use)
100+
usr = user(use)
101+
if usr isa LLVM.CallInst
102+
callee = called_operand(usr)
103+
if callee isa LLVM.Function && startswith(name(callee), "metal_os_log")
104+
return true
105+
end
106+
end
107+
return false
108+
end
109+
if all(used_for_logging, uses(val))
110+
return false
111+
end
112+
113+
return true
114+
end
115+
append!(errors, check_ir_values(mod, is_illegal_double, "use of double value"))
116+
117+
# Metal never supports 128-bit integers
94118
append!(errors, check_ir_values(mod, LLVM.IntType(128)))
95119

96120
errors

src/validation.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,16 +317,18 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
317317
return errors
318318
end
319319

320-
# helper function to check if a LLVM module uses values of a certain type
321-
function check_ir_values(mod::LLVM.Module, T_bad::LLVMType)
320+
# helper function to check for illegal values in an LLVM module
321+
function check_ir_values(mod::LLVM.Module, predicate, msg="value")
322322
errors = IRError[]
323-
324323
for fun in functions(mod), bb in blocks(fun), inst in instructions(bb)
325-
if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst))
324+
if predicate(inst) || any(predicate, operands(inst))
326325
bt = backtrace(inst)
327-
push!(errors, ("use of $(string(T_bad)) value", bt, inst))
326+
push!(errors, (msg, bt, inst))
328327
end
329328
end
330-
331329
return errors
332330
end
331+
## shorthand to check for illegal value types
332+
function check_ir_values(mod::LLVM.Module, T_bad::LLVMType)
333+
check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value")
334+
end

test/metal_tests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,53 @@ end
7979
@test occursin("air.max.s.v2i64", ir)
8080
end
8181

82+
@testset "unsupported type detection" begin
83+
function kernel1(ptr)
84+
buf = reinterpret(Ptr{Float32}, ptr)
85+
val = unsafe_load(buf)
86+
dval = Cdouble(val)
87+
# ccall("extern metal_os_log", llvmcall, Nothing, (Float64,), dval)
88+
Base.llvmcall(("""
89+
declare void @llvm.va_start(i8*)
90+
declare void @llvm.va_end(i8*)
91+
declare void @air.os_log(i8*, i64)
92+
93+
define void @metal_os_log(...) {
94+
%1 = alloca i8*
95+
%2 = bitcast i8** %1 to i8*
96+
call void @llvm.va_start(i8* %2)
97+
%3 = load i8*, i8** %1
98+
call void @air.os_log(i8* %3, i64 8)
99+
call void @llvm.va_end(i8* %2)
100+
ret void
101+
}
102+
103+
define void @entry(double %val) #0 {
104+
call void (...) @metal_os_log(double %val)
105+
ret void
106+
}
107+
108+
attributes #0 = { alwaysinline }""", "entry"),
109+
Nothing, Tuple{Float64}, dval)
110+
return
111+
end
112+
113+
114+
ir = sprint(io->Metal.code_llvm(io, kernel1, Tuple{Core.LLVMPtr{Float32,1}}; validate=true))
115+
@test occursin("@metal_os_log", ir)
116+
117+
function kernel2(ptr)
118+
val = unsafe_load(ptr)
119+
res = val * val
120+
unsafe_store!(ptr, res)
121+
return
122+
end
123+
124+
@test_throws_message(InvalidIRError, Metal.code_llvm(devnull, kernel2, Tuple{Core.LLVMPtr{Float64,1}}; validate=true)) do msg
125+
occursin("unsupported use of double value", msg)
126+
end
127+
end
128+
82129
end
83130

84131
end

0 commit comments

Comments
 (0)