Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/chat/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@ async function main() {
sessionId: process.env['MCP_RUN_SESSION_ID']
})

mcpx.addBeforePolicy("fetch", async (call, context) => {
const hostname = new URL(call.params.arguments.url).hostname
const counts = context.get("fetch.domain-counts") || {}
const currentCount = counts[hostname] || 0
// we can only call each domain 3 times for some reason!
// our lawyers are making us!
if (currentCount >= 3) {
console.log(`failed`)
return {
allowed: false,
reason: `The domain ${hostname} was called too many times: ${currentCount}`
};
}

return { allowed: true }
})

mcpx.addAfterPolicy("fetch", async (call, context, _result) => {
const hostname = new URL(call.params.arguments.url).hostname
const counts = context.get("fetch.domain-counts") || {}
const count = (counts[hostname] || 0) + 1
counts[hostname] = count
context.set("fetch.domain-counts", counts)
console.log({hostname, counts, count})

return { allowed: true }
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about folding this into a single intercept "middleware" function?

mcpx.intercept("fetch", async (request, context, next) => {
  const { hostname } = new URL(request.params.arguments.url)
  const counts = context.get("fetch.domain-counts") || {}
  if (!context.has("fetch.domain-counts")) {
    context.set("fetch.domain-counts", counts)
  }

  counts[hostname] ??= 0
  // we can only call each domain 3 times for some reason!
  // our lawyers are making us!
  if (counts[hostname] >= 3) {
    console.log(`failed`)
    return {
      isError: true,
      text: `The domain ${hostname} was called too many times: ${currentCount}`
    }
  }

  const result = await next(request, context)

  ++counts[hostname]
  console.log({hostname, counts, count})

  return result
})

I.e., instead of an allow/deny response, give the library the ability to control handling the incoming tool call request or delegating it to the next layer down. This makes it possible to stub tool invocations during tests, call a tool multiple times, return its own response in special circumstances, etc. (You could implement a policy API on top of this.)


const messages = [{
role: 'system',
content: `
Expand Down
2 changes: 1 addition & 1 deletion examples/chat/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
import { McpxOpenAI } from './openai';
import { CallToolRequest, CallToolResult } from "@modelcontextprotocol/sdk/types";
import type { PolicyContext, PolicyFunction } from "./policy-enforcer";

export { McpxOpenAI }
export type { CallToolRequest, CallToolResult, PolicyFunction, PolicyContext}

18 changes: 15 additions & 3 deletions src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type { ChatCompletion, ChatCompletionCreateParamsNonStreaming, ChatComple
import { Session } from './session';
import type { RequestOptions } from 'openai/core';
import { Logger } from 'pino';
import { CallToolRequest } from '@modelcontextprotocol/sdk/types';
import { PolicyFunction } from './policy-enforcer';

export interface McpxOpenAIOptions {
openai: OpenAI;
Expand All @@ -23,6 +25,16 @@ export class McpxOpenAI {
this.#session = session
}

// Add a policy to run before the function call
addBeforePolicy(functionName: string, policy: PolicyFunction) {
this.#session.addBeforePolicy(functionName, policy)
}

// Add a policy to run after the function call
addAfterPolicy(functionName: string, policy: PolicyFunction) {
this.#session.addAfterPolicy(functionName, policy)
}

static async create(opts: McpxOpenAIOptions) {
const {openai, logger, sessionId, profile } = opts
const config = {
Expand Down Expand Up @@ -78,14 +90,14 @@ export class McpxOpenAI {
return;
}

console.info(toolCall)
try {
// process the tool call using mcpx
const toolResp = await this.#session.handleCallTool({
params: {
name: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
}
arguments: JSON.parse(toolCall.function.arguments) as Record<string, unknown>,
},
method: "tools/call"
});

messages.push({
Expand Down
74 changes: 74 additions & 0 deletions src/policy-enforcer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import { CallToolRequest, CallToolResult } from "@modelcontextprotocol/sdk/types";

interface PolicyResult {
allowed: boolean;
reason?: string;
}

// A policy can inspect both the call and the result
export type PolicyFunction = (
call: CallToolRequest,
context: PolicyContext,
result?: CallToolResult,
) => Promise<PolicyResult>;

// Shared context that policies can use to track state
export class PolicyContext {
private state = new Map<string, any>();

get(key: string): any {
return this.state.get(key);
}

set(key: string, value: any): void {
this.state.set(key, value);
}
}

export class PolicyEnforcer {
private beforePolicies: Map<string, PolicyFunction[]> = new Map();
private afterPolicies: Map<string, PolicyFunction[]> = new Map();
private context = new PolicyContext();

// Add a policy to run before the function call
addBeforePolicy(functionName: string, policy: PolicyFunction) {
const policies = this.beforePolicies.get(functionName) || [];
policies.push(policy);
this.beforePolicies.set(functionName, policies);
}

// Add a policy to run after the function call
addAfterPolicy(functionName: string, policy: PolicyFunction) {
const policies = this.afterPolicies.get(functionName) || [];
policies.push(policy);
this.afterPolicies.set(functionName, policies);
}

async wrapCall(
call: CallToolRequest,
executor: (call: CallToolRequest) => Promise<CallToolResult>
): Promise<any> {
// NOTE: hack to normalized name
const normalizedName = call.params.name.replace(/^.*?_/, '')
const beforePolicies = this.beforePolicies.get(normalizedName) || [];
for (const policy of beforePolicies) {
const result = await policy(call, this.context);
if (!result.allowed) {
throw new Error(`Policy violation: ${result.reason}`);
}
}

// execute the tool call
const result = await executor(call);

const afterPolicies = this.afterPolicies.get(normalizedName) || [];
for (const policy of afterPolicies) {
const policyResult = await policy(call, this.context, result);
if (!policyResult.allowed) {
throw new Error(`Policy violation: ${policyResult.reason}`);
}
}

return result;
}
}
20 changes: 18 additions & 2 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { Logger } from 'pino'
import type { ChatCompletion, ChatCompletionMessageParam, ChatCompletionTool } from 'openai/resources';
import { CallToolRequest } from '@modelcontextprotocol/sdk/types'
import { PolicyEnforcer, PolicyFunction } from './policy-enforcer'

export interface SessionConfig {
authentication?: [string, string][]
Expand All @@ -30,12 +32,14 @@ export class Session {
#config: SessionConfig;
#session: McpxSession;
#logger?: Logger;
#policyEnforcer: PolicyEnforcer;
//@ts-ignore
tools: ChatCompletionTool[];

private constructor(opts: SessionOptions) {
this.#config = opts.config
this.#logger = opts.logger
this.#policyEnforcer = new PolicyEnforcer()
}

static async create(opts: SessionOptions) {
Expand All @@ -44,8 +48,20 @@ export class Session {
return s
}

async handleCallTool(opts: any) {
return this.#session.handleCallTool(opts)
// Add a policy to run before the function call
addBeforePolicy(functionName: string, policy: PolicyFunction) {
this.#policyEnforcer.addBeforePolicy(functionName, policy)
}

// Add a policy to run after the function call
addAfterPolicy(functionName: string, policy: PolicyFunction) {
this.#policyEnforcer.addAfterPolicy(functionName, policy)
}

async handleCallTool(request: CallToolRequest) {
return this.#policyEnforcer.wrapCall(request, async (call: CallToolRequest) => {
return this.#session.handleCallTool(call)
})
}

async load() {
Expand Down
Empty file added src/types.ts
Empty file.