@@ -28,6 +28,51 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
2828 return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
2929}
3030
31+ static ValueDecl *getEqualEqualOperator (NominalTypeDecl *decl) {
32+ auto id = decl->getASTContext ().Id_EqualsOperator ;
33+
34+ auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
35+ auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
36+ if (!equalEqual || !equalEqual->hasParameterList ())
37+ return false ;
38+ auto params = equalEqual->getParameters ();
39+ if (params->size () != 2 )
40+ return false ;
41+ auto lhs = params->get (0 );
42+ auto rhs = params->get (1 );
43+ if (lhs->isInOut () || rhs->isInOut ())
44+ return false ;
45+ auto lhsTy = lhs->getType ();
46+ auto rhsTy = rhs->getType ();
47+ if (!lhsTy || !rhsTy)
48+ return false ;
49+ auto lhsNominal = lhsTy->getAnyNominal ();
50+ auto rhsNominal = rhsTy->getAnyNominal ();
51+ if (lhsNominal != rhsNominal || lhsNominal != decl)
52+ return false ;
53+ return true ;
54+ };
55+
56+ // First look for `func ==` declared as a member.
57+ auto memberResults = decl->lookupDirect (id);
58+ for (const auto &member : memberResults) {
59+ if (isValid (member))
60+ return member;
61+ }
62+
63+ // If no member `func ==` was found, look for out-of-class definitions in the
64+ // same module.
65+ auto module = decl->getModuleContext ();
66+ llvm::SmallVector<ValueDecl *> nonMemberResults;
67+ module ->lookupValue (id, NLKind::UnqualifiedLookup, nonMemberResults);
68+ for (const auto &nonMember : nonMemberResults) {
69+ if (isValid (nonMember))
70+ return nonMember;
71+ }
72+
73+ return nullptr ;
74+ }
75+
3176bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
3277 return getIteratorCategoryDecl (clangDecl);
3378}
@@ -103,24 +148,8 @@ void swift::conformToCxxIteratorIfNeeded(
103148 return ;
104149
105150 // Check if present: `func ==`
106- // FIXME: this only detects `operator==` declared as a member.
107- auto equalEquals = decl->lookupDirect (ctx.Id_EqualsOperator );
108- if (equalEquals.empty ())
109- return ;
110- auto equalEqual = dyn_cast<FuncDecl>(equalEquals.front ());
111- if (!equalEqual || !equalEqual->hasParameterList ())
112- return ;
113- auto equalEqualParams = equalEqual->getParameters ();
114- if (equalEqualParams->size () != 2 )
115- return ;
116- auto equalEqualLHS = equalEqualParams->get (0 );
117- auto equalEqualRHS = equalEqualParams->get (1 );
118- if (equalEqualLHS->isInOut () || equalEqualRHS->isInOut ())
119- return ;
120- auto equalEqualLHSTy = equalEqualLHS->getType ();
121- auto equalEqualRHSTy = equalEqualRHS->getType ();
122- if (!equalEqualLHSTy || !equalEqualRHSTy ||
123- equalEqualLHSTy->getAnyNominal () != equalEqualRHSTy->getAnyNominal ())
151+ auto equalEqual = getEqualEqualOperator (decl);
152+ if (!equalEqual)
124153 return ;
125154
126155 impl.addSynthesizedTypealias (decl, ctx.getIdentifier (" Pointee" ),
0 commit comments