Skip to content

Commit 182c9fd

Browse files
avik-palwsmoses
authored andcommitted
feat(JLL): gpu performance model + HLO IR utilities
1 parent 86bea8f commit 182c9fd

File tree

2 files changed

+358
-2
lines changed

2 files changed

+358
-2
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 339 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@
8282
#include "xla/pjrt/pjrt_executable.h"
8383
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
8484

85+
#include "xla/hlo/ir/hlo_computation.h"
86+
#include "xla/hlo/ir/hlo_instruction.h"
87+
#include "xla/hlo/ir/hlo_instructions.h"
88+
#include "xla/hlo/ir/hlo_module.h"
89+
#include "xla/hlo/parser/hlo_parser.h"
8590
#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h"
8691
#include "xla/hlo/translate/stablehlo.h"
8792

@@ -155,6 +160,13 @@
155160
#include "xla/hlo/ir/hlo_module.h"
156161
#include "xla/service/hlo_cost_analysis.h"
157162

163+
#if defined(REACTANT_CUDA) || defined(REACTANT_ROCM)
164+
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
165+
#include "xla/service/gpu/model/gpu_performance_model.h"
166+
#include "xla/service/gpu/model/gpu_performance_model_base.h"
167+
#include "xla/stream_executor/device_description.h"
168+
#endif
169+
158170
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
159171

160172
#include "llvm/Support/ExtensibleRTTI.h"
@@ -763,6 +775,7 @@ struct DeviceProperties {
763775

764776
#include "third_party/gpus/cuda/include/cuda.h"
765777
#include "third_party/gpus/cuda/include/cuda_runtime.h"
778+
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
766779

767780
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() {
768781
int32_t data;
@@ -844,6 +857,94 @@ REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
844857
return;
845858
}
846859

860+
inline stream_executor::SemanticVersion
861+
GetStreamExecutorVersion(int32_t version) {
862+
return stream_executor::SemanticVersion(version / 1000, (version % 1000) / 10,
863+
version % 10);
864+
}
865+
866+
inline int32_t GetCudaIntegerAttribute(cudaDeviceAttr attribute,
867+
int32_t device_id) {
868+
int32_t value;
869+
ReactantHandleCuResult(cudaDeviceGetAttribute(&value, attribute, device_id));
870+
return value;
871+
}
872+
873+
static int32_t CUDACoresPerSM(int32_t major, int32_t minor) {
874+
switch (major) {
875+
case 2:
876+
return 32;
877+
case 3:
878+
return 192;
879+
case 7:
880+
return 64;
881+
case 8:
882+
return minor == 0 ? 64 : 128;
883+
default:
884+
return 128;
885+
}
886+
}
887+
888+
REACTANT_ABI stream_executor::DeviceDescription *
889+
CudaGetStreamExecutorDeviceDescription(int32_t device_id) {
890+
stream_executor::DeviceDescription *device_description =
891+
new stream_executor::DeviceDescription();
892+
893+
cudaDeviceProp props;
894+
cudaGetDeviceProperties(&props, device_id);
895+
896+
device_description->set_gpu_compute_capability(
897+
stream_executor::CudaComputeCapability(props.major, props.minor));
898+
899+
device_description->set_threads_per_block_limit(props.maxThreadsPerBlock);
900+
device_description->set_threads_per_warp(props.warpSize);
901+
device_description->set_shared_memory_per_block(props.sharedMemPerBlock);
902+
device_description->set_shared_memory_per_block_optin(GetCudaIntegerAttribute(
903+
cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
904+
device_description->set_shared_memory_per_core(GetCudaIntegerAttribute(
905+
cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id));
906+
device_description->set_threads_per_core_limit(GetCudaIntegerAttribute(
907+
cudaDevAttrMaxThreadsPerMultiProcessor, device_id));
908+
device_description->set_core_count(props.multiProcessorCount);
909+
device_description->set_fpus_per_core(
910+
CUDACoresPerSM(props.major, props.minor));
911+
device_description->set_block_dim_limit_x(props.maxGridSize[0]);
912+
device_description->set_block_dim_limit_y(props.maxGridSize[1]);
913+
device_description->set_block_dim_limit_z(props.maxGridSize[2]);
914+
915+
// Memory bandwidth (bytes/sec) ≈ 2 * memClock(Hz) * busWidth(bytes)
916+
// props.memoryClockRate is in kHz; bus width is in bits.
917+
const double mem_clock_hz =
918+
static_cast<double>(props.memoryClockRate) * 1000.0;
919+
const double bus_bytes = static_cast<double>(props.memoryBusWidth) / 8.0;
920+
const double bandwidth_Bps = 2.0 * mem_clock_hz * bus_bytes; // DDR assumption
921+
device_description->set_memory_bandwidth(
922+
static_cast<uint64_t>(bandwidth_Bps));
923+
924+
device_description->set_l2_cache_size(
925+
GetCudaIntegerAttribute(cudaDevAttrL2CacheSize, device_id));
926+
927+
// SM clock (GHz). props.clockRate is kHz.
928+
device_description->set_clock_rate_ghz(static_cast<double>(props.clockRate) /
929+
1.0e6);
930+
device_description->set_device_memory_size(props.totalGlobalMem);
931+
932+
// Registers
933+
device_description->set_registers_per_core_limit(GetCudaIntegerAttribute(
934+
cudaDevAttrMaxRegistersPerMultiprocessor, device_id));
935+
device_description->set_registers_per_block_limit(
936+
GetCudaIntegerAttribute(cudaDevAttrMaxRegistersPerBlock, device_id));
937+
938+
// CUDA versions
939+
int drv = 0, rtm = 0;
940+
cudaRuntimeGetVersion(&rtm);
941+
device_description->set_runtime_version(GetStreamExecutorVersion(rtm));
942+
cudaDriverGetVersion(&drv);
943+
device_description->set_driver_version(GetStreamExecutorVersion(drv));
944+
945+
return device_description;
946+
}
947+
847948
#else
848949

849950
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
@@ -863,8 +964,18 @@ REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
863964
const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
864965
int32_t *maxThreads) {}
865966

967+
REACTANT_ABI stream_executor::DeviceDescription *
968+
CudaGetStreamExecutorDeviceDescription(int32_t device_id) {
969+
return nullptr;
970+
}
971+
866972
#endif
867973

974+
REACTANT_ABI const char *
975+
deviceDescriptionToString(stream_executor::DeviceDescription *device) {
976+
return cstr_from_string(device->ToString());
977+
}
978+
868979
REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
869980
auto unsafe = MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer));
870981
return (void *)unsafe;
@@ -1725,8 +1836,27 @@ PjRtLoadedExecutableGetHloModules(xla::PjRtLoadedExecutable *exec,
17251836
}
17261837
}
17271838

