Skip to content

feat (core, react): refactor useChat to use MessageStore #5770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: v5
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions examples/next-openai/app/use-chat-v2/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
'use client';

import { useChatV2 } from '@ai-sdk/react';

export default function Chat() {
const { messages, input, handleInputChange, handleSubmit, addToolResult } =
useChatV2({
api: '/api/use-chat-tools',
maxSteps: 5,

// run client-side tools that are automatically executed:
async onToolCall({ toolCall }) {
// artificial 2 second delay
await new Promise(resolve => setTimeout(resolve, 2000));

if (toolCall.toolName === 'getLocation') {
const cities = [
'New York',
'Los Angeles',
'Chicago',
'San Francisco',
];
return cities[Math.floor(Math.random() * cities.length)];
}
},

onError(error) {
console.error(error);
},
});

return (
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
{messages?.map(message => (
<div key={message.id} className="whitespace-pre-wrap">
<strong>{`${message.role}: `}</strong>
{message.parts.map((part, index) => {
switch (part.type) {
case 'text':
return <div key={index}>{part.text}</div>;
case 'step-start':
return index > 0 ? (
<div key={index} className="text-gray-500">
<hr className="my-2 border-gray-300" />
</div>
) : null;
case 'tool-invocation': {
switch (part.toolInvocation.toolName) {
case 'askForConfirmation': {
switch (part.toolInvocation.state) {
case 'call':
return (
<div key={index} className="text-gray-500">
{part.toolInvocation.args.message}
<div className="flex gap-2">
<button
className="px-4 py-2 font-bold text-white bg-blue-500 rounded hover:bg-blue-700"
onClick={() =>
addToolResult({
toolCallId: part.toolInvocation.toolCallId,
result: 'Yes, confirmed.',
})
}
>
Yes
</button>
<button
className="px-4 py-2 font-bold text-white bg-red-500 rounded hover:bg-red-700"
onClick={() =>
addToolResult({
toolCallId: part.toolInvocation.toolCallId,
result: 'No, denied',
})
}
>
No
</button>
</div>
</div>
);
case 'result':
return (
<div key={index} className="text-gray-500">
Location access allowed:{' '}
{part.toolInvocation.result}
</div>
);
}
break;
}

case 'getLocation': {
switch (part.toolInvocation.state) {
case 'call':
return (
<div key={index} className="text-gray-500">
Getting location...
</div>
);
case 'result':
return (
<div key={index} className="text-gray-500">
Location: {part.toolInvocation.result}
</div>
);
}
break;
}

case 'getWeatherInformation': {
switch (part.toolInvocation.state) {
// example of pre-rendering streaming tool calls:
case 'partial-call':
return (
<pre key={index}>
{JSON.stringify(part.toolInvocation, null, 2)}
</pre>
);
case 'call':
return (
<div key={index} className="text-gray-500">
Getting weather information for{' '}
{part.toolInvocation.args.city}...
</div>
);
case 'result':
return (
<div key={index} className="text-gray-500">
Weather in {part.toolInvocation.args.city}:{' '}
{part.toolInvocation.result}
</div>
);
}
break;
}
}
}
}
})}
<br />
</div>
))}

<form onSubmit={handleSubmit}>
<input
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
value={input}
placeholder="Say something..."
onChange={handleInputChange}
/>
</form>
</div>
);
}
22 changes: 20 additions & 2 deletions packages/ai/core/types/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
LanguageModelV2Source,
} from '@ai-sdk/provider';
import { FetchFunction, ToolCall, ToolResult } from '@ai-sdk/provider-utils';
import { ChatState, ChatStoreEvent } from '../util/chat-store';
import { LanguageModelUsage } from './duplicated/usage';

export type IdGenerator = () => string;
Expand Down Expand Up @@ -280,6 +281,11 @@ Additional data to be sent to the API endpoint.

export type UseChatOptions = {
/**
* The initial messages of the chat.
*/
initialMessages?: Message[];

/**
Keeps the last message when an error happens. Defaults to `true`.

@deprecated This option will be removed in the next major release.
Expand All @@ -300,9 +306,21 @@ Keeps the last message when an error happens. Defaults to `true`.
id?: string;

/**
* Initial messages of the chat. Useful to load an existing chat history.
* Optional initialization object for the chat store.
*/
initialMessages?: Message[];
chats?: Record<string, Pick<ChatState, 'messages'>>;
/**
* Optional callback function that is called when chat store changes.
*/
onChatStoreChange?: ({
event,
chatId,
state,
}: {
event: ChatStoreEvent;
chatId: string;
state: ChatState;
}) => void;

/**
* Initial input of the chat.
Expand Down
110 changes: 110 additions & 0 deletions packages/ai/core/util/call-chat-api-v2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { IdGenerator, JSONValue, UseChatOptions } from '../types';
import { ChatStore } from './chat-store';
import { processChatResponseV2 } from './process-chat-response-v2';
import { processChatTextResponseV2 } from './process-chat-text-response-v2';

// use function to allow for mocking in tests:
const getOriginalFetch = () => fetch;

export async function callChatApiV2({
api,
body,
streamProtocol = 'data',
credentials,
headers,
abortController,
restoreMessagesOnFailure,
onResponse,
onUpdate,
onFinish,
onToolCall,
generateId,
fetch = getOriginalFetch(),
store,
chatId,
}: {
api: string;
body: Record<string, any>;
streamProtocol: 'data' | 'text' | undefined;
credentials: RequestCredentials | undefined;
headers: HeadersInit | undefined;
abortController: (() => AbortController | null) | undefined;
restoreMessagesOnFailure: () => void;
onResponse: ((response: Response) => void | Promise<void>) | undefined;
onUpdate: (options: { data: JSONValue[] | undefined }) => void;
onFinish: UseChatOptions['onFinish'];
onToolCall: UseChatOptions['onToolCall'];
generateId: IdGenerator;
fetch: ReturnType<typeof getOriginalFetch> | undefined;
store: ChatStore;
chatId: string;
}) {
const response = await fetch(api, {
method: 'POST',
body: JSON.stringify(body),
headers: {
'Content-Type': 'application/json',
...headers,
},
signal: abortController?.()?.signal,
credentials,
}).catch(err => {
restoreMessagesOnFailure();
throw err;
});

if (onResponse) {
try {
await onResponse(response);
} catch (err) {
throw err;
}
}

if (!response.ok) {
restoreMessagesOnFailure();
throw new Error(
(await response.text()) ?? 'Failed to fetch the chat response.',
);
}

if (!response.body) {
throw new Error('The response body is empty.');
}

switch (streamProtocol) {
case 'text': {
await processChatTextResponseV2({
chatId,
stream: response.body,
update: onUpdate,
onFinish,
generateId,
store,
});
return;
}

case 'data': {
await processChatResponseV2({
chatId,
stream: response.body,
update: onUpdate,
store,
onToolCall,
onFinish({ message, finishReason, usage }) {
if (onFinish && message != null) {
onFinish(message, { usage, finishReason });
}
},
generateId,
});
return;
}

default: {
const exhaustiveCheck: never = streamProtocol;
throw new Error(`Unknown stream protocol: ${exhaustiveCheck}`);
}
}
}
Loading