@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
41494149 return params ();
41504150 }
41514151
4152+ // / Get the type of the error that will be thrown by the \c async method or \c
4153+ // / None if the completion handler doesn't accept an error parameter.
4154+ // / This may be more specialized than the generic 'Error' type if the
4155+ // / completion handler of the converted function takes a more specialized
4156+ // / error type.
4157+ Optional<swift::Type> getErrorType () const {
4158+ if (HasError) {
4159+ switch (Type) {
4160+ case HandlerType::INVALID:
4161+ return None;
4162+ case HandlerType::PARAMS:
4163+ // The last parameter of the completion handler is the error param
4164+ return params ().back ().getPlainType ()->lookThroughSingleOptionalType ();
4165+ case HandlerType::RESULT:
4166+ assert (
4167+ params ().size () == 1 &&
4168+ " Result handler should have the Result type as the only parameter" );
4169+ auto ResultType =
4170+ params ().back ().getPlainType ()->getAs <BoundGenericType>();
4171+ auto GenericArgs = ResultType->getGenericArgs ();
4172+ assert (GenericArgs.size () == 2 && " Result should have two params" );
4173+ // The second (last) generic parameter of the Result type is the error
4174+ // type.
4175+ return GenericArgs.back ();
4176+ }
4177+ } else {
4178+ return None;
4179+ }
4180+ }
4181+
41524182 // / The `CallExpr` if the given node is a call to the `Handler`
41534183 CallExpr *getAsHandlerCall (ASTNode Node) const {
41544184 if (!isValid ())
@@ -5310,6 +5340,262 @@ class AsyncConverter : private SourceEntityWalker {
53105340 }
53115341 }
53125342};
5343+
5344+ // / When adding an async alternative method for the function declaration \c FD,
5345+ // / this class tries to create a function body for the legacy function (the one
5346+ // / with a completion handler), which calls the newly converted async function.
5347+ // / There are certain situations in which we fail to create such a body, e.g.
5348+ // / if the completion handler has the signature `(String, Error?) -> Void` in
5349+ // / which case we can't synthesize the result of type \c String in the error
5350+ // / case.
5351+ class LegacyAlternativeBodyCreator {
5352+ // / The old function declaration for which an async alternative has been added
5353+ // / and whose body shall be rewritten to call the newly added async
5354+ // / alternative.
5355+ FuncDecl *FD;
5356+
5357+ // / The description of the completion handler in the old function declaration.
5358+ AsyncHandlerDesc HandlerDesc;
5359+
5360+ std::string Buffer;
5361+ llvm::raw_string_ostream OS;
5362+
5363+ // / Adds the call to the refactored 'async' method without the 'await'
5364+ // / keyword to the output stream.
5365+ void addCallToAsyncMethod () {
5366+ OS << FD->getBaseName () << " (" ;
5367+ bool FirstParam = true ;
5368+ for (auto Param : *FD->getParameters ()) {
5369+ if (Param == HandlerDesc.Handler ) {
5370+ // / We don't need to pass the completion handler to the async method.
5371+ continue ;
5372+ }
5373+ if (!FirstParam) {
5374+ OS << " , " ;
5375+ } else {
5376+ FirstParam = false ;
5377+ }
5378+ if (!Param->getArgumentName ().empty ()) {
5379+ OS << Param->getArgumentName () << " : " ;
5380+ }
5381+ OS << Param->getParameterName ();
5382+ }
5383+ OS << " )" ;
5384+ }
5385+
5386+ // / If the returned error type is more specialized than \c Error, adds an
5387+ // / 'as! CustomError' cast to the more specialized error type to the output
5388+ // / stream.
5389+ void addCastToCustomErrorTypeIfNecessary () {
5390+ auto ErrorType = *HandlerDesc.getErrorType ();
5391+ if (ErrorType->getCanonicalType () !=
5392+ FD->getASTContext ().getExceptionType ()) {
5393+ OS << " as! " ;
5394+ ErrorType->lookThroughSingleOptionalType ()->print (OS);
5395+ }
5396+ }
5397+
5398+ // / Adds the \c Index -th parameter to the completion handler.
5399+ // / If \p HasResult is \c true, it is assumed that a variable named 'result'
5400+ // / contains the result returned from the async alternative. If the callback
5401+ // / also takes an error parameter, \c nil passed to the completion handler for
5402+ // / the error.
5403+ // / If \p HasResult is \c false, it is a assumed that a variable named 'error'
5404+ // / contains the error thrown from the async method and 'nil' will be passed
5405+ // / to the completion handler for all result parameters.
5406+ void addCompletionHandlerArgument (size_t Index, bool HasResult) {
5407+ if (HandlerDesc.HasError && Index == HandlerDesc.params ().size () - 1 ) {
5408+ // The error parameter is the last argument of the completion handler.
5409+ if (!HasResult) {
5410+ OS << " error" ;
5411+ addCastToCustomErrorTypeIfNecessary ();
5412+ } else {
5413+ OS << " nil" ;
5414+ }
5415+ } else {
5416+ if (!HasResult) {
5417+ OS << " nil" ;
5418+ } else if (HandlerDesc
5419+ .getSuccessParamAsyncReturnType (
5420+ HandlerDesc.params ()[Index].getPlainType ())
5421+ ->isVoid ()) {
5422+ // Void return types are not returned by the async function, synthesize
5423+ // a Void instance.
5424+ OS << " ()" ;
5425+ } else if (HandlerDesc.getSuccessParams ().size () > 1 ) {
5426+ // If the async method returns a tuple, we need to pass its elements to
5427+ // the completion handler separately. For example:
5428+ //
5429+ // func foo() async -> (String, Int) {}
5430+ //
5431+ // causes the following legacy body to be created:
5432+ //
5433+ // func foo(completion: (String, Int) -> Void) {
5434+ // async {
5435+ // let result = await foo()
5436+ // completion(result.0, result.1)
5437+ // }
5438+ // }
5439+ OS << " result." << Index;
5440+ } else {
5441+ OS << " result" ;
5442+ }
5443+ }
5444+ }
5445+
5446+ // / Adds the call to the completion handler. See \c
5447+ // / getCompletionHandlerArgument for how the arguments are synthesized if the
5448+ // / completion handler takes arguments, not a \c Result type.
5449+ void addCallToCompletionHandler (bool HasResult) {
5450+ OS << HandlerDesc.Handler ->getParameterName () << " (" ;
5451+
5452+ // Construct arguments to pass to the completion handler
5453+ switch (HandlerDesc.Type ) {
5454+ case HandlerType::INVALID:
5455+ llvm_unreachable (" Cannot be rewritten" );
5456+ break ;
5457+ case HandlerType::PARAMS: {
5458+ for (size_t I = 0 ; I < HandlerDesc.params ().size (); ++I) {
5459+ if (I > 0 ) {
5460+ OS << " , " ;
5461+ }
5462+ addCompletionHandlerArgument (I, HasResult);
5463+ }
5464+ break ;
5465+ }
5466+ case HandlerType::RESULT: {
5467+ if (HasResult) {
5468+ OS << " .success(result)" ;
5469+ } else {
5470+ OS << " .failure(error" ;
5471+ addCastToCustomErrorTypeIfNecessary ();
5472+ OS << " )" ;
5473+ }
5474+ break ;
5475+ }
5476+ }
5477+ OS << " )" ; // Close the call to the completion handler
5478+ }
5479+
5480+ // / Adds the result type of the converted async function.
5481+ void addAsyncFuncReturnType () {
5482+ SmallVector<Type, 2 > Scratch;
5483+ auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
5484+ if (ReturnTypes.size () > 1 ) {
5485+ OS << " (" ;
5486+ }
5487+
5488+ llvm::interleave (
5489+ ReturnTypes, [&](Type Ty) { Ty->print (OS); }, [&]() { OS << " , " ; });
5490+
5491+ if (ReturnTypes.size () > 1 ) {
5492+ OS << " )" ;
5493+ }
5494+ }
5495+
5496+ // / If the async alternative function is generic, adds the type annotation
5497+ // / to the 'return' variable in the legacy function so that the generic
5498+ // / parameters of the legacy function are passed to the generic function.
5499+ // / For example for
5500+ // / \code
5501+ // / func foo<GenericParam>() async -> GenericParam {}
5502+ // / \endcode
5503+ // / we generate
5504+ // / \code
5505+ // / func foo<GenericParam>(completion: (T) -> Void) {
5506+ // / async {
5507+ // / let result: GenericParam = await foo()
5508+ // / <------------>
5509+ // / completion(result)
5510+ // / }
5511+ // / }
5512+ // / \endcode
5513+ // / This function adds the range marked by \c <----->
5514+ void addResultTypeAnnotationIfNecessary () {
5515+ if (FD->isGeneric ()) {
5516+ OS << " : " ;
5517+ addAsyncFuncReturnType ();
5518+ }
5519+ }
5520+
5521+ public:
5522+ LegacyAlternativeBodyCreator (FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5523+ : FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5524+
5525+ bool canRewriteLegacyBody () {
5526+ if (FD == nullptr || FD->getBody () == nullptr ) {
5527+ return false ;
5528+ }
5529+ if (FD->hasThrows ()) {
5530+ assert (!HandlerDesc.isValid () && " We shouldn't have found a handler desc "
5531+ " if the original function throws" );
5532+ return false ;
5533+ }
5534+ switch (HandlerDesc.Type ) {
5535+ case HandlerType::INVALID:
5536+ return false ;
5537+ case HandlerType::PARAMS: {
5538+ if (HandlerDesc.HasError ) {
5539+ // The non-error parameters must be optional so that we can set them to
5540+ // nil in the error case.
5541+ // The error parameter must be optional so we can set it to nil in the
5542+ // success case.
5543+ // Otherwise we can't synthesize the values to return for these
5544+ // parameters.
5545+ return llvm::all_of (HandlerDesc.params (),
5546+ [](AnyFunctionType::Param Param) -> bool {
5547+ return Param.getPlainType ()->isOptional ();
5548+ });
5549+ } else {
5550+ return true ;
5551+ }
5552+ }
5553+ case HandlerType::RESULT:
5554+ return true ;
5555+ }
5556+ }
5557+
5558+ std::string create () {
5559+ assert (Buffer.empty () &&
5560+ " LegacyAlternativeBodyCreator can only be used once" );
5561+ assert (canRewriteLegacyBody () &&
5562+ " Cannot create a legacy body if the body can't be rewritten" );
5563+ OS << " {\n " ; // start function body
5564+ OS << " async {\n " ;
5565+ if (HandlerDesc.HasError ) {
5566+ OS << " do {\n " ;
5567+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5568+ OS << " let result" ;
5569+ addResultTypeAnnotationIfNecessary ();
5570+ OS << " = " ;
5571+ }
5572+ OS << " try await " ;
5573+ addCallToAsyncMethod ();
5574+ OS << " \n " ;
5575+ addCallToCompletionHandler (/* HasResult=*/ true );
5576+ OS << " \n "
5577+ << " } catch {\n " ;
5578+ addCallToCompletionHandler (/* HasResult=*/ false );
5579+ OS << " \n "
5580+ << " }\n " ; // end catch
5581+ } else {
5582+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5583+ OS << " let result" ;
5584+ addResultTypeAnnotationIfNecessary ();
5585+ OS << " = " ;
5586+ }
5587+ OS << " await " ;
5588+ addCallToAsyncMethod ();
5589+ OS << " \n " ;
5590+ addCallToCompletionHandler (/* HasResult=*/ true );
5591+ OS << " \n " ;
5592+ }
5593+ OS << " }\n " ; // end 'async'
5594+ OS << " }\n " ; // end function body
5595+ return Buffer;
5596+ }
5597+ };
5598+
53135599} // namespace asyncrefactorings
53145600
53155601bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -5416,6 +5702,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54165702 EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
54175703 " @available(*, deprecated, message: \" Prefer async "
54185704 " alternative instead\" )\n " );
5705+ LegacyAlternativeBodyCreator LegacyBody (FD, HandlerDesc);
5706+ if (LegacyBody.canRewriteLegacyBody ()) {
5707+ EditConsumer.accept (SM,
5708+ Lexer::getCharSourceRangeFromSourceRange (
5709+ SM, FD->getBody ()->getSourceRange ()),
5710+ LegacyBody.create ());
5711+ }
54195712 Converter.insertAfter (FD, EditConsumer);
54205713
54215714 return false ;
0 commit comments