1728-
REACTANT_ABI const char *HloModuleToString(HeldHloModule *hlo_module) {
1729-
return cstr_from_string(hlo_module->obj()->ToString());
1839+
HloPrintOptions getHloPrintOptions(int32_t print_options) {
1840+
switch (print_options) {
1841+
case 0:
1842+
return HloPrintOptions::Default();
1843+
case 1:
1844+
return HloPrintOptions::ShortParsable();
1845+
case 2:
1846+
return HloPrintOptions::Canonical();
1847+
case 3:
1848+
return HloPrintOptions::Fingerprint();
1849+
case 4:
1850+
return HloPrintOptions::ModuleFingerprint();
1851+
default:
1852+
ReactantThrowError("Invalid print_options");
1853+
}
1854+
}
1855+
1856+
REACTANT_ABI const char *HloModuleToString(HeldHloModule *hlo_module,
1857+
int32_t print_options) {
1858+
return cstr_from_string(
1859+
hlo_module->obj()->ToString(getHloPrintOptions(print_options)));
17301860
}
17311861

17321862
REACTANT_ABI void FreeHloModule(HeldHloModule *hlo_module) {
@@ -3163,3 +3293,210 @@ REACTANT_ABI HeldHloModule *convertMlirModuleToHloModule(MlirModule mod) {
31633293
std::move(MyValueOrThrow(xla::ConvertStablehloToHlo(cmod_op)));
31643294
return reactant::capture(hlo_module);
31653295
}
3296+
3297+
REACTANT_ABI HeldHloModule *
3298+
parseAndReturnUnverifiedHloModule(const char *cstr) {
3299+
absl::string_view str(cstr);
3300+
auto hlo_module_status = xla::ParseAndReturnUnverifiedModule(str);
3301+
if (!hlo_module_status.ok()) {
3302+
ReactantThrowError(hlo_module_status.status().ToString().c_str());
3303+
}
3304+
std::shared_ptr<xla::HloModule> hlo_module =
3305+
std::move(hlo_module_status.value());
3306+
return reactant::capture(hlo_module);
3307+
}
3308+
3309+
REACTANT_ABI xla::HloComputation *
3310+
hloModuleGetEntryComputation(HeldHloModule *hlo_module) {
3311+
return hlo_module->obj()->entry_computation();
3312+
}
3313+
3314+
REACTANT_ABI void freeHloComputation(HloComputation *hlo_computation) {
3315+
delete hlo_computation;
3316+
}
3317+
3318+
REACTANT_ABI const char *hloComputationToString(HloComputation *hlo_computation,
3319+
int32_t print_options) {
3320+
return cstr_from_string(
3321+
hlo_computation->ToString(getHloPrintOptions(print_options)));
3322+
}
3323+
3324+
REACTANT_ABI int64_t
3325+
hloComputationInstructionCount(HloComputation *hlo_computation) {
3326+
return hlo_computation->instruction_count();
3327+
}
3328+
3329+
REACTANT_ABI void
3330+
hloComputationGetInstructionsPostOrder(HloComputation *hlo_computation,
3331+
int64_t num_instructions,
3332+
HloInstruction **hlo_instructions) {
3333+
std::vector<HloInstruction *> instructions =
3334+
hlo_computation->MakeInstructionPostOrder();
3335+
assert(instructions.size() == num_instructions);
3336+
for (int i = 0; i < num_instructions; i++) {
3337+
hlo_instructions[i] = instructions[i];
3338+
}
3339+
}
3340+
3341+
REACTANT_ABI void freeHloInstruction(HloInstruction *hlo_instruction) {
3342+
delete hlo_instruction;
3343+
}
3344+
3345+
REACTANT_ABI const char *hloInstructionToString(HloInstruction *hlo_instruction,
3346+
int32_t print_options) {
3347+
return cstr_from_string(
3348+
hlo_instruction->ToString(getHloPrintOptions(print_options)));
3349+
}
3350+
3351+
REACTANT_ABI uint8_t hloInstructionHasToApply(HloInstruction *hlo_instruction) {
3352+
return hlo_instruction->has_to_apply();
3353+
}
3354+
3355+
REACTANT_ABI HloComputation *
3356+
hloInstructionGetToApply(HloInstruction *hlo_instruction) {
3357+
return hlo_instruction->to_apply();
3358+
}
3359+
3360+
REACTANT_ABI uint8_t hloInstructionGetOpcode(HloInstruction *hlo_instruction) {
3361+
return static_cast<uint8_t>(hlo_instruction->opcode());
3362+
}
3363+
3364+
REACTANT_ABI const char *hloOpcodeToString(uint8_t opcode) {
3365+
return cstr_from_string(xla::HloOpcodeString(static_cast<HloOpcode>(opcode)));
3366+
}
3367+
3368+
REACTANT_ABI uint8_t hloInstructionIsFusion(HloInstruction *hlo_instruction) {
3369+
if (dynamic_cast<HloFusionInstruction *>(hlo_instruction)) {
3370+
return 1;
3371+
}
3372+
return 0;
3373+
}
3374+
3375+
REACTANT_ABI uint8_t
3376+
hloInstructionGetFusionKind(HloInstruction *hlo_instruction) {
3377+
if (auto hlo_instruction_fusion =
3378+
dynamic_cast<HloFusionInstruction *>(hlo_instruction)) {
3379+
return static_cast<uint8_t>(hlo_instruction_fusion->fusion_kind());
3380+
}
3381+
ReactantThrowError("hloInstructionGetFusionKind: not a fusion instruction");
3382+
}
3383+
3384+
REACTANT_ABI const char *hloFusionKindToString(uint8_t kind) {
3385+
return cstr_from_string(
3386+
xla::ToString(static_cast<HloInstruction::FusionKind>(kind)));
3387+
}
3388+
3389+
REACTANT_ABI HloComputation *
3390+
hloInstructionFusedInstructionsComputation(HloInstruction *hlo_instruction) {
3391+
if (auto hlo_instruction_fusion =
3392+
dynamic_cast<HloFusionInstruction *>(hlo_instruction)) {
3393+
return hlo_instruction_fusion->fused_instructions_computation();
3394+
}
3395+
ReactantThrowError("hloInstructionFusedInstructionsComputation: not a fusion "
3396+
"instruction");
3397+
}
3398+
3399+
struct JLEstimateRunTimeData {
3400+
int64_t flops;
3401+
int64_t bytes_read;
3402+
int64_t bytes_written;
3403+
int64_t read_time_ns;
3404+
int64_t write_time_ns;
3405+
int64_t compute_time_ns;
3406+
int64_t execution_time_ns;
3407+
};
3408+
3409+
#if defined(REACTANT_CUDA) || defined(REACTANT_ROCM)
3410+
namespace details {
3411+
3412+
// Cost analysis for individual instructions.
3413+
class GPUPerformanceModel {
3414+
public:
3415+
GPUPerformanceModel(mlir::MLIRContext *mlir_context,
3416+
stream_executor::DeviceDescription *device_description)
3417+
: mlir_context_(std::move(mlir_context)),
3418+
symbolic_expr_context_(mlir_context_),
3419+
device_description_(*device_description),
3420+
hlo_cost_analysis_options_{.count_multiple_input_accesses = true},
3421+
fusion_analysis_cache_(device_description_),
3422+
gpu_hlo_cost_analysis_(hlo_cost_analysis_options_, device_description_),
3423+
gpu_performance_model_(device_description_, fusion_analysis_cache_,
3424+
gpu_performance_model_cache_,
3425+
&symbolic_expr_context_) {}
3426+
3427+
void RunAnalysisOnHloModule(std::shared_ptr<xla::HloModule> hlo_module) {
3428+
hlo_module->entry_computation()->Accept(&gpu_hlo_cost_analysis_);
3429+
ran_analysis_ = true;
3430+
}
3431+
3432+
xla::gpu::EstimateRunTimeData
3433+
EstimateRunTimeForInstruction(HloInstruction *hlo_instruction) {
3434+
if (!ran_analysis_) {
3435+
ReactantThrowError("Must call RunAnalysisOnHloModule before calling "
3436+
"EstimateRunTimeForInstruction");
3437+
}
3438+
return gpu_performance_model_.EstimateRunTimeForInstruction(
3439+
hlo_instruction, &gpu_hlo_cost_analysis_);
3440+
}
3441+
3442+
private:
3443+
mlir::MLIRContext *mlir_context_;
3444+
xla::SymbolicExprContext symbolic_expr_context_;
3445+
xla::gpu::GpuHloCostAnalysis::Options hlo_cost_analysis_options_;
3446+
stream_executor::DeviceDescription device_description_;
3447+
xla::gpu::HloFusionAnalysisCache fusion_analysis_cache_;
3448+
xla::gpu::GpuHloCostAnalysis gpu_hlo_cost_analysis_;
3449+
xla::gpu::GpuPerformanceModelCache gpu_performance_model_cache_;
3450+
xla::gpu::GpuPerformanceModel gpu_performance_model_;
3451+
bool ran_analysis_ = false;
3452+
};
3453+
3454+
} // namespace details
3455+
3456+
REACTANT_ABI details::GPUPerformanceModel *CreateGPUPerformanceModel(
3457+
MlirContext ctx, stream_executor::DeviceDescription *device_description) {
3458+
return new details::GPUPerformanceModel(unwrap(ctx), device_description);
3459+
}
3460+
3461+
REACTANT_ABI void
3462+
RunAnalysisOnHloModule(details::GPUPerformanceModel *gpu_performance_model,
3463+
HeldHloModule *hlo_module) {
3464+
gpu_performance_model->RunAnalysisOnHloModule(hlo_module->obj());
3465+
}
3466+
3467+
REACTANT_ABI void EstimateRunTimeForInstruction(
3468+
details::GPUPerformanceModel *gpu_performance_model,
3469+
HloInstruction *hlo_instruction, JLEstimateRunTimeData *jldata) {
3470+
auto data =
3471+
gpu_performance_model->EstimateRunTimeForInstruction(hlo_instruction);
3472+
jldata->flops = data.flops;
3473+
jldata->bytes_read = data.bytes_read;
3474+
jldata->bytes_written = data.bytes_written;
3475+
jldata->read_time_ns = absl::ToInt64Nanoseconds(data.read_time);
3476+
jldata->write_time_ns = absl::ToInt64Nanoseconds(data.write_time);
3477+
jldata->compute_time_ns = absl::ToInt64Nanoseconds(data.compute_time);
3478+
jldata->execution_time_ns = absl::ToInt64Nanoseconds(data.exec_time);
3479+
}
3480+
3481+
#else
3482+
3483+
REACTANT_ABI void *CreateGPUPerformanceModelWrapper(
3484+
MlirContext ctx, stream_executor::DeviceDescription *device_description) {
3485+
return nullptr;
3486+
}
3487+
3488+
REACTANT_ABI void RunAnalysisOnHloModule(void *gpu_performance_model,
3489+
HloModule *hlo_module) {
3490+
ReactantThrowError("RunAnalysisOnHloModule is only supported if Reactant "
3491+
"was compiled with CUDA or ROCM support.");
3492+
}
3493+
3494+
REACTANT_ABI void EstimateRunTimeForInstruction(void *gpu_performance_model,
3495+
HloInstruction *hlo_instruction,
3496+
JLEstimateRunTimeData *jldata) {
3497+
ReactantThrowError(
3498+
"EstimateRunTimeForInstruction is only supported if Reactant "
3499+
"was compiled with CUDA or ROCM support.");
3500+
}
3501+
3502+
#endif

0 commit comments

Comments
 (0)