From ebbda93813218d51ceeb23a483dec57ce72ab813 Mon Sep 17 00:00:00 2001 From: Maximillian Polhill Date: Wed, 23 Apr 2025 19:46:34 +0000 Subject: [PATCH 1/2] Add initial Gist tools: ListGists, CreateGist --- pkg/github/gists.go | 167 +++++++++++++++++++ pkg/github/gists_test.go | 343 +++++++++++++++++++++++++++++++++++++++ pkg/github/tools.go | 9 + 3 files changed, 519 insertions(+) create mode 100644 pkg/github/gists.go create mode 100644 pkg/github/gists_test.go diff --git a/pkg/github/gists.go b/pkg/github/gists.go new file mode 100644 index 00000000..a4e840fc --- /dev/null +++ b/pkg/github/gists.go @@ -0,0 +1,167 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListGists creates a tool to list gists for a user +func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_gists", + mcp.WithDescription(t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user")), + mcp.WithString("username", + mcp.Description("GitHub username (omit for authenticated user's gists)"), + ), + mcp.WithString("since", + mcp.Description("Only gists updated after this time (ISO 8601 timestamp)"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + username, err := OptionalParam[string](request, "username") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := OptionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + }, + } + + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil + } + opts.Since = sinceTime + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return nil, fmt.Errorf("failed to list gists: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil + } + + r, err := json.Marshal(gists) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// CreateGist creates a tool to create a new gist +func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("create_gist", + mcp.WithDescription(t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist")), + mcp.WithString("description", + mcp.Description("Description of the gist"), + ), + mcp.WithString("filename", + mcp.Required(), + mcp.Description("Filename for simple single-file gist creation"), + ), + mcp.WithString("content", + mcp.Required(), + mcp.Description("Content for simple single-file gist creation"), + ), + mcp.WithBoolean("public", + mcp.Description("Whether the gist is public"), + mcp.DefaultBool(false), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + filename, err := requiredParam[string](request, "filename") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + content, err := requiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + public, err := OptionalParam[bool](request, "public") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return nil, fmt.Errorf("failed to create gist: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil + } + + r, err := json.Marshal(createdGist) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go new file mode 100644 index 00000000..e6121d17 --- /dev/null +++ b/pkg/github/gists_test.go @@ -0,0 +1,343 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListGists(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_gists", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "username") + assert.Contains(t, tool.InputSchema.Properties, "since") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.Empty(t, tool.InputSchema.Required) + + // Setup mock gists for success case + mockGists := []*github.Gist{ + { + ID: github.Ptr("gist1"), + Description: github.Ptr("First Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/gist1"), + Public: github.Ptr(true), + CreatedAt: &github.Timestamp{Time: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "file1.txt": { + Filename: github.Ptr("file1.txt"), + Content: github.Ptr("content of file 1"), + }, + }, + }, + { + ID: github.Ptr("gist2"), + Description: github.Ptr("Second Gist"), + HTMLURL: github.Ptr("https://gist.github.com/testuser/gist2"), + Public: github.Ptr(false), + CreatedAt: &github.Timestamp{Time: time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC)}, + Owner: &github.User{Login: github.Ptr("testuser")}, + Files: map[github.GistFilename]github.GistFile{ + "file2.js": { + Filename: github.Ptr("file2.js"), + Content: github.Ptr("console.log('hello');"), + }, + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedGists []*github.Gist + expectedErrMsg string + }{ + { + name: "list authenticated user's gists", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetGists, + mockGists, + ), + ), + requestArgs: map[string]interface{}{}, + expectError: false, + expectedGists: mockGists, + }, + { + name: "list specific user's gists", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUsersGistsByUsername, + mockResponse(t, http.StatusOK, mockGists), + ), + ), + requestArgs: map[string]interface{}{ + "username": "testuser", + }, + expectError: false, + expectedGists: mockGists, + }, + { + name: "list gists with pagination and since parameter", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGists, + expectQueryParams(t, map[string]string{ + "since": "2023-01-01T00:00:00Z", + "page": "2", + "per_page": "5", + }).andThen( + mockResponse(t, http.StatusOK, mockGists), + ), + ), + ), + requestArgs: map[string]interface{}{ + "since": "2023-01-01T00:00:00Z", + "page": float64(2), + "perPage": float64(5), + }, + expectError: false, + expectedGists: mockGists, + }, + { + name: "invalid since parameter", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetGists, + mockGists, + ), + ), + requestArgs: map[string]interface{}{ + "since": "invalid-date", + }, + expectError: true, + expectedErrMsg: "invalid since timestamp", + }, + { + name: "list gists fails with error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGists, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "failed to list gists", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedGists []*github.Gist + err = json.Unmarshal([]byte(textContent.Text), &returnedGists) + require.NoError(t, err) + + assert.Len(t, returnedGists, len(tc.expectedGists)) + for i, gist := range returnedGists { + assert.Equal(t, *tc.expectedGists[i].ID, *gist.ID) + assert.Equal(t, *tc.expectedGists[i].Description, *gist.Description) + assert.Equal(t, *tc.expectedGists[i].HTMLURL, *gist.HTMLURL) + assert.Equal(t, *tc.expectedGists[i].Public, *gist.Public) + } + }) + } +} + +func Test_CreateGist(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "create_gist", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "filename") + assert.Contains(t, tool.InputSchema.Properties, "content") + assert.Contains(t, tool.InputSchema.Properties, "public") + + // Verify required parameters + assert.Contains(t, tool.InputSchema.Required, "filename") + assert.Contains(t, tool.InputSchema.Required, "content") + + // Setup mock data for test cases + createdGist := &github.Gist{ + ID: github.Ptr("new-gist-id"), + Description: github.Ptr("Test Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/new-gist-id"), + Public: github.Ptr(false), + CreatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "test.go": { + Filename: github.Ptr("test.go"), + Content: github.Ptr("package main\n\nfunc main() {\n\tfmt.Println(\"Hello, Gist!\")\n}"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + expectedGist *github.Gist + }{ + { + name: "create gist successfully", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostGists, + mockResponse(t, http.StatusCreated, createdGist), + ), + ), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "content": "package main\n\nfunc main() {\n\tfmt.Println(\"Hello, Gist!\")\n}", + "description": "Test Gist", + "public": false, + }, + expectError: false, + expectedGist: createdGist, + }, + { + name: "missing required filename", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "content": "test content", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: filename", + }, + { + name: "missing required content", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: content", + }, + { + name: "api returns error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostGists, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "filename": "test.go", + "content": "package main", + "description": "Test Gist", + }, + expectError: true, + expectedErrMsg: "failed to create gist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + assert.NotNil(t, result) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var gist *github.Gist + err = json.Unmarshal([]byte(textContent.Text), &gist) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedGist.ID, *gist.ID) + assert.Equal(t, *tc.expectedGist.Description, *gist.Description) + assert.Equal(t, *tc.expectedGist.HTMLURL, *gist.HTMLURL) + assert.Equal(t, *tc.expectedGist.Public, *gist.Public) + + // Verify file content + for filename, expectedFile := range tc.expectedGist.Files { + actualFile, exists := gist.Files[filename] + assert.True(t, exists) + assert.Equal(t, *expectedFile.Filename, *actualFile.Filename) + assert.Equal(t, *expectedFile.Content, *actualFile.Content) + } + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ce10c4ad..a047bf67 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -75,6 +75,13 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, ) // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") + gists := toolsets.NewToolset("gists", "GitHub Gist related tools"). + AddReadTools( + toolsets.NewServerTool(ListGists(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(CreateGist(getClient, t)), + ) // Add toolsets to the group tsg.AddToolset(repos) @@ -83,6 +90,8 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(experiments) + tsg.AddToolset(gists) + // Enable the requested features if err := tsg.EnableToolsets(passedToolsets); err != nil { From 668de2c180f622632dca1a532e91f7a6e59008c1 Mon Sep 17 00:00:00 2001 From: Maximillian Polhill Date: Fri, 25 Apr 2025 13:44:46 -0700 Subject: [PATCH 2/2] Add UpdateGist tool --- pkg/github/gists.go | 80 +++++++++++++++++++ pkg/github/gists_test.go | 164 +++++++++++++++++++++++++++++++++++++++ pkg/github/tools.go | 1 + 3 files changed, 245 insertions(+) diff --git a/pkg/github/gists.go b/pkg/github/gists.go index a4e840fc..94f2da39 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -165,3 +165,83 @@ func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (to return mcp.NewToolResultText(string(r)), nil } } + +// UpdateGist creates a tool to edit an existing gist +func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_gist", + mcp.WithDescription(t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist")), + mcp.WithString("gist_id", + mcp.Required(), + mcp.Description("ID of the gist to update"), + ), + mcp.WithString("description", + mcp.Description("Updated description of the gist"), + ), + mcp.WithString("filename", + mcp.Required(), + mcp.Description("Filename to update or create"), + ), + mcp.WithString("content", + mcp.Required(), + mcp.Description("Content for the file"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + gistID, err := requiredParam[string](request, "gist_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + description, err := OptionalParam[string](request, "description") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + filename, err := requiredParam[string](request, "filename") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + content, err := requiredParam[string](request, "content") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return nil, fmt.Errorf("failed to update gist: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil + } + + r, err := json.Marshal(updatedGist) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} \ No newline at end of file diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index e6121d17..ff17ec0e 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -341,3 +341,167 @@ func Test_CreateGist(t *testing.T) { }) } } + +func Test_UpdateGist(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_gist", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "gist_id") + assert.Contains(t, tool.InputSchema.Properties, "description") + assert.Contains(t, tool.InputSchema.Properties, "filename") + assert.Contains(t, tool.InputSchema.Properties, "content") + + // Verify required parameters + assert.Contains(t, tool.InputSchema.Required, "gist_id") + assert.Contains(t, tool.InputSchema.Required, "filename") + assert.Contains(t, tool.InputSchema.Required, "content") + + // Setup mock data for test cases + updatedGist := &github.Gist{ + ID: github.Ptr("existing-gist-id"), + Description: github.Ptr("Updated Test Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/existing-gist-id"), + Public: github.Ptr(true), + UpdatedAt: &github.Timestamp{Time: time.Now()}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + "updated.go": { + Filename: github.Ptr("updated.go"), + Content: github.Ptr("package main\n\nfunc main() {\n\tfmt.Println(\"Updated Gist!\")\n}"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + expectedGist *github.Gist + }{ + { + name: "update gist successfully", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchGistsByGistId, + mockResponse(t, http.StatusOK, updatedGist), + ), + ), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "filename": "updated.go", + "content": "package main\n\nfunc main() {\n\tfmt.Println(\"Updated Gist!\")\n}", + "description": "Updated Test Gist", + }, + expectError: false, + expectedGist: updatedGist, + }, + { + name: "missing required gist_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "filename": "updated.go", + "content": "updated content", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: gist_id", + }, + { + name: "missing required filename", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "content": "updated content", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: filename", + }, + { + name: "missing required content", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "gist_id": "existing-gist-id", + "filename": "updated.go", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "missing required parameter: content", + }, + { + name: "api returns error", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchGistsByGistId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "gist_id": "nonexistent-gist-id", + "filename": "updated.go", + "content": "package main", + "description": "Updated Test Gist", + }, + expectError: true, + expectedErrMsg: "failed to update gist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + assert.NotNil(t, result) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var gist *github.Gist + err = json.Unmarshal([]byte(textContent.Text), &gist) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedGist.ID, *gist.ID) + assert.Equal(t, *tc.expectedGist.Description, *gist.Description) + assert.Equal(t, *tc.expectedGist.HTMLURL, *gist.HTMLURL) + + // Verify file content + for filename, expectedFile := range tc.expectedGist.Files { + actualFile, exists := gist.Files[filename] + assert.True(t, exists) + assert.Equal(t, *expectedFile.Filename, *actualFile.Filename) + assert.Equal(t, *expectedFile.Content, *actualFile.Content) + } + }) + } +} \ No newline at end of file diff --git a/pkg/github/tools.go b/pkg/github/tools.go index a047bf67..cceee3fe 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -81,6 +81,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, ). AddWriteTools( toolsets.NewServerTool(CreateGist(getClient, t)), + toolsets.NewServerTool(UpdateGist(getClient, t)), ) // Add toolsets to the group