@@ -28,7 +28,7 @@ public void Initialize (IncrementalGeneratorInitializationContext context)
2828 private void Execute ( SourceProductionContext context , ( Compilation Left , ImmutableArray < ClassDeclarationSyntax > Right ) arg2 )
2929 {
3030 INamedTypeSymbol assertType = arg2 . Left . GetTypeByMetadataName ( "Xunit.Assert" )
31- ?? throw new NotSupportedException ( "Referencing codebase does not include Xunit, could not find Xunit.Assert" ) ;
31+ ?? throw new NotSupportedException ( "Referencing codebase does not include Xunit, could not find Xunit.Assert" ) ;
3232
3333 GenerateMethods ( assertType , context , "Equal" , false ) ;
3434
@@ -70,7 +70,7 @@ private void Execute (SourceProductionContext context, (Compilation Left, Immuta
7070 GenerateMethods ( assertType , context , "Subset" , true ) ;
7171 GenerateMethods ( assertType , context , "Superset" , true ) ;
7272
73- // GenerateMethods (assertType, context, "Throws", true);
73+ // GenerateMethods (assertType, context, "Throws", true);
7474 // GenerateMethods (assertType, context, "ThrowsAny", true);
7575 GenerateMethods ( assertType , context , "True" , false ) ;
7676 }
@@ -208,25 +208,49 @@ out string typeParams
208208 dec = dec . WithTypeParameterList ( SyntaxFactory . TypeParameterList ( SyntaxFactory . SeparatedList ( typeParameters ) ) ) ;
209209
210210 // Handle type parameter constraints
211- List < TypeParameterConstraintClauseSyntax > constraintClauses = methodSymbol . TypeParameters
212- . Where ( tp => tp . ConstraintTypes . Length > 0 )
213- . Select (
214- tp =>
215- SyntaxFactory . TypeParameterConstraintClause ( tp . Name )
216- . WithConstraints (
217- SyntaxFactory
218- . SeparatedList < TypeParameterConstraintSyntax > (
219- tp . ConstraintTypes . Select (
220- constraintType =>
221- SyntaxFactory . TypeConstraint (
222- SyntaxFactory . ParseTypeName (
223- constraintType
224- . ToDisplayString ( ) ) )
225- )
226- )
227- )
228- )
229- . ToList ( ) ;
211+ List < TypeParameterConstraintClauseSyntax > constraintClauses = new ( ) ;
212+
213+ foreach ( ITypeParameterSymbol tp in methodSymbol . TypeParameters )
214+ {
215+ List < TypeParameterConstraintSyntax > constraints = new ( ) ;
216+
217+ // Add class/struct constraints
218+ if ( tp . HasReferenceTypeConstraint )
219+ {
220+ constraints . Add ( SyntaxFactory . ClassOrStructConstraint ( SyntaxKind . ClassConstraint ) ) ;
221+ }
222+ else if ( tp . HasValueTypeConstraint )
223+ {
224+ constraints . Add ( SyntaxFactory . ClassOrStructConstraint ( SyntaxKind . StructConstraint ) ) ;
225+ }
226+ else if ( tp . HasNotNullConstraint )
227+ {
228+ // Add notnull constraint
229+ constraints . Add ( SyntaxFactory . TypeConstraint ( SyntaxFactory . IdentifierName ( "notnull" ) ) ) ;
230+ }
231+
232+ // Add type constraints
233+ foreach ( ITypeSymbol constraintType in tp . ConstraintTypes )
234+ {
235+ constraints . Add (
236+ SyntaxFactory . TypeConstraint (
237+ SyntaxFactory . ParseTypeName ( constraintType . ToDisplayString ( ) ) ) ) ;
238+ }
239+
240+ // Add new() constraint
241+ if ( tp . HasConstructorConstraint )
242+ {
243+ constraints . Add ( SyntaxFactory . ConstructorConstraint ( ) ) ;
244+ }
245+
246+ // Only add constraint clause if there are constraints
247+ if ( constraints . Any ( ) )
248+ {
249+ constraintClauses . Add (
250+ SyntaxFactory . TypeParameterConstraintClause ( tp . Name )
251+ . WithConstraints ( SyntaxFactory . SeparatedList ( constraints ) ) ) ;
252+ }
253+ }
230254
231255 if ( constraintClauses . Any ( ) )
232256 {
@@ -281,12 +305,12 @@ private ParameterSyntax CreateParameter (IParameterSymbol p)
281305 if ( p . RefKind != RefKind . None )
282306 {
283307 SyntaxKind modifierKind = p . RefKind switch
284- {
285- RefKind . Ref => SyntaxKind . RefKeyword ,
286- RefKind . Out => SyntaxKind . OutKeyword ,
287- RefKind . In => SyntaxKind . InKeyword ,
288- _ => throw new NotSupportedException ( $ "Unsupported RefKind: { p . RefKind } ")
289- } ;
308+ {
309+ RefKind . Ref => SyntaxKind . RefKeyword ,
310+ RefKind . Out => SyntaxKind . OutKeyword ,
311+ RefKind . In => SyntaxKind . InKeyword ,
312+ _ => throw new NotSupportedException ( $ "Unsupported RefKind: { p . RefKind } ")
313+ } ;
290314
291315
292316 modifiers . Add ( SyntaxFactory . Token ( modifierKind ) ) ;
@@ -302,23 +326,23 @@ private ParameterSyntax CreateParameter (IParameterSymbol p)
302326 if ( p . HasExplicitDefaultValue )
303327 {
304328 ExpressionSyntax defaultValueExpression = p . ExplicitDefaultValue switch
305- {
306- null => SyntaxFactory . LiteralExpression ( SyntaxKind . NullLiteralExpression ) ,
307- bool b => SyntaxFactory . LiteralExpression (
308- b
309- ? SyntaxKind . TrueLiteralExpression
310- : SyntaxKind . FalseLiteralExpression ) ,
311- int i => SyntaxFactory . LiteralExpression (
312- SyntaxKind . NumericLiteralExpression ,
313- SyntaxFactory . Literal ( i ) ) ,
314- double d => SyntaxFactory . LiteralExpression (
315- SyntaxKind . NumericLiteralExpression ,
316- SyntaxFactory . Literal ( d ) ) ,
317- string s => SyntaxFactory . LiteralExpression (
318- SyntaxKind . StringLiteralExpression ,
319- SyntaxFactory . Literal ( s ) ) ,
320- _ => SyntaxFactory . ParseExpression ( p . ExplicitDefaultValue . ToString ( ) ) // Fallback
321- } ;
329+ {
330+ null => SyntaxFactory . LiteralExpression ( SyntaxKind . NullLiteralExpression ) ,
331+ bool b => SyntaxFactory . LiteralExpression (
332+ b
333+ ? SyntaxKind . TrueLiteralExpression
334+ : SyntaxKind . FalseLiteralExpression ) ,
335+ int i => SyntaxFactory . LiteralExpression (
336+ SyntaxKind . NumericLiteralExpression ,
337+ SyntaxFactory . Literal ( i ) ) ,
338+ double d => SyntaxFactory . LiteralExpression (
339+ SyntaxKind . NumericLiteralExpression ,
340+ SyntaxFactory . Literal ( d ) ) ,
341+ string s => SyntaxFactory . LiteralExpression (
342+ SyntaxKind . StringLiteralExpression ,
343+ SyntaxFactory . Literal ( s ) ) ,
344+ _ => SyntaxFactory . ParseExpression ( p . ExplicitDefaultValue . ToString ( ) ) // Fallback
345+ } ;
322346
323347 parameterSyntax = parameterSyntax . WithDefault (
324348 SyntaxFactory . EqualsValueClause ( defaultValueExpression )
@@ -330,4 +354,4 @@ private ParameterSyntax CreateParameter (IParameterSymbol p)
330354
331355 // Helper method to check if a parameter name is a reserved keyword
332356 private bool IsReservedKeyword ( string name ) { return string . Equals ( name , "object" ) ; }
333- }
357+ }
0 commit comments