@@ -1065,16 +1065,94 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) {
10651065
10661066SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst (
10671067 ConvertEscapeToNoEscapeInst *Cvt) {
1068- auto *OrigThinToThick =
1069- dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ());
1070- if (!OrigThinToThick)
1071- return nullptr ;
1072- auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1073- auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
1068+ // Rewrite conversion of `convert_function` of `thin_to_thick_function` as
1069+ // conversion of `thin_to_thick_function` of `convert_function`.
1070+ //
1071+ // (convert_escape_to_noescape (convert_function (thin_to_thick_function x)))
1072+ // =>
1073+ // (convert_escape_to_noescape (thin_to_thick_function (convert_function x)))
1074+ //
1075+ // This unblocks the `thin_to_thick_function` peephole optimization below.
1076+ if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getConverted ())) {
1077+ if (CFI->getSingleUse ()) {
1078+ if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getConverted ())) {
1079+ if (TTTFI->getSingleUse ()) {
1080+ auto convertedThickType = CFI->getType ().castTo <SILFunctionType>();
1081+ auto convertedThinType = convertedThickType->getWithRepresentation (
1082+ SILFunctionTypeRepresentation::Thin);
1083+ auto *newCFI = Builder.createConvertFunction (
1084+ CFI->getLoc (), TTTFI->getConverted (),
1085+ SILType::getPrimitiveObjectType (convertedThinType),
1086+ CFI->withoutActuallyEscaping ());
1087+ auto *newTTTFI = Builder.createThinToThickFunction (
1088+ TTTFI->getLoc (), newCFI, CFI->getType ());
1089+ replaceInstUsesWith (*CFI, newTTTFI);
1090+ }
1091+ }
1092+ }
1093+ }
1094+
1095+ // Rewrite conversion of `thin_to_thick_function` as `thin_to_thick_function`
1096+ // with a noescape function type.
1097+ //
1098+ // (convert_escape_to_noescape (thin_to_thick_function x)) =>
1099+ // (thin_to_thick_function [noescape] x)
1100+ if (auto *OrigThinToThick = dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ())) {
1101+ auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1102+ auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
10741103
1075- return Builder.createThinToThickFunction (
1104+ return Builder.createThinToThickFunction (
10761105 OrigThinToThick->getLoc (), OrigThinToThick->getOperand (),
10771106 SILType::getPrimitiveObjectType (NewTy));
1107+ }
1108+
1109+ // Push conversion instructions inside `differentiable_function`. This
1110+ // unblocks more optimizations.
1111+ //
1112+ // Before:
1113+ // %x = differentiable_function(%orig, %jvp, %vjp)
1114+ // %y = convert_escape_to_noescape %x
1115+ //
1116+ // After:
1117+ // %orig' = convert_escape_to_noescape %orig
1118+ // %jvp' = convert_escape_to_noescape %jvp
1119+ // %vjp' = convert_escape_to_noescape %vjp
1120+ // %y = differentiable_function(%orig', %jvp', %vjp')
1121+ if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted ())) {
1122+ auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1123+ if (!DFI->hasExtractee (extractee))
1124+ return SILValue ();
1125+
1126+ auto operand = DFI->getExtractee (extractee);
1127+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1128+ auto noEscapeFnType =
1129+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1130+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1131+ return Builder.createConvertEscapeToNoEscape (
1132+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1133+ };
1134+
1135+ SILValue originalNoEscape =
1136+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1137+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1138+ NormalDifferentiableFunctionTypeComponent::JVP);
1139+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1140+ NormalDifferentiableFunctionTypeComponent::VJP);
1141+
1142+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1143+ if (convertedJVP && convertedVJP)
1144+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1145+
1146+ auto *newDFI = Builder.createDifferentiableFunction (
1147+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1148+ originalNoEscape, derivativeFunctions);
1149+ assert (newDFI->getType () == Cvt->getType () &&
1150+ " New `@differentiable` function instruction should have same type "
1151+ " as the old `convert_escape_to_no_escape` instruction" );
1152+ return newDFI;
1153+ }
1154+
1155+ return nullptr ;
10781156}
10791157
10801158SILInstruction *
@@ -1207,6 +1285,54 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
12071285 return std::move (folder).optimizeWithSetValue (subCFI->getConverted ());
12081286 }
12091287
1288+ // Push conversion instructions inside `differentiable_function`. This
1289+ // unblocks more optimizations.
1290+ //
1291+ // Before:
1292+ // %x = differentiable_function(%orig, %jvp, %vjp)
1293+ // %y = convert_function %x
1294+ //
1295+ // After:
1296+ // %orig' = convert_function %orig
1297+ // %jvp' = convert_function %jvp
1298+ // %vjp' = convert_function %vjp
1299+ // %y = differentiable_function(%orig', %jvp', %vjp')
1300+ if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getConverted ())) {
1301+ auto createConvertFunctionOfComponent =
1302+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1303+ if (!DFI->hasExtractee (extractee))
1304+ return SILValue ();
1305+
1306+ auto operand = DFI->getExtractee (extractee);
1307+ auto convertInstType =
1308+ cfi->getType ().castTo <SILFunctionType>();
1309+ auto convertedComponentFnType =
1310+ convertInstType->getDifferentiableComponentType (
1311+ extractee, Builder.getModule ());
1312+ auto convertedComponentType =
1313+ SILType::getPrimitiveObjectType (convertedComponentFnType);
1314+ return Builder.createConvertFunction (
1315+ operand.getLoc (), operand, convertedComponentType,
1316+ cfi->withoutActuallyEscaping ())->getResult (0 );
1317+ };
1318+ SILValue convertedOriginal = createConvertFunctionOfComponent (
1319+ NormalDifferentiableFunctionTypeComponent::Original);
1320+ SILValue convertedJVP = createConvertFunctionOfComponent (
1321+ NormalDifferentiableFunctionTypeComponent::JVP);
1322+ SILValue convertedVJP = createConvertFunctionOfComponent (
1323+ NormalDifferentiableFunctionTypeComponent::VJP);
1324+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1325+ if (convertedJVP && convertedVJP)
1326+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1327+ auto *newDFI = Builder.createDifferentiableFunction (
1328+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1329+ convertedOriginal, derivativeFunctions);
1330+ assert (newDFI->getType () == cfi->getType () &&
1331+ " New `@differentiable` function instruction should have same type "
1332+ " as the old `convert_function` instruction" );
1333+ return newDFI;
1334+ }
1335+
12101336 // Replace a convert_function that only has refcounting uses with its
12111337 // operand.
12121338 tryEliminateOnlyOwnershipUsedForwardingInst (cfi, getInstModCallbacks ());
0 commit comments