Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ case class ApolloSourceGenerator(
jsonCodeGen: JsonCodeGen
) extends Generator[List[Stat]] {

private val typesImport: Import = q"import types._"

/**
* Generates only the interfaces (fragments) that appear in the given
* document.
Expand All @@ -46,6 +48,13 @@ case class ApolloSourceGenerator(
.map(generateInterface(_, isSealed = false))
)

def generateFragments(document: TypedDocument.Api): Result[List[Stat]] =
Right(
jsonCodeGen.imports ++ additionalImports ++ List(typesImport, q"""object fragments {
..${document.fragments.flatMap(generateFragment)}
}""")
)

/**
* Generates only the types that appear in the given
* document.
Expand Down Expand Up @@ -80,13 +89,17 @@ case class ApolloSourceGenerator(
val inputParams = generateFieldParams(operation.variables, List.empty)
val dataParams =
generateFieldParams(operation.selection.fields, List.empty)

val data =
operation.selection.fields.flatMap(selectionStats(_, List.empty))

// render the document into the query object.
// replacing single $ with $$ for escaping
val escapedDocumentString =
operation.original.renderPretty.replaceAll("\\$", "\\$\\$")
operation.original
.copy(selections = operation.original.selections.map(removeCodeGen))
.renderPretty
.replaceAll("\\$", "\\$\\$")

// add the fragments to the query as well
val escapedFragmentString = Option(document.original.fragments)
Expand Down Expand Up @@ -133,24 +146,32 @@ case class ApolloSourceGenerator(
..$data
}"""
}
val types = document.types.flatMap(generateType)
val objectName = fileName.replaceAll("\\.graphql$|\\.gql$", "")

Right(
additionalImports ++
jsonCodeGen.imports ++
List(q"import sangria.macros._", q"import types._", q"""
List(q"import sangria.macros._", typesImport, q"""
object ${Term.Name(objectName)} {
..$operations
}
""")
)
}

private def selectionStats(field: TypedDocument.Field, typeQualifiers: List[String]): List[Stat] =
/**
*
* @param field field to generated code for
* @param typeQualifiers for nested case class structures in companion objects
* @return
*/
private def selectionStats(field: TypedDocument.Field, typeQualifiers: List[String]): List[Stat] = {
field match {
case TypedDocument.Field(name, tpe, _, _, _, Some(codeGen)) => List.empty
case TypedDocument.Field(name, tpe, _, _, Some(fragment), _) => List.empty

// render enumerations (union types)
case TypedDocument.Field(name, _, None, unionTypes) if unionTypes.nonEmpty =>
case TypedDocument.Field(name, _, None, unionTypes, _, _) if unionTypes.nonEmpty =>
// create the union types

val unionName = Type.Name(name.capitalize)
Expand Down Expand Up @@ -229,7 +250,7 @@ case class ApolloSourceGenerator(
)

// render a nested case class for a deeper selection
case TypedDocument.Field(name, tpe, Some(fieldSelection), _) =>
case TypedDocument.Field(name, tpe, Some(fieldSelection), _, _, _) =>
// Recursive call - create more case classes

val fieldName = Type.Name(name.capitalize)
Expand All @@ -253,15 +274,28 @@ case class ApolloSourceGenerator(
) ++ Option(innerStats).filter(_.nonEmpty).map { stats =>
q"object $termName { ..$stats }"
}
case TypedDocument.Field(_, _, _, _) =>
case TypedDocument.Field(_, _, _, _, _, _) =>
// scalar types, e.g. String, Option, List
List.empty
}
}

private def removeCodeGen(selection: sangria.ast.Selection): sangria.ast.Selection = selection match {
case f: sangria.ast.Field =>
f.copy(
selections = f.selections.map(removeCodeGen),
directives = f.directives.filter(_.name == "codeGen")
)
case s => s
}

private def generateFieldParams(
fields: List[TypedDocument.Field],
typeQualifiers: List[String]
): List[Term.Param] =
/**
*
* @param fields generate case class class members / values
* @param typeQualifiers for nested case class structures in companion objects
* @return
*/
private def generateFieldParams(fields: List[TypedDocument.Field], typeQualifiers: List[String]): List[Term.Param] =
fields.map { field =>
val tpe = parameterFieldType(field, typeQualifiers)
termParam(field.name, tpe)
Expand All @@ -273,18 +307,27 @@ case class ApolloSourceGenerator(
/**
* Turns a Type
* @param field
* @param typeQualifiers for nested case class structures in companion objects
* @return
*/
private def parameterFieldType(field: TypedDocument.Field, typeQualifiers: List[String]): Type =
generateFieldType(field) { tpe =>
if (field.isObjectLike || field.isUnion) {
// prepend the type qualifier for nested object/case class structures
ScalametaUtils.typeRefOf(typeQualifiers, field.name.capitalize)
} else {
// this branch handles non-enum or case class types, which means we don't need the
// typeQualifiers here.
Type.Name(tpe.namedType.name)
// custom directive has highest precedence
val codeGenType = field.codeGen.map(codeGen => Type.Name(codeGen.useType))
// use the fragment name as a type!
val fragmentType = field.fragment.map(fragment => Type.Name(fragment.name))

codeGenType.orElse(fragmentType).getOrElse {
if (field.isObjectLike || field.isUnion) {
// prepend the type qualifier for nested object/case class structures
ScalametaUtils.typeRefOf(typeQualifiers, field.name.capitalize)
} else {
// this branch handles non-enum or case class types, which means we don't need the
// typeQualifiers here.
Type.Name(tpe.namedType.name)
}
}

}

/**
Expand Down Expand Up @@ -364,6 +407,9 @@ case class ApolloSourceGenerator(
List[Stat](q"case class $className(..$params) extends $template") ++ objectStats
}

private def generateFragment(fragment: TypedDocument.Fragment): List[Stat] =
selectionStats(fragment.field, List.empty)

/**
* Generates the general types for this document.
*
Expand Down Expand Up @@ -408,6 +454,9 @@ case class ApolloSourceGenerator(
q"sealed trait $unionName",
q"object $objectName { ..$unionValues }"
)
// fragments are generated separately and thus are ignored here
case _: TypedDocument.Fragment =>
List.empty
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import sbt.Logger
* @param log output log
*/
case class CodeGenContext(
schema: Schema[_, _],
schema: Schema[Any, Any],
targetDirectory: File,
graphQLFiles: Seq[File],
packageName: String,
Expand Down
27 changes: 23 additions & 4 deletions src/main/scala/rocks/muki/graphql/codegen/DocumentLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,31 @@ import scala.io.Source

object DocumentLoader {

private val codeGenUseTypeArg: Argument[Option[String]] = Argument(
"useType",
OptionInputType(StringType),
"Specify another type to be used"
)

private val codeGenDirective = Directive(
name = "codeGen",
description = Some("Directs the executor to include this fragment definition only when the `if` argument is true."),
arguments = codeGenUseTypeArg :: Nil,
locations = Set(
DirectiveLocation.Field
)
)

/**
* Loads and parses all files and merge them into a single document
* @param schema used to validate parsed files
* @param files the files that should be loaded
* @return
*/
def merged(schema: Schema[_, _], files: List[File]): Result[Document] =
def merged(schema: Schema[Any, Any], files: List[File]): Result[Document] =
/*_*/
files
.traverse(file => single(schema, file))
.traverse(file => single(withDirectives(schema), file))
.map(documents => documents.combineAll)
/*_*/

Expand All @@ -48,17 +63,21 @@ object DocumentLoader {
* @param file
* @return
*/
def single(schema: Schema[_, _], file: File): Result[Document] =
def single(schema: Schema[Any, Any], file: File): Result[Document] =
for {
document <- parseDocument(file)
violations = QueryValidator.default.validateQuery(schema, document)
violations = QueryValidator.default.validateQuery(withDirectives(schema), document)
_ <- Either.cond(
violations.isEmpty,
document,
Failure(s"Invalid query in ${file.getAbsolutePath}:\n${violations.map(_.errorMessage).mkString(", ")}")
)
} yield document

private def withDirectives(schema: Schema[Any, Any]): Schema[Any, Any] = schema.copy(
directives = schema.directives :+ codeGenDirective
)

private def parseSchema(file: File): Result[Schema[_, _]] =
for {
document <- parseDocument(file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ case class ScalametaGenerator(moduleName: Term.Name, emitInterfaces: Boolean = f
def generateSelectionStats(prefix: String)(selection: TypedDocument.Selection): List[Stat] =
selection.fields.flatMap {
// render enumerations (union types)
case TypedDocument.Field(name, _, None, unionTypes) if unionTypes.nonEmpty =>
case TypedDocument.Field(name, _, None, unionTypes, _, _) if unionTypes.nonEmpty =>
val unionName = Type.Name(name.capitalize)
val objectName = Term.Name(unionName.value)
val template = generateTemplate(List(unionName.value), prefix)
Expand All @@ -121,7 +121,7 @@ case class ScalametaGenerator(moduleName: Term.Name, emitInterfaces: Boolean = f
)

// render a nested case class for a deeper selection
case TypedDocument.Field(name, _, Some(selection), _) =>
case TypedDocument.Field(name, _, Some(selection), _, _, _) =>
val stats =
generateSelectionStats(prefix + name.capitalize + ".")(selection)
val params =
Expand All @@ -142,7 +142,7 @@ case class ScalametaGenerator(moduleName: Term.Name, emitInterfaces: Boolean = f
}
.toList

case TypedDocument.Field(_, _, _, _) =>
case TypedDocument.Field(_, _, _, _, _, _) =>
List.empty
}

Expand Down Expand Up @@ -235,6 +235,8 @@ case class ScalametaGenerator(moduleName: Term.Name, emitInterfaces: Boolean = f
q"sealed trait $unionName",
q"object $objectName { ..$unionValues }"
)

case _ :TypedDocument.Fragment => List.empty
}

}
Expand Down
Loading