@@ -1711,6 +1711,8 @@ class LoadableByAddress : public SILModuleTransform {
17111711 bool recreateDifferentiabilityWitnessFunction (
17121712 SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
17131713
1714+ bool shouldTransformGlobal (SILGlobalVariable *global);
1715+
17141716private:
17151717 llvm::SetVector<SILFunction *> modFuncs;
17161718 llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2907,6 +2909,24 @@ void LoadableByAddress::updateLoweredTypes(SILFunction *F) {
29072909 F->rewriteLoweredTypeUnsafe (newFuncTy);
29082910}
29092911
2912+ bool LoadableByAddress::shouldTransformGlobal (SILGlobalVariable *global) {
2913+ SILInstruction *init = global->getStaticInitializerValue ();
2914+ if (!init)
2915+ return false ;
2916+ auto silTy = global->getLoweredType ();
2917+ if (!isa<SILFunctionType>(silTy.getASTType ()))
2918+ return false ;
2919+
2920+ auto *decl = global->getDecl ();
2921+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
2922+ decl ? decl->getDeclContext () : nullptr );
2923+ auto silFnTy = global->getLoweredFunctionType ();
2924+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
2925+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod))
2926+ return true ;
2927+ return false ;
2928+ }
2929+
29102930// / The entry point to this function transformation.
29112931void LoadableByAddress::run () {
29122932 // Set the SIL state before the PassManager has a chance to run
@@ -2922,10 +2942,23 @@ void LoadableByAddress::run() {
29222942
29232943 // Scan the module for all references of the modified functions:
29242944 llvm::SetVector<FunctionRefBaseInst *> funcRefs;
2945+ llvm::SetVector<SILInstruction *> globalRefs;
29252946 for (SILFunction &CurrF : *getModule ()) {
29262947 for (SILBasicBlock &BB : CurrF) {
29272948 for (SILInstruction &I : BB) {
2928- if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2949+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(&I)) {
2950+ auto *global = allocGlobal->getReferencedGlobal ();
2951+ if (shouldTransformGlobal (global))
2952+ globalRefs.insert (allocGlobal);
2953+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(&I)) {
2954+ auto *global = globalAddr->getReferencedGlobal ();
2955+ if (shouldTransformGlobal (global))
2956+ globalRefs.insert (globalAddr);
2957+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(&I)) {
2958+ auto *global = globalVal->getReferencedGlobal ();
2959+ if (shouldTransformGlobal (global))
2960+ globalRefs.insert (globalVal);
2961+ } else if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
29292962 SILFunction *RefF = FRI->getInitiallyReferencedFunction ();
29302963 if (modFuncs.count (RefF) != 0 ) {
29312964 // Go over the uses and add them to lists to modify
@@ -2954,7 +2987,7 @@ void LoadableByAddress::run() {
29542987 case SILInstructionKind::LinearFunctionExtractInst:
29552988 case SILInstructionKind::DifferentiableFunctionExtractInst: {
29562989 conversionInstrs.insert (
2957- cast<SingleValueInstruction>(currInstr));
2990+ cast<SingleValueInstruction>(currInstr));
29582991 break ;
29592992 }
29602993 case SILInstructionKind::BuiltinInst: {
@@ -3032,6 +3065,99 @@ void LoadableByAddress::run() {
30323065 updateLoweredTypes (F);
30333066 }
30343067
3068+ auto computeNewResultType = [&](SILType ty, IRGenModule *mod) -> SILType {
3069+ auto currSILFunctionType = ty.castTo <SILFunctionType>();
3070+ GenericEnvironment *genEnv =
3071+ getSubstGenericEnvironment (currSILFunctionType);
3072+ return SILType::getPrimitiveObjectType (
3073+ MapperCache.getNewSILFunctionType (genEnv, currSILFunctionType, *mod));
3074+ };
3075+
3076+ // Update globals' initializer.
3077+ SmallVector<SILGlobalVariable *, 16 > deadGlobals;
3078+ for (SILGlobalVariable &global : getModule ()->getSILGlobals ()) {
3079+ SILInstruction *init = global.getStaticInitializerValue ();
3080+ if (!init)
3081+ continue ;
3082+ auto silTy = global.getLoweredType ();
3083+ if (!isa<SILFunctionType>(silTy.getASTType ()))
3084+ continue ;
3085+ auto *decl = global.getDecl ();
3086+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
3087+ decl ? decl->getDeclContext () : nullptr );
3088+ auto silFnTy = global.getLoweredFunctionType ();
3089+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
3090+
3091+ // Update the global's type.
3092+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod)) {
3093+ auto newSILFnType =
3094+ MapperCache.getNewSILFunctionType (genEnv, silFnTy, *currIRMod);
3095+ global.unsafeSetLoweredType (
3096+ SILType::getPrimitiveObjectType (newSILFnType));
3097+
3098+ // Rewrite the init basic block...
3099+ SmallVector<SILInstruction *, 8 > initBlockInsts;
3100+ for (auto it = global.begin (), end = global.end (); it != end; ++it) {
3101+ initBlockInsts.push_back (const_cast <SILInstruction *>(&*it));
3102+ }
3103+ for (auto *oldInst : initBlockInsts) {
3104+ if (auto *f = dyn_cast<FunctionRefInst>(oldInst)) {
3105+ SILBuilder builder (&global);
3106+ auto *newInst = builder.createFunctionRef (
3107+ f->getLoc (), f->getInitiallyReferencedFunction (), f->getKind ());
3108+ f->replaceAllUsesWith (newInst);
3109+ global.unsafeRemove (f, *getModule ());
3110+ } else if (auto *cvt = dyn_cast<ConvertFunctionInst>(oldInst)) {
3111+ auto newType = computeNewResultType (cvt->getType (), currIRMod);
3112+ SILBuilder builder (&global);
3113+ auto *newInst = builder.createConvertFunction (
3114+ cvt->getLoc (), cvt->getOperand (), newType,
3115+ cvt->withoutActuallyEscaping ());
3116+ cvt->replaceAllUsesWith (newInst);
3117+ global.unsafeRemove (cvt, *getModule ());
3118+ } else if (auto *thinToThick =
3119+ dyn_cast<ThinToThickFunctionInst>(oldInst)) {
3120+ auto newType =
3121+ computeNewResultType (thinToThick->getType (), currIRMod);
3122+ SILBuilder builder (&global);
3123+ auto *newInstr = builder.createThinToThickFunction (
3124+ thinToThick->getLoc (), thinToThick->getOperand (), newType);
3125+ thinToThick->replaceAllUsesWith (newInstr);
3126+ global.unsafeRemove (thinToThick, *getModule ());
3127+ } else {
3128+ auto *sv = cast<SingleValueInstruction>(oldInst);
3129+ auto *newInst = cast<SingleValueInstruction>(oldInst->clone ());
3130+ global.unsafeAppend (newInst);
3131+ sv->replaceAllUsesWith (newInst);
3132+ global.unsafeRemove (oldInst, *getModule ());
3133+ }
3134+ }
3135+ }
3136+ }
3137+
3138+ // Rewrite global variable users.
3139+ for (auto *inst : globalRefs) {
3140+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(inst)) {
3141+ // alloc_global produces no results.
3142+ SILBuilderWithScope builder (inst);
3143+ builder.createAllocGlobal (allocGlobal->getLoc (),
3144+ allocGlobal->getReferencedGlobal ());
3145+ allocGlobal->eraseFromParent ();
3146+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(inst)) {
3147+ SILBuilderWithScope builder (inst);
3148+ auto *newInst = builder.createGlobalAddr (
3149+ globalAddr->getLoc (), globalAddr->getReferencedGlobal ());
3150+ globalAddr->replaceAllUsesWith (newInst);
3151+ globalAddr->eraseFromParent ();
3152+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(inst)) {
3153+ SILBuilderWithScope builder (inst);
3154+ auto *newInst = builder.createGlobalValue (
3155+ globalVal->getLoc (), globalVal->getReferencedGlobal ());
3156+ globalVal->replaceAllUsesWith (newInst);
3157+ globalVal->eraseFromParent ();
3158+ }
3159+ }
3160+
30353161 // Update all references:
30363162 // Note: We don't need to update the witness tables and vtables
30373163 // They just contain a pointer to the function
0 commit comments