Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import collection.mutable
/** A utility class offering methods for rewriting inlined code */
class InlineReducer(inliner: Inliner)(using Context):
import tpd.*
import Inliner.{isElideableExpr, DefBuffer}
import Inliner.{isElideableExpr, DefBuffer, inlinedConstToLiteral}
import inliner.{call, newSym, tryInlineArg, paramBindingDef}

extension (tp: Type)
Expand Down Expand Up @@ -201,7 +201,7 @@ class InlineReducer(inliner: Inliner)(using Context):
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord,
flags = sym.flags &~ Case).asTerm
adjustErased(copied, rhs)
caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span)))
caseBindingMap += ((sym, ValDef(copied, inlinedConstToLiteral(rhs)).withSpan(sym.span)))

def newTypeBinding(sym: TypeSymbol, alias: Type): Unit = {
val copied = sym.copy(info = TypeAlias(alias), coord = sym.coord).asType
Expand Down Expand Up @@ -321,7 +321,7 @@ class InlineReducer(inliner: Inliner)(using Context):
case (pat :: pats1, selector :: selectors1) =>
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm
adjustErased(elem, selector)
val rhs = constToLiteral(selector)
val rhs = inlinedConstToLiteral(selector)
elem.defTree = rhs
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))
reducePattern(caseBindingMap, elem.termRef, pat) &&
Expand All @@ -337,7 +337,7 @@ class InlineReducer(inliner: Inliner)(using Context):
else paramCls.asClass.paramAccessors
val selectors =
for (accessor <- caseAccessors)
yield constToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
yield inlinedConstToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
caseAccessors.length == pats.length && reduceSubPatterns(pats, selectors)
}
else false
Expand Down
35 changes: 28 additions & 7 deletions compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ object Inliner:

end OpaqueProxy

/** A more powerful version of [[constToLiteral]] that also can "see through"
* [[Block]], [[Inlined]] and [[Typed]] trees that are elidable (see
* [[isElideableExpr]]).
*/
def inlinedConstToLiteral(rootTree: Tree)(using Context): Tree =
def rec(tree: Tree): Tree =
inline def recChild(subTree: Tree): Tree =
val res = rec(subTree)
if res eq subTree then tree else res

tree match
case Typed(expr, _) => recChild(expr)
case Inlined(_, _, expr) => recChild(expr)
case Block(_, expr) => recChild(expr)
case _ => constToLiteral(tree)

if isElideableExpr(rootTree) then
rec(rootTree)
else
constToLiteral(rootTree)

private[inlines] def newSym(name: Name, flags: FlagSet, info: Type, span: Span)(using Context): Symbol =
newSymbol(ctx.owner, name, flags, info, coord = span)
end Inliner
Expand Down Expand Up @@ -897,7 +918,7 @@ class Inliner(val call: tpd.Tree)(using Context):
//if the projection leads to a typed tree then we stop reduction
resNoReduce
else
val res = constToLiteral(reducedProjection)
val res = inlinedConstToLiteral(reducedProjection)
if resNoReduce ne res then
typed(res, pt) // redo typecheck if reduction changed something
else if res.symbol.isInlineMethod then
Expand Down Expand Up @@ -928,19 +949,19 @@ class Inliner(val call: tpd.Tree)(using Context):
override def typedValDef(vdef: untpd.ValDef, sym: Symbol)(using Context): Tree =
val vdef1 =
if sym.is(Inline) then
val rhs = typed(vdef.rhs)
val rhs = inlinedConstToLiteral(typed(vdef.rhs))
sym.info = rhs.tpe
untpd.cpy.ValDef(vdef)(vdef.name, untpd.TypeTree(rhs.tpe), untpd.TypedSplice(rhs))
else vdef
super.typedValDef(vdef1, sym)

override def typedApply(tree: untpd.Apply, pt: Type)(using Context): Tree =
val locked = ctx.typerState.ownedVars
specializeEq(inlineIfNeeded(constToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked))
specializeEq(inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedApply(tree, pt))), pt, locked))

