Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions src/llama_stack/apis/inference/event_logger.py

This file was deleted.

182 changes: 6 additions & 176 deletions src/llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree.

from collections.abc import AsyncIterator
from enum import Enum
from enum import Enum, StrEnum
from typing import (
Annotated,
Any,
Expand All @@ -15,28 +15,18 @@
)

from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import MetricResponseMixin, Order
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.responses import (
Order,
)
from llama_stack.apis.common.tracing import telemetry_traceable
from llama_stack.apis.models import Model
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod

register_schema(ToolCall)
register_schema(ToolDefinition)

from enum import StrEnum


@json_schema_type
class GreedySamplingStrategy(BaseModel):
Expand Down Expand Up @@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
content: InterleavedContent


@json_schema_type
class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.

:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param stop_reason: Reason why the model stopped generating. Options are:
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""

role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])


Message = Annotated[
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")


@json_schema_type
class ToolResponse(BaseModel):
"""Response from a tool invocation.

:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was invoked
:param content: The response content from the tool
:param metadata: (Optional) Additional metadata about the tool response
"""

call_id: str
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: dict[str, Any] | None = None

@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v


class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.

Expand Down Expand Up @@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"


@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""An event during chat completion generation.

:param event_type: Type of the event
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
:param logprobs: Optional log probabilities for generated tokens
:param stop_reason: Optional reason why generation stopped, if complete
"""

event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None


class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding.

Expand Down Expand Up @@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None


@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.

:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""

content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None


@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.

:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""

delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None


class SystemMessageBehavior(Enum):
"""Config for how to override the default system prompt.

Expand All @@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
replace = "replace"


@json_schema_type
class ToolConfig(BaseModel):
"""Configuration for tool use.

:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""

tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)

def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
try:
self.tool_choice = ToolChoice[self.tool_choice]
except KeyError:
pass


# This is an internally used class
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)

tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)

response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None


@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.

:param event: The event containing the new content
"""

event: ChatCompletionResponseEvent


@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.

:param completion_message: The complete response message
:param logprobs: Optional log probabilities for generated tokens
"""

completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None


@json_schema_type
class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings.
Expand Down
4 changes: 2 additions & 2 deletions src/llama_stack/core/routers/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any

from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
Expand Down Expand Up @@ -52,7 +52,7 @@ async def unregister_shield(self, identifier: str) -> None:
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
Expand Down
4 changes: 3 additions & 1 deletion src/llama_stack/models/llama/llama3/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
)
from termcolor import cprint

from llama_stack.models.llama.datatypes import ToolPromptFormat

from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
Expand Down
7 changes: 2 additions & 5 deletions src/llama_stack/models/llama/llama3/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@

from termcolor import colored

from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat

from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import datetime
from typing import Any

from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
)
Expand Down
3 changes: 2 additions & 1 deletion src/llama_stack/models/llama/llama3/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import re

from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat

from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import RecursiveType

logger = get_logger(name=__name__, category="models::llama")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import textwrap

from llama_stack.apis.inference import ToolDefinition
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,
Expand Down
Loading
Loading