Skip to content

Commit 535464a

Browse files
committed
Enhance constant-folding during inlining
1 parent 5ccea40 commit 535464a

File tree

8 files changed

+109
-11
lines changed

8 files changed

+109
-11
lines changed

compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import collection.mutable
1616
/** A utility class offering methods for rewriting inlined code */
1717
class InlineReducer(inliner: Inliner)(using Context):
1818
import tpd.*
19-
import Inliner.{isElideableExpr, DefBuffer}
19+
import Inliner.{isElideableExpr, DefBuffer, inlinedConstToLiteral}
2020
import inliner.{call, newSym, tryInlineArg, paramBindingDef}
2121

2222
extension (tp: Type)
@@ -201,7 +201,7 @@ class InlineReducer(inliner: Inliner)(using Context):
201201
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord,
202202
flags = sym.flags &~ Case).asTerm
203203
adjustErased(copied, rhs)
204-
caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span)))
204+
caseBindingMap += ((sym, ValDef(copied, inlinedConstToLiteral(rhs)).withSpan(sym.span)))
205205

206206
def newTypeBinding(sym: TypeSymbol, alias: Type): Unit = {
207207
val copied = sym.copy(info = TypeAlias(alias), coord = sym.coord).asType
@@ -321,7 +321,7 @@ class InlineReducer(inliner: Inliner)(using Context):
321321
case (pat :: pats1, selector :: selectors1) =>
322322
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm
323323
adjustErased(elem, selector)
324-
val rhs = constToLiteral(selector)
324+
val rhs = inlinedConstToLiteral(selector)
325325
elem.defTree = rhs
326326
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))
327327
reducePattern(caseBindingMap, elem.termRef, pat) &&
@@ -337,7 +337,7 @@ class InlineReducer(inliner: Inliner)(using Context):
337337
else paramCls.asClass.paramAccessors
338338
val selectors =
339339
for (accessor <- caseAccessors)
340-
yield constToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
340+
yield inlinedConstToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
341341
caseAccessors.length == pats.length && reduceSubPatterns(pats, selectors)
342342
}
343343
else false

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,25 @@ object Inliner:
183183

184184
end OpaqueProxy
185185

186+
/** A more powerful version of `constToLiteral` that also recurses through elidable (see [[isElideableExpr]]) [[Block]], [[Inlined]], and [[Typed]] nodes.
187+
*/
188+
def inlinedConstToLiteral(rootTree: Tree)(using Context): Tree =
189+
def rec(tree: Tree): Tree =
190+
inline def recChild(subTree: Tree): Tree =
191+
val res = rec(subTree)
192+
if res eq subTree then tree else res
193+
194+
tree match
195+
case Typed(expr, _) => recChild(expr)
196+
case Inlined(_, _, expr) => recChild(expr)
197+
case Block(_, expr) => recChild(expr)
198+
case _ => constToLiteral(tree)
199+
200+
if isElideableExpr(rootTree) then
201+
rec(rootTree)
202+
else
203+
constToLiteral(rootTree)
204+
186205
private[inlines] def newSym(name: Name, flags: FlagSet, info: Type, span: Span)(using Context): Symbol =
187206
newSymbol(ctx.owner, name, flags, info, coord = span)
188207
end Inliner
@@ -897,7 +916,7 @@ class Inliner(val call: tpd.Tree)(using Context):
897916
//if the projection leads to a typed tree then we stop reduction
898917
resNoReduce
899918
else
900-
val res = constToLiteral(reducedProjection)
919+
val res = inlinedConstToLiteral(reducedProjection)
901920
if resNoReduce ne res then
902921
typed(res, pt) // redo typecheck if reduction changed something
903922
else if res.symbol.isInlineMethod then
@@ -928,19 +947,19 @@ class Inliner(val call: tpd.Tree)(using Context):
928947
override def typedValDef(vdef: untpd.ValDef, sym: Symbol)(using Context): Tree =
929948
val vdef1 =
930949
if sym.is(Inline) then
931-
val rhs = typed(vdef.rhs)
950+
val rhs = inlinedConstToLiteral(typed(vdef.rhs))
932951
sym.info = rhs.tpe
933952
untpd.cpy.ValDef(vdef)(vdef.name, untpd.TypeTree(rhs.tpe), untpd.TypedSplice(rhs))
934953
else vdef
935954
super.typedValDef(vdef1, sym)
936955

