8282#include " xla/pjrt/pjrt_executable.h"
8383#include " xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
8484
85+ #include " xla/hlo/ir/hlo_computation.h"
86+ #include " xla/hlo/ir/hlo_instruction.h"
87+ #include " xla/hlo/ir/hlo_instructions.h"
88+ #include " xla/hlo/ir/hlo_module.h"
89+ #include " xla/hlo/parser/hlo_parser.h"
8590#include " xla/hlo/translate/hlo_to_mhlo/hlo_utils.h"
8691#include " xla/hlo/translate/stablehlo.h"
8792
155160#include " xla/hlo/ir/hlo_module.h"
156161#include " xla/service/hlo_cost_analysis.h"
157162
163+ #if defined(REACTANT_CUDA) || defined(REACTANT_ROCM)
164+ #include " xla/service/gpu/model/gpu_hlo_cost_analysis.h"
165+ #include " xla/service/gpu/model/gpu_performance_model.h"
166+ #include " xla/service/gpu/model/gpu_performance_model_base.h"
167+ #include " xla/stream_executor/device_description.h"
168+ #endif
169+
158170#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
159171
160172#include " llvm/Support/ExtensibleRTTI.h"
@@ -763,6 +775,7 @@ struct DeviceProperties {
763775
764776#include " third_party/gpus/cuda/include/cuda.h"
765777#include " third_party/gpus/cuda/include/cuda_runtime.h"
778+ #include " xla/stream_executor/cuda/cuda_compute_capability.h"
766779
767780REACTANT_ABI int32_t ReactantCudaDriverGetVersion () {
768781 int32_t data;
@@ -844,6 +857,94 @@ REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
844857 return ;
845858}
846859
860+ inline stream_executor::SemanticVersion
861+ GetStreamExecutorVersion (int32_t version) {
862+ return stream_executor::SemanticVersion (version / 1000 , (version % 1000 ) / 10 ,
863+ version % 10 );
864+ }
865+
866+ inline int32_t GetCudaIntegerAttribute (cudaDeviceAttr attribute,
867+ int32_t device_id) {
868+ int32_t value;
869+ ReactantHandleCuResult (cudaDeviceGetAttribute (&value, attribute, device_id));
870+ return value;
871+ }
872+
873+ static int32_t CUDACoresPerSM (int32_t major, int32_t minor) {
874+ switch (major) {
875+ case 2 :
876+ return 32 ;
877+ case 3 :
878+ return 192 ;
879+ case 7 :
880+ return 64 ;
881+ case 8 :
882+ return minor == 0 ? 64 : 128 ;
883+ default :
884+ return 128 ;
885+ }
886+ }
887+
888+ REACTANT_ABI stream_executor::DeviceDescription *
889+ CudaGetStreamExecutorDeviceDescription (int32_t device_id) {
890+ stream_executor::DeviceDescription *device_description =
891+ new stream_executor::DeviceDescription ();
892+
893+ cudaDeviceProp props;
894+ cudaGetDeviceProperties (&props, device_id);
895+
896+ device_description->set_gpu_compute_capability (
897+ stream_executor::CudaComputeCapability (props.major , props.minor ));
898+
899+ device_description->set_threads_per_block_limit (props.maxThreadsPerBlock );
900+ device_description->set_threads_per_warp (props.warpSize );
901+ device_description->set_shared_memory_per_block (props.sharedMemPerBlock );
902+ device_description->set_shared_memory_per_block_optin (GetCudaIntegerAttribute (
903+ cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
904+ device_description->set_shared_memory_per_core (GetCudaIntegerAttribute (
905+ cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id));
906+ device_description->set_threads_per_core_limit (GetCudaIntegerAttribute (
907+ cudaDevAttrMaxThreadsPerMultiProcessor, device_id));
908+ device_description->set_core_count (props.multiProcessorCount );
909+ device_description->set_fpus_per_core (
910+ CUDACoresPerSM (props.major , props.minor ));
911+ device_description->set_block_dim_limit_x (props.maxGridSize [0 ]);
912+ device_description->set_block_dim_limit_y (props.maxGridSize [1 ]);
913+ device_description->set_block_dim_limit_z (props.maxGridSize [2 ]);
914+
915+ // Memory bandwidth (bytes/sec) ≈ 2 * memClock(Hz) * busWidth(bytes)
916+ // props.memoryClockRate is in kHz; bus width is in bits.
917+ const double mem_clock_hz =
918+ static_cast <double >(props.memoryClockRate ) * 1000.0 ;
919+ const double bus_bytes = static_cast <double >(props.memoryBusWidth ) / 8.0 ;
920+ const double bandwidth_Bps = 2.0 * mem_clock_hz * bus_bytes; // DDR assumption
921+ device_description->set_memory_bandwidth (
922+ static_cast <uint64_t >(bandwidth_Bps));
923+
924+ device_description->set_l2_cache_size (
925+ GetCudaIntegerAttribute (cudaDevAttrL2CacheSize, device_id));
926+
927+ // SM clock (GHz). props.clockRate is kHz.
928+ device_description->set_clock_rate_ghz (static_cast <double >(props.clockRate ) /
929+ 1.0e6 );
930+ device_description->set_device_memory_size (props.totalGlobalMem );
931+
932+ // Registers
933+ device_description->set_registers_per_core_limit (GetCudaIntegerAttribute (
934+ cudaDevAttrMaxRegistersPerMultiprocessor, device_id));
935+ device_description->set_registers_per_block_limit (
936+ GetCudaIntegerAttribute (cudaDevAttrMaxRegistersPerBlock, device_id));
937+
938+ // CUDA versions
939+ int drv = 0 , rtm = 0 ;
940+ cudaRuntimeGetVersion (&rtm);
941+ device_description->set_runtime_version (GetStreamExecutorVersion (rtm));
942+ cudaDriverGetVersion (&drv);
943+ device_description->set_driver_version (GetStreamExecutorVersion (drv));
944+
945+ return device_description;
946+ }
947+
847948#else
848949
849950REACTANT_ABI int32_t ReactantCudaDriverGetVersion () { return 0 ; }
@@ -863,8 +964,18 @@ REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
863964 const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
864965 int32_t *maxThreads) {}
865966
967+ REACTANT_ABI stream_executor::DeviceDescription *
968+ CudaGetStreamExecutorDeviceDescription (int32_t device_id) {
969+ return nullptr ;
970+ }
971+
866972#endif
867973
974+ REACTANT_ABI const char *
975+ deviceDescriptionToString (stream_executor::DeviceDescription *device) {
976+ return cstr_from_string (device->ToString ());
977+ }
978+
868979REACTANT_ABI void *UnsafeBufferPointer (PjRtBuffer *buffer) {
869980 auto unsafe = MyValueOrThrow (buffer->client ()->UnsafeBufferPointer (buffer));
870981 return (void *)unsafe;
@@ -1725,8 +1836,27 @@ PjRtLoadedExecutableGetHloModules(xla::PjRtLoadedExecutable *exec,
17251836 }
17261837}
17271838
1728- REACTANT_ABI const char *HloModuleToString (HeldHloModule *hlo_module) {
1729- return cstr_from_string (hlo_module->obj ()->ToString ());
1839+ HloPrintOptions getHloPrintOptions (int32_t print_options) {
1840+ switch (print_options) {
1841+ case 0 :
1842+ return HloPrintOptions::Default ();
1843+ case 1 :
1844+ return HloPrintOptions::ShortParsable ();
1845+ case 2 :
1846+ return HloPrintOptions::Canonical ();
1847+ case 3 :
1848+ return HloPrintOptions::Fingerprint ();
1849+ case 4 :
1850+ return HloPrintOptions::ModuleFingerprint ();
1851+ default :
1852+ ReactantThrowError (" Invalid print_options" );
1853+ }
1854+ }
1855+
1856+ REACTANT_ABI const char *HloModuleToString (HeldHloModule *hlo_module,
1857+ int32_t print_options) {
1858+ return cstr_from_string (
1859+ hlo_module->obj ()->ToString (getHloPrintOptions (print_options)));
17301860}
17311861
17321862REACTANT_ABI void FreeHloModule (HeldHloModule *hlo_module) {
@@ -3163,3 +3293,210 @@ REACTANT_ABI HeldHloModule *convertMlirModuleToHloModule(MlirModule mod) {
31633293 std::move (MyValueOrThrow (xla::ConvertStablehloToHlo (cmod_op)));
31643294 return reactant::capture (hlo_module);
31653295}
3296+
3297+ REACTANT_ABI HeldHloModule *
3298+ parseAndReturnUnverifiedHloModule (const char *cstr) {
3299+ absl::string_view str (cstr);
3300+ auto hlo_module_status = xla::ParseAndReturnUnverifiedModule (str);
3301+ if (!hlo_module_status.ok ()) {
3302+ ReactantThrowError (hlo_module_status.status ().ToString ().c_str ());
3303+ }
3304+ std::shared_ptr<xla::HloModule> hlo_module =
3305+ std::move (hlo_module_status.value ());
3306+ return reactant::capture (hlo_module);
3307+ }
3308+
3309+ REACTANT_ABI xla::HloComputation *
3310+ hloModuleGetEntryComputation (HeldHloModule *hlo_module) {
3311+ return hlo_module->obj ()->entry_computation ();
3312+ }
3313+
3314+ REACTANT_ABI void freeHloComputation (HloComputation *hlo_computation) {
3315+ delete hlo_computation;
3316+ }
3317+
3318+ REACTANT_ABI const char *hloComputationToString (HloComputation *hlo_computation,
3319+ int32_t print_options) {
3320+ return cstr_from_string (
3321+ hlo_computation->ToString (getHloPrintOptions (print_options)));
3322+ }
3323+
3324+ REACTANT_ABI int64_t
3325+ hloComputationInstructionCount (HloComputation *hlo_computation) {
3326+ return hlo_computation->instruction_count ();
3327+ }
3328+
3329+ REACTANT_ABI void
3330+ hloComputationGetInstructionsPostOrder (HloComputation *hlo_computation,
3331+ int64_t num_instructions,
3332+ HloInstruction **hlo_instructions) {
3333+ std::vector<HloInstruction *> instructions =
3334+ hlo_computation->MakeInstructionPostOrder ();
3335+ assert (instructions.size () == num_instructions);
3336+ for (int i = 0 ; i < num_instructions; i++) {
3337+ hlo_instructions[i] = instructions[i];
3338+ }
3339+ }
3340+
3341+ REACTANT_ABI void freeHloInstruction (HloInstruction *hlo_instruction) {
3342+ delete hlo_instruction;
3343+ }
3344+
3345+ REACTANT_ABI const char *hloInstructionToString (HloInstruction *hlo_instruction,
3346+ int32_t print_options) {
3347+ return cstr_from_string (
3348+ hlo_instruction->ToString (getHloPrintOptions (print_options)));
3349+ }
3350+
3351+ REACTANT_ABI uint8_t hloInstructionHasToApply (HloInstruction *hlo_instruction) {
3352+ return hlo_instruction->has_to_apply ();
3353+ }
3354+
3355+ REACTANT_ABI HloComputation *
3356+ hloInstructionGetToApply (HloInstruction *hlo_instruction) {
3357+ return hlo_instruction->to_apply ();
3358+ }
3359+
3360+ REACTANT_ABI uint8_t hloInstructionGetOpcode (HloInstruction *hlo_instruction) {
3361+ return static_cast <uint8_t >(hlo_instruction->opcode ());
3362+ }
3363+
3364+ REACTANT_ABI const char *hloOpcodeToString (uint8_t opcode) {
3365+ return cstr_from_string (xla::HloOpcodeString (static_cast <HloOpcode>(opcode)));
3366+ }
3367+
3368+ REACTANT_ABI uint8_t hloInstructionIsFusion (HloInstruction *hlo_instruction) {
3369+ if (dynamic_cast <HloFusionInstruction *>(hlo_instruction)) {
3370+ return 1 ;
3371+ }
3372+ return 0 ;
3373+ }
3374+
3375+ REACTANT_ABI uint8_t
3376+ hloInstructionGetFusionKind (HloInstruction *hlo_instruction) {
3377+ if (auto hlo_instruction_fusion =
3378+ dynamic_cast <HloFusionInstruction *>(hlo_instruction)) {
3379+ return static_cast <uint8_t >(hlo_instruction_fusion->fusion_kind ());
3380+ }
3381+ ReactantThrowError (" hloInstructionGetFusionKind: not a fusion instruction" );
3382+ }
3383+
3384+ REACTANT_ABI const char *hloFusionKindToString (uint8_t kind) {
3385+ return cstr_from_string (
3386+ xla::ToString (static_cast <HloInstruction::FusionKind>(kind)));
3387+ }
3388+
3389+ REACTANT_ABI HloComputation *
3390+ hloInstructionFusedInstructionsComputation (HloInstruction *hlo_instruction) {
3391+ if (auto hlo_instruction_fusion =
3392+ dynamic_cast <HloFusionInstruction *>(hlo_instruction)) {
3393+ return hlo_instruction_fusion->fused_instructions_computation ();
3394+ }
3395+ ReactantThrowError (" hloInstructionFusedInstructionsComputation: not a fusion "
3396+ " instruction" );
3397+ }
3398+
3399+ struct JLEstimateRunTimeData {
3400+ int64_t flops;
3401+ int64_t bytes_read;
3402+ int64_t bytes_written;
3403+ int64_t read_time_ns;
3404+ int64_t write_time_ns;
3405+ int64_t compute_time_ns;
3406+ int64_t execution_time_ns;
3407+ };
3408+
3409+ #if defined(REACTANT_CUDA) || defined(REACTANT_ROCM)
3410+ namespace details {
3411+
3412+ // Cost analysis for individual instructions.
3413+ class GPUPerformanceModel {
3414+ public:
3415+ GPUPerformanceModel (mlir::MLIRContext *mlir_context,
3416+ stream_executor::DeviceDescription *device_description)
3417+ : mlir_context_(std::move(mlir_context)),
3418+ symbolic_expr_context_ (mlir_context_),
3419+ device_description_(*device_description),
3420+ hlo_cost_analysis_options_{.count_multiple_input_accesses = true },
3421+ fusion_analysis_cache_ (device_description_),
3422+ gpu_hlo_cost_analysis_(hlo_cost_analysis_options_, device_description_),
3423+ gpu_performance_model_(device_description_, fusion_analysis_cache_,
3424+ gpu_performance_model_cache_,
3425+ &symbolic_expr_context_) {}
3426+
3427+ void RunAnalysisOnHloModule (std::shared_ptr<xla::HloModule> hlo_module) {
3428+ hlo_module->entry_computation ()->Accept (&gpu_hlo_cost_analysis_);
3429+ ran_analysis_ = true ;
3430+ }
3431+
3432+ xla::gpu::EstimateRunTimeData
3433+ EstimateRunTimeForInstruction (HloInstruction *hlo_instruction) {
3434+ if (!ran_analysis_) {
3435+ ReactantThrowError (" Must call RunAnalysisOnHloModule before calling "
3436+ " EstimateRunTimeForInstruction" );
3437+ }
3438+ return gpu_performance_model_.EstimateRunTimeForInstruction (
3439+ hlo_instruction, &gpu_hlo_cost_analysis_);
3440+ }
3441+
3442+ private:
3443+ mlir::MLIRContext *mlir_context_;
3444+ xla::SymbolicExprContext symbolic_expr_context_;
3445+ xla::gpu::GpuHloCostAnalysis::Options hlo_cost_analysis_options_;
3446+ stream_executor::DeviceDescription device_description_;
3447+ xla::gpu::HloFusionAnalysisCache fusion_analysis_cache_;
3448+ xla::gpu::GpuHloCostAnalysis gpu_hlo_cost_analysis_;
3449+ xla::gpu::GpuPerformanceModelCache gpu_performance_model_cache_;
3450+ xla::gpu::GpuPerformanceModel gpu_performance_model_;
3451+ bool ran_analysis_ = false ;
3452+ };
3453+
3454+ } // namespace details
3455+
3456+ REACTANT_ABI details::GPUPerformanceModel *CreateGPUPerformanceModel (
3457+ MlirContext ctx, stream_executor::DeviceDescription *device_description) {
3458+ return new details::GPUPerformanceModel (unwrap (ctx), device_description);
3459+ }
3460+
3461+ REACTANT_ABI void
3462+ RunAnalysisOnHloModule (details::GPUPerformanceModel *gpu_performance_model,
3463+ HeldHloModule *hlo_module) {
3464+ gpu_performance_model->RunAnalysisOnHloModule (hlo_module->obj ());
3465+ }
3466+
3467+ REACTANT_ABI void EstimateRunTimeForInstruction (
3468+ details::GPUPerformanceModel *gpu_performance_model,
3469+ HloInstruction *hlo_instruction, JLEstimateRunTimeData *jldata) {
3470+ auto data =
3471+ gpu_performance_model->EstimateRunTimeForInstruction (hlo_instruction);
3472+ jldata->flops = data.flops ;
3473+ jldata->bytes_read = data.bytes_read ;
3474+ jldata->bytes_written = data.bytes_written ;
3475+ jldata->read_time_ns = absl::ToInt64Nanoseconds (data.read_time );
3476+ jldata->write_time_ns = absl::ToInt64Nanoseconds (data.write_time );
3477+ jldata->compute_time_ns = absl::ToInt64Nanoseconds (data.compute_time );
3478+ jldata->execution_time_ns = absl::ToInt64Nanoseconds (data.exec_time );
3479+ }
3480+
3481+ #else
3482+
3483+ REACTANT_ABI void *CreateGPUPerformanceModelWrapper (
3484+ MlirContext ctx, stream_executor::DeviceDescription *device_description) {
3485+ return nullptr ;
3486+ }
3487+
3488+ REACTANT_ABI void RunAnalysisOnHloModule (void *gpu_performance_model,
3489+ HloModule *hlo_module) {
3490+ ReactantThrowError (" RunAnalysisOnHloModule is only supported if Reactant "
3491+ " was compiled with CUDA or ROCM support." );
3492+ }
3493+
3494+ REACTANT_ABI void EstimateRunTimeForInstruction (void *gpu_performance_model,
3495+ HloInstruction *hlo_instruction,
3496+ JLEstimateRunTimeData *jldata) {
3497+ ReactantThrowError (
3498+ " EstimateRunTimeForInstruction is only supported if Reactant "
3499+ " was compiled with CUDA or ROCM support." );
3500+ }
3501+
3502+ #endif
0 commit comments