Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 43 additions & 27 deletions ai-core/src/main/scala/wvlet/ai/core/design/Binder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,26 @@
package wvlet.ai.core.design

import wvlet.ai.core.log.LogSupport
import wvlet.ai.core.surface.Surface
import wvlet.ai.core.typeshape.TypeShape
import LifeCycleHookType.*
import wvlet.ai.core.util.{LazyF0, SourceCode}

object Binder:
sealed trait Binding extends Serializable:
def forSingleton: Boolean = false
def from: Surface
def from: TypeShape
def sourceCode: SourceCode

case class ClassBinding(from: Surface, to: Surface, sourceCode: SourceCode) extends Binding:
case class ClassBinding(from: TypeShape, to: TypeShape, sourceCode: SourceCode) extends Binding:
if from == to then
throw DesignException.cyclicDependency(List(to), sourceCode)

case class SingletonBinding(from: Surface, to: Surface, isEager: Boolean, sourceCode: SourceCode)
extends Binding:
case class SingletonBinding(
from: TypeShape,
to: TypeShape,
isEager: Boolean,
sourceCode: SourceCode
) extends Binding:
override def forSingleton: Boolean = true

case class ProviderBinding(
Expand All @@ -39,7 +43,7 @@ object Binder:
sourceCode: SourceCode
) extends Binding:
assert(!eager || (eager && provideSingleton))
def from: Surface = factory.from
def from: TypeShape = factory.from
override def forSingleton: Boolean = provideSingleton

private val objectId = new Object().hashCode()
Expand All @@ -54,7 +58,7 @@ object Binder:
case _ =>
false

case class DependencyFactory(from: Surface, dependencyTypes: Seq[Surface], factory: Any):
case class DependencyFactory(from: TypeShape, dependencyTypes: Seq[TypeShape], factory: Any):
override def toString: String =
val deps =
if dependencyTypes.isEmpty then
Expand Down Expand Up @@ -96,7 +100,7 @@ import Binder.*

/**
*/
class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCode)
class Binder[A](val design: Design, val from: TypeShape, val sourceCode: SourceCode)
extends LogSupport:
/**
* Bind the type to a given instance. The instance will be instantiated as an eager singleton
Expand Down Expand Up @@ -130,22 +134,22 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
* @tparam B
*/
inline def to[B <: A]: DesignWithContext[B] =
val to = Surface.of[B]
val to = TypeShape.of[B]
if from == to then
warn("Binding to the same type is not allowed: " + to.toString)
throw DesignException.cyclicDependency(List(to), SourceCode())
design.addBinding[B](SingletonBinding(from, to, false, sourceCode))

inline def toEagerSingletonOf[B <: A]: DesignWithContext[B] =
val to = Surface.of[B]
val to = TypeShape.of[B]
if from == to then
warn("Binding to the same type is not allowed: " + to.toString)
throw DesignException.cyclicDependency(List(to), SourceCode())
design.addBinding[B](SingletonBinding(from, to, true, sourceCode))

