Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ lazy val metals = project
"io.modelcontextprotocol.sdk" % "mcp" % "0.12.1",
"com.fasterxml.jackson.core" % "jackson-databind" % "2.20.0",
"io.undertow" % "undertow-servlet" % "2.3.12.Final",
// For Twirl
"org.playframework.twirl" %% "twirl-compiler" % "2.0.9",
),
buildInfoPackage := "scala.meta.internal.metals",
buildInfoKeys := Seq[BuildInfoKey](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ final class BuildTargets private (
def inverseSources(
source: AbsolutePath
): Option[BuildTargetIdentifier] = {

val buildTargets = sourceBuildTargets(source)
val orSbtBuildTarget =
buildTargets.getOrElse(sbtBuildScalaTarget(source).toIterable).toSeq
Expand Down
74 changes: 42 additions & 32 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@ class Compilers(
params: SemanticTokensParams,
token: CancelToken,
): Future[SemanticTokens] = {
val path = params.getTextDocument.getUri.toAbsolutePath
val emptyTokens = Collections.emptyList[Integer]();
if (!userConfig().enableSemanticHighlighting) {
if (!userConfig().enableSemanticHighlighting || path.isTwirlTemplate) {
Future { new SemanticTokens(emptyTokens) }
} else {
val path = params.getTextDocument.getUri.toAbsolutePath
loadCompiler(path)
.map { compiler =>
val (input, _, adjust) =
Expand Down Expand Up @@ -1002,21 +1002,27 @@ class Compilers(
codeActionId: String,
codeActionPayload: Option[Object],
): Future[ju.List[TextEdit]] = {
withPCAndAdjustLsp(params) { (pc, pos, adjust) =>
pc.codeAction(
CompilerOffsetParamsUtils.fromPos(
pos,
token,
outlineFilesProvider.getOutlineFiles(pc.buildTargetId()),
),
codeActionId,
codeActionPayload.asJava,
).asScala
.map { edits =>
adjust.adjustTextEdits(edits)
}
// disable code actions completely for Twirl templates
val isTwirl = params.getTextDocument.getUri.isTwirlTemplate
if (isTwirl) {
Future.successful(Nil.asJava)
} else {
withPCAndAdjustLsp(params) { (pc, pos, adjust) =>
pc.codeAction(
CompilerOffsetParamsUtils.fromPos(
pos,
token,
outlineFilesProvider.getOutlineFiles(pc.buildTargetId()),
),
codeActionId,
codeActionPayload.asJava,
).asScala
.map { edits =>
adjust.adjustTextEdits(edits)
}
}.getOrElse(Future.successful(Nil.asJava))
}
}.getOrElse(Future.successful(Nil.asJava))
}

def supportedCodeActions(path: AbsolutePath): ju.List[String] = {
loadCompiler(path).map { pc =>
Expand Down Expand Up @@ -1121,33 +1127,37 @@ class Compilers(
findTypeDef: Boolean,
): Future[DefinitionResult] =
withPCAndAdjustLsp(params) { (pc, pos, adjust) =>
val params = CompilerOffsetParamsUtils.fromPos(
val paramsWithOutline = CompilerOffsetParamsUtils.fromPos(
pos,
token,
outlineFilesProvider.getOutlineFiles(pc.buildTargetId()),
)

val defResult =
if (findTypeDef) pc.typeDefinition(params)
else
pc.definition(CompilerOffsetParamsUtils.fromPos(pos, token))
if (findTypeDef) pc.typeDefinition(paramsWithOutline)
else pc.definition(CompilerOffsetParamsUtils.fromPos(pos, token))

defResult.asScala
.map { c =>
adjust.adjustLocations(c.locations())
val definitionPaths = c
.locations()
.map { loc =>
loc.getUri().toAbsolutePath
}
val locations = c.locations()
val originalUri = paramsWithOutline.uri

val adjustable = locations.asScala.filter(loc =>
originalUri.toString == loc.getUri()
)
adjust.adjustLocations(adjustable.asJava)

val definitionPaths = locations
.map(_.getUri().toAbsolutePath)
.asScala
.toSet

val definitionPath = if (definitionPaths.size == 1) {
Some(definitionPaths.head)
} else {
None
}
val definitionPath =
if (definitionPaths.size == 1) Some(definitionPaths.head)
else None

DefinitionResult(
c.locations(),
locations,
c.symbol(),
definitionPath,
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ final case class SourceMapper(
path.isWorksheet && ScalaVersions.isScala3Version(scalaVersion)
) {
WorksheetProvider.worksheetScala3Adjustments(input)
} else if (path.isTwirlTemplate) {
Some(TwirlAdjustments(input, scalaVersion))
} else None

forScripts.getOrElse(default)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package scala.meta.internal.metals

import java.io.File

import scala.io.Codec

import scala.meta.inputs.Input.VirtualFile

import org.eclipse.lsp4j.Position
import play.twirl.compiler.GeneratedSourceVirtual
import play.twirl.compiler.TwirlCompiler

/**
* A utility object for adjusting and mapping positions between Twirl templates and their compiled Scala output.
*
* This is particularly useful for hover, completions and goto definition features between user-authored `.scala.html`
* templates and their generated `.template.scala` counterparts.
*/

object TwirlAdjustments {

def isPlayProject(implicit file: VirtualFile): Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AS an improvement, we could check in sbt project for play specific libraries. But this seems good enough for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does make sense.

file.path.contains("views/")

/**
* Probably not the best solution, as ideally one should also be able to take configuration from
* the client's build.sbt files as TwirlKeys.templateImports. But this works nonetheless.
*/
private def playImports(
originalImports: Seq[String]
)(implicit file: VirtualFile): Seq[String] = {
if (isPlayProject)
originalImports ++ Seq(
"models._", "controllers._", "play.api.i18n._", "views.html._",
"play.api.templates.PlayMagic._", "play.api.mvc._", "play.api.data._",
)
else originalImports
}

private def playDI(implicit file: VirtualFile): Seq[String] = {
if (isPlayProject) Seq("@javax.inject.Inject()")
else Nil
}

/**
* Compiles an in-memory Twirl template into a compiled representation using the Twirl compiler.
*
* This method uses a virtual file and a resolved Scala version to invoke `TwirlCompiler.compileVirtual`.
*
* @param the virtual file representing the template content
* @param the full Scala version string (used to resolve compatibility with Twirl)
* @return the result of compiling the Twirl template
*/
def getCompiledString(implicit
file: VirtualFile,
scalaVersion: String,
): GeneratedSourceVirtual =
TwirlCompiler
.compileVirtual(
content = file.value,
source = new File("foo/bar/example.scala.html"),
sourceDirectory = new File("foo/bar"),
resultType = "play.twirl.api.Html",
formatterType = "play.twirl.api.HtmlFormat.Appendable",
additionalImports =
playImports(TwirlCompiler.defaultImports(scalaVersion)),
constructorAnnotations = playDI,
codec = Codec(
scala.util.Properties.sourceEncoding
),
scalaVersion = Some(scalaVersion),
inclusiveDot = true,
)

/**
* Converts a character offset (index) in a string to an LSP `Position` (0 based - line number and character offset).
*
* @param The full text content. Can be either the Twirl Source or the Compiled Twirl File
* @param The character offset within the text (0-based).
* @return A `Position` object representing the line and column corresponding to the given index.
*/
private def getPositionFromIndex(text: String, index: Int): Position = {
val lines = text.substring(0, index).split('\n')
new Position(lines.length - 1, lines.last.length)
}

/**
* Converts an LSP `Position` (0 based - line number and character offset) into a character index.
*
* @param The full text content. Can be either the Twirl Source or the Compiled Twirl File
* @param The LSP `Position` to convert (line and character).
* @return The absolute character index in the string corresponding to the position.
*/
private def getIndexFromPosition(text: String, pos: Position): Int = {
val lines = text.split('\n')
lines
.take(pos.getLine)
.map(_.length + 1)
.sum + pos.getCharacter
}

val pattern = """(\d+)->(\d+)""".r

/**
* Extracts a positional mapping matrix from the compiled Twirl template.
*
* This method parses those mappings and builds a matrix of (original, generated) index pairs.
* The mapping is later used for position translation between source and compiled files.
*
* @param The compiled Twirl template content as a string
* @return An array of tuples representing (originalIndex, generatedIndex) pairs
*/
private def getMatrix(compiledTwirl: String): Array[(Int, Int)] = {
val number_matching =
pattern.findAllIn(compiledTwirl).toArray
val chars = number_matching.take(number_matching.length / 2)
chars.map { char =>
val parts = char.split("->")
val a = parts(0).toInt
val b = parts(1).toInt
(b, a)
}
}

/**
* Maps positions between original Twirl template and compiled Scala output.
*
* Returns a tuple of:
* - The compiled virtual file,
* - A function to map original Twirl positions -> compiled Scala positions,
* - An AdjustedLspData instance for reverse mapping compiled Scala -> Twirl positions.
*/
def apply(
twirlFile: VirtualFile,
rawScalaVersion: String,
): (VirtualFile, Position => Position, AdjustLspData) = {

val originalTwirl = twirlFile.value
val compiledSource = getCompiledString(twirlFile, rawScalaVersion)
val compiledTwirl = compiledSource.content
val newVirtualFile = twirlFile.copy(value = compiledTwirl)
val matrix: Array[(Int, Int)] = getMatrix(compiledTwirl)

/**
* Maps a Position in the original Twirl template to the corresponding
* Position in the compiled Scala output
*/
def mapPosition(originalPos: Position): Position = {
val originalIndex = getIndexFromPosition(originalTwirl, originalPos)
val idx = matrix.indexWhere(_._1 >= originalIndex)
val (origBase, genBase) = matrix(idx - 1)
val mappedIndex = genBase + (originalIndex - origBase)
getPositionFromIndex(compiledTwirl, mappedIndex)
}

/**
* Maps a Position in the compiled Scala output back to the original
*/
def reverseMapPosition(compiledPos: Position): Position = {
val compiledIndex = getIndexFromPosition(compiledTwirl, compiledPos)
val pos_tuple = compiledSource.mapPosition(compiledIndex)
getPositionFromIndex(originalTwirl, pos_tuple)
}

(
newVirtualFile,
mapPosition,
AdjustedLspData.create(reverseMapPosition),
)
}
}
11 changes: 7 additions & 4 deletions metals/src/main/scala/scala/meta/internal/parsing/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ final class Trees(
private val tokenized = TrieMap.empty[AbsolutePath, Tokens]

def get(path: AbsolutePath): Option[Tree] =
trees.get(path).orElse {
// Fallback to parse without caching result.
parse(path, scalaVersionSelector.getDialect(path)).flatMap(_.toOption)
if (path.isTwirlTemplate) None
else {
trees.get(path).orElse {
// Fallback to parse without caching result.
parse(path, scalaVersionSelector.getDialect(path)).flatMap(_.toOption)
}
}

def didClose(fileUri: AbsolutePath): Unit = {
Expand Down Expand Up @@ -171,7 +174,7 @@ final class Trees(
def didChange(path: AbsolutePath): List[Diagnostic] = {
val dialect = scalaVersionSelector.getDialect(path)
parse(path, dialect) match {
case Some(parsed) =>
case Some(parsed) if !path.isTwirlTemplate =>
parsed match {
case Parsed.Error(pos, message, _) =>
List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ trait CommonMtagsEnrichments {
def isWorksheet: Boolean =
doc.endsWith(".worksheet.sc")
def isScalaFilename: Boolean =
doc.isScala || isScalaScript || isSbt || isMill
doc.isScala || isScalaScript || isSbt || isMill || isTwirlTemplate
def isTwirlTemplate: Boolean =
doc.endsWith(".scala.html")
def isScalaOrJavaFilename: Boolean =
isScalaFilename || isJavaFilename
def isJavaFilename: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ final class AutoImportsProvider(
context,
value
)

val nameEdit = new l.TextEdit(namePos, short)

if (short != name && shouldApplyNameEdit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ trait Signatures { compiler: MetalsGlobal =>
val name =
if (isGroup) importNames.mkString("{", ", ", "}")
else importNames.mkString
s"${indent}import ${scope.fullname(owner)}.${name}"
val isTwirl =
if (pos.source.toString.endsWith("scala.html")) "\n@" else ""

s"${indent}${isTwirl}import ${scope.fullname(owner)}.${name}"
}
.mkString(topPadding, "\n", "\n")
val startPos = pos.withPoint(lineStart).focus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ trait ScalametaCommonEnrichments extends CommonMtagsEnrichments {
def isJavaFilename: Boolean = {
filename.isJavaFilename
}

def isTwirlTemplate: Boolean = {
filename.endsWith(".scala.html")
}

def isScalaFilename: Boolean = {
filename.isScalaFilename
}
Expand Down
Loading
Loading