@@ -376,10 +376,18 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
376376
377377static bool checkDistributedTargetResultType (
378378 ModuleDecl *module , ValueDecl *valueDecl,
379- const llvm::SmallPtrSetImpl<ProtocolDecl *> &serializationRequirements,
379+ Type serializationRequirement,
380+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements,
380381 bool diagnose) {
381382 auto &C = valueDecl->getASTContext ();
382383
384+ if (serializationRequirement && serializationRequirement->hasError ()) {
385+ return false ;
386+ }
387+ if ((!serializationRequirement || serializationRequirement->hasError ()) && serializationRequirements.empty ()) {
388+ return false ; // error of the type would be diagnosed elsewhere
389+ }
390+
383391 Type resultType;
384392 if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
385393 resultType = func->mapTypeIntoContext (func->getResultInterfaceType ());
@@ -392,18 +400,27 @@ static bool checkDistributedTargetResultType(
392400 if (resultType->isVoid ())
393401 return false ;
394402
403+
404+ // Collect extra "SerializationRequirement: SomeProtocol" requirements
405+ if (serializationRequirement && !serializationRequirement->hasError ()) {
406+ auto srl = serializationRequirement->getExistentialLayout ();
407+ for (auto s: srl.getProtocols ()) {
408+ serializationRequirements.insert (s);
409+ }
410+ }
411+
395412 auto isCodableRequirement =
396413 checkDistributedSerializationRequirementIsExactlyCodable (
397- C, serializationRequirements );
414+ C, serializationRequirement );
398415
399- for (auto serializationReq : serializationRequirements) {
416+ for (auto serializationReq: serializationRequirements) {
400417 auto conformance =
401418 TypeChecker::conformsToProtocol (resultType, serializationReq, module );
402419 if (conformance.isInvalid ()) {
403420 if (diagnose) {
404421 llvm::StringRef conformanceToSuggest = isCodableRequirement ?
405- " Codable" : // Codable is a typealias, easier to diagnose like that
406- serializationReq->getNameStr ();
422+ " Codable" : // Codable is a typealias, easier to diagnose like that
423+ serializationReq->getNameStr ();
407424
408425 auto diag = valueDecl->diagnose (
409426 diag::distributed_actor_target_result_not_codable,
@@ -418,12 +435,12 @@ static bool checkDistributedTargetResultType(
418435 }
419436 }
420437 } // end if: diagnose
421-
438+
422439 return true ;
423440 }
424441 }
425442
426- return false ;
443+ return false ;
427444}
428445
429446bool swift::checkDistributedActorSystem (const NominalTypeDecl *system) {
@@ -487,66 +504,42 @@ bool CheckDistributedFunctionRequest::evaluate(
487504 }
488505
489506 auto &C = func->getASTContext ();
490- auto DC = func->getDeclContext ();
491507 auto module = func->getParentModule ();
492508
493509 // / If no distributed module is available, then no reason to even try checks.
494510 if (!C.getLoadedModule (C.Id_Distributed ))
495511 return true ;
496512
497- // === All parameters and the result type must conform
498- // SerializationRequirement
499513 llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements;
500- if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
501- serializationRequirements = extractDistributedSerializationRequirements (
502- C, extension->getGenericRequirements ());
503- } else if (auto actor = dyn_cast<ClassDecl>(DC)) {
504- serializationRequirements = getDistributedSerializationRequirementProtocols (
505- getDistributedActorSystemType (actor)->getAnyNominal (),
506- C.getProtocol (KnownProtocolKind::DistributedActorSystem));
507- } else if (isa<ProtocolDecl>(DC)) {
508- if (auto seqReqTy =
509- getConcreteReplacementForMemberSerializationRequirement (func)) {
510- auto layout = seqReqTy->getExistentialLayout ();
511- for (auto req : layout.getProtocols ()) {
512- serializationRequirements.insert (req);
513- }
514- }
515-
516- // The distributed actor constrained protocol has no serialization requirements
517- // or actor system defined, so these will only be enforced, by implementations
518- // of DAs conforming to it, skip checks here.
519- if (serializationRequirements.empty ()) {
520- return false ;
521- }
522- } else {
523- llvm_unreachable (" Distributed function detected in type other than extension, "
524- " distributed actor, or protocol! This should not be possible "
525- " , please file a bug." );
526- }
527-
528- // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
529- auto serializationRequirementIsCodable =
530- checkDistributedSerializationRequirementIsExactlyCodable (
531- C, serializationRequirements);
532-
533- for (auto param : *func->getParameters ()) {
534- // --- Check parameters for 'Codable' conformance
535- auto paramTy = func->mapTypeIntoContext (param->getInterfaceType ());
536-
537- for (auto req : serializationRequirements) {
538- if (TypeChecker::conformsToProtocol (paramTy, req, module ).isInvalid ()) {
539- auto diag = func->diagnose (
540- diag::distributed_actor_func_param_not_codable,
541- param->getArgumentName ().str (), param->getInterfaceType (),
542- func->getDescriptiveKind (),
543- serializationRequirementIsCodable ? " Codable"
544- : req->getNameStr ());
545-
546- if (auto paramNominalTy = paramTy->getAnyNominal ()) {
547- addCodableFixIt (paramNominalTy, diag);
548- } // else, no nominal type to suggest the fixit for, e.g. a closure
549- return true ;
514+ Type serializationReqType = getSerializationRequirementTypesForMember (func, serializationRequirements);
515+
516+ for (auto param: *func->getParameters ()) {
517+ // --- Check the parameter conforming to serialization requirements
518+ if (serializationReqType && !serializationReqType->hasError ()) {
519+ // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
520+ auto serializationRequirementIsCodable =
521+ checkDistributedSerializationRequirementIsExactlyCodable (
522+ C, serializationReqType);
523+
524+ // --- Check parameters for 'SerializationRequirement' conformance
525+ auto paramTy = func->mapTypeIntoContext (param->getInterfaceType ());
526+
527+ auto srl = serializationReqType->getExistentialLayout ();
528+ for (auto req: srl.getProtocols ()) {
529+ if (TypeChecker::conformsToProtocol (paramTy, req, module ).isInvalid ()) {
530+ auto diag = func->diagnose (
531+ diag::distributed_actor_func_param_not_codable,
532+ param->getArgumentName ().str (), param->getInterfaceType (),
533+ func->getDescriptiveKind (),
534+ serializationRequirementIsCodable ? " Codable"
535+ : req->getNameStr ());
536+
537+ if (auto paramNominalTy = paramTy->getAnyNominal ()) {
538+ addCodableFixIt (paramNominalTy, diag);
539+ } // else, no nominal type to suggest the fixit for, e.g. a closure
540+
541+ return true ;
542+ }
550543 }
551544 }
552545
@@ -583,9 +576,10 @@ bool CheckDistributedFunctionRequest::evaluate(
583576 }
584577 }
585578
586- // --- Result type must be either void or a codable type
587- if (checkDistributedTargetResultType (module , func, serializationRequirements,
588- /* diagnose=*/ true )) {
579+ // --- Result type must be either void or a serialization requirement conforming type
580+ if (checkDistributedTargetResultType (
581+ module , func, serializationReqType, serializationRequirements,
582+ /* diagnose=*/ true )) {
589583 return true ;
590584 }
591585
@@ -639,8 +633,11 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
639633 systemDecl,
640634 C.getProtocol (KnownProtocolKind::DistributedActorSystem));
641635
636+ auto serializationRequirement =
637+ getSerializationRequirementTypesForMember (systemVar, serializationRequirements);
638+
642639 auto module = var->getModuleContext ();
643- if (checkDistributedTargetResultType (module , var, serializationRequirements, diagnose)) {
640+ if (checkDistributedTargetResultType (module , var, serializationRequirement, serializationRequirements, diagnose)) {
644641 return true ;
645642 }
646643
@@ -740,13 +737,14 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
740737 (void )nominal->getDistributedActorIDProperty ();
741738}
742739
743- void TypeChecker::checkDistributedFunc (FuncDecl *func) {
740+ bool TypeChecker::checkDistributedFunc (FuncDecl *func) {
744741 if (!func->isDistributed ())
745- return ;
742+ return false ;
746743
747- swift::checkDistributedFunction (func);
744+ return swift::checkDistributedFunction (func);
748745}
749746
747+ // TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks
750748llvm::SmallPtrSet<ProtocolDecl *, 2 >
751749swift::getDistributedSerializationRequirementProtocols (
752750 NominalTypeDecl *nominal, ProtocolDecl *protocol) {
0 commit comments