Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions packages/workers-ai-provider/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "workers-ai-provider",
"description": "Workers AI Provider for the vercel AI SDK",
"type": "module",
"version": "0.5.2",
"version": "0.5.3-v5-beta",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"repository": {
Expand All @@ -21,10 +21,25 @@
"test:ci": "vitest --watch=false",
"test": "vitest"
},
"files": ["dist", "src", "README.md", "package.json"],
"keywords": ["workers", "cloudflare", "ai", "vercel", "sdk", "provider", "chat", "serverless"],
"files": [
"dist",
"src",
"README.md",
"package.json"
],
"keywords": [
"workers",
"cloudflare",
"ai",
"vercel",
"sdk",
"provider",
"chat",
"serverless"
],
"dependencies": {
"@ai-sdk/provider": "^1.1.3"
"@ai-sdk/provider": "2.0.0-beta.1",
"@ai-sdk/provider-utils": "2.2.8"
},
"devDependencies": {
"@cloudflare/workers-types": "^4.20250525.0"
Expand Down
113 changes: 38 additions & 75 deletions packages/workers-ai-provider/src/autorag-chat-language-model.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import {
type LanguageModelV1,
type LanguageModelV1CallWarning,
UnsupportedFunctionalityError,
type LanguageModelV2,
type LanguageModelV2CallWarning,
} from "@ai-sdk/provider";

import type { AutoRAGChatSettings } from "./autorag-chat-settings";
import { convertToWorkersAIChatMessages } from "./convert-to-workersai-chat-messages";
import { mapWorkersAIUsage } from "./map-workersai-usage";
import { getMappedStream } from "./streaming";
import { prepareToolsAndToolChoice, processToolCalls } from "./utils";
import type { TextGenerationModels } from "./workersai-models";

type AutoRAGChatConfig = {
Expand All @@ -17,13 +15,15 @@ type AutoRAGChatConfig = {
gateway?: GatewayOptions;
};

export class AutoRAGChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
export class AutoRAGChatLanguageModel implements LanguageModelV2 {
readonly specificationVersion = "v2";
readonly defaultObjectGenerationMode = "json";

readonly modelId: TextGenerationModels;
readonly settings: AutoRAGChatSettings;

readonly supportedUrls = {}

private readonly config: AutoRAGChatConfig;

constructor(
Expand All @@ -41,14 +41,13 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
}

private getArgs({
mode,
prompt,
frequencyPenalty,
presencePenalty,
}: Parameters<LanguageModelV1["doGenerate"]>[0]) {
const type = mode.type;

const warnings: LanguageModelV1CallWarning[] = [];
tools,
toolChoice,
}: Parameters<LanguageModelV2["doGenerate"]>[0]) {
const warnings: LanguageModelV2CallWarning[] = [];

if (frequencyPenalty != null) {
warnings.push({
Expand All @@ -72,87 +71,51 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {
messages: convertToWorkersAIChatMessages(prompt),
};

switch (type) {
case "regular": {
return {
args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) },
warnings,
};
}

case "object-json": {
return {
args: {
...baseArgs,
response_format: {
type: "json_schema",
json_schema: mode.schema,
},
tools: undefined,
},
warnings,
};
}

case "object-tool": {
return {
args: {
...baseArgs,
tool_choice: "any",
tools: [{ type: "function", function: mode.tool }],
},
warnings,
};
}

// @ts-expect-error - this is unreachable code
// TODO: fixme
case "object-grammar": {
throw new UnsupportedFunctionalityError({
functionality: "object-grammar mode",
});
}

default: {
const exhaustiveCheck = type satisfies never;
throw new Error(`Unsupported type: ${exhaustiveCheck}`);
}
return {
args: {
...baseArgs,
tool_choice: toolChoice,
tools
},
warnings,
}

}

async doGenerate(
options: Parameters<LanguageModelV1["doGenerate"]>[0],
): Promise<Awaited<ReturnType<LanguageModelV1["doGenerate"]>>> {
const { args, warnings } = this.getArgs(options);
options: Parameters<LanguageModelV2["doGenerate"]>[0],
): Promise<Awaited<ReturnType<LanguageModelV2["doGenerate"]>>> {
const { warnings } = this.getArgs(options);

const { messages } = convertToWorkersAIChatMessages(options.prompt);

const output = await this.config.binding.aiSearch({
query: messages.map(({ content, role }) => `${role}: ${content}`).join("\n\n"),
});

//@ts-ignore
return {
text: output.response,
toolCalls: processToolCalls(output),
// content: output.response,
// toolCalls: processToolCalls(output),
finishReason: "stop", // TODO: mapWorkersAIFinishReason(response.finish_reason),
rawCall: { rawPrompt: args.messages, rawSettings: args },
// rawCall: { rawPrompt: args.messages, rawSettings: args },
usage: mapWorkersAIUsage(output),
warnings,
sources: output.data.map(({ file_id, filename, score }) => ({
id: file_id,
sourceType: "url",
url: filename,
providerMetadata: {
attributes: { score },
},
})),
// sources: output.data.map(({ file_id, filename, score }) => ({
// id: file_id,
// sourceType: "url",
// url: filename,
// providerMetadata: {
// attributes: { score },
// },
// })),
};
}

async doStream(
options: Parameters<LanguageModelV1["doStream"]>[0],
): Promise<Awaited<ReturnType<LanguageModelV1["doStream"]>>> {
const { args, warnings } = this.getArgs(options);
options: Parameters<LanguageModelV2["doStream"]>[0],
): Promise<Awaited<ReturnType<LanguageModelV2["doStream"]>>> {
// const { args, warnings } = this.getArgs(options);

const { messages } = convertToWorkersAIChatMessages(options.prompt);

Expand All @@ -165,8 +128,8 @@ export class AutoRAGChatLanguageModel implements LanguageModelV1 {

return {
stream: getMappedStream(response),
rawCall: { rawPrompt: args.messages, rawSettings: args },
warnings,
// rawCall: { rawPrompt: args.messages, rawSettings: args },
// warnings,
};
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
import type { LanguageModelV1Prompt, LanguageModelV1ProviderMetadata } from "@ai-sdk/provider";
import type { LanguageModelV2Prompt, LanguageModelV2ProviderMetadata } from "@ai-sdk/provider";
import type { WorkersAIChatPrompt } from "./workersai-chat-prompt";

export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): {
export function convertToWorkersAIChatMessages(prompt: LanguageModelV2Prompt): {
messages: WorkersAIChatPrompt;
images: {
mimeType: string | undefined;
image: Uint8Array;
providerMetadata: LanguageModelV1ProviderMetadata | undefined;
providerMetadata: LanguageModelV2ProviderMetadata | undefined;
}[];
} {
const messages: WorkersAIChatPrompt = [];
const images: {
mimeType: string | undefined;
image: Uint8Array;
providerMetadata: LanguageModelV1ProviderMetadata | undefined;
providerMetadata: LanguageModelV2ProviderMetadata | undefined;
}[] = [];

for (const { role, content } of prompt) {
switch (role) {
case "system": {
messages.push({ role: "system", content });
messages.push({
role: "system",
content: content
.map((part) => {
if (part.type === "text") {
return part.text;
}
return "";
})
.join("\n")
});
break;
}

Expand Down Expand Up @@ -95,10 +105,10 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): {
tool_calls:
toolCalls.length > 0
? toolCalls.map(({ function: { name, arguments: args } }) => ({
id: "null",
type: "function",
function: { name, arguments: args },
}))
id: "null",
type: "function",
function: { name, arguments: args },
}))
: undefined,
});

Expand Down
67 changes: 44 additions & 23 deletions packages/workers-ai-provider/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import {
type ProviderV2,
} from '@ai-sdk/provider';
import { AutoRAGChatLanguageModel } from "./autorag-chat-language-model";
import type { AutoRAGChatSettings } from "./autorag-chat-settings";
import { createRun } from "./utils";
Expand All @@ -15,42 +18,47 @@ import type {
TextGenerationModels,
} from "./workersai-models";


export type WorkersAISettings = (
| {
/**
* Provide a Cloudflare AI binding.
*/
binding: Ai;

/**
* Credentials must be absent when a binding is given.
*/
accountId?: never;
apiKey?: never;
}
/**
* Provide a Cloudflare AI binding.
*/
binding: Ai;

/**
* Credentials must be absent when a binding is given.
*/
accountId?: never;
apiKey?: never;
}
| {
/**
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
*/
accountId: string;
apiKey: string;
/**
* Both binding must be absent if credentials are used directly.
*/
binding?: never;
}
/**
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
*/
accountId: string;
apiKey: string;
/**
* Both binding must be absent if credentials are used directly.
*/
binding?: never;
}
) & {
/**
* Optionally specify a gateway.
*/
gateway?: GatewayOptions;
};

export interface WorkersAI {
export interface WorkersAI extends ProviderV2 {
(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
/**
* Creates a model for text generation.
**/

/**
* @deprecated Use `.languageModel()` instead.
**/
chat(
modelId: TextGenerationModels,
settings?: WorkersAIChatSettings,
Expand All @@ -73,8 +81,19 @@ export interface WorkersAI {

/**
* Creates a model for image generation.
* @deprecated use .imageModel() instead.
**/
image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;

/**
* Creates a model for text generation.
**/
languageModel(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;

/**
* Creates a model for image generation.
**/
imageModel(modelId: string, settings?: WorkersAIImageSettings): WorkersAIImageModel;
}

/**
Expand All @@ -83,6 +102,7 @@ export interface WorkersAI {
export function createWorkersAI(options: WorkersAISettings): WorkersAI {
// Use a binding if one is directly provided. Otherwise use credentials to create
// a `run` method that calls the Cloudflare REST API.
console.log("Creating Workers AI provider with options:", options);
let binding: Ai | undefined;

if (options.binding) {
Expand Down Expand Up @@ -131,7 +151,8 @@ export function createWorkersAI(options: WorkersAISettings): WorkersAI {
return createChatModel(modelId, settings);
};

provider.chat = createChatModel;
provider.chat = createChatModel; // Deprecated alias for `languageModel`
provider.languageModel = createChatModel;
provider.embedding = createEmbeddingModel;
provider.textEmbedding = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { LanguageModelV1FinishReason } from "@ai-sdk/provider";
import type { LanguageModelV2FinishReason } from "@ai-sdk/provider";

export function mapWorkersAIFinishReason(
finishReason: string | null | undefined,
): LanguageModelV1FinishReason {
): LanguageModelV2FinishReason {
switch (finishReason) {
case "stop":
return "stop";
Expand Down
Loading