Skip to content

Commit d408dd2

Browse files
committed
ToolBuilder
1 parent 0ce8367 commit d408dd2

File tree

3 files changed

+147
-13
lines changed

3 files changed

+147
-13
lines changed

src/Chat.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@ import { PromptBuilder } from "./PromptBuilder";
44
import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types";
55

66
export class Chat<
7+
ToolNames extends string,
78
TMessages extends
89
| []
910
| [...OpenAI.Chat.CreateChatCompletionRequestMessage[], OpenAI.Chat.CreateChatCompletionRequestMessage],
10-
TSuppliedInputArgs extends ExtractChatArgs<TMessages, {}>
11+
TSuppliedInputArgs extends ExtractChatArgs<TMessages, {}>,
1112
> {
1213
constructor(
1314
public messages: F.Narrow<TMessages>,
14-
public args: F.Narrow<TSuppliedInputArgs>
15+
public args: F.Narrow<TSuppliedInputArgs>,
16+
public tools = {} as Record<ToolNames, Tool>,
17+
///
18+
public mustUseTool: boolean = false
1519
) {}
1620

21+
toJSONSchema() {
22+
}
23+
1724
toArray() {
1825
return (this.messages as TMessages).map((m) => ({
1926
role: m.role,

src/ToolBuilder.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
interface Tool<I = unknown, O = unknown> {
2+
name: string;
3+
type: "query" | "mutation"
4+
build: (input: I) => O;
5+
}
6+
7+
export class ToolBuilder<TType extends "query" | "mutation" = "query", I = unknown, O = unknown> {
8+
private name: string;
9+
private implementation?: (input: I) => O;
10+
private type: TType;
11+
12+
constructor(name: string, type: TType = "query" as TType) {
13+
this.name = name;
14+
this.type = type;
15+
}
16+
17+
addInputValidation<T = I>(): ToolBuilder<TType, T, O> {
18+
// Implementation here
19+
return this as unknown as ToolBuilder<TType, T, O>;
20+
}
21+
22+
addOutputValidation<T = O>(): ToolBuilder<TType, I, T> {
23+
// Implementation here
24+
return this as unknown as ToolBuilder<TType, I, T>;
25+
}
26+
27+
query(queryFunction: (input: I) => O): ToolBuilder<"query", I, O> {
28+
29+
return {
30+
...this,
31+
implementation: queryFunction,
32+
type: "query"
33+
};
34+
}
35+
36+
mutation(mutationFunction: (input: I) => O): ToolBuilder<"mutation", I, O> {
37+
return {
38+
...this,
39+
implementation: mutationFunction,
40+
type: "mutation"
41+
};
42+
}
43+
44+
build(): Tool<I, O> {
45+
return {
46+
name: this.name,
47+
build: this.implementation!,
48+
type: this.type
49+
};
50+
}
51+
}

src/__tests__/Chat.test.ts

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ import { strict as assert } from "node:assert";
22
import { Chat } from "../Chat";
33
import { system, user, assistant } from "../ChatHelpers";
44
import { Equal, Expect } from "./types.test";
5+
import { ToolBuilder } from "../ToolBuilder";
56

67
describe("Chat", () => {
78
it("should allow empty array", () => {
8-
const chat = new Chat([], {}).toArray();
9+
const chat = new Chat([], {}, {}).toArray();
910
type test = Expect<Equal<typeof chat, []>>;
1011
assert.deepEqual(chat, []);
1112
});
@@ -43,7 +44,6 @@ describe("Chat", () => {
4344
it("should allow chat of all diffent types", () => {
4445
const chat = new Chat(
4546
[
46-
// ^?
4747
user(`Tell me a {{jokeType1}} joke`),
4848
assistant(`{{var2}} joke?`),
4949
system(`joke? {{var3}}`),
@@ -64,21 +64,97 @@ describe("Chat", () => {
6464
});
6565

6666
it("should allow chat of all diffent types with no args", () => {
67-
const chat = new Chat(
68-
[
69-
// ^?
70-
user(`Tell me a joke`),
71-
assistant(`joke?`),
72-
system(`joke?`),
73-
],
74-
{}
75-
).toArray();
7667
const usrMsg = user("Tell me a joke");
7768
const astMsg = assistant("joke?");
7869
const sysMsg = system("joke?");
70+
71+
const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray();
7972
type test = Expect<
8073
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
8174
>;
8275
assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]);
8376
});
77+
78+
it("should allow me to pass in tools", () => {
79+
const usrMsg = user("Tell me a joke");
80+
const astMsg = assistant("joke?");
81+
const sysMsg = system("joke?");
82+
const tools = {
83+
google: new ToolBuilder("google")
84+
.addInputValidation<{ query: string }>()
85+
.addOutputValidation<{ results: string[] }>()
86+
.query(({ query }) => {
87+
return {
88+
results: ["foo", "bar"],
89+
};
90+
}),
91+
wikipedia: new ToolBuilder("wikipedia")
92+
.addInputValidation<{ page: string }>()
93+
.addOutputValidation<{ results: string[] }>()
94+
.query(({ page }) => {
95+
return {
96+
results: ["foo", "bar"],
97+
};
98+
}),
99+
sendEmail: new ToolBuilder("sendEmail")
100+
.addInputValidation<{ to: string; subject: string; body: string }>()
101+
.addOutputValidation<{ success: boolean }>()
102+
.mutation(({ to, subject, body }) => {
103+
return {
104+
success: true,
105+
};
106+
}),
107+
} as const;
108+
109+
const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools);
110+
111+
type tests = [
112+
Expect<
113+
Equal<
114+
typeof chat,
115+
Chat<
116+
keyof typeof tools,
117+
[typeof usrMsg, typeof astMsg, typeof sysMsg],
118+
{}
119+
>
120+
>
121+
>,
122+
Expect<
123+
Equal<
124+
typeof tools,
125+
{
126+
readonly google: ToolBuilder<
127+
"query",
128+
{
129+
query: string;
130+
},
131+
{
132+
results: string[];
133+
}
134+
>;
135+
readonly wikipedia: ToolBuilder<
136+
"query",
137+
{
138+
page: string;
139+
},
140+
{
141+
results: string[];
142+
}
143+
>;
144+
readonly sendEmail: ToolBuilder<
145+
"mutation",
146+
{
147+
to: string;
148+
subject: string;
149+
body: string;
150+
},
151+
{
152+
success: boolean;
153+
}
154+
>;
155+
}
156+
>
157+
>
158+
];
159+
});
84160
});

0 commit comments

Comments
 (0)