diff --git a/core/llm/llms/Gemini.ts b/core/llm/llms/Gemini.ts index 40b0a656ef6..7caf2db7d75 100644 --- a/core/llm/llms/Gemini.ts +++ b/core/llm/llms/Gemini.ts @@ -1,4 +1,4 @@ -import { streamSse } from "@continuedev/fetch"; +import { streamResponse } from "@continuedev/fetch"; import { v4 as uuidv4 } from "uuid"; import { AssistantChatMessage, @@ -312,57 +312,84 @@ class Gemini extends BaseLLM { } public async *processGeminiResponse( - response: Response, + stream: AsyncIterable, ): AsyncGenerator { - for await (const chunk of streamSse(response)) { - let data: GeminiChatResponse; - try { - data = JSON.parse(chunk) as GeminiChatResponse; - } catch (e) { - continue; + let buffer = ""; + for await (const chunk of stream) { + buffer += chunk; + if (buffer.startsWith("[")) { + buffer = buffer.slice(1); } - - if ("error" in data) { - throw new Error(data.error.message); + if (buffer.endsWith("]")) { + buffer = buffer.slice(0, -1); + } + if (buffer.startsWith(",")) { + buffer = buffer.slice(1); } - const contentParts = data?.candidates?.[0]?.content?.parts; - if (contentParts) { - const textParts: MessagePart[] = []; - const toolCalls: ToolCallDelta[] = []; - - for (const part of contentParts) { - if ("text" in part) { - textParts.push({ type: "text", text: part.text }); - } else if ("functionCall" in part) { - toolCalls.push({ - type: "function", - id: part.functionCall.id ?? uuidv4(), - function: { - name: part.functionCall.name, - arguments: - typeof part.functionCall.args === "string" - ? part.functionCall.args - : JSON.stringify(part.functionCall.args), - }, - }); - } else { - console.warn("Unsupported gemini part type received", part); - } + const parts = buffer.split("\n,"); + + let foundIncomplete = false; + for (let i = 0; i < parts.length; i++) { + const part = parts[i]; + let data: GeminiChatResponse; + try { + data = JSON.parse(part) as GeminiChatResponse; + } catch (e) { + foundIncomplete = true; + continue; // yo! } - const assistantMessage: AssistantChatMessage = { - role: "assistant", - content: textParts.length ? textParts : "", - }; - if (toolCalls.length > 0) { - assistantMessage.toolCalls = toolCalls; + if ("error" in data) { + throw new Error(data.error.message); } - if (textParts.length || toolCalls.length) { - yield assistantMessage; + + // In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec + const contentParts = data?.candidates?.[0]?.content?.parts; + if (contentParts) { + const textParts: MessagePart[] = []; + const toolCalls: ToolCallDelta[] = []; + + for (const part of contentParts) { + if ("text" in part) { + textParts.push({ type: "text", text: part.text }); + } else if ("functionCall" in part) { + toolCalls.push({ + type: "function", + id: part.functionCall.id ?? uuidv4(), + function: { + name: part.functionCall.name, + arguments: + typeof part.functionCall.args === "string" + ? part.functionCall.args + : JSON.stringify(part.functionCall.args), + }, + }); + } else { + // Note: function responses shouldn't be streamed, images not supported + console.warn("Unsupported gemini part type received", part); + } + } + + const assistantMessage: AssistantChatMessage = { + role: "assistant", + content: textParts.length ? textParts : "", + }; + if (toolCalls.length > 0) { + assistantMessage.toolCalls = toolCalls; + } + if (textParts.length || toolCalls.length) { + yield assistantMessage; + } + } else { + // Handle the case where the expected data structure is not found + console.warn("Unexpected response format:", data); } + } + if (foundIncomplete) { + buffer = parts[parts.length - 1]; } else { - console.warn("Unexpected response format:", data); + buffer = ""; } } } @@ -387,9 +414,10 @@ class Gemini extends BaseLLM { body: JSON.stringify(body), signal, }); - - for await (const chunk of this.processGeminiResponse(response)) { - yield chunk; + for await (const message of this.processGeminiResponse( + streamResponse(response), + )) { + yield message; } } private async *streamChatBison( diff --git a/core/llm/llms/VertexAI.ts b/core/llm/llms/VertexAI.ts index 97600e054f8..258bccc3d2f 100644 --- a/core/llm/llms/VertexAI.ts +++ b/core/llm/llms/VertexAI.ts @@ -1,6 +1,6 @@ import { AuthClient, GoogleAuth, JWT, auth } from "google-auth-library"; -import { streamSse } from "@continuedev/fetch"; +import { streamResponse, streamSse } from "@continuedev/fetch"; import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; import { renderChatMessage, stripImages } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; @@ -287,7 +287,7 @@ class VertexAI extends BaseLLM { body: JSON.stringify(body), signal, }); - yield* this.geminiInstance.processGeminiResponse(response); + yield* this.geminiInstance.processGeminiResponse(streamResponse(response)); } private async *streamChatBison( diff --git a/packages/openai-adapters/src/apis/Gemini.ts b/packages/openai-adapters/src/apis/Gemini.ts index 64d926f1b8c..b0f32ad46e4 100644 --- a/packages/openai-adapters/src/apis/Gemini.ts +++ b/packages/openai-adapters/src/apis/Gemini.ts @@ -1,4 +1,4 @@ -import { streamSse } from "@continuedev/fetch"; +import { streamResponse } from "@continuedev/fetch"; import { OpenAI } from "openai/index"; import { ChatCompletion, @@ -284,58 +284,85 @@ export class GeminiApi implements BaseLlmApi { } async *handleStreamResponse(response: any, model: string) { + let buffer = ""; let usage: UsageInfo | undefined = undefined; - for await (const chunk of streamSse(response as any)) { - let data; - try { - data = JSON.parse(chunk); - } catch (e) { - continue; + for await (const chunk of streamResponse(response as any)) { + buffer += chunk; + if (buffer.startsWith("[")) { + buffer = buffer.slice(1); } - if (data.error) { - throw new Error(data.error.message); + if (buffer.endsWith("]")) { + buffer = buffer.slice(0, -1); } - - if (data.usageMetadata) { - usage = { - prompt_tokens: data.usageMetadata.promptTokenCount || 0, - completion_tokens: data.usageMetadata.candidatesTokenCount || 0, - total_tokens: data.usageMetadata.totalTokenCount || 0, - }; + if (buffer.startsWith(",")) { + buffer = buffer.slice(1); } - const contentParts = data?.candidates?.[0]?.content?.parts; - if (contentParts) { - for (const part of contentParts) { - if ("text" in part) { - yield chatChunk({ - content: part.text, - model, - }); - } else if ("functionCall" in part) { - yield chatChunkFromDelta({ - model, - delta: { - tool_calls: [ - { - index: 0, - id: part.functionCall.id ?? uuidv4(), - type: "function", - function: { - name: part.functionCall.name, - arguments: JSON.stringify(part.functionCall.args), + const parts = buffer.split("\n,"); + + let foundIncomplete = false; + for (let i = 0; i < parts.length; i++) { + const part = parts[i]; + let data; + try { + data = JSON.parse(part); + } catch (e) { + foundIncomplete = true; + continue; // yo! + } + if (data.error) { + throw new Error(data.error.message); + } + + // Check for usage metadata + if (data.usageMetadata) { + usage = { + prompt_tokens: data.usageMetadata.promptTokenCount || 0, + completion_tokens: data.usageMetadata.candidatesTokenCount || 0, + total_tokens: data.usageMetadata.totalTokenCount || 0, + }; + } + + // In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec + const contentParts = data?.candidates?.[0]?.content?.parts; + if (contentParts) { + for (const part of contentParts) { + if ("text" in part) { + yield chatChunk({ + content: part.text, + model, + }); + } else if ("functionCall" in part) { + yield chatChunkFromDelta({ + model, + delta: { + tool_calls: [ + { + index: 0, + id: part.functionCall.id ?? uuidv4(), + type: "function", + function: { + name: part.functionCall.name, + arguments: JSON.stringify(part.functionCall.args), + }, }, - }, - ], - }, - }); + ], + }, + }); + } } + } else { + console.warn("Unexpected response format:", data); } + } + if (foundIncomplete) { + buffer = parts[parts.length - 1]; } else { - console.warn("Unexpected response format:", data); + buffer = ""; } } + // Emit usage at the end if we have it if (usage) { yield usageChatChunk({ model,