@@ -156,7 +156,7 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
156156 return lookupOperator (decl, decl->getASTContext ().Id_EqualsOperator , isValid);
157157}
158158
159- static ValueDecl *getMinusOperator (NominalTypeDecl *decl) {
159+ static FuncDecl *getMinusOperator (NominalTypeDecl *decl) {
160160 auto binaryIntegerProto =
161161 decl->getASTContext ().getProtocol (KnownProtocolKind::BinaryInteger);
162162 auto module = decl->getModuleContext ();
@@ -188,11 +188,12 @@ static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
188188 return true ;
189189 };
190190
191- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ),
192- isValid);
191+ ValueDecl *result =
192+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" -" ), isValid);
193+ return dyn_cast_or_null<FuncDecl>(result);
193194}
194195
195- static ValueDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196+ static FuncDecl *getPlusEqualOperator (NominalTypeDecl *decl, Type distanceTy) {
196197 auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
197198 auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
198199 if (!plusEqual || !plusEqual->hasParameterList ())
@@ -219,14 +220,15 @@ static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
219220 return true ;
220221 };
221222
222- return lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ),
223- isValid);
223+ ValueDecl *result =
224+ lookupOperator (decl, decl->getASTContext ().getIdentifier (" +=" ), isValid);
225+ return dyn_cast_or_null<FuncDecl>(result);
224226}
225227
226- static void instantiateTemplatedOperator (
227- ClangImporter::Implementation &impl,
228- const clang::ClassTemplateSpecializationDecl *classDecl,
229- clang::BinaryOperatorKind operatorKind) {
228+ static clang::FunctionDecl *
229+ instantiateTemplatedOperator ( ClangImporter::Implementation &impl,
230+ const clang::CXXRecordDecl *classDecl,
231+ clang::BinaryOperatorKind operatorKind) {
230232
231233 clang::ASTContext &clangCtx = impl.getClangASTContext ();
232234 clang::Sema &clangSema = impl.getClangSema ();
@@ -252,6 +254,7 @@ static void instantiateTemplatedOperator(
252254 if (auto clangCallee = best->Function ) {
253255 auto lookupTable = impl.findLookupTable (classDecl);
254256 addEntryToLookupTable (*lookupTable, clangCallee, impl.getNameImporter ());
257+ return clangCallee;
255258 }
256259 break ;
257260 }
@@ -260,6 +263,95 @@ static void instantiateTemplatedOperator(
260263 case clang::OR_Deleted:
261264 break ;
262265 }
266+
267+ return nullptr ;
268+ }
269+
270+ // / Warning: This function emits an error and stops compilation if the
271+ // / underlying operator function is unavailable in Swift for the current target
272+ // / (see `clang::Sema::DiagnoseAvailabilityOfDecl`).
273+ static bool synthesizeCXXOperator (ClangImporter::Implementation &impl,
274+ const clang::CXXRecordDecl *classDecl,
275+ clang::BinaryOperatorKind operatorKind,
276+ clang::QualType lhsTy, clang::QualType rhsTy,
277+ clang::QualType returnTy) {
278+ auto &clangCtx = impl.getClangASTContext ();
279+ auto &clangSema = impl.getClangSema ();
280+
281+ clang::OverloadedOperatorKind opKind =
282+ clang::BinaryOperator::getOverloadedOperator (operatorKind);
283+ const char *opSpelling = clang::getOperatorSpelling (opKind);
284+
285+ auto declName = clang::DeclarationName (&clangCtx.Idents .get (opSpelling));
286+
287+ // Determine the Clang decl context where the new operator function will be
288+ // created. We use the translation unit as the decl context of the new
289+ // operator, otherwise, the operator might get imported as a static member
290+ // function of a different type (e.g. an operator declared inside of a C++
291+ // namespace would get imported as a member function of a Swift enum), which
292+ // would make the operator un-discoverable to Swift name lookup.
293+ auto declContext =
294+ const_cast <clang::CXXRecordDecl *>(classDecl)->getDeclContext ();
295+ while (!declContext->isTranslationUnit ()) {
296+ declContext = declContext->getParent ();
297+ }
298+
299+ auto equalEqualTy = clangCtx.getFunctionType (
300+ returnTy, {lhsTy, rhsTy}, clang::FunctionProtoType::ExtProtoInfo ());
301+
302+ // Create a `bool operator==(T, T)` function.
303+ auto equalEqualDecl = clang::FunctionDecl::Create (
304+ clangCtx, declContext, clang::SourceLocation (), clang::SourceLocation (),
305+ declName, equalEqualTy, clangCtx.getTrivialTypeSourceInfo (returnTy),
306+ clang::StorageClass::SC_Static);
307+ equalEqualDecl->setImplicit ();
308+ equalEqualDecl->setImplicitlyInline ();
309+ // If this is a static member function of a class, it needs to be public.
310+ equalEqualDecl->setAccess (clang::AccessSpecifier::AS_public);
311+
312+ // Create the parameters of the function. They are not referenced from source
313+ // code, so they don't need to have a name.
314+ auto lhsParamId = nullptr ;
315+ auto lhsTyInfo = clangCtx.getTrivialTypeSourceInfo (lhsTy);
316+ auto lhsParamDecl = clang::ParmVarDecl::Create (
317+ clangCtx, equalEqualDecl, clang::SourceLocation (),
318+ clang::SourceLocation (), lhsParamId, lhsTy, lhsTyInfo,
319+ clang::StorageClass::SC_None, /* DefArg*/ nullptr );
320+ auto lhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
321+ clangCtx, lhsParamDecl, false , lhsTy, clang::ExprValueKind::VK_LValue,
322+ clang::SourceLocation ());
323+
324+ auto rhsParamId = nullptr ;
325+ auto rhsTyInfo = clangCtx.getTrivialTypeSourceInfo (rhsTy);
326+ auto rhsParamDecl = clang::ParmVarDecl::Create (
327+ clangCtx, equalEqualDecl, clang::SourceLocation (),
328+ clang::SourceLocation (), rhsParamId, rhsTy, rhsTyInfo,
329+ clang::StorageClass::SC_None, nullptr );
330+ auto rhsParamRefExpr = new (clangCtx) clang::DeclRefExpr (
331+ clangCtx, rhsParamDecl, false , rhsTy, clang::ExprValueKind::VK_LValue,
332+ clang::SourceLocation ());
333+
334+ equalEqualDecl->setParams ({lhsParamDecl, rhsParamDecl});
335+
336+ // Lookup the `operator==` function that will be called under the hood.
337+ clang::UnresolvedSet<16 > operators;
338+ // Note: calling `CreateOverloadedBinOp` emits an error if the looked up
339+ // function is unavailable for the current target.
340+ auto underlyingCallResult = clangSema.CreateOverloadedBinOp (
341+ clang::SourceLocation (), operatorKind, operators, lhsParamRefExpr,
342+ rhsParamRefExpr);
343+ if (!underlyingCallResult.isUsable ())
344+ return false ;
345+ auto underlyingCall = underlyingCallResult.get ();
346+
347+ auto equalEqualBody = clang::ReturnStmt::Create (
348+ clangCtx, clang::SourceLocation (), underlyingCall, nullptr );
349+ equalEqualDecl->setBody (equalEqualBody);
350+
351+ impl.synthesizedAndAlwaysVisibleDecls .insert (equalEqualDecl);
352+ auto lookupTable = impl.findLookupTable (classDecl);
353+ addEntryToLookupTable (*lookupTable, equalEqualDecl, impl.getNameImporter ());
354+ return true ;
263355}
264356
265357bool swift::isIterator (const clang::CXXRecordDecl *clangDecl) {
@@ -274,6 +366,7 @@ void swift::conformToCxxIteratorIfNeeded(
274366 assert (decl);
275367 assert (clangDecl);
276368 ASTContext &ctx = decl->getASTContext ();
369+ clang::ASTContext &clangCtx = clangDecl->getASTContext ();
277370
278371 if (!ctx.getProtocol (KnownProtocolKind::UnsafeCxxInputIterator))
279372 return ;
@@ -349,15 +442,28 @@ void swift::conformToCxxIteratorIfNeeded(
349442 if (!successorTy || successorTy->getAnyNominal () != decl)
350443 return ;
351444
352- // If this is a templated class, `operator==` might be templated as well.
353- // Try to instantiate it.
354- if (auto templateSpec =
355- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
356- instantiateTemplatedOperator (impl, templateSpec,
357- clang::BinaryOperatorKind::BO_EQ);
358- }
359445 // Check if present: `func ==`
360446 auto equalEqual = getEqualEqualOperator (decl);
447+ if (!equalEqual) {
448+ // If this class is inherited, `operator==` might be defined for a base
449+ // class. If this is a templated class, `operator==` might be templated as
450+ // well. Try to instantiate it.
451+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
452+ impl, clangDecl, clang::BinaryOperatorKind::BO_EQ);
453+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
454+ // If `operator==` was instantiated successfully, try to find `func ==`
455+ // again.
456+ equalEqual = getEqualEqualOperator (decl);
457+ if (!equalEqual) {
458+ // If `func ==` still can't be found, it might be defined for a base
459+ // class of the current class.
460+ auto paramTy = clangCtx.getRecordType (clangDecl);
461+ synthesizeCXXOperator (impl, clangDecl, clang::BinaryOperatorKind::BO_EQ,
462+ paramTy, paramTy, clangCtx.BoolTy );
463+ equalEqual = getEqualEqualOperator (decl);
464+ }
465+ }
466+ }
361467 if (!equalEqual)
362468 return ;
363469
@@ -371,18 +477,46 @@ void swift::conformToCxxIteratorIfNeeded(
371477
372478 // Try to conform to UnsafeCxxRandomAccessIterator if possible.
373479
374- if (auto templateSpec =
375- dyn_cast<clang::ClassTemplateSpecializationDecl>(clangDecl)) {
376- instantiateTemplatedOperator (impl, templateSpec,
377- clang::BinaryOperatorKind::BO_Sub);
480+ // Check if present: `func -`
481+ auto minus = getMinusOperator (decl);
482+ if (!minus) {
483+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
484+ impl, clangDecl, clang::BinaryOperatorKind::BO_Sub);
485+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
486+ minus = getMinusOperator (decl);
487+ if (!minus) {
488+ clang::QualType returnTy = instantiated->getReturnType ();
489+ auto paramTy = clangCtx.getRecordType (clangDecl);
490+ synthesizeCXXOperator (impl, clangDecl,
491+ clang::BinaryOperatorKind::BO_Sub, paramTy,
492+ paramTy, returnTy);
493+ minus = getMinusOperator (decl);
494+ }
495+ }
378496 }
379- auto minus = dyn_cast_or_null<FuncDecl>(getMinusOperator (decl));
380497 if (!minus)
381498 return ;
382499 auto distanceTy = minus->getResultInterfaceType ();
383500 // distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
384501
385- auto plusEqual = dyn_cast_or_null<FuncDecl>(getPlusEqualOperator (decl, distanceTy));
502+ auto plusEqual = getPlusEqualOperator (decl, distanceTy);
503+ if (!plusEqual) {
504+ clang::FunctionDecl *instantiated = instantiateTemplatedOperator (
505+ impl, clangDecl, clang::BinaryOperatorKind::BO_AddAssign);
506+ if (instantiated && !impl.isUnavailableInSwift (instantiated)) {
507+ plusEqual = getPlusEqualOperator (decl, distanceTy);
508+ if (!plusEqual) {
509+ clang::QualType returnTy = instantiated->getReturnType ();
510+ auto clangMinus = cast<clang::FunctionDecl>(minus->getClangDecl ());
511+ auto lhsTy = clangCtx.getRecordType (clangDecl);
512+ auto rhsTy = clangMinus->getReturnType ();
513+ synthesizeCXXOperator (impl, clangDecl,
514+ clang::BinaryOperatorKind::BO_AddAssign, lhsTy,
515+ rhsTy, returnTy);
516+ plusEqual = getPlusEqualOperator (decl, distanceTy);
517+ }
518+ }
519+ }
386520 if (!plusEqual)
387521 return ;
388522
0 commit comments