@@ -362,7 +362,8 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
362362}
363363
364364static CanGenericSignature buildDifferentiableGenericSignature (CanGenericSignature sig,
365- CanType tanType) {
365+ CanType tanType,
366+ CanType origTypeOfAbstraction) {
366367 if (!sig)
367368 return sig;
368369
@@ -390,6 +391,20 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
390391 }
391392 }
392393
394+ if (origTypeOfAbstraction) {
395+ (void ) origTypeOfAbstraction.findIf ([&](Type t) -> bool {
396+ if (auto *at = t->getAs <ArchetypeType>()) {
397+ types.insert (at->getInterfaceType ()->getCanonicalType ());
398+ for (auto *proto : at->getConformsTo ()) {
399+ reqs.push_back (Requirement (RequirementKind::Conformance,
400+ at->getInterfaceType (),
401+ proto->getDeclaredInterfaceType ()));
402+ }
403+ }
404+ return false ;
405+ });
406+ }
407+
393408 return evaluateOrDefault (
394409 ctx.evaluator ,
395410 AbstractGenericSignatureRequest{sig.getPointer (), {}, reqs},
@@ -427,14 +442,15 @@ static CanType getAutoDiffTangentTypeForLinearMap(
427442static CanSILFunctionType getAutoDiffDifferentialType (
428443 SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
429444 IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
445+ CanType origTypeOfAbstraction,
430446 TypeConverter &TC) {
431447 // Given the tangent type and the corresponding original parameter's
432448 // convention, returns the tangent parameter's convention.
433449 auto getTangentParameterConvention =
434450 [&](CanType tanType,
435451 ParameterConvention origParamConv) -> ParameterConvention {
436452 auto sig = buildDifferentiableGenericSignature (
437- originalFnTy->getSubstGenericSignature (), tanType);
453+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
438454
439455 tanType = tanType->getCanonicalType (sig);
440456 AbstractionPattern pattern (sig, tanType);
@@ -462,7 +478,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
462478 [&](CanType tanType,
463479 ResultConvention origResConv) -> ResultConvention {
464480 auto sig = buildDifferentiableGenericSignature (
465- originalFnTy->getSubstGenericSignature (), tanType);
481+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
466482
467483 tanType = tanType->getCanonicalType (sig);
468484 AbstractionPattern pattern (sig, tanType);
@@ -565,7 +581,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
565581static CanSILFunctionType getAutoDiffPullbackType (
566582 SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
567583 IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
568- TypeConverter &TC) {
584+ CanType origTypeOfAbstraction, TypeConverter &TC) {
569585 auto &ctx = originalFnTy->getASTContext ();
570586 SmallVector<GenericTypeParamType *, 4 > substGenericParams;
571587 SmallVector<Requirement, 4 > substRequirements;
@@ -582,7 +598,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
582598 [&](CanType tanType,
583599 ResultConvention origResConv) -> ParameterConvention {
584600 auto sig = buildDifferentiableGenericSignature (
585- originalFnTy->getSubstGenericSignature (), tanType);
601+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
586602
587603 tanType = tanType->getCanonicalType (sig);
588604 AbstractionPattern pattern (sig, tanType);
@@ -613,7 +629,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
613629 [&](CanType tanType,
614630 ParameterConvention origParamConv) -> ResultConvention {
615631 auto sig = buildDifferentiableGenericSignature (
616- originalFnTy->getSubstGenericSignature (), tanType);
632+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
617633
618634 tanType = tanType->getCanonicalType (sig);
619635 AbstractionPattern pattern (sig, tanType);
@@ -780,7 +796,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
780796 AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
781797 LookupConformanceFn lookupConformance,
782798 CanGenericSignature derivativeFnInvocationGenSig,
783- bool isReabstractionThunk) {
799+ bool isReabstractionThunk,
800+ CanType origTypeOfAbstraction) {
784801 assert (parameterIndices);
785802 assert (!parameterIndices->isEmpty () && " Parameter indices must not be empty" );
786803 assert (resultIndices);
@@ -810,12 +827,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
810827 case AutoDiffDerivativeFunctionKind::JVP:
811828 closureType =
812829 getAutoDiffDifferentialType (constrainedOriginalFnTy, parameterIndices,
813- resultIndices, lookupConformance, TC);
830+ resultIndices, lookupConformance,
831+ origTypeOfAbstraction, TC);
814832 break ;
815833 case AutoDiffDerivativeFunctionKind::VJP:
816834 closureType =
817835 getAutoDiffPullbackType (constrainedOriginalFnTy, parameterIndices,
818- resultIndices, lookupConformance, TC);
836+ resultIndices, lookupConformance,
837+ origTypeOfAbstraction, TC);
819838 break ;
820839 }
821840 // Compute the derivative function parameters.
0 commit comments