Skip to content

Commit a0ce99c

Browse files
committed
fix: address review comments
This commit addresses comments regarding the OpenAI chat completion implementation in the meta_reference provider. Tool Augmentation - Add `augment_raw_messages_for_tools()` to properly inject tool definitions into prompts - Support model-family-specific tool formats: * Llama 3.1/3.2 multimodal: JsonCustomToolGenerator with JSON format * Llama 3.2/3.3/4: PythonListCustomToolGenerator with Python list format - Handle tool_choice hints (auto/required/specific tool) - Preserve existing system messages while adding tool context Streaming & Tool Call Detection - Implement streaming support via `params.stream` with `_stream_chat_completion()` - Add tool call detection by decoding assistant messages after generation - Set proper `finish_reason` based on content ("stop" vs "tool_calls") - Convert internal ToolCall format to OpenAI-compatible types - Stream chunks incrementally with proper delta formatting **Type Corrections** - Fix response_format handling in generators.py to properly extract schema from OpenAIJSONSchema TypedDict and use correct ResponseFormatType enum - Use correct OpenAI types: OpenAIChatCompletionToolCall, OpenAIChunkChoice, OpenAIChoiceDelta, OpenAIChatCompletionToolCallFunction Signed-off-by: Charlie Doern <[email protected]>
1 parent dac1ff1 commit a0ce99c

File tree

2 files changed

+345
-9
lines changed

2 files changed

+345
-9
lines changed

src/llama_stack/providers/inline/inference/meta_reference/generators.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
GreedySamplingStrategy,
1515
JsonSchemaResponseFormat,
1616
OpenAIChatCompletionRequestWithExtraBody,
17+
OpenAIResponseFormatJSONSchema,
1718
ResponseFormat,
19+
ResponseFormatType,
1820
SamplingParams,
1921
TopPSamplingStrategy,
2022
)
@@ -163,7 +165,8 @@ def chat_completion(
163165
sampling_params = SamplingParams()
164166
if request.temperature is not None or request.top_p is not None:
165167
sampling_params.strategy = TopPSamplingStrategy(
166-
temperature=request.temperature or 1.0, top_p=request.top_p or 1.0
168+
temperature=request.temperature if request.temperature is not None else 1.0,
169+
top_p=request.top_p if request.top_p is not None else 1.0,
167170
)
168171
if request.max_tokens:
169172
sampling_params.max_tokens = request.max_tokens
@@ -177,9 +180,12 @@ def chat_completion(
177180
# Get logits processor for response format
178181
logits_processor = None
179182
if request.response_format:
180-
if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema":
183+
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
184+
# Extract the actual schema from OpenAIJSONSchema TypedDict
185+
schema_dict = request.response_format.json_schema.get("schema") or {}
181186
json_schema_format = JsonSchemaResponseFormat(
182-
type="json_schema", json_schema=request.response_format.get("json_schema", {})
187+
type=ResponseFormatType.json_schema,
188+
json_schema=schema_dict,
183189
)
184190
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
185191

0 commit comments

Comments
 (0)