diff --git a/src/metal.jl b/src/metal.jl index d191ab47..e3978071 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -89,8 +89,32 @@ end function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] - # Metal never supports double precision - append!(errors, check_ir_values(mod, LLVM.DoubleType())) + # Metal does not support double precision, except for logging + function is_illegal_double(val) + T_bad = LLVM.DoubleType() + if value_type(val) != T_bad + return false + end + + function used_for_logging(use::LLVM.Use) + usr = user(use) + if usr isa LLVM.CallInst + callee = called_operand(usr) + if callee isa LLVM.Function && startswith(name(callee), "metal_os_log") + return true + end + end + return false + end + if all(used_for_logging, uses(val)) + return false + end + + return true + end + append!(errors, check_ir_values(mod, is_illegal_double, "use of double value")) + + # Metal never supports 128-bit integers append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index 93e18950..104cf19b 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -317,16 +317,18 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) return errors end -# helper function to check if a LLVM module uses values of a certain type -function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) +# helper function to check for illegal values in an LLVM module +function check_ir_values(mod::LLVM.Module, predicate, msg="value") errors = IRError[] - for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) + if predicate(inst) || any(predicate, operands(inst)) bt = backtrace(inst) - push!(errors, ("use of $(string(T_bad)) value", bt, inst)) + push!(errors, (msg, bt, inst)) end end - return errors end +## shorthand to check for illegal value types +function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) + check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value") +end diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 6a6efd5b..de97d90b 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -79,6 +79,53 @@ end @test occursin("air.max.s.v2i64", ir) end +@testset "unsupported type detection" begin + function kernel1(ptr) + buf = reinterpret(Ptr{Float32}, ptr) + val = unsafe_load(buf) + dval = Cdouble(val) + # ccall("extern metal_os_log", llvmcall, Nothing, (Float64,), dval) + Base.llvmcall((""" + declare void @llvm.va_start(i8*) + declare void @llvm.va_end(i8*) + declare void @air.os_log(i8*, i64) + + define void @metal_os_log(...) { + %1 = alloca i8* + %2 = bitcast i8** %1 to i8* + call void @llvm.va_start(i8* %2) + %3 = load i8*, i8** %1 + call void @air.os_log(i8* %3, i64 8) + call void @llvm.va_end(i8* %2) + ret void + } + + define void @entry(double %val) #0 { + call void (...) @metal_os_log(double %val) + ret void + } + + attributes #0 = { alwaysinline }""", "entry"), + Nothing, Tuple{Float64}, dval) + return + end + + + ir = sprint(io->Metal.code_llvm(io, kernel1, Tuple{Core.LLVMPtr{Float32,1}}; validate=true)) + @test occursin("@metal_os_log", ir) + + function kernel2(ptr) + val = unsafe_load(ptr) + res = val * val + unsafe_store!(ptr, res) + return + end + + @test_throws_message(InvalidIRError, Metal.code_llvm(devnull, kernel2, Tuple{Core.LLVMPtr{Float64,1}}; validate=true)) do msg + occursin("unsupported use of double value", msg) + end +end + end end