@@ -441,6 +441,8 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
441441 isAutoClosure ? CTP_AutoclosureDefaultParameter : CTP_DefaultParameter,
442442 paramType, /* isDiscarded=*/ false );
443443
444+ auto paramInterfaceTy = paramType->mapTypeOutOfContext ();
445+
444446 {
445447 // Buffer all of the diagnostics produced by \c typeCheckExpression
446448 // since in some cases we need to try type-checking again with a
@@ -459,6 +461,11 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
459461 if (!ctx.TypeCheckerOpts .EnableTypeInferenceFromDefaultArguments )
460462 return Type ();
461463
464+ // Parameter type doesn't have any generic parameters mentioned
465+ // in it, so there is nothing to infer.
466+ if (!paramInterfaceTy->hasTypeParameter ())
467+ return Type ();
468+
462469 // Ignore any diagnostics emitted by the original type-check.
463470 diagnostics.abort ();
464471 }
@@ -475,40 +482,76 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
475482 // If both of aforementioned conditions are true, let's attempt
476483 // to open generic parameter and infer the type of this default
477484 // expression.
478- auto interfaceType = paramType->mapTypeOutOfContext ();
479- if (!interfaceType->isTypeParameter ())
480- return Type ();
485+ OpenedTypeMap genericParameters;
486+
487+ ConstraintSystemOptions options;
488+ options |= ConstraintSystemFlags::AllowFixes;
489+
490+ ConstraintSystem cs (DC, options);
491+
492+ auto *locator = cs.getConstraintLocator (
493+ defaultValue, LocatorPathElt::ContextualType (
494+ defaultExprTarget.getExprContextualTypePurpose ()));
481495
482- auto containsType = [&](Type type, Type contained) {
483- return type.findIf (
484- [&contained](Type nested) { return nested->isEqual (contained); });
496+ auto getCanonicalGenericParamTy = [](GenericTypeParamType *GP) {
497+ return cast<GenericTypeParamType>(GP->getCanonicalType ());
485498 };
486499
487- // Anchor of this default expression.
500+ // Find and open all of the generic parameters used by the parameter
501+ // and replace them with type variables.
502+ auto contextualTy = paramInterfaceTy.transform ([&](Type type) -> Type {
503+ assert (!type->is <UnboundGenericType>());
504+
505+ if (auto *GP = type->getAs <GenericTypeParamType>()) {
506+ return cs.openGenericParameter (DC->getParent (), GP, genericParameters,
507+ locator);
508+ }
509+ return type;
510+ });
511+
512+ auto containsTypes = [&](Type type, OpenedTypeMap &toFind) {
513+ return type.findIf ([&](Type nested) {
514+ if (auto *GP = nested->getAs <GenericTypeParamType>())
515+ return toFind.count (getCanonicalGenericParamTy (GP)) > 0 ;
516+ return false ;
517+ });
518+ };
519+
520+ auto containsGenericParamsExcluding = [&](Type type,
521+ OpenedTypeMap &exclusions) -> bool {
522+ return type.findIf ([&](Type type) {
523+ if (auto *GP = type->getAs <GenericTypeParamType>())
524+ return !exclusions.count (getCanonicalGenericParamTy (GP));
525+ return false ;
526+ });
527+ };
528+
529+ // Anchor of this default expression i.e. function, subscript
530+ // or enum case.
488531 auto *anchor = cast<ValueDecl>(DC->getParent ()->getAsDecl ());
489532
490- // Check whether generic parameter is only mentioned once in
533+ // Check whether generic parameters are only mentioned once in
491534 // the anchor's signature.
492535 {
493536 auto anchorTy = anchor->getInterfaceType ()->castTo <GenericFunctionType>();
494537
495- // Reject if generic parameter could be inferred from result type.
496- if (containsType (anchorTy->getResult (), interfaceType )) {
538+ // Reject if generic parameters could be inferred from result type.
539+ if (containsTypes (anchorTy->getResult (), genericParameters )) {
497540 ctx.Diags .diagnose (
498541 defaultValue->getLoc (),
499542 diag::cannot_default_generic_parameter_inferrable_from_result,
500- interfaceType );
543+ paramInterfaceTy );
501544 return Type ();
502545 }
503546
504- // Reject if generic parameter is used in multiple different positions
547+ // Reject if generic parameters are used in multiple different positions
505548 // in the parameter list.
506549
507550 llvm::SmallVector<unsigned , 2 > affectedParams;
508551 for (unsigned i : indices (anchorTy->getParams ())) {
509552 const auto ¶m = anchorTy->getParams ()[i];
510553
511- if (containsType (param.getPlainType (), interfaceType ))
554+ if (containsTypes (param.getPlainType (), genericParameters ))
512555 affectedParams.push_back (i);
513556 }
514557
@@ -524,27 +567,14 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
524567 defaultValue->getLoc (),
525568 diag::
526569 cannot_default_generic_parameter_inferrable_from_another_parameter,
527- interfaceType , params.str ());
570+ paramInterfaceTy , params.str ());
528571 return Type ();
529572 }
530573 }
531574
532575 auto signature = DC->getGenericSignatureOfContext ();
533576 assert (signature && " generic parameter without signature?" );
534577
535- ConstraintSystemOptions options;
536- options |= ConstraintSystemFlags::AllowFixes;
537-
538- ConstraintSystem cs (DC, options);
539-
540- auto *locator = cs.getConstraintLocator (
541- defaultValue, LocatorPathElt::ContextualType (
542- defaultExprTarget.getExprContextualTypePurpose ()));
543-
544- // A replacement for generic parameter type to associate any generic
545- // requirements with.
546- auto *contextualTy = cs.createTypeVariable (locator, /* flags=*/ 0 );
547-
548578 auto *requirementBaseLocator = cs.getConstraintLocator (
549579 locator, LocatorPathElt::OpenedGeneric (signature));
550580
@@ -553,76 +583,84 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
553583 // a dependent member type), that means it could be inferred through
554584 // them e.g. `T: X.Y` or `T == U`.
555585 {
556- auto isViable = [](Type type) {
557- return !(type->hasTypeParameter () && type->hasDependentMember ());
558- };
559-
560586 auto recordRequirement = [&](unsigned index, Requirement requirement,
561587 ConstraintLocator *locator) {
562588 cs.openGenericRequirement (DC->getParent (), index, requirement,
563589 /* skipSelfProtocolConstraint=*/ false , locator,
564- [](Type type) -> Type { return type; });
590+ [&](Type type) -> Type {
591+ return cs.openType (type, genericParameters);
592+ });
593+ };
594+
595+ auto diagnoseInvalidRequirement = [&](Requirement requirement) {
596+ SmallString<32 > reqBuf;
597+ llvm::raw_svector_ostream req (reqBuf);
598+
599+ requirement.print (req, PrintOptions ());
600+
601+ ctx.Diags .diagnose (
602+ defaultValue->getLoc (),
603+ diag::cannot_default_generic_parameter_invalid_requirement,
604+ paramInterfaceTy, req.str ());
565605 };
566606
567607 auto requirements = signature.getRequirements ();
568608 for (unsigned reqIdx = 0 ; reqIdx != requirements.size (); ++reqIdx) {
569609 auto &requirement = requirements[reqIdx];
570610
571611 switch (requirement.getKind ()) {
572- case RequirementKind::Conformance: {
573- if (!requirement.getFirstType ()->isEqual (interfaceType))
574- continue ;
575-
576- recordRequirement (reqIdx,
577- {RequirementKind::Conformance, contextualTy,
578- requirement.getSecondType ()},
579- requirementBaseLocator);
580- break ;
581- }
612+ case RequirementKind::SameType: {
613+ auto lhsTy = requirement.getFirstType ();
614+ auto rhsTy = requirement.getSecondType ();
582615
583- case RequirementKind::Superclass: {
584- auto subclassTy = requirement.getFirstType ();
585- auto superclassTy = requirement.getSecondType ();
616+ // Unrelated requirement.
617+ if (!containsTypes (lhsTy, genericParameters) &&
618+ !containsTypes (rhsTy, genericParameters))
619+ continue ;
586620
587- if (subclassTy->isEqual (interfaceType) && isViable (superclassTy)) {
588- recordRequirement (
589- reqIdx, {RequirementKind::Superclass, contextualTy, superclassTy},
590- requirementBaseLocator);
621+ // Allow a subset of generic same-type requirements that only mention
622+ // "in scope" generic parameters e.g. `T.X == Int` or `T == U.Z`
623+ if (!containsGenericParamsExcluding (lhsTy, genericParameters) &&
624+ !containsGenericParamsExcluding (rhsTy, genericParameters)) {
625+ recordRequirement (reqIdx, requirement, requirementBaseLocator);
626+ continue ;
591627 }
592628
593- break ;
594- }
595-
596- case RequirementKind::SameType: {
597- // If there is a same-type constraint that involves our parameter
598- // type, fail the type-check since the type could be inferred
599- // through other positions.
600- if (containsType (requirement.getFirstType (), interfaceType) ||
601- containsType (requirement.getSecondType (), interfaceType)) {
602- SmallString<32 > reqBuf;
603- llvm::raw_svector_ostream req (reqBuf);
604-
605- requirement.print (req, PrintOptions ());
606-
607- ctx.Diags .diagnose (
608- defaultValue->getLoc (),
609- diag::
610- cannot_default_generic_parameter_inferrable_through_same_type,
611- interfaceType, req.str ());
629+ // If there is a same-type constraint that involves out of scope
630+ // generic parameters mixed with in-scope ones, fail the type-check
631+ // since the type could be inferred through other positions.
632+ {
633+ diagnoseInvalidRequirement (requirement);
612634 return Type ();
613635 }
614-
615- continue ;
616636 }
617637
638+ case RequirementKind::Conformance:
639+ case RequirementKind::Superclass:
618640 case RequirementKind::Layout:
619- if (!requirement.getFirstType ()->isEqual (interfaceType))
641+ auto adheringTy = requirement.getFirstType ();
642+
643+ // Unrelated requirement.
644+ if (!containsTypes (adheringTy, genericParameters))
620645 continue ;
621646
622- recordRequirement (reqIdx,
623- {RequirementKind::Layout, contextualTy,
624- requirement.getLayoutConstraint ()},
625- requirementBaseLocator);
647+ // If adhering type has a mix or in- and out-of-scope parameters
648+ // mentioned we need to diagnose.
649+ if (containsGenericParamsExcluding (adheringTy, genericParameters)) {
650+ diagnoseInvalidRequirement (requirement);
651+ return Type ();
652+ }
653+
654+ if (requirement.getKind () == RequirementKind::Superclass) {
655+ auto superclassTy = requirement.getSecondType ();
656+
657+ if (containsGenericParamsExcluding (superclassTy, genericParameters)) {
658+ diagnoseInvalidRequirement (requirement);
659+ return Type ();
660+ }
661+ }
662+
663+ recordRequirement (reqIdx, requirement, requirementBaseLocator);
626664 break ;
627665 }
628666 }
0 commit comments