|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
4 | 7 | "flag" |
5 | 8 | "fmt" |
| 9 | + "io" |
| 10 | + "net/http" |
6 | 11 | "strings" |
| 12 | + |
| 13 | + "github.com/sourcegraph/src-cli/internal/api" |
| 14 | + |
| 15 | + "github.com/sourcegraph/sourcegraph/lib/errors" |
7 | 16 | ) |
8 | 17 |
|
| 18 | +const McpPath = ".api/mcp/v1" |
| 19 | + |
9 | 20 | func init() { |
10 | 21 | flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) |
11 | 22 | commands = append(commands, &command{ |
@@ -33,37 +44,115 @@ func mcpMain(args []string) error { |
33 | 44 | if !ok { |
34 | 45 | return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd) |
35 | 46 | } |
36 | | - return handleMcpTool(tool, args[1:]) |
37 | | -} |
38 | 47 |
|
39 | | -func handleMcpTool(tool *MCPToolDef, args []string) error { |
40 | | - fs, vars, err := buildArgFlagSet(tool) |
| 48 | + flagArgs := args[1:] // skip subcommand name |
| 49 | + if len(args) > 1 && args[1] == "schema" { |
| 50 | + return printSchemas(tool) |
| 51 | + } |
| 52 | + |
| 53 | + flags, vars, err := buildToolFlagSet(tool) |
41 | 54 | if err != nil { |
42 | 55 | return err |
43 | 56 | } |
| 57 | + if err := flags.Parse(flagArgs); err != nil { |
| 58 | + return err |
| 59 | + } |
| 60 | + sanitizeFlagValues(vars) |
44 | 61 |
|
45 | | - if err := fs.Parse(args); err != nil { |
| 62 | + if err := validateToolArgs(tool.InputSchema, args, vars); err != nil { |
46 | 63 | return err |
47 | 64 | } |
48 | 65 |
|
49 | | - inputSchema := tool.InputSchema |
| 66 | + apiClient := cfg.apiClient(nil, flags.Output()) |
| 67 | + return handleMcpTool(context.Background(), apiClient, tool, vars) |
| 68 | +} |
50 | 69 |
|
| 70 | +func printSchemas(tool *MCPToolDef) error { |
| 71 | + input, err := json.MarshalIndent(tool.InputSchema, "", " ") |
| 72 | + if err != nil { |
| 73 | + return err |
| 74 | + } |
| 75 | + output, err := json.MarshalIndent(tool.OutputSchema, "", " ") |
| 76 | + if err != nil { |
| 77 | + return err |
| 78 | + } |
| 79 | + |
| 80 | + fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output)) |
| 81 | + return nil |
| 82 | +} |
| 83 | + |
| 84 | +func validateToolArgs(inputSchema Schema, args []string, vars map[string]any) error { |
51 | 85 | for _, reqName := range inputSchema.Required { |
52 | 86 | if vars[reqName] == nil { |
53 | | - return fmt.Errorf("no value provided for required flag --%s", reqName) |
| 87 | + return errors.Newf("no value provided for required flag --%s", reqName) |
54 | 88 | } |
55 | 89 | } |
56 | 90 |
|
57 | 91 | if len(args) < len(inputSchema.Required) { |
58 | | - return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 92 | + return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
59 | 93 | } |
60 | 94 |
|
61 | | - derefFlagValues(vars) |
| 95 | + return nil |
| 96 | +} |
62 | 97 |
|
63 | | - fmt.Println("Flags") |
64 | | - for name, val := range vars { |
65 | | - fmt.Printf("--%s=%v\n", name, val) |
| 98 | +func handleMcpTool(ctx context.Context, client api.Client, tool *MCPToolDef, vars map[string]any) error { |
| 99 | + jsonRPC := struct { |
| 100 | + Version string `json:"jsonrpc"` |
| 101 | + ID int `json:"id"` |
| 102 | + Method string `json:"method"` |
| 103 | + Params any `json:"params"` |
| 104 | + }{ |
| 105 | + Version: "2.0", |
| 106 | + ID: 1, |
| 107 | + Method: "tools/call", |
| 108 | + Params: struct { |
| 109 | + Name string `json:"name"` |
| 110 | + Arguments map[string]any `json:"arguments"` |
| 111 | + }{ |
| 112 | + Name: tool.RawName, |
| 113 | + Arguments: vars, |
| 114 | + }, |
66 | 115 | } |
67 | 116 |
|
| 117 | + buf := bytes.NewBuffer(nil) |
| 118 | + data, err := json.Marshal(jsonRPC) |
| 119 | + if err != nil { |
| 120 | + return err |
| 121 | + } |
| 122 | + buf.Write(data) |
| 123 | + |
| 124 | + req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpPath, buf) |
| 125 | + if err != nil { |
| 126 | + return err |
| 127 | + } |
| 128 | + req.Header.Add("Content-Type", "application/json") |
| 129 | + req.Header.Add("Accept", "*/*") |
| 130 | + |
| 131 | + resp, err := client.Do(req) |
| 132 | + if err != nil { |
| 133 | + return err |
| 134 | + } |
| 135 | + |
| 136 | + data, err = io.ReadAll(resp.Body) |
| 137 | + if err != nil { |
| 138 | + return err |
| 139 | + } |
| 140 | + |
| 141 | + jsonData, err := parseSSEResponse(data) |
| 142 | + if err != nil { |
| 143 | + return err |
| 144 | + } |
| 145 | + |
| 146 | + fmt.Println(jsonData) |
68 | 147 | return nil |
69 | 148 | } |
| 149 | + |
| 150 | +func parseSSEResponse(data []byte) ([]byte, error) { |
| 151 | + lines := bytes.SplitSeq(data, []byte("\n")) |
| 152 | + for line := range lines { |
| 153 | + if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok { |
| 154 | + return jsonData, nil |
| 155 | + } |
| 156 | + } |
| 157 | + return nil, errors.New("no data found in SSE response") |
| 158 | +} |
0 commit comments