Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 74 additions & 46 deletions core/llm/llms/Gemini.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { streamSse } from "@continuedev/fetch";
import { streamResponse } from "@continuedev/fetch";
import { v4 as uuidv4 } from "uuid";
import {
AssistantChatMessage,
Expand Down Expand Up @@ -312,57 +312,84 @@ class Gemini extends BaseLLM {
}

public async *processGeminiResponse(
response: Response,
stream: AsyncIterable<string>,
): AsyncGenerator<ChatMessage> {
for await (const chunk of streamSse(response)) {
let data: GeminiChatResponse;
try {
data = JSON.parse(chunk) as GeminiChatResponse;
} catch (e) {
continue;
let buffer = "";
for await (const chunk of stream) {
buffer += chunk;
if (buffer.startsWith("[")) {
buffer = buffer.slice(1);
}

if ("error" in data) {
throw new Error(data.error.message);
if (buffer.endsWith("]")) {
buffer = buffer.slice(0, -1);
}
if (buffer.startsWith(",")) {
buffer = buffer.slice(1);
}

const contentParts = data?.candidates?.[0]?.content?.parts;
if (contentParts) {
const textParts: MessagePart[] = [];
const toolCalls: ToolCallDelta[] = [];

for (const part of contentParts) {
if ("text" in part) {
textParts.push({ type: "text", text: part.text });
} else if ("functionCall" in part) {
toolCalls.push({
type: "function",
id: part.functionCall.id ?? uuidv4(),
function: {
name: part.functionCall.name,
arguments:
typeof part.functionCall.args === "string"
? part.functionCall.args
: JSON.stringify(part.functionCall.args),
},
});
} else {
console.warn("Unsupported gemini part type received", part);
}
const parts = buffer.split("\n,");

let foundIncomplete = false;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
let data: GeminiChatResponse;
try {
data = JSON.parse(part) as GeminiChatResponse;
} catch (e) {
foundIncomplete = true;
continue; // yo!
}

const assistantMessage: AssistantChatMessage = {
role: "assistant",
content: textParts.length ? textParts : "",
};
if (toolCalls.length > 0) {
assistantMessage.toolCalls = toolCalls;
if ("error" in data) {
throw new Error(data.error.message);
}
if (textParts.length || toolCalls.length) {
yield assistantMessage;

// In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec
const contentParts = data?.candidates?.[0]?.content?.parts;
if (contentParts) {
const textParts: MessagePart[] = [];
const toolCalls: ToolCallDelta[] = [];

for (const part of contentParts) {
if ("text" in part) {
textParts.push({ type: "text", text: part.text });
} else if ("functionCall" in part) {
toolCalls.push({
type: "function",
id: part.functionCall.id ?? uuidv4(),
function: {
name: part.functionCall.name,
arguments:
typeof part.functionCall.args === "string"
? part.functionCall.args
: JSON.stringify(part.functionCall.args),
},
});
} else {
// Note: function responses shouldn't be streamed, images not supported
console.warn("Unsupported gemini part type received", part);
}
}

const assistantMessage: AssistantChatMessage = {
role: "assistant",
content: textParts.length ? textParts : "",
};
if (toolCalls.length > 0) {
assistantMessage.toolCalls = toolCalls;
}
if (textParts.length || toolCalls.length) {
yield assistantMessage;
}
} else {
// Handle the case where the expected data structure is not found
console.warn("Unexpected response format:", data);
}
}
if (foundIncomplete) {
buffer = parts[parts.length - 1];
} else {
console.warn("Unexpected response format:", data);
buffer = "";
}
}
}
Expand All @@ -387,9 +414,10 @@ class Gemini extends BaseLLM {
body: JSON.stringify(body),
signal,
});

for await (const chunk of this.processGeminiResponse(response)) {
yield chunk;
for await (const message of this.processGeminiResponse(
streamResponse(response),
)) {
yield message;
}
}
private async *streamChatBison(
Expand Down
4 changes: 2 additions & 2 deletions core/llm/llms/VertexAI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { AuthClient, GoogleAuth, JWT, auth } from "google-auth-library";

import { streamSse } from "@continuedev/fetch";
import { streamResponse, streamSse } from "@continuedev/fetch";
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
import { renderChatMessage, stripImages } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";
Expand Down Expand Up @@ -287,7 +287,7 @@ class VertexAI extends BaseLLM {
body: JSON.stringify(body),
signal,
});
yield* this.geminiInstance.processGeminiResponse(response);
yield* this.geminiInstance.processGeminiResponse(streamResponse(response));
}

private async *streamChatBison(
Expand Down
109 changes: 68 additions & 41 deletions packages/openai-adapters/src/apis/Gemini.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { streamSse } from "@continuedev/fetch";
import { streamResponse } from "@continuedev/fetch";
import { OpenAI } from "openai/index";
import {
ChatCompletion,
Expand Down Expand Up @@ -284,58 +284,85 @@ export class GeminiApi implements BaseLlmApi {
}

async *handleStreamResponse(response: any, model: string) {
let buffer = "";
let usage: UsageInfo | undefined = undefined;
for await (const chunk of streamSse(response as any)) {
let data;
try {
data = JSON.parse(chunk);
} catch (e) {
continue;
for await (const chunk of streamResponse(response as any)) {
buffer += chunk;
if (buffer.startsWith("[")) {
buffer = buffer.slice(1);
}
if (data.error) {
throw new Error(data.error.message);
if (buffer.endsWith("]")) {
buffer = buffer.slice(0, -1);
}

if (data.usageMetadata) {
usage = {
prompt_tokens: data.usageMetadata.promptTokenCount || 0,
completion_tokens: data.usageMetadata.candidatesTokenCount || 0,
total_tokens: data.usageMetadata.totalTokenCount || 0,
};
if (buffer.startsWith(",")) {
buffer = buffer.slice(1);
}

const contentParts = data?.candidates?.[0]?.content?.parts;
if (contentParts) {
for (const part of contentParts) {
if ("text" in part) {
yield chatChunk({
content: part.text,
model,
});
} else if ("functionCall" in part) {
yield chatChunkFromDelta({
model,
delta: {
tool_calls: [
{
index: 0,
id: part.functionCall.id ?? uuidv4(),
type: "function",
function: {
name: part.functionCall.name,
arguments: JSON.stringify(part.functionCall.args),
const parts = buffer.split("\n,");

let foundIncomplete = false;
for (let i = 0; i < parts.length; i++) {
const part = parts[i];
let data;
try {
data = JSON.parse(part);
} catch (e) {
foundIncomplete = true;
continue; // yo!
}
if (data.error) {
throw new Error(data.error.message);
}

// Check for usage metadata
if (data.usageMetadata) {
usage = {
prompt_tokens: data.usageMetadata.promptTokenCount || 0,
completion_tokens: data.usageMetadata.candidatesTokenCount || 0,
total_tokens: data.usageMetadata.totalTokenCount || 0,
};
}

// In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec
const contentParts = data?.candidates?.[0]?.content?.parts;
if (contentParts) {
for (const part of contentParts) {
if ("text" in part) {
yield chatChunk({
content: part.text,
model,
});
} else if ("functionCall" in part) {
yield chatChunkFromDelta({
model,
delta: {
tool_calls: [
{
index: 0,
id: part.functionCall.id ?? uuidv4(),
type: "function",
function: {
name: part.functionCall.name,
arguments: JSON.stringify(part.functionCall.args),
},
},
},
],
},
});
],
},
});
}
}
} else {
console.warn("Unexpected response format:", data);
}
}
if (foundIncomplete) {
buffer = parts[parts.length - 1];
} else {
console.warn("Unexpected response format:", data);
buffer = "";
}
}

// Emit usage at the end if we have it
if (usage) {
yield usageChatChunk({
model,
Expand Down
Loading