From 809940de1c173d97dcc3681604c601c970164694 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 17 Sep 2024 15:49:33 +0200 Subject: [PATCH 1/8] add ir_check_ignore metadata to check ir type check --- src/validation.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/validation.jl b/src/validation.jl index 93e18950..a46fd19a 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -322,7 +322,8 @@ function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) + if value_type(inst) == T_bad && !haskey(metadata(inst), "ir_check_ignore") || + any(op -> value_type(op) == T_bad && !(op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) bt = backtrace(inst) push!(errors, ("use of $(string(T_bad)) value", bt, inst)) end From b5e13ac620a3baff37ee74d2bfbc574cbe31fa05 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Wed, 18 Sep 2024 17:40:11 +0200 Subject: [PATCH 2/8] allow fpext for metal --- src/metal.jl | 2 +- src/validation.jl | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index d191ab47..b38e2c03 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -90,7 +90,7 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] # Metal never supports double precision - append!(errors, check_ir_values(mod, LLVM.DoubleType())) + append!(errors, check_ir_values(mod, LLVM.DoubleType(), allow=(LLVM.FPExtInst,))) append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index a46fd19a..915b8bca 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -318,15 +318,25 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) end # helper function to check if a LLVM module uses values of a certain type -function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) +function check_ir_values(mod::LLVM.Module, T_bad::LLVMType; allow=()) errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if value_type(inst) == T_bad && !haskey(metadata(inst), "ir_check_ignore") || - any(op -> value_type(op) == T_bad && !(op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) - bt = backtrace(inst) - push!(errors, ("use of $(string(T_bad)) value", bt, inst)) + if typeof(inst) in allow + continue end + + if haskey(metadata(inst), "ir_check_ignore") + continue + end + + if value_type(inst) != T_bad && + all(op -> value_type(op) != T_bad || (op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) + continue + end + + bt = backtrace(inst) + push!(errors, ("use of $(string(T_bad)) value", bt, inst)) end return errors From fe80530e2364c3888a621e28be716f855c72e31f Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 20 Sep 2024 13:34:52 +0200 Subject: [PATCH 3/8] Revert "allow fpext for metal" This reverts commit b5e13ac620a3baff37ee74d2bfbc574cbe31fa05. --- src/metal.jl | 2 +- src/validation.jl | 20 +++++--------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index b38e2c03..d191ab47 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -90,7 +90,7 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] # Metal never supports double precision - append!(errors, check_ir_values(mod, LLVM.DoubleType(), allow=(LLVM.FPExtInst,))) + append!(errors, check_ir_values(mod, LLVM.DoubleType())) append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index 915b8bca..a46fd19a 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -318,25 +318,15 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) end # helper function to check if a LLVM module uses values of a certain type -function check_ir_values(mod::LLVM.Module, T_bad::LLVMType; allow=()) +function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if typeof(inst) in allow - continue + if value_type(inst) == T_bad && !haskey(metadata(inst), "ir_check_ignore") || + any(op -> value_type(op) == T_bad && !(op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) + bt = backtrace(inst) + push!(errors, ("use of $(string(T_bad)) value", bt, inst)) end - - if haskey(metadata(inst), "ir_check_ignore") - continue - end - - if value_type(inst) != T_bad && - all(op -> value_type(op) != T_bad || (op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) - continue - end - - bt = backtrace(inst) - push!(errors, ("use of $(string(T_bad)) value", bt, inst)) end return errors From 9a2535eafcad1e946fd6468bbdc256f5301fb97d Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 20 Sep 2024 13:34:55 +0200 Subject: [PATCH 4/8] Revert "add ir_check_ignore metadata to check ir type check" This reverts commit 809940de1c173d97dcc3681604c601c970164694. --- src/validation.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/validation.jl b/src/validation.jl index a46fd19a..93e18950 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -322,8 +322,7 @@ function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if value_type(inst) == T_bad && !haskey(metadata(inst), "ir_check_ignore") || - any(op -> value_type(op) == T_bad && !(op isa Instruction && haskey(metadata(op), "ir_check_ignore")), operands(inst)) + if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) bt = backtrace(inst) push!(errors, ("use of $(string(T_bad)) value", bt, inst)) end From 9ed1583c19977590754fa0b48783a1d0bd0ec371 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 20 Sep 2024 14:57:09 +0200 Subject: [PATCH 5/8] use callback to check for doubles used in logging --- src/metal.jl | 26 +++++++++++++++++++++++++- src/validation.jl | 13 +++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/metal.jl b/src/metal.jl index d191ab47..17fac155 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -89,8 +89,32 @@ end function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] + function is_illegal_double_use(val) + if value_type(val) != LLVM.DoubleType() + return false + end + + function used_for_logging(use::LLVM.Use) + usr = user(use) + if usr isa LLVM.CallInst + callee = called_operand(usr) + if callee isa LLVM.Function && startswith(name(callee), "metal_os_log") + return true + end + end + + return false + end + + if all(used_for_logging, uses(val)) + return false + end + + return true + end + # Metal never supports double precision - append!(errors, check_ir_values(mod, LLVM.DoubleType())) + append!(errors, check_ir_values(mod, is_illegal_double_use)) append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index 93e18950..1705e9ce 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -330,3 +330,16 @@ function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) return errors end + +function check_ir_values(mod::LLVM.Module, T_bad) + errors = IRError[] + + for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) + if T_bad(inst) || any(T_bad, operands(inst)) + bt = backtrace(inst) + push!(errors, ("use of $(string(inst))", bt, inst)) + end + end + + return errors +end From 2cc6241740c14d77b19aa18dff2a996358c1f912 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 20 Sep 2024 19:44:41 +0200 Subject: [PATCH 6/8] improve callback --- src/metal.jl | 18 +++++++++++------- src/validation.jl | 21 ++++++++------------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index 17fac155..e5fc17f2 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -89,9 +89,11 @@ end function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] - function is_illegal_double_use(val) - if value_type(val) != LLVM.DoubleType() - return false + function is_valid_double_use(inst::LLVM.Instruction, errors) + T_bad = LLVM.DoubleType() + + if value_type(inst) != T_bad || all(param->value_type(param) != T_bad, operands(inst)) + return end function used_for_logging(use::LLVM.Use) @@ -106,15 +108,17 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) return false end - if all(used_for_logging, uses(val)) - return false + if all(used_for_logging, uses(inst)) + return end - return true + bt = backtrace(inst) + err = ("use of double value", bt, inst) + push!(errors, err) end # Metal never supports double precision - append!(errors, check_ir_values(mod, is_illegal_double_use)) + append!(errors, check_ir_values(mod, is_valid_double_use)) append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index 1705e9ce..3d3986af 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -318,27 +318,22 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) end # helper function to check if a LLVM module uses values of a certain type -function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) - errors = IRError[] - for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) - bt = backtrace(inst) - push!(errors, ("use of $(string(T_bad)) value", bt, inst)) - end +function check_illegal_value_type(inst::LLVM.Instruction, errors, T_bad::LLVMType) + if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) + bt = backtrace(inst) + err = ("use of $(string(T_bad)) value", bt, inst) + push!(errors, err) end - - return errors end +check_ir_values(mod::LLVM.Module, T_bad::LLVMType) = check_ir_values(mod, (x,errs)->check_illegal_value_type(x, errs, T_bad)) + function check_ir_values(mod::LLVM.Module, T_bad) errors = IRError[] for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - if T_bad(inst) || any(T_bad, operands(inst)) - bt = backtrace(inst) - push!(errors, ("use of $(string(inst))", bt, inst)) - end + T_bad(inst, errors) end return errors From 8c6e38cb05e48675a2b44c2d00dda3a0fb76e281 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 20 Sep 2024 19:44:46 +0200 Subject: [PATCH 7/8] add tests --- test/metal_tests.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 6a6efd5b..de97d90b 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -79,6 +79,53 @@ end @test occursin("air.max.s.v2i64", ir) end +@testset "unsupported type detection" begin + function kernel1(ptr) + buf = reinterpret(Ptr{Float32}, ptr) + val = unsafe_load(buf) + dval = Cdouble(val) + # ccall("extern metal_os_log", llvmcall, Nothing, (Float64,), dval) + Base.llvmcall((""" + declare void @llvm.va_start(i8*) + declare void @llvm.va_end(i8*) + declare void @air.os_log(i8*, i64) + + define void @metal_os_log(...) { + %1 = alloca i8* + %2 = bitcast i8** %1 to i8* + call void @llvm.va_start(i8* %2) + %3 = load i8*, i8** %1 + call void @air.os_log(i8* %3, i64 8) + call void @llvm.va_end(i8* %2) + ret void + } + + define void @entry(double %val) #0 { + call void (...) @metal_os_log(double %val) + ret void + } + + attributes #0 = { alwaysinline }""", "entry"), + Nothing, Tuple{Float64}, dval) + return + end + + + ir = sprint(io->Metal.code_llvm(io, kernel1, Tuple{Core.LLVMPtr{Float32,1}}; validate=true)) + @test occursin("@metal_os_log", ir) + + function kernel2(ptr) + val = unsafe_load(ptr) + res = val * val + unsafe_store!(ptr, res) + return + end + + @test_throws_message(InvalidIRError, Metal.code_llvm(devnull, kernel2, Tuple{Core.LLVMPtr{Float64,1}}; validate=true)) do msg + occursin("unsupported use of double value", msg) + end +end + end end From b77553c502cf836b8bae0dee56c9291f8ee35cfa Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Sat, 21 Sep 2024 12:53:21 +0200 Subject: [PATCH 8/8] NFC clean-up. --- src/metal.jl | 22 +++++++++------------- src/validation.jl | 26 ++++++++++---------------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index e5fc17f2..e3978071 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -89,11 +89,11 @@ end function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors = IRError[] - function is_valid_double_use(inst::LLVM.Instruction, errors) + # Metal does not support double precision, except for logging + function is_illegal_double(val) T_bad = LLVM.DoubleType() - - if value_type(inst) != T_bad || all(param->value_type(param) != T_bad, operands(inst)) - return + if value_type(val) != T_bad + return false end function used_for_logging(use::LLVM.Use) @@ -104,21 +104,17 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) return true end end - return false end - - if all(used_for_logging, uses(inst)) - return + if all(used_for_logging, uses(val)) + return false end - bt = backtrace(inst) - err = ("use of double value", bt, inst) - push!(errors, err) + return true end + append!(errors, check_ir_values(mod, is_illegal_double, "use of double value")) - # Metal never supports double precision - append!(errors, check_ir_values(mod, is_valid_double_use)) + # Metal never supports 128-bit integers append!(errors, check_ir_values(mod, LLVM.IntType(128))) errors diff --git a/src/validation.jl b/src/validation.jl index 3d3986af..104cf19b 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -317,24 +317,18 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst) return errors end -# helper function to check if a LLVM module uses values of a certain type - -function check_illegal_value_type(inst::LLVM.Instruction, errors, T_bad::LLVMType) - if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst)) - bt = backtrace(inst) - err = ("use of $(string(T_bad)) value", bt, inst) - push!(errors, err) - end -end - -check_ir_values(mod::LLVM.Module, T_bad::LLVMType) = check_ir_values(mod, (x,errs)->check_illegal_value_type(x, errs, T_bad)) - -function check_ir_values(mod::LLVM.Module, T_bad) +# helper function to check for illegal values in an LLVM module +function check_ir_values(mod::LLVM.Module, predicate, msg="value") errors = IRError[] - for fun in functions(mod), bb in blocks(fun), inst in instructions(bb) - T_bad(inst, errors) + if predicate(inst) || any(predicate, operands(inst)) + bt = backtrace(inst) + push!(errors, (msg, bt, inst)) + end end - return errors end +## shorthand to check for illegal value types +function check_ir_values(mod::LLVM.Module, T_bad::LLVMType) + check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value") +end