Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion SSA/Projects/LLVMRiscV/Pipeline/Combiners.lean
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,77 @@ def mul_by_neg_one_const : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 64)] w
def mul_by_neg_one : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) :=
[⟨_, mul_by_neg_one_const⟩]

/-! ### select_of_zext -/

/-
Test the rewrite:
fold zext(select(cond, true_val, false_val)) -> select(cond, zext(true_val), zext(false_val))
-/
def select_of_zext_rw : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 1)] where
lhs := [LV| {
^entry (%cond: i1, %true_val: i32, %false_val: i32):
%0 = llvm.select %cond, %true_val, %false_val : i32
%1 = llvm.zext %0 : i32 to i64
llvm.return %1 : i64
}]
rhs := [LV| {
^entry (%cond: i1, %true_val: i32, %false_val: i32):
%0 = llvm.zext %true_val : i32 to i64
%1 = llvm.zext %false_val : i32 to i64
%2 = llvm.select %cond, %0, %1 : i64
llvm.return %2 : i64
}]

def select_of_zext : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) :=
[⟨_, select_of_zext_rw⟩]

/-! ### select_of_anyext -/

/-
Test the rewrite:
fold sext(select(cond, true_val, false_val)) -> select(cond, sext(true_val), sext(false_val))
-/
def select_of_anyext_rw : LLVMPeepholeRewriteRefine 64 [Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 32), Ty.llvm (.bitvec 1)] where
lhs := [LV| {
^entry (%cond: i1, %true_val: i32, %false_val: i32):
%0 = llvm.select %cond, %true_val, %false_val : i32
%1 = llvm.sext %0 : i32 to i64
llvm.return %1 : i64
}]
rhs := [LV| {
^entry (%cond: i1, %true_val: i32, %false_val: i32):
%0 = llvm.sext %true_val : i32 to i64
%1 = llvm.sext %false_val : i32 to i64
%2 = llvm.select %cond, %0, %1 : i64
llvm.return %2 : i64
}]

def select_of_anyext : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) :=
[⟨_, select_of_anyext_rw⟩]

/-! ### select_of_truncate -/

/-
Test the rewrite:
fold trunc(select(cond, true_val, false_val)) -> select(cond, trunc(true_val), trunc(false_val))
-/
def select_of_truncate_rw : LLVMPeepholeRewriteRefine 32 [Ty.llvm (.bitvec 64), Ty.llvm (.bitvec 64), Ty.llvm (.bitvec 1)] where
lhs := [LV| {
^entry (%cond: i1, %true_val: i64, %false_val: i64):
%0 = llvm.select %cond, %true_val, %false_val : i64
%1 = llvm.trunc %0 : i64 to i32
llvm.return %1 : i32
}]
rhs := [LV| {
^entry (%cond: i1, %true_val: i64, %false_val: i64):
%0 = llvm.trunc %true_val : i64 to i32
%1 = llvm.trunc %false_val : i64 to i32
%2 = llvm.select %cond, %0, %1 : i32
llvm.return %2 : i32
}]

def select_of_truncate : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) :=
[⟨_, select_of_truncate_rw⟩]

/-- ### commute_constant_to_rhs
(C op x) → (x op C)
Expand Down Expand Up @@ -1123,6 +1194,11 @@ def LLVMIR_identity_combines_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ)

def LLVMIR_identity_combines_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := anyext_trunc_fold

def LLVMIR_cast_combines_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64 Γ) :=
select_of_zext ++ select_of_anyext

def LLVMIR_cast_combines_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) := select_of_truncate

/-- Post-legalization combine pass for RISCV -/
def PostLegalizerCombiner_RISCV: List (Σ Γ,RISCVPeepholeRewrite Γ) :=
RISCV_identity_combines ++
Expand All @@ -1134,11 +1210,12 @@ def PostLegalizerCombiner_LLVMIR_64 : List (Σ Γ, LLVMPeepholeRewriteRefine 64
sub_to_add ++
redundant_and ++
select_same_val ++
LLVMIR_cast_combines_64 ++
LLVMIR_identity_combines_64

/-- Post-legalization combine pass for LLVM specialized for i64 type -/
def PostLegalizerCombiner_LLVMIR_32 : List (Σ Γ, LLVMPeepholeRewriteRefine 32 Γ) :=
LLVMIR_identity_combines_32
LLVMIR_identity_combines_32 ++ LLVMIR_cast_combines_32

/-- We group all the rewrites that form the pre-legalization optimizations in GlobalISel-/
def GLobalISelO0PreLegalizerCombiner :
Expand Down
37 changes: 37 additions & 0 deletions SSA/Projects/LLVMRiscV/Tests/Combiners.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,43 @@
import SSA.Projects.InstCombine.LLVM.Opt
import LeanMLIR.Framework.Print

/--
info: {
^bb0(%0 : i1, %1 : i32, %2 : i32):
%3 = "llvm.sext"(%1) : (i32) -> (i64)
%4 = "llvm.sext"(%2) : (i32) -> (i64)
%5 = "llvm.select"(%0, %3, %4) : (i1, i64, i64) -> (i64)
"llvm.return"(%5) : (i64) -> ()
}
-/
#guard_msgs in
#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_anyext_rw.lhs)).val).val

/--
info: {
^bb0(%0 : i1, %1 : i32, %2 : i32):
%3 = "llvm.zext"(%1) : (i32) -> (i64)
%4 = "llvm.zext"(%2) : (i32) -> (i64)
%5 = "llvm.select"(%0, %3, %4) : (i1, i64, i64) -> (i64)
"llvm.return"(%5) : (i64) -> ()
}
-/
#guard_msgs in
#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_zext_rw.lhs)).val).val

/--
info: {
^bb0(%0 : i1, %1 : i64, %2 : i64):
%3 = "llvm.trunc"(%1) : (i64) -> (i32)
%4 = "llvm.trunc"(%2) : (i64) -> (i32)
%5 = "llvm.select"(%0, %3, %4) : (i1, i32, i32) -> (i32)
"llvm.return"(%5) : (i32) -> ()
}
-/
#guard_msgs in
#eval! Com.print (DCE.dce' (DCE.dce' (multiRewritePeephole 100 GLobalISelPostLegalizerCombiner select_of_truncate_rw.lhs)).val).val


/--
info: {
^bb0(%0 : i64, %1 : i64):
Expand Down
Loading