@@ -54,10 +54,13 @@ class TypeListPackMatcher {
5454 ArrayRef<Element> lhsElements;
5555 ArrayRef<Element> rhsElements;
5656
57+ std::function<bool (Type)> IsPackExpansionType;
5758protected:
5859 TypeListPackMatcher (ASTContext &ctx, ArrayRef<Element> lhs,
59- ArrayRef<Element> rhs)
60- : ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
60+ ArrayRef<Element> rhs,
61+ std::function<bool (Type)> isPackExpansionType)
62+ : ctx(ctx), lhsElements(lhs), rhsElements(rhs),
63+ IsPackExpansionType (isPackExpansionType) {}
6164
6265public:
6366 SmallVector<MatchedPair, 4 > pairs;
@@ -86,8 +89,8 @@ class TypeListPackMatcher {
8689 auto lhsType = getElementType (lhsElt);
8790 auto rhsType = getElementType (rhsElt);
8891
89- if (lhsType-> template is <PackExpansionType>( ) ||
90- rhsType-> template is <PackExpansionType>( )) {
92+ if (IsPackExpansionType (lhsType ) ||
93+ IsPackExpansionType (rhsType )) {
9194 break ;
9295 }
9396
@@ -115,8 +118,8 @@ class TypeListPackMatcher {
115118 auto lhsType = getElementType (lhsElt);
116119 auto rhsType = getElementType (rhsElt);
117120
118- if (lhsType-> template is <PackExpansionType>( ) ||
119- rhsType-> template is <PackExpansionType>( )) {
121+ if (IsPackExpansionType (lhsType ) ||
122+ IsPackExpansionType (rhsType )) {
120123 break ;
121124 }
122125
@@ -139,7 +142,7 @@ class TypeListPackMatcher {
139142 // to what remains of the right hand side.
140143 if (lhsElts.size () == 1 ) {
141144 auto lhsType = getElementType (lhsElts[0 ]);
142- if (auto *lhsExpansion = lhsType-> template getAs <PackExpansionType>( )) {
145+ if (IsPackExpansionType (lhsType )) {
143146 unsigned lhsIdx = prefixLength;
144147 unsigned rhsIdx = prefixLength;
145148
@@ -154,7 +157,7 @@ class TypeListPackMatcher {
154157 auto rhs = createPackBinding (rhsTypes);
155158
156159 // FIXME: Check lhs flags
157- pairs.emplace_back (lhsExpansion , rhs, lhsIdx, rhsIdx);
160+ pairs.emplace_back (lhsType , rhs, lhsIdx, rhsIdx);
158161 return false ;
159162 }
160163 }
@@ -163,7 +166,7 @@ class TypeListPackMatcher {
163166 // to what remains of the left hand side.
164167 if (rhsElts.size () == 1 ) {
165168 auto rhsType = getElementType (rhsElts[0 ]);
166- if (auto *rhsExpansion = rhsType-> template getAs <PackExpansionType>( )) {
169+ if (IsPackExpansionType (rhsType )) {
167170 unsigned lhsIdx = prefixLength;
168171 unsigned rhsIdx = prefixLength;
169172
@@ -178,7 +181,7 @@ class TypeListPackMatcher {
178181 auto lhs = createPackBinding (lhsTypes);
179182
180183 // FIXME: Check rhs flags
181- pairs.emplace_back (lhs, rhsExpansion , lhsIdx, rhsIdx);
184+ pairs.emplace_back (lhs, rhsType , lhsIdx, rhsIdx);
182185 return false ;
183186 }
184187 }
@@ -197,14 +200,11 @@ class TypeListPackMatcher {
197200 Type getElementType (const Element &) const ;
198201 ParameterTypeFlags getElementFlags (const Element &) const ;
199202
200- PackExpansionType * createPackBinding (ArrayRef<Type> types) const {
203+ Type createPackBinding (ArrayRef<Type> types) const {
201204 // If there is only one element and it's a PackExpansionType,
202205 // return it directly.
203- if (types.size () == 1 ) {
204- if (auto *expansionType = types.front ()->getAs <PackExpansionType>()) {
205- return expansionType;
206- }
207- }
206+ if (types.size () == 1 && IsPackExpansionType (types.front ()))
207+ return types.front ();
208208
209209 // Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210210 auto *packType = PackType::get (ctx, types);
@@ -220,10 +220,12 @@ class TypeListPackMatcher {
220220// / other side.
221221class TuplePackMatcher : public TypeListPackMatcher <TupleTypeElt> {
222222public:
223- TuplePackMatcher (TupleType *lhsTuple, TupleType *rhsTuple)
224- : TypeListPackMatcher(lhsTuple->getASTContext (),
225- lhsTuple->getElements(),
226- rhsTuple->getElements()) {}
223+ TuplePackMatcher (
224+ TupleType *lhsTuple, TupleType *rhsTuple,
225+ std::function<bool (Type)> isPackExpansionType =
226+ [](Type T) { return T->is <PackExpansionType>(); })
227+ : TypeListPackMatcher(lhsTuple->getASTContext (), lhsTuple->getElements(),
228+ rhsTuple->getElements(), isPackExpansionType) {}
227229};
228230
229231// / Performs a structural match of two lists of (unlabeled) function
@@ -235,9 +237,12 @@ class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
235237// / other side.
236238class ParamPackMatcher : public TypeListPackMatcher <AnyFunctionType::Param> {
237239public:
238- ParamPackMatcher (ArrayRef<AnyFunctionType::Param> lhsParams,
239- ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240- : TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
240+ ParamPackMatcher (
241+ ArrayRef<AnyFunctionType::Param> lhsParams,
242+ ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx,
243+ std::function<bool (Type)> isPackExpansionType =
244+ [](Type T) { return T->is <PackExpansionType>(); })
245+ : TypeListPackMatcher(ctx, lhsParams, rhsParams, isPackExpansionType) {}
241246};
242247
243248// / Performs a structural match of two lists of types.
@@ -248,8 +253,11 @@ class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
248253// / other side.
249254class PackMatcher : public TypeListPackMatcher <Type> {
250255public:
251- PackMatcher (ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252- : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
256+ PackMatcher (
257+ ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx,
258+ std::function<bool (Type)> isPackExpansionType =
259+ [](Type T) { return T->is <PackExpansionType>(); })
260+ : TypeListPackMatcher(ctx, lhsTypes, rhsTypes, isPackExpansionType) {}
253261};
254262
255263} // end namespace swift
0 commit comments