Skip to content

Commit 66fbc1f

Browse files
ZeroExpr must check for a disjonction for multiplication (#132)
* zeroExpr must check for a disjonction for multiplication * mirror was not mirroring the mult
1 parent 5236f72 commit 66fbc1f

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tutorials/const-fold/ConstFold.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ object ConstFold:
2626
case Var(_) => false
2727
case Add(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
2828
case Minus(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
29-
case Mul(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
29+
case Mul(e1, e2) => zeroExpr(e1) || zeroExpr(e2)
3030

31-
def lemma(ctx: Env, @induct e: Expr): Unit = {
31+
def lemma(ctx: Env, e: Expr): Unit = {
3232
require(zeroExpr(e))
33-
()
33+
e match
34+
case Number(_) => ()
35+
case Var(_) => ()
36+
case Add(e1, e2) => lemma(ctx, e1); lemma(ctx, e2)
37+
case Minus(e1, e2) => lemma(ctx, e1); lemma(ctx, e2)
38+
case Mul(e1, e2) => if (zeroExpr(e1)) then lemma(ctx, e1) else lemma(ctx, e2)
39+
3440
}.ensuring(_ => evaluate(ctx, e) == 0)
3541

3642
def mirror(e: Expr)(anyCtx: Env = zeroEnv): Expr = {
@@ -39,7 +45,7 @@ object ConstFold:
3945
case Var(name) => e
4046
case Add(e1, e2) => Add(mirror(e2)(anyCtx), mirror(e1)(anyCtx))
4147
case Minus(e1, e2) => Minus(mirror(e1)(anyCtx), mirror(e2)(anyCtx))
42-
case Mul(e1, e2) => Mul(mirror(e1)(anyCtx), mirror(e2)(anyCtx))
48+
case Mul(e1, e2) => Mul(mirror(e2)(anyCtx), mirror(e1)(anyCtx))
4349
}.ensuring(evaluate(anyCtx, _) == evaluate(anyCtx,e))
4450

4551
abstract class SoundSimplifier:

0 commit comments

Comments
 (0)