@@ -366,27 +366,6 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
366366 diffParams.push_back (valueAndIndex.value ());
367367}
368368
369- // / Collects the semantic results of the given function type in
370- // / `originalResults`. The semantic results are formal results followed by
371- // / semantic result parameters, in type order.
372- static void
373- getSemanticResults (SILFunctionType *functionType,
374- IndexSubset *parameterIndices,
375- SmallVectorImpl<SILResultInfo> &originalResults) {
376- // Collect original formal results.
377- originalResults.append (functionType->getResults ().begin (),
378- functionType->getResults ().end ());
379-
380- // Collect original semantic result parameters.
381- for (auto i : range (functionType->getNumParameters ())) {
382- auto param = functionType->getParameters ()[i];
383- if (!param.isAutoDiffSemanticResult ())
384- continue ;
385- if (param.getDifferentiability () != SILParameterDifferentiability::NotDifferentiable)
386- originalResults.emplace_back (param.getInterfaceType (), ResultConvention::Indirect);
387- }
388- }
389-
390369static CanGenericSignature buildDifferentiableGenericSignature (CanGenericSignature sig,
391370 CanType tanType,
392371 CanType origTypeOfAbstraction) {
@@ -563,7 +542,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
563542 SmallVector<ProtocolConformanceRef, 4 > substConformances;
564543
565544 SmallVector<SILResultInfo, 2 > originalResults;
566- getSemanticResults (originalFnTy, parameterIndices, originalResults);
545+ autodiff:: getSemanticResults (originalFnTy, parameterIndices, originalResults);
567546
568547 SmallVector<SILParameterInfo, 4 > diffParams;
569548 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
@@ -647,7 +626,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
647626 SmallVector<ProtocolConformanceRef, 4 > substConformances;
648627
649628 SmallVector<SILResultInfo, 2 > originalResults;
650- getSemanticResults (originalFnTy, parameterIndices, originalResults);
629+ autodiff:: getSemanticResults (originalFnTy, parameterIndices, originalResults);
651630
652631 // Given a type, returns its formal SIL parameter info.
653632 auto getTangentParameterConventionForOriginalResult =
@@ -791,9 +770,9 @@ static CanSILFunctionType getAutoDiffPullbackType(
791770 llvm::makeArrayRef (substConformances));
792771 }
793772 return SILFunctionType::get (
794- GenericSignature (), SILFunctionType::ExtInfo (), SILCoroutineKind::None ,
795- ParameterConvention::Direct_Guaranteed, pullbackParams, {},
796- pullbackResults, llvm::None, substitutions,
773+ GenericSignature (), SILFunctionType::ExtInfo (), originalFnTy-> getCoroutineKind () ,
774+ ParameterConvention::Direct_Guaranteed,
775+ pullbackParams, {}, pullbackResults, llvm::None, substitutions,
797776 /* invocationSubstitutions*/ SubstitutionMap (), ctx);
798777}
799778
@@ -804,7 +783,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
804783// / - The invocation generic signature is replaced by the
805784// / `constrainedInvocationGenSig` argument.
806785static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType (
807- SILFunctionType *original, IndexSubset *parameterIndices,
786+ SILFunctionType *original, IndexSubset *parameterIndices, IndexSubset *resultIndices,
808787 LookupConformanceFn lookupConformance,
809788 CanGenericSignature constrainedInvocationGenSig) {
810789 auto originalInvocationGenSig = original->getInvocationGenericSignature ();
@@ -813,6 +792,25 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
813792 constrainedInvocationGenSig->areAllParamsConcrete () &&
814793 " derivative function cannot have invocation generic signature "
815794 " when original function doesn't" );
795+ if (auto patternSig = original->getPatternGenericSignature ()) {
796+ auto constrainedPatternSig =
797+ autodiff::getConstrainedDerivativeGenericSignature (
798+ original, parameterIndices, resultIndices,
799+ patternSig, lookupConformance).getCanonicalSignature ();
800+ auto constrainedPatternSubs =
801+ SubstitutionMap::get (constrainedPatternSig,
802+ QuerySubstitutionMap{original->getPatternSubstitutions ()},
803+ lookupConformance);
804+ return SILFunctionType::get (GenericSignature (),
805+ original->getExtInfo (), original->getCoroutineKind (),
806+ original->getCalleeConvention (),
807+ original->getParameters (), original->getYields (),
808+ original->getResults (), original->getOptionalErrorResult (),
809+ constrainedPatternSubs,
810+ /* invocationSubstitutions*/ SubstitutionMap (), original->getASTContext (),
811+ original->getWitnessMethodConformanceOrInvalid ());
812+ }
813+
816814 return original;
817815 }
818816
@@ -823,10 +821,10 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
823821 if (!constrainedInvocationGenSig)
824822 return original;
825823 constrainedInvocationGenSig =
826- autodiff::getConstrainedDerivativeGenericSignature (
827- original, parameterIndices, constrainedInvocationGenSig ,
828- lookupConformance)
829- .getCanonicalSignature ();
824+ autodiff::getConstrainedDerivativeGenericSignature (
825+ original, parameterIndices, resultIndices ,
826+ constrainedInvocationGenSig,
827+ lookupConformance) .getCanonicalSignature ();
830828
831829 SmallVector<SILParameterInfo, 4 > newParameters;
832830 newParameters.reserve (original->getNumParameters ());
@@ -882,9 +880,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
882880 return cachedResult;
883881
884882 SILFunctionType *constrainedOriginalFnTy =
885- getConstrainedAutoDiffOriginalFunctionType (this , parameterIndices,
883+ getConstrainedAutoDiffOriginalFunctionType (this , parameterIndices, resultIndices,
886884 lookupConformance,
887885 derivativeFnInvocationGenSig);
886+
888887 // Compute closure type.
889888 CanSILFunctionType closureType;
890889 switch (kind) {
@@ -957,11 +956,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
957956 IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
958957 LookupConformanceFn lookupConformance,
959958 CanGenericSignature transposeFnGenSig) {
959+ auto &ctx = getASTContext ();
960+
960961 // Get the "constrained" transpose function generic signature.
961962 if (!transposeFnGenSig)
962963 transposeFnGenSig = getSubstGenericSignature ();
963964 transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature (
964- this , parameterIndices, transposeFnGenSig,
965+ this , parameterIndices, IndexSubset::getDefault (ctx, 0 ),
966+ transposeFnGenSig,
965967 lookupConformance, /* isLinear*/ true )
966968 .getCanonicalSignature ();
967969
0 commit comments