@@ -730,9 +730,29 @@ std::vector<int64_t> row_major(int64_t dim) {
730730}
731731static void noop () {}
732732
733+ struct DeviceProperties {
734+ size_t totalGlobalMem;
735+ size_t sharedMemPerBlock;
736+ int regsPerBlock;
737+ int warpSize;
738+ int maxThreadsPerBlock;
739+ int maxThreadsDim[3 ];
740+ int maxGridSize[3 ];
741+ int clockRate;
742+ size_t totalConstMem;
743+ int major;
744+ int minor;
745+ int multiProcessorCount;
746+ int canMapHostMemory;
747+ int computeMode;
748+ int l2CacheSize;
749+ int maxThreadsPerMultiProcessor;
750+ };
751+
733752#ifdef REACTANT_CUDA
734753
735754#include " third_party/gpus/cuda/include/cuda.h"
755+ #include " third_party/gpus/cuda/include/cuda_runtime.h"
736756
737757REACTANT_ABI int32_t ReactantCudaDriverGetVersion () {
738758 int32_t data;
@@ -769,6 +789,33 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() {
769789 return warpSize;
770790}
771791
792+ REACTANT_ABI void ReactantCudaDeviceGetProperties (DeviceProperties *jlprops,
793+ int32_t device_id) {
794+ cudaDeviceProp props;
795+ ReactantHandleCuResult (cudaGetDeviceProperties (&props, device_id));
796+
797+ jlprops->totalGlobalMem = props.totalGlobalMem ;
798+ jlprops->sharedMemPerBlock = props.sharedMemPerBlock ;
799+ jlprops->regsPerBlock = props.regsPerBlock ;
800+ jlprops->warpSize = props.warpSize ;
801+ jlprops->maxThreadsPerBlock = props.maxThreadsPerBlock ;
802+ jlprops->maxThreadsDim [0 ] = props.maxThreadsDim [0 ];
803+ jlprops->maxThreadsDim [1 ] = props.maxThreadsDim [1 ];
804+ jlprops->maxThreadsDim [2 ] = props.maxThreadsDim [2 ];
805+ jlprops->maxGridSize [0 ] = props.maxGridSize [0 ];
806+ jlprops->maxGridSize [1 ] = props.maxGridSize [1 ];
807+ jlprops->maxGridSize [2 ] = props.maxGridSize [2 ];
808+ jlprops->clockRate = props.clockRate ;
809+ jlprops->totalConstMem = props.totalConstMem ;
810+ jlprops->major = props.major ;
811+ jlprops->minor = props.minor ;
812+ jlprops->multiProcessorCount = props.multiProcessorCount ;
813+ jlprops->canMapHostMemory = props.canMapHostMemory ;
814+ jlprops->computeMode = props.computeMode ;
815+ jlprops->l2CacheSize = props.l2CacheSize ;
816+ jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor ;
817+ }
818+
772819#else
773820
774821REACTANT_ABI int32_t ReactantCudaDriverGetVersion () { return 0 ; }
@@ -781,6 +828,9 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; }
781828
782829REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads () { return 0 ; }
783830
831+ REACTANT_ABI void ReactantCudaDeviceGetProperties (DeviceProperties *jlprops,
832+ int32_t device_id) {}
833+
784834#endif
785835
786836REACTANT_ABI void *UnsafeBufferPointer (PjRtBuffer *buffer) {
@@ -1955,6 +2005,15 @@ REACTANT_ABI bool ifrt_DeviceIsAddressable(ifrt::Device *device) {
19552005 return device->IsAddressable ();
19562006}
19572007
2008+ REACTANT_ABI int64_t ifrt_DeviceGetLocalHardwareId (ifrt::Device *device) {
2009+ if (!llvm::isa<ifrt::PjRtDevice>(device)) {
2010+ ReactantThrowError (
2011+ " ifrt_device_get_allocator_stats: only supported for ifrt-pjrt." );
2012+ }
2013+ auto ifrt_pjrt_device = llvm::dyn_cast<ifrt::PjRtDevice>(device);
2014+ return ifrt_pjrt_device->pjrt_device ()->local_hardware_id ().value ();
2015+ }
2016+
19582017static xla::ifrt::RCReferenceWrapper<ifrt::DeviceList>
19592018ifrt_CreateDeviceListFromDevices (ifrt::Client *client,
19602019 ifrt::Device **device_list,
0 commit comments