Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions aisuite/framework/chat_completion_stream_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# aisuite/framework/chat_completion_stream_response.py

from typing import Optional

class ChatCompletionStreamResponseDelta:
"""
Mimics the 'delta' object returned by OpenAI streaming chunks.
Example usage in code:
chunk.choices[0].delta.content
"""
def __init__(self, role: Optional[str] = None, content: Optional[str] = None):
self.role = role
self.content = content


class ChatCompletionStreamResponseChoice:
"""
Holds the 'delta' for a single chunk choice.
Example usage in code:
chunk.choices[0].delta
"""
def __init__(self, delta: ChatCompletionStreamResponseDelta):
self.delta = delta


class ChatCompletionStreamResponse:
"""
Container for streaming response chunks.
Each chunk has a 'choices' list, each with a 'delta'.
Example usage in code:
chunk.choices[0].delta.content
"""
def __init__(self, choices: list[ChatCompletionStreamResponseChoice]):
self.choices = choices



124 changes: 90 additions & 34 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Anthropic provider
# Links:
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
# aisuite/providers/anthropic_provider.py

import anthropic
import json

from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function

# Import our new streaming response classes:
from aisuite.framework.chat_completion_stream_response import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
ChatCompletionStreamResponseDelta,
)

# Define a constant for the default max_tokens value
DEFAULT_MAX_TOKENS = 4096

Expand All @@ -33,7 +39,7 @@ def convert_request(self, messages):
return system_message, converted_messages

def convert_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
"""Normalize a non-streaming response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
normalized_response.usage = self._get_usage_stats(response)
Expand All @@ -57,7 +63,7 @@ def _convert_dict_message(self, msg):
return {"role": msg["role"], "content": msg["content"]}

def _convert_message_object(self, msg):
"""Convert a Message object to Anthropic format."""
"""Convert a `Message` object to Anthropic format."""
if msg.role == self.ROLE_TOOL:
return self._create_tool_result_message(msg.tool_call_id, msg.content)
elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls:
Expand Down Expand Up @@ -107,22 +113,23 @@ def _create_assistant_tool_message(self, content, tool_calls):
return {"role": self.ROLE_ASSISTANT, "content": message_content}

def _extract_system_message(self, messages):
"""Extract system message if present, otherwise return empty list."""
# TODO: This is a temporary solution to extract the system message.
# User can pass multiple system messages, which can mingled with other messages.
# This needs to be fixed to handle this case.
"""
Extract system message if present, otherwise return empty string.
If there are multiple system messages, or the system message is not the first,
you may need to adapt this approach.
"""
if messages and messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages.pop(0)
return system_message
return []
return ""

def _get_finish_reason(self, response):
"""Get the normalized finish reason."""
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")

def _get_usage_stats(self, response):
"""Get the usage statistics."""
"""Get the usage statistics from Anthropic response."""
return {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
Expand All @@ -135,9 +142,8 @@ def _get_message(self, response):
tool_message = self.convert_response_with_tool_use(response)
if tool_message:
return tool_message

return Message(
content=response.content[0].text,
content=response.content[0].text if response.content else "",
role="assistant",
tool_calls=None,
refusal=None,
Expand All @@ -146,26 +152,22 @@ def _get_message(self, response):
def convert_response_with_tool_use(self, response):
"""Convert Anthropic tool use response to the framework's format."""
tool_call = next(
(content for content in response.content if content.type == "tool_use"),
(c for c in response.content if c.type == "tool_use"),
None,
)

if tool_call:
function = Function(
name=tool_call.name, arguments=json.dumps(tool_call.input)
name=tool_call.name,
arguments=json.dumps(tool_call.input),
)
tool_call_obj = ChatCompletionMessageToolCall(
id=tool_call.id, function=function, type="function"
id=tool_call.id,
function=function,
type="function",
)
text_content = next(
(
content.text
for content in response.content
if content.type == "text"
),
"",
(c.text for c in response.content if c.type == "text"), ""
)

return Message(
content=text_content or None,
tool_calls=[tool_call_obj] if tool_call else None,
Expand All @@ -177,11 +179,9 @@ def convert_response_with_tool_use(self, response):
def convert_tool_spec(self, openai_tools):
"""Convert OpenAI tool specification to Anthropic format."""
anthropic_tools = []

for tool in openai_tools:
if tool.get("type") != "function":
continue

function = tool["function"]
anthropic_tool = {
"name": function["name"],
Expand All @@ -193,7 +193,6 @@ def convert_tool_spec(self, openai_tools):
},
}
anthropic_tools.append(anthropic_tool)

