@@ -205,6 +205,42 @@ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
205205 return assocType;
206206}
207207
208+ // / Find the most canonical associated type declaration with the given
209+ // / name among a set of conforming protocols stored in this property map
210+ // / entry.
211+ AssociatedTypeDecl *PropertyBag::getAssociatedType (Identifier name) {
212+ auto found = AssocTypes.find (name);
213+ if (found != AssocTypes.end ())
214+ return found->second ;
215+
216+ AssociatedTypeDecl *assocType = nullptr ;
217+
218+ for (auto *proto : ConformsTo) {
219+ auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
220+ otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
221+
222+ if (otherAssocType->getName () == name &&
223+ (assocType == nullptr ||
224+ TypeDecl::compare (otherAssocType->getProtocol (),
225+ assocType->getProtocol ()) < 0 )) {
226+ assocType = otherAssocType;
227+ }
228+ };
229+
230+ for (auto *otherAssocType : proto->getAssociatedTypeMembers ()) {
231+ checkOtherAssocType (otherAssocType);
232+ }
233+ }
234+
235+ assert (assocType != nullptr && " Need to look harder" );
236+
237+ auto inserted = AssocTypes.insert (std::make_pair (name, assocType)).second ;
238+ assert (inserted);
239+ (void ) inserted;
240+
241+ return assocType;
242+ }
243+
208244// / Compute the interface type for a range of symbols, with an optional
209245// / root type.
210246// /
@@ -233,8 +269,8 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
233269 result = genericParam;
234270 };
235271
236- for (; begin != end; ++begin ) {
237- auto symbol = *begin ;
272+ for (auto *iter = begin; iter != end; ++iter ) {
273+ auto symbol = *iter ;
238274
239275 if (!result) {
240276 // A valid term always begins with a generic parameter, protocol or
@@ -253,7 +289,7 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
253289 handleRoot (GenericTypeParamType::get (/* type sequence*/ false , 0 , 0 ,
254290 ctx.getASTContext ()));
255291
256- // An associated type term at the root means we have a dependent
292+ // An associated type symbol at the root means we have a dependent
257293 // member type rooted at Self; handle the associated type below.
258294 break ;
259295
@@ -281,17 +317,48 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
281317 if (symbol.getKind () == Symbol::Kind::Protocol) {
282318#ifndef NDEBUG
283319 // Ensure that the domain of the suffix contains P.
284- if (begin + 1 < end) {
285- auto protos = (begin + 1 )->getProtocols ();
320+ if (iter + 1 < end) {
321+ auto protos = (iter + 1 )->getProtocols ();
286322 assert (std::find (protos.begin (), protos.end (), symbol.getProtocol ()));
287323 }
288324#endif
289325 continue ;
290326 }
291327
292328 // We should have a resolved type at this point.
293- auto *assocType =
294- ctx.getAssociatedTypeForSymbol (symbol);
329+ AssociatedTypeDecl *assocType;
330+
331+ if (begin == iter) {
332+ // FIXME: Eliminate this case once merged associated types are gone.
333+ assocType = ctx.getAssociatedTypeForSymbol (symbol);
334+ } else {
335+ // The protocol stored in an associated type symbol appearing in a
336+ // canonical term is not necessarily the right protocol to look for
337+ // an associated type declaration to get a canonical _type_, because
338+ // the reduction order on terms is different than the canonical order
339+ // on types.
340+ //
341+ // Instead, find all protocols that the prefix conforms to, and look
342+ // for an associated type in those protocols.
343+ MutableTerm prefix (begin, iter);
344+ assert (prefix.size () > 0 );
345+
346+ auto *props = map.lookUpProperties (prefix.rbegin (), prefix.rend ());
347+ assert (props != nullptr );
348+
349+ // Assert that the associated type's protocol appears among the set
350+ // of protocols that the prefix conforms to.
351+ #ifndef NDEBUG
352+ auto conformsTo = props->getConformsTo ();
353+ for (auto *otherProto : symbol.getProtocols ()) {
354+ assert (std::find (conformsTo.begin (), conformsTo.end (), otherProto)
355+ != conformsTo.end ());
356+ }
357+ #endif
358+
359+ assocType = props->getAssociatedType (symbol.getName ());
360+ }
361+
295362 result = DependentMemberType::get (result, assocType);
296363 }
297364
0 commit comments