diff --git a/src/Microsoft.FeatureManagement/FeatureManagementBuilderExtensions.cs b/src/Microsoft.FeatureManagement/FeatureManagementBuilderExtensions.cs index f8635a79..b12adb8f 100644 --- a/src/Microsoft.FeatureManagement/FeatureManagementBuilderExtensions.cs +++ b/src/Microsoft.FeatureManagement/FeatureManagementBuilderExtensions.cs @@ -60,17 +60,27 @@ public static IFeatureManagementBuilder WithVariantService(this IFeatu if (builder.Services.Any(descriptor => descriptor.ServiceType == typeof(IFeatureManager) && descriptor.Lifetime == ServiceLifetime.Scoped)) { - builder.Services.AddScoped>(sp => new VariantServiceProvider( - featureName, - sp.GetRequiredService(), - sp.GetRequiredService>())); + builder.Services.AddScoped>(sp => + { + IEnumerable serviceDescriptors = builder.Services.Where(d => d.ServiceType == typeof(TService)); + return new VariantServiceProvider( + featureName, + sp.GetRequiredService(), + serviceDescriptors, + sp); + }); } else { - builder.Services.AddSingleton>(sp => new VariantServiceProvider( - featureName, - sp.GetRequiredService(), - sp.GetRequiredService>())); + builder.Services.AddSingleton>(sp => + { + IEnumerable serviceDescriptors = builder.Services.Where(d => d.ServiceType == typeof(TService)); + return new VariantServiceProvider( + featureName, + sp.GetRequiredService(), + serviceDescriptors, + sp); + }); } return builder; diff --git a/src/Microsoft.FeatureManagement/VariantServiceProvider.cs b/src/Microsoft.FeatureManagement/VariantServiceProvider.cs index d4b3f514..d1a4b0cf 100644 --- a/src/Microsoft.FeatureManagement/VariantServiceProvider.cs +++ b/src/Microsoft.FeatureManagement/VariantServiceProvider.cs @@ -5,9 +5,9 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.FeatureManagement { @@ -16,26 +16,55 @@ namespace Microsoft.FeatureManagement /// internal class VariantServiceProvider : IVariantServiceProvider where TService : class { - private readonly IEnumerable _services; private readonly IVariantFeatureManager _featureManager; private readonly string _featureName; private readonly ConcurrentDictionary _variantServiceCache; + private readonly IServiceProvider _serviceProvider; + private readonly Dictionary _variantNameToDescriptor; // ImplementationType/Instance descriptors mapped by variant name. + private readonly List _factoryDescriptors; // Descriptors that require factory invocation to discover variant name. /// /// Creates a variant service provider. /// /// The feature flag that should be used to determine which variant of the service should be used. /// The feature manager to get the assigned variant of the feature flag. - /// Implementation variants of TService. - /// Thrown if is null. - /// Thrown if is null. - /// Thrown if is null. - public VariantServiceProvider(string featureName, IVariantFeatureManager featureManager, IEnumerable services) + /// Service descriptors for implementation variants of TService. + /// The service provider / scope used to activate implementations lazily. + public VariantServiceProvider(string featureName, IVariantFeatureManager featureManager, IEnumerable serviceDescriptors, IServiceProvider serviceProvider) { _featureName = featureName ?? throw new ArgumentNullException(nameof(featureName)); _featureManager = featureManager ?? throw new ArgumentNullException(nameof(featureManager)); - _services = services ?? throw new ArgumentNullException(nameof(services)); - _variantServiceCache = new ConcurrentDictionary(); + if (serviceDescriptors == null) throw new ArgumentNullException(nameof(serviceDescriptors)); + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + _variantServiceCache = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _variantNameToDescriptor = new Dictionary(StringComparer.OrdinalIgnoreCase); + _factoryDescriptors = new List(); + + // Precompute mapping for descriptors whose variant name can be determined without instantiation. + foreach (ServiceDescriptor descriptor in serviceDescriptors) + { + if (descriptor.ImplementationType != null) + { + string name = GetVariantName(descriptor.ImplementationType); + if (!_variantNameToDescriptor.ContainsKey(name)) + { + _variantNameToDescriptor.Add(name, descriptor); + } + } + else if (descriptor.ImplementationInstance != null) + { + string name = GetVariantName(descriptor.ImplementationInstance.GetType()); + if (!_variantNameToDescriptor.ContainsKey(name)) + { + _variantNameToDescriptor.Add(name, descriptor); + } + } + else if (descriptor.ImplementationFactory != null) + { + // Factory descriptors require instantiation to discover variant name; hold for later. + _factoryDescriptors.Add(descriptor); + } + } } /// @@ -47,25 +76,73 @@ public async ValueTask GetServiceAsync(CancellationToken cancellationT { Debug.Assert(_featureName != null); - Variant variant = await _featureManager.GetVariantAsync(_featureName, cancellationToken); + Variant variant = await _featureManager.GetVariantAsync(_featureName, cancellationToken).ConfigureAwait(false); + + if (variant == null) + { + return null; + } + + return _variantServiceCache.GetOrAdd(variant.Name, ResolveVariant); + } + + private TService ResolveVariant(string variantName) + { + // Try fast path using precomputed mapping. + if (_variantNameToDescriptor.TryGetValue(variantName, out ServiceDescriptor descriptor)) + { + return ActivateDescriptor(descriptor); + } + + // Need to probe factory descriptors lazily. + foreach (ServiceDescriptor factoryDescriptor in _factoryDescriptors) + { + TService instance = ActivateDescriptor(factoryDescriptor); - TService implementation = null; + if (instance == null) + { + continue; + } + + string discoveredName = GetVariantName(instance.GetType()); + + // Cache the mapping for future lookups. + if (!_variantNameToDescriptor.ContainsKey(discoveredName)) + { + _variantNameToDescriptor.Add(discoveredName, factoryDescriptor); + } + + if (string.Equals(discoveredName, variantName, StringComparison.OrdinalIgnoreCase)) + { + return instance; + } + } + + return null; + } + + private TService ActivateDescriptor(ServiceDescriptor descriptor) + { + if (descriptor.ImplementationInstance != null) + { + return (TService)descriptor.ImplementationInstance; + } + + if (descriptor.ImplementationType != null) + { + // Use ActivatorUtilities to honor DI for dependencies of the implementation type. + return (TService)ActivatorUtilities.GetServiceOrCreateInstance(_serviceProvider, descriptor.ImplementationType); + } - if (variant != null) + if (descriptor.ImplementationFactory != null) { - implementation = _variantServiceCache.GetOrAdd( - variant.Name, - (_) => _services.FirstOrDefault( - service => IsMatchingVariantName( - service.GetType(), - variant.Name)) - ); + return (TService)descriptor.ImplementationFactory(_serviceProvider); } - return implementation; + return null; } - private bool IsMatchingVariantName(Type implementationType, string variantName) + private string GetVariantName(Type implementationType) { string implementationName = ((VariantServiceAliasAttribute)Attribute.GetCustomAttribute(implementationType, typeof(VariantServiceAliasAttribute)))?.Alias; @@ -74,7 +151,7 @@ private bool IsMatchingVariantName(Type implementationType, string variantName) implementationName = implementationType.Name; } - return string.Equals(implementationName, variantName, StringComparison.OrdinalIgnoreCase); + return implementationName; } } } diff --git a/tests/Tests.FeatureManagement/FeatureManagementTest.cs b/tests/Tests.FeatureManagement/FeatureManagementTest.cs index 10636b84..0d5edfe0 100644 --- a/tests/Tests.FeatureManagement/FeatureManagementTest.cs +++ b/tests/Tests.FeatureManagement/FeatureManagementTest.cs @@ -1803,59 +1803,75 @@ public async Task VariantBasedInjection() services = new ServiceCollection(); Assert.Throws(() => - { - services.AddFeatureManagement() - .WithVariantService("DummyFeature1") - .WithVariantService("DummyFeature2"); - } + { + services.AddFeatureManagement() + .WithVariantService("DummyFeature1") + .WithVariantService("DummyFeature2"); + } ); } [Fact] - public async Task VariantFeatureFlagWithContextualFeatureFilter() + public async Task VariantServiceLazyInstantiation() { + // Reset counters + AlgorithmBeta.Instances = 0; + AlgorithmOmega.Instances = 0; + IConfiguration configuration = new ConfigurationBuilder() .AddJsonFile("appsettings.json") .Build(); IServiceCollection services = new ServiceCollection(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(sp => new AlgorithmOmega("OMEGA")); + services.AddSingleton(configuration) .AddFeatureManagement() - .AddFeatureFilter(); - - ServiceProvider serviceProvider = services.BuildServiceProvider(); - - ContextualTestFilter contextualTestFeatureFilter = (ContextualTestFilter)serviceProvider.GetRequiredService>().First(f => f is ContextualTestFilter); + .AddFeatureFilter() + .WithVariantService(Features.VariantImplementationFeature); - contextualTestFeatureFilter.ContextualCallback = (ctx, accountContext) => - { - var allowedAccounts = new List(); + var targetingContextAccessor = new OnDemandTargetingContextAccessor(); + services.AddSingleton(targetingContextAccessor); - ctx.Parameters.Bind("AllowedAccounts", allowedAccounts); + ServiceProvider serviceProvider = services.BuildServiceProvider(); - return allowedAccounts.Contains(accountContext.AccountId); - }; + // At this point none of the implementations should have been instantiated yet because provider hasn't requested them. + Assert.Equal(0, AlgorithmBeta.Instances); + Assert.Equal(0, AlgorithmOmega.Instances); + IVariantServiceProvider variantProvider = serviceProvider.GetRequiredService>(); IVariantFeatureManager featureManager = serviceProvider.GetRequiredService(); - var context = new AppContext(); - - context.AccountId = "NotEnabledAccount"; - - Assert.False(await featureManager.IsEnabledAsync(Features.ContextualFeatureWithVariant, context)); - - Variant variant = await featureManager.GetVariantAsync(Features.ContextualFeatureWithVariant, context); - - Assert.Equal("Small", variant.Name); - - context.AccountId = "abc"; + targetingContextAccessor.Current = new TargetingContext { UserId = "Guest" }; + IAlgorithm algorithm = await variantProvider.GetServiceAsync(CancellationToken.None); + Assert.Null(algorithm); + Assert.Equal(0, AlgorithmBeta.Instances); + Assert.Equal(0, AlgorithmOmega.Instances); - Assert.True(await featureManager.IsEnabledAsync(Features.ContextualFeatureWithVariant, context)); + targetingContextAccessor.Current = new TargetingContext { UserId = "UserBeta" }; + algorithm = await variantProvider.GetServiceAsync(CancellationToken.None); + Assert.NotNull(algorithm); + Assert.Equal("Beta", algorithm.Style); + Assert.Equal(1, AlgorithmBeta.Instances); + Assert.Equal(0, AlgorithmOmega.Instances); - variant = await featureManager.GetVariantAsync(Features.ContextualFeatureWithVariant, context); + targetingContextAccessor.Current = new TargetingContext { UserId = "UserOmega" }; + algorithm = await variantProvider.GetServiceAsync(CancellationToken.None); + Assert.NotNull(algorithm); + Assert.Equal("OMEGA", algorithm.Style); + Assert.Equal(1, AlgorithmBeta.Instances); + Assert.Equal(1, AlgorithmOmega.Instances); - Assert.Equal("Big", variant.Name); + // Re-resolve Beta variant should not create additional instance because singleton already constructed previously + targetingContextAccessor.Current = new TargetingContext { UserId = "UserBeta" }; + algorithm = await variantProvider.GetServiceAsync(CancellationToken.None); + Assert.NotNull(algorithm); + Assert.Equal("Beta", algorithm.Style); + Assert.Equal(1, AlgorithmBeta.Instances); + Assert.Equal(1, AlgorithmOmega.Instances); } } diff --git a/tests/Tests.FeatureManagement/VariantServices.cs b/tests/Tests.FeatureManagement/VariantServices.cs index 942b110c..6670c589 100644 --- a/tests/Tests.FeatureManagement/VariantServices.cs +++ b/tests/Tests.FeatureManagement/VariantServices.cs @@ -9,20 +9,24 @@ interface IAlgorithm class AlgorithmBeta : IAlgorithm { + public static int Instances; // Tracks constructed instances public string Style { get; set; } public AlgorithmBeta() { + Instances++; Style = "Beta"; } } class AlgorithmSigma : IAlgorithm { + public static int Instances; // Tracks constructed instances public string Style { get; set; } public AlgorithmSigma() { + Instances++; Style = "Sigma"; } } @@ -30,10 +34,12 @@ public AlgorithmSigma() [VariantServiceAlias("Omega")] class AlgorithmOmega : IAlgorithm { + public static int Instances; // Tracks constructed instances public string Style { get; set; } public AlgorithmOmega(string style) { + Instances++; Style = style; } }