Skip to content

Commit 13fcea3

Browse files
committed
build flagset from inputschema to parse args
1 parent 75745b3 commit 13fcea3

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

cmd/src/mcp_args.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"strings"
7+
)
8+
9+
var _ flag.Value = (*strSliceFlag)(nil)
10+
11+
type strSliceFlag struct {
12+
vals []string
13+
}
14+
15+
func (s *strSliceFlag) Set(v string) error {
16+
s.vals = append(s.vals, v)
17+
return nil
18+
}
19+
20+
func (s *strSliceFlag) String() string {
21+
return strings.Join(s.vals, ",")
22+
}
23+
24+
func buildArgFlagSet(tool *MCPToolDef) (*flag.FlagSet, map[string]any, error) {
25+
fs := flag.NewFlagSet(tool.Name(), flag.ContinueOnError)
26+
flagVars := map[string]any{}
27+
28+
for name, pVal := range tool.InputSchema.Properties {
29+
switch pv := pVal.(type) {
30+
case *SchemaPrimitive:
31+
switch pv.Kind {
32+
case "integer":
33+
dst := fs.Int(name, 0, pv.Description)
34+
flagVars[name] = dst
35+
36+
case "boolean":
37+
dst := fs.Bool(name, false, pv.Description)
38+
flagVars[name] = dst
39+
case "string":
40+
dst := fs.String(name, "", pv.Description)
41+
flagVars[name] = dst
42+
default:
43+
return nil, nil, fmt.Errorf("unknown schema primitive kind %q", pv.Kind)
44+
45+
}
46+
case *SchemaArray:
47+
strSlice := new(strSliceFlag)
48+
fs.Var(strSlice, name, pv.Description)
49+
flagVars[name] = strSlice
50+
case *SchemaObject:
51+
// not supported yet
52+
}
53+
}
54+
55+
return fs, flagVars, nil
56+
}

cmd/src/mcp_args_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestFlagSetParse(t *testing.T) {
8+
toolJSON := []byte(`{
9+
"tools": [
10+
{
11+
"name": "sg_test_tool",
12+
"description": "test description",
13+
"inputSchema": {
14+
"type": "object",
15+
"$schema": "https://localhost/schema-draft/2025-07",
16+
"required": ["values"],
17+
"properties": {
18+
"repos": {
19+
"type": "array",
20+
"items": {
21+
"type": "string"
22+
}
23+
},
24+
"tag": {
25+
"type": "string",
26+
"items": true
27+
},
28+
"count": {
29+
"type": "integer"
30+
},
31+
"boolFlag": {
32+
"type": "boolean"
33+
}
34+
}
35+
},
36+
"outputSchema": {
37+
"type": "object",
38+
"$schema": "https://localhost/schema-draft/2025-07",
39+
"properties": {
40+
"result": { "type": "string" }
41+
}
42+
}
43+
}
44+
]
45+
}`)
46+
47+
defs, err := LoadMCPToolDefinitions(toolJSON)
48+
if err != nil {
49+
t.Fatalf("failed to load tool json: %v", err)
50+
}
51+
52+
flagSet, vars, err := buildArgFlagSet(defs["sg_test_tool"])
53+
if err != nil {
54+
t.Fatalf("failed to build flagset from mcp tool definition: %v", err)
55+
}
56+
57+
if len(vars) == 0 {
58+
t.Fatalf("vars from buildArgFlagSet should not be empty")
59+
}
60+
61+
args := []string{"-repos=A", "-repos=B", "-count=10", "-boolFlag", "-tag=testTag"}
62+
63+
if err := flagSet.Parse(args); err != nil {
64+
t.Fatalf("flagset parsing failed: %v", err)
65+
}
66+
derefFlagValues(vars)
67+
68+
if v, ok := vars["repos"].([]string); ok {
69+
if len(v) != 2 {
70+
t.Fatalf("expected flag 'repos' values to have length %d but got %d", 2, len(v))
71+
}
72+
} else {
73+
t.Fatalf("expected flag 'repos' to have type of []string but got %T", v)
74+
}
75+
if v, ok := vars["tag"].(string); ok {
76+
if v != "testTag" {
77+
t.Fatalf("expected flag 'tag' values to have value %q but got %q", "testTag", v)
78+
}
79+
} else {
80+
t.Fatalf("expected flag 'tag' to have type of string but got %T", v)
81+
}
82+
if v, ok := vars["count"].(int); ok {
83+
if v != 10 {
84+
t.Fatalf("expected flag 'count' values to have value %d but got %d", 10, v)
85+
}
86+
} else {
87+
t.Fatalf("expected flag 'count' to have type of int but got %T", v)
88+
}
89+
if v, ok := vars["boolFlag"].(bool); ok {
90+
if v != true {
91+
t.Fatalf("expected flag 'boolFlag' values to have value %v but got %v", true, v)
92+
}
93+
} else {
94+
t.Fatalf("expected flag 'boolFlag' to have type of bool but got %T", v)
95+
}
96+
97+
}

0 commit comments

Comments
 (0)