Skip to content

Commit 47b546a

Browse files
authored
feat(mcp): add LocalAI endpoint to stream live results of the agent (#7274)
* feat(mcp): add LocalAI endpoint to stream live results of the agent Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * wip Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Refactoring Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * MCP UX integration Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Enhance UX Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Support also non-SSE Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent a09d49d commit 47b546a

File tree

7 files changed

+1188
-105
lines changed

7 files changed

+1188
-105
lines changed

core/config/model_config.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/mudler/LocalAI/core/schema"
1010
"github.com/mudler/LocalAI/pkg/downloader"
1111
"github.com/mudler/LocalAI/pkg/functions"
12+
"github.com/mudler/cogito"
1213
"gopkg.in/yaml.v3"
1314
)
1415

@@ -668,3 +669,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
668669

669670
return true
670671
}
672+
673+
// BuildCogitoOptions generates cogito options from the model configuration
674+
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
675+
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
676+
cogitoOpts := []cogito.Option{
677+
cogito.WithIterations(3), // default to 3 iterations
678+
cogito.WithMaxAttempts(3), // default to 3 attempts
679+
cogito.WithForceReasoning(),
680+
}
681+
682+
// Apply agent configuration options
683+
if c.Agent.EnableReasoning {
684+
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
685+
}
686+
687+
if c.Agent.EnablePlanning {
688+
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
689+
}
690+
691+
if c.Agent.EnableMCPPrompts {
692+
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
693+
}
694+
695+
if c.Agent.EnablePlanReEvaluator {
696+
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
697+
}
698+
699+
if c.Agent.MaxIterations != 0 {
700+
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
701+
}
702+
703+
if c.Agent.MaxAttempts != 0 {
704+
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
705+
}
706+
707+
return cogitoOpts
708+
}

core/http/app.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func API(application *application.Application) (*echo.Echo, error) {
205205
opcache = services.NewOpCache(application.GalleryService())
206206
}
207207

208-
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
208+
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
209209
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
210210
if !application.ApplicationConfig().DisableWebUI {
211211
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)

