Skip to content

Commit 2292ba9

Browse files
committed
feat: API to get device properties for cuda
1 parent 5587a2f commit 2292ba9

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,29 @@ std::vector<int64_t> row_major(int64_t dim) {
730730
}
731731
static void noop() {}
732732

733+
struct DeviceProperties {
734+
size_t totalGlobalMem;
735+
size_t sharedMemPerBlock;
736+
int regsPerBlock;
737+
int warpSize;
738+
int maxThreadsPerBlock;
739+
int maxThreadsDim[3];
740+
int maxGridSize[3];
741+
int clockRate;
742+
size_t totalConstMem;
743+
int major;
744+
int minor;
745+
int multiProcessorCount;
746+
int canMapHostMemory;
747+
int computeMode;
748+
int l2CacheSize;
749+
int maxThreadsPerMultiProcessor;
750+
};
751+
733752
#ifdef REACTANT_CUDA
734753

735754
#include "third_party/gpus/cuda/include/cuda.h"
755+
#include "third_party/gpus/cuda/include/cuda_runtime.h"
736756

737757
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() {
738758
int32_t data;
@@ -769,6 +789,33 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() {
769789
return warpSize;
770790
}
771791

792+
REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
793+
int32_t device_id) {
794+
cudaDeviceProp props;
795+
ReactantHandleCuResult(cudaGetDeviceProperties(&props, device_id));
796+
797+
jlprops->totalGlobalMem = props.totalGlobalMem;
798+
jlprops->sharedMemPerBlock = props.sharedMemPerBlock;
799+
jlprops->regsPerBlock = props.regsPerBlock;
800+
jlprops->warpSize = props.warpSize;
801+
jlprops->maxThreadsPerBlock = props.maxThreadsPerBlock;
802+
jlprops->maxThreadsDim[0] = props.maxThreadsDim[0];
803+
jlprops->maxThreadsDim[1] = props.maxThreadsDim[1];
804+
jlprops->maxThreadsDim[2] = props.maxThreadsDim[2];
805+
jlprops->maxGridSize[0] = props.maxGridSize[0];
806+
jlprops->maxGridSize[1] = props.maxGridSize[1];
807+
jlprops->maxGridSize[2] = props.maxGridSize[2];
808+
jlprops->clockRate = props.clockRate;
809+
jlprops->totalConstMem = props.totalConstMem;
810+
jlprops->major = props.major;
811+
jlprops->minor = props.minor;
812+
jlprops->multiProcessorCount = props.multiProcessorCount;
813+
jlprops->canMapHostMemory = props.canMapHostMemory;
814+
jlprops->computeMode = props.computeMode;
815+
jlprops->l2CacheSize = props.l2CacheSize;
816+
jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor;
817+
}
818+
772819
#else
773820

774821
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
@@ -781,6 +828,9 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; }
781828

782829
REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }
783830

831+
REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
832+
int32_t device_id) {}
833+
784834
#endif
785835

786836
REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
@@ -1955,6 +2005,15 @@ REACTANT_ABI bool ifrt_DeviceIsAddressable(ifrt::Device *device) {
19552005
return device->IsAddressable();
19562006
}
19572007

2008+
REACTANT_ABI int64_t ifrt_DeviceGetLocalHardwareId(ifrt::Device *device) {
2009+
if (!llvm::isa<ifrt::PjRtDevice>(device)) {
2010+
ReactantThrowError(
2011+
"ifrt_device_get_allocator_stats: only supported for ifrt-pjrt.");
2012+
}
2013+
auto ifrt_pjrt_device = llvm::dyn_cast<ifrt::PjRtDevice>(device);
2014+
return ifrt_pjrt_device->pjrt_device()->local_hardware_id().value();
2015+
}
2016+
19582017
static xla::ifrt::RCReferenceWrapper<ifrt::DeviceList>
19592018
ifrt_CreateDeviceListFromDevices(ifrt::Client *client,
19602019
ifrt::Device **device_list,

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ cc_library(
983983
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
984984
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
985985
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
986+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties",
986987
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
987988
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
988989
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",

0 commit comments

Comments
 (0)