return anthropic_tools


Expand All @@ -204,21 +203,78 @@ def __init__(self, **config):
self.converter = AnthropicMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
"""Create a chat completion using the Anthropic API."""
"""
Create a chat completion using the Anthropic API.

If 'stream=True' is passed, return a generator that yields
`ChatCompletionStreamResponse` objects shaped like OpenAI's streaming chunks.
"""
stream = kwargs.pop("stream", False)

if not stream:
# Non-streaming call
kwargs = self._prepare_kwargs(kwargs)
system_message, converted_messages = self.converter.convert_request(messages)
response = self.client.messages.create(
model=model,
system=system_message,
messages=converted_messages,
**kwargs
)
return self.converter.convert_response(response)
else:
# Streaming call
return self._streaming_chat_completions_create(model, messages, **kwargs)

def _streaming_chat_completions_create(self, model, messages, **kwargs):
"""
Generator that yields chunk objects in the shape:
chunk.choices[0].delta.content
"""
kwargs = self._prepare_kwargs(kwargs)
system_message, converted_messages = self.converter.convert_request(messages)

response = self.client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
return self.converter.convert_response(response)
first_chunk = True

with self.client.messages.stream(
model=model,
system=system_message,
messages=converted_messages,
**kwargs
) as stream_resp:

for partial_text in stream_resp.text_stream:
# For the first token, include `role='assistant'`.
if first_chunk:
chunk = ChatCompletionStreamResponse(choices=[
ChatCompletionStreamResponseChoice(
delta=ChatCompletionStreamResponseDelta(
role="assistant",
content=partial_text
)
)
])
first_chunk = False
else:
chunk = ChatCompletionStreamResponse(choices=[
ChatCompletionStreamResponseChoice(
delta=ChatCompletionStreamResponseDelta(
content=partial_text
)
)
])

yield chunk

def _prepare_kwargs(self, kwargs):
"""Prepare kwargs for the API call."""
"""Prepare kwargs for the Anthropic API call."""
kwargs = kwargs.copy()
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)

if "tools" in kwargs:
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])

return kwargs




69 changes: 52 additions & 17 deletions aisuite/providers/openai_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import openai
import os
from aisuite.provider import Provider, LLMError
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


class OpenaiProvider(Provider):
Expand All @@ -14,27 +13,63 @@ def __init__(self, **config):
config.setdefault("api_key", os.getenv("OPENAI_API_KEY"))
if not config["api_key"]:
raise ValueError(
"OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
"OpenAI API key is missing. Please provide it in the config "
"or set the OPENAI_API_KEY environment variable."
)

# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically
# infer certain values from the environment variables.
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc.

# Pass the entire config to the OpenAI client constructor
# (Note: This assumes openai.OpenAI(...) is valid in your environment.
# If you typically do `openai.api_key = ...`, adapt as needed.)
self.client = openai.OpenAI(**config)
self.transformer = OpenAICompliantMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
# Any exception raised by OpenAI will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
try:
transformed_messages = self.transformer.convert_request(messages)
response = self.client.chat.completions.create(
"""
Create chat completion using the OpenAI API.
If 'stream=True' is passed via kwargs, return a generator that yields
chunked responses in the OpenAI streaming format.
"""
stream = kwargs.pop("stream", False)

if not stream:
# Non-streaming call
return self.client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs, # Pass any additional arguments to the OpenAI API
messages=messages,
**kwargs
)
return response
except Exception as e:
raise LLMError(f"An error occurred: {e}")
else:
# Streaming call: return a generator that yields each chunk
return self._streaming_chat_completions_create(model, messages, **kwargs)

def _streaming_chat_completions_create(self, model, messages, **kwargs):
"""
Internal helper method that yields chunked responses for streaming.
Each chunk is already in the OpenAI streaming format:

{
"id": ...,
"object": "chat.completion.chunk",
"created": ...,
"model": ...,
"choices": [
{
"delta": {
"role": "assistant" or "content": ...
}
}
]
}
"""
response_gen = self.client.chat.completions.create(
model=model,
messages=messages,
stream=True,
**kwargs
)

# Yield chunks as they arrive
for chunk in response_gen:
yield chunk



Loading