core/http/endpoints/localai/mcp.go

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
package localai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"strings"
9+
"time"
10+
11+
"github.com/labstack/echo/v4"
12+
"github.com/mudler/LocalAI/core/config"
13+
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
14+
"github.com/mudler/LocalAI/core/http/middleware"
15+
"github.com/mudler/LocalAI/core/schema"
16+
"github.com/mudler/LocalAI/core/templates"
17+
"github.com/mudler/LocalAI/pkg/model"
18+
"github.com/mudler/cogito"
19+
"github.com/rs/zerolog/log"
20+
)
21+
22+
// MCP SSE Event Types
23+
type MCPReasoningEvent struct {
24+
Type string `json:"type"`
25+
Content string `json:"content"`
26+
}
27+
28+
type MCPToolCallEvent struct {
29+
Type string `json:"type"`
30+
Name string `json:"name"`
31+
Arguments map[string]interface{} `json:"arguments"`
32+
Reasoning string `json:"reasoning"`
33+
}
34+
35+
type MCPToolResultEvent struct {
36+
Type string `json:"type"`
37+
Name string `json:"name"`
38+
Result string `json:"result"`
39+
}
40+
41+
type MCPStatusEvent struct {
42+
Type string `json:"type"`
43+
Message string `json:"message"`
44+
}
45+
46+
type MCPAssistantEvent struct {
47+
Type string `json:"type"`
48+
Content string `json:"content"`
49+
}
50+
51+
type MCPErrorEvent struct {
52+
Type string `json:"type"`
53+
Message string `json:"message"`
54+
}
55+
56+
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
57+
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
58+
// @Param request body schema.OpenAIRequest true "query params"
59+
// @Success 200 {object} schema.OpenAIResponse "Response"
60+
// @Router /v1/mcp/chat/completions [post]
61+
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
62+
return func(c echo.Context) error {
63+
ctx := c.Request().Context()
64+
created := int(time.Now().Unix())
65+
66+
// Handle Correlation
67+
id := c.Request().Header.Get("X-Correlation-ID")
68+
if id == "" {
69+
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
70+
}
71+
72+
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
73+
if !ok || input.Model == "" {
74+
return echo.ErrBadRequest
75+
}
76+
77+
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
78+
if !ok || config == nil {
79+
return echo.ErrBadRequest
80+
}
81+
82+
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
83+
return fmt.Errorf("no MCP servers configured")
84+
}
85+
86+
// Get MCP config from model config
87+
remote, stdio, err := config.MCP.MCPConfigFromYAML()
88+
if err != nil {
89+
return fmt.Errorf("failed to get MCP config: %w", err)
90+
}
91+
92+
// Check if we have tools in cache, or we have to have an initial connection
93+
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
94+
if err != nil {
95+
return fmt.Errorf("failed to get MCP sessions: %w", err)
96+
}
97+
98+
if len(sessions) == 0 {
99+
return fmt.Errorf("no working MCP servers found")
100+
}
101+
102+
// Build fragment from messages
103+
fragment := cogito.NewEmptyFragment()
104+
for _, message := range input.Messages {
105+
fragment = fragment.AddMessage(message.Role, message.StringContent)
106+
}
107+
108+
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
109+
apiKey := ""
110+
if len(appConfig.ApiKeys) > 0 {
111+
apiKey = appConfig.ApiKeys[0]
112+
}
113+
114+
ctxWithCancellation, cancel := context.WithCancel(ctx)
115+
defer cancel()
116+
117+
// TODO: instead of connecting to the API, we should just wire this internally
118+
// and act like completion.go.
119+
// We can do this as cogito expects an interface and we can create one that
120+
// we satisfy to just call internally ComputeChoices
121+
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
122+
123+
// Build cogito options using the consolidated method
124+
cogitoOpts := config.BuildCogitoOptions()
125+
cogitoOpts = append(
126+
cogitoOpts,
127+
cogito.WithContext(ctxWithCancellation),
128+
cogito.WithMCPs(sessions...),
129+
)
130+
// Check if streaming is requested
131+
toStream := input.Stream
132+
133+
if !toStream {
134+
// Non-streaming mode: execute synchronously and return JSON response
135+
cogitoOpts = append(
136+
cogitoOpts,
137+
cogito.WithStatusCallback(func(s string) {
138+
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
139+
}),
140+
cogito.WithReasoningCallback(func(s string) {
141+
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
142+
}),
143+
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
144+
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
145+
return true
146+
}),
147+
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
148+
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
149+
}),
150+
)
151+
152+
f, err := cogito.ExecuteTools(
153+
defaultLLM, fragment,
154+
cogitoOpts...,
155+
)
156+
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
157+
return err
158+
}
159+
160+
f, err = defaultLLM.Ask(ctxWithCancellation, f)
161+
if err != nil {
162+
return err
163+
}
164+
165+
resp := &schema.OpenAIResponse{
166+
ID: id,
167+
Created: created,
168+
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
169+
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
170+
Object: "chat.completion",
171+
}
172+
173+
jsonResult, _ := json.Marshal(resp)
174+
log.Debug().Msgf("Response: %s", jsonResult)
175+
176+
// Return the prediction in the response body
177+
return c.JSON(200, resp)
178+
}
179+
180+
// Streaming mode: use SSE
181+
// Set up SSE headers
182+
c.Response().Header().Set("Content-Type", "text/event-stream")
183+
c.Response().Header().Set("Cache-Control", "no-cache")
184+
c.Response().Header().Set("Connection", "keep-alive")
185+
c.Response().Header().Set("X-Correlation-ID", id)
186+
187+
// Create channel for streaming events
188+
events := make(chan interface{})
189+
ended := make(chan error, 1)
190+
191+
// Set up callbacks for streaming
192+
statusCallback := func(s string) {
193+
events <- MCPStatusEvent{
194+
Type: "status",
195+
Message: s,
196+
}
197+
}
198+
199+
reasoningCallback := func(s string) {
200+
events <- MCPReasoningEvent{
201+
Type: "reasoning",
202+
Content: s,
203+
}
204+
}
205+
206+
toolCallCallback := func(t *cogito.ToolChoice) bool {
207+
events <- MCPToolCallEvent{
208+
Type: "tool_call",
209+
Name: t.Name,
210+
Arguments: t.Arguments,
211+
Reasoning: t.Reasoning,
212+
}
213+
return true
214+
}
215+
216+
toolCallResultCallback := func(t cogito.ToolStatus) {
217+
events <- MCPToolResultEvent{
218+
Type: "tool_result",
219+
Name: t.Name,
220+
Result: t.Result,
221+
}
222+
}
223+
224+
cogitoOpts = append(cogitoOpts,
225+
cogito.WithStatusCallback(statusCallback),
226+
cogito.WithReasoningCallback(reasoningCallback),
227+
cogito.WithToolCallBack(toolCallCallback),
228+
cogito.WithToolCallResultCallback(toolCallResultCallback),
229+
)
230+
231+
// Execute tools in a goroutine
232+
go func() {
233+
defer close(events)
234+
235+
f, err := cogito.ExecuteTools(
236+
defaultLLM, fragment,
237+
cogitoOpts...,
238+
)
239+
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
240+
events <- MCPErrorEvent{
241+
Type: "error",
242+
Message: fmt.Sprintf("Failed to execute tools: %v", err),
243+
}
244+
ended <- err
245+
return
246+
}
247+
248+
// Get final response
249+
f, err = defaultLLM.Ask(ctxWithCancellation, f)
250+
if err != nil {
251+
events <- MCPErrorEvent{
252+
Type: "error",
253+
Message: fmt.Sprintf("Failed to get response: %v", err),
254+
}
255+
ended <- err
256+
return
257+
}
258+
259+
// Stream final assistant response
260+
content := f.LastMessage().Content
261+
events <- MCPAssistantEvent{
262+
Type: "assistant",
263+
Content: content,
264+
}
265+
266+
ended <- nil
267+
}()
268+
269+
// Stream events to client
270+
LOOP:
271+
for {
272+
select {
273+
case <-ctx.Done():
274+
// Context was cancelled (client disconnected or request cancelled)
275+
log.Debug().Msgf("Request context cancelled, stopping stream")
276+
cancel()
277+
break LOOP
278+
case event := <-events:
279+
if event == nil {
280+
// Channel closed
281+
break LOOP
282+
}
283+
eventData, err := json.Marshal(event)
284+
if err != nil {
285+
log.Debug().Msgf("Failed to marshal event: %v", err)
286+
continue
287+
}
288+
log.Debug().Msgf("Sending event: %s", string(eventData))
289+
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
290+
if err != nil {
291+
log.Debug().Msgf("Sending event failed: %v", err)
292+
cancel()
293+
return err
294+
}
295+
c.Response().Flush()
296+
case err := <-ended:
297+
if err == nil {
298+
// Send done signal
299+
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
300+
c.Response().Flush()
301+
break LOOP
302+
}
303+
log.Error().Msgf("Stream ended with error: %v", err)
304+
errorEvent := MCPErrorEvent{
305+
Type: "error",
306+
Message: err.Error(),
307+
}
308+
errorData, marshalErr := json.Marshal(errorEvent)
309+
if marshalErr != nil {
310+
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
311+
} else {
312+
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
313+
}
314+
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
315+
c.Response().Flush()
316+
return nil
317+
}
318+
}
319+
320+
log.Debug().Msgf("Stream ended")
321+
return nil
322+
}
323+
}

0 commit comments

Comments
 (0)