override def typedTypeApply(tree: untpd.TypeApply, pt: Type)(using Context): Tree =
val locked = ctx.typerState.ownedVars
val tree1 = inlineIfNeeded(constToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
val tree1 = inlineIfNeeded(inlinedConstToLiteral(BetaReduce(super.typedTypeApply(tree, pt))), pt, locked)
if tree1.symbol == defn.QuotedTypeModule_of then
ctx.compilationUnit.needsStaging = true
tree1
Expand Down Expand Up @@ -1021,8 +1042,8 @@ class Inliner(val call: tpd.Tree)(using Context):
case _ => rhs0
}
val rhs2 = rhs1 match {
case Typed(expr, tpt) if rhs1.span.isSynthetic => constToLiteral(expr)
case _ => constToLiteral(rhs1)
case Typed(expr, tpt) if rhs1.span.isSynthetic => inlinedConstToLiteral(expr)
case _ => inlinedConstToLiteral(rhs1)
}
val (usedBindings, rhs3) = dropUnusedDefs(caseBindings, rhs2)
val rhs = seq(usedBindings, rhs3)
Expand Down Expand Up @@ -1056,7 +1077,7 @@ class Inliner(val call: tpd.Tree)(using Context):
val meth = tree.symbol
if meth.isAllOf(DeferredInline) then
errorTree(tree, em"Deferred inline ${meth.showLocated} cannot be invoked")
else if Inlines.needsInlining(tree) then Inlines.inlineCall(simplify(tree, pt, locked))
else if Inlines.needsInlining(tree) then inlinedConstToLiteral(Inlines.inlineCall(simplify(tree, pt, locked)))
else tree

override def typedUnadapted(tree: untpd.Tree, pt: Type, locked: TypeVars)(using Context): Tree =
Expand Down
13 changes: 12 additions & 1 deletion docs/_docs/reference/metaprogramming/inline.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,19 @@ trait InlineConstants:
inline val myShort: Short

object Constants extends InlineConstants:
inline val myShort/*: Short(4)*/ = 4
inline val myShort/*: (4 : Short)*/ = 4
```
<!-- Test case: tests/pos/inline-val-short.scala -->

Inline values that are inside inline methods are only required to be constant _after inlining_. Therefore, the following is valid:

```scala
inline def double(inline x: Int): Int = x * 2
inline def eight: Int =
inline val res = double(4)
res
```
<!-- Test case: tests/pos/inline-val-in-inline-method.scala -->

## Transparent Inline Methods

Expand Down
13 changes: 13 additions & 0 deletions tests/pos/i24412.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object test {
import scala.compiletime.erasedValue

inline def contains[T <: Tuple, E]: Boolean = inline erasedValue[T] match {
case _: EmptyTuple => false
case _: (_ *: tail) => contains[tail, E]
}
inline def check[T <: Tuple]: Unit = {
inline if contains[T, Long] && false then ???
}

check[(String, Double)]
}
6 changes: 6 additions & 0 deletions tests/pos/inline-val-in-inline-method.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Example in docs/_docs/reference/metaprogramming/inline.md

inline def double(inline x: Int): Int = x * 2
inline def eight: Int =
inline val res = double(4)
res
7 changes: 7 additions & 0 deletions tests/pos/inline-val-short.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Example in docs/_docs/reference/metaprogramming/inline.md

trait InlineConstants:
inline val myShort: Short

object Constants extends InlineConstants:
inline val myShort/*: (4 : Short)*/ = 4
12 changes: 12 additions & 0 deletions tests/run/i24420-inline-local-ref.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
inline def f(): Long =
1L

inline def g(): Long =
inline val x = f()
x

inline def h(): Long =
inline if g() > 0L then 1L else 0L

@main def Test: Unit =
assert(h() == 1L)
25 changes: 25 additions & 0 deletions tests/run/i24420-inline-val.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
inline def f1(): Long =
1L

inline def f2(): Long =
inline val x = f1() + 1L
x

inline def f3(): Long =
inline val x = f1()
x

inline def g1(): Boolean =
true

inline def g2(): Long =
inline if g1() then 1L else 2L

inline def g3(): Long =
inline if f1() > 0L then 1L else 2L

@main def Test: Unit =
assert(f2() == 2L)
assert(f3() == 1L)
assert(g2() == 1L)
assert(g3() == 1L)
12 changes: 12 additions & 0 deletions tests/run/i24420-transparent-inline-local-ref.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
transparent inline def f(): Long =
1L

transparent inline def g(): Long =
inline val x = f()
x

transparent inline def h(): Long =
inline if g() > 0L then 1L else 0L

@main def Test: Unit =
assert(h() == 1L)
Loading