@@ -126,6 +126,79 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
126126 return MutableTerm (symbols);
127127}
128128
129+ // / Map an associated type symbol to an associated type declaration.
130+ // /
131+ // / Note that the protocol graph is not part of the caching key; each
132+ // / protocol graph is a subgraph of the global inheritance graph, so
133+ // / the specific choice of subgraph does not change the result.
134+ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol (
135+ Symbol symbol, const ProtocolGraph &protos) {
136+ auto found = AssocTypes.find (symbol);
137+ if (found != AssocTypes.end ())
138+ return found->second ;
139+
140+ assert (symbol.getKind () == Symbol::Kind::AssociatedType);
141+ auto *proto = symbol.getProtocols ()[0 ];
142+ auto name = symbol.getName ();
143+
144+ AssociatedTypeDecl *assocType = nullptr ;
145+
146+ // Special case: handle unknown protocols, since they can appear in the
147+ // invalid types that getCanonicalTypeInContext() must handle via
148+ // concrete substitution; see the definition of getCanonicalTypeInContext()
149+ // below for details.
150+ if (!protos.isKnownProtocol (proto)) {
151+ assert (symbol.getProtocols ().size () == 1 &&
152+ " Unknown associated type symbol must have a single protocol" );
153+ assocType = proto->getAssociatedType (name)->getAssociatedTypeAnchor ();
154+ } else {
155+ // An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
156+ // P0...Pn and an identifier 'A'.
157+ //
158+ // We map it back to a AssociatedTypeDecl as follows:
159+ //
160+ // - For each protocol Pn, look for associated types A in Pn itself,
161+ // and all protocols that Pn refines.
162+ //
163+ // - For each candidate associated type An in protocol Qn where
164+ // Pn refines Qn, get the associated type anchor An' defined in
165+ // protocol Qn', where Qn refines Qn'.
166+ //
167+ // - Out of all the candidiate pairs (Qn', An'), pick the one where
168+ // the protocol Qn' is the lowest element according to the linear
169+ // order defined by TypeDecl::compare().
170+ //
171+ // The associated type An' is then the canonical associated type
172+ // representative of the associated type symbol [P0&...&Pn:A].
173+ //
174+ for (auto *proto : symbol.getProtocols ()) {
175+ const auto &info = protos.getProtocolInfo (proto);
176+ auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
177+ otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
178+
179+ if (otherAssocType->getName () == name &&
180+ (assocType == nullptr ||
181+ TypeDecl::compare (otherAssocType->getProtocol (),
182+ assocType->getProtocol ()) < 0 )) {
183+ assocType = otherAssocType;
184+ }
185+ };
186+
187+ for (auto *otherAssocType : info.AssociatedTypes ) {
188+ checkOtherAssocType (otherAssocType);
189+ }
190+
191+ for (auto *otherAssocType : info.InheritedAssociatedTypes ) {
192+ checkOtherAssocType (otherAssocType);
193+ }
194+ }
195+ }
196+
197+ assert (assocType && " Need to look harder" );
198+ AssocTypes[symbol] = assocType;
199+ return assocType;
200+ }
201+
129202// / Compute the interface type for a range of symbols, with an optional
130203// / root type.
131204// /
@@ -136,7 +209,7 @@ template<typename Iter>
136209Type getTypeForSymbolRange (Iter begin, Iter end, Type root,
137210 TypeArrayView<GenericTypeParamType> genericParams,
138211 const ProtocolGraph &protos,
139- ASTContext &ctx) {
212+ const RewriteContext &ctx) {
140213 Type result = root;
141214
142215 auto handleRoot = [&](GenericTypeParamType *genericParam) {
@@ -166,11 +239,11 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
166239 continue ;
167240
168241 case Symbol::Kind::Protocol:
169- handleRoot (GenericTypeParamType::get (0 , 0 , ctx));
242+ handleRoot (GenericTypeParamType::get (0 , 0 , ctx. getASTContext () ));
170243 continue ;
171244
172245 case Symbol::Kind::AssociatedType:
173- handleRoot (GenericTypeParamType::get (0 , 0 , ctx));
246+ handleRoot (GenericTypeParamType::get (0 , 0 , ctx. getASTContext () ));
174247
175248 // An associated type term at the root means we have a dependent
176249 // member type rooted at Self; handle the associated type below.
@@ -191,68 +264,9 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
191264 }
192265
193266 // We should have a resolved type at this point.
194- assert (symbol.getKind () == Symbol::Kind::AssociatedType);
195- auto *proto = symbol.getProtocols ()[0 ];
196- auto name = symbol.getName ();
197-
198- AssociatedTypeDecl *assocType = nullptr ;
199-
200- // Special case: handle unknown protocols, since they can appear in the
201- // invalid types that getCanonicalTypeInContext() must handle via
202- // concrete substitution; see the definition of getCanonicalTypeInContext()
203- // below for details.
204- if (!protos.isKnownProtocol (proto)) {
205- assert (root &&
206- " We only allow unknown protocols in getRelativeTypeForTerm()" );
207- assert (symbol.getProtocols ().size () == 1 &&
208- " Unknown associated type symbol must have a single protocol" );
209- assocType = proto->getAssociatedType (name)->getAssociatedTypeAnchor ();
210- } else {
211- // FIXME: Cache this
212- //
213- // An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
214- // P0...Pn and an identifier 'A'.
215- //
216- // We map it back to a AssociatedTypeDecl as follows:
217- //
218- // - For each protocol Pn, look for associated types A in Pn itself,
219- // and all protocols that Pn refines.
220- //
221- // - For each candidate associated type An in protocol Qn where
222- // Pn refines Qn, get the associated type anchor An' defined in
223- // protocol Qn', where Qn refines Qn'.
224- //
225- // - Out of all the candidiate pairs (Qn', An'), pick the one where
226- // the protocol Qn' is the lowest element according to the linear
227- // order defined by TypeDecl::compare().
228- //
229- // The associated type An' is then the canonical associated type
230- // representative of the associated type symbol [P0&...&Pn:A].
231- //
232- for (auto *proto : symbol.getProtocols ()) {
233- const auto &info = protos.getProtocolInfo (proto);
234- auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
235- otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
236-
237- if (otherAssocType->getName () == name &&
238- (assocType == nullptr ||
239- TypeDecl::compare (otherAssocType->getProtocol (),
240- assocType->getProtocol ()) < 0 )) {
241- assocType = otherAssocType;
242- }
243- };
244-
245- for (auto *otherAssocType : info.AssociatedTypes ) {
246- checkOtherAssocType (otherAssocType);
247- }
248-
249- for (auto *otherAssocType : info.InheritedAssociatedTypes ) {
250- checkOtherAssocType (otherAssocType);
251- }
252- }
253- }
254-
255- assert (assocType && " Need to look harder" );
267+ auto *assocType =
268+ const_cast <RewriteContext &>(ctx)
269+ .getAssociatedTypeForSymbol (symbol, protos);
256270 result = DependentMemberType::get (result, assocType);
257271 }
258272
@@ -263,14 +277,14 @@ Type RewriteContext::getTypeForTerm(Term term,
263277 TypeArrayView<GenericTypeParamType> genericParams,
264278 const ProtocolGraph &protos) const {
265279 return getTypeForSymbolRange (term.begin (), term.end (), Type (),
266- genericParams, protos, Context );
280+ genericParams, protos, * this );
267281}
268282
269283Type RewriteContext::getTypeForTerm (const MutableTerm &term,
270284 TypeArrayView<GenericTypeParamType> genericParams,
271285 const ProtocolGraph &protos) const {
272286 return getTypeForSymbolRange (term.begin (), term.end (), Type (),
273- genericParams, protos, Context );
287+ genericParams, protos, * this );
274288}
275289
276290Type RewriteContext::getRelativeTypeForTerm (
@@ -281,7 +295,7 @@ Type RewriteContext::getRelativeTypeForTerm(
281295 auto genericParam = CanGenericTypeParamType::get (0 , 0 , Context);
282296 return getTypeForSymbolRange (
283297 term.begin () + prefix.size (), term.end (), genericParam,
284- { }, protos, Context );
298+ { }, protos, * this );
285299}
286300
287301// / We print stats in the destructor, which should get executed at the end of
0 commit comments