@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
41494149 return params ();
41504150 }
41514151
4152+ // / Get the type of the error that will be thrown by the \c async method or \c
4153+ // / None if the completion handler doesn't accept an error parameter.
4154+ // / This may be more specialized than the generic 'Error' type if the
4155+ // / completion handler of the converted function takes a more specialized
4156+ // / error type.
4157+ Optional<swift::Type> getErrorType () const {
4158+ if (HasError) {
4159+ switch (Type) {
4160+ case HandlerType::INVALID:
4161+ return None;
4162+ case HandlerType::PARAMS:
4163+ // The last parameter of the completion handler is the error param
4164+ return params ().back ().getPlainType ()->lookThroughSingleOptionalType ();
4165+ case HandlerType::RESULT:
4166+ assert (
4167+ params ().size () == 1 &&
4168+ " Result handler should have the Result type as the only parameter" );
4169+ auto ResultType =
4170+ params ().back ().getPlainType ()->getAs <BoundGenericType>();
4171+ auto GenericArgs = ResultType->getGenericArgs ();
4172+ assert (GenericArgs.size () == 2 && " Result should have two params" );
4173+ // The second (last) generic parameter of the Result type is the error
4174+ // type.
4175+ return GenericArgs.back ();
4176+ }
4177+ } else {
4178+ return None;
4179+ }
4180+ }
4181+
41524182 // / The `CallExpr` if the given node is a call to the `Handler`
41534183 CallExpr *getAsHandlerCall (ASTNode Node) const {
41544184 if (!isValid ())
@@ -5319,6 +5349,262 @@ class AsyncConverter : private SourceEntityWalker {
53195349 }
53205350 }
53215351};
5352+
5353+ // / When adding an async alternative method for the function declaration \c FD,
5354+ // / this class tries to create a function body for the legacy function (the one
5355+ // / with a completion handler), which calls the newly converted async function.
5356+ // / There are certain situations in which we fail to create such a body, e.g.
5357+ // / if the completion handler has the signature `(String, Error?) -> Void` in
5358+ // / which case we can't synthesize the result of type \c String in the error
5359+ // / case.
5360+ class LegacyAlternativeBodyCreator {
5361+ // / The old function declaration for which an async alternative has been added
5362+ // / and whose body shall be rewritten to call the newly added async
5363+ // / alternative.
5364+ FuncDecl *FD;
5365+
5366+ // / The description of the completion handler in the old function declaration.
5367+ AsyncHandlerDesc HandlerDesc;
5368+
5369+ std::string Buffer;
5370+ llvm::raw_string_ostream OS;
5371+
5372+ // / Adds the call to the refactored 'async' method without the 'await'
5373+ // / keyword to the output stream.
5374+ void addCallToAsyncMethod () {
5375+ OS << FD->getBaseName () << " (" ;
5376+ bool FirstParam = true ;
5377+ for (auto Param : *FD->getParameters ()) {
5378+ if (Param == HandlerDesc.Handler ) {
5379+ // / We don't need to pass the completion handler to the async method.
5380+ continue ;
5381+ }
5382+ if (!FirstParam) {
5383+ OS << " , " ;
5384+ } else {
5385+ FirstParam = false ;
5386+ }
5387+ if (!Param->getArgumentName ().empty ()) {
5388+ OS << Param->getArgumentName () << " : " ;
5389+ }
5390+ OS << Param->getParameterName ();
5391+ }
5392+ OS << " )" ;
5393+ }
5394+
5395+ // / If the returned error type is more specialized than \c Error, adds an
5396+ // / 'as! CustomError' cast to the more specialized error type to the output
5397+ // / stream.
5398+ void addCastToCustomErrorTypeIfNecessary () {
5399+ auto ErrorType = *HandlerDesc.getErrorType ();
5400+ if (ErrorType->getCanonicalType () !=
5401+ FD->getASTContext ().getExceptionType ()) {
5402+ OS << " as! " ;
5403+ ErrorType->lookThroughSingleOptionalType ()->print (OS);
5404+ }
5405+ }
5406+
5407+ // / Adds the \c Index -th parameter to the completion handler.
5408+ // / If \p HasResult is \c true, it is assumed that a variable named 'result'
5409+ // / contains the result returned from the async alternative. If the callback
5410+ // / also takes an error parameter, \c nil passed to the completion handler for
5411+ // / the error.
5412+ // / If \p HasResult is \c false, it is a assumed that a variable named 'error'
5413+ // / contains the error thrown from the async method and 'nil' will be passed
5414+ // / to the completion handler for all result parameters.
5415+ void addCompletionHandlerArgument (size_t Index, bool HasResult) {
5416+ if (HandlerDesc.HasError && Index == HandlerDesc.params ().size () - 1 ) {
5417+ // The error parameter is the last argument of the completion handler.
5418+ if (!HasResult) {
5419+ OS << " error" ;
5420+ addCastToCustomErrorTypeIfNecessary ();
5421+ } else {
5422+ OS << " nil" ;
5423+ }
5424+ } else {
5425+ if (!HasResult) {
5426+ OS << " nil" ;
5427+ } else if (HandlerDesc
5428+ .getSuccessParamAsyncReturnType (
5429+ HandlerDesc.params ()[Index].getPlainType ())
5430+ ->isVoid ()) {
5431+ // Void return types are not returned by the async function, synthesize
5432+ // a Void instance.
5433+ OS << " ()" ;
5434+ } else if (HandlerDesc.getSuccessParams ().size () > 1 ) {
5435+ // If the async method returns a tuple, we need to pass its elements to
5436+ // the completion handler separately. For example:
5437+ //
5438+ // func foo() async -> (String, Int) {}
5439+ //
5440+ // causes the following legacy body to be created:
5441+ //
5442+ // func foo(completion: (String, Int) -> Void) {
5443+ // async {
5444+ // let result = await foo()
5445+ // completion(result.0, result.1)
5446+ // }
5447+ // }
5448+ OS << " result." << Index;
5449+ } else {
5450+ OS << " result" ;
5451+ }
5452+ }
5453+ }
5454+
5455+ // / Adds the call to the completion handler. See \c
5456+ // / getCompletionHandlerArgument for how the arguments are synthesized if the
5457+ // / completion handler takes arguments, not a \c Result type.
5458+ void addCallToCompletionHandler (bool HasResult) {
5459+ OS << HandlerDesc.Handler ->getParameterName () << " (" ;
5460+
5461+ // Construct arguments to pass to the completion handler
5462+ switch (HandlerDesc.Type ) {
5463+ case HandlerType::INVALID:
5464+ llvm_unreachable (" Cannot be rewritten" );
5465+ break ;
5466+ case HandlerType::PARAMS: {
5467+ for (size_t I = 0 ; I < HandlerDesc.params ().size (); ++I) {
5468+ if (I > 0 ) {
5469+ OS << " , " ;
5470+ }
5471+ addCompletionHandlerArgument (I, HasResult);
5472+ }
5473+ break ;
5474+ }
5475+ case HandlerType::RESULT: {
5476+ if (HasResult) {
5477+ OS << " .success(result)" ;
5478+ } else {
5479+ OS << " .failure(error" ;
5480+ addCastToCustomErrorTypeIfNecessary ();
5481+ OS << " )" ;
5482+ }
5483+ break ;
5484+ }
5485+ }
5486+ OS << " )" ; // Close the call to the completion handler
5487+ }
5488+
5489+ // / Adds the result type of the converted async function.
5490+ void addAsyncFuncReturnType () {
5491+ SmallVector<Type, 2 > Scratch;
5492+ auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
5493+ if (ReturnTypes.size () > 1 ) {
5494+ OS << " (" ;
5495+ }
5496+
5497+ llvm::interleave (
5498+ ReturnTypes, [&](Type Ty) { Ty->print (OS); }, [&]() { OS << " , " ; });
5499+
5500+ if (ReturnTypes.size () > 1 ) {
5501+ OS << " )" ;
5502+ }
5503+ }
5504+
5505+ // / If the async alternative function is generic, adds the type annotation
5506+ // / to the 'return' variable in the legacy function so that the generic
5507+ // / parameters of the legacy function are passed to the generic function.
5508+ // / For example for
5509+ // / \code
5510+ // / func foo<GenericParam>() async -> GenericParam {}
5511+ // / \endcode
5512+ // / we generate
5513+ // / \code
5514+ // / func foo<GenericParam>(completion: (T) -> Void) {
5515+ // / async {
5516+ // / let result: GenericParam = await foo()
5517+ // / <------------>
5518+ // / completion(result)
5519+ // / }
5520+ // / }
5521+ // / \endcode
5522+ // / This function adds the range marked by \c <----->
5523+ void addResultTypeAnnotationIfNecessary () {
5524+ if (FD->isGeneric ()) {
5525+ OS << " : " ;
5526+ addAsyncFuncReturnType ();
5527+ }
5528+ }
5529+
5530+ public:
5531+ LegacyAlternativeBodyCreator (FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5532+ : FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5533+
5534+ bool canRewriteLegacyBody () {
5535+ if (FD == nullptr || FD->getBody () == nullptr ) {
5536+ return false ;
5537+ }
5538+ if (FD->hasThrows ()) {
5539+ assert (!HandlerDesc.isValid () && " We shouldn't have found a handler desc "
5540+ " if the original function throws" );
5541+ return false ;
5542+ }
5543+ switch (HandlerDesc.Type ) {
5544+ case HandlerType::INVALID:
5545+ return false ;
5546+ case HandlerType::PARAMS: {
5547+ if (HandlerDesc.HasError ) {
5548+ // The non-error parameters must be optional so that we can set them to
5549+ // nil in the error case.
5550+ // The error parameter must be optional so we can set it to nil in the
5551+ // success case.
5552+ // Otherwise we can't synthesize the values to return for these
5553+ // parameters.
5554+ return llvm::all_of (HandlerDesc.params (),
5555+ [](AnyFunctionType::Param Param) -> bool {
5556+ return Param.getPlainType ()->isOptional ();
5557+ });
5558+ } else {
5559+ return true ;
5560+ }
5561+ }
5562+ case HandlerType::RESULT:
5563+ return true ;
5564+ }
5565+ }
5566+
5567+ std::string create () {
5568+ assert (Buffer.empty () &&
5569+ " LegacyAlternativeBodyCreator can only be used once" );
5570+ assert (canRewriteLegacyBody () &&
5571+ " Cannot create a legacy body if the body can't be rewritten" );
5572+ OS << " {\n " ; // start function body
5573+ OS << " async {\n " ;
5574+ if (HandlerDesc.HasError ) {
5575+ OS << " do {\n " ;
5576+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5577+ OS << " let result" ;
5578+ addResultTypeAnnotationIfNecessary ();
5579+ OS << " = " ;
5580+ }
5581+ OS << " try await " ;
5582+ addCallToAsyncMethod ();
5583+ OS << " \n " ;
5584+ addCallToCompletionHandler (/* HasResult=*/ true );
5585+ OS << " \n "
5586+ << " } catch {\n " ;
5587+ addCallToCompletionHandler (/* HasResult=*/ false );
5588+ OS << " \n "
5589+ << " }\n " ; // end catch
5590+ } else {
5591+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5592+ OS << " let result" ;
5593+ addResultTypeAnnotationIfNecessary ();
5594+ OS << " = " ;
5595+ }
5596+ OS << " await " ;
5597+ addCallToAsyncMethod ();
5598+ OS << " \n " ;
5599+ addCallToCompletionHandler (/* HasResult=*/ true );
5600+ OS << " \n " ;
5601+ }
5602+ OS << " }\n " ; // end 'async'
5603+ OS << " }\n " ; // end function body
5604+ return Buffer;
5605+ }
5606+ };
5607+
53225608} // namespace asyncrefactorings
53235609
53245610bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -5425,6 +5711,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54255711 EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
54265712 " @available(*, deprecated, message: \" Prefer async "
54275713 " alternative instead\" )\n " );
5714+ LegacyAlternativeBodyCreator LegacyBody (FD, HandlerDesc);
5715+ if (LegacyBody.canRewriteLegacyBody ()) {
5716+ EditConsumer.accept (SM,
5717+ Lexer::getCharSourceRangeFromSourceRange (
5718+ SM, FD->getBody ()->getSourceRange ()),
5719+ LegacyBody.create ());
5720+ }
54285721 Converter.insertAfter (FD, EditConsumer);
54295722
54305723 return false ;
0 commit comments