@@ -53,48 +53,26 @@ ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc,
5353 }
5454 }
5555
56- bool multiOsContextDriver = false ;
5756 for (auto devicePair : context->getDevices ()) {
5857 auto neoDevice = devicePair.second ->getNEODevice ();
59- multiOsContextDriver |= devicePair.second ->isMultiDeviceCapable ();
6058 context->rootDeviceIndices .insert (neoDevice->getRootDeviceIndex ());
6159 context->deviceBitfields .insert ({neoDevice->getRootDeviceIndex (),
6260 neoDevice->getDeviceBitfield ()});
6361 }
6462
65- if (this ->mainContext == nullptr ) {
66- this ->mainContext = context;
67-
68- if (this ->getMemoryManager () == nullptr ) {
69- this ->setMemoryManager (context->getDevices ().begin ()->second ->getNEODevice ()->getMemoryManager ());
70- }
71-
72- this ->setSvmAllocsManager (new NEO::SVMAllocsManager (this ->getMemoryManager (), multiOsContextDriver));
73-
74- this ->getMemoryManager ()->setForceNonSvmForExternalHostPtr (true );
75-
76- if (NEO::DebugManager.flags .EnableHostPointerImport .get () == 1 ) {
77- createHostPointerManager ();
78- }
79- }
80-
8163 return ZE_RESULT_SUCCESS;
8264}
8365
8466NEO::MemoryManager *DriverHandleImp::getMemoryManager () {
85- return this ->mainContext -> getMemoryManager () ;
67+ return this ->memoryManager ;
8668}
8769
8870void DriverHandleImp::setMemoryManager (NEO::MemoryManager *memoryManager) {
89- this ->mainContext -> setMemoryManager ( memoryManager) ;
71+ this ->memoryManager = memoryManager ;
9072}
9173
9274NEO::SVMAllocsManager *DriverHandleImp::getSvmAllocsManager () {
93- return this ->mainContext ->getSvmAllocsManager ();
94- }
95-
96- void DriverHandleImp::setSvmAllocsManager (NEO::SVMAllocsManager *svmManager) {
97- this ->mainContext ->setSvmAllocsManager (svmManager);
75+ return this ->svmAllocsManager ;
9876}
9977
10078ze_result_t DriverHandleImp::getApiVersion (ze_api_version_t *version) {
@@ -155,6 +133,10 @@ DriverHandleImp::~DriverHandleImp() {
155133 for (auto &device : this ->devices ) {
156134 delete device;
157135 }
136+ if (this ->svmAllocsManager ) {
137+ delete this ->svmAllocsManager ;
138+ this ->svmAllocsManager = nullptr ;
139+ }
158140}
159141
160142ze_result_t DriverHandleImp::initialize (std::vector<std::unique_ptr<NEO::Device>> neoDevices) {
@@ -169,6 +151,13 @@ ze_result_t DriverHandleImp::initialize(std::vector<std::unique_ptr<NEO::Device>
169151 continue ;
170152 }
171153
154+ if (this ->memoryManager == nullptr ) {
155+ this ->memoryManager = neoDevice->getMemoryManager ();
156+ if (this ->memoryManager == nullptr ) {
157+ return ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY;
158+ }
159+ }
160+
172161 const auto rootDeviceIndex = neoDevice->getRootDeviceIndex ();
173162 auto rootDeviceEnvironment = neoDevice->getExecutionEnvironment ()->rootDeviceEnvironments [rootDeviceIndex].get ();
174163
@@ -200,12 +189,21 @@ ze_result_t DriverHandleImp::initialize(std::vector<std::unique_ptr<NEO::Device>
200189 return ZE_RESULT_ERROR_UNINITIALIZED;
201190 }
202191
192+ this ->svmAllocsManager = new NEO::SVMAllocsManager (memoryManager, multiOsContextDriver);
193+ if (this ->svmAllocsManager == nullptr ) {
194+ return ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY;
195+ }
196+
203197 this ->numDevices = static_cast <uint32_t >(this ->devices .size ());
204198
205199 extensionFunctionsLookupMap = getExtensionFunctionsLookupMap ();
206200
207201 uuidTimestamp = static_cast <uint64_t >(std::chrono::system_clock::now ().time_since_epoch ().count ());
208202
203+ if (NEO::DebugManager.flags .EnableHostPointerImport .get () == 1 ) {
204+ createHostPointerManager ();
205+ }
206+
209207 return ZE_RESULT_SUCCESS;
210208}
211209
@@ -225,6 +223,8 @@ DriverHandle *DriverHandle::create(std::vector<std::unique_ptr<NEO::Device>> dev
225223
226224 GlobalDriver = driverHandle;
227225
226+ driverHandle->getMemoryManager ()->setForceNonSvmForExternalHostPtr (true );
227+
228228 return driverHandle;
229229}
230230
@@ -250,8 +250,8 @@ bool DriverHandleImp::findAllocationDataForRange(const void *buffer,
250250 NEO::SvmAllocationData **allocData) {
251251 // Make sure the host buffer does not overlap any existing allocation
252252 const char *baseAddress = reinterpret_cast <const char *>(buffer);
253- NEO::SvmAllocationData *beginAllocData = getSvmAllocsManager () ->getSVMAlloc (baseAddress);
254- NEO::SvmAllocationData *endAllocData = getSvmAllocsManager () ->getSVMAlloc (baseAddress + size - 1 );
253+ NEO::SvmAllocationData *beginAllocData = svmAllocsManager ->getSVMAlloc (baseAddress);
254+ NEO::SvmAllocationData *endAllocData = svmAllocsManager ->getSVMAlloc (baseAddress + size - 1 );
255255
256256 if (allocData) {
257257 if (beginAllocData) {
@@ -275,8 +275,8 @@ std::vector<NEO::SvmAllocationData *> DriverHandleImp::findAllocationsWithinRang
275275 std::vector<NEO::SvmAllocationData *> allocDataArray;
276276 const char *baseAddress = reinterpret_cast <const char *>(buffer);
277277 // Check if the host buffer overlaps any existing allocation
278- NEO::SvmAllocationData *beginAllocData = this -> getSvmAllocsManager () ->getSVMAlloc (baseAddress);
279- NEO::SvmAllocationData *endAllocData = this -> getSvmAllocsManager () ->getSVMAlloc (baseAddress + size - 1 );
278+ NEO::SvmAllocationData *beginAllocData = svmAllocsManager ->getSVMAlloc (baseAddress);
279+ NEO::SvmAllocationData *endAllocData = svmAllocsManager ->getSVMAlloc (baseAddress + size - 1 );
280280
281281 // Add the allocation that matches the beginning address
282282 if (beginAllocData) {
0 commit comments