Skip to content

Commit 5462657

Browse files
authored
fix: build for cuda 13 + mac (#1844)
* fix: build for cuda 13 + mac * ci: build reactant jll * ci: pass in enzyme_jax_commit to reduce builds * Apply suggestion from @avik-pal * Apply suggestion from @avik-pal
1 parent 6e5667c commit 5462657

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: "Build Reactant_jll"
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
paths:
8+
- ".github/workflows/build-reactantjll.yml"
9+
- "deps/ReactantExtra/API.cpp"
10+
- "deps/ReactantExtra/BUILD"
11+
- "deps/ReactantExtra/WORKSPACE"
12+
- "deps/ReactantExtra/workspace.bzl"
13+
14+
concurrency:
15+
# Skip intermediate builds: always.
16+
# Cancel intermediate builds: only if it is a pull request build.
17+
group: ${{ github.workflow }}-${{ github.ref }}
18+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
19+
20+
jobs:
21+
enzyme-jax-commit:
22+
name: Extract ENZYMEXLA_COMMIT from WORKSPACE
23+
runs-on: ubuntu-latest
24+
timeout-minutes: 10
25+
26+
steps:
27+
- name: Checkout repository
28+
uses: actions/checkout@v4
29+
30+
- name: Extract ENZYMEXLA_COMMIT from WORKSPACE
31+
id: extract_enzyme_jax_commit
32+
run: |
33+
ENZYMEXLA_COMMIT=$(grep -oP 'ENZYMEXLA_COMMIT = "\K[^"]+' deps/ReactantExtra/WORKSPACE)
34+
echo "enzyme_jax_commit=$ENZYMEXLA_COMMIT" >> $GITHUB_OUTPUT
35+
outputs:
36+
enzyme_jax_commit: ${{ steps.extract_enzyme_jax_commit.outputs.enzyme_jax_commit }}
37+
38+
build-jll:
39+
name: Build Reactant_jll
40+
if: github.event.pull_request.draft == false
41+
uses: EnzymeAD/ReactantBuilder/.github/workflows/build-reactant-reusable.yml@main
42+
needs: enzyme-jax-commit
43+
with:
44+
reactantbuilder_ref: "main"
45+
reactant_commit: ${{ github.event.pull_request.head.sha }}
46+
enzyme_jax_commit: ${{ needs.enzyme-jax-commit.outputs.enzyme_jax_commit }}

deps/ReactantExtra/API.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -914,8 +914,9 @@ CudaGetStreamExecutorDeviceDescription(int32_t device_id) {
914914

915915
// Memory bandwidth (bytes/sec) ≈ 2 * memClock(Hz) * busWidth(bytes)
916916
// props.memoryClockRate is in kHz; bus width is in bits.
917-
const double mem_clock_hz =
918-
static_cast<double>(props.memoryClockRate) * 1000.0;
917+
const double mem_clock_hz = static_cast<double>(GetCudaIntegerAttribute(
918+
cudaDevAttrMemoryClockRate, device_id)) *
919+
1000.0;
919920
const double bus_bytes = static_cast<double>(props.memoryBusWidth) / 8.0;
920921
const double bandwidth_Bps = 2.0 * mem_clock_hz * bus_bytes; // DDR assumption
921922
device_description->set_memory_bandwidth(
@@ -925,8 +926,10 @@ CudaGetStreamExecutorDeviceDescription(int32_t device_id) {
925926
GetCudaIntegerAttribute(cudaDevAttrL2CacheSize, device_id));
926927

927928
// SM clock (GHz). props.clockRate is kHz.
928-
device_description->set_clock_rate_ghz(static_cast<double>(props.clockRate) /
929-
1.0e6);
929+
device_description->set_clock_rate_ghz(
930+
static_cast<double>(
931+
GetCudaIntegerAttribute(cudaDevAttrClockRate, device_id)) /
932+
1.0e6);
930933
device_description->set_device_memory_size(props.totalGlobalMem);
931934

932935
// Registers
@@ -3480,7 +3483,7 @@ REACTANT_ABI void EstimateRunTimeForInstruction(
34803483

34813484
#else
34823485

3483-
REACTANT_ABI void *CreateGPUPerformanceModelWrapper(
3486+
REACTANT_ABI void *CreateGPUPerformanceModel(
34843487
MlirContext ctx, stream_executor::DeviceDescription *device_description) {
34853488
return nullptr;
34863489
}

0 commit comments

Comments
 (0)