@@ -408,6 +408,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408408 return buildGenericSignature (ctx, sig, {}, reqs).getCanonicalSignature ();
409409}
410410
411+ // / Given an original type, computes its tangent type for the purpose of
412+ // / building a linear map using this type. When the original type is an
413+ // / archetype or contains a type parameter, appends a new generic parameter and
414+ // / a corresponding replacement type to the given containers.
415+ static CanType getAutoDiffTangentTypeForLinearMap (
416+ Type originalType,
417+ LookupConformanceFn lookupConformance,
418+ SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419+ SmallVectorImpl<Type> &substReplacements,
420+ ASTContext &context
421+ ) {
422+ auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
423+ assert (maybeTanType && " Type does not have a tangent space?" );
424+ auto tanType = maybeTanType->getCanonicalType ();
425+ // If concrete, the tangent type is concrete.
426+ if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
427+ return tanType;
428+ // Otherwise, the tangent type is a new generic parameter substituted for the
429+ // tangent type.
430+ auto gpIndex = substGenericParams.size ();
431+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
432+ substGenericParams.push_back (gpType);
433+ substReplacements.push_back (tanType);
434+ return gpType;
435+ }
436+
411437// / Returns the differential type for the given original function type,
412438// / parameter indices, and result index.
413439static CanSILFunctionType getAutoDiffDifferentialType (
@@ -484,45 +510,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
484510 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
485511 SmallVector<SILParameterInfo, 8 > differentialParams;
486512 for (auto ¶m : diffParams) {
487- auto paramTan =
488- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
489- assert (paramTan && " Parameter type does not have a tangent space?" );
490- auto paramTanType = paramTan->getCanonicalType ();
491- auto paramConv = getTangentParameterConvention (paramTanType,
492- param.getConvention ());
493- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
494- differentialParams.push_back (
495- {paramTan->getCanonicalType (), paramConv});
496- } else {
497- auto gpIndex = substGenericParams.size ();
498- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
499- substGenericParams.push_back (gpType);
500- substReplacements.push_back (paramTanType);
501- differentialParams.push_back ({gpType, paramConv});
502- }
513+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
514+ param.getInterfaceType (), lookupConformance,
515+ substGenericParams, substReplacements, ctx);
516+ auto paramConv = getTangentParameterConvention (
517+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518+ param.getInterfaceType ()
519+ ->getAutoDiffTangentSpace (lookupConformance)
520+ ->getCanonicalType (),
521+ param.getConvention ());
522+ differentialParams.push_back ({paramTanType, paramConv});
503523 }
504524 SmallVector<SILResultInfo, 1 > differentialResults;
505525 for (auto resultIndex : resultIndices->getIndices ()) {
506526 // Handle formal original result.
507527 if (resultIndex < originalFnTy->getNumResults ()) {
508528 auto &result = originalResults[resultIndex];
509- auto resultTan =
510- result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
511- assert (resultTan && " Result type does not have a tangent space?" );
512- auto resultTanType = resultTan->getCanonicalType ();
513- auto resultConv =
514- getTangentResultConvention (resultTanType, result.getConvention ());
515- if (!resultTanType->hasArchetype () &&
516- !resultTanType->hasTypeParameter ()) {
517- differentialResults.push_back (
518- {resultTan->getCanonicalType (), resultConv});
519- } else {
520- auto gpIndex = substGenericParams.size ();
521- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
522- substGenericParams.push_back (gpType);
523- substReplacements.push_back (resultTanType);
524- differentialResults.push_back ({gpType, resultConv});
525- }
529+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
530+ result.getInterfaceType (), lookupConformance,
531+ substGenericParams, substReplacements, ctx);
532+ auto resultConv = getTangentResultConvention (
533+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534+ result.getInterfaceType ()
535+ ->getAutoDiffTangentSpace (lookupConformance)
536+ ->getCanonicalType (),
537+ result.getConvention ());
538+ differentialResults.push_back ({resultTanType, resultConv});
526539 continue ;
527540 }
528541 // Handle original `inout` parameter.
@@ -537,11 +550,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
537550 if (parameterIndices->contains (paramIndex))
538551 continue ;
539552 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
540- auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
541- lookupConformance);
542- assert (paramTan && " Parameter type does not have a tangent space? " );
553+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
554+ inoutParam. getInterfaceType (), lookupConformance,
555+ substGenericParams, substReplacements, ctx );
543556 differentialResults.push_back (
544- {paramTan-> getCanonicalType () , ResultConvention::Indirect});
557+ {inoutParamTanType , ResultConvention::Indirect});
545558 }
546559
547560 SubstitutionMap substitutions;
@@ -648,23 +661,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
648661 // Handle formal original result.
649662 if (resultIndex < originalFnTy->getNumResults ()) {
650663 auto &origRes = originalResults[resultIndex];
651- auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
652- lookupConformance);
653- assert (resultTan && " Result type does not have a tangent space?" );
654- auto resultTanType = resultTan->getCanonicalType ();
655- auto paramTanConvention = getTangentParameterConventionForOriginalResult (
656- resultTanType, origRes.getConvention ());
657- if (!resultTanType->hasArchetype () &&
658- !resultTanType->hasTypeParameter ()) {
659- auto resultTanType = resultTan->getCanonicalType ();
660- pullbackParams.push_back ({resultTanType, paramTanConvention});
661- } else {
662- auto gpIndex = substGenericParams.size ();
663- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
664- substGenericParams.push_back (gpType);
665- substReplacements.push_back (resultTanType);
666- pullbackParams.push_back ({gpType, paramTanConvention});
667- }
664+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
665+ origRes.getInterfaceType (), lookupConformance,
666+ substGenericParams, substReplacements, ctx);
667+ auto paramConv = getTangentParameterConventionForOriginalResult (
668+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669+ origRes.getInterfaceType ()
670+ ->getAutoDiffTangentSpace (lookupConformance)
671+ ->getCanonicalType (),
672+ origRes.getConvention ());
673+ pullbackParams.push_back ({resultTanType, paramConv});
668674 continue ;
669675 }
670676 // Handle original `inout` parameter.
@@ -674,28 +680,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
674680 auto paramIndex =
675681 std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
676682 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
677- auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
678- lookupConformance);
679- assert (paramTan && " Parameter type does not have a tangent space?" );
680683 // The pullback parameter convention depends on whether the original `inout`
681684 // paramater is a differentiability parameter.
682685 // - If yes, the pullback parameter convention is `@inout`.
683686 // - If no, the pullback parameter convention is `@in_guaranteed`.
687+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
688+ inoutParam.getInterfaceType (), lookupConformance,
689+ substGenericParams, substReplacements, ctx);
684690 bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
685691 auto paramTanConvention = isWrtInoutParameter
686- ? inoutParam.getConvention ()
687- : ParameterConvention::Indirect_In_Guaranteed;
688- auto paramTanType = paramTan->getCanonicalType ();
689- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
690- pullbackParams.push_back (
691- SILParameterInfo (paramTanType, paramTanConvention));
692- } else {
693- auto gpIndex = substGenericParams.size ();
694- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
695- substGenericParams.push_back (gpType);
696- substReplacements.push_back (paramTanType);
697- pullbackParams.push_back ({gpType, paramTanConvention});
698- }
692+ ? inoutParam.getConvention ()
693+ : ParameterConvention::Indirect_In_Guaranteed;
694+ pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
699695 }
700696
701697 // Collect pullback results.
@@ -707,21 +703,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
707703 // and always appear as pullback parameters.
708704 if (param.isIndirectInOut ())
709705 continue ;
710- auto paramTan =
711- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
712- assert (paramTan && " Parameter type does not have a tangent space?" );
713- auto paramTanType = paramTan->getCanonicalType ();
706+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
707+ param.getInterfaceType (), lookupConformance,
708+ substGenericParams, substReplacements, ctx);
714709 auto resultTanConvention = getTangentResultConventionForOriginalParameter (
715- paramTanType, param.getConvention ());
716- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
717- pullbackResults.push_back ({paramTanType, resultTanConvention});
718- } else {
719- auto gpIndex = substGenericParams.size ();
720- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
721- substGenericParams.push_back (gpType);
722- substReplacements.push_back (paramTanType);
723- pullbackResults.push_back ({gpType, resultTanConvention});
724- }
710+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711+ param.getInterfaceType ()
712+ ->getAutoDiffTangentSpace (lookupConformance)
713+ ->getCanonicalType (),
714+ param.getConvention ());
715+ pullbackResults.push_back ({paramTanType, resultTanConvention});
725716 }
726717 SubstitutionMap substitutions;
727718 if (!substGenericParams.empty ()) {
0 commit comments