@@ -70,6 +70,19 @@ llvm::Value *irgen::emitDistributedActorInitializeRemote(
7070
7171namespace {
7272
73+ using ThunkOrRequirement = llvm::PointerUnion<SILFunction *, AbstractFunctionDecl *>;
74+
75+ static LinkEntity
76+ getAccessorLinking (ThunkOrRequirement accessorFor) {
77+ if (auto *method = accessorFor.dyn_cast <SILFunction *>()) {
78+ assert (method->isDistributed ());
79+ return LinkEntity::forDistributedTargetAccessor (method);
80+ }
81+
82+ auto *requirement = accessorFor.get <AbstractFunctionDecl *>();
83+ return LinkEntity::forDistributedTargetAccessor (requirement);
84+ }
85+
7386struct ArgumentDecoderInfo {
7487 // / The instance of the decoder this information belongs to.
7588 llvm::Value *Decoder;
@@ -128,32 +141,47 @@ struct ArgumentDecoderInfo {
128141struct AccessorTarget {
129142private:
130143 IRGenFunction &IGF;
131- SILFunction * Target;
144+ ThunkOrRequirement Target;
132145
133146 CanSILFunctionType Type;
134147
148+ mutable std::optional<WitnessMetadata> Witness;
149+
135150public:
136- AccessorTarget (IRGenFunction &IGF, SILFunction *target)
137- : IGF(IGF), Target(target), Type(target->getLoweredFunctionType ()) {}
151+ AccessorTarget (IRGenFunction &IGF, ThunkOrRequirement target)
152+ : IGF(IGF), Target(target) {
153+ if (auto *thunk = target.dyn_cast <SILFunction *>()) {
154+ Type = thunk->getLoweredFunctionType ();
155+ } else {
156+ auto *requirement = target.get <AbstractFunctionDecl *>();
157+ Type = IGF.IGM .getSILTypes ().getConstantFunctionType (
158+ IGF.IGM .getMaximalTypeExpansionContext (),
159+ SILDeclRef (requirement).asDistributed ());
160+ }
161+ }
138162
139- DeclContext *getDeclContext () const { return Target->getDeclContext (); }
163+ DeclContext *getDeclContext () const {
164+ if (auto *thunk = Target.dyn_cast <SILFunction *>())
165+ return thunk->getDeclContext ();
166+ return Target.get <AbstractFunctionDecl *>();
167+ }
140168
141169 CanSILFunctionType getType () const { return Type; }
142170
143- bool isGeneric () const { return Target->isGeneric (); }
171+ bool isGeneric () const {
172+ auto sig = Type->getInvocationGenericSignature ();
173+ return sig && !sig->areAllParamsConcrete ();
174+ }
144175
145- Callee getCallee (llvm::Value *actorSelf) const ;
176+ Callee getCallee (llvm::Value *actorSelf);
146177
147- LinkEntity getLinking () const {
148- return LinkEntity::forDistributedTargetAccessor (Target);
149- }
178+ LinkEntity getLinking () const { return getAccessorLinking (Target); }
150179
151- WitnessMetadata *getWitnessMetadata () const {
152- return nullptr ;
153- }
180+ // / Witness metadata is computed lazily upon the first request.
181+ WitnessMetadata *getWitnessMetadata (llvm::Value *actorSelf);
154182
155183public:
156- FunctionPointer getPointerToTarget () const ;
184+ FunctionPointer getPointerToTarget (llvm::Value *actorSelf) ;
157185};
158186
159187class DistributedAccessor {
@@ -175,7 +203,7 @@ class DistributedAccessor {
175203 SmallVector<std::pair<Address, /* type=*/ llvm::Value *>, 4 > LoadedArguments;
176204
177205public:
178- DistributedAccessor (IRGenFunction &IGF, SILFunction * target,
206+ DistributedAccessor (IRGenFunction &IGF, ThunkOrRequirement target,
179207 CanSILFunctionType accessorTy);
180208
181209 void emit ();
@@ -313,27 +341,24 @@ static CanSILFunctionType getAccessorType(IRGenModule &IGM) {
313341}
314342
315343llvm::Function *
316- IRGenModule::getAddrOfDistributedTargetAccessor (SILFunction *F ,
344+ IRGenModule::getAddrOfDistributedTargetAccessor (LinkEntity accessor ,
317345 ForDefinition_t forDefinition) {
318- auto entity = LinkEntity::forDistributedTargetAccessor (F);
319-
320- llvm::Function *&entry = GlobalFuncs[entity];
346+ llvm::Function *&entry = GlobalFuncs[accessor];
321347 if (entry) {
322348 if (forDefinition)
323- updateLinkageForDefinition (*this , entry, entity );
349+ updateLinkageForDefinition (*this , entry, accessor );
324350 return entry;
325351 }
326352
327353 Signature signature = getSignature (getAccessorType (*this ));
328- LinkInfo link = LinkInfo::get (*this , entity , forDefinition);
354+ LinkInfo link = LinkInfo::get (*this , accessor , forDefinition);
329355
330356 return createFunction (*this , link, signature);
331357}
332358
333- void IRGenModule::emitDistributedTargetAccessor (SILFunction *target) {
334- assert (target->isDistributed ());
335-
336- auto *f = getAddrOfDistributedTargetAccessor (target, ForDefinition);
359+ void IRGenModule::emitDistributedTargetAccessor (ThunkOrRequirement target) {
360+ auto *f = getAddrOfDistributedTargetAccessor (getAccessorLinking (target),
361+ ForDefinition);
337362
338363 if (!f->isDeclaration ())
339364 return ;
@@ -343,7 +368,7 @@ void IRGenModule::emitDistributedTargetAccessor(SILFunction *target) {
343368}
344369
345370DistributedAccessor::DistributedAccessor (IRGenFunction &IGF,
346- SILFunction * target,
371+ ThunkOrRequirement target,
347372 CanSILFunctionType accessorTy)
348373 : IGM(IGF.IGM), IGF(IGF), Target(IGF, target), AccessorType(accessorTy),
349374 AsyncLayout(getAsyncContextLayout(IGM, AccessorType, AccessorType,
@@ -540,6 +565,35 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
540565 }
541566}
542567
568+ static llvm::Value *lookupWitnessTable (IRGenFunction &IGF, llvm::Value *witness,
569+ ProtocolDecl *protocol) {
570+ assert (Lowering::TypeConverter::protocolRequiresWitnessTable (protocol));
571+
572+ auto &IGM = IGF.IGM ;
573+ auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
574+ auto *witnessTable = IGF.Builder .CreateCall (
575+ IGM.getConformsToProtocolFunctionPointer (), {witness, protocolDescriptor});
576+
577+ auto failBB = IGF.createBasicBlock (" missing-witness" );
578+ auto contBB = IGF.createBasicBlock (" " );
579+
580+ auto isNull = IGF.Builder .CreateICmpEQ (
581+ witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
582+ IGF.Builder .CreateCondBr (isNull, failBB, contBB);
583+
584+ // This operation shouldn't fail because the compuler should have
585+ // checked that the given witness conforms to the protocol. If it
586+ // does fail then accessor should trap.
587+ {
588+ IGF.Builder .emitBlock (failBB);
589+ IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
590+ }
591+
592+ IGF.Builder .emitBlock (contBB);
593+
594+ return witnessTable;
595+ }
596+
543597void DistributedAccessor::lookupWitnessTables (
544598 llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
545599 Explosion &witnessTables) {
@@ -552,28 +606,7 @@ void DistributedAccessor::lookupWitnessTables(
552606 if (!Lowering::TypeConverter::protocolRequiresWitnessTable (protocol))
553607 continue ;
554608
555- auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
556- auto *witnessTable =
557- IGF.Builder .CreateCall (conformsToProtocol, {value, protocolDescriptor});
558-
559- auto failBB = IGF.createBasicBlock (" missing-witness" );
560- auto contBB = IGF.createBasicBlock (" " );
561-
562- auto isNull = IGF.Builder .CreateICmpEQ (
563- witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
564- IGF.Builder .CreateCondBr (isNull, failBB, contBB);
565-
566- // This operation shouldn't fail because runtime should have checked that
567- // a particular argument type conforms to `SerializationRequirement`
568- // of the distributed actor the decoder is used for. If it does fail
569- // then accessor should trap.
570- {
571- IGF.Builder .emitBlock (failBB);
572- IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
573- }
574-
575- IGF.Builder .emitBlock (contBB);
576- witnessTables.add (witnessTable);
609+ witnessTables.add (lookupWitnessTable (IGF, value, protocol));
577610 }
578611}
579612
@@ -759,7 +792,7 @@ void DistributedAccessor::emit() {
759792
760793 emission->begin ();
761794 emission->setArgs (arguments, /* isOutlined=*/ false ,
762- Target.getWitnessMetadata ());
795+ Target.getWitnessMetadata (actorSelf ));
763796
764797 // Load result of the thunk into the location provided by the caller.
765798 // This would only generate code for direct results, if thunk has an
@@ -790,39 +823,75 @@ void DistributedAccessor::emit() {
790823 }
791824}
792825
793- FunctionPointer AccessorTarget::getPointerToTarget () const {
826+ FunctionPointer AccessorTarget::getPointerToTarget (llvm::Value *actorSelf) {
794827 auto &IGM = IGF.IGM ;
795- auto fpKind = classifyFunctionPointerKind (Target);
796- auto signature = IGM.getSignature (Type, fpKind);
797828
798- auto *fnPtr =
799- llvm::ConstantExpr::getBitCast (IGM.getAddrOfAsyncFunctionPointer (Target),
800- signature.getType ()->getPointerTo ());
829+ if (auto *thunk = Target.dyn_cast <SILFunction *>()) {
830+ auto fpKind = classifyFunctionPointerKind (thunk);
831+ auto signature = IGM.getSignature (Type, fpKind);
832+
833+ auto *fnPtr =
834+ llvm::ConstantExpr::getBitCast (IGM.getAddrOfAsyncFunctionPointer (thunk),
835+ signature.getType ()->getPointerTo ());
836+
837+ return FunctionPointer::forDirect (
838+ FunctionPointer::Kind (Type), fnPtr,
839+ IGM.getAddrOfSILFunction (thunk, NotForDefinition), signature);
840+ }
841+
842+ auto *requirementDecl = Target.get <AbstractFunctionDecl *>();
843+ auto *protocol = requirementDecl->getDeclContext ()->getSelfProtocolDecl ();
844+ SILDeclRef requirementRef = SILDeclRef (requirementDecl).asDistributed ();
845+
846+ if (!IGM.isResilient (protocol, ResilienceExpansion::Maximal)) {
847+ auto *witness = getWitnessMetadata (actorSelf);
848+ return emitWitnessMethodValue (IGF, witness->SelfWitnessTable ,
849+ requirementRef);
850+ }
801851
802- return FunctionPointer::forDirect (
803- FunctionPointer::Kind (Type), fnPtr,
804- IGM.getAddrOfSILFunction (Target, NotForDefinition), signature);
852+ auto fnPtr = IGM.getAddrOfDispatchThunk (requirementRef, NotForDefinition);
853+ auto sig = IGM.getSignature (Type);
854+ return FunctionPointer::forDirect (Type, fnPtr,
855+ /* secondaryValue=*/ nullptr , sig, true );
805856}
806857
807- Callee AccessorTarget::getCallee (llvm::Value *actorSelf) const {
858+ Callee AccessorTarget::getCallee (llvm::Value *actorSelf) {
808859 CalleeInfo info{Type, Type, SubstitutionMap ()};
809- return {std::move (info), getPointerToTarget (), actorSelf};
860+ return {std::move (info), getPointerToTarget (actorSelf), actorSelf};
861+ }
862+
863+ WitnessMetadata *AccessorTarget::getWitnessMetadata (llvm::Value *actorSelf) {
864+ if (Target.is <SILFunction *>())
865+ return nullptr ;
866+
867+ if (!Witness) {
868+ WitnessMetadata witness;
869+
870+ auto *requirement = Target.get <AbstractFunctionDecl *>();
871+ auto *protocol = requirement->getDeclContext ()->getSelfProtocolDecl ();
872+ assert (protocol);
873+
874+ witness.SelfMetadata = actorSelf;
875+ witness.SelfWitnessTable = lookupWitnessTable (
876+ IGF, emitHeapMetadataRefForUnknownHeapObject (IGF, actorSelf), protocol);
877+
878+ Witness = witness;
879+ }
880+
881+ return &(*Witness);
810882}
811883
812884ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder (
813885 llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
814886 auto &C = IGM.Context ;
815- DeclContext *targetContext = Target.getDeclContext ();
887+ auto *thunk = cast<AbstractFunctionDecl>( Target.getDeclContext () );
816888 auto expansionContext = IGM.getMaximalTypeExpansionContext ();
817889
818890 // / If the context was a function, unwrap it and look for the decode method
819891 // / based off a concrete class; If we're not in a concrete class, we'll be
820892 // / using a witness for the decoder so returning null is okey.
821- FuncDecl *decodeFn = nullptr ;
822- if (auto func = dyn_cast<AbstractFunctionDecl>(targetContext)) {
823- decodeFn = C.getDistributedActorArgumentDecodingMethod (
824- func->getDeclContext ()->getSelfNominalTypeDecl ());
825- }
893+ FuncDecl *decodeFn = C.getDistributedActorArgumentDecodingMethod (
894+ thunk->getDeclContext ()->getSelfNominalTypeDecl ());
826895
827896 // If distributed actor is generic over actor system, we have to
828897 // use witness to reference `decodeNextArgument`.
0 commit comments