@@ -21,10 +21,8 @@ using namespace umf_test;
2121
2222class CUDAMemoryAccessor : public MemoryAccessor {
2323 public:
24- void init (CUcontext hContext, CUdevice hDevice) {
25- hDevice_ = hDevice;
26- hContext_ = hContext;
27- }
24+ CUDAMemoryAccessor (CUcontext hContext, CUdevice hDevice)
25+ : hDevice_(hDevice), hContext_(hContext) {}
2826
2927 void fill (void *ptr, size_t size, const void *pattern,
3028 size_t pattern_size) {
@@ -53,7 +51,7 @@ class CUDAMemoryAccessor : public MemoryAccessor {
5351};
5452
5553using CUDAProviderTestParams =
56- std::tuple<umf_usm_memory_type_t , MemoryAccessor *>;
54+ std::tuple<cuda_memory_provider_params_t , MemoryAccessor *>;
5755
5856struct umfCUDAProviderTest
5957 : umf_test::test,
@@ -62,23 +60,12 @@ struct umfCUDAProviderTest
6260 void SetUp () override {
6361 test::SetUp ();
6462
65- auto [memory_type , accessor] = this ->GetParam ();
66- params = create_cuda_prov_params (memory_type) ;
63+ auto [cuda_params , accessor] = this ->GetParam ();
64+ params = cuda_params ;
6765 memAccessor = accessor;
68- if (memory_type == UMF_MEMORY_TYPE_DEVICE) {
69- ((CUDAMemoryAccessor *)memAccessor)
70- ->init ((CUcontext)params.cuda_context_handle ,
71- params.cuda_device_handle );
72- }
7366 }
7467
75- void TearDown () override {
76- if (params.cuda_context_handle ) {
77- int ret = destroy_context ((CUcontext)params.cuda_context_handle );
78- ASSERT_EQ (ret, 0 );
79- }
80- test::TearDown ();
81- }
68+ void TearDown () override { test::TearDown (); }
8269
8370 cuda_memory_provider_params_t params;
8471 MemoryAccessor *memAccessor = nullptr ;
@@ -87,6 +74,7 @@ struct umfCUDAProviderTest
8774TEST_P (umfCUDAProviderTest, basic) {
8875 const size_t size = 1024 * 8 ;
8976 const uint32_t pattern = 0xAB ;
77+ CUcontext expected_current_context = get_current_context ();
9078
9179 // create CUDA provider
9280 umf_memory_provider_handle_t provider = nullptr ;
@@ -113,6 +101,12 @@ TEST_P(umfCUDAProviderTest, basic) {
113101 // use the allocated memory - fill it with a 0xAB pattern
114102 memAccessor->fill (ptr, size, &pattern, sizeof (pattern));
115103
104+ CUcontext actual_mem_context = get_mem_context (ptr);
105+ ASSERT_EQ (actual_mem_context, (CUcontext)params.cuda_context_handle );
106+
107+ CUcontext actual_current_context = get_current_context ();
108+ ASSERT_EQ (actual_current_context, expected_current_context);
109+
116110 umf_usm_memory_type_t memoryTypeActual =
117111 get_mem_type ((CUcontext)params.cuda_context_handle , ptr);
118112 ASSERT_EQ (memoryTypeActual, params.memory_type );
@@ -132,6 +126,7 @@ TEST_P(umfCUDAProviderTest, basic) {
132126}
133127
134128TEST_P (umfCUDAProviderTest, allocInvalidSize) {
129+ CUcontext expected_current_context = get_current_context ();
135130 // create CUDA provider
136131 umf_memory_provider_handle_t provider = nullptr ;
137132 umf_result_t umf_result =
@@ -151,32 +146,32 @@ TEST_P(umfCUDAProviderTest, allocInvalidSize) {
151146 ASSERT_EQ (umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT);
152147 }
153148
154- // destroy context and try to alloc some memory
155- destroy_context ((CUcontext)params.cuda_context_handle );
156- params.cuda_context_handle = 0 ;
157- umf_result = umfMemoryProviderAlloc (provider, 128 , 0 , &ptr);
158- ASSERT_EQ (umf_result, UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC);
159-
160- const char *message;
161- int32_t error;
162- umfMemoryProviderGetLastNativeError (provider, &message, &error);
163- ASSERT_EQ (error, CUDA_ERROR_INVALID_CONTEXT);
164- const char *expected_message =
165- " CUDA_ERROR_INVALID_CONTEXT - invalid device context" ;
166- ASSERT_EQ (strncmp (message, expected_message, strlen (expected_message)), 0 );
149+ CUcontext actual_current_context = get_current_context ();
150+ ASSERT_EQ (actual_current_context, expected_current_context);
151+
152+ umfMemoryProviderDestroy (provider);
167153}
168154
169155// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool
170156
171- CUDAMemoryAccessor cuAccessor;
157+ cuda_memory_provider_params_t cuParams_device_memory =
158+ create_cuda_prov_params (UMF_MEMORY_TYPE_DEVICE);
159+ cuda_memory_provider_params_t cuParams_shared_memory =
160+ create_cuda_prov_params (UMF_MEMORY_TYPE_SHARED);
161+ cuda_memory_provider_params_t cuParams_host_memory =
162+ create_cuda_prov_params (UMF_MEMORY_TYPE_HOST);
163+
164+ CUDAMemoryAccessor
165+ cuAccessor ((CUcontext)cuParams_device_memory.cuda_context_handle,
166+ (CUdevice)cuParams_device_memory.cuda_device_handle);
172167HostMemoryAccessor hostAccessor;
173168
174169INSTANTIATE_TEST_SUITE_P (
175170 umfCUDAProviderTestSuite, umfCUDAProviderTest,
176171 ::testing::Values (
177- CUDAProviderTestParams{UMF_MEMORY_TYPE_DEVICE , &cuAccessor},
178- CUDAProviderTestParams{UMF_MEMORY_TYPE_SHARED , &hostAccessor},
179- CUDAProviderTestParams{UMF_MEMORY_TYPE_HOST , &hostAccessor}));
172+ CUDAProviderTestParams{cuParams_device_memory , &cuAccessor},
173+ CUDAProviderTestParams{cuParams_shared_memory , &hostAccessor},
174+ CUDAProviderTestParams{cuParams_host_memory , &hostAccessor}));
180175
181176// TODO: add IPC API
182177GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST (umfIpcTest);
@@ -185,5 +180,5 @@ INSTANTIATE_TEST_SUITE_P(umfCUDAProviderTestSuite, umfIpcTest,
185180 ::testing::Values(ipcTestParams{
186181 umfProxyPoolOps(), nullptr,
187182 umfCUDAMemoryProviderOps(),
188- &cuParams_device_memory, &l0Accessor }));
183+ &cuParams_device_memory, &cuAccessor }));
189184*/
0 commit comments