Skip to content

Commit 0cbe25c

Browse files
committed
Add a tactic to genralize the widths of bitvectors
1 parent c2826a2 commit 0cbe25c

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

Blase/Blase/MultiWidth/Tactic.lean

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,123 @@ def evalBvMultiWidth : Tactic := fun
10931093
solveEntrypoint g cfg
10941094
| _ => throwUnsupportedSyntax
10951095

1096+
/-
1097+
A tactic to generalize the width of BitVectors
1098+
-/
1099+
1100+
structure State where
1101+
mapping : Std.HashMap Expr Expr
1102+
deriving Inhabited
1103+
1104+
abbrev GenM := StateT State TermElabM
1105+
1106+
def State.get? (e : Expr) : GenM (Option Expr) := do
1107+
let s ← get
1108+
for (e', x) in s.mapping do
1109+
if ← isDefEq e e' then
1110+
return x
1111+
pure none
1112+
1113+
def State.setMapping (e x : Expr) : GenM Unit := do
1114+
modify fun s =>
1115+
{s with mapping := s.mapping.insert e x}
1116+
1117+
def State.add? (e : Expr) : GenM Expr := do
1118+
match ← get? e with
1119+
| some x => pure x
1120+
| none =>
1121+
if e.isFVar || e.isBVar then pure e else
1122+
let x ← mkFreshExprMVar (some (.const ``Nat [])) (userName := `w)
1123+
setMapping e x
1124+
pure x
1125+
1126+
/--
1127+
This table determines which arguments of important functions are bitwidths and
1128+
should be generalized and which ones are normal parameters which should be
1129+
recursively visited.
1130+
-/
1131+
def genTable : Std.HashMap Name (Array Bool) := Id.run do
1132+
let mut table := .emptyWithCapacity 16
1133+
table := table.insert ``BitVec #[true]
1134+
table := table.insert ``BitVec.zeroExtend #[true, true, false]
1135+
table := table.insert ``BitVec.signExtend #[true, true, false]
1136+
table := table.insert ``BitVec.instAdd #[true]
1137+
table := table.insert ``BitVec.instSub #[true]
1138+
table := table.insert ``BitVec.instMul #[true]
1139+
table := table.insert ``BitVec.instDiv #[true]
1140+
table
1141+
1142+
partial def visit (t : Expr) : GenM Expr := do
1143+
let t ← instantiateMVars t
1144+
match t with
1145+
| .app _ _ =>
1146+
let f := t.getAppFn
1147+
let args := t.getAppArgs
1148+
let table :=
1149+
if let some (f, _) := f.const? then
1150+
genTable[f]?
1151+
else
1152+
none
1153+
let bv? (n : Nat) :=
1154+
match table with
1155+
| .some xs => xs.getD n false
1156+
| .none => false
1157+
/- let f ← visit f -/
1158+
args.zipIdx.foldlM (init := f) fun res (arg, i) => do
1159+
let arg ← if bv? i then State.add? arg else visit arg
1160+
pure <| .app res arg
1161+
| .forallE n e₁ e₂ info =>
1162+
pure <| .forallE n (← visit e₁) (← visit e₂) info
1163+
| e =>
1164+
pure e
1165+
1166+
def doBvGeneralize (g : MVarId) : GenM (Expr × MVarId) := do
1167+
let lctx ← getLCtx
1168+
let mut allFVars := #[]
1169+
for h in lctx do
1170+
if not h.isImplementationDetail then
1171+
allFVars := allFVars.push h.fvarId
1172+
let (_, g) ← g.revert allFVars
1173+
let e ← visit (← g.getType)
1174+
let mut newVars := #[]
1175+
for (_, x) in (←get).mapping do
1176+
newVars := newVars.push x
1177+
1178+
let e ← mkForallFVars newVars e (binderInfoForMVars := .default)
1179+
let e ← instantiateMVars e
1180+
pure (e, g)
1181+
1182+
/--
1183+
This tactic tries to generalize the bitvector widths, and only the bitvector
1184+
widths. See `genTable` if the tactic fails to generalize the right parameters
1185+
of a function over bitvectors.
1186+
-/
1187+
syntax (name := bvGeneralize) "bv_generalize" Lean.Parser.Tactic.optConfig : tactic
1188+
@[tactic bvGeneralize]
1189+
def evalBvGeneralize : Tactic := fun
1190+
| `(tactic| bv_generalize) => do
1191+
let g₀ ← getMainGoal
1192+
g₀.withContext do
1193+
let ((e, g), s) ← (doBvGeneralize g₀).run default
1194+
g.withContext do
1195+
let g' ← mkFreshExprMVar (some e)
1196+
-- TODO: instantiate the old goal
1197+
let mut newVals := #[]
1198+
for (e, x) in s.mapping do
1199+
newVals := newVals.push e
1200+
g.assign <| mkAppN g' newVals
1201+
replaceMainGoal [g'.mvarId!]
1202+
| _ => throwUnsupportedSyntax
1203+
1204+
theorem test_bv_generalize_simple (x y : BitVec 32) (zs : List (BitVec 44)) :
1205+
x = x := by
1206+
bv_generalize
1207+
bv_multi_width
1208+
1209+
theorem test_bv_generalize (x y : BitVec 32) (zs : List (BitVec 44)) (z : BitVec 10) (h : 52 + 10 = 42) (heq : x = y) :
1210+
x.zeroExtend 10 = y.zeroExtend 10 + 0 := by
1211+
bv_generalize
1212+
bv_multi_width
10961213

10971214
end Tactic
10981215
end MultiWidth

0 commit comments

Comments
 (0)