diff --git a/cmd/src/mcp.go b/cmd/src/mcp.go index ec5683e4f7..6e46721f9b 100644 --- a/cmd/src/mcp.go +++ b/cmd/src/mcp.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "strings" "github.com/sourcegraph/src-cli/internal/mcp" ) @@ -16,7 +17,7 @@ func init() { } func mcpMain(args []string) error { fmt.Println("NOTE: This command is still experimental") - tools, err := mcp.LoadToolDefinitions() + tools, err := mcp.LoadDefaultToolDefinitions() if err != nil { return err } @@ -38,6 +39,33 @@ func mcpMain(args []string) error { } func handleMcpTool(tool *mcp.ToolDef, args []string) error { - fmt.Printf("handling tool %q args: %+v", tool.Name, args) + fs, vars, err := mcp.BuildArgFlagSet(tool) + if err != nil { + return err + } + + if err := fs.Parse(args); err != nil { + return err + } + + inputSchema := tool.InputSchema + + for _, reqName := range inputSchema.Required { + if vars[reqName] == nil { + return fmt.Errorf("no value provided for required flag --%s", reqName) + } + } + + if len(args) < len(inputSchema.Required) { + return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) + } + + mcp.DerefFlagValues(vars) + + fmt.Println("Flags") + for name, val := range vars { + fmt.Printf("--%s=%v\n", name, val) + } + return nil } diff --git a/internal/mcp/mcp_args.go b/internal/mcp/mcp_args.go new file mode 100644 index 0000000000..5b5b1ccc01 --- /dev/null +++ b/internal/mcp/mcp_args.go @@ -0,0 +1,75 @@ +package mcp + +import ( + "flag" + "fmt" + "reflect" + "strings" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +var _ flag.Value = (*strSliceFlag)(nil) + +type strSliceFlag struct { + vals []string +} + +func (s *strSliceFlag) Set(v string) error { + s.vals = append(s.vals, v) + return nil +} + +func (s *strSliceFlag) String() string { + return strings.Join(s.vals, ",") +} + +func DerefFlagValues(vars map[string]any) { + for k, v := range vars { + rfl := reflect.ValueOf(v) + if rfl.Kind() == reflect.Pointer { + vv := rfl.Elem().Interface() + if slice, ok := vv.(strSliceFlag); ok { + vv = slice.vals + } + vars[k] = vv + } + } +} + +func BuildArgFlagSet(tool *ToolDef) (*flag.FlagSet, map[string]any, error) { + if tool == nil { + return nil, nil, errors.New("cannot build flagset on nil Tool Definition") + } + fs := flag.NewFlagSet(tool.Name, flag.ContinueOnError) + flagVars := map[string]any{} + + for name, pVal := range tool.InputSchema.Properties { + switch pv := pVal.(type) { + case *SchemaPrimitive: + switch pv.Kind { + case "integer": + dst := fs.Int(name, 0, pv.Description) + flagVars[name] = dst + + case "boolean": + dst := fs.Bool(name, false, pv.Description) + flagVars[name] = dst + case "string": + dst := fs.String(name, "", pv.Description) + flagVars[name] = dst + default: + return nil, nil, fmt.Errorf("unknown schema primitive kind %q", pv.Kind) + + } + case *SchemaArray: + strSlice := new(strSliceFlag) + fs.Var(strSlice, name, pv.Description) + flagVars[name] = strSlice + case *SchemaObject: + // not supported yet + } + } + + return fs, flagVars, nil +} diff --git a/internal/mcp/mcp_args_test.go b/internal/mcp/mcp_args_test.go new file mode 100644 index 0000000000..17d5b466e0 --- /dev/null +++ b/internal/mcp/mcp_args_test.go @@ -0,0 +1,97 @@ +package mcp + +import ( + "testing" +) + +func TestFlagSetParse(t *testing.T) { + toolJSON := []byte(`{ + "tools": [ + { + "name": "sg_test_tool", + "description": "test description", + "inputSchema": { + "type": "object", + "$schema": "https://localhost/schema-draft/2025-07", + "required": ["values"], + "properties": { + "repos": { + "type": "array", + "items": { + "type": "string" + } + }, + "tag": { + "type": "string", + "items": true + }, + "count": { + "type": "integer" + }, + "boolFlag": { + "type": "boolean" + } + } + }, + "outputSchema": { + "type": "object", + "$schema": "https://localhost/schema-draft/2025-07", + "properties": { + "result": { "type": "string" } + } + } + } + ] + }`) + + defs, err := loadToolDefinitions(toolJSON) + if err != nil { + t.Fatalf("failed to load tool json: %v", err) + } + + flagSet, vars, err := BuildArgFlagSet(defs["test-tool"]) + if err != nil { + t.Fatalf("failed to build flagset from mcp tool definition: %v", err) + } + + if len(vars) == 0 { + t.Fatalf("vars from buildArgFlagSet should not be empty") + } + + args := []string{"-repos=A", "-repos=B", "-count=10", "-boolFlag", "-tag=testTag"} + + if err := flagSet.Parse(args); err != nil { + t.Fatalf("flagset parsing failed: %v", err) + } + DerefFlagValues(vars) + + if v, ok := vars["repos"].([]string); ok { + if len(v) != 2 { + t.Fatalf("expected flag 'repos' values to have length %d but got %d", 2, len(v)) + } + } else { + t.Fatalf("expected flag 'repos' to have type of []string but got %T", v) + } + if v, ok := vars["tag"].(string); ok { + if v != "testTag" { + t.Fatalf("expected flag 'tag' values to have value %q but got %q", "testTag", v) + } + } else { + t.Fatalf("expected flag 'tag' to have type of string but got %T", v) + } + if v, ok := vars["count"].(int); ok { + if v != 10 { + t.Fatalf("expected flag 'count' values to have value %d but got %d", 10, v) + } + } else { + t.Fatalf("expected flag 'count' to have type of int but got %T", v) + } + if v, ok := vars["boolFlag"].(bool); ok { + if v != true { + t.Fatalf("expected flag 'boolFlag' values to have value %v but got %v", true, v) + } + } else { + t.Fatalf("expected flag 'boolFlag' to have type of bool but got %T", v) + } + +} diff --git a/internal/mcp/mcp_parse.go b/internal/mcp/mcp_parse.go index e610a271e9..55e9650fd2 100644 --- a/internal/mcp/mcp_parse.go +++ b/internal/mcp/mcp_parse.go @@ -68,7 +68,7 @@ type parser struct { errors []error } -func LoadToolDefinitions() (map[string]*ToolDef, error) { +func LoadDefaultToolDefinitions() (map[string]*ToolDef, error) { return loadToolDefinitions(mcpToolListJSON) }