From 365c90614dd1c11b76233b7c92d3f650975dd43a Mon Sep 17 00:00:00 2001 From: osmanyasar05 Date: Wed, 8 Oct 2025 11:57:53 +0100 Subject: [PATCH] feat: add select_of_ext_trunc --- .../LLVMRiscV/Pipeline/Combiners.lean | 79 ++++++++++++++++++- SSA/Projects/LLVMRiscV/Tests/Combiners.lean | 37 +++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/SSA/Projects/LLVMRiscV/Pipeline/Combiners.lean b/SSA/Projects/LLVMRiscV/Pipeline/Combiners.lean index 00140e935e..c380f54000 100644 --- a/SSA/Projects/LLVMRiscV/Pipeline/Combiners.lean +++ b/SSA/Projects/LLVMRiscV/Pipeline/Combiners.lean @@ -901,6 +901,77 @@ def mul_by_neg_one_const : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 64)] w def mul_by_neg_one : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) := [⟨_, mul_by_neg_one_const⟩] +/-! ### select_of_zext -/ + +/- +Test the rewrite: + fold zext(select(cond, true_val, false_val)) -> select(cond, zext(true_val), zext(false_val)) +-/ +def select_of_zext_rw : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 1)] where + lhs := [LV| { + ^entry (%cond: i1, %true_val: i32, %false_val: i32): + %0 = llvm.select %cond, %true_val, %false_val : i32 + %1 = llvm.zext %0 : i32 to i64 + llvm.return %1 : i64 + }] + rhs := [LV| { + ^entry (%cond: i1, %true_val: i32, %false_val: i32): + %0 = llvm.zext %true_val : i32 to i64 + %1 = llvm.zext %false_val : i32 to i64 + %2 = llvm.select %cond, %0, %1 : i64 + llvm.return %2 : i64 + }] + +def select_of_zext : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) := + [⟨_, select_of_zext_rw⟩] + +/-! ### select_of_anyext -/ + +/- +Test the rewrite: + fold sext(select(cond, true_val, false_val)) -> select(cond, sext(true_val), sext(false_val)) +-/ +def select_of_anyext_rw : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 1)] where + lhs := [LV| { + ^entry (%cond: i1, %true_val: i32, %false_val: i32): + %0 = llvm.select %cond, %true_val, %false_val : i32 + %1 = llvm.sext %0 : i32 to i64 + llvm.return %1 : i64 + }] + rhs := [LV| { + ^entry (%cond: i1, %true_val: i32, %false_val: i32): + %0 = llvm.sext %true_val : i32 to i64 + %1 = llvm.sext %false_val : i32 to i64 + %2 = llvm.select %cond, %0, %1 : i64 + llvm.return %2 : i64 + }] + +def select_of_anyext : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) := + [⟨_, select_of_anyext_rw⟩] + +/-! ### select_of_truncate -/ + +/- +Test the rewrite: + fold trunc(select(cond, true_val, false_val)) -> select(cond, trunc(true_val), trunc(false_val)) +-/ +def select_of_truncate_rw : LLVMPeepholeRewriteRefine 32 [Ty.llvm (.bitvec 64), Ty.llvm (.bitvec 64), Ty.llvm (.bitvec 1)] where + lhs := [LV| { + ^entry (%cond: i1, %true_val: i64, %false_val: i64): + %0 = llvm.select %cond, %true_val, %false_val : i64 + %1 = llvm.trunc %0 : i64 to i32 + llvm.return %1 : i32 + }] + rhs := [LV| { + ^entry (%cond: i1, %true_val: i64, %false_val: i64): + %0 = llvm.trunc %true_val : i64 to i32 + %1 = llvm.trunc %false_val : i64 to i32 + %2 = llvm.select %cond, %0, %1 : i32 + llvm.return %2 : i32 + }] + +def select_of_truncate : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := + [⟨_, select_of_truncate_rw⟩] /-- ### commute_constant_to_rhs (C op x) → (x op C) @@ -1123,6 +1194,11 @@ def LLVMIR_identity_combines_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) def LLVMIR_identity_combines_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := anyext_trunc_fold +def LLVMIR_cast_combines_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) := + select_of_zext ++ select_of_anyext + +def LLVMIR_cast_combines_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := select_of_truncate + /-- Post-legalization combine pass for RISCV -/ def PostLegalizerCombiner_RISCV: List (Σ Γ,RISCVPeepholeRewrite Γ) := RISCV_identity_combines ++ @@ -1134,11 +1210,12 @@ def PostLegalizerCombiner_LLVMIR_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64 sub_to_add ++ redundant_and ++ select_same_val ++ + LLVMIR_cast_combines_64 ++ LLVMIR_identity_combines_64 /-- Post-legalization combine pass for LLVM specialized for i64 type -/ def PostLegalizerCombiner_LLVMIR_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := - LLVMIR_identity_combines_32 + LLVMIR_identity_combines_32 ++ LLVMIR_cast_combines_32 /-- We group all the rewrites that form the pre-legalization optimizations in GlobalISel-/ def GLobalISelO0PreLegalizerCombiner : diff --git a/SSA/Projects/LLVMRiscV/Tests/Combiners.lean b/SSA/Projects/LLVMRiscV/Tests/Combiners.lean index 0059b9ef07..d355961bb2 100644 --- a/SSA/Projects/LLVMRiscV/Tests/Combiners.lean +++ b/SSA/Projects/LLVMRiscV/Tests/Combiners.lean @@ -1,6 +1,43 @@ import SSA.Projects.InstCombine.LLVM.Opt import LeanMLIR.Framework.Print +/-- +info: { + ^bb0(%0 : i1, %1 : i32, %2 : i32): + %3 = "llvm.sext"(%1) : (i32) -> (i64) + %4 = "llvm.sext"(%2) : (i32) -> (i64) + %5 = "llvm.select"(%0, %3, %4) : (i1, i64, i64) -> (i64) + "llvm.return"(%5) : (i64) -> () +} +-/ +#guard_msgs in +#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_anyext_rw.lhs)).val).val + +/-- +info: { + ^bb0(%0 : i1, %1 : i32, %2 : i32): + %3 = "llvm.zext"(%1) : (i32) -> (i64) + %4 = "llvm.zext"(%2) : (i32) -> (i64) + %5 = "llvm.select"(%0, %3, %4) : (i1, i64, i64) -> (i64) + "llvm.return"(%5) : (i64) -> () +} +-/ +#guard_msgs in +#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_zext_rw.lhs)).val).val + +/-- +info: { + ^bb0(%0 : i1, %1 : i64, %2 : i64): + %3 = "llvm.trunc"(%1) : (i64) -> (i32) + %4 = "llvm.trunc"(%2) : (i64) -> (i32) + %5 = "llvm.select"(%0, %3, %4) : (i1, i32, i32) -> (i32) + "llvm.return"(%5) : (i32) -> () +} +-/ +#guard_msgs in +#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_truncate_rw.lhs)).val).val + + /-- info: { ^bb0(%0 : i64, %1 : i64):