937956
override def typedApply(tree: untpd.Apply, pt: Type)(using Context): Tree =
938957
val locked = ctx.typerState.ownedVars
939-
specializeEq(inlineIfNeeded(constToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked))
958+
specializeEq(inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked))
940959

941960
override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree =
942961
val locked = ctx.typerState.ownedVars
943-
val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
962+
val tree1 = inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
944963
if tree1.symbol == defn.QuotedTypeModule_of then
945964
ctx.compilationUnit.needsStaging = true
946965
tree1
@@ -1021,8 +1040,8 @@ class Inliner(val call: tpd.Tree)(using Context):
10211040
case _ => rhs0
10221041
}
10231042
val rhs2 = rhs1 match {
1024-
case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr)
1025-
case _ => constToLiteral(rhs1)
1043+
case Typed(expr, tpt) if rhs1.span.isSynthetic => inlinedConstToLiteral(expr)
1044+
case _ => inlinedConstToLiteral(rhs1)
10261045
}
10271046
val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2)
10281047
val rhs = seq(usedBindings, rhs3)
@@ -1056,7 +1075,7 @@ class Inliner(val call: tpd.Tree)(using Context):
10561075
val meth = tree.symbol
10571076
if meth.isAllOf(DeferredInline) then
10581077
errorTree(tree, em"Deferred inline ${meth.showLocated} cannot be invoked")
1059-
else if Inlines.needsInlining(tree) then Inlines.inlineCall(simplify(tree, pt, locked))
1078+
else if Inlines.needsInlining(tree) then inlinedConstToLiteral(Inlines.inlineCall(simplify(tree, pt, locked)))
10601079
else tree
10611080

10621081
override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree =

tests/neg/i18123b.check

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-- [E007] Type Mismatch Error: tests/neg/i18123b.scala:8:8 -------------------------------------------------------------
2+
8 |def z = y.rep().toUpperCase // error
3+
| ^^^^^^^
4+
| Found: (??? : => Nothing)
5+
| Required: ?{ toUpperCase: ? }
6+
| Note that implicit conversions were not tried because the result of an implicit conversion
7+
| must be more specific than ?{ toUpperCase: <?> }
8+
|
9+
| longer explanation available when compiling with `-explain`

tests/neg/i18123b.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// Minimized version of `tests/pos/i18123.scala` to test #24425.
2+
3+
extension (x: String)
4+
transparent inline def rep(min: Int = 0): String = ???
5+
6+
def y: String = ???
7+
8+
def z = y.rep().toUpperCase // error

tests/pos/i24412.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object test {
2+
import scala.compiletime.erasedValue
3+
4+
inline def contains[T <: Tuple, E]: Boolean = inline erasedValue[T] match {
5+
case _: EmptyTuple => false
6+
case _: (_ *: tail) => contains[tail, E]
7+
}
8+
inline def check[T <: Tuple]: Unit = {
9+
inline if contains[T, Long] && false then ???
10+
}
11+
12+
check[(String, Double)]
13+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
inline def f(): Long =
2+
1L
3+
4+
inline def g(): Long =
5+
inline val x = f()
6+
x
7+
8+
inline def h(): Long =
9+
inline if g() > 0L then 1L else 0L
10+
11+
@main def Test: Unit =
12+
assert(h() == 1L)

tests/run/i24420-inline-val.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
inline def f1(): Long =
2+
1L
3+
4+
inline def f2(): Long =
5+
inline val x = f1() + 1L
6+
x
7+
8+
inline def f3(): Long =
9+
inline val x = f1()
10+
x
11+
12+
inline def g1(): Boolean =
13+
true
14+
15+
inline def g2(): Long =
16+
inline if g1() then 1L else 2L
17+
18+
inline def g3(): Long =
19+
inline if f1() > 0L then 1L else 2L
20+
21+
@main def Test: Unit =
22+
assert(f2() == 2L)
23+
assert(f3() == 1L)
24+
assert(g2() == 1L)
25+
assert(g3() == 1L)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
transparent inline def f(): Long =
2+
1L
3+
4+
transparent inline def g(): Long =
5+
inline val x = f()
6+
x
7+
8+
transparent inline def h(): Long =
9+
inline if g() > 0L then 1L else 0L
10+
11+
@main def Test: Unit =
12+
assert(h() == 1L)

0 commit comments

Comments
 (0)