Skip to content

Commit e94f2f9

Browse files
committed
make json rpc 2.0 call based on cli args
- basic sse resposne parsing - remove zero / nil arguments before making the request
1 parent 619a3a0 commit e94f2f9

File tree

3 files changed

+124
-17
lines changed

3 files changed

+124
-17
lines changed

cmd/src/mcp.go

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
package main
22

33
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
47
"flag"
58
"fmt"
9+
"io"
10+
"net/http"
611
"strings"
712

13+
"github.com/sourcegraph/src-cli/internal/api"
814
"github.com/sourcegraph/src-cli/internal/mcp"
15+
16+
"github.com/sourcegraph/sourcegraph/lib/errors"
17+
18+
919
)
1020

21+
const McpPath = ".api/mcp/v1"
22+
1123
func init() {
1224
flagSet := flag.NewFlagSet("mcp", flag.ExitOnError)
1325
commands = append(commands, &command{
@@ -35,37 +47,115 @@ func mcpMain(args []string) error {
3547
if !ok {
3648
return fmt.Errorf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd)
3749
}
38-
return handleMcpTool(tool, args[1:])
39-
}
4050

41-
func handleMcpTool(tool *mcp.MCPToolDef, args []string) error {
42-
fs, vars, err := buildArgFlagSet(tool)
51+
flagArgs := args[1:] // skip subcommand name
52+
if len(args) > 1 && args[1] == "schema" {
53+
return printSchemas(tool)
54+
}
55+
56+
flags, vars, err := buildToolFlagSet(tool)
4357
if err != nil {
4458
return err
4559
}
60+
if err := flags.Parse(flagArgs); err != nil {
61+
return err
62+
}
63+
sanitizeFlagValues(vars)
4664

47-
if err := fs.Parse(args); err != nil {
65+
if err := validateToolArgs(tool.InputSchema, args, vars); err != nil {
4866
return err
4967
}
5068

51-
inputSchema := tool.InputSchema
69+
apiClient := cfg.apiClient(nil, flags.Output())
70+
return handleMcpTool(context.Background(), apiClient, tool, vars)
71+
}
5272

73+
func printSchemas(tool *mcp.MCPToolDef) error {
74+
input, err := json.MarshalIndent(tool.InputSchema, "", " ")
75+
if err != nil {
76+
return err
77+
}
78+
output, err := json.MarshalIndent(tool.OutputSchema, "", " ")
79+
if err != nil {
80+
return err
81+
}
82+
83+
fmt.Printf("Input:\n%v\nOutput:\n%v\n", string(input), string(output))
84+
return nil
85+
}
86+
87+
func validateToolArgs(inputSchema mcp.Schema, args []string, vars map[string]any) error {
5388
for _, reqName := range inputSchema.Required {
5489
if vars[reqName] == nil {
55-
return fmt.Errorf("no value provided for required flag --%s", reqName)
90+
return errors.Newf("no value provided for required flag --%s", reqName)
5691
}
5792
}
5893

5994
if len(args) < len(inputSchema.Required) {
60-
return fmt.Errorf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
95+
return errors.Newf("not enough arguments provided - the following flags are required:\n%s", strings.Join(inputSchema.Required, "\n"))
96+
}
97+
98+
return nil
99+
}
100+
101+
func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.MCPToolDef, vars map[string]any) error {
102+
jsonRPC := struct {
103+
Version string `json:"jsonrpc"`
104+
ID int `json:"id"`
105+
Method string `json:"method"`
106+
Params any `json:"params"`
107+
}{
108+
Version: "2.0",
109+
ID: 1,
110+
Method: "tools/call",
111+
Params: struct {
112+
Name string `json:"name"`
113+
Arguments map[string]any `json:"arguments"`
114+
}{
115+
Name: tool.RawName,
116+
Arguments: vars,
117+
},
118+
}
119+
120+
buf := bytes.NewBuffer(nil)
121+
data, err := json.Marshal(jsonRPC)
122+
if err != nil {
123+
return err
61124
}
125+
buf.Write(data)
62126

63-
derefFlagValues(vars)
127+
req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpPath, buf)
128+
if err != nil {
129+
return err
130+
}
131+
req.Header.Add("Content-Type", "application/json")
132+
req.Header.Add("Accept", "*/*")
64133

65-
fmt.Println("Flags")
66-
for name, val := range vars {
67-
fmt.Printf("--%s=%v\n", name, val)
134+
resp, err := client.Do(req)
135+
if err != nil {
136+
return err
68137
}
69138

139+
data, err = io.ReadAll(resp.Body)
140+
if err != nil {
141+
return err
142+
}
143+
144+
jsonData, err := parseSSEResponse(data)
145+
if err != nil {
146+
return err
147+
}
148+
149+
fmt.Println(string(jsonData))
70150
return nil
71151
}
152+
153+
func parseSSEResponse(data []byte) ([]byte, error) {
154+
lines := bytes.SplitSeq(data, []byte("\n"))
155+
for line := range lines {
156+
if jsonData, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
157+
return jsonData, nil
158+
}
159+
}
160+
return nil, errors.New("no data found in SSE response")
161+
}

cmd/src/mcp_args.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,37 @@ func (s *strSliceFlag) String() string {
2222
return strings.Join(s.vals, ",")
2323
}
2424

25-
func derefFlagValues(vars map[string]any) {
25+
func sanitizeFlagValues(vars map[string]any) {
2626
for k, v := range vars {
2727
rfl := reflect.ValueOf(v)
2828
if rfl.Kind() == reflect.Pointer {
2929
vv := rfl.Elem().Interface()
3030
if slice, ok := vv.(strSliceFlag); ok {
3131
vv = slice.vals
3232
}
33-
vars[k] = vv
33+
if isNil(vv) {
34+
delete(vars, k)
35+
} else {
36+
vars[k] = vv
37+
}
3438
}
3539
}
3640
}
3741

38-
func buildArgFlagSet(tool *MCPToolDef) (*flag.FlagSet, map[string]any, error) {
42+
func isNil(v any) bool {
43+
if v == nil {
44+
return true
45+
}
46+
rv := reflect.ValueOf(v)
47+
switch rv.Kind() {
48+
case reflect.Slice, reflect.Map, reflect.Pointer, reflect.Interface:
49+
return rv.IsNil()
50+
default:
51+
return false
52+
}
53+
}
54+
55+
func buildToolFlagSet(tool *MCPToolDef) (*flag.FlagSet, map[string]any, error) {
3956
fs := flag.NewFlagSet(tool.Name(), flag.ContinueOnError)
4057
flagVars := map[string]any{}
4158

cmd/src/mcp_args_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestFlagSetParse(t *testing.T) {
4949
t.Fatalf("failed to load tool json: %v", err)
5050
}
5151

52-
flagSet, vars, err := buildArgFlagSet(defs["sg_test_tool"])
52+
flagSet, vars, err := buildToolFlagSet(defs["sg_test_tool"])
5353
if err != nil {
5454
t.Fatalf("failed to build flagset from mcp tool definition: %v", err)
5555
}
@@ -63,7 +63,7 @@ func TestFlagSetParse(t *testing.T) {
6363
if err := flagSet.Parse(args); err != nil {
6464
t.Fatalf("flagset parsing failed: %v", err)
6565
}
66-
derefFlagValues(vars)
66+
sanitizeFlagValues(vars)
6767

6868
if v, ok := vars["repos"].([]string); ok {
6969
if len(v) != 2 {

0 commit comments

Comments
 (0)