inline def toProvider[D1](factory: D1 => A): DesignWithContext[A] = design.addBinding[A](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1]), factory),
true,
false,
SourceCode()
Expand All @@ -156,7 +160,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
A
](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1], Surface.of[D2]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1], TypeShape.of[D2]), factory),
true,
false,
SourceCode()
Expand All @@ -166,7 +170,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
inline def toProvider[D1, D2, D3](factory: (D1, D2, D3) => A): DesignWithContext[A] = design
.addBinding[A](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1], TypeShape.of[D2], TypeShape.of[D3]), factory),
true,
false,
SourceCode()
Expand All @@ -178,7 +182,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
ProviderBinding(
DependencyFactory(
from,
Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3], Surface.of[D4]),
Seq(TypeShape.of[D1], TypeShape.of[D2], TypeShape.of[D3], TypeShape.of[D4]),
factory
),
true,
Expand All @@ -193,7 +197,13 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
ProviderBinding(
DependencyFactory(
from,
Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3], Surface.of[D4], Surface.of[D5]),
Seq(
TypeShape.of[D1],
TypeShape.of[D2],
TypeShape.of[D3],
TypeShape.of[D4],
TypeShape.of[D5]
),
factory
),
true,
Expand All @@ -205,7 +215,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
inline def toEagerSingletonProvider[D1](factory: D1 => A): DesignWithContext[A] = design
.addBinding[A](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1]), factory),
true,
true,
SourceCode()
Expand All @@ -215,7 +225,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
inline def toEagerSingletonProvider[D1, D2](factory: (D1, D2) => A): DesignWithContext[A] = design
.addBinding[A](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1], Surface.of[D2]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1], TypeShape.of[D2]), factory),
true,
true,
SourceCode()
Expand All @@ -226,7 +236,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
factory: (D1, D2, D3) => A
): DesignWithContext[A] = design.addBinding[A](
ProviderBinding(
DependencyFactory(from, Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3]), factory),
DependencyFactory(from, Seq(TypeShape.of[D1], TypeShape.of[D2], TypeShape.of[D3]), factory),
true,
true,
SourceCode()
Expand All @@ -239,7 +249,7 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
ProviderBinding(
DependencyFactory(
from,
Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3], Surface.of[D4]),
Seq(TypeShape.of[D1], TypeShape.of[D2], TypeShape.of[D3], TypeShape.of[D4]),
factory
),
true,
Expand All @@ -254,7 +264,13 @@ class Binder[A](val design: Design, val from: Surface, val sourceCode: SourceCod
ProviderBinding(
DependencyFactory(
from,
Seq(Surface.of[D1], Surface.of[D2], Surface.of[D3], Surface.of[D4], Surface.of[D5]),
Seq(
TypeShape.of[D1],
TypeShape.of[D2],
TypeShape.of[D3],
TypeShape.of[D4],
TypeShape.of[D5]
),
factory
),
true,
Expand Down Expand Up @@ -293,28 +309,28 @@ end Binder
* DesignWithContext[A] is a wrapper of Design class for chaining lifecycle hooks for the same type
* A. This can be safely cast to just Design
*/
class DesignWithContext[A](design: Design, lastSurface: Surface)
class DesignWithContext[A](design: Design, lastTypeShape: TypeShape)
extends Design(design.designOptions, design.binding, design.hooks):
def onInit(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(ON_INIT, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(ON_INIT, lastTypeShape, body.asInstanceOf[Any => Unit])
)

def onInject(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(ON_INJECT, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(ON_INJECT, lastTypeShape, body.asInstanceOf[Any => Unit])
)

def onStart(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(ON_START, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(ON_START, lastTypeShape, body.asInstanceOf[Any => Unit])
)

def afterStart(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(AFTER_START, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(AFTER_START, lastTypeShape, body.asInstanceOf[Any => Unit])
)

def beforeShutdown(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(BEFORE_SHUTDOWN, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(BEFORE_SHUTDOWN, lastTypeShape, body.asInstanceOf[Any => Unit])
)

def onShutdown(body: A => Unit): DesignWithContext[A] = design.withLifeCycleHook[A](
LifeCycleHookDesign(ON_SHUTDOWN, lastSurface, body.asInstanceOf[Any => Unit])
LifeCycleHookDesign(ON_SHUTDOWN, lastTypeShape, body.asInstanceOf[Any => Unit])
)
27 changes: 14 additions & 13 deletions ai-core/src/main/scala/wvlet/ai/core/design/Design.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
package wvlet.ai.core.design

import wvlet.ai.core.log.LogSupport
import wvlet.ai.core.surface.Surface
import wvlet.ai.core.typeshape.TypeShape
import Binder.Binding
import DesignOptions.*
import wvlet.ai.core.util.SourceCode

/**
* Immutable airframe design.
*
* Design instance does not hold any duplicate bindings for the same Surface.
* Design instance does not hold any duplicate bindings for the same TypeShape.
*/
class Design(
private[design] val designOptions: DesignOptions,
Expand Down Expand Up @@ -51,10 +51,10 @@ class Design(
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)

private inline def bind[A](using sourceCode: SourceCode): Binder[A] =
new Binder(this, Surface.of[A], sourceCode).asInstanceOf[Binder[A]]
new Binder(this, TypeShape.of[A], sourceCode).asInstanceOf[Binder[A]]

inline def remove[A]: Design =
val target = Surface.of[A]
val target = TypeShape.of[A]
new Design(designOptions, binding.filterNot(_.from == target), hooks)

inline def bindInstance[A](obj: A)(using sourceCode: SourceCode): DesignWithContext[A] = bind[A]
Expand Down Expand Up @@ -115,21 +115,21 @@ class Design(
* @return
*/
def minimize: Design =
var seenBindingSurrace = Set.empty[Surface]
var minimizedBindingList = List.empty[Binding]
var seenBindingTypeShapes = Set.empty[TypeShape]
var minimizedBindingList = List.empty[Binding]

// Later binding has higher precedence, so traverse bindings from the tail
for b <- binding.reverseIterator do
val surface = b.from
if !seenBindingSurrace.contains(surface) then
if !seenBindingTypeShapes.contains(surface) then
minimizedBindingList = b :: minimizedBindingList
seenBindingSurrace += surface
seenBindingTypeShapes += surface

var seenHooks = Set.empty[(LifeCycleHookType, Surface)]
var seenHooks = Set.empty[(LifeCycleHookType, TypeShape)]
var minimizedHooks = List.empty[LifeCycleHookDesign]
// Override hooks for the same surface and event type
for h <- hooks.reverseIterator do
val key: (LifeCycleHookType, Surface) = (h.lifeCycleHookType, h.surface)
val key: (LifeCycleHookType, TypeShape) = (h.lifeCycleHookType, h.typeShape)
if !seenHooks.contains(key) then
minimizedHooks = h :: minimizedHooks
seenHooks += key
Expand All @@ -141,7 +141,7 @@ class Design(

def +(other: Design): Design = add(other)

def bindSurface(t: Surface)(using sourceCode: SourceCode): Binder[Any] =
def bindTypeShape(t: TypeShape)(using sourceCode: SourceCode): Binder[Any] =
trace(s"bind($t) ${t.isAlias}")
val b = new Binder[Any](this, t, sourceCode)
b
Expand All @@ -154,10 +154,11 @@ class Design(
trace(s"withLifeCycleHook: ${hook}")
new DesignWithContext[A](
new Design(designOptions, binding, hooks = hooks :+ hook),
hook.surface
hook.typeShape
)

def remove(t: Surface): Design = new Design(designOptions, binding.filterNot(_.from == t), hooks)
def remove(t: TypeShape): Design =
new Design(designOptions, binding.filterNot(_.from == t), hooks)

def withLifeCycleLogging: Design = new Design(designOptions.withLifeCycleLogging, binding, hooks)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package wvlet.ai.core.design

import wvlet.ai.core.surface.Surface
import wvlet.ai.core.typeshape.TypeShape
import wvlet.ai.core.util.SourceCode

enum DesignErrorCode:
Expand All @@ -33,13 +33,13 @@ case class DesignException(code: DesignErrorCode, message: String, cause: Throwa
override def getMessage: String = s"[${code}] ${message}"

object DesignException:
def cyclicDependency(deps: List[Surface], sourceCode: SourceCode): DesignException =
def cyclicDependency(deps: List[TypeShape], sourceCode: SourceCode): DesignException =
DesignException(
DesignErrorCode.CYCLIC_DEPENDENCY,
s"${deps.reverse.mkString(" -> ")} at ${sourceCode}"
)

def missingDependency(stack: List[Surface], sourceCode: SourceCode): DesignException =
def missingDependency(stack: List[TypeShape], sourceCode: SourceCode): DesignException =
DesignException(
DesignErrorCode.MISSING_DEPENDENCY,
s"Binding for ${stack.head} at ${sourceCode} is not found: ${stack.mkString(" <- ")}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package wvlet.ai.core.design

import wvlet.ai.core.log.LogSupport
import wvlet.ai.core.surface.Surface
import wvlet.ai.core.typeshape.TypeShape

/**
* Design configs
Expand Down Expand Up @@ -78,8 +78,8 @@ object DesignOptions:

case class LifeCycleHookDesign(
lifeCycleHookType: LifeCycleHookType,
surface: Surface,
typeShape: TypeShape,
hook: Any => Unit
):
// Override toString to protect calling the hook accidentally
override def toString: String = s"LifeCycleHookDesign[${lifeCycleHookType}](${surface})"
override def toString: String = s"LifeCycleHookDesign[${lifeCycleHookType}](${typeShape})"
Loading
Loading