Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions tutorials/const-fold/ConstFold.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ object ConstFold:
case Var(_) => false
case Add(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
case Minus(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
case Mul(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
case Mul(e1, e2) => zeroExpr(e1) || zeroExpr(e2)

def lemma(ctx: Env, @induct e: Expr): Unit = {
def lemma(ctx: Env, e: Expr): Unit = {
require(zeroExpr(e))
()
e match
case Number(_) => ()
case Var(_) => ()
case Add(e1, e2) => lemma(ctx, e1); lemma(ctx, e2)
case Minus(e1, e2) => lemma(ctx, e1); lemma(ctx, e2)
case Mul(e1, e2) => if (zeroExpr(e1)) then lemma(ctx, e1) else lemma(ctx, e2)

}.ensuring(_ => evaluate(ctx, e) == 0)

def mirror(e: Expr)(anyCtx: Env = zeroEnv): Expr = {
Expand All @@ -39,7 +45,7 @@ object ConstFold:
case Var(name) => e
case Add(e1, e2) => Add(mirror(e2)(anyCtx), mirror(e1)(anyCtx))
case Minus(e1, e2) => Minus(mirror(e1)(anyCtx), mirror(e2)(anyCtx))
case Mul(e1, e2) => Mul(mirror(e1)(anyCtx), mirror(e2)(anyCtx))
case Mul(e1, e2) => Mul(mirror(e2)(anyCtx), mirror(e1)(anyCtx))
}.ensuring(evaluate(anyCtx, _) == evaluate(anyCtx,e))

abstract class SoundSimplifier:
Expand Down