diff --git a/samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj b/samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj index 59ab49828..5a23275ef 100644 --- a/samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj +++ b/samples/AspNetCoreMcpServer/AspNetCoreMcpServer.csproj @@ -4,7 +4,6 @@ net9.0 enable enable - true diff --git a/samples/AspNetCoreMcpServer/EventStore/EventStoreCleanupService.cs b/samples/AspNetCoreMcpServer/EventStore/EventStoreCleanupService.cs new file mode 100644 index 000000000..cc79dc1a7 --- /dev/null +++ b/samples/AspNetCoreMcpServer/EventStore/EventStoreCleanupService.cs @@ -0,0 +1,47 @@ +using ModelContextProtocol.Server; + +namespace AspNetCoreMcpServer.EventStore; + +public class EventStoreCleanupService : BackgroundService +{ + private readonly TimeSpan _jobRunFrequencyInMinutes; + private readonly ILogger _logger; + private readonly IEventStoreCleaner? _eventStoreCleaner; + + public EventStoreCleanupService(ILogger logger, IConfiguration configuration, IEventStoreCleaner? eventStoreCleaner = null) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + + _eventStoreCleaner = eventStoreCleaner; + _jobRunFrequencyInMinutes = TimeSpan.FromMinutes(configuration.GetValue("EventStore:CleanupJobRunFrequencyInMinutes", 30)); + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + + if (_eventStoreCleaner is null) + { + _logger.LogWarning("No event store cleaner implementation provided. Event store cleanup job will not run."); + return; + } + + _logger.LogInformation("Event store cleanup job started."); + + while (!stoppingToken.IsCancellationRequested) + { + try + { + _logger.LogInformation("Running event store cleanup job at {CurrentTimeInUtc}.", DateTime.UtcNow); + _eventStoreCleaner.CleanEventStore(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error running event store cleanup job."); + } + + await Task.Delay(_jobRunFrequencyInMinutes, stoppingToken); + } + + _logger.LogInformation("Event store cleanup job stopping."); + } +} diff --git a/samples/AspNetCoreMcpServer/EventStore/IEventStoreCleaner.cs b/samples/AspNetCoreMcpServer/EventStore/IEventStoreCleaner.cs new file mode 100644 index 000000000..ba4d495d6 --- /dev/null +++ b/samples/AspNetCoreMcpServer/EventStore/IEventStoreCleaner.cs @@ -0,0 +1,15 @@ +namespace AspNetCoreMcpServer.EventStore; + +/// +/// Interface for cleaning up the event store +/// +public interface IEventStoreCleaner +{ + + /// + /// Cleans up the event store by removing outdated or unnecessary events. + /// + /// This method is typically used to maintain the event store's size and performance by clearing + /// events that are no longer needed. + void CleanEventStore(); +} diff --git a/samples/AspNetCoreMcpServer/EventStore/InMemoryEventStore.cs b/samples/AspNetCoreMcpServer/EventStore/InMemoryEventStore.cs new file mode 100644 index 000000000..0706c86bf --- /dev/null +++ b/samples/AspNetCoreMcpServer/EventStore/InMemoryEventStore.cs @@ -0,0 +1,97 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Net.ServerSentEvents; + +namespace AspNetCoreMcpServer.EventStore; + +/// +/// Represents an in-memory implementation of an event store that stores and replays events associated with specific +/// streams. This class is designed to handle events of type where the data payload is a . +/// +/// The provides functionality to store events for a given stream and +/// replay events after a specified event ID. It supports resumability for specific types of requests and ensures events +/// are replayed in the correct order. +public sealed class InMemoryEventStore : IEventStore, IEventStoreCleaner +{ + public const string EventIdDelimiter = "_"; + private static ConcurrentDictionary>> _eventStore = new(); + + private readonly ILogger _logger; + private readonly TimeSpan _eventsRetentionDurationInMinutes; + + public InMemoryEventStore(IConfiguration configuration, ILogger logger) + { + _eventsRetentionDurationInMinutes = TimeSpan.FromMinutes(configuration.GetValue("EventStore:EventsRetentionDurationInMinutes", 60)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public void StoreEvent(string streamId, SseItem messageItem) + { + // remove ElicitationCreate method check to support resumability for other type of requests + if (messageItem.Data is JsonRpcRequest jsonRpcReq && jsonRpcReq.Method == RequestMethods.ElicitationCreate) + { + var sseItemList = _eventStore.GetOrAdd(streamId, (key) => new List>()); + sseItemList.Add(messageItem); + } + + if (messageItem.Data is JsonRpcResponse jsonRpcResp && + _eventStore.TryGetValue(streamId, out var itemList)) + { + itemList.Add(messageItem); + } + } + + public async Task ReplayEventsAfter(string lastEventId, Action>> sendEvents) + { + var streamId = lastEventId.Split(EventIdDelimiter)[0]; + var events = _eventStore.GetValueOrDefault(streamId, new()); + var sortedAndFilteredEventsToSend = events + .Where(e => e.Data is not null && e.EventId != null) + .OrderBy(e => e.EventId) + // Sending events with EventId greater than lastEventId. + .SkipWhile(e => string.Compare(e.EventId!, lastEventId, StringComparison.Ordinal) <= 0) + .Select(e => + new SseItem(e.Data!, e.EventType) + { + EventId = e.EventId, + ReconnectionInterval = e.ReconnectionInterval + }); + sendEvents(SseItemsAsyncEnumerable(sortedAndFilteredEventsToSend)); + } + + private static async IAsyncEnumerable> SseItemsAsyncEnumerable(IEnumerable> enumerableItems) + { + foreach (var sseItem in enumerableItems) + { + yield return sseItem; + } + } + + public string? GetEventId(string streamId, JsonRpcMessage message) + { + return $"{streamId}{EventIdDelimiter}{DateTime.UtcNow.Ticks}"; + } + + public void CleanEventStore() + { + var cutoffTime = DateTime.UtcNow - _eventsRetentionDurationInMinutes; + _logger.LogInformation("Cleaning up events older than {CutoffTime} from event store.", cutoffTime); + + foreach (var key in _eventStore.Keys) + { + if (_eventStore.TryGetValue(key, out var itemList)) + { + itemList.RemoveAll(item => item.EventId != null && + long.TryParse(item.EventId.Split(EventIdDelimiter)[1], out var ticks) && + new DateTime(ticks) < cutoffTime); + if (itemList.Count == 0) + { + _logger.LogInformation("Removing empty event stream with key {EventStreamKey} from event store.", key); + _eventStore.TryRemove(key, out _); + } + } + } + } +} diff --git a/samples/AspNetCoreMcpServer/Program.cs b/samples/AspNetCoreMcpServer/Program.cs index 96f89bffa..00ad4967c 100644 --- a/samples/AspNetCoreMcpServer/Program.cs +++ b/samples/AspNetCoreMcpServer/Program.cs @@ -1,15 +1,18 @@ +using AspNetCoreMcpServer.EventStore; +using AspNetCoreMcpServer.Resources; +using AspNetCoreMcpServer.Tools; +using Microsoft.Extensions.DependencyInjection.Extensions; +using ModelContextProtocol.Server; using OpenTelemetry; using OpenTelemetry.Metrics; using OpenTelemetry.Trace; -using AspNetCoreMcpServer.Tools; -using AspNetCoreMcpServer.Resources; using System.Net.Http.Headers; var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer() .WithHttpTransport() .WithTools() - .WithTools() + .WithTools() // this tool collect user information through elicitation .WithTools() .WithResources(); @@ -30,6 +33,11 @@ client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); }); +// adding InMemoryEventStore to support stream resumability and background cleanup service +builder.Services.TryAddSingleton(); +builder.Services.TryAddSingleton(); +builder.Services.AddHostedService(); + var app = builder.Build(); app.MapMcp(); diff --git a/samples/AspNetCoreMcpServer/Tools/CollectUserInformationTool.cs b/samples/AspNetCoreMcpServer/Tools/CollectUserInformationTool.cs new file mode 100644 index 000000000..e5d9b56b8 --- /dev/null +++ b/samples/AspNetCoreMcpServer/Tools/CollectUserInformationTool.cs @@ -0,0 +1,143 @@ +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Text.Json; + +namespace AspNetCoreMcpServer.Tools; + +[McpServerToolType] +public sealed class CollectUserInformationTool +{ + public enum InfoType + { + contact, + preferences, + feedback + } + + [McpServerTool(Name = "collect-user-info"), Description("A tool that collects user information through elicitation")] + public static async Task ElicitationEcho(McpServer thisServer, [Description("Type of information to collect")] InfoType infoType) + { + ElicitRequestParams elicitRequestParams; + switch (infoType) + { + case InfoType.contact: + elicitRequestParams = new ElicitRequestParams() + { + Message = "Please provide your contact information", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary() + { + ["name"] = new ElicitRequestParams.StringSchema + { + Title = "Full name", + Description = "Your full name", + }, + ["email"] = new ElicitRequestParams.StringSchema + { + Title = "Email address", + Description = "Your email address", + Format = "email", + }, + ["phone"] = new ElicitRequestParams.StringSchema + { + Title = "Phone number", + Description = "Your phone number (optional)", + } + }, + Required = new List { "name", "email" } + } + }; + break; + + case InfoType.preferences: + elicitRequestParams = new ElicitRequestParams() + { + Message = "Please set your preferences", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary() + { + ["theme"] = new ElicitRequestParams.EnumSchema + { + Title = "Theme", + Description = "Choose your preferred theme", + Enum = new List { "light", "dark", "auto" }, + EnumNames = new List { "Light", "Dark", "Auto" } + }, + ["notifications"] = new ElicitRequestParams.BooleanSchema + { + Title = "Enable notifications", + Description = "Would you like to receive notifications?", + Default = true, + }, + ["frequency"] = new ElicitRequestParams.EnumSchema + { + Title = "Notification frequency", + Description = "How often would you like notifications?", + Enum = new List { "daily", "weekly", "monthly" }, + EnumNames = new List { "Daily", "Weekly", "Monthly" } + } + }, + Required = new List { "theme" } + } + }; + + break; + + case InfoType.feedback: + elicitRequestParams = new ElicitRequestParams() + { + Message = "Please provide your feedback", + RequestedSchema = new ElicitRequestParams.RequestSchema + { + Properties = new Dictionary() + { + ["rating"] = new ElicitRequestParams.NumberSchema + { + Title = "Rating", + Description = "Rate your experience (1-5)", + Minimum = 1, + Maximum = 5, + }, + ["comments"] = new ElicitRequestParams.StringSchema + { + Title = "Comments", + Description = "Additional comments (optional)", + MaxLength = 500, + }, + ["recommend"] = new ElicitRequestParams.BooleanSchema + { + Title = "Would you recommend this?", + Description = "Would you recommend this to others?", + } + }, + Required = new List { "rating", "recommend" } + } + }; + + break; + + default: + throw new Exception($"Unknown info type: ${infoType}"); + + } + + + var result = await thisServer.ElicitAsync(elicitRequestParams); + var textResult = result.Action switch + { + "accept" => $"Thank you! Collected ${infoType} information: {JsonSerializer.Serialize(result.Content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary)))}", + "decline" => "No information was collected. User declined ${infoType} information request.", + "cancel" => "Information collection was cancelled by the user.", + _ => "Error collecting ${infoType} information: ${error}" + }; + + return new CallToolResult() + { + Content = [ new TextContentBlock { Text = textResult } ], + }; + } +} diff --git a/samples/AspNetCoreMcpServer/appsettings.json b/samples/AspNetCoreMcpServer/appsettings.json index 10f68b8c8..7e2f39d16 100644 --- a/samples/AspNetCoreMcpServer/appsettings.json +++ b/samples/AspNetCoreMcpServer/appsettings.json @@ -5,5 +5,9 @@ "Microsoft.AspNetCore": "Warning" } }, - "AllowedHosts": "*" + "AllowedHosts": "*", + "EventStore": { + "CleanupJobRunFrequencyInMinutes": 30, + "EventsRetentionDurationInMinutes": 60 + } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 9f4af7ea5..a4aec7336 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -7,8 +7,10 @@ using Microsoft.Net.Http.Headers; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Net.ServerSentEvents; using System.Security.Claims; using System.Security.Cryptography; +using System.Text.Json; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.AspNetCore; @@ -20,9 +22,11 @@ internal sealed class StreamableHttpHandler( StatefulSessionManager sessionManager, IHostApplicationLifetime hostApplicationLifetime, IServiceProvider applicationServices, - ILoggerFactory loggerFactory) + ILoggerFactory loggerFactory, + IEventStore? eventStore = null) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + private const string LastEventIdHeaderName = "Last-Event-Id"; private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); @@ -88,6 +92,20 @@ await WriteJsonRpcErrorAsync(context, return; } + // eventId format is _ + var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); + if (!string.IsNullOrEmpty(lastEventId) && eventStore is not null) + { + InitializeSseResponse(context); + await eventStore.ReplayEventsAfter(lastEventId, async (enumerableEvents) => + await SseFormatter.WriteAsync(enumerableEvents, + context.Response.Body, + (item, bufferWriter) => JsonSerializer.Serialize(new Utf8JsonWriter(bufferWriter), item.Data, s_messageTypeInfo), + context.RequestAborted)); + + return; + } + if (!session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context, @@ -105,11 +123,11 @@ await WriteJsonRpcErrorAsync(context, try { await using var _ = await session.AcquireReferenceAsync(cancellationToken); - InitializeSseResponse(context); + InitializeSseResponse(context); - // We should flush headers to indicate a 200 success quickly, because the initialization response - // will be sent in response to a different POST request. It might be a while before we send a message - // over this response body. + // We should flush headers to indicate a 200 success quickly, because the initialization response + // will be sent in response to a different POST request. It might be a while before we send a message + // over this response body. await context.Response.Body.FlushAsync(cancellationToken); await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); } @@ -190,7 +208,7 @@ private async ValueTask StartNewSessionAsync(HttpContext if (!HttpServerTransportOptions.Stateless) { sessionId = MakeNewSessionId(); - transport = new() + transport = new(eventStore) { SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, @@ -273,7 +291,7 @@ internal static string MakeNewSessionId() { Span buffer = stackalloc byte[16]; RandomNumberGenerator.Fill(buffer); - return WebEncoders.Base64UrlEncode(buffer); + return WebEncoders.Base64UrlEncode(buffer).Replace("_", "-"); } internal static async Task ReadJsonRpcMessageAsync(HttpContext context) { diff --git a/src/ModelContextProtocol.Core/Server/IEventStore.cs b/src/ModelContextProtocol.Core/Server/IEventStore.cs new file mode 100644 index 000000000..329a077fe --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/IEventStore.cs @@ -0,0 +1,38 @@ +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.Server; + +/// +/// Interface for resumability support via event storage +/// +public interface IEventStore +{ + /// + /// Stores an event in the specified stream and returns the unique identifier of the stored event. + /// + /// This method asynchronously stores the provided event in the specified stream. The returned + /// event identifier can be used to retrieve or reference the stored event in the future. + /// The identifier of the stream where the event will be stored. Cannot be null or empty. + /// The event item to be stored, which may contain a JSON-RPC message or be null. + void StoreEvent(string streamId, SseItem messageItem); + + + /// + /// Replays events that occurred after the specified event ID. + /// + /// The ID of the last event that was processed. Events occurring after this ID will be replayed. + /// A callback action that processes the replayed events as an asynchronous enumerable of + /// containing objects. + /// A task that represents the asynchronous operation of replaying events. + Task ReplayEventsAfter(string lastEventId, Action>> sendEvents); + + /// + /// Retrieves the event identifier associated with a specific JSON-RPC message in the given stream. + /// + /// The unique identifier of the stream containing the message. + /// The JSON-RPC message for which the event identifier is to be retrieved. + /// The event identifier as a string, or if no event identifier is associated with the + /// message. + string? GetEventId(string streamId, JsonRpcMessage message); +} diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index a2314e623..dfde263b7 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -7,7 +7,10 @@ namespace ModelContextProtocol.Server; -internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null) : IAsyncDisposable +internal sealed class SseWriter( + string? messageEndpoint = null, + BoundedChannelOptions? channelOptions = null, + IEventStore? eventStore = null) : IAsyncDisposable { private readonly Channel> _messages = Channel.CreateBounded>(channelOptions ?? new BoundedChannelOptions(1) { @@ -60,8 +63,29 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationTok return false; } + var transport = message.Context?.RelatedTransport; + var sseItem = new SseItem(message, SseParser.EventTypeDefault); + + if (eventStore is not null && + transport is StreamableHttpPostTransport postTransport && + !string.IsNullOrEmpty(postTransport.pendingStreamId)) + { + var streamId = postTransport.pendingStreamId!; + sseItem = new SseItem(message, SseParser.EventTypeDefault) + { + EventId = eventStore.GetEventId(streamId, message) + }; + + // store the requests and response to the pending request + if (message is JsonRpcRequest jsonRpcReq || + (message is JsonRpcResponse jsonRpcResp && jsonRpcResp.Id == postTransport.pendingRequestId)) + { + eventStore.StoreEvent(streamId, sseItem); + } + } + // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); + await _messages.Writer.WriteAsync(sseItem, cancellationToken).ConfigureAwait(false); return true; } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 1109c2b2b..9516f4503 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Diagnostics; using System.IO.Pipelines; using System.Net.ServerSentEvents; @@ -13,15 +14,22 @@ namespace ModelContextProtocol.Server; /// Handles processing the request/response body pairs for the Streamable HTTP transport. /// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream) : ITransport +internal sealed class StreamableHttpPostTransport( + StreamableHttpServerTransport parentTransport, + Stream responseStream, + IEventStore? eventStore = null) : ITransport { - private readonly SseWriter _sseWriter = new(); + private readonly SseWriter _sseWriter = new(eventStore: eventStore); private RequestId _pendingRequest; + private string? _pendingStreamId; public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; + public string? pendingStreamId => _pendingStreamId; + public RequestId pendingRequestId => _pendingRequest; + /// /// True, if data was written to the respond body. /// False, if nothing was written because the request body did not contain any messages to respond to. @@ -34,6 +42,7 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio if (message is JsonRpcRequest request) { _pendingRequest = request.Id; + _pendingStreamId = Guid.NewGuid().ToString(); // Invoke the initialize request callback if applicable. if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index ee943ea70..d1d313478 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,7 @@ using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.IO.Pipelines; +using System.Net.ServerSentEvents; using System.Security.Claims; using System.Threading.Channels; @@ -19,7 +21,7 @@ namespace ModelContextProtocol.Server; /// such as when streaming completion results or providing progress updates during long-running operations. /// /// -public sealed class StreamableHttpServerTransport : ITransport +public sealed class StreamableHttpServerTransport(IEventStore? eventStore = null) : ITransport { // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages. private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1) @@ -27,7 +29,7 @@ public sealed class StreamableHttpServerTransport : ITransport SingleReader = true, SingleWriter = false, FullMode = BoundedChannelFullMode.DropOldest, - }); + }, eventStore: eventStore); private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) { SingleReader = true, @@ -117,7 +119,7 @@ public async Task HandlePostRequestAsync(JsonRpcMessage message, Stream re Throw.IfNull(responseStream); using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(this, responseStream); + await using var postTransport = new StreamableHttpPostTransport(this, responseStream, eventStore); return await postTransport.HandlePostAsync(message, postCts.Token).ConfigureAwait(false); }