Skip to content

Commit eecc471

Browse files
authored
Fix CI dependency error for nvidia-nvshmem-cu12 when using PyTorch nightly and other CI lint errors from pyrefly change. (#1165)
1 parent 078e708 commit eecc471

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ jobs:
3535
- name: Install PyTorch
3636
run: |
3737
source .venv/bin/activate
38+
# Install nvidia-nvshmem-cu12 from cu129 index (missing on cu128)
39+
uv pip install -U --pre nvidia-nvshmem-cu12 --index-url https://download.pytorch.org/whl/nightly/cu129
3840
uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
3941
4042
- name: Install lint dependencies

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ jobs:
113113
uv pip install -U "torch==2.9.*" --index-url https://download.pytorch.org/whl/${{ matrix.runtime-version }}
114114
else
115115
# Default to nightly
116+
if [[ "${{ matrix.runtime-version }}" == "cu128" ]]; then
117+
# Install nvidia-nvshmem-cu12 from cu129 index (missing on cu128)
118+
uv pip install -U --pre nvidia-nvshmem-cu12 --index-url https://download.pytorch.org/whl/nightly/cu129
119+
fi
116120
uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.runtime-version }}
117121
fi
118122

benchmarks/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def load_kernel_config(
741741

742742
def process_single_kernel_mapping(
743743
kernel_name: str, mapping: dict[str, Any]
744-
) -> tuple[str, ...]:
744+
) -> tuple[Any, ...]:
745745
"""Process a single kernel mapping configuration."""
746746
if not isinstance(mapping, dict):
747747
raise ValueError(
@@ -785,11 +785,11 @@ def process_single_kernel_mapping(
785785

786786

787787
def merge_kernel_configs(
788-
base_mappings: dict[str, tuple[str, ...]],
788+
base_mappings: dict[str, tuple[Any, ...]],
789789
base_metrics: dict[str, dict[str, str]],
790-
custom_mappings: dict[str, tuple[str, ...]],
790+
custom_mappings: dict[str, tuple[Any, ...]],
791791
custom_metrics: dict[str, dict[str, str]],
792-
) -> tuple[dict[str, tuple[str, ...]], dict[str, dict[str, str]]]:
792+
) -> tuple[dict[str, tuple[Any, ...]], dict[str, dict[str, str]]]:
793793
"""Merge custom kernel configurations with base configurations.
794794
795795
Custom configs extend and can override base configs.

helion/_logging/_internal.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from dataclasses import dataclass
44
from dataclasses import field
5+
import functools
56
import logging
67
import os
78
from typing import Callable
9+
from typing import Generic
810
from typing import ParamSpec
911

1012
LOG_ENV_VAR = "HELION_LOGS"
@@ -82,14 +84,11 @@ def init_logs() -> None:
8284
P = ParamSpec("P")
8385

8486

85-
class LazyString:
87+
class LazyString(Generic[P]):
8688
def __init__(
8789
self, func: Callable[P, str], *args: P.args, **kwargs: P.kwargs
8890
) -> None:
89-
# pyrefly: ignore [invalid-type-var]
90-
self.func: Callable[P, str] = func
91-
self.args: tuple[object, ...] = args
92-
self.kwargs: object = kwargs
91+
self._callable: Callable[[], str] = functools.partial(func, *args, **kwargs)
9392

9493
def __str__(self) -> str:
95-
return self.func(*self.args, **self.kwargs)
94+
return self._callable()

0 commit comments

Comments
 (0)