@@ -81,6 +81,10 @@ struct ArgumentDecoderInfo {
8181 // / The type of `decodeNextArgument` method.
8282 CanSILFunctionType MethodType;
8383
84+ // / Protocol requirements associated with the generic
85+ // / parameter `Argument` of this decode method.
86+ GenericSignature::RequiredProtocols ProtocolRequirements;
87+
8488 // Witness metadata for conformance to DistributedTargetInvocationDecoder
8589 // protocol.
8690 WitnessMetadata Witness;
@@ -90,19 +94,31 @@ struct ArgumentDecoderInfo {
9094 FunctionPointer decodeNextArgumentPtr,
9195 CanSILFunctionType decodeNextArgumentTy)
9296 : Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93- MethodType (decodeNextArgumentTy) {
97+ MethodType (decodeNextArgumentTy),
98+ ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
9499 Witness.SelfMetadata = decoderType;
95100 Witness.SelfWitnessTable = decoderWitnessTable;
96101 }
97102
98103 CanSILFunctionType getMethodType () const { return MethodType; }
99104
100- WitnessMetadata * getWitnessMetadata () const {
101- return const_cast <WitnessMetadata *>(&Witness) ;
105+ ArrayRef<ProtocolDecl *> getProtocolRequirements () const {
106+ return ProtocolRequirements ;
102107 }
103108
104109 // / Form a callee to a decode method - `decodeNextArgument`.
105110 Callee getCallee () const ;
111+
112+ private:
113+ static GenericSignature::RequiredProtocols
114+ findProtocolRequirements (CanSILFunctionType decodeMethodTy) {
115+ auto signature = decodeMethodTy->getInvocationGenericSignature ();
116+ auto genericParams = signature.getGenericParams ();
117+
118+ // func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
119+ assert (genericParams.size () == 1 );
120+ return signature->getRequiredProtocols (genericParams.front ());
121+ }
106122};
107123
108124class DistributedAccessor {
@@ -140,6 +156,10 @@ class DistributedAccessor {
140156 llvm::Value *argumentType, const SILParameterInfo ¶m,
141157 Explosion &arguments);
142158
159+ void lookupWitnessTables (llvm::Value *value,
160+ ArrayRef<ProtocolDecl *> protocols,
161+ Explosion &witnessTables);
162+
143163 // / Load witness table addresses (if any) from the given buffer
144164 // / into the given argument explosion.
145165 // /
@@ -385,13 +405,17 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385405 // substitution Argument -> <argument metadata>
386406 decodeArgs.add (argumentType);
387407
408+ // Lookup witness tables for the requirement on the argument type.
409+ lookupWitnessTables (argumentType, decoder.getProtocolRequirements (),
410+ decodeArgs);
411+
388412 Address calleeErrorSlot;
389413 llvm::Value *decodeError = nullptr ;
390414
391415 emission->begin ();
392416 {
393417 emission->setArgs (decodeArgs, /* isOutlined=*/ false ,
394- /* witnessMetadata=*/ decoder. getWitnessMetadata () );
418+ /* witnessMetadata=*/ nullptr );
395419
396420 Explosion result;
397421 emission->emitToExplosion (result, /* isOutlined=*/ false );
@@ -492,6 +516,37 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492516 }
493517}
494518
519+ void DistributedAccessor::lookupWitnessTables (
520+ llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
521+ Explosion &witnessTables) {
522+ auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer ();
523+
524+ for (auto *protocol : protocols) {
525+ auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
526+ auto *witnessTable =
527+ IGF.Builder .CreateCall (conformsToProtocol, {value, protocolDescriptor});
528+
529+ auto failBB = IGF.createBasicBlock (" missing-witness" );
530+ auto contBB = IGF.createBasicBlock (" " );
531+
532+ auto isNull = IGF.Builder .CreateICmpEQ (
533+ witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
534+ IGF.Builder .CreateCondBr (isNull, failBB, contBB);
535+
536+ // This operation shouldn't fail because runtime should have checked that
537+ // a particular argument type conforms to `SerializationRequirement`
538+ // of the distributed actor the decoder is used for. If it does fail
539+ // then accessor should trap.
540+ {
541+ IGF.Builder .emitBlock (failBB);
542+ IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
543+ }
544+
545+ IGF.Builder .emitBlock (contBB);
546+ witnessTables.add (witnessTable);
547+ }
548+ }
549+
495550void DistributedAccessor::emitLoadOfWitnessTables (llvm::Value *witnessTables,
496551 llvm::Value *numTables,
497552 unsigned expectedWitnessTables,
@@ -730,22 +785,70 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
730785
731786ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder (
732787 llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733- auto &C = IGM.Context ;
788+ auto *actor = getDistributedActorOf (Target);
789+ auto expansionContext = IGM.getMaximalTypeExpansionContext ();
790+
791+ auto *decodeFn = IGM.Context .getDistributedActorArgumentDecodingMethod (actor);
792+ assert (decodeFn && " no suitable decoder?" );
793+
794+ auto methodTy = IGM.getSILTypes ().getConstantFunctionType (
795+ expansionContext, SILDeclRef (decodeFn));
796+
797+ auto fpKind = FunctionPointerKind::defaultAsync ();
798+ auto signature = IGM.getSignature (methodTy, fpKind);
799+
800+ // If the decoder class is `final`, let's emit a direct reference.
801+ auto *decoderDecl = decodeFn->getDeclContext ()->getSelfNominalTypeDecl ();
802+
803+ // If decoder is a class, need to load it first because generic parameter
804+ // is passed indirectly. This is good for structs and enums because
805+ // `decodeNextArgument` is a mutating method, but not for classes because
806+ // in that case heap object is mutated directly.
807+ bool usesDispatchThunk = false ;
734808
735- auto decoderProtocol = C. getDistributedTargetInvocationDecoderDecl ();
736- SILDeclRef decodeNextArgumentRef (
737- decoderProtocol-> getSingleRequirement (C. Id_decodeNextArgument ) );
809+ if ( auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
810+ auto selfTy = methodTy-> getSelfParameter (). getSILStorageType (
811+ IGM. getSILModule (), methodTy, expansionContext );
738812
739- llvm::Constant *fnPtr =
740- IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
813+ auto &classTI = IGM.getTypeInfo (selfTy).as <ClassTypeInfo>();
814+ auto &classLayout = classTI.getClassLayout (IGM, selfTy,
815+ /* forBackwardDeployment=*/ false );
741816
742- auto fnType = IGM.getSILTypes ().getConstantFunctionType (
743- IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
817+ llvm::Value *typedDecoderPtr = IGF.Builder .CreateBitCast (
818+ decoder, classLayout.getType ()->getPointerTo ()->getPointerTo ());
819+
820+ Explosion instance;
821+
822+ classTI.loadAsTake (IGF,
823+ {typedDecoderPtr, classTI.getStorageType (),
824+ classTI.getBestKnownAlignment ()},
825+ instance);
826+
827+ decoder = instance.claimNext ();
828+
829+ // / When using library evolution functions have another "dispatch thunk"
830+ // / so we must use this instead of the decodeFn directly.
831+ usesDispatchThunk =
832+ getMethodDispatch (decodeFn) == swift::MethodDispatch::Class &&
833+ classDecl->hasResilientMetadata ();
834+ }
835+
836+ FunctionPointer methodPtr;
837+
838+ if (usesDispatchThunk) {
839+ auto fnPtr = IGM.getAddrOfDispatchThunk (SILDeclRef (decodeFn), NotForDefinition);
840+ methodPtr = FunctionPointer::createUnsigned (
841+ methodTy, fnPtr, signature, /* useSignature=*/ true );
842+ } else {
843+ SILFunction *decodeSILFn = IGM.getSILModule ().lookUpFunction (SILDeclRef (decodeFn));
844+ auto fnPtr = IGM.getAddrOfSILFunction (decodeSILFn, NotForDefinition,
845+ /* isDynamicallyReplaceable=*/ false );
846+ methodPtr = FunctionPointer::forDirect (
847+ classifyFunctionPointerKind (decodeSILFn), fnPtr,
848+ /* secondaryValue=*/ nullptr , signature);
849+ }
744850
745- auto sig = IGM.getSignature (fnType);
746- auto fn = FunctionPointer::forDirect (fnType, fnPtr,
747- /* secondaryValue=*/ nullptr , sig, true );
748- return {decoder, decoderTy, witnessTable, fn, fnType};
851+ return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
749852}
750853
751854SILType DistributedAccessor::getResultType () const {
0 commit comments