Skip to content

Commit 1906a4b

Browse files
committed
Add JLJITLinkMemoryManager (ports memory manager to JITLink)
Ports our RTDyLD memory manager to JITLink in order to avoid memory use regressions after switching to JITLink everywhere (JuliaLang#60031). This is essentially a direct port: finalization must happen all at once, because it invalidates all allocation `wr_ptr`s. I decided it wasn't worth it to associate `OnFinalizedFunction` callbacks with each block, since they are large enough to make it extremely likely that all in-flight allocations land in the same block; everything must be relocated before finalization can happen. I plan to add support for DualMapAllocator on ARM64 macOS, as well as an alternative for executable memory to come later. For now, we fall back to the old MapperJITLinkMemoryManager.
1 parent 0c34bde commit 1906a4b

File tree

2 files changed

+180
-43
lines changed

2 files changed

+180
-43
lines changed

src/cgmemmgr.cpp

Lines changed: 179 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
#include "llvm-version.h"
44
#include "platform.h"
55

6+
#include <llvm/ExecutionEngine/JITLink/JITLink.h>
7+
#include <llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h>
8+
#include <llvm/ExecutionEngine/Orc/MapperJITLinkMemoryManager.h>
69
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
10+
711
#include "julia.h"
812
#include "julia_internal.h"
913

@@ -460,26 +464,36 @@ struct Block {
460464
}
461465
};
462466

467+
struct Allocation {
468+
// Address to write to (the one returned by the allocation function)
469+
void *wr_addr;
470+
// Runtime address
471+
void *rt_addr;
472+
size_t sz;
473+
bool relocated;
474+
};
475+
463476
class RWAllocator {
464477
static constexpr int nblocks = 8;
465478
Block blocks[nblocks]{};
466479
public:
467480
RWAllocator() JL_NOTSAFEPOINT = default;
468-
void *alloc(size_t size, size_t align) JL_NOTSAFEPOINT
481+
Allocation alloc(size_t size, size_t align) JL_NOTSAFEPOINT
469482
{
470483
size_t min_size = (size_t)-1;
471484
int min_id = 0;
472485
for (int i = 0;i < nblocks && blocks[i].ptr;i++) {
473486
if (void *ptr = blocks[i].alloc(size, align))
474-
return ptr;
487+
return {ptr, ptr, size, false};
475488
if (blocks[i].avail < min_size) {
476489
min_size = blocks[i].avail;
477490
min_id = i;
478491
}
479492
}
480493
size_t block_size = get_block_size(size);
481494
blocks[min_id].reset(map_anon_page(block_size), block_size);
482-
return blocks[min_id].alloc(size, align);
495+
void *ptr = blocks[min_id].alloc(size, align);
496+
return {ptr, ptr, size, false};
483497
}
484498
};
485499

@@ -519,16 +533,6 @@ struct SplitPtrBlock : public Block {
519533
}
520534
};
521535

