@@ -121,6 +121,319 @@ static FuncDecl *deriveDistributedActor_resolve(DerivedConformance &derived) {
121121 return factoryDecl;
122122}
123123
124+ /* *****************************************************************************/
125+ /* ************** INVOKE HANDLER ON-RETURN FUNCTION ****************************/
126+ /* *****************************************************************************/
127+
128+ namespace {
129+ struct DoInvokeOnReturnContext {
130+ ParamDecl *handlerParam;
131+ ParamDecl *resultBufferParam;
132+ };
133+ } // namespace
134+
135+ static std::pair<BraceStmt *, bool >
136+ deriveBodyDistributed_doInvokeOnReturn (AbstractFunctionDecl *afd, void *arg) {
137+ auto &C = afd->getASTContext ();
138+ auto *context = static_cast <DoInvokeOnReturnContext *>(arg);
139+
140+ // mock locations, we're a thunk and don't really need detailed locations
141+ const SourceLoc sloc = SourceLoc ();
142+ const DeclNameLoc dloc = DeclNameLoc ();
143+ bool implicit = true ;
144+
145+ auto returnTypeParam = afd->getParameters ()->get (0 );
146+ SmallVector<ASTNode, 8 > stmts;
147+
148+ VarDecl *resultVar =
149+ new (C) VarDecl (/* isStatic=*/ false , VarDecl::Introducer::Let, sloc,
150+ C.getIdentifier (" result" ), afd);
151+ {
152+ auto resultLoadCall = CallExpr::createImplicit (
153+ C,
154+ UnresolvedDotExpr::createImplicit (
155+ C,
156+ /* base=*/
157+ new (C) DeclRefExpr (ConcreteDeclRef (context->resultBufferParam ),
158+ dloc, implicit),
159+ /* baseName=*/ DeclBaseName (C.getIdentifier (" load" )),
160+ /* argLabels=*/
161+ {C.getIdentifier (" fromByteOffset" ), C.getIdentifier (" as" )}),
162+ ArgumentList::createImplicit (
163+ C, {Argument (sloc, C.getIdentifier (" as" ),
164+ new (C) DeclRefExpr (ConcreteDeclRef (returnTypeParam),
165+ dloc, implicit))}));
166+
167+ auto resultPattern = NamedPattern::createImplicit (C, resultVar);
168+ auto resultPB = PatternBindingDecl::createImplicit (
169+ C, swift::StaticSpellingKind::None, resultPattern,
170+ /* expr=*/ resultLoadCall, afd);
171+
172+ stmts.push_back (resultPB);
173+ stmts.push_back (resultVar);
174+ }
175+
176+ // call the ad-hoc `handler.onReturn`
177+ {
178+ // Find the ad-hoc requirement ensured function on the concrete handler:
179+ auto onReturnFunc = C.getOnReturnOnDistributedTargetInvocationResultHandler (
180+ context->handlerParam ->getInterfaceType ()->getAnyNominal ());
181+ assert (onReturnFunc && " did not find ad-hoc requirement witness!" );
182+
183+ Expr *callExpr = CallExpr::createImplicit (
184+ C,
185+ UnresolvedDotExpr::createImplicit (
186+ C,
187+ /* base=*/
188+ new (C) DeclRefExpr (ConcreteDeclRef (context->handlerParam ), dloc,
189+ implicit),
190+ /* baseName=*/ onReturnFunc->getBaseName (),
191+ /* paramList=*/ onReturnFunc->getParameters ()),
192+ ArgumentList::forImplicitCallTo (
193+ DeclNameRef (onReturnFunc->getName ()),
194+ {new (C) DeclRefExpr (ConcreteDeclRef (resultVar), dloc, implicit)},
195+ C));
196+ callExpr = TryExpr::createImplicit (C, sloc, callExpr);
197+ callExpr = AwaitExpr::createImplicit (C, sloc, callExpr);
198+
199+ stmts.push_back (callExpr);
200+ }
201+
202+ auto body = BraceStmt::create (C, sloc, {stmts}, sloc, implicit);
203+ return {body, /* isTypeChecked=*/ false };
204+ }
205+
206+ // Create local function:
207+ // func invokeOnReturn<R: Self.SerializationRequirement>(
208+ // _ returnType: R.Type
209+ // ) async throws {
210+ // let value = resultBuffer.load(as: returnType)
211+ // try await handler.onReturn(value: value)
212+ // }
213+ static FuncDecl* createLocalFunc_doInvokeOnReturn (
214+ ASTContext& C, FuncDecl* parentFunc,
215+ NominalTypeDecl* systemNominal,
216+ ParamDecl* handlerParam,
217+ ParamDecl* resultBufParam) {
218+ auto DC = parentFunc;
219+ auto DAS = C.getDistributedActorSystemDecl ();
220+ auto doInvokeLocalFuncIdent = C.getIdentifier (" doInvokeOnReturn" );
221+
222+ // mock locations, we're a synthesized func and don't need real locations
223+ const SourceLoc sloc = SourceLoc ();
224+
225+ // <R: Self.SerializationRequirement>
226+ // We create the generic param at invalid depth, which means it'll be filled
227+ // by semantic analysis.
228+ auto *resultGenericParamDecl = GenericTypeParamDecl::createImplicit (
229+ parentFunc, C.getIdentifier (" R" ), /* depth*/ 0 , /* index*/ 0 );
230+ GenericParamList *doInvokeGenericParamList =
231+ GenericParamList::create (C, sloc, {resultGenericParamDecl}, sloc);
232+
233+ auto returnTypeIdent = C.getIdentifier (" returnType" );
234+ auto resultTyParamDecl =
235+ ParamDecl::createImplicit (C,
236+ /* argument=*/ returnTypeIdent,
237+ /* parameter=*/ returnTypeIdent,
238+ resultGenericParamDecl->getInterfaceType (), DC);
239+ ParameterList *doInvokeParamsList =
240+ ParameterList::create (C, {resultTyParamDecl});
241+
242+ SmallVector<Requirement, 2 > requirements;
243+ for (auto p : getDistributedSerializationRequirementProtocols (systemNominal, DAS)) {
244+ auto requirement =
245+ Requirement (RequirementKind::Conformance,
246+ resultGenericParamDecl->getDeclaredInterfaceType (),
247+ p->getDeclaredInterfaceType ());
248+ requirements.push_back (requirement);
249+ }
250+ GenericSignature doInvokeGenSig =
251+ buildGenericSignature (C, parentFunc->getGenericSignature (),
252+ {resultGenericParamDecl->getDeclaredInterfaceType ()
253+ ->castTo <GenericTypeParamType>()},
254+ std::move (requirements),
255+ /* allowInverses=*/ true );
256+
257+ FuncDecl *doInvokeOnReturnFunc = FuncDecl::createImplicit (
258+ C, swift::StaticSpellingKind::None,
259+ DeclName (C, doInvokeLocalFuncIdent, doInvokeParamsList),
260+ sloc,
261+ /* async=*/ true ,
262+ /* throws=*/ true ,
263+ /* ThrownType=*/ Type (),
264+ doInvokeGenericParamList, doInvokeParamsList,
265+ /* returnType=*/ C.TheEmptyTupleType , parentFunc);
266+ doInvokeOnReturnFunc->setImplicit ();
267+ doInvokeOnReturnFunc->setSynthesized ();
268+ doInvokeOnReturnFunc->setGenericSignature (doInvokeGenSig);
269+
270+ auto *doInvokeContext = C.Allocate <DoInvokeOnReturnContext>();
271+ doInvokeContext->handlerParam = handlerParam;
272+ doInvokeContext->resultBufferParam = resultBufParam;
273+ doInvokeOnReturnFunc->setBodySynthesizer (
274+ deriveBodyDistributed_doInvokeOnReturn, doInvokeContext);
275+
276+ return doInvokeOnReturnFunc;
277+ }
278+
279+ static std::pair<BraceStmt *, bool >
280+ deriveBodyDistributed_invokeHandlerOnReturn (AbstractFunctionDecl *afd,
281+ void *context) {
282+ auto implicit = true ;
283+ ASTContext &C = afd->getASTContext ();
284+ auto DC = afd->getDeclContext ();
285+ auto DAS = C.getDistributedActorSystemDecl ();
286+
287+ // mock locations, we're a thunk and don't really need detailed locations
288+ const SourceLoc sloc = SourceLoc ();
289+ const DeclNameLoc dloc = DeclNameLoc ();
290+
291+ NominalTypeDecl *nominal = dyn_cast<NominalTypeDecl>(DC);
292+ assert (nominal);
293+
294+ auto func = dyn_cast<FuncDecl>(afd);
295+ assert (func);
296+
297+ // === parameters
298+ auto params = func->getParameters ();
299+ assert (params->size () == 3 );
300+ auto handlerParam = params->get (0 );
301+ auto resultBufParam = params->get (1 );
302+ auto metatypeParam = params->get (2 );
303+
304+ auto serializationRequirementTypeTy =
305+ getDistributedSerializationRequirementType (nominal, DAS);
306+
307+ auto serializationRequirementMetaTypeTy =
308+ ExistentialMetatypeType::get (serializationRequirementTypeTy);
309+
310+ // Statements
311+ SmallVector<ASTNode, 8 > stmts;
312+
313+ // --- `let m = metatype as! SerializationRequirement.Type`
314+ VarDecl *metatypeVar =
315+ new (C) VarDecl (/* isStatic=*/ false , VarDecl::Introducer::Let, sloc,
316+ C.getIdentifier (" m" ), func);
317+ {
318+ metatypeVar->setImplicit ();
319+ metatypeVar->setSynthesized ();
320+
321+ // metatype as! <<concrete SerializationRequirement.Type>>
322+ auto metatypeRef =
323+ new (C) DeclRefExpr (ConcreteDeclRef (metatypeParam), dloc, implicit);
324+ auto metatypeSRCastExpr = ForcedCheckedCastExpr::createImplicit (
325+ C, metatypeRef, serializationRequirementMetaTypeTy);
326+
327+ auto metatypePattern = NamedPattern::createImplicit (C, metatypeVar);
328+ auto metatypePB = PatternBindingDecl::createImplicit (
329+ C, swift::StaticSpellingKind::None, metatypePattern,
330+ /* expr=*/ metatypeSRCastExpr, func);
331+
332+ stmts.push_back (metatypePB);
333+ stmts.push_back (metatypeVar);
334+ }
335+
336+ // --- Declare the local function `doInvokeOnReturn`...
337+ FuncDecl *doInvokeOnReturnFunc = createLocalFunc_doInvokeOnReturn (
338+ C, func,
339+ nominal, handlerParam, resultBufParam);
340+ stmts.push_back (doInvokeOnReturnFunc);
341+
342+ // --- try await _openExistential(metatypeVar, do: <<doInvokeLocalFunc>>)
343+ {
344+ auto openExistentialBaseIdent = C.getIdentifier (" _openExistential" );
345+ auto doIdent = C.getIdentifier (" do" );
346+
347+ auto openExArgs = ArgumentList::createImplicit (
348+ C, {
349+ Argument (sloc, Identifier (),
350+ new (C) DeclRefExpr (ConcreteDeclRef (metatypeVar), dloc,
351+ implicit)),
352+ Argument (sloc, doIdent,
353+ new (C) DeclRefExpr (ConcreteDeclRef (doInvokeOnReturnFunc),
354+ dloc, implicit)),
355+ });
356+ Expr *tryAwaitDoOpenExistential =
357+ CallExpr::createImplicit (C,
358+ UnresolvedDeclRefExpr::createImplicit (
359+ C, openExistentialBaseIdent),
360+ openExArgs);
361+
362+ tryAwaitDoOpenExistential =
363+ AwaitExpr::createImplicit (C, sloc, tryAwaitDoOpenExistential);
364+ tryAwaitDoOpenExistential =
365+ TryExpr::createImplicit (C, sloc, tryAwaitDoOpenExistential);
366+
367+ stmts.push_back (tryAwaitDoOpenExistential);
368+ }
369+
370+ auto body = BraceStmt::create (C, sloc, {stmts}, sloc, implicit);
371+ return {body, /* isTypeChecked=*/ false };
372+ }
373+
374+ // / Synthesizes the
375+ // /
376+ // / \verbatim
377+ // / static func invokeHandlerOnReturn(
378+ // // handler: ResultHandler,
379+ // // resultBuffer: UnsafeRawPointer,
380+ // // metatype _metatype: Any.Type
381+ // // ) async throws
382+ // / \endverbatim
383+ static FuncDecl *deriveDistributedActorSystem_invokeHandlerOnReturn (
384+ DerivedConformance &derived) {
385+ auto system = derived.Nominal ;
386+ auto &C = system->getASTContext ();
387+
388+ // auto serializationRequirementType = getDistributedActorSystemType(decl);
389+ auto resultHandlerType = getDistributedActorSystemResultHandlerType (system);
390+ auto unsafeRawPointerType = C.getUnsafeRawPointerType ();
391+ auto anyTypeType = ExistentialMetatypeType::get (C.TheAnyType ); // Any.Type
392+
393+ // auto serializationRequirementType =
394+ // getDistributedSerializationRequirementType(system, DAS);
395+
396+ // params:
397+ // - handler: Self.ResultHandler
398+ // - resultBuffer:
399+ // - metatype _metatype: Any.Type
400+ auto *params = ParameterList::create (
401+ C,
402+ /* LParenLoc=*/ SourceLoc (),
403+ /* params=*/
404+ {
405+ ParamDecl::createImplicit (
406+ C, C.Id_handler , C.Id_handler ,
407+ system->mapTypeIntoContext (resultHandlerType), system),
408+ ParamDecl::createImplicit (
409+ C, C.Id_resultBuffer , C.Id_resultBuffer ,
410+ unsafeRawPointerType, system),
411+ ParamDecl::createImplicit (
412+ C, C.Id_metatype , C.Id_metatype ,
413+ anyTypeType, system)
414+ },
415+ /* RParenLoc=*/ SourceLoc ());
416+
417+ // Func name: invokeHandlerOnReturn(handler:resultBuffer:metatype)
418+ DeclName name (C, C.Id_invokeHandlerOnReturn , params);
419+
420+ // Expected type: (Self.ResultHandler, UnsafeRawPointer, any Any.Type) async
421+ // throws -> ()
422+ auto *funcDecl =
423+ FuncDecl::createImplicit (C, StaticSpellingKind::None, name, SourceLoc (),
424+ /* async=*/ true ,
425+ /* throws=*/ true ,
426+ /* ThrownType=*/ Type (),
427+ /* genericParams=*/ nullptr , params,
428+ /* returnType*/ TupleType::getEmpty (C), system);
429+ funcDecl->setSynthesized (true );
430+ funcDecl->copyFormalAccessFrom (system, /* sourceIsParentContext=*/ true );
431+ funcDecl->setBodySynthesizer (deriveBodyDistributed_invokeHandlerOnReturn);
432+
433+ derived.addMembersToConformanceContext ({funcDecl});
434+ return funcDecl;
435+ }
436+
124437/* *****************************************************************************/
125438/* ****************************** PROPERTIES ***********************************/
126439/* *****************************************************************************/
@@ -581,6 +894,14 @@ std::pair<Type, TypeDecl *> DerivedConformance::deriveDistributedActor(
581894
582895ValueDecl *
583896DerivedConformance::deriveDistributedActorSystem (ValueDecl *requirement) {
897+ if (auto func = dyn_cast<FuncDecl>(requirement)) {
898+ // just a simple name check is enough here,
899+ // if we are invoked here we know for sure it is for the "right" function
900+ if (func->getName ().getBaseName () == Context.Id_invokeHandlerOnReturn ) {
901+ return deriveDistributedActorSystem_invokeHandlerOnReturn (*this );
902+ }
903+ }
904+
584905 return nullptr ;
585906}
586907
0 commit comments