Skip to content

Commit 1ce8040

Browse files
authored
Merge pull request #2762 from GaloisInc/bh/term-var-types
Enforce consistent variable typing contexts for all SAWCore terms
2 parents 919d427 + 31cb81f commit 1ce8040

File tree

9 files changed

+227
-230
lines changed

9 files changed

+227
-230
lines changed

otherTests/saw-core/Tests/Functor.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ shared :: TermIndex -> TermF Term -> Term
5050
shared ix t = STApp {
5151
stAppIndex = ix,
5252
stAppHash = hash t,
53-
stAppFreeVars = mempty,
53+
stAppVarTypes = mempty,
5454
stAppTermF = t
5555
}
5656

saw-central/src/SAWCentral/Crucible/Common/ResolveSetupValue.hs

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
-- backends.
33
{-# Language DataKinds, TypeOperators, GADTs, TypeApplications #-}
44
{-# Language ImplicitParams #-}
5-
module SAWCentral.Crucible.Common.ResolveSetupValue (
5+
module SAWCentral.Crucible.Common.ResolveSetupValue (
66
resolveBoolTerm, resolveBoolTerm',
77
resolveBitvectorTerm, resolveBitvectorTerm',
88
ResolveRewrite(..),
99
) where
1010

11-
import qualified Data.Map as Map
1211
import Data.Set(Set)
1312
import qualified Data.BitVector.Sized as BV
1413
import Data.Parameterized.Some (Some(..))
@@ -19,7 +18,6 @@ import qualified What4.Interface as W4
1918

2019

2120
import SAWCore.SharedTerm
22-
import SAWCore.Name
2321
import qualified SAWCore.Prim as Prim
2422

2523
import qualified SAWCore.Simulator.Concrete as Concrete
@@ -32,7 +30,6 @@ import SAWCentral.Crucible.Common
3230

3331
import SAWCentral.Proof (TheoremNonce)
3432
import SAWCore.Rewriter (Simpset, rewriteSharedTerm)
35-
import qualified CryptolSAWCore.Simpset as Cryptol
3633
import SAWCoreWhat4.What4(w4EvalAny, valueToSymExpr)
3734

3835
import Cryptol.TypeCheck.Type (tIsBit, tIsSeq, tIsNum)
@@ -82,21 +79,13 @@ resolveTerm sym unint bt rr tm =
8279
_ -> fail "resolveTerm: expected `Bool` or bit-vector"
8380

8481
| rrWhat4Eval rr ->
85-
do -- Try to use rewrites to simplify the term
86-
cryptol_ss <- Cryptol.mkCryptolSimpset @TheoremNonce sc
87-
tm'' <- snd <$> rewriteSharedTerm sc cryptol_ss tm'
88-
tm''' <- basicRewrite sc tm''
89-
if all isPreludeName (Map.elems (getConstantSet tm''')) then
90-
do
91-
(_, _, _, p) <- w4EvalAny sym st sc mempty unint tm'''
92-
case valueToSymExpr p of
93-
Just (Some y)
94-
| Just Refl <- testEquality bt ty -> pure y
95-
| otherwise -> typeError (show ty)
96-
where ty = W4.exprType y
97-
_ -> fail ("resolveTerm: unexpected w4Eval result " ++ show p)
98-
else
99-
bindSAWTerm sym st bt tm'''
82+
do (_, _, _, p) <- w4EvalAny sym st sc mempty unint tm'
83+
case valueToSymExpr p of
84+
Just (Some y)
85+
| Just Refl <- testEquality bt ty -> pure y
86+
| otherwise -> typeError (show ty)
87+
where ty = W4.exprType y
88+
_ -> fail ("resolveTerm: unexpected w4Eval result " ++ show p)
10089

10190
-- Just bind the term
10291
| otherwise -> bindSAWTerm sym st bt tm'
@@ -107,11 +96,6 @@ resolveTerm sym unint bt rr tm =
10796
Nothing -> pure
10897
Just ss -> \t -> snd <$> rewriteSharedTerm sc ss t
10998

110-
isPreludeName nm =
111-
case nm of
112-
ModuleIdentifier ident -> identModule ident == preludeName
113-
_ -> False
114-
11599
checkType sc =
116100
do
117101
schema <- ttType <$> mkTypedTerm sc tm
@@ -146,9 +130,8 @@ resolveBoolTerm sym unint = resolveBoolTerm' sym unint noResolveRewrite
146130
resolveBitvectorTerm' ::
147131
(1 W4.<= w) => Sym -> Set VarIndex -> W4.NatRepr w -> ResolveRewrite -> Term -> IO (W4.SymBV Sym w)
148132
resolveBitvectorTerm' sym unint w = resolveTerm sym unint (W4.BaseBVRepr w)
149-
133+
150134
-- 'resolveTerm' specialized to bit-vectors, without rewriting.
151-
resolveBitvectorTerm ::
135+
resolveBitvectorTerm ::
152136
(1 W4.<= w) => Sym -> Set VarIndex -> W4.NatRepr w -> Term -> IO (W4.SymBV Sym w)
153137
resolveBitvectorTerm sym unint w = resolveBitvectorTerm' sym unint w noResolveRewrite
154-

saw-central/src/SAWCentral/SolverCache.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ import Control.Monad (when, forM_)
7979
import System.Timeout (timeout)
8080

8181
import GHC.Generics (Generic)
82+
import qualified Data.IntMap as IntMap
8283
import Data.IORef (IORef, newIORef, modifyIORef, readIORef)
8384
import Data.Time.Clock (UTCTime, getCurrentTime)
8485
import Data.Tuple.Extra (first, firstM)
@@ -116,6 +117,7 @@ import SAWCore.Name (VarName(..))
116117
import SAWCore.SATQuery
117118
import SAWCore.ExternalFormat
118119
import SAWCore.SharedTerm
120+
import SAWCore.Term.Raw (varTypes)
119121

120122
import SAWCentral.Options
121123
import SAWCentral.Proof
@@ -300,7 +302,8 @@ mkSolverCacheKey :: SharedContext -> SolverBackendVersions ->
300302
[SolverBackendOption] -> SATQuery -> IO SolverCacheKey
301303
mkSolverCacheKey sc vs opts satq = do
302304
body <- satQueryAsPropTerm sc satq
303-
satVars <- traverse (scFirstOrderType sc) (satVariables satq)
305+
let mkVar x _fot = IntMap.lookup (vnIndex x) (varTypes body)
306+
let satVars = Map.mapMaybeWithKey mkVar (satVariables satq)
304307
let vars = Map.toList satVars ++
305308
filter (\(x, _) -> vnIndex x `elem` satUninterp satq)
306309
(getAllVars body)

saw-core/src/SAWCore/Rewriter.hs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,25 @@ rewriteSharedTerm sc ss t0 =
719719
rewriteAll STApp{ stAppIndex = tidx, stAppTermF = tf } =
720720
useCache ?cache tidx (traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop)
721721

722-
traverseTF :: forall b. (b -> IO b) -> TermF b -> IO (TermF b)
723-
traverseTF _ tf@(Constant {}) = pure tf
724-
traverseTF f tf = traverse f tf
722+
traverseTF :: (Term -> IO Term) -> TermF Term -> IO (TermF Term)
723+
traverseTF f tf =
724+
case tf of
725+
-- Maintain invariant that types on Lambda/Pi binders should
726+
-- exactly match types on the bound variables in the body.
727+
Variable {} -> pure tf
728+
Lambda x t1 t2 ->
729+
do t1' <- f t1
730+
var <- scVariable sc x t1'
731+
t2' <- scInstantiate sc (IntMap.singleton (vnIndex x) var) t2
732+
t2'' <- f t2'
733+
pure (Lambda x t1' t2'')
734+
Pi x t1 t2 ->
735+
do t1' <- f t1
736+
var <- scVariable sc x t1'
737+
t2' <- scInstantiate sc (IntMap.singleton (vnIndex x) var) t2
738+
t2'' <- f t2'
739+
pure (Pi x t1' t2'')
740+
_ -> traverse f tf
725741

726742
rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term
727743
rewriteTop t =

saw-core/src/SAWCore/SATQuery.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@ module SAWCore.SATQuery
66
, satQueryAsPropTerm
77
) where
88

9+
import qualified Data.IntMap as IntMap
910
import Data.Map (Map)
11+
import Data.Maybe (mapMaybe)
1012
import Data.Set (Set)
1113
import Data.Foldable (foldrM)
1214

1315
import SAWCore.Name
1416
import SAWCore.FiniteValue
1517
import SAWCore.SharedTerm
18+
import SAWCore.Term.Raw (varTypes)
1619

1720
-- | This datatype represents a satisfiability query that might
1821
-- be dispatched to a solver. It carries a series of assertions
@@ -96,6 +99,7 @@ satQueryAsPropTerm sc satq =
9699
scTupleType sc =<< mapM assertAsPropTerm (satAsserts satq)
97100
where assertAsPropTerm (BoolAssert b) = scEqTrue sc b
98101
assertAsPropTerm (UniversalAssert vars hs g) =
99-
do vars' <- traverse (traverse (scFirstOrderType sc)) vars
100-
scPiList sc vars' =<<
101-
scEqTrue sc =<< foldrM (scImplies sc) g hs
102+
do body <- scEqTrue sc =<< foldrM (scImplies sc) g hs
103+
let varType x = fmap ((,) x) $ IntMap.lookup (vnIndex x) (varTypes body)
104+
let vars' = mapMaybe (varType . fst) vars
105+
scPiList sc vars' body

0 commit comments

Comments
 (0)