Skip to content

Commit d28c26a

Browse files
authored
[ROCm] fix: obtain AMD GPU memory info through rocm_smi library (microsoft#21190)
### Description Previously ROCMExecutionProvider uses `hipMemGetInfo` to obtain the sizes of total memory and available memory. However, this API has been broken since ROCm 5.7. In this PR, we use `rocm_smi` library instead of `hipMemGetInfo`. ### Motivation and Context `hipMemGetInfo` API has been broken since ROCm 5.7 and inference with ROCMExecutionProvider will lead to following errors: ``` HIP failure 1: invalid argument ; GPU=0 ; hostname=4cc4900475fe ; file=/onnxruntime/onnxruntime/core/providers/rocm/rocm_execution_provider.cc ; line=229 ; expr=hipMemGetInfo(&free, &total); ``` MIOpen has a brute-force fix for this (https://github.com/ROCm/MIOpen/blob/911e67189592c311374940493f2099f3abced60d/src/hip/handlehip.cpp#L72). Instead of hard-coding available memory to 16GB, I suppose we could obtain memory info through `rocm_smi` library as in this PR.
1 parent fffd430 commit d28c26a

File tree

7 files changed

+25
-5
lines changed

7 files changed

+25
-5
lines changed

cmake/onnxruntime_providers_rocm.cmake

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949

5050
find_library(RCCL_LIB rccl REQUIRED)
5151
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
52-
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB})
52+
find_package(rocm_smi REQUIRED)
53+
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
54+
include_directories(${ROCM_SMI_INCLUDE_DIR})
55+
link_directories(${ROCM_SMI_LIB_DIR})
5356

5457
file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS
5558
"${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h"

onnxruntime/core/providers/rocm/nn/conv.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState<mio
4949
}
5050

5151
size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState<miopenConvAlgoPerf_t>& s,
52-
const miopenConvFwdAlgorithm_t* algo, int n_algo) {
52+
const miopenConvFwdAlgorithm_t* algo, int n_algo, int device_id = 0) {
5353
// TODO: get maximum available size from memory arena
5454
size_t free, total;
55-
HIP_CALL_THROW(hipMemGetInfo(&free, &total));
55+
onnxruntime::rocm::hipMemGetInfoAlt(device_id, &free, &total);
5656
// Assuming 10% of fragmentation
5757
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
5858
size_t max_ws_size = 0;
@@ -283,7 +283,7 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
283283
int algo_count = 1;
284284
const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
285285
static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
286-
size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
286+
size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId())
287287
: AlgoSearchWorkspaceSize;
288288
IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
289289
MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(

onnxruntime/core/providers/rocm/rocm_call.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template Status RocmCall<hiprandStatus_t, false>(hiprandStatus_t retCode, const
143143
template void RocmCall<hiprandStatus_t, true>(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line);
144144
template Status RocmCall<hipfftResult, false>(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line);
145145
template void RocmCall<hipfftResult, true>(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line);
146+
template Status RocmCall<rsmi_status_t, false>(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line);
147+
template void RocmCall<rsmi_status_t, true>(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line);
146148

147149
#ifdef ORT_USE_NCCL
148150
template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);

onnxruntime/core/providers/rocm/rocm_common.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,17 @@ inline int warpSizeDynamic() {
6767
return deviceProp.warpSize;
6868
}
6969

70+
inline void hipMemGetInfoAlt(uint32_t deviceId, size_t* pFree, size_t* pTotal) {
71+
const auto status = hipMemGetInfo(pFree, pTotal);
72+
if (status != hipSuccess) {
73+
size_t usedMemory = 0;
74+
ROCMSMI_CALL_THROW(rsmi_init(0));
75+
ROCMSMI_CALL_THROW(rsmi_dev_memory_total_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, pTotal));
76+
ROCMSMI_CALL_THROW(rsmi_dev_memory_usage_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, &usedMemory));
77+
*pFree = *pTotal - usedMemory;
78+
ROCMSMI_CALL_THROW(rsmi_shut_down());
79+
}
80+
}
81+
7082
} // namespace rocm
7183
} // namespace onnxruntime

onnxruntime/core/providers/rocm/rocm_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in
261261

262262
size_t free = 0;
263263
size_t total = 0;
264-
HIP_CALL_THROW(hipMemGetInfo(&free, &total));
264+
onnxruntime::rocm::hipMemGetInfoAlt(info_.device_id, &free, &total);
265265

266266
OverrideTunableOpInfoByEnv(info_);
267267

onnxruntime/core/providers/rocm/rocm_pch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <hipsparse/hipsparse.h>
1515
#include <miopen/miopen.h>
1616
#include <rocblas/rocblas.h>
17+
#include <rocm_smi/rocm_smi.h>
1718

1819
#ifdef ORT_USE_NCCL
1920
#include <rccl/rccl.h>

onnxruntime/core/providers/rocm/shared_inc/rocm_call.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ std::conditional_t<THRW, void, Status> RocmCall(
1717

1818
#define HIP_CALL(expr) (RocmCall<hipError_t, false>((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__))
1919
#define ROCBLAS_CALL(expr) (RocmCall<rocblas_status, false>((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__))
20+
#define ROCMSMI_CALL(expr) (RocmCall<rsmi_status_t, false>((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__))
2021

2122
#define HIPSPARSE_CALL(expr) (RocmCall<hipsparseStatus_t, false>((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
2223
#define HIPRAND_CALL(expr) (RocmCall<hiprandStatus_t, false>((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__))
@@ -27,6 +28,7 @@ std::conditional_t<THRW, void, Status> RocmCall(
2728

2829
#define HIP_CALL_THROW(expr) (RocmCall<hipError_t, true>((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__))
2930
#define ROCBLAS_CALL_THROW(expr) (RocmCall<rocblas_status, true>((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__))
31+
#define ROCMSMI_CALL_THROW(expr) (RocmCall<rsmi_status_t, true>((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__))
3032

3133
#define HIPSPARSE_CALL_THROW(expr) (RocmCall<hipsparseStatus_t, true>((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
3234
#define HIPRAND_CALL_THROW(expr) (RocmCall<hiprandStatus_t, true>((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__))

0 commit comments

Comments
 (0)