@@ -1218,12 +1218,15 @@ forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams,
12181218 // Honor ignoreFinalParam for the substituted parameters on all paths.
12191219 if (ignoreFinalParam) substParams = substParams.drop_back ();
12201220
1221- // If this isn't a function type, use the substituted type.
1222- if (isTypeParameterOrOpaqueArchetype ()) {
1221+ // If we don't have a function type, use the substituted type.
1222+ if (isTypeParameterOrOpaqueArchetype () ||
1223+ getKind () == Kind::OpaqueFunction ||
1224+ getKind () == Kind::OpaqueDerivativeFunction) {
12231225 for (auto substParamIndex : indices (substParams)) {
12241226 handleScalar (substParamIndex, substParamIndex,
12251227 substParams[substParamIndex].getParameterFlags (),
1226- *this , substParams[substParamIndex]);
1228+ AbstractionPattern::getOpaque (),
1229+ substParams[substParamIndex]);
12271230 }
12281231 return ;
12291232 }
@@ -1829,38 +1832,57 @@ class SubstFunctionTypePatternVisitor
18291832 SmallVector<Requirement, 2 > substRequirements;
18301833 SmallVector<Type, 2 > substReplacementTypes;
18311834 CanType substYieldType;
1835+ bool WithinExpansion = false ;
18321836
18331837 SubstFunctionTypePatternVisitor (TypeConverter &TC)
18341838 : TC(TC) {}
1835-
1839+
18361840 // Creates and returns a fresh type parameter in the substituted generic
18371841 // signature if `pattern` is a type parameter or opaque archetype. Returns
18381842 // null otherwise.
1839- CanType handleTypeParameterInAbstractionPattern (AbstractionPattern pattern,
1840- CanType substTy) {
1843+ CanType handleTypeParameter (AbstractionPattern pattern, CanType substTy) {
18411844 if (!pattern.isTypeParameterOrOpaqueArchetype ())
18421845 return CanType ();
18431846
1844- // If so, let's put a fresh generic parameter in the substituted signature
1845- // here.
18461847 unsigned paramIndex = substGenericParams.size ();
18471848
1848- bool isParameterPack = false ;
1849- if (substTy->isParameterPack () || substTy->is <PackArchetypeType>())
1850- isParameterPack = true ;
1851- else if (pattern.isTypeParameterPack ())
1852- isParameterPack = true ;
1849+ // Pack parameters that aren't within expansions should just be
1850+ // abstracted as scalars.
1851+ bool isParameterPack = (WithinExpansion && pattern.isTypeParameterPack ());
18531852
18541853 auto gp = GenericTypeParamType::get (isParameterPack, 0 , paramIndex,
18551854 TC.Context );
18561855 substGenericParams.push_back (gp);
1857- if (isParameterPack) {
1858- substReplacementTypes.push_back (
1859- PackType::getSingletonPackExpansion (substTy));
1856+
1857+ CanType replacement;
1858+
1859+ if (WithinExpansion) {
1860+ // If we're within an expansion, and there are substitutions in the
1861+ // abstraction pattern, use those instead of substTy. substTy is not
1862+ // contextually meaningful in this case; see handlePackExpansion.
1863+ if (auto subs = pattern.getGenericSubstitutions ()) {
1864+ replacement = pattern.getType ().subst (subs)->getCanonicalType ();
1865+
1866+ // If we don't have substitutions, but we're abstracting a pack
1867+ // parameter, assume that we're lowering a function type using
1868+ // itself as its pattern or something like. The substituted type
1869+ // should be `each T` for some pack reference; wrap that in a pack.
1870+ } else if (isParameterPack) {
1871+ replacement = CanPackType::getSingletonPackExpansion (substTy);
1872+
1873+ // Otherwise, just use substTy.
1874+ } else {
1875+ replacement = substTy;
1876+ }
1877+
1878+ // Otherwise, we can just use substTy.
18601879 } else {
1861- substReplacementTypes.push_back (substTy);
1880+ assert (!isParameterPack);
1881+ assert (!isa<PackType>(substTy));
1882+ replacement = substTy;
18621883 }
1863-
1884+ substReplacementTypes.push_back (replacement);
1885+
18641886 if (auto layout = pattern.getLayoutConstraint ()) {
18651887 // Look at the layout constraint on this position in the abstraction pattern
18661888 // and carry it over, with some generalization to the point it affects
@@ -1914,7 +1936,7 @@ class SubstFunctionTypePatternVisitor
19141936 }
19151937
19161938 CanType visit (CanType t, AbstractionPattern pattern) {
1917- if (auto gp = handleTypeParameterInAbstractionPattern (pattern, t))
1939+ if (auto gp = handleTypeParameter (pattern, t))
19181940 return gp;
19191941
19201942 return CanTypeVisitor::visit (t, pattern);
@@ -1960,7 +1982,7 @@ class SubstFunctionTypePatternVisitor
19601982 if (!orig->hasTypeParameter ()
19611983 && !orig->hasArchetype ()
19621984 && !orig->hasOpaqueArchetype ()) {
1963- return CanType ( subst) ;
1985+ return subst;
19641986 }
19651987
19661988 // If the substituted type is a subclass of the abstraction pattern
@@ -2067,26 +2089,81 @@ class SubstFunctionTypePatternVisitor
20672089
20682090 CanType visitPackExpansionType (CanPackExpansionType pack,
20692091 AbstractionPattern pattern) {
2070- // Avoid walking into the pattern and count type if we can help it.
2071- if (!pack->hasTypeParameter () && !pack->hasArchetype () &&
2072- !pack->hasOpaqueArchetype ()) {
2073- return CanType (pack);
2092+ llvm_unreachable (" shouldn't encounter pack expansion by itself" );
2093+ }
2094+
2095+ CanType handlePackExpansion (AbstractionPattern origExpansion,
2096+ CanType candidateSubstType) {
2097+ // When we're within a pack expansion, pack references matching that
2098+ // expansion should be abstracted as packs. The substitution will be
2099+ // the pack substitution for that parameter recorded in the pattern.
2100+
2101+ // Remember that we're within an expansion.
2102+ // FIXME: when we introduce PackReferenceType we'll need to be clear
2103+ // about which pack expansions to treat this way.
2104+ llvm::SaveAndRestore<bool > scope (WithinExpansion, true );
2105+
2106+ auto origPatternType = origExpansion.getPackExpansionPatternType ();
2107+
2108+ // We only really need a subst type here if we don't have
2109+ // substitutions in the pattern, because handleTypeParameter
2110+ // will always those substitutions within an expansion if
2111+ // they're available. And if we don't have substitutions in the
2112+ // pattern, we can't map the pack expansion to a concrete set
2113+ // of expanded components, so we should have exactly one subst
2114+ // type.
2115+ CanType substPatternType;
2116+ if (origExpansion.getGenericSubstitutions ()) {
2117+ substPatternType = origPatternType.getType ();
2118+ } else {
2119+ assert (candidateSubstType);
2120+ substPatternType =
2121+ cast<PackExpansionType>(candidateSubstType).getPatternType ();
20742122 }
20752123
2076- auto substPatternType = visit (pack.getPatternType (),
2077- pattern.getPackExpansionPatternType ());
2078- auto substCountType = visit (pack.getCountType (),
2079- AbstractionPattern::getOpaque ());
2124+ // Recursively visit the pattern type.
2125+ auto patternTy = visit (substPatternType, origPatternType);
20802126
2081- SmallVector<Type> rootParameterPacks;
2082- substPatternType-> getTypeParameterPacks (rootParameterPacks );
2127+ // Find a pack parameter from the pattern to expand over.
2128+ auto countParam = findExpandedPackParameter (patternTy );
20832129
2084- for ( auto parameterPack : rootParameterPacks) {
2085- substRequirements. emplace_back (RequirementKind::SameShape ,
2086- parameterPack, substCountType);
2087- }
2130+ // If that didn't work, we should be able to find an expansion
2131+ // to use from either the substituted type or the subs. At worst ,
2132+ // we can make one.
2133+ assert (countParam && " implementable but lazy " );
20882134
2089- return CanPackExpansionType::get (substPatternType, substCountType);
2135+ return CanPackExpansionType::get (patternTy, countParam);
2136+ }
2137+
2138+ static CanType findExpandedPackParameter (CanType patternType) {
2139+ struct Walker : public TypeWalker {
2140+ CanType Result;
2141+ Action walkToTypePre (Type _ty) override {
2142+ auto ty = CanType (_ty);
2143+
2144+ // Don't recurse inside pack expansions.
2145+ if (isa<PackExpansionType>(ty)) {
2146+ return Action::SkipChildren;
2147+ }
2148+
2149+ // Consider type parameters.
2150+ if (ty->isTypeParameter ()) {
2151+ auto param = ty->getRootGenericParam ();
2152+ if (param->isParameterPack ()) {
2153+ Result = CanType (param);
2154+ return Action::Stop;
2155+ }
2156+ return Action::SkipChildren;
2157+ }
2158+
2159+ // Otherwise continue.
2160+ return Action::Continue;
2161+ }
2162+ };
2163+
2164+ Walker walker;
2165+ patternType.walk (walker);
2166+ return walker.Result ;
20902167 }
20912168
20922169 CanType visitExistentialType (CanExistentialType exist,
@@ -2121,14 +2198,31 @@ class SubstFunctionTypePatternVisitor
21212198 }
21222199
21232200 CanType visitTupleType (CanTupleType tuple, AbstractionPattern pattern) {
2124- // Break down the tuple.
2201+ assert (pattern.isTuple ());
2202+
2203+ // It's pretty weird for us to end up in this case with an
2204+ // open-coded tuple pattern, but it happens with opaque derivative
2205+ // functions in autodiff.
2206+ CanTupleType origTupleTypeForLabels = pattern.getAs <TupleType>();
2207+ if (!origTupleTypeForLabels) origTupleTypeForLabels = tuple;
2208+
21252209 SmallVector<TupleTypeElt, 4 > tupleElts;
2126- for (unsigned i = 0 ; i < tuple->getNumElements (); ++i) {
2127- auto elt = tuple->getElement (i);
2128- auto substEltTy = visit (tuple.getElementType (i),
2129- pattern.getTupleElementType (i));
2130- tupleElts.emplace_back (substEltTy, elt.getName ());
2131- }
2210+ pattern.forEachTupleElement (tuple,
2211+ [&](unsigned origEltIndex, unsigned substEltIndex,
2212+ AbstractionPattern origEltType, CanType substEltType) {
2213+ auto eltTy = visit (substEltType, origEltType);
2214+ auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2215+ tupleElts.push_back (origElt.getWithType (eltTy));
2216+ }, [&](unsigned origEltIndex, unsigned substEltIndex,
2217+ AbstractionPattern origExpansionType,
2218+ CanTupleEltTypeArrayRef substEltTypes) {
2219+ CanType candidateSubstType;
2220+ if (!substEltTypes.empty ())
2221+ candidateSubstType = substEltTypes[0 ];
2222+ auto eltTy = handlePackExpansion (origExpansionType, candidateSubstType);
2223+ auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2224+ tupleElts.push_back (origElt.getWithType (eltTy));
2225+ });
21322226
21332227 return CanType (TupleType::get (tupleElts, TC.Context ));
21342228 }
@@ -2138,19 +2232,29 @@ class SubstFunctionTypePatternVisitor
21382232 CanType yieldType,
21392233 AbstractionPattern yieldPattern) {
21402234 SmallVector<FunctionType::Param, 4 > newParams;
2141-
2142- for (unsigned i = 0 ; i < func->getParams ().size (); ++i) {
2143- auto param = func->getParams ()[i];
2144- // Lower the formal type of the argument binding, eliminating variadicity.
2145- auto newParamTy = visit (CanType (param.getParameterType (true )),
2146- pattern.getFunctionParamType (i));
2147- auto newParam = FunctionType::Param (newParamTy,
2148- param.getLabel (),
2149- param.getParameterFlags ()
2150- .withVariadic (false ),
2151- param.getInternalLabel ());
2152- newParams.push_back (newParam);
2153- }
2235+ auto addParam = [&](ParameterTypeFlags oldFlags, CanType newType) {
2236+ newParams.push_back (FunctionType::Param (
2237+ newType, /* label*/ Identifier (), oldFlags.withVariadic (false ),
2238+ /* internal label*/ Identifier ()));
2239+ };
2240+
2241+ pattern.forEachFunctionParam (func.getParams (), /* ignore self*/ false ,
2242+ [&](unsigned origParamIndex, unsigned substParamIndex,
2243+ ParameterTypeFlags origFlags, AbstractionPattern origParamType,
2244+ AnyFunctionType::CanParam substParam) {
2245+ auto newParamTy = visit (substParam.getParameterType (), origParamType);
2246+ addParam (origFlags, newParamTy);
2247+ }, [&](unsigned origParamIndex, unsigned substParamIndex,
2248+ ParameterTypeFlags origFlags,
2249+ AbstractionPattern origExpansionType,
2250+ AnyFunctionType::CanParamArrayRef substParams) {
2251+ CanType candidateSubstType;
2252+ if (!substParams.empty ())
2253+ candidateSubstType = substParams[0 ].getParameterType ();
2254+ auto expansionType =
2255+ handlePackExpansion (origExpansionType, candidateSubstType);
2256+ addParam (origFlags, expansionType);
2257+ });
21542258
21552259 if (yieldType) {
21562260 substYieldType = visit (yieldType, yieldPattern);
@@ -2229,9 +2333,9 @@ const {
22292333 yieldType = yieldType->getReducedType (substSig);
22302334
22312335 return std::make_tuple (
2232- AbstractionPattern (substSig, substTy->getReducedType (substSig)),
2336+ AbstractionPattern (subMap, substSig, substTy->getReducedType (substSig)),
22332337 subMap,
22342338 yieldType
2235- ? AbstractionPattern (substSig, yieldType)
2339+ ? AbstractionPattern (subMap, substSig, yieldType)
22362340 : AbstractionPattern::getInvalid ());
22372341}
0 commit comments