Skip to content

Commit 005809c

Browse files
authored
feat(cmd/server): support reading initial prompt from stdin (#139)
Supports passing initial prompt via stdin, e.g. ``` echo "a very long prompt" | agentapi server claude ```
1 parent d99fcde commit 005809c

File tree

3 files changed

+78
-16
lines changed

3 files changed

+78
-16
lines changed

cmd/server/server.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"io"
78
"log/slog"
89
"net/http"
910
"os"
1011
"sort"
1112
"strings"
1213

14+
"github.com/mattn/go-isatty"
1315
"github.com/spf13/cobra"
1416
"github.com/spf13/viper"
1517
"golang.org/x/xerrors"
@@ -88,6 +90,19 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
8890
return xerrors.Errorf("term height must be at least 10")
8991
}
9092

93+
// Read stdin if it's piped, to be used as initial prompt
94+
initialPrompt := viper.GetString(FlagInitialPrompt)
95+
if initialPrompt == "" {
96+
if !isatty.IsTerminal(os.Stdin.Fd()) {
97+
if stdinData, err := io.ReadAll(os.Stdin); err != nil {
98+
return xerrors.Errorf("failed to read stdin: %w", err)
99+
} else if len(stdinData) > 0 {
100+
initialPrompt = string(stdinData)
101+
logger.Info("Read initial prompt from stdin", "bytes", len(stdinData))
102+
}
103+
}
104+
}
105+
91106
printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
92107
var process *termexec.Process
93108
if printOpenAPI {
@@ -112,7 +127,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
112127
ChatBasePath: viper.GetString(FlagChatBasePath),
113128
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
114129
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
115-
InitialPrompt: viper.GetString(FlagInitialPrompt),
130+
InitialPrompt: initialPrompt,
116131
})
117132
if err != nil {
118133
return xerrors.Errorf("failed to create server: %w", err)
@@ -213,7 +228,7 @@ func CreateServerCmd() *cobra.Command {
213228
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
214229
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
215230
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
216-
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
231+
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
217232
}
218233

219234
for _, spec := range flagSpecs {

e2e/echo_test.go

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222

2323
const (
2424
testTimeout = 30 * time.Second
25-
operationTimeout = 5 * time.Second
25+
operationTimeout = 10 * time.Second
2626
healthCheckTimeout = 10 * time.Second
2727
)
2828

@@ -40,15 +40,14 @@ func TestE2E(t *testing.T) {
4040
t.Run("basic", func(t *testing.T) {
4141
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
4242
defer cancel()
43-
script, apiClient := setup(ctx, t)
44-
require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
43+
script, apiClient := setup(ctx, t, nil)
4544
messageReq := agentapisdk.PostMessageParams{
4645
Content: "This is a test message.",
4746
Type: agentapisdk.MessageTypeUser,
4847
}
4948
_, err := apiClient.PostMessage(ctx, messageReq)
5049
require.NoError(t, err, "Failed to send message via SDK")
51-
require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
50+
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message"))
5251
msgResp, err := apiClient.GetMessages(ctx)
5352
require.NoError(t, err, "Failed to get messages via SDK")
5453
require.Len(t, msgResp.Messages, 3)
@@ -61,7 +60,7 @@ func TestE2E(t *testing.T) {
6160
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
6261
defer cancel()
6362

64-
script, apiClient := setup(ctx, t)
63+
script, apiClient := setup(ctx, t, nil)
6564
messageReq := agentapisdk.PostMessageParams{
6665
Content: "What is the answer to life, the universe, and everything?",
6766
Type: agentapisdk.MessageTypeUser,
@@ -71,7 +70,7 @@ func TestE2E(t *testing.T) {
7170
statusResp, err := apiClient.GetStatus(ctx)
7271
require.NoError(t, err)
7372
require.Equal(t, agentapisdk.StatusRunning, statusResp.Status)
74-
require.NoError(t, waitAgentAPIStable(ctx, apiClient, 5*time.Second))
73+
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "post message"))
7574
msgResp, err := apiClient.GetMessages(ctx)
7675
require.NoError(t, err, "Failed to get messages via SDK")
7776
require.Len(t, msgResp.Messages, 3)
@@ -82,11 +81,45 @@ func TestE2E(t *testing.T) {
8281
require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(parts[0]))
8382
require.Equal(t, script[2].ResponseMessage, strings.TrimSpace(parts[1]))
8483
})
84+
85+
t.Run("stdin", func(t *testing.T) {
86+
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
87+
defer cancel()
88+
89+
script, apiClient := setup(ctx, t, &params{
90+
cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
91+
defCmd, defArgs := defaultCmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath)
92+
script := fmt.Sprintf(`echo "hello agent" | %s %s`, defCmd, strings.Join(defArgs, " "))
93+
return "/bin/sh", []string{"-c", script}
94+
},
95+
})
96+
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "stdin"))
97+
msgResp, err := apiClient.GetMessages(ctx)
98+
require.NoError(t, err, "Failed to get messages via SDK")
99+
require.Len(t, msgResp.Messages, 3)
100+
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content))
101+
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content))
102+
})
85103
}
86104

