@@ -408,32 +408,6 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408408 return buildGenericSignature (ctx, sig, {}, reqs).getCanonicalSignature ();
409409}
410410
411- // / Given an original type, computes its tangent type for the purpose of
412- // / building a linear map using this type. When the original type is an
413- // / archetype or contains a type parameter, appends a new generic parameter and
414- // / a corresponding replacement type to the given containers.
415- static CanType getAutoDiffTangentTypeForLinearMap (
416- Type originalType,
417- LookupConformanceFn lookupConformance,
418- SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419- SmallVectorImpl<Type> &substReplacements,
420- ASTContext &context
421- ) {
422- auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
423- assert (maybeTanType && " Type does not have a tangent space?" );
424- auto tanType = maybeTanType->getCanonicalType ();
425- // If concrete, the tangent type is concrete.
426- if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
427- return tanType;
428- // Otherwise, the tangent type is a new generic parameter substituted for the
429- // tangent type.
430- auto gpIndex = substGenericParams.size ();
431- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
432- substGenericParams.push_back (gpType);
433- substReplacements.push_back (tanType);
434- return gpType;
435- }
436-
437411// / Returns the differential type for the given original function type,
438412// / parameter indices, and result index.
439413static CanSILFunctionType getAutoDiffDifferentialType (
@@ -510,32 +484,45 @@ static CanSILFunctionType getAutoDiffDifferentialType(
510484 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
511485 SmallVector<SILParameterInfo, 8 > differentialParams;
512486 for (auto ¶m : diffParams) {
513- auto paramTanType = getAutoDiffTangentTypeForLinearMap (
514- param.getInterfaceType (), lookupConformance,
515- substGenericParams, substReplacements, ctx);
516- auto paramConv = getTangentParameterConvention (
517- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518- param.getInterfaceType ()
519- ->getAutoDiffTangentSpace (lookupConformance)
520- ->getCanonicalType (),
521- param.getConvention ());
522- differentialParams.push_back ({paramTanType, paramConv});
487+ auto paramTan =
488+ param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
489+ assert (paramTan && " Parameter type does not have a tangent space?" );
490+ auto paramTanType = paramTan->getCanonicalType ();
491+ auto paramConv = getTangentParameterConvention (paramTanType,
492+ param.getConvention ());
493+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
494+ differentialParams.push_back (
495+ {paramTan->getCanonicalType (), paramConv});
496+ } else {
497+ auto gpIndex = substGenericParams.size ();
498+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
499+ substGenericParams.push_back (gpType);
500+ substReplacements.push_back (paramTanType);
501+ differentialParams.push_back ({gpType, paramConv});
502+ }
523503 }
524504 SmallVector<SILResultInfo, 1 > differentialResults;
525505 for (auto resultIndex : resultIndices->getIndices ()) {
526506 // Handle formal original result.
527507 if (resultIndex < originalFnTy->getNumResults ()) {
528508 auto &result = originalResults[resultIndex];
529- auto resultTanType = getAutoDiffTangentTypeForLinearMap (
530- result.getInterfaceType (), lookupConformance,
531- substGenericParams, substReplacements, ctx);
532- auto resultConv = getTangentResultConvention (
533- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534- result.getInterfaceType ()
535- ->getAutoDiffTangentSpace (lookupConformance)
536- ->getCanonicalType (),
537- result.getConvention ());
538- differentialResults.push_back ({resultTanType, resultConv});
509+ auto resultTan =
510+ result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
511+ assert (resultTan && " Result type does not have a tangent space?" );
512+ auto resultTanType = resultTan->getCanonicalType ();
513+ auto resultConv =
514+ getTangentResultConvention (resultTanType, result.getConvention ());
515+ if (!resultTanType->hasArchetype () &&
516+ !resultTanType->hasTypeParameter ()) {
517+ differentialResults.push_back (
518+ {resultTan->getCanonicalType (), resultConv});
519+ } else {
520+ auto gpIndex = substGenericParams.size ();
521+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
522+ substGenericParams.push_back (gpType);
523+ substReplacements.push_back (resultTanType);
524+ differentialResults.push_back ({gpType, resultConv});
525+ }
539526 continue ;
540527 }
541528 // Handle original `inout` parameter.
@@ -550,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
550537 if (parameterIndices->contains (paramIndex))
551538 continue ;
552539 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
553- auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
554- inoutParam. getInterfaceType (), lookupConformance,
555- substGenericParams, substReplacements, ctx );
540+ auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
541+ lookupConformance);
542+ assert (paramTan && " Parameter type does not have a tangent space? " );
556543 differentialResults.push_back (
557- {inoutParamTanType , ResultConvention::Indirect});
544+ {paramTan-> getCanonicalType () , ResultConvention::Indirect});
558545 }
559546
560547 SubstitutionMap substitutions;
@@ -661,16 +648,23 @@ static CanSILFunctionType getAutoDiffPullbackType(
661648 // Handle formal original result.
662649 if (resultIndex < originalFnTy->getNumResults ()) {
663650 auto &origRes = originalResults[resultIndex];
664- auto resultTanType = getAutoDiffTangentTypeForLinearMap (
665- origRes.getInterfaceType (), lookupConformance,
666- substGenericParams, substReplacements, ctx);
667- auto paramConv = getTangentParameterConventionForOriginalResult (
668- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669- origRes.getInterfaceType ()
670- ->getAutoDiffTangentSpace (lookupConformance)
671- ->getCanonicalType (),
672- origRes.getConvention ());
673- pullbackParams.push_back ({resultTanType, paramConv});
651+ auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
652+ lookupConformance);
653+ assert (resultTan && " Result type does not have a tangent space?" );
654+ auto resultTanType = resultTan->getCanonicalType ();
655+ auto paramTanConvention = getTangentParameterConventionForOriginalResult (
656+ resultTanType, origRes.getConvention ());
657+ if (!resultTanType->hasArchetype () &&
658+ !resultTanType->hasTypeParameter ()) {
659+ auto resultTanType = resultTan->getCanonicalType ();
660+ pullbackParams.push_back ({resultTanType, paramTanConvention});
661+ } else {
662+ auto gpIndex = substGenericParams.size ();
663+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
664+ substGenericParams.push_back (gpType);
665+ substReplacements.push_back (resultTanType);
666+ pullbackParams.push_back ({gpType, paramTanConvention});
667+ }
674668 continue ;
675669 }
676670 // Handle original `inout` parameter.
@@ -680,18 +674,28 @@ static CanSILFunctionType getAutoDiffPullbackType(
680674 auto paramIndex =
681675 std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
682676 auto inoutParam = originalFnTy->getParameters ()[paramIndex];
677+ auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
678+ lookupConformance);
679+ assert (paramTan && " Parameter type does not have a tangent space?" );
683680 // The pullback parameter convention depends on whether the original `inout`
684681 // paramater is a differentiability parameter.
685682 // - If yes, the pullback parameter convention is `@inout`.
686683 // - If no, the pullback parameter convention is `@in_guaranteed`.
687- auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
688- inoutParam.getInterfaceType (), lookupConformance,
689- substGenericParams, substReplacements, ctx);
690684 bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
691685 auto paramTanConvention = isWrtInoutParameter
692- ? inoutParam.getConvention ()
693- : ParameterConvention::Indirect_In_Guaranteed;
694- pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
686+ ? inoutParam.getConvention ()
687+ : ParameterConvention::Indirect_In_Guaranteed;
688+ auto paramTanType = paramTan->getCanonicalType ();
689+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
690+ pullbackParams.push_back (
691+ SILParameterInfo (paramTanType, paramTanConvention));
692+ } else {
693+ auto gpIndex = substGenericParams.size ();
694+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
695+ substGenericParams.push_back (gpType);
696+ substReplacements.push_back (paramTanType);
697+ pullbackParams.push_back ({gpType, paramTanConvention});
698+ }
695699 }
696700
697701 // Collect pullback results.
@@ -703,16 +707,21 @@ static CanSILFunctionType getAutoDiffPullbackType(
703707 // and always appear as pullback parameters.
704708 if (param.isIndirectInOut ())
705709 continue ;
706- auto paramTanType = getAutoDiffTangentTypeForLinearMap (
707- param.getInterfaceType (), lookupConformance,
708- substGenericParams, substReplacements, ctx);
710+ auto paramTan =
711+ param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
712+ assert (paramTan && " Parameter type does not have a tangent space?" );
713+ auto paramTanType = paramTan->getCanonicalType ();
709714 auto resultTanConvention = getTangentResultConventionForOriginalParameter (
710- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711- param.getInterfaceType ()
712- ->getAutoDiffTangentSpace (lookupConformance)
713- ->getCanonicalType (),
714- param.getConvention ());
715- pullbackResults.push_back ({paramTanType, resultTanConvention});
715+ paramTanType, param.getConvention ());
716+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
717+ pullbackResults.push_back ({paramTanType, resultTanConvention});
718+ } else {
719+ auto gpIndex = substGenericParams.size ();
720+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
721+ substGenericParams.push_back (gpType);
722+ substReplacements.push_back (paramTanType);
723+ pullbackResults.push_back ({gpType, resultTanConvention});
724+ }
716725 }
717726 SubstitutionMap substitutions;
718727 if (!substGenericParams.empty ()) {
0 commit comments