1+ using System . Collections . Generic ;
2+ using System . Threading . Tasks ;
3+ using FluentAssertions . Analyzers . Utilities ;
4+ using Microsoft . CodeAnalysis ;
5+ using Microsoft . CodeAnalysis . CodeActions ;
6+ using Microsoft . CodeAnalysis . CodeFixes ;
7+ using Microsoft . CodeAnalysis . CSharp . Syntax ;
8+ using Microsoft . CodeAnalysis . Operations ;
9+ using CreateChangedDocument = System . Func < System . Threading . CancellationToken , System . Threading . Tasks . Task < Microsoft . CodeAnalysis . Document > > ;
10+
11+ namespace FluentAssertions . Analyzers ;
12+
13+ public abstract class TestingFrameworkCodeFixProvider : CodeFixProvider
14+ {
15+ protected const string Title = "Replace with FluentAssertions" ;
16+
17+ public override FixAllProvider GetFixAllProvider ( ) => WellKnownFixAllProviders . BatchFixer ;
18+
19+ public override async Task RegisterCodeFixesAsync ( CodeFixContext context )
20+ {
21+ var root = await context . Document . GetSyntaxRootAsync ( context . CancellationToken ) ;
22+ var semanticModel = await context . Document . GetSemanticModelAsync ( context . CancellationToken ) ;
23+
24+ var testContext = new TestingFrameworkCodeFixContext ( semanticModel . Compilation ) ;
25+ foreach ( var diagnostic in context . Diagnostics )
26+ {
27+ var node = root . FindNode ( diagnostic . Location . SourceSpan ) ;
28+ if ( node is not InvocationExpressionSyntax invocationExpression )
29+ {
30+ continue ;
31+ }
32+
33+ var operation = semanticModel . GetOperation ( invocationExpression , context . CancellationToken ) ;
34+ if ( operation is not IInvocationOperation invocation )
35+ {
36+ continue ;
37+ }
38+
39+ var fix = TryComputeFix ( invocation , context , testContext , diagnostic ) ;
40+ if ( fix is not null )
41+ {
42+ context . RegisterCodeFix ( CodeAction . Create ( Title , fix , equivalenceKey : Title ) , diagnostic ) ;
43+ }
44+ }
45+ }
46+
47+ protected abstract CreateChangedDocument TryComputeFix ( IInvocationOperation invocation , CodeFixContext context , TestingFrameworkCodeFixContext t , Diagnostic diagnostic ) ;
48+
49+ protected static bool ArgumentsAreTypeOf ( IInvocationOperation invocation , params ITypeSymbol [ ] types ) => ArgumentsAreTypeOf ( invocation , 0 , types ) ;
50+ protected static bool ArgumentsAreTypeOf ( IInvocationOperation invocation , int startFromIndex , params ITypeSymbol [ ] types )
51+ {
52+ if ( invocation . TargetMethod . Parameters . Length != types . Length + startFromIndex )
53+ {
54+ return false ;
55+ }
56+
57+ for ( int i = startFromIndex ; i < types . Length ; i ++ )
58+ {
59+ if ( ! invocation . TargetMethod . Parameters [ i ] . Type . EqualsSymbol ( types [ i ] ) )
60+ {
61+ return false ;
62+ }
63+ }
64+
65+ return true ;
66+ }
67+
68+ protected static bool ArgumentsAreGenericTypeOf ( IInvocationOperation invocation , params ITypeSymbol [ ] types )
69+ {
70+ const int generics = 1 ;
71+ if ( invocation . TargetMethod . Parameters . Length != types . Length )
72+ {
73+ return false ;
74+ }
75+
76+ if ( invocation . TargetMethod . TypeArguments . Length != generics )
77+ {
78+ return false ;
79+ }
80+
81+ var genericType = invocation . TargetMethod . TypeArguments [ 0 ] ;
82+
83+ for ( int i = 0 ; i < types . Length ; i ++ )
84+ {
85+ if ( invocation . TargetMethod . Parameters [ i ] . Type is not INamedTypeSymbol parameterType )
86+ {
87+ return false ;
88+ }
89+
90+ if ( parameterType . TypeArguments . IsEmpty && parameterType . EqualsSymbol ( genericType ) )
91+ {
92+ continue ;
93+ }
94+
95+ if ( parameterType . TypeArguments . Length != generics
96+ || ! ( parameterType . TypeArguments [ 0 ] . EqualsSymbol ( genericType ) && parameterType . OriginalDefinition . EqualsSymbol ( types [ i ] ) ) )
97+ {
98+ return false ;
99+ }
100+ }
101+
102+ return true ;
103+ }
104+
105+ protected static bool ArgumentsCount ( IInvocationOperation invocation , int arguments )
106+ {
107+ return invocation . TargetMethod . Parameters . Length == arguments ;
108+ }
109+
110+ protected sealed class TestingFrameworkCodeFixContext ( Compilation compilation )
111+ {
112+ public INamedTypeSymbol Object { get ; } = compilation . ObjectType ;
113+ public INamedTypeSymbol String { get ; } = compilation . GetTypeByMetadataName ( "System.String" ) ;
114+ public INamedTypeSymbol Int32 { get ; } = compilation . GetTypeByMetadataName ( "System.Int32" ) ;
115+ public INamedTypeSymbol Float { get ; } = compilation . GetTypeByMetadataName ( "System.Single" ) ;
116+ public INamedTypeSymbol Double { get ; } = compilation . GetTypeByMetadataName ( "System.Double" ) ;
117+ public INamedTypeSymbol Decimal { get ; } = compilation . GetTypeByMetadataName ( "System.Decimal" ) ;
118+ public INamedTypeSymbol Boolean { get ; } = compilation . GetTypeByMetadataName ( "System.Boolean" ) ;
119+ public INamedTypeSymbol Action { get ; } = compilation . GetTypeByMetadataName ( "System.Action" ) ;
120+ public INamedTypeSymbol Type { get ; } = compilation . GetTypeByMetadataName ( "System.Type" ) ;
121+ public INamedTypeSymbol DateTime { get ; } = compilation . GetTypeByMetadataName ( "System.DateTime" ) ;
122+ public INamedTypeSymbol TimeSpan { get ; } = compilation . GetTypeByMetadataName ( "System.TimeSpan" ) ;
123+ public INamedTypeSymbol FuncOfObject { get ; } = compilation . GetTypeByMetadataName ( "System.Func`1" ) . Construct ( compilation . ObjectType ) ;
124+ public INamedTypeSymbol FuncOfTask { get ; } = compilation . GetTypeByMetadataName ( "System.Func`1" ) . Construct ( compilation . GetTypeByMetadataName ( "System.Threading.Tasks.Task" ) ) ;
125+ public IArrayTypeSymbol ObjectArray { get ; } = compilation . CreateArrayTypeSymbol ( compilation . ObjectType ) ;
126+ public INamedTypeSymbol CultureInfo { get ; } = compilation . GetTypeByMetadataName ( "System.Globalization.CultureInfo" ) ;
127+ public INamedTypeSymbol StringComparison { get ; } = compilation . GetTypeByMetadataName ( "System.StringComparison" ) ;
128+ public INamedTypeSymbol Regex { get ; } = compilation . GetTypeByMetadataName ( "System.Text.RegularExpressions.Regex" ) ;
129+ public INamedTypeSymbol ICollection { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.ICollection" ) ;
130+ public INamedTypeSymbol IComparer { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.IComparer" ) ;
131+ public INamedTypeSymbol IEqualityComparerOfT1 { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEqualityComparer`1" ) ;
132+ public INamedTypeSymbol IEnumerableOfT1 { get ; } = compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEnumerable`1" ) ;
133+
134+ public INamedTypeSymbol Identity { get ; } = null ;
135+ }
136+ }
0 commit comments