From ee7f341c1be9fc3b77baf8e6ebd355d6a1a4aa51 Mon Sep 17 00:00:00 2001 From: 7amou3 <993610+7amou3@users.noreply.github.com> Date: Sun, 20 Jul 2025 18:05:49 +0200 Subject: [PATCH] feat: TypeHandler Support --- docs/rules/DAP050.md | 20 ++ docs/rules/DAP051.md | 18 ++ docs/typehandlers.md | 26 ++ .../DapperInterceptorGenerator.Diagnostics.cs | 5 +- .../DapperInterceptorGenerator.cs | 228 ++++++++++++++---- .../CodeAnalysis/ParseState.cs | 30 ++- .../Internal/CodeWriter.cs | 30 ++- src/Dapper.AOT/TypeHandlerT.cs | 8 +- .../Interceptors/QueryStrictBind.output.cs | 12 +- .../QueryStrictBind.output.netfx.cs | 12 +- .../Interceptors/TypeHandler.input.cs | 35 +++ .../Interceptors/TypeHandler.output.cs | 198 +++++++++++++++ .../Interceptors/TypeHandler.output.netfx.cs | 198 +++++++++++++++ .../Interceptors/TypeHandler.output.netfx.txt | 4 + .../Interceptors/TypeHandler.output.txt | 4 + .../TestCommon/GeneratorWrapper.cs | 15 +- 16 files changed, 765 insertions(+), 78 deletions(-) create mode 100644 docs/rules/DAP050.md create mode 100644 docs/rules/DAP051.md create mode 100644 docs/typehandlers.md create mode 100644 test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt create mode 100644 test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt diff --git a/docs/rules/DAP050.md b/docs/rules/DAP050.md new file mode 100644 index 00000000..5b3c979b --- /dev/null +++ b/docs/rules/DAP050.md @@ -0,0 +1,20 @@ +# DAP050 + +Duplicate classes have been registered as type handlers for the same type, +meaning it's not possible to determine which to use when handling the type. +Note type handlers can be registered at the assembly and module level, so +ensure the type used for the `TValue` parameter in the attribute is only +specified once. + +Bad: + +``` c# +[module: TypeHandler] +[module: TypeHandler] +``` + +Good: + +``` c# +[module: TypeHandler] +``` \ No newline at end of file diff --git a/docs/rules/DAP051.md b/docs/rules/DAP051.md new file mode 100644 index 00000000..4b1931a4 --- /dev/null +++ b/docs/rules/DAP051.md @@ -0,0 +1,18 @@ +# DAP051 + +TypeHandler attribute points to a type handler class (TTypeHandler in TypeHandler) that is not a valid named type, +or cannot be constructed as required by Dapper AOT. +The type handler must be a concrete (non-abstract) class and must have a public parameterless constructor. + +Bad: + +``` c# +[module: TypeHandler] +[module: TypeHandler] +``` + +Good: + +``` c# +[module: TypeHandler] +``` \ No newline at end of file diff --git a/docs/typehandlers.md b/docs/typehandlers.md new file mode 100644 index 00000000..f207b635 --- /dev/null +++ b/docs/typehandlers.md @@ -0,0 +1,26 @@ +# Type Handlers + +Dapper AOT provides a mechanism to customize how specific .NET types are mapped to and from database values. +This is achieved through Type Handlers, which allow you to define custom serialization for parameters sent +to the database and custom deserialization for values read from query results. +This is similar to SqlMapper.TypeHandler in vanilla Dapper, but with a specific AOT-compatible registration process. + +To register your custom type handler, you use either an assembly or module level attribute. +Both have the same effect. +This attribute tells Dapper AOT which .NET type (TValue) your handler processes +and which class (TTypeHandler) is responsible for that processing. + +``` csharp +// Example using a module-level attribute +using Dapper; + +[module: TypeHandler] + +// Example using an assembly-level attribute (often in AssemblyInfo.cs or a shared file) +// [assembly: TypeHandler] +``` + +Your custom type handler class must: +* Inherit from Dapper.TypeHandler, where TValue is the .NET type it handles. +* Have a public parameterless constructor (new()). +* Implement (override) the necessary virtual methods for your specific conversion needs. \ No newline at end of file diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs index cc07fa11..4f16e65b 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs @@ -12,6 +12,9 @@ internal static readonly DiagnosticDescriptor LanguageVersionTooLow = LibraryWarning("DAP004", "Language version too low", "Interceptors require at least C# version 11"), CommandPropertyNotFound = LibraryWarning("DAP033", "Command property not found", "Command property {0}.{1} was not found or was not valid; attribute will be ignored"), - CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored"); + CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored"), + + DuplicateTypeHandlers = LibraryError("DAP050", "Duplicate type handlers", "Type {0} has multiple type handlers registered"), + InvalidTypeHandlerSymbol = LibraryError("DAP051", "Invalid type handler symbol", "Type handler symbol {0} is not a valid named type; attribute will be ignored"); } } diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs index 595282a6..76375690 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs @@ -243,7 +243,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm { try { - Generate(new(ctx, state)); + var (typeHandlerRegistry, typeHandlers) = InitTypeHandlers(ctx.ReportDiagnostic, state.Compilation); + Generate(new(ctx, state.Compilation, state.Nodes, typeHandlers, typeHandlerRegistry)); } catch (Exception ex) { @@ -276,6 +277,9 @@ internal void Generate(in GenerateState ctx) .Append("#pragma warning disable CS9270 // SDK-dependent change to interceptors usage").NewLine() .Append("namespace ").Append(FeatureKeys.CodegenNamespace).Append(" // interceptors must be in a known namespace").Indent().NewLine() .Append("file static class DapperGeneratedInterceptors").Indent().NewLine(); + + sb.AppendTypeHandlers(ctx.TypeHandlerRegistry.GetAllRegisteredHandlersForEmission()); + int methodIndex = 0, callSiteCount = 0; var factories = new CommandFactoryState(ctx.Compilation); @@ -759,7 +763,8 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false }); var useConstructorDeferred = map.Constructor is not null; var useFactoryMethodDeferred = map.FactoryMethod is not null; - + var typeHandlers = context.TypeHandlers; // Prevent ctx getting captured + // Implementation detail: // constructor takes advantage over factory method. var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers || hasRequiredMembers; @@ -824,7 +829,8 @@ void WriteTokenizeMethod() .Append("var type = reader.GetFieldType(columnOffset);").NewLine() .Append("switch (NormalizedHash(name))").Indent().NewLine(); - int token = 0; + int firstToken = 0; + int secondToken = map.Members.Length; foreach (var member in members) { if (member.IsMapped) @@ -835,17 +841,27 @@ void WriteTokenizeMethod() .AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine(); if (flags.HasAny(OperationFlags.StrictTypes)) { - sb.Append("token = ").Append(token).Append(";").Append(token == 0 ? " // note: strict types" : ""); + sb.Append("token = ").Append(firstToken ).Append(";").Append(firstToken == 0 ? " // note: strict types" : ""); } else { - sb.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(token) - .Append(" : ").Append(token + map.Members.Length).Append(";") - .Append(token == 0 ? " // two tokens for right-typed and type-flexible" : ""); + if (typeHandlers.TryGetValue(member.CodeType, out var typeHandlerSymbolFound) && typeHandlerSymbolFound is INamedTypeSymbol namedTypeHandlerSymbol) + { + sb.Append("token = ").Append(firstToken).Append(";"); + } + else + { + sb + .Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(firstToken) + .Append(" : ").Append(secondToken).Append(";"); + secondToken++; + } + + sb.Append(firstToken == 0 ? " // two tokens for right-typed and type-flexible" : ""); } sb.NewLine().Append("break;").Outdent(false).NewLine(); } - token++; + firstToken++; } sb.Outdent().NewLine() .Append("tokens[i] = token;").NewLine() @@ -886,7 +902,7 @@ void WriteReadMethod(in GenerateState context) sb.Append("public override ").Append(type).Append(" Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state)").Indent().NewLine(); - int token = 0; + int firstToken = 0; var deferredMethodArgumentsOrdered = new SortedList(); if (useDeferredConstruction) @@ -901,7 +917,7 @@ void WriteReadMethod(in GenerateState context) { if (member.IsMapped) { - var variableName = DeferredConstructionVariableName + token; + var variableName = DeferredConstructionVariableName + firstToken; if (Inspection.CouldBeNullable(member.CodeType)) sb.Append(CodeWriter.GetTypeName(member.CodeType.WithNullableAnnotation(NullableAnnotation.Annotated))); else sb.Append(CodeWriter.GetTypeName(member.CodeType)); @@ -920,7 +936,7 @@ void WriteReadMethod(in GenerateState context) deferredMethodArgumentsOrdered.Add(member.FactoryMethodParameterOrder.Value, variableName); } } - token++; + firstToken++; } } else @@ -942,7 +958,8 @@ void WriteReadMethod(in GenerateState context) } sb.Indent().NewLine().Append("switch (token)").Indent().NewLine(); - token = 0; + firstToken = 0; + int secondToken = members.Length; foreach (var member in members) { if (member.IsMapped) @@ -951,43 +968,53 @@ void WriteReadMethod(in GenerateState context) member.GetDbType(out var readerMethod); var nullCheck = Inspection.CouldBeNullable(memberType) ? $"reader.IsDBNull(columnOffset) ? ({CodeWriter.GetTypeName(memberType.WithNullableAnnotation(NullableAnnotation.Annotated))})null : " : ""; - sb.Append("case ").Append(token).Append(":").NewLine().Indent(false); + sb.Append("case ").Append(firstToken).Append(":").NewLine().Indent(false); // write `result.X = ` or `member0 = ` - if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token); + if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken); else sb.Append("result.").Append(member.CodeName); sb.Append(" = "); sb.Append(nullCheck); - if (readerMethod is null) + if (context.TypeHandlers.TryGetValue(memberType, out var typeHandlerSymbolFound) && typeHandlerSymbolFound is INamedTypeSymbol namedTypeHandlerSymbol) { - sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);"); + var (handlerPropertyName, _) = context.TypeHandlerRegistry.GetOrCreateHandlerInfo(namedTypeHandlerSymbol); + sb + .Append($"{handlerPropertyName}.Read(reader, columnOffset);").NewLine() + .Append("break;").NewLine().Outdent(false); } else { - sb.Append("reader.").Append(readerMethod).Append("(columnOffset);"); - } - + if (readerMethod is null) + { + sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);"); + } + else + { + sb.Append("reader.").Append(readerMethod).Append("(columnOffset);"); + } - sb.NewLine().Append("break;").NewLine().Outdent(false); + sb.NewLine().Append("break;").NewLine().Outdent(false); - // optionally emit type-forgiving version - if (!flags.HasAny(OperationFlags.StrictTypes)) - { - sb.Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false); + // optionally emit type-forgiving version + if (!flags.HasAny(OperationFlags.StrictTypes)) + { + sb.Append("case ").Append(firstToken + map.Members.Length).Append(":").NewLine().Indent(false); - // write `result.X = ` or `member0 = ` - if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token); - else sb.Append("result.").Append(member.CodeName); + // write `result.X = ` or `member0 = ` + if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken); + else sb.Append("result.").Append(member.CodeName); - sb.Append(" = ") - .Append(nullCheck) - .Append("GetValue<") - .Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine() - .Append("break;").NewLine().Outdent(false); + sb.Append(" = ") + .Append(nullCheck) + .Append("GetValue<") + .Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine() + .Append("break;").NewLine().Outdent(false); + } + secondToken++; } + firstToken++; } - token++; } sb.Outdent().NewLine().Append("columnOffset++;").NewLine().Outdent().NewLine(); @@ -1042,14 +1069,14 @@ void WriteDeferredInitialization() if (deferredMethodArgumentsOrdered!.Count == members.Length) return; sb.Indent().NewLine(); - token = -1; + firstToken = -1; foreach (var member in members) { - token++; + firstToken++; if (member.IsMapped) { if (member.ConstructorParameterOrder is not null) continue; // already used in constructor arguments - sb.Append(member.CodeName).Append(" = ").Append(DeferredConstructionVariableName).Append(token).Append(',').NewLine(); + sb.Append(member.CodeName).Append(" = ").Append(DeferredConstructionVariableName).Append(firstToken).Append(',').NewLine(); } } sb.Outdent(withScope: false).Append("}"); @@ -1074,6 +1101,31 @@ void WriteDeferredMethodArgs() } } } + + internal static (TypeHandlerInstanceRegistry, IImmutableDictionary) InitTypeHandlers(Action reportDiagnostic, Compilation compilation) + { + var typeHandlerRegistry = new TypeHandlerInstanceRegistry(); + var typeHandlers = IdentifyTypeHandlers(reportDiagnostic, compilation); + foreach (var pair in typeHandlers) + { + var handlerSymbol = pair.Value; + + if (handlerSymbol is INamedTypeSymbol namedTypeHandlerSymbol) + { + typeHandlerRegistry.GetOrCreateHandlerInfo(namedTypeHandlerSymbol); + } + else + { + reportDiagnostic(Diagnostic.Create( + Diagnostics.InvalidTypeHandlerSymbol, + handlerSymbol.Locations.FirstOrDefault(), + handlerSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + )); + } + } + + return (typeHandlerRegistry, typeHandlers); + } [Flags] enum WriteArgsFlags @@ -1270,7 +1322,7 @@ private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, } else { - sb.Append("p.Value = ").Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + AppendSetValue(ctx, sb, "p", source, member); } break; default: @@ -1301,29 +1353,33 @@ private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, break; } - sb.Append("ps["); - if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName); - else sb.Append(parameterIndex); - sb.Append("].Value = "); + var parameter = GetParameterIndex(flags, member.DbName, parameterIndex); switch (direction) { case ParameterDirection.Input: case ParameterDirection.InputOutput: - sb.Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + AppendSetValue(ctx, sb, parameter, source, member); break; default: - sb.Append("global::System.DBNull.Value;").NewLine(); + sb.Append(parameter).Append(".Value = global::System.DBNull.Value;").NewLine(); break; } break; case WriteArgsMode.PostProcess: // we already eliminated args that we don't need to look at - sb.Append(source).Append(".").Append(member.CodeName).Append(" = Parse<") - .Append(member.CodeType).Append(">(ps["); - if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName); - else sb.Append(parameterIndex); - sb.Append("].Value);").NewLine(); + parameter = GetParameterIndex(flags, member.DbName, parameterIndex); + sb.Append(source).Append(".").Append(member.CodeName).Append(" = "); + if (ctx.TypeHandlers.TryGetValue(member.CodeType, out var handlerClassSymbolFound) && handlerClassSymbolFound is INamedTypeSymbol namedHandlerClassSymbol) + { + var (handlerPropertyName, _) = ctx.TypeHandlerRegistry.GetOrCreateHandlerInfo(namedHandlerClassSymbol); + sb.Append($"{handlerPropertyName}.Parse((global::System.Data.Common.DbParameter){parameter});").NewLine(); + } + else + { + sb.Append(source).Append(".").Append(member.CodeName).Append("Parse<") + .Append(member.CodeType).Append(">(").Append(parameter).Append(".Value);").NewLine(); + } break; } @@ -1335,6 +1391,22 @@ private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, } } + static void AppendSetValue(in GenerateState ctx, CodeWriter sb, string parameter, string? source, in Inspection.ElementMember member) + { + if (member.CodeType != null && ctx.TypeHandlers.TryGetValue(member.CodeType, out var handlerClassSymbolFound) && handlerClassSymbolFound is INamedTypeSymbol namedHandlerClassSymbol) + { + var (handlerPropertyName, _) = ctx.TypeHandlerRegistry.GetOrCreateHandlerInfo(namedHandlerClassSymbol); + sb + .Append($"{handlerPropertyName}.SetValue(") + .Append($"(global::System.Data.Common.DbParameter){parameter}, ").Append(source).Append(".").Append(member.CodeName) + .Append(");").NewLine(); + } + else + { + sb.Append(parameter).Append(".Value = AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + } + } + static void AppendDbParameterSetting(CodeWriter sb, string memberName, int? value) { if (value is not null) @@ -1405,6 +1477,15 @@ static bool IsSettableInstanceProperty(ISymbol? symbol, SpecialType type) => && !prop.IsIndexer && !prop.IsStatic; } + private static string GetParameterIndex(WriteArgsFlags flags, string dbName, int parameterIndex) + { + string index = ((flags & WriteArgsFlags.NeedsTest) != 0) + ? CodeWriter.CreateVerbatimLiteral(dbName) + : parameterIndex.ToString(CultureInfo.InvariantCulture); + + return "ps[" + index + "]"; + } + [Flags] private enum SpecialCommandFlags { @@ -1488,6 +1569,31 @@ static bool IsDerived(ITypeSymbol? type, ITypeSymbol baseType) } } + private static IImmutableDictionary IdentifyTypeHandlers(Action reportDiagnostic, Compilation compilation) + { + var assembly = compilation.Assembly; + var attributes = assembly.GetAttributes() + .Concat(assembly.Modules.SelectMany(x => x.GetAttributes())) + .Where(x => Inspection.IsDapperAttribute(x) && x.AttributeClass!.Name == "TypeHandlerAttribute"); + + var dictionary = ImmutableDictionary.CreateBuilder(SymbolEqualityComparer.Default); + foreach (var attribute in attributes) + { + var valueType = attribute.AttributeClass!.TypeArguments[0]; + var typeHandler = attribute.AttributeClass!.TypeArguments[1]; + if (dictionary.ContainsKey(valueType)) + { + reportDiagnostic(Diagnostic.Create(Diagnostics.DuplicateTypeHandlers, null, valueType.Name)); + } + else + { + dictionary.Add(valueType, typeHandler); + } + } + + return dictionary.ToImmutable(); + } + internal abstract class SourceState { public Location? Location { get; } @@ -1570,4 +1676,30 @@ public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? return hash; } } + + internal class TypeHandlerInstanceRegistry + { + private int _nextTypeHandlerIndex = 0; + private readonly Dictionary _typeHandlerMap = new(SymbolEqualityComparer.Default); + + public (string PropertyName, INamedTypeSymbol HandlerType) GetOrCreateHandlerInfo(INamedTypeSymbol typeHandlerSymbol) + { + if (!_typeHandlerMap.TryGetValue(typeHandlerSymbol, out var info)) + { + var index = Interlocked.Increment(ref _nextTypeHandlerIndex); + var propertyName = $"__Handler{index}"; + info = (index, propertyName, typeHandlerSymbol); + _typeHandlerMap[typeHandlerSymbol] = info; + } + return (info.HandlerPropertyName, info.HandlerTypeSymbol); + } + + public IEnumerable<(string PropertyName, string TypeHandlerFullName)> GetAllRegisteredHandlersForEmission() + { + return _typeHandlerMap.Values + .Select(info => (info.HandlerPropertyName, info.HandlerTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) + .Distinct() + .ToList(); + } + } } \ No newline at end of file diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs index ee738995..06c70bf4 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs @@ -70,19 +70,26 @@ public GenerateState(GenerateContextProxy proxy) Nodes = proxy.Nodes; ctx = default; this.proxy = proxy; + TypeHandlerRegistry = proxy.TypeHandlerRegistry; + TypeHandlers = proxy.TypeHandlers; } - public GenerateState(SourceProductionContext ctx, in (Compilation Compilation, ImmutableArray Nodes) state) + public GenerateState(SourceProductionContext ctx, Compilation compilation, ImmutableArray nodes, + IImmutableDictionary typeHandlers, TypeHandlerInstanceRegistry typeHandlerRegistry) { - Compilation = state.Compilation; - Nodes = state.Nodes; + Compilation = compilation; + Nodes = nodes; this.ctx = ctx; proxy = null; + TypeHandlers = typeHandlers; + TypeHandlerRegistry = typeHandlerRegistry; } private readonly SourceProductionContext ctx; private readonly GenerateContextProxy? proxy; public readonly ImmutableArray Nodes; public readonly Compilation Compilation; public readonly GeneratorContext GeneratorContext = new(); + public readonly IImmutableDictionary TypeHandlers; + public readonly TypeHandlerInstanceRegistry TypeHandlerRegistry; internal void ReportDiagnostic(Diagnostic diagnostic) { @@ -121,9 +128,12 @@ internal abstract class GenerateContextProxy { public abstract Compilation Compilation { get; } public abstract ImmutableArray Nodes { get; } + public abstract TypeHandlerInstanceRegistry TypeHandlerRegistry { get; } + public abstract IImmutableDictionary TypeHandlers { get; } - public static GenerateContextProxy Create(in CompilationAnalysisContext context, ImmutableArray nodes) - => new CompilationAnalysisContextProxy(in context, nodes); + public static GenerateContextProxy Create(in CompilationAnalysisContext context, ImmutableArray nodes, + TypeHandlerInstanceRegistry typeHandlerRegistry, IImmutableDictionary typeHandlers) + => new CompilationAnalysisContextProxy(in context, nodes, typeHandlerRegistry, typeHandlers); internal virtual void AddSource(string hintName, string text) { } internal virtual void ReportDiagnostic(Diagnostic diagnostic) { } @@ -132,8 +142,14 @@ private sealed class CompilationAnalysisContextProxy : GenerateContextProxy { private readonly CompilationAnalysisContext context; private readonly ImmutableArray nodes; - public CompilationAnalysisContextProxy(in CompilationAnalysisContext context, ImmutableArray nodes) + private readonly TypeHandlerInstanceRegistry typeHandlerRegistry; + private readonly IImmutableDictionary typeHandlers; + + public CompilationAnalysisContextProxy(in CompilationAnalysisContext context, ImmutableArray nodes, + TypeHandlerInstanceRegistry typeHandlerRegistry, IImmutableDictionary typeHandlers) { + this.typeHandlerRegistry = typeHandlerRegistry; + this.typeHandlers = typeHandlers; this.context = context; this.nodes = nodes; } @@ -142,5 +158,7 @@ internal override void ReportDiagnostic(Diagnostic diagnostic) => context.ReportDiagnostic(diagnostic); public override Compilation Compilation => context.Compilation; public override ImmutableArray Nodes => nodes; + public override TypeHandlerInstanceRegistry TypeHandlerRegistry => typeHandlerRegistry; + public override IImmutableDictionary TypeHandlers => typeHandlers; } } diff --git a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs index 7f3b03e1..daa3b760 100644 --- a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs +++ b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs @@ -1,10 +1,12 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using System; +using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; +using System.Linq; using System.Text; using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; namespace Dapper.Internal; @@ -155,6 +157,23 @@ private void AppendAsValueTuple(ITypeSymbol value) } } + public CodeWriter AppendTypeHandlers(IEnumerable<(string PropertyName, string TypeHandlerFullName)> typeHandlers) + { + if (typeHandlers == null || !typeHandlers.Any()) + { + return this; + } + + Append("#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.").NewLine(); + foreach (var (propertyName, typeHandlerFullName) in typeHandlers) + { + Append($"private static {typeHandlerFullName}? {propertyName.Replace("__Handler", "__handler")};").NewLine(); + Append($"private static {typeHandlerFullName} {propertyName} => {propertyName.Replace("__Handler", "__handler")} ??= new {typeHandlerFullName}();").NewLine(); + } + Append("#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.").NewLine(); + return this; + } + public static int CountGettableInstanceMembers(ImmutableArray members) { int count = 0; @@ -238,7 +257,10 @@ public CodeWriter AppendEnumLiteral(ITypeSymbol enumType, int value) } public CodeWriter AppendVerbatimLiteral(string? value) => Append( - value is null ? "null" : SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(value)).ToFullString()); + CreateVerbatimLiteral(value)); + + public static string CreateVerbatimLiteral(string? value) => + value is null ? "null" : SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(value)).ToFullString(); public CodeWriter Append(char value) { Core.Append(value); diff --git a/src/Dapper.AOT/TypeHandlerT.cs b/src/Dapper.AOT/TypeHandlerT.cs index 51b15bd7..0f2775a3 100644 --- a/src/Dapper.AOT/TypeHandlerT.cs +++ b/src/Dapper.AOT/TypeHandlerT.cs @@ -10,7 +10,7 @@ namespace Dapper; /// when processing values of type /// [ImmutableObject(true)] -[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = true)] +[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module, AllowMultiple = true)] public sealed class TypeHandlerAttribute : Attribute where TTypeHandler : TypeHandler, new() {} @@ -31,4 +31,10 @@ public virtual void SetValue(DbParameter parameter, T value) /// public virtual T Parse(DbParameter parameter) => CommandUtils.As(parameter.Value); + + /// + /// Reads the value from the results + /// + public virtual T Read(DbDataReader reader, int columnOffset) + => CommandUtils.As(reader.GetValue(columnOffset)); } \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.cs b/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.cs index 34e4571c..19607888 100644 --- a/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.cs +++ b/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.cs @@ -115,16 +115,16 @@ private RowFactory0() {} case 3: result.X = GetValue(reader, columnOffset); break; - case 2: + case 1: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : reader.GetDouble(columnOffset); break; - case 5: + case 4: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : GetValue(reader, columnOffset); break; - case 4: + case 2: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); break; - case 7: + case 5: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : GetValue(reader, columnOffset); break; @@ -161,10 +161,10 @@ private RowFactory1() {} case 0: result.X = reader.GetInt32(columnOffset); break; - case 2: + case 1: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : reader.GetDouble(columnOffset); break; - case 4: + case 2: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); break; diff --git a/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.cs index 34e4571c..19607888 100644 --- a/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.cs +++ b/test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.cs @@ -115,16 +115,16 @@ private RowFactory0() {} case 3: result.X = GetValue(reader, columnOffset); break; - case 2: + case 1: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : reader.GetDouble(columnOffset); break; - case 5: + case 4: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : GetValue(reader, columnOffset); break; - case 4: + case 2: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); break; - case 7: + case 5: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : GetValue(reader, columnOffset); break; @@ -161,10 +161,10 @@ private RowFactory1() {} case 0: result.X = reader.GetInt32(columnOffset); break; - case 2: + case 1: result.Z = reader.IsDBNull(columnOffset) ? (double?)null : reader.GetDouble(columnOffset); break; - case 4: + case 2: result.Y = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); break; diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs new file mode 100644 index 00000000..c98a7ac6 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs @@ -0,0 +1,35 @@ +using Dapper; +using System.Data; +using System.Data.Common; + +[module: DapperAot] +[module: TypeHandler] + +public class CustomClassTypeHandler : TypeHandler +{ +} + +public class CustomClass +{ +} + +public static class Foo +{ + static void SomeCode(DbConnection connection, string bar, bool isBuffered) + { + _ = connection.Query("def"); + _ = connection.Query("def", new { Param = new CustomClass() }); + _ = connection.Query("@OutputValue = def", new CommandParameters()); + } + + public class CommandParameters + { + [DbValue(Direction = ParameterDirection.Output)] + public CustomClass OutputValue { get; set; } + } + + public class MyType + { + public CustomClass C { get; set; } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs new file mode 100644 index 00000000..ce611232 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs @@ -0,0 +1,198 @@ +#nullable enable +#pragma warning disable IDE0078 // unnecessary suppression is necessary +#pragma warning disable CS9270 // SDK-dependent change to interceptors usage +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private static global::CustomClassTypeHandler? __handler1; + private static global::CustomClassTypeHandler __Handler1 => __handler1 ??= new global::CustomClassTypeHandler(); + #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 20, 24)] + internal static global::System.Collections.Generic.IEnumerable Query0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, Buffered, StoredProcedure, BindResultsByName + // returns data: global::Foo.MyType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).QueryBuffered(param, RowFactory0.Instance); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 21, 24)] + internal static global::System.Collections.Generic.IEnumerable Query1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, StoredProcedure, KnownParameters + // takes parameter: + // parameter map: Param + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QueryBuffered(param, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 22, 24)] + internal static global::System.Collections.Generic.IEnumerable Query2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, Text, KnownParameters + // takes parameter: global::Foo.CommandParameters + // parameter map: OutputValue + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryBuffered((global::Foo.CommandParameters)param!, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 3859557458U when NormalizedEquals(name, "c"): + token = 0; // two tokens for right-typed and type-flexible + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::Foo.MyType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::Foo.MyType result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.C = reader.IsDBNull(columnOffset) ? (global::CustomClass?)null : __Handler1.Read(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Param"; + p.Direction = global::System.Data.ParameterDirection.Input; + __Handler1.SetValue((global::System.Data.Common.DbParameter)p, typed.Param); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + __Handler1.SetValue((global::System.Data.Common.DbParameter)ps[0], typed.Param); + + } + + } + + private sealed class CommandFactory1 : CommonCommandFactory + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "OutputValue"; + p.Direction = global::System.Data.ParameterDirection.Output; + p.Value = global::System.DBNull.Value; + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + ps[0].Value = global::System.DBNull.Value; + + } + public override bool RequirePostProcess => true; + + public override void PostProcess(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args, int rowCount) + { + var ps = cmd.Parameters; + args.OutputValue = __Handler1.Parse((global::System.Data.Common.DbParameter)ps[0]); + base.PostProcess(in cmd, args, rowCount); + + } + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs new file mode 100644 index 00000000..ce611232 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs @@ -0,0 +1,198 @@ +#nullable enable +#pragma warning disable IDE0078 // unnecessary suppression is necessary +#pragma warning disable CS9270 // SDK-dependent change to interceptors usage +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private static global::CustomClassTypeHandler? __handler1; + private static global::CustomClassTypeHandler __Handler1 => __handler1 ??= new global::CustomClassTypeHandler(); + #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 20, 24)] + internal static global::System.Collections.Generic.IEnumerable Query0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, Buffered, StoredProcedure, BindResultsByName + // returns data: global::Foo.MyType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).QueryBuffered(param, RowFactory0.Instance); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 21, 24)] + internal static global::System.Collections.Generic.IEnumerable Query1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, StoredProcedure, KnownParameters + // takes parameter: + // parameter map: Param + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QueryBuffered(param, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 22, 24)] + internal static global::System.Collections.Generic.IEnumerable Query2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, Text, KnownParameters + // takes parameter: global::Foo.CommandParameters + // parameter map: OutputValue + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryBuffered((global::Foo.CommandParameters)param!, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 3859557458U when NormalizedEquals(name, "c"): + token = 0; // two tokens for right-typed and type-flexible + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::Foo.MyType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::Foo.MyType result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.C = reader.IsDBNull(columnOffset) ? (global::CustomClass?)null : __Handler1.Read(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Param"; + p.Direction = global::System.Data.ParameterDirection.Input; + __Handler1.SetValue((global::System.Data.Common.DbParameter)p, typed.Param); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + __Handler1.SetValue((global::System.Data.Common.DbParameter)ps[0], typed.Param); + + } + + } + + private sealed class CommandFactory1 : CommonCommandFactory + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "OutputValue"; + p.Direction = global::System.Data.ParameterDirection.Output; + p.Value = global::System.DBNull.Value; + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + ps[0].Value = global::System.DBNull.Value; + + } + public override bool RequirePostProcess => true; + + public override void PostProcess(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args, int rowCount) + { + var ps = cmd.Parameters; + args.OutputValue = __Handler1.Parse((global::System.Data.Common.DbParameter)ps[0]); + base.PostProcess(in cmd, args, rowCount); + + } + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt new file mode 100644 index 00000000..d4a5c195 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 3 of 3 possible call-sites using 3 interceptors, 2 commands and 1 readers diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt new file mode 100644 index 00000000..d4a5c195 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 3 of 3 possible call-sites using 3 interceptors, 2 commands and 1 readers diff --git a/test/Dapper.AOT.Test/TestCommon/GeneratorWrapper.cs b/test/Dapper.AOT.Test/TestCommon/GeneratorWrapper.cs index bc877747..25004d90 100644 --- a/test/Dapper.AOT.Test/TestCommon/GeneratorWrapper.cs +++ b/test/Dapper.AOT.Test/TestCommon/GeneratorWrapper.cs @@ -1,8 +1,8 @@ -using Dapper.CodeAnalysis; +using System.Collections.Concurrent; +using System.Collections.Immutable; +using Dapper.CodeAnalysis; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Diagnostics; -using System.Collections.Concurrent; -using System.Collections.Immutable; using static Dapper.CodeAnalysis.DapperInterceptorGenerator; namespace Dapper.AOT.Test.TestCommon; @@ -39,9 +39,13 @@ public GenerationState(DapperInterceptorGenerator inner) [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0028:Simplify collection initialization", Justification = "This is fine")] private readonly ConcurrentBag _bag = new(); public void OnCompilationEnd(CompilationAnalysisContext context) - => inner.Generate(new GenerateState(GenerateContextProxy.Create(context, _bag.ToImmutableArray()))); - + { + var (registry, identifiedTypeHandlers) = InitTypeHandlers(context.ReportDiagnostic, context.Compilation); + var proxy = GenerateContextProxy.Create(context, [.. _bag], registry, identifiedTypeHandlers); + inner.Generate(new GenerateState(proxy)); + } + public void OnOperation(OperationAnalysisContext context) { if (!inner.PreFilter(context.Operation.Syntax, context.CancellationToken)) return; @@ -54,5 +58,4 @@ public void OnOperation(OperationAnalysisContext context) } } - }