Skip to content

Commit b46fa54

Browse files
AlexNinynine
and
nine
authored
feat: Add server hooks:OnRequestInitialization (#164)
* add hooks:beforeHandleRequest * Update request_handler.go use mcp.INVALID_REQUEST replace mcp.PARSE_ERROR * rename hook name rename OnBeforeHandleRequest to OnRequestInitialization * update tmpl file update tmpl file generate hooks.go/request_handler.go by "go generate" fix example function name error * update test update test message --------- Co-authored-by: nine <[email protected]>
1 parent 341ebc5 commit b46fa54

File tree

6 files changed

+83
-8
lines changed

6 files changed

+83
-8
lines changed

examples/everything/main.go

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ func NewMCPServer() *server.MCPServer {
4444
hooks.AddBeforeInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest) {
4545
fmt.Printf("beforeInitialize: %v, %v\n", id, message)
4646
})
47+
hooks.AddOnRequestInitialization(func(ctx context.Context, id any, message any) error {
48+
fmt.Printf("AddOnRequestInitialization: %v, %v\n", id, message)
49+
// authorization verification and other preprocessing tasks are performed.
50+
return nil
51+
})
4752
hooks.AddAfterInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
4853
fmt.Printf("afterInitialize: %v, %v, %v\n", id, message, result)
4954
})

server/hooks.go

+22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/internal/gen/hooks.go.tmpl

+23
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m
5959
// })
6060
type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
6161

62+
// OnRequestInitializationFunc is a function that called before handle diff request method
63+
// Should any errors arise during func execution, the service will promptly return the corresponding error message.
64+
type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error
65+
66+
6267
{{range .}}
6368
type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}})
6469
type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
@@ -70,6 +75,7 @@ type Hooks struct {
7075
OnBeforeAny []BeforeAnyHookFunc
7176
OnSuccess []OnSuccessHookFunc
7277
OnError []OnErrorHookFunc
78+
OnRequestInitialization []OnRequestInitializationFunc
7379
{{- range .}}
7480
OnBefore{{.HookName}} []OnBefore{{.HookName}}Func
7581
OnAfter{{.HookName}} []OnAfter{{.HookName}}Func
@@ -199,6 +205,23 @@ func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) {
199205
}
200206
}
201207

208+
func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) {
209+
c.OnRequestInitialization = append(c.OnRequestInitialization, hook)
210+
}
211+
212+
func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error {
213+
if c == nil {
214+
return nil
215+
}
216+
for _, hook := range c.OnRequestInitialization {
217+
err := hook(ctx, id, message)
218+
if err != nil {
219+
return err
220+
}
221+
}
222+
return nil
223+
}
224+
202225
{{- range .}}
203226
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
204227
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)

server/internal/gen/request_handler.go.tmpl

+9
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ func (s *MCPServer) HandleMessage(
6363
return nil
6464
}
6565

66+
handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message)
67+
if handleErr != nil {
68+
return createErrorResponse(
69+
baseMessage.ID,
70+
mcp.INVALID_REQUEST,
71+
handleErr.Error(),
72+
)
73+
}
74+
6675
switch baseMessage.Method {
6776
{{- range .}}
6877
case mcp.{{.MethodName}}:

server/request_handler.go

+9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/server_test.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -1265,13 +1265,14 @@ var _ ClientSession = fakeSession{}
12651265
func TestMCPServer_WithHooks(t *testing.T) {
12661266
// Create hook counters to verify calls
12671267
var (
1268-
beforeAnyCount int
1269-
onSuccessCount int
1270-
onErrorCount int
1271-
beforePingCount int
1272-
afterPingCount int
1273-
beforeToolsCount int
1274-
afterToolsCount int
1268+
beforeAnyCount int
1269+
onSuccessCount int
1270+
onErrorCount int
1271+
beforePingCount int
1272+
afterPingCount int
1273+
beforeToolsCount int
1274+
afterToolsCount int
1275+
onRequestInitializationCount int
12751276
)
12761277

12771278
// Collectors for message and result types
@@ -1335,6 +1336,11 @@ func TestMCPServer_WithHooks(t *testing.T) {
13351336
afterToolsCount++
13361337
})
13371338

1339+
hooks.AddOnRequestInitialization(func(ctx context.Context, id any, message any) error {
1340+
onRequestInitializationCount++
1341+
return nil
1342+
})
1343+
13381344
// Create a server with the hooks
13391345
server := NewMCPServer(
13401346
"test-server",
@@ -1398,10 +1404,11 @@ func TestMCPServer_WithHooks(t *testing.T) {
13981404
assert.Equal(t, 1, afterPingCount, "afterPing should be called once")
13991405
assert.Equal(t, 1, beforeToolsCount, "beforeListTools should be called once")
14001406
assert.Equal(t, 1, afterToolsCount, "afterListTools should be called once")
1401-
14021407
// General hooks should be called for all methods
14031408
// beforeAny is called for all 4 methods (initialize, ping, tools/list, tools/call)
14041409
assert.Equal(t, 4, beforeAnyCount, "beforeAny should be called for each method")
1410+
// onRequestInitialization is called for all 4 methods (initialize, ping, tools/list, tools/call)
1411+
assert.Equal(t, 4, onRequestInitializationCount, "onRequestInitializationCount should be called for each method")
14051412
// onSuccess is called for all 3 success methods (initialize, ping, tools/list)
14061413
assert.Equal(t, 3, onSuccessCount, "onSuccess should be called after all successful invocations")
14071414

0 commit comments

Comments
 (0)