|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
4 | 7 | "flag" |
5 | 8 | "fmt" |
| 9 | + "io" |
| 10 | + "net/http" |
6 | 11 | "strings" |
7 | 12 |
|
| 13 | + "github.com/sourcegraph/src-cli/internal/api" |
8 | 14 | "github.com/sourcegraph/src-cli/internal/mcp" |
| 15 | + |
| 16 | + "github.com/sourcegraph/sourcegraph/lib/errors" |
| 17 | + |
| 18 | + |
9 | 19 | ) |
10 | 20 |
|
| 21 | +const McpPath = ".api/mcp/v1" |
| 22 | + |
11 | 23 | func init() { |
12 | 24 | flagSet := flag.NewFlagSet("mcp", flag.ExitOnError) |
13 | 25 | commands = append(commands, &command{ |
@@ -35,37 +47,115 @@ func mcpMain(args []string) error { |
35 | 47 | if !ok { |
36 | 48 | return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd) |
37 | 49 | } |
38 | | - return handleMcpTool(tool, args[1:]) |
39 | | -} |
40 | 50 |
|
41 | | -func handleMcpTool(tool *mcp.MCPToolDef, args []string) error { |
42 | | - fs, vars, err := buildArgFlagSet(tool) |
| 51 | + flagArgs := args[1:] // skip subcommand name |
| 52 | + if len(args) > 1 && args[1] == "schema" { |
| 53 | + return printSchemas(tool) |
| 54 | + } |
| 55 | + |
| 56 | + flags, vars, err := buildToolFlagSet(tool) |
43 | 57 | if err != nil { |
44 | 58 | return err |
45 | 59 | } |
| 60 | + if err := flags.Parse(flagArgs); err != nil { |
| 61 | + return err |
| 62 | + } |
| 63 | + sanitizeFlagValues(vars) |
46 | 64 |
|
47 | | - if err := fs.Parse(args); err != nil { |
| 65 | + if err := validateToolArgs(tool.InputSchema, args, vars); err != nil { |
48 | 66 | return err |
49 | 67 | } |
50 | 68 |
|
51 | | - inputSchema := tool.InputSchema |
| 69 | + apiClient := cfg.apiClient(nil, flags.Output()) |
| 70 | + return handleMcpTool(context.Background(), apiClient, tool, vars) |
| 71 | +} |
52 | 72 |
|
| 73 | +func printSchemas(tool *mcp.MCPToolDef) error { |
| 74 | + input, err := json.MarshalIndent(tool.InputSchema, "", " ") |
| 75 | + if err != nil { |
| 76 | + return err |
| 77 | + } |
| 78 | + output, err := json.MarshalIndent(tool.OutputSchema, "", " ") |
| 79 | + if err != nil { |
| 80 | + return err |
| 81 | + } |
| 82 | + |
| 83 | + fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output)) |
| 84 | + return nil |
| 85 | +} |
| 86 | + |
| 87 | +func validateToolArgs(inputSchema mcp.Schema, args []string, vars map[string]any) error { |
53 | 88 | for _, reqName := range inputSchema.Required { |
54 | 89 | if vars[reqName] == nil { |
55 | | - return fmt.Errorf("no value provided for required flag --%s", reqName) |
| 90 | + return errors.Newf("no value provided for required flag --%s", reqName) |
56 | 91 | } |
57 | 92 | } |
58 | 93 |
|
59 | 94 | if len(args) < len(inputSchema.Required) { |
60 | | - return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 95 | + return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n")) |
| 96 | + } |
| 97 | + |
| 98 | + return nil |
| 99 | +} |
| 100 | + |
| 101 | +func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.MCPToolDef, vars map[string]any) error { |
| 102 | + jsonRPC := struct { |
| 103 | + Version string `json:"jsonrpc"` |
| 104 | + ID int `json:"id"` |
| 105 | + Method string `json:"method"` |
| 106 | + Params any `json:"params"` |
| 107 | + }{ |
| 108 | + Version: "2.0", |
| 109 | + ID: 1, |
| 110 | + Method: "tools/call", |
| 111 | + Params: struct { |
| 112 | + Name string `json:"name"` |
| 113 | + Arguments map[string]any `json:"arguments"` |
| 114 | + }{ |
| 115 | + Name: tool.RawName, |
| 116 | + Arguments: vars, |
| 117 | + }, |
| 118 | + } |
| 119 | + |
| 120 | + buf := bytes.NewBuffer(nil) |
| 121 | + data, err := json.Marshal(jsonRPC) |
| 122 | + if err != nil { |
| 123 | + return err |
61 | 124 | } |
| 125 | + buf.Write(data) |
62 | 126 |
|
63 | | - derefFlagValues(vars) |
| 127 | + req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpPath, buf) |
| 128 | + if err != nil { |
| 129 | + return err |
| 130 | + } |
| 131 | + req.Header.Add("Content-Type", "application/json") |
| 132 | + req.Header.Add("Accept", "*/*") |
64 | 133 |
|
65 | | - fmt.Println("Flags") |
66 | | - for name, val := range vars { |
67 | | - fmt.Printf("--%s=%v\n", name, val) |
| 134 | + resp, err := client.Do(req) |
| 135 | + if err != nil { |
| 136 | + return err |
68 | 137 | } |
69 | 138 |
|
| 139 | + data, err = io.ReadAll(resp.Body) |
| 140 | + if err != nil { |
| 141 | + return err |
| 142 | + } |
| 143 | + |
| 144 | + jsonData, err := parseSSEResponse(data) |
| 145 | + if err != nil { |
| 146 | + return err |
| 147 | + } |
| 148 | + |
| 149 | + fmt.Println(string(jsonData)) |
70 | 150 | return nil |
71 | 151 | } |
| 152 | + |
| 153 | +func parseSSEResponse(data []byte) ([]byte, error) { |
| 154 | + lines := bytes.SplitSeq(data, []byte("\n")) |
| 155 | + for line := range lines { |
| 156 | + if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok { |
| 157 | + return jsonData, nil |
| 158 | + } |
| 159 | + } |
| 160 | + return nil, errors.New("no data found in SSE response") |
| 161 | +} |
0 commit comments