1
1
package com .wavesplatform .lang .v1 .estimator .v3
2
2
3
- import cats .implicits .*
3
+ import cats .implicits .{ toBifunctorOps , toFoldableOps , toTraverseOps }
4
4
import cats .{Id , Monad }
5
5
import com .wavesplatform .lang .v1 .FunctionHeader
6
6
import com .wavesplatform .lang .v1 .FunctionHeader .User
@@ -13,7 +13,7 @@ import monix.eval.Coeval
13
13
14
14
import scala .util .Try
15
15
16
- case class ScriptEstimatorV3 (fixOverflow : Boolean , overhead : Boolean ) extends ScriptEstimator {
16
+ case class ScriptEstimatorV3 (fixOverflow : Boolean , overhead : Boolean , letFixes : Boolean ) extends ScriptEstimator {
17
17
private val overheadCost : Long = if (overhead) 1 else 0
18
18
19
19
override val version : Int = 3
@@ -39,55 +39,45 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
39
39
globalDeclarationsMode : Boolean
40
40
): (EstimatorContext , Either [EstimationError , Long ]) = {
41
41
val ctxFuncs = funcs.view.mapValues((_, Set [String ]())).toMap
42
- evalExpr(expr, globalDeclarationsMode).run(EstimatorContext (ctxFuncs)).value
42
+ evalExpr(expr, Set (), globalDeclarationsMode).run(EstimatorContext (ctxFuncs)).value
43
43
}
44
44
45
- private def evalExpr (t : EXPR , globalDeclarationsMode : Boolean = false ): EvalM [Long ] =
45
+ private def evalExpr (t : EXPR , activeFuncArgs : Set [ String ], globalDeclarationsMode : Boolean = false ): EvalM [Long ] =
46
46
if (Thread .currentThread().isInterrupted)
47
47
raiseError(" Script estimation was interrupted" )
48
48
else
49
49
t match {
50
- case LET_BLOCK (let, inner) => evalLetBlock(let, inner, globalDeclarationsMode)
51
- case BLOCK (let : LET , inner) => evalLetBlock(let, inner, globalDeclarationsMode)
52
- case BLOCK (f : FUNC , inner) => evalFuncBlock(f, inner, globalDeclarationsMode)
50
+ case LET_BLOCK (let, inner) => evalLetBlock(let, inner, activeFuncArgs, globalDeclarationsMode)
51
+ case BLOCK (let : LET , inner) => evalLetBlock(let, inner, activeFuncArgs, globalDeclarationsMode)
52
+ case BLOCK (f : FUNC , inner) => evalFuncBlock(f, inner, activeFuncArgs, globalDeclarationsMode)
53
53
case BLOCK (_ : FAILED_DEC , _) => zero
54
- case REF (str) => markRef (str)
54
+ case REF (str) => evalRef (str, activeFuncArgs )
55
55
case _ : EVALUATED => const(overheadCost)
56
- case IF (cond, t1, t2) => evalIF(cond, t1, t2)
57
- case GETTER (expr, _) => evalGetter(expr)
58
- case FUNCTION_CALL (header, args) => evalFuncCall(header, args)
56
+ case IF (cond, t1, t2) => evalIF(cond, t1, t2, activeFuncArgs )
57
+ case GETTER (expr, _) => evalGetter(expr, activeFuncArgs )
58
+ case FUNCTION_CALL (header, args) => evalFuncCall(header, args, activeFuncArgs )
59
59
case _ : FAILED_EXPR => zero
60
60
}
61
61
62
- private def evalHoldingFuncs ( expr : EXPR ): EvalM [Long ] =
62
+ private def evalLetBlock ( let : LET , nextExpr : EXPR , activeFuncArgs : Set [ String ], globalDeclarationsMode : Boolean ): EvalM [Long ] =
63
63
for {
64
+ _ <- if (globalDeclarationsMode) saveGlobalLetCost(let, activeFuncArgs) else doNothing
64
65
startCtx <- get[Id , EstimatorContext , EstimationError ]
65
- cost <- evalExpr(expr)
66
- _ <- update(funcs.set(_)(startCtx.funcs))
67
- } yield cost
68
-
69
- private def evalLetBlock (let : LET , inner : EXPR , globalDeclarationsMode : Boolean ): EvalM [Long ] =
70
- for {
71
- startCtx <- get[Id , EstimatorContext , EstimationError ]
72
- overlap = startCtx.usedRefs.contains(let.name)
73
- _ <- update(usedRefs.modify(_)(_ - let.name))
74
- letEval = evalHoldingFuncs(let.value)
75
- _ <- if (globalDeclarationsMode) saveGlobalLetCost(let) else doNothing
76
- nextCost <- evalExpr(inner, globalDeclarationsMode)
77
- ctx <- get[Id , EstimatorContext , EstimationError ]
78
- letCost <- if (ctx.usedRefs.contains(let.name)) letEval else zero
79
- _ <- update(usedRefs.modify(_)(r => if (overlap) r + let.name else r - let.name))
80
- result <- sum(nextCost, letCost)
66
+ letEval = evalHoldingFuncs(let.value, activeFuncArgs)
67
+ _ <- beforeNextExprEval(let, letEval)
68
+ nextExprCost <- evalExpr(nextExpr, activeFuncArgs, globalDeclarationsMode)
69
+ nextExprCtx <- get[Id , EstimatorContext , EstimationError ]
70
+ _ <- afterNextExprEval(let, startCtx)
71
+ letCost <- if (nextExprCtx.usedRefs.contains(let.name)) letEval else const(0L )
72
+ result <- sum(nextExprCost, letCost)
81
73
} yield result
82
74
83
- private def saveGlobalLetCost (let : LET ): EvalM [Unit ] = {
75
+ private def saveGlobalLetCost (let : LET , activeFuncArgs : Set [ String ] ): EvalM [Unit ] = {
84
76
val costEvaluation =
85
77
for {
86
- startCtx <- get[Id , EstimatorContext , EstimationError ]
87
- bodyCost <- evalExpr(let.value)
88
- bodyEvalCtx <- get[Id , EstimatorContext , EstimationError ]
89
- usedRefs = bodyEvalCtx.usedRefs diff startCtx.usedRefs
90
- letCosts <- usedRefs.toSeq.traverse(bodyEvalCtx.globalLetEvals.getOrElse(_, zero))
78
+ (bodyCost, usedRefs) <- withUsedRefs(evalExpr(let.value, activeFuncArgs))
79
+ ctx <- get[Id , EstimatorContext , EstimationError ]
80
+ letCosts <- usedRefs.toSeq.traverse(ctx.globalLetEvals.getOrElse(_, zero))
91
81
} yield bodyCost + letCosts.sum
92
82
for {
93
83
cost <- local(costEvaluation)
@@ -100,26 +90,47 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
100
90
} yield ()
101
91
}
102
92
103
- private def evalFuncBlock (func : FUNC , inner : EXPR , globalDeclarationsMode : Boolean ): EvalM [Long ] =
93
+ private def beforeNextExprEval (let : LET , eval : EvalM [Long ]): EvalM [Unit ] =
94
+ for {
95
+ cost <- local(eval)
96
+ _ <- update(ctx =>
97
+ usedRefs
98
+ .modify(ctx)(_ - let.name)
99
+ .copy(refsCosts = ctx.refsCosts + (let.name -> cost))
100
+ )
101
+ } yield ()
102
+
103
+ private def afterNextExprEval (let : LET , startCtx : EstimatorContext ): EvalM [Unit ] =
104
+ update(ctx =>
105
+ usedRefs
106
+ .modify(ctx)(r => if (startCtx.usedRefs.contains(let.name)) r + let.name else r - let.name)
107
+ .copy(refsCosts =
108
+ if (startCtx.refsCosts.contains(let.name))
109
+ ctx.refsCosts + (let.name -> startCtx.refsCosts(let.name))
110
+ else
111
+ ctx.refsCosts - let.name
112
+ )
113
+ )
114
+
115
+ private def evalFuncBlock (func : FUNC , nextExpr : EXPR , activeFuncArgs : Set [String ], globalDeclarationsMode : Boolean ): EvalM [Long ] =
104
116
for {
105
- startCtx <- get[Id , EstimatorContext , EstimationError ]
106
- _ <- checkShadowing(func, startCtx)
107
- funcCost <- evalHoldingFuncs(func.body)
108
- bodyEvalCtx <- get[Id , EstimatorContext , EstimationError ]
109
- refsUsedInBody = bodyEvalCtx.usedRefs diff startCtx.usedRefs
110
- _ <- if (globalDeclarationsMode) saveGlobalFuncCost(func.name, funcCost, bodyEvalCtx, refsUsedInBody) else doNothing
111
- _ <- handleUsedRefs(func.name, funcCost, startCtx, refsUsedInBody)
112
- nextCost <- evalExpr(inner, globalDeclarationsMode)
113
- } yield nextCost
117
+ startCtx <- get[Id , EstimatorContext , EstimationError ]
118
+ _ <- checkShadowing(func, startCtx)
119
+ (funcCost, refsUsedInBody) <- withUsedRefs(evalHoldingFuncs(func.body, activeFuncArgs ++ func.args))
120
+ _ <- if (globalDeclarationsMode) saveGlobalFuncCost(func.name, funcCost, refsUsedInBody) else doNothing
121
+ _ <- handleUsedRefs(func.name, funcCost, startCtx, refsUsedInBody)
122
+ nextExprCost <- evalExpr(nextExpr, activeFuncArgs, globalDeclarationsMode)
123
+ } yield nextExprCost
114
124
115
125
private def checkShadowing (func : FUNC , startCtx : EstimatorContext ): EvalM [Any ] =
116
126
if (fixOverflow && startCtx.funcs.contains(FunctionHeader .User (func.name)))
117
127
raiseError(s " Function ' ${func.name}${func.args.mkString(" (" , " , " , " )" )}' shadows preceding declaration " )
118
128
else
119
129
doNothing
120
130
121
- private def saveGlobalFuncCost (name : String , funcCost : Long , ctx : EstimatorContext , refsUsedInBody : Set [String ]): EvalM [Unit ] =
131
+ private def saveGlobalFuncCost (name : String , funcCost : Long , refsUsedInBody : Set [String ]): EvalM [Unit ] =
122
132
for {
133
+ ctx <- get[Id , EstimatorContext , EstimationError ]
123
134
letCosts <- local(refsUsedInBody.toSeq.traverse(ctx.globalLetEvals.getOrElse(_, zero)))
124
135
totalCost = math.max(1 , funcCost + letCosts.sum)
125
136
_ <- set[Id , EstimatorContext , EstimationError ](ctx.copy(globalFunctionsCosts = ctx.globalFunctionsCosts + (name -> totalCost)))
@@ -135,46 +146,75 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
135
146
}
136
147
)
137
148
138
- private def evalIF (cond : EXPR , ifTrue : EXPR , ifFalse : EXPR ): EvalM [Long ] =
149
+ private def evalIF (cond : EXPR , ifTrue : EXPR , ifFalse : EXPR , activeFuncArgs : Set [ String ] ): EvalM [Long ] =
139
150
for {
140
- cond <- evalHoldingFuncs(cond)
141
- right <- evalHoldingFuncs(ifTrue)
142
- left <- evalHoldingFuncs(ifFalse)
151
+ cond <- evalHoldingFuncs(cond, activeFuncArgs )
152
+ right <- evalHoldingFuncs(ifTrue, activeFuncArgs )
153
+ left <- evalHoldingFuncs(ifFalse, activeFuncArgs )
143
154
r1 <- sum(cond, Math .max(right, left))
144
155
r2 <- sum(r1, overheadCost)
145
156
} yield r2
146
157
147
- private def markRef (key : String ): EvalM [Long ] =
148
- update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
158
+ private def evalRef (key : String , activeFuncArgs : Set [String ]): EvalM [Long ] =
159
+ if (activeFuncArgs.contains(key) && letFixes)
160
+ const(overheadCost)
161
+ else
162
+ update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
149
163
150
- private def evalGetter (expr : EXPR ): EvalM [Long ] =
151
- evalExpr(expr).flatMap(sum(_, overheadCost))
164
+ private def evalGetter (expr : EXPR , activeFuncArgs : Set [ String ] ): EvalM [Long ] =
165
+ evalExpr(expr, activeFuncArgs ).flatMap(sum(_, overheadCost))
152
166
153
- private def evalFuncCall (header : FunctionHeader , args : List [EXPR ]): EvalM [Long ] =
167
+ private def evalFuncCall (header : FunctionHeader , args : List [EXPR ], activeFuncArgs : Set [ String ] ): EvalM [Long ] =
154
168
for {
155
- ctx <- get[Id , EstimatorContext , EstimationError ]
156
- (bodyCost, bodyUsedRefs) <- funcs
157
- .get(ctx)
158
- .get(header)
159
- .map(const)
160
- .getOrElse(
161
- raiseError[Id , EstimatorContext , EstimationError , (Coeval [Long ], Set [String ])](s " function ' $header' not found " )
162
- )
163
- _ <- update(
164
- (funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
165
- (
166
- funcs + ((header, (bodyCost, Set [String ]()))),
167
- usedRefs ++ bodyUsedRefs
168
- )
169
- }
170
- )
171
- argsCosts <- args.traverse(evalHoldingFuncs)
172
- argsCostsSum <- argsCosts.foldM(0L )(sum)
173
- bodyCostV = bodyCost.value()
174
- correctedBodyCost = if (! overhead && bodyCostV == 0 ) 1 else bodyCostV
169
+ ctx <- get[Id , EstimatorContext , EstimationError ]
170
+ (bodyCost, bodyUsedRefs) <- getFuncCost(header, ctx)
171
+ _ <- setFuncToCtx(header, bodyCost, bodyUsedRefs)
172
+ (argsCosts, argsUsedRefs) <- withUsedRefs(args.traverse(evalHoldingFuncs(_, activeFuncArgs)))
173
+ argsCostsSum <- argsCosts.foldM(0L )(sum)
174
+ bodyCostV = bodyCost.value()
175
+ correctedBodyCost =
176
+ if (! overhead && ! letFixes && bodyCostV == 0 ) 1
177
+ else if (letFixes && bodyCostV == 0 && isBlankFunc(bodyUsedRefs ++ argsUsedRefs, ctx.refsCosts)) 1
178
+ else bodyCostV
175
179
result <- sum(argsCostsSum, correctedBodyCost)
176
180
} yield result
177
181
182
+ private def setFuncToCtx (header : FunctionHeader , bodyCost : Coeval [Long ], bodyUsedRefs : Set [EstimationError ]): EvalM [Unit ] =
183
+ update(
184
+ (funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
185
+ (
186
+ funcs + (header -> (bodyCost, Set ())),
187
+ usedRefs ++ bodyUsedRefs
188
+ )
189
+ }
190
+ )
191
+
192
+ private def getFuncCost (header : FunctionHeader , ctx : EstimatorContext ): EvalM [(Coeval [Long ], Set [EstimationError ])] =
193
+ funcs
194
+ .get(ctx)
195
+ .get(header)
196
+ .map(const)
197
+ .getOrElse(
198
+ raiseError[Id , EstimatorContext , EstimationError , (Coeval [Long ], Set [EstimationError ])](s " function ' $header' not found " )
199
+ )
200
+
201
+ private def isBlankFunc (usedRefs : Set [String ], refsCosts : Map [String , Long ]): Boolean =
202
+ ! usedRefs.exists(refsCosts.get(_).exists(_ > 0 ))
203
+
204
+ private def evalHoldingFuncs (expr : EXPR , activeFuncArgs : Set [String ]): EvalM [Long ] =
205
+ for {
206
+ startCtx <- get[Id , EstimatorContext , EstimationError ]
207
+ cost <- evalExpr(expr, activeFuncArgs)
208
+ _ <- update(funcs.set(_)(startCtx.funcs))
209
+ } yield cost
210
+
211
+ private def withUsedRefs [A ](eval : EvalM [A ]): EvalM [(A , Set [String ])] =
212
+ for {
213
+ ctxBefore <- get[Id , EstimatorContext , EstimationError ]
214
+ result <- eval
215
+ ctxAfter <- get[Id , EstimatorContext , EstimationError ]
216
+ } yield (result, ctxAfter.usedRefs diff ctxBefore.usedRefs)
217
+
178
218
private def update (f : EstimatorContext => EstimatorContext ): EvalM [Unit ] =
179
219
modify[Id , EstimatorContext , EstimationError ](f)
180
220
@@ -192,3 +232,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean) extends Sc
192
232
liftEither(Try (r).toEither.leftMap(_ => " Illegal script" ))
193
233
}
194
234
}
235
+
236
+ object ScriptEstimatorV3 {
237
+ val latest = ScriptEstimatorV3 (fixOverflow = true , overhead = false , letFixes = true )
238
+ }
0 commit comments