diff --git a/cmd/cmd.go b/cmd/cmd.go index 090685a..286b59d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,6 +17,7 @@ var ( host string port int token string + tools string version bool ) @@ -31,6 +32,9 @@ func init() { flag.StringVar(&token, "token", "", "") flag.BoolVar(&flagPkg.ReadOnly, "r", false, "") flag.BoolVar(&flagPkg.ReadOnly, "read-only", false, "") + defaultTools := os.Getenv("GITEA_TOOLS") + flag.StringVar(&tools, "O", defaultTools, "") + flag.StringVar(&tools, "tools", defaultTools, "") flag.BoolVar(&flagPkg.Debug, "d", false, "") flag.BoolVar(&flagPkg.Debug, "debug", false, "") flag.BoolVar(&flagPkg.Insecure, "k", false, "") @@ -48,6 +52,7 @@ func init() { fmt.Fprintf(w, " -p, -port \tHTTP server port (default: 8080)\n") fmt.Fprintf(w, " -T, -token \tPersonal access token\n") fmt.Fprintf(w, " -r, -read-only\tExpose only read-only tools\n") + fmt.Fprintf(w, " -O, -tools \tComma-separated list of tool names to expose\n") fmt.Fprintf(w, " -d, -debug\tEnable debug mode\n") fmt.Fprintf(w, " -k, -insecure\tIgnore TLS certificate errors\n") fmt.Fprintf(w, " -v, -version\tPrint version and exit\n") @@ -59,6 +64,7 @@ func init() { fmt.Fprintf(w, " GITEA_HOST\tOverride Gitea host URL\n") fmt.Fprintf(w, " GITEA_INSECURE\tSet to 'true' to ignore TLS errors\n") fmt.Fprintf(w, " GITEA_READONLY\tSet to 'true' for read-only mode\n") + fmt.Fprintf(w, " GITEA_TOOLS\tComma-separated list of tool names to expose\n") fmt.Fprintf(w, " MCP_MODE\tOverride transport mode\n") w.Flush() } @@ -95,6 +101,16 @@ func init() { flagPkg.ReadOnly = true } + allowed := map[string]struct{}{} + for t := range strings.SplitSeq(tools, ",") { + if t = strings.TrimSpace(t); t != "" { + allowed[t] = struct{}{} + } + } + if len(allowed) > 0 { + flagPkg.AllowedTools = allowed + } + if os.Getenv("GITEA_DEBUG") == "true" { flagPkg.Debug = true } diff --git a/operation/operation.go b/operation/operation.go index 597e692..e5a9df6 100644 --- a/operation/operation.go +++ b/operation/operation.go @@ -27,53 +27,27 @@ import ( mcpContext "gitea.com/gitea/gitea-mcp/pkg/context" "gitea.com/gitea/gitea-mcp/pkg/flag" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/tool" "github.com/mark3labs/mcp-go/server" ) -var mcpServer *server.MCPServer +var ( + mcpServer *server.MCPServer + + domainTools = []*tool.Tool{ + user.Tool, actions.Tool, repo.Tool, notification.Tool, issue.Tool, + label.Tool, milestone.Tool, packages.Tool, pull.Tool, search.Tool, + version.Tool, wiki.Tool, timetracking.Tool, + } +) func RegisterTool(s *server.MCPServer) { - // User Tool - s.AddTools(user.Tool.Tools()...) - - // Actions Tool - s.AddTools(actions.Tool.Tools()...) - - // Repo Tool - s.AddTools(repo.Tool.Tools()...) - - // Notification Tool - s.AddTools(notification.Tool.Tools()...) - - // Issue Tool - s.AddTools(issue.Tool.Tools()...) - - // Label Tool - s.AddTools(label.Tool.Tools()...) - - // Milestone Tool - s.AddTools(milestone.Tool.Tools()...) - - // Package Tool - s.AddTools(packages.Tool.Tools()...) - - // Pull Tool - s.AddTools(pull.Tool.Tools()...) - - // Search Tool - s.AddTools(search.Tool.Tools()...) - - // Version Tool - s.AddTools(version.Tool.Tools()...) - - // Wiki Tool - s.AddTools(wiki.Tool.Tools()...) - - // Time Tracking Tool - s.AddTools(timetracking.Tool.Tools()...) - + for _, t := range domainTools { + s.AddTools(t.Tools()...) + } s.DeleteTools("") + tool.WarnUnmatchedAllowedTools(domainTools...) } // parseAuthToken extracts the token from an Authorization header. diff --git a/pkg/flag/flag.go b/pkg/flag/flag.go index 9ebffa1..9e537c4 100644 --- a/pkg/flag/flag.go +++ b/pkg/flag/flag.go @@ -7,7 +7,8 @@ var ( Version string Mode string - Insecure bool - ReadOnly bool - Debug bool + Insecure bool + ReadOnly bool + Debug bool + AllowedTools map[string]struct{} ) diff --git a/pkg/tool/tool.go b/pkg/tool/tool.go index b137451..84cc474 100644 --- a/pkg/tool/tool.go +++ b/pkg/tool/tool.go @@ -1,7 +1,11 @@ package tool import ( + "slices" + "strings" + "gitea.com/gitea/gitea-mcp/pkg/flag" + "gitea.com/gitea/gitea-mcp/pkg/log" "github.com/mark3labs/mcp-go/server" ) @@ -27,12 +31,48 @@ func (t *Tool) RegisterRead(s server.ServerTool) { } func (t *Tool) Tools() []server.ServerTool { - tools := make([]server.ServerTool, 0, len(t.write)+len(t.read)) - if flag.ReadOnly { - tools = append(tools, t.read...) - return tools + all := make([]server.ServerTool, 0, len(t.write)+len(t.read)) + if !flag.ReadOnly { + all = append(all, t.write...) } - tools = append(tools, t.write...) - tools = append(tools, t.read...) - return tools + all = append(all, t.read...) + if len(flag.AllowedTools) == 0 { + return all + } + filtered := make([]server.ServerTool, 0, len(all)) + for _, st := range all { + if _, ok := flag.AllowedTools[st.Tool.Name]; ok { + filtered = append(filtered, st) + } + } + return filtered +} + +// WarnUnmatchedAllowedTools logs any names in flag.AllowedTools that don't +// match a tool registered on any of the given domains. No-op if the allowlist +// is empty. +func WarnUnmatchedAllowedTools(domains ...*Tool) { + if len(flag.AllowedTools) == 0 { + return + } + known := map[string]struct{}{} + for _, d := range domains { + for _, st := range d.read { + known[st.Tool.Name] = struct{}{} + } + for _, st := range d.write { + known[st.Tool.Name] = struct{}{} + } + } + var unmatched []string + for name := range flag.AllowedTools { + if _, ok := known[name]; !ok { + unmatched = append(unmatched, name) + } + } + if len(unmatched) == 0 { + return + } + slices.Sort(unmatched) + log.Warnf("Unknown tools in --tools allowlist (ignored): %s", strings.Join(unmatched, ", ")) } diff --git a/pkg/tool/tool_test.go b/pkg/tool/tool_test.go new file mode 100644 index 0000000..7b62857 --- /dev/null +++ b/pkg/tool/tool_test.go @@ -0,0 +1,100 @@ +package tool + +import ( + "slices" + "testing" + + "gitea.com/gitea/gitea-mcp/pkg/flag" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func makeTool(name string) server.ServerTool { + return server.ServerTool{Tool: mcp.NewTool(name)} +} + +func names(sts []server.ServerTool) []string { + out := make([]string, len(sts)) + for i, st := range sts { + out[i] = st.Tool.Name + } + return out +} + +func TestTools(t *testing.T) { + tests := []struct { + name string + readOnly bool + allowed map[string]struct{} + read []string + write []string + want []string + }{ + { + name: "no filters returns write then read", + read: []string{"r1", "r2"}, + write: []string{"w1", "w2"}, + want: []string{"w1", "w2", "r1", "r2"}, + }, + { + name: "read-only excludes write", + readOnly: true, + read: []string{"r1", "r2"}, + write: []string{"w1"}, + want: []string{"r1", "r2"}, + }, + { + name: "allowlist keeps only listed", + allowed: map[string]struct{}{"r1": {}, "w1": {}}, + read: []string{"r1", "r2"}, + write: []string{"w1", "w2"}, + want: []string{"w1", "r1"}, + }, + { + name: "allowlist intersected with read-only drops write entries", + readOnly: true, + allowed: map[string]struct{}{"r1": {}, "w1": {}}, + read: []string{"r1", "r2"}, + write: []string{"w1", "w2"}, + want: []string{"r1"}, + }, + { + name: "allowlist with only unknown names returns empty", + allowed: map[string]struct{}{"unknown": {}}, + read: []string{"r1"}, + write: []string{"w1"}, + want: []string{}, + }, + { + name: "empty allowlist map passes through", + allowed: map[string]struct{}{}, + read: []string{"r1"}, + write: []string{"w1"}, + want: []string{"w1", "r1"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origRO, origAllow := flag.ReadOnly, flag.AllowedTools + t.Cleanup(func() { + flag.ReadOnly, flag.AllowedTools = origRO, origAllow + }) + flag.ReadOnly = tt.readOnly + flag.AllowedTools = tt.allowed + + tr := New() + for _, n := range tt.read { + tr.RegisterRead(makeTool(n)) + } + for _, n := range tt.write { + tr.RegisterWrite(makeTool(n)) + } + + got := names(tr.Tools()) + if !slices.Equal(got, tt.want) { + t.Errorf("Tools() = %v, want %v", got, tt.want) + } + }) + } +}