Skip to content

Commit 447194b

Browse files
authored
Ensure PreSendHeader event is raised when response is started (#532)
Previously, we were just raining it when a flush occurred, which would be after headers are made readonly. Fixes #531
1 parent e4a9776 commit 447194b

File tree

4 files changed

+158
-44
lines changed

4 files changed

+158
-44
lines changed

src/Microsoft.AspNetCore.SystemWebAdapters.CoreServices/Features/HttpApplicationPreSendEventsResponseBodyFeature.cs

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.IO.Pipelines;
88
using System.Threading;
99
using System.Threading.Tasks;
10+
using System.Web;
1011
using Microsoft.AspNetCore.Http;
1112
using Microsoft.AspNetCore.Http.Features;
1213

@@ -50,7 +51,15 @@ public HttpApplicationPreSendEventsResponseBodyFeature(HttpContextCore context,
5051
public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
5152
=> SendFileFallback.SendFileAsync(Stream, path, offset, count, cancellationToken);
5253

53-
public Task StartAsync(CancellationToken cancellationToken = default) => _other.StartAsync(cancellationToken);
54+
public async Task StartAsync(CancellationToken cancellationToken = default)
55+
{
56+
if (_context.Features.Get<IHttpApplicationFeature>() is { } httpApplication)
57+
{
58+
await RaisePreSendRequestHeadersEvent(httpApplication);
59+
}
60+
61+
await _other.StartAsync(cancellationToken);
62+
}
5463

5564
public override void Advance(int bytes)
5665
{
@@ -69,29 +78,43 @@ public override void Advance(int bytes)
6978

7079
public override async ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
7180
{
72-
// Only need to raise events if data will be flushed and the feature is available
73-
if (_byteCount > 0 && _context.Features.Get<IHttpApplicationFeature>() is { } httpApplication)
81+
// Only need to raise events if data will be flushed
82+
if (_byteCount == 0)
83+
{
84+
return default;
85+
}
86+
87+
if (_context.Features.Get<IHttpApplicationFeature>() is { } httpApplication)
7488
{
7589
_byteCount = 0;
7690

77-
if (_state is State.NotStarted)
78-
{
79-
_state = State.RaisingPreHeader;
80-
await _context.Features.GetRequiredFeature<IHttpApplicationFeature>().RaiseEventAsync(ApplicationEvent.PreSendRequestHeaders);
81-
_state = State.ReadyForContent;
82-
}
83-
84-
if (_state is State.ReadyForContent)
85-
{
86-
_state = State.RaisingPreContent;
87-
await httpApplication.RaiseEventAsync(ApplicationEvent.PreSendRequestContent);
88-
_state = State.ReadyForContent;
89-
}
91+
await RaisePreSendRequestHeadersEvent(httpApplication);
92+
await RaisePreSendRequestContentEvent(httpApplication);
9093
}
9194

9295
return await _pipe.FlushAsync(cancellationToken);
9396
}
9497

98+
private async ValueTask RaisePreSendRequestHeadersEvent(IHttpApplicationFeature httpApplication)
99+
{
100+
if (_state is State.NotStarted)
101+
{
102+
_state = State.RaisingPreHeader;
103+
await httpApplication.RaiseEventAsync(ApplicationEvent.PreSendRequestHeaders);
104+
_state = State.ReadyForContent;
105+
}
106+
}
107+
108+
private async ValueTask RaisePreSendRequestContentEvent(IHttpApplicationFeature httpApplication)
109+
{
110+
if (_state is State.ReadyForContent)
111+
{
112+
_state = State.RaisingPreContent;
113+
await httpApplication.RaiseEventAsync(ApplicationEvent.PreSendRequestContent);
114+
_state = State.ReadyForContent;
115+
}
116+
}
117+
95118
public override Memory<byte> GetMemory(int sizeHint = 0) => _pipe.GetMemory(sizeHint);
96119

97120
public override Span<byte> GetSpan(int sizeHint = 0) => _pipe.GetSpan(sizeHint);
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
namespace Microsoft.AspNetCore.SystemWebAdapters.CoreServices.Tests;
55

6-
public class BufferedModuleTests : ModuleTests
6+
public class BufferedModuleTests : ModuleTests<BufferedModuleTests>
77
{
8-
public BufferedModuleTests()
9-
: base(true)
10-
{
11-
}
128
}

test/Microsoft.AspNetCore.SystemWebAdapters.CoreServices.Tests/Modules/ModuleTests.cs

Lines changed: 115 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Web;
1010
using Microsoft.AspNetCore.Builder;
1111
using Microsoft.AspNetCore.Hosting;
12+
using Microsoft.AspNetCore.Http;
1213
using Microsoft.AspNetCore.Http.Features;
1314
using Microsoft.AspNetCore.SystemWebAdapters.Features;
1415
using Microsoft.AspNetCore.TestHost;
@@ -20,7 +21,7 @@
2021
namespace Microsoft.AspNetCore.SystemWebAdapters.CoreServices.Tests;
2122

2223
[Collection(nameof(SelfHostedTests))]
23-
public abstract class ModuleTests(bool isBuffered)
24+
public abstract class ModuleTests<T>
2425
{
2526
private static readonly ImmutableArray<ApplicationEvent> BeforeHandlerEvents =
2627
[
@@ -54,7 +55,27 @@ public abstract class ModuleTests(bool isBuffered)
5455
ApplicationEvent.EndRequest,
5556
];
5657

57-
public static IEnumerable<object[]> GetAllEvents()
58+
private static bool IsBuffered
59+
{
60+
get
61+
{
62+
if (typeof(T) == typeof(NotBufferedModuleTests))
63+
{
64+
return false;
65+
}
66+
else if (typeof(T) == typeof(BufferedModuleTests))
67+
{
68+
return true;
69+
}
70+
else
71+
{
72+
throw new ArgumentOutOfRangeException(typeof(T).FullName);
73+
}
74+
}
75+
}
76+
77+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1000:Do not declare static members on generic types", Justification = "Needed for xUnit theory tests")]
78+
public static TheoryData<ApplicationEvent, RegisterMode> GetAllEvents()
5879
{
5980
IEnumerable<ApplicationEvent> all =
6081
[
@@ -66,14 +87,17 @@ public static IEnumerable<object[]> GetAllEvents()
6687
];
6788

6889
var modes = Enum.GetValues<RegisterMode>();
90+
var data = new TheoryData<ApplicationEvent, RegisterMode>();
6991

7092
foreach (var notification in all)
7193
{
7294
foreach (var mode in modes)
7395
{
74-
yield return new object[] { notification, mode };
96+
data.Add(notification, mode);
7597
}
7698
}
99+
100+
return data;
77101
}
78102

79103
[MemberData(nameof(GetAllEvents))]
@@ -105,7 +129,7 @@ public async Task ModulesThrow(ApplicationEvent notification, RegisterMode mode)
105129

106130
Assert.Equal(expected, result);
107131

108-
IEnumerable<ApplicationEvent> GetExpected(ApplicationEvent notification)
132+
static IEnumerable<ApplicationEvent> GetExpected(ApplicationEvent notification)
109133
{
110134
foreach (var item in GetNotificationsUpTo(notification, isThrowing: true))
111135
{
@@ -119,9 +143,77 @@ IEnumerable<ApplicationEvent> GetExpected(ApplicationEvent notification)
119143
}
120144
}
121145

122-
private IEnumerable<ApplicationEvent> GetNotificationsUpTo(ApplicationEvent notification, bool isThrowing = false)
146+
[Fact]
147+
public async Task PreSendRequestHeadersAddHeaders()
123148
{
124-
return isBuffered
149+
// Arrange
150+
const string HeaderName = "name";
151+
const string HeaderValue = "value";
152+
const string Result = "Hello world!";
153+
154+
using var host = await new HostBuilder()
155+
.ConfigureWebHost(webBuilder => webBuilder
156+
.UseTestServer()
157+
.ConfigureServices(services =>
158+
{
159+
services.AddSystemWebAdapters()
160+
.AddHttpApplication(options =>
161+
{
162+
options.RegisterModule<BufferToggleModule>();
163+
options.RegisterModule<PreSendHeadersAddHeaderModule>();
164+
options.ArePreSendEventsEnabled = true;
165+
});
166+
167+
})
168+
.Configure(app =>
169+
{
170+
app.Use((ctx, next) =>
171+
{
172+
ctx.Features.Set(new ResultHeader { { HeaderName, HeaderValue } });
173+
return next(ctx);
174+
});
175+
app.UseSystemWebAdapters();
176+
177+
app.Run(ctx => ctx.Response.WriteAsync(Result));
178+
})).StartAsync();
179+
180+
// Act
181+
using var response = await host.GetTestClient().GetAsync(new Uri("/", UriKind.Relative));
182+
183+
// Assert
184+
Assert.True(response.Headers.TryGetValues(HeaderName, out var resultHeader));
185+
Assert.Equal([HeaderValue], resultHeader);
186+
Assert.Equal(Result, await response.Content.ReadAsStringAsync());
187+
}
188+
189+
private sealed class PreSendHeadersAddHeaderModule : IHttpModule
190+
{
191+
public void Dispose()
192+
{
193+
}
194+
195+
public void Init(HttpApplication application)
196+
{
197+
application.PreSendRequestHeaders += (s, o) =>
198+
{
199+
if (s is HttpApplication { Context: { } context })
200+
{
201+
foreach (var (name, value) in context.AsAspNetCore().Features.GetRequiredFeature<ResultHeader>())
202+
{
203+
context.Response.AddHeader(name, value);
204+
}
205+
}
206+
};
207+
}
208+
}
209+
210+
private sealed class ResultHeader : Dictionary<string, string>
211+
{
212+
}
213+
214+
private static IEnumerable<ApplicationEvent> GetNotificationsUpTo(ApplicationEvent notification, bool isThrowing = false)
215+
{
216+
return IsBuffered
125217
? GetExpectedBufferedNotificationsUntilAction(notification, isThrowing)
126218
: GetExpectedUnbufferedNotificationsUntilAction(notification, isThrowing);
127219

@@ -206,10 +298,9 @@ static IEnumerable<ApplicationEvent> GetExpectedUnbufferedNotificationsUntilActi
206298
}
207299
}
208300

209-
private async Task<List<ApplicationEvent>> RunAsync(string action, ApplicationEvent @event, RegisterMode mode)
301+
private static async Task<List<ApplicationEvent>> RunAsync(string action, ApplicationEvent @event, RegisterMode mode)
210302
{
211303
var notifier = new NotificationCollection();
212-
var module = isBuffered ? typeof(BufferedTestModule) : typeof(NotBufferedTestModule);
213304

214305
using var host = await new HostBuilder()
215306
.ConfigureWebHost(webBuilder =>
@@ -226,9 +317,11 @@ private async Task<List<ApplicationEvent>> RunAsync(string action, ApplicationEv
226317
services.AddSystemWebAdapters()
227318
.AddHttpApplication(options =>
228319
{
320+
options.RegisterModule<BufferToggleModule>();
321+
229322
if (mode == RegisterMode.Options)
230323
{
231-
options.RegisterModule(module);
324+
options.RegisterModule<NotificationTrackingModule>();
232325
}
233326

234327
options.ArePreSendEventsEnabled = true;
@@ -239,13 +332,13 @@ private async Task<List<ApplicationEvent>> RunAsync(string action, ApplicationEv
239332
{
240333
if (mode == RegisterMode.RegisterModule)
241334
{
242-
HttpApplication.RegisterModule(module);
335+
HttpApplication.RegisterModule(typeof(NotificationTrackingModule));
243336
}
244337
else if (mode == RegisterMode.RegisterModuleOnStartup)
245338
{
246339
app.ApplicationServices.GetRequiredService<IHostApplicationLifetime>().ApplicationStarted.Register(() =>
247340
{
248-
HttpApplication.RegisterModule(module);
341+
HttpApplication.RegisterModule(typeof(NotificationTrackingModule));
249342
});
250343
}
251344

@@ -316,16 +409,22 @@ public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
316409
};
317410
}
318411

319-
private sealed class NotBufferedTestModule : BufferedTestModule
412+
/// <summary>
413+
/// Module used to toggle buffer output depending on the test suite we're running
414+
/// </summary>
415+
private sealed class BufferToggleModule : IHttpModule
320416
{
321-
public override void Init(HttpApplication application)
417+
public void Dispose()
418+
{
419+
}
420+
421+
public void Init(HttpApplication application)
322422
{
323-
application.BeginRequest += (s, o) => ((HttpApplication)s!).Context.Response.BufferOutput = false;
324-
base.Init(application);
423+
application.BeginRequest += (s, o) => ((HttpApplication)s!).Context.Response.BufferOutput = IsBuffered;
325424
}
326425
}
327426

328-
private class BufferedTestModule : EventsModule
427+
private sealed class NotificationTrackingModule : EventsModule
329428
{
330429
protected override void InvokeEvent(HttpContext context, string name)
331430
{
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
namespace Microsoft.AspNetCore.SystemWebAdapters.CoreServices.Tests;
55

6-
public class NotBufferedModuleTests : ModuleTests
6+
public class NotBufferedModuleTests : ModuleTests<NotBufferedModuleTests>
77
{
8-
public NotBufferedModuleTests()
9-
: base(false)
10-
{
11-
}
128
}

0 commit comments

Comments
 (0)