87-
func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Client) {
105+
type params struct {
106+
cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string)
107+
}
108+
109+
func defaultCmdFn(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
110+
return binaryPath, []string{"server", fmt.Sprintf("--port=%d", serverPort), "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath}
111+
}
112+
113+
func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agentapisdk.Client) {
88114
t.Helper()
89115

116+
if p == nil {
117+
p = &params{}
118+
}
119+
if p.cmdFn == nil {
120+
p.cmdFn = defaultCmdFn
121+
}
122+
90123
scriptFilePath := filepath.Join("testdata", filepath.Base(t.Name())+".json")
91124
data, err := os.ReadFile(scriptFilePath)
92125
require.NoError(t, err, "Failed to read test script file: %s", scriptFilePath)
@@ -116,10 +149,9 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien
116149
cwd, err := os.Getwd()
117150
require.NoError(t, err, "Failed to get current working directory")
118151

119-
cmd := exec.CommandContext(ctx, binaryPath, "server",
120-
fmt.Sprintf("--port=%d", serverPort),
121-
"--",
122-
"go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath)
152+
bin, args := p.cmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath)
153+
t.Logf("Running command: %s %s", bin, strings.Join(args, " "))
154+
cmd := exec.CommandContext(ctx, bin, args...)
123155

124156
// Capture output for debugging
125157
stdout, err := cmd.StdoutPipe()
@@ -160,7 +192,7 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien
160192
apiClient, err := agentapisdk.NewClient(serverURL)
161193
require.NoError(t, err, "Failed to create agentapi SDK client")
162194

163-
require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
195+
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup"))
164196
return script, apiClient
165197
}
166198

@@ -198,21 +230,30 @@ func waitForServer(ctx context.Context, t testing.TB, url string, timeout time.D
198230
}
199231
}
200232

201-
func waitAgentAPIStable(ctx context.Context, apiClient *agentapisdk.Client, waitFor time.Duration) error {
233+
func waitAgentAPIStable(ctx context.Context, t testing.TB, apiClient *agentapisdk.Client, waitFor time.Duration, msg string) error {
234+
t.Helper()
202235
waitCtx, waitCancel := context.WithTimeout(ctx, waitFor)
203236
defer waitCancel()
204237

205-
tick := time.NewTicker(100 * time.Millisecond)
238+
start := time.Now()
239+
tick := time.NewTicker(time.Millisecond)
206240
defer tick.Stop()
241+
var prevStatus agentapisdk.AgentStatus
242+
defer func() {
243+
elapsed := time.Since(start)
244+
t.Logf("%s: agent API status: %s (elapsed: %s)", msg, prevStatus, elapsed.Round(100*time.Millisecond))
245+
}()
207246
for {
208247
select {
209248
case <-waitCtx.Done():
210249
return waitCtx.Err()
211250
case <-tick.C:
251+
tick.Reset(100 * time.Millisecond)
212252
sr, err := apiClient.GetStatus(ctx)
213253
if err != nil {
214254
continue
215255
}
256+
prevStatus = sr.Status
216257
if sr.Status == agentapisdk.StatusStable {
217258
return nil
218259
}

e2e/testdata/stdin.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[
2+
{
3+
"expectMessage": "hello agent",
4+
"responseMessage": "Hello! I'm ready to help you. Please send me a message to echo back."
5+
}
6+
]

0 commit comments

Comments
 (0)