@@ -81,16 +81,22 @@ struct ArgumentDecoderInfo {
8181 // / The type of `decodeNextArgument` method.
8282 CanSILFunctionType MethodType;
8383
84- // Witness metadata for conformance to DistributedTargetInvocationDecoder
85- // protocol.
84+ // / Witness metadata for conformance to DistributedTargetInvocationDecoder
85+ // / protocol.
8686 WitnessMetadata Witness;
8787
88+ // / Indicates whether `decodeNextArgument` is referenced through
89+ // / a protocol witness thunk.
90+ bool UsesWitnessDispatch;
91+
8892 ArgumentDecoderInfo (llvm::Value *decoder, llvm::Value *decoderType,
8993 llvm::Value *decoderWitnessTable,
9094 FunctionPointer decodeNextArgumentPtr,
91- CanSILFunctionType decodeNextArgumentTy)
95+ CanSILFunctionType decodeNextArgumentTy,
96+ bool usesWitnessDispatch)
9297 : Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93- MethodType (decodeNextArgumentTy) {
98+ MethodType (decodeNextArgumentTy),
99+ UsesWitnessDispatch(usesWitnessDispatch) {
94100 Witness.SelfMetadata = decoderType;
95101 Witness.SelfWitnessTable = decoderWitnessTable;
96102 }
@@ -101,6 +107,20 @@ struct ArgumentDecoderInfo {
101107 return const_cast <WitnessMetadata *>(&Witness);
102108 }
103109
110+ // / Protocol requirements associated with the generic
111+ // / parameter `Argument` of this decode method.
112+ GenericSignature::RequiredProtocols getProtocolRequirements () const {
113+ if (UsesWitnessDispatch)
114+ return {};
115+
116+ auto signature = MethodType->getInvocationGenericSignature ();
117+ auto genericParams = signature.getGenericParams ();
118+
119+ // func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
120+ assert (genericParams.size () == 1 );
121+ return signature->getRequiredProtocols (genericParams.front ());
122+ }
123+
104124 // / Form a callee to a decode method - `decodeNextArgument`.
105125 Callee getCallee () const ;
106126};
@@ -140,6 +160,10 @@ class DistributedAccessor {
140160 llvm::Value *argumentType, const SILParameterInfo ¶m,
141161 Explosion &arguments);
142162
163+ void lookupWitnessTables (llvm::Value *value,
164+ ArrayRef<ProtocolDecl *> protocols,
165+ Explosion &witnessTables);
166+
143167 // / Load witness table addresses (if any) from the given buffer
144168 // / into the given argument explosion.
145169 // /
@@ -385,13 +409,18 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385409 // substitution Argument -> <argument metadata>
386410 decodeArgs.add (argumentType);
387411
412+ // Lookup witness tables for the requirement on the argument type.
413+ lookupWitnessTables (argumentType, decoder.getProtocolRequirements (),
414+ decodeArgs);
415+
388416 Address calleeErrorSlot;
389417 llvm::Value *decodeError = nullptr ;
390418
391419 emission->begin ();
392420 {
393421 emission->setArgs (decodeArgs, /* isOutlined=*/ false ,
394- /* witnessMetadata=*/ decoder.getWitnessMetadata ());
422+ decoder.UsesWitnessDispatch ? decoder.getWitnessMetadata ()
423+ : nullptr );
395424
396425 Explosion result;
397426 emission->emitToExplosion (result, /* isOutlined=*/ false );
@@ -492,6 +521,43 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492521 }
493522}
494523
524+ void DistributedAccessor::lookupWitnessTables (
525+ llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
526+ Explosion &witnessTables) {
527+ if (protocols.empty ())
528+ return ;
529+
530+ auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer ();
531+
532+ for (auto *protocol : protocols) {
533+ if (!Lowering::TypeConverter::protocolRequiresWitnessTable (protocol))
534+ continue ;
535+
536+ auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
537+ auto *witnessTable =
538+ IGF.Builder .CreateCall (conformsToProtocol, {value, protocolDescriptor});
539+
540+ auto failBB = IGF.createBasicBlock (" missing-witness" );
541+ auto contBB = IGF.createBasicBlock (" " );
542+
543+ auto isNull = IGF.Builder .CreateICmpEQ (
544+ witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
545+ IGF.Builder .CreateCondBr (isNull, failBB, contBB);
546+
547+ // This operation shouldn't fail because runtime should have checked that
548+ // a particular argument type conforms to `SerializationRequirement`
549+ // of the distributed actor the decoder is used for. If it does fail
550+ // then accessor should trap.
551+ {
552+ IGF.Builder .emitBlock (failBB);
553+ IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
554+ }
555+
556+ IGF.Builder .emitBlock (contBB);
557+ witnessTables.add (witnessTable);
558+ }
559+ }
560+
495561void DistributedAccessor::emitLoadOfWitnessTables (llvm::Value *witnessTables,
496562 llvm::Value *numTables,
497563 unsigned expectedWitnessTables,
@@ -731,21 +797,91 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
731797ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder (
732798 llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733799 auto &C = IGM.Context ;
800+ auto *actor = getDistributedActorOf (Target);
801+ auto expansionContext = IGM.getMaximalTypeExpansionContext ();
802+
803+ auto *decodeFn = C.getDistributedActorArgumentDecodingMethod (actor);
804+
805+ // If distributed actor is generic over actor system, we have to
806+ // use witness to reference `decodeNextArgument`.
807+ if (!decodeFn) {
808+ auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl ();
809+ auto decodeNextArgRequirement =
810+ decoderProtocol->getSingleRequirement (C.Id_decodeNextArgument );
811+ assert (decodeNextArgRequirement);
812+ SILDeclRef decodeNextArgumentRef (decodeNextArgRequirement);
813+
814+ llvm::Constant *fnPtr =
815+ IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
816+ auto fnType = IGM.getSILTypes ().getConstantFunctionType (
817+ IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
818+
819+ auto sig = IGM.getSignature (fnType);
820+ auto fn = FunctionPointer::forDirect (fnType, fnPtr,
821+ /* secondaryValue=*/ nullptr , sig, true );
822+ return {decoder, decoderTy, witnessTable,
823+ fn, fnType, /* usesWitnessDispatch=*/ true };
824+ }
825+
826+ auto methodTy = IGM.getSILTypes ().getConstantFunctionType (
827+ expansionContext, SILDeclRef (decodeFn));
828+
829+ auto fpKind = FunctionPointerKind::defaultAsync ();
830+ auto signature = IGM.getSignature (methodTy, fpKind);
831+
832+ // If the decoder class is `final`, let's emit a direct reference.
833+ auto *decoderDecl = decodeFn->getDeclContext ()->getSelfNominalTypeDecl ();
734834
735- auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl ();
736- SILDeclRef decodeNextArgumentRef (
737- decoderProtocol->getSingleRequirement (C.Id_decodeNextArgument ));
835+ // If decoder is a class, need to load it first because generic parameter
836+ // is passed indirectly. This is good for structs and enums because
837+ // `decodeNextArgument` is a mutating method, but not for classes because
838+ // in that case heap object is mutated directly.
839+ bool usesDispatchThunk = false ;
738840
739- llvm::Constant *fnPtr =
740- IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
841+ if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
842+ auto selfTy = methodTy->getSelfParameter ().getSILStorageType (
843+ IGM.getSILModule (), methodTy, expansionContext);
741844
742- auto fnType = IGM.getSILTypes ().getConstantFunctionType (
743- IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
845+ auto &classTI = IGM.getTypeInfo (selfTy).as <ClassTypeInfo>();
846+ auto &classLayout = classTI.getClassLayout (IGM, selfTy,
847+ /* forBackwardDeployment=*/ false );
848+
849+ llvm::Value *typedDecoderPtr = IGF.Builder .CreateBitCast (
850+ decoder, classLayout.getType ()->getPointerTo ()->getPointerTo ());
851+
852+ Explosion instance;
853+
854+ classTI.loadAsTake (IGF,
855+ {typedDecoderPtr, classTI.getStorageType (),
856+ classTI.getBestKnownAlignment ()},
857+ instance);
858+
859+ decoder = instance.claimNext ();
860+
861+ // / When using library evolution functions have another "dispatch thunk"
862+ // / so we must use this instead of the decodeFn directly.
863+ usesDispatchThunk =
864+ getMethodDispatch (decodeFn) == swift::MethodDispatch::Class &&
865+ classDecl->hasResilientMetadata ();
866+ }
867+
868+ FunctionPointer methodPtr;
869+
870+ if (usesDispatchThunk) {
871+ auto fnPtr = IGM.getAddrOfDispatchThunk (SILDeclRef (decodeFn), NotForDefinition);
872+ methodPtr = FunctionPointer::createUnsigned (
873+ methodTy, fnPtr, signature, /* useSignature=*/ true );
874+ } else {
875+ SILFunction *decodeSILFn = IGM.getSILModule ().lookUpFunction (SILDeclRef (decodeFn));
876+ auto fnPtr = IGM.getAddrOfSILFunction (decodeSILFn, NotForDefinition,
877+ /* isDynamicallyReplaceable=*/ false );
878+ methodPtr = FunctionPointer::forDirect (
879+ classifyFunctionPointerKind (decodeSILFn), fnPtr,
880+ /* secondaryValue=*/ nullptr , signature);
881+ }
744882
745- auto sig = IGM.getSignature (fnType);
746- auto fn = FunctionPointer::forDirect (fnType, fnPtr,
747- /* secondaryValue=*/ nullptr , sig, true );
748- return {decoder, decoderTy, witnessTable, fn, fnType};
883+ return {decoder, decoderTy, witnessTable,
884+ methodPtr, methodTy, /* usesWitnessDispatch=*/ false };
749885}
750886
751887SILType DistributedAccessor::getResultType () const {
0 commit comments