@@ -396,6 +396,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
396396 GenericSignature ()).getCanonicalSignature ();
397397}
398398
399+ // / Given an original type, computes its tangent type for the purpose of
400+ // / building a linear map using this type. When the original type is an
401+ // / archetype or contains a type parameter, appends a new generic parameter and
402+ // / a corresponding replacement type to the given containers.
403+ static CanType getAutoDiffTangentTypeForLinearMap (
404+ Type originalType,
405+ LookupConformanceFn lookupConformance,
406+ SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
407+ SmallVectorImpl<Type> &substReplacements,
408+ ASTContext &context
409+ ) {
410+ auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
411+ assert (maybeTanType && " Type does not have a tangent space?" );
412+ auto tanType = maybeTanType->getCanonicalType ();
413+ // If concrete, the tangent type is concrete.
414+ if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
415+ return tanType;
416+ // Otherwise, the tangent type is a new generic parameter substituted for the
417+ // tangent type.
418+ auto gpIndex = substGenericParams.size ();
419+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
420+ substGenericParams.push_back (gpType);
421+ substReplacements.push_back (tanType);
422+ return gpType;
423+ }
424+
399425// / Returns the differential type for the given original function type,
400426// / parameter indices, and result index.
401427static CanSILFunctionType getAutoDiffDifferentialType (
@@ -471,45 +497,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
471497 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
472498 SmallVector<SILParameterInfo, 8 > differentialParams;
473499 for (auto ¶m : diffParams) {
474- auto paramTan =
475- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
476- assert (paramTan && " Parameter type does not have a tangent space?" );
477- auto paramTanType = paramTan->getCanonicalType ();
478- auto paramConv = getTangentParameterConvention (paramTanType,
479- param.getConvention ());
480- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
481- differentialParams.push_back (
482- {paramTan->getCanonicalType (), paramConv});
483- } else {
484- auto gpIndex = substGenericParams.size ();
485- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
486- substGenericParams.push_back (gpType);
487- substReplacements.push_back (paramTanType);
488- differentialParams.push_back ({gpType, paramConv});
489- }
500+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
501+ param.getInterfaceType (), lookupConformance,
502+ substGenericParams, substReplacements, ctx);
503+ auto paramConv = getTangentParameterConvention (
504+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
505+ param.getInterfaceType ()
506+ ->getAutoDiffTangentSpace (lookupConformance)
507+ ->getCanonicalType (),
508+ param.getConvention ());
509+ differentialParams.push_back ({paramTanType, paramConv});
490510 }
491511 SmallVector<SILResultInfo, 1 > differentialResults;
492512 for (auto resultIndex : resultIndices->getIndices ()) {
493513 // Handle formal original result.
494514 if (resultIndex < originalFnTy->getNumResults ()) {
495515 auto &result = originalResults[resultIndex];
496- auto resultTan =
497- result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
498- assert (resultTan && " Result type does not have a tangent space?" );
499- auto resultTanType = resultTan->getCanonicalType ();
500- auto resultConv =
501- getTangentResultConvention (resultTanType, result.getConvention ());
502- if (!resultTanType->hasArchetype () &&
503- !resultTanType->hasTypeParameter ()) {
504- differentialResults.push_back (
505- {resultTan->getCanonicalType (), resultConv});
506- } else {
507- auto gpIndex = substGenericParams.size ();
508- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
509- substGenericParams.push_back (gpType);
510- substReplacements.push_back (resultTanType);
511- differentialResults.push_back ({gpType, resultConv});
512- }
516+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
517+ result.getInterfaceType (), lookupConformance,
518+ substGenericParams, substReplacements, ctx);
519+ auto resultConv = getTangentResultConvention (
520+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
521+ result.getInterfaceType ()
522+ ->getAutoDiffTangentSpace (lookupConformance)
523+ ->getCanonicalType (),
524+ result.getConvention ());
525+ differentialResults.push_back ({resultTanType, resultConv});
513526 continue ;
514527 }
515528 // Handle original `inout` parameter.
@@ -524,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
524537 if (parameterIndices->contains (paramIndex))
525538 continue ;
526539 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
527- auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
528- lookupConformance);
529- assert (paramTan && " Parameter type does not have a tangent space? " );
540+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
541+ inoutParam. getInterfaceType (), lookupConformance,
542+ substGenericParams, substReplacements, ctx );
530543 differentialResults.push_back (
531- {paramTan-> getCanonicalType () , ResultConvention::Indirect});
544+ {inoutParamTanType , ResultConvention::Indirect});
532545 }
533546
534547 SubstitutionMap substitutions;
@@ -635,23 +648,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
635648 // Handle formal original result.
636649 if (resultIndex < originalFnTy->getNumResults ()) {
637650 auto &origRes = originalResults[resultIndex];
638- auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
639- lookupConformance);
640- assert (resultTan && " Result type does not have a tangent space?" );
641- auto resultTanType = resultTan->getCanonicalType ();
642- auto paramTanConvention = getTangentParameterConventionForOriginalResult (
643- resultTanType, origRes.getConvention ());
644- if (!resultTanType->hasArchetype () &&
645- !resultTanType->hasTypeParameter ()) {
646- auto resultTanType = resultTan->getCanonicalType ();
647- pullbackParams.push_back ({resultTanType, paramTanConvention});
648- } else {
649- auto gpIndex = substGenericParams.size ();
650- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
651- substGenericParams.push_back (gpType);
652- substReplacements.push_back (resultTanType);
653- pullbackParams.push_back ({gpType, paramTanConvention});
654- }
651+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
652+ origRes.getInterfaceType (), lookupConformance,
653+ substGenericParams, substReplacements, ctx);
654+ auto paramConv = getTangentParameterConventionForOriginalResult (
655+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
656+ origRes.getInterfaceType ()
657+ ->getAutoDiffTangentSpace (lookupConformance)
658+ ->getCanonicalType (),
659+ origRes.getConvention ());
660+ pullbackParams.push_back ({resultTanType, paramConv});
655661 continue ;
656662 }
657663 // Handle original `inout` parameter.
@@ -661,28 +667,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
661667 auto paramIndex =
662668 std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
663669 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
664- auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
665- lookupConformance);
666- assert (paramTan && " Parameter type does not have a tangent space?" );
667670 // The pullback parameter convention depends on whether the original `inout`
668671 // paramater is a differentiability parameter.
669672 // - If yes, the pullback parameter convention is `@inout`.
670673 // - If no, the pullback parameter convention is `@in_guaranteed`.
674+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
675+ inoutParam.getInterfaceType (), lookupConformance,
676+ substGenericParams, substReplacements, ctx);
671677 bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
672678 auto paramTanConvention = isWrtInoutParameter
673- ? inoutParam.getConvention ()
674- : ParameterConvention::Indirect_In_Guaranteed;
675- auto paramTanType = paramTan->getCanonicalType ();
676- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
677- pullbackParams.push_back (
678- SILParameterInfo (paramTanType, paramTanConvention));
679- } else {
680- auto gpIndex = substGenericParams.size ();
681- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
682- substGenericParams.push_back (gpType);
683- substReplacements.push_back (paramTanType);
684- pullbackParams.push_back ({gpType, paramTanConvention});
685- }
679+ ? inoutParam.getConvention ()
680+ : ParameterConvention::Indirect_In_Guaranteed;
681+ pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
686682 }
687683
688684 // Collect pullback results.
@@ -694,21 +690,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
694690 // and always appear as pullback parameters.
695691 if (param.isIndirectInOut ())
696692 continue ;
697- auto paramTan =
698- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
699- assert (paramTan && " Parameter type does not have a tangent space?" );
700- auto paramTanType = paramTan->getCanonicalType ();
693+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
694+ param.getInterfaceType (), lookupConformance,
695+ substGenericParams, substReplacements, ctx);
701696 auto resultTanConvention = getTangentResultConventionForOriginalParameter (
702- paramTanType, param.getConvention ());
703- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
704- pullbackResults.push_back ({paramTanType, resultTanConvention});
705- } else {
706- auto gpIndex = substGenericParams.size ();
707- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
708- substGenericParams.push_back (gpType);
709- substReplacements.push_back (paramTanType);
710- pullbackResults.push_back ({gpType, resultTanConvention});
711- }
697+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
698+ param.getInterfaceType ()
699+ ->getAutoDiffTangentSpace (lookupConformance)
700+ ->getCanonicalType (),
701+ param.getConvention ());
702+ pullbackResults.push_back ({paramTanType, resultTanConvention});
712703 }
713704 SubstitutionMap substitutions;
714705 if (!substGenericParams.empty ()) {
0 commit comments