522-
struct Allocation {
523-
// Address to write to (the one returned by the allocation function)
524-
void *wr_addr;
525-
// Runtime address
526-
void *rt_addr;
527-
size_t sz;
528-
bool relocated;
529-
};
530-
531-
template<bool exec>
532536
class ROAllocator {
533537
protected:
534538
static constexpr int nblocks = 8;
@@ -556,7 +560,7 @@ class ROAllocator {
556560
}
557561
// Allocations that have not been finalized yet.
558562
SmallVector<Allocation, 16> allocations;
559-
void *alloc(size_t size, size_t align) JL_NOTSAFEPOINT
563+
Allocation alloc(size_t size, size_t align) JL_NOTSAFEPOINT
560564
{
561565
size_t min_size = (size_t)-1;
562566
int min_id = 0;
@@ -572,8 +576,9 @@ class ROAllocator {
572576
wr_ptr = get_wr_ptr(block, ptr, size, align);
573577
}
574578
block.state |= SplitPtrBlock::Alloc;
575-
allocations.push_back(Allocation{wr_ptr, ptr, size, false});
576-
return wr_ptr;
579+
Allocation a{wr_ptr, ptr, size, false};
580+
allocations.push_back(a);
581+
return a;
577582
}
578583
if (block.avail < min_size) {
579584
min_size = block.avail;
@@ -598,14 +603,16 @@ class ROAllocator {
598603
ptr = wr_ptr;
599604
#else
600605
block.state = SplitPtrBlock::Alloc | SplitPtrBlock::InitAlloc;
601-
allocations.push_back(Allocation{ptr, ptr, size, false});
606+
Allocation a{ptr, ptr, size, false};
607+
allocations.push_back(a);
602608
#endif
603-
return ptr;
609+
return a;
604610
}
605611
};
606612

607-
template<bool exec>
608-
class DualMapAllocator : public ROAllocator<exec> {
613+
class DualMapAllocator : public ROAllocator {
614+
bool exec;
615+
609616
protected:
610617
void *get_wr_ptr(SplitPtrBlock &block, void *rt_ptr, size_t, size_t) override JL_NOTSAFEPOINT
611618
{
@@ -666,7 +673,7 @@ class DualMapAllocator : public ROAllocator<exec> {
666673
}
667674
}
668675
public:
669-
DualMapAllocator() JL_NOTSAFEPOINT
676+
DualMapAllocator(bool exec) JL_NOTSAFEPOINT : exec(exec)
670677
{
671678
assert(anon_hdl != -1);
672679
}
@@ -679,13 +686,13 @@ class DualMapAllocator : public ROAllocator<exec> {
679686
finalize_block(block, true);
680687
block.reset(nullptr, 0);
681688
}
682-
ROAllocator<exec>::finalize();
689+
ROAllocator::finalize();
683690
}
684691
};
685692

686693
#ifdef _OS_LINUX_
687-
template<bool exec>
688-
class SelfMemAllocator : public ROAllocator<exec> {
694+
class SelfMemAllocator : public ROAllocator {
695+
bool exec;
689696
SmallVector<Block, 16> temp_buff;
690697
protected:
691698
void *get_wr_ptr(SplitPtrBlock &block, void *rt_ptr,
@@ -722,9 +729,7 @@ class SelfMemAllocator : public ROAllocator<exec> {
722729
}
723730
}
724731
public:
725-
SelfMemAllocator() JL_NOTSAFEPOINT
726-
: ROAllocator<exec>(),
727-
temp_buff()
732+
SelfMemAllocator(bool exec) JL_NOTSAFEPOINT : exec(exec), temp_buff()
728733
{
729734
assert(get_self_mem_fd() != -1);
730735
}
@@ -758,7 +763,7 @@ class SelfMemAllocator : public ROAllocator<exec> {
758763
}
759764
if (cached)
760765
temp_buff.resize(1);
761-
ROAllocator<exec>::finalize();
766+
ROAllocator::finalize();
762767
}
763768
};
764769
#endif // _OS_LINUX_
@@ -772,8 +777,8 @@ class RTDyldMemoryManagerJL : public SectionMemoryManager {
772777
void operator=(const RTDyldMemoryManagerJL&) = delete;
773778
SmallVector<EHFrame, 16> pending_eh;
774779
RWAllocator rw_alloc;
775-
std::unique_ptr<ROAllocator<false>> ro_alloc;
776-
std::unique_ptr<ROAllocator<true>> exe_alloc;
780+
std::unique_ptr<ROAllocator> ro_alloc;
781+
std::unique_ptr<ROAllocator> exe_alloc;
777782
size_t total_allocated;
778783

779784
public:
@@ -787,13 +792,13 @@ class RTDyldMemoryManagerJL : public SectionMemoryManager {
787792
{
788793
#ifdef _OS_LINUX_
789794
if (!ro_alloc && get_self_mem_fd() != -1) {
790-
ro_alloc.reset(new SelfMemAllocator<false>());
791-
exe_alloc.reset(new SelfMemAllocator<true>());
795+
ro_alloc.reset(new SelfMemAllocator(false));
796+
exe_alloc.reset(new SelfMemAllocator(true));
792797
}
793798
#endif
794799
if (!ro_alloc && init_shared_map() != -1) {
795-
ro_alloc.reset(new DualMapAllocator<false>());
796-
exe_alloc.reset(new DualMapAllocator<true>());
800+
ro_alloc.reset(new DualMapAllocator(false));
801+
exe_alloc.reset(new DualMapAllocator(true));
797802
}
798803
}
799804
~RTDyldMemoryManagerJL() override JL_NOTSAFEPOINT
@@ -847,7 +852,7 @@ uint8_t *RTDyldMemoryManagerJL::allocateCodeSection(uintptr_t Size,
847852
jl_timing_counter_inc(JL_TIMING_COUNTER_JITSize, Size);
848853
jl_timing_counter_inc(JL_TIMING_COUNTER_JITCodeSize, Size);
849854
if (exe_alloc)
850-
return (uint8_t*)exe_alloc->alloc(Size, Alignment);
855+
return (uint8_t*)exe_alloc->alloc(Size, Alignment).wr_addr;
851856
return SectionMemoryManager::allocateCodeSection(Size, Alignment, SectionID,
852857
SectionName);
853858
}
@@ -862,9 +867,9 @@ uint8_t *RTDyldMemoryManagerJL::allocateDataSection(uintptr_t Size,
862867
jl_timing_counter_inc(JL_TIMING_COUNTER_JITSize, Size);
863868
jl_timing_counter_inc(JL_TIMING_COUNTER_JITDataSize, Size);
864869
if (!isReadOnly)
865-
return (uint8_t*)rw_alloc.alloc(Size, Alignment);
870+
return (uint8_t*)rw_alloc.alloc(Size, Alignment).wr_addr;
866871
if (ro_alloc)
867-
return (uint8_t*)ro_alloc->alloc(Size, Alignment);
872+
return (uint8_t*)ro_alloc->alloc(Size, Alignment).wr_addr;
868873
return SectionMemoryManager::allocateDataSection(Size, Alignment, SectionID,
869874
SectionName, isReadOnly);
870875
}
@@ -919,6 +924,138 @@ void RTDyldMemoryManagerJL::deregisterEHFrames(uint8_t *Addr,
919924
}
920925
#endif
921926

927+
class JLJITLinkMemoryManager : public jitlink::JITLinkMemoryManager {
928+
using OnFinalizedFunction =
929+
jitlink::JITLinkMemoryManager::InFlightAlloc::OnFinalizedFunction;
930+
931+
std::mutex Mutex;
932+
RWAllocator RWAlloc;
933+
std::unique_ptr<ROAllocator> ROAlloc;
934+
std::unique_ptr<ROAllocator> ExeAlloc;
935+
SmallVector<OnFinalizedFunction> FinalizedCallbacks;
936+
uint32_t InFlight{0};
937+
938+
public:
939+
class InFlightAlloc;
940+
941+
static std::unique_ptr<JITLinkMemoryManager> Create()
942+
{
943+
#ifdef _OS_LINUX_
944+
bool ok = get_self_mem_fd() != -1;
945+
#else
946+
bool ok = init_shared_map() != -1;
947+
#endif
948+
if (ok)
949+
return std::unique_ptr<JLJITLinkMemoryManager>(new JLJITLinkMemoryManager());
950+
951+
return cantFail(
952+
orc::MapperJITLinkMemoryManager::CreateWithMapper<orc::InProcessMemoryMapper>(
953+
/*Reservation Granularity*/ 16 * 1024 * 1024));
954+
}
955+
956+
void allocate(const jitlink::JITLinkDylib *JD, jitlink::LinkGraph &G,
957+
OnAllocatedFunction OnAllocated) override;
958+
959+
void deallocate(std::vector<FinalizedAlloc> Allocs,
960+
OnDeallocatedFunction OnDeallocated) override
961+
{
962+
jl_unreachable();
963+
}
964+
965+
protected:
966+
JLJITLinkMemoryManager()
967+
#ifdef _OS_LINUX_
968+
: ROAlloc(std::make_unique<SelfMemAllocator>(false)),
969+
ExeAlloc(std::make_unique<SelfMemAllocator>(true))
970+
#else
971+
: ROAlloc(std::make_unique<DualMapAllocator>(false)),
972+
ExeAlloc(std::make_unique<DualMapAllocator>(true))
973+
#endif
974+
{
975+
}
976+
977+
void finalize(OnFinalizedFunction OnFinalized)
978+
{
979+
std::unique_lock Lock{Mutex};
980+
FinalizedCallbacks.push_back(std::move(OnFinalized));
981+
982+
if (--InFlight > 0)
983+
return;
984+
985+
ROAlloc->finalize();
986+
ExeAlloc->finalize();
987+
988+
for (auto &CB : FinalizedCallbacks)
989+
std::move(CB)(FinalizedAlloc{});
990+
FinalizedCallbacks.clear();
991+
}
992+
};
993+
994+
class JLJITLinkMemoryManager::InFlightAlloc
995+
: public jitlink::JITLinkMemoryManager::InFlightAlloc {
996+
JLJITLinkMemoryManager &MM;
997+
jitlink::LinkGraph &G;
998+
999+
public:
1000+
InFlightAlloc(JLJITLinkMemoryManager &MM, jitlink::LinkGraph &G) : MM(MM), G(G) {}
1001+
1002+
void abandon(OnAbandonedFunction OnAbandoned) override { jl_unreachable(); }
1003+
1004+
void finalize(OnFinalizedFunction OnFinalized) override
1005+
{
1006+
auto *GP = &G;
1007+
MM.finalize([GP, OnFinalized =
1008+
std::move(OnFinalized)](Expected<FinalizedAlloc> FA) mutable {
1009+
if (!FA)
1010+
return OnFinalized(FA.takeError());
1011+
// Need to handle dealloc actions when we GC code
1012+
auto E = orc::shared::runFinalizeActions(GP->allocActions());
1013+
if (!E)
1014+
return OnFinalized(E.takeError());
1015+
OnFinalized(std::move(FA));
1016+
});
1017+
}
1018+
};
1019+
1020+
using orc::MemProt;
1021+
1022+
void JLJITLinkMemoryManager::allocate(const jitlink::JITLinkDylib *JD,
1023+
jitlink::LinkGraph &G,
1024+
OnAllocatedFunction OnAllocated)
1025+
{
1026+
jitlink::BasicLayout BL{G};
1027+
1028+
{
1029+
std::unique_lock Lock{Mutex};
1030+
for (auto &[AG, Seg] : BL.segments()) {
1031+
if (AG.getMemLifetime() == orc::MemLifetime::NoAlloc)
1032+
continue;
1033+
assert(AG.getMemLifetime() == orc::MemLifetime::Standard);
1034+
1035+
auto Prot = AG.getMemProt();
1036+
uint64_t Alignment = Seg.Alignment.value();
1037+
uint64_t Size = Seg.ContentSize + Seg.ZeroFillSize;
1038+
Allocation Alloc;
1039+
if (Prot == (MemProt::Read | MemProt::Write))
1040+
Alloc = RWAlloc.alloc(Size, Alignment);
1041+
else if (Prot == MemProt::Read)
1042+
Alloc = ROAlloc->alloc(Size, Alignment);
1043+
else if (Prot == (MemProt::Read | MemProt::Exec))
1044+
Alloc = ExeAlloc->alloc(Size, Alignment);
1045+
else
1046+
abort();
1047+
1048+
Seg.Addr = orc::ExecutorAddr::fromPtr(Alloc.rt_addr);
1049+
Seg.WorkingMem = (char *)Alloc.wr_addr;
1050+
}
1051+
}
1052+
1053+
if (auto Err = BL.apply())
1054+
return OnAllocated(std::move(Err));
1055+
1056+
++InFlight;
1057+
OnAllocated(std::make_unique<InFlightAlloc>(*this, G));
1058+
}
9221059
}
9231060

9241061
RTDyldMemoryManager* createRTDyldMemoryManager() JL_NOTSAFEPOINT
@@ -930,3 +1067,8 @@ size_t getRTDyldMemoryManagerTotalBytes(RTDyldMemoryManager *mm) JL_NOTSAFEPOINT
9301067
{
9311068
return ((RTDyldMemoryManagerJL*)mm)->getTotalBytes();
9321069
}
1070+
1071+
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager()
1072+
{
1073+
return JLJITLinkMemoryManager::Create();
1074+
}

src/jitlayers.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,12 +1208,6 @@ class JLMemoryUsagePlugin : public ObjectLinkingLayer::Plugin {
12081208
#pragma clang diagnostic ignored "-Wunused-function"
12091209
#endif
12101210

1211-
// TODO: Port our memory management optimisations to JITLink instead of using the
1212-
// default InProcessMemoryManager.
1213-
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager() JL_NOTSAFEPOINT {
1214-
return cantFail(orc::MapperJITLinkMemoryManager::CreateWithMapper<orc::InProcessMemoryMapper>(/*Reservation Granularity*/ 16 * 1024 * 1024));
1215-
}
1216-
12171211
#ifdef _COMPILER_CLANG_
12181212
#pragma clang diagnostic pop
12191213
#endif
@@ -1237,6 +1231,7 @@ class JLEHFrameRegistrar final : public jitlink::EHFrameRegistrar {
12371231
};
12381232

12391233
RTDyldMemoryManager *createRTDyldMemoryManager(void) JL_NOTSAFEPOINT;
1234+
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager() JL_NOTSAFEPOINT;
12401235

12411236
// A simple forwarding class, since OrcJIT v2 needs a unique_ptr, while we have a shared_ptr
12421237
class ForwardingMemoryManager : public RuntimeDyld::MemoryManager {

0 commit comments

Comments
 (0)