diff --git a/livekit-agents/livekit/agents/llm/_provider_format/__init__.py b/livekit-agents/livekit/agents/llm/_provider_format/__init__.py index de676f43e3..fed1b34785 100644 --- a/livekit-agents/livekit/agents/llm/_provider_format/__init__.py +++ b/livekit-agents/livekit/agents/llm/_provider_format/__init__.py @@ -1,3 +1,3 @@ -from . import anthropic, aws, google, mistralai, openai +from . import anthropic, aws, google, mistralai, openai, vllm -__all__ = ["openai", "google", "aws", "anthropic", "mistralai"] +__all__ = ["openai", "google", "aws", "anthropic", "mistralai", "vllm"] diff --git a/livekit-agents/livekit/agents/llm/_provider_format/vllm.py b/livekit-agents/livekit/agents/llm/_provider_format/vllm.py new file mode 100644 index 0000000000..ad583540af --- /dev/null +++ b/livekit-agents/livekit/agents/llm/_provider_format/vllm.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Literal + +from livekit.agents import llm + +from .openai import to_chat_ctx as openai_to_chat_ctx + + +def to_chat_ctx( + chat_ctx: llm.ChatContext, + *, + inject_dummy_user_message: bool = True, + dummy_user_message: str = "", +) -> tuple[list[dict], Literal[None]]: + messages, _ = openai_to_chat_ctx(chat_ctx, inject_dummy_user_message=inject_dummy_user_message) + + if len(messages) == 1 and messages[0]["role"] == "system": + messages.append({"role": "user", "content": dummy_user_message}) + if len(messages) > 1 and messages[0]["role"] == "system" and messages[1]["role"] == "assistant": + messages.insert(1, {"role": "user", "content": dummy_user_message}) + collated: list[dict] = [] + for msg in messages: + if len(collated) > 0 and collated[-1]["role"] == msg["role"]: + collated[-1]["content"] += " " + msg["content"] + else: + collated.append(msg) + return collated, None diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index 9a4a940327..054c3b66f7 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -407,12 +407,17 @@ def to_provider_format( self, format: Literal["mistralai"], *, inject_dummy_user_message: bool = True ) -> tuple[list[dict], Literal[None]]: ... + @overload + def to_provider_format( + self, format: Literal["vllm"], *, inject_dummy_user_message: bool = True + ) -> tuple[list[dict], Literal[None]]: ... + @overload def to_provider_format(self, format: str, **kwargs: Any) -> tuple[list[dict], Any]: ... def to_provider_format( self, - format: Literal["openai", "google", "aws", "anthropic", "mistralai"] | str, + format: Literal["openai", "google", "aws", "anthropic", "mistralai", "vllm"] | str, *, inject_dummy_user_message: bool = True, **kwargs: Any, @@ -437,6 +442,8 @@ def to_provider_format( return _provider_format.anthropic.to_chat_ctx(self, **kwargs) elif format == "mistralai": return _provider_format.mistralai.to_chat_ctx(self, **kwargs) + elif format == "vllm": + return _provider_format.vllm.to_chat_ctx(self, **kwargs) else: raise ValueError(f"Unsupported provider format: {format}")