@@ -581,6 +581,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
581581 auto &C = getASTContext ();
582582 auto module = getParentModule ();
583583
584+ auto func = dyn_cast<FuncDecl>(this );
585+ if (!func) {
586+ return false ;
587+ }
588+
584589 // === Check base name
585590 if (getBaseIdentifier () != C.Id_recordArgument ) {
586591 return false ;
@@ -614,6 +619,12 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
614619 return false ;
615620 }
616621
622+ // --- must be mutating, if it is defined in a struct
623+ if (isa<StructDecl>(getDeclContext ()) &&
624+ !func->isMutating ()) {
625+ return false ;
626+ }
627+
617628 // --- Check number of generic parameters
618629 auto genericParams = getGenericParams ();
619630 unsigned int expectedGenericParamNum = 1 ;
@@ -639,56 +650,60 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
639650 return false ;
640651 }
641652
642- // --- Check parameter: _ argument
643- auto argumentParam = params->get (0 );
644- if (!argumentParam->getArgumentName ().is (" " )) {
645- return false ;
646- }
647-
648- // === Check generic parameters in detail
649- // --- Check: Argument: SerializationRequirement
650653 GenericTypeParamDecl *ArgumentParam = genericParams->getParams ()[0 ];
651654
652- auto sig = getGenericSignature ();
653- auto requirements = sig.getRequirements ();
655+ // --- Check parameter: _ argument
656+ auto argumentParam = params->get (0 );
657+ if (!argumentParam->getArgumentName ().empty ()) {
658+ return false ;
659+ }
654660
655- if (requirements.size () != expectedRequirementsNum) {
656- return false ;
657- }
661+ auto argumentTy = argumentParam->getInterfaceType ();
662+ auto argumentInContextTy = mapTypeIntoContext (argumentTy);
663+ if (argumentInContextTy->getAnyNominal () == C.getRemoteCallArgumentDecl ()) {
664+ auto argGenericParams = argumentInContextTy->getStructOrBoundGenericStruct ()
665+ ->getGenericParams ()->getParams ();
666+ if (argGenericParams.size () != 1 ) {
667+ return false ;
668+ }
658669
659- // --- Check the expected requirements
660- // --- all the Argument requirements ---
661- // conforms_to: Argument Decodable
662- // conforms_to: Argument Encodable
663- // ...
670+ // the <Value> of the RemoteCallArgument<Value>
671+ auto remoteCallArgValueGenericTy =
672+ mapTypeIntoContext (argGenericParams[0 ]->getInterfaceType ())
673+ ->getDesugaredType ()
674+ ->getMetatypeInstanceType ();
675+ // expected (the <Value> from the recordArgument<Value>)
676+ auto expectedGenericParamTy = mapTypeIntoContext (
677+ ArgumentParam->getInterfaceType ()->getMetatypeInstanceType ());
678+
679+ if (!remoteCallArgValueGenericTy->isEqual (expectedGenericParamTy)) {
680+ return false ;
681+ }
682+ } else {
683+ return false ;
684+ }
664685
665- auto func = dyn_cast<FuncDecl>(this );
666- if (!func) {
667- return false ;
668- }
669686
670- auto resultType = func->mapTypeIntoContext (argumentParam->getInterfaceType ())
671- ->getDesugaredType ();
672- auto resultParamType = func->mapTypeIntoContext (
673- ArgumentParam->getInterfaceType ()->getMetatypeInstanceType ());
674- // The result of the function must be the `Res` generic argument.
675- if (!resultType->isEqual (resultParamType)) {
676- return false ;
677- }
687+ auto sig = getGenericSignature ();
688+ auto requirements = sig.getRequirements ();
678689
679- for (auto requirementProto : requirementProtos) {
680- auto conformance = module ->lookupConformance (resultType, requirementProto);
681- if (conformance.isInvalid ()) {
690+ if (requirements.size () != expectedRequirementsNum) {
682691 return false ;
683692 }
684- }
685693
686- // === Check result type: Void
687- if (!func->getResultInterfaceType ()->isVoid ()) {
688- return false ;
689- }
694+ // --- Check the expected requirements
695+ // --- all the Argument requirements ---
696+ // e.g.
697+ // conforms_to: Argument Decodable
698+ // conforms_to: Argument Encodable
699+ // ...
690700
691- return true ;
701+ // === Check result type: Void
702+ if (!func->getResultInterfaceType ()->isVoid ()) {
703+ return false ;
704+ }
705+
706+ return true ;
692707}
693708
694709bool
@@ -879,8 +894,8 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
879894 }
880895
881896 // --- Check parameter: _ errorType
882- auto argumentParam = params->get (0 );
883- if (!argumentParam ->getArgumentName ().is (" " )) {
897+ auto errorTypeParam = params->get (0 );
898+ if (!errorTypeParam ->getArgumentName ().is (" " )) {
884899 return false ;
885900 }
886901
@@ -1140,6 +1155,14 @@ NominalTypeDecl::getDistributedRemoteCallTargetInitFunction() const {
11401155 GetDistributedRemoteCallTargetInitFunctionRequest (mutableThis), nullptr );
11411156}
11421157
1158+ ConstructorDecl *
1159+ NominalTypeDecl::getDistributedRemoteCallArgumentInitFunction () const {
1160+ auto mutableThis = const_cast <NominalTypeDecl *>(this );
1161+ return evaluateOrDefault (
1162+ getASTContext ().evaluator ,
1163+ GetDistributedRemoteCallArgumentInitFunctionRequest (mutableThis), nullptr );
1164+ }
1165+
11431166AbstractFunctionDecl *ASTContext::getRemoteCallOnDistributedActorSystem (
11441167 NominalTypeDecl *actorOrSystem, bool isVoidReturn) const {
11451168 assert (actorOrSystem && " distributed actor (or system) decl must be provided" );
0 commit comments