Skip to content

Commit 1ccb799

Browse files
authored
[None][chore] Relocate rlhf_utils.py (#8938)
Signed-off-by: shuyix <219646547+shuyixiong@users.noreply.github.com>
1 parent 972c21c commit 1ccb799

File tree

3 files changed

+2
-14
lines changed

3 files changed

+2
-14
lines changed
File renamed without changes.
Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
22
import sys
3-
from pathlib import Path
43

54
import pytest
6-
from utils.cpp_paths import llm_root # noqa: F401
75

86
from tensorrt_llm._utils import mpi_disabled
97

@@ -23,13 +21,3 @@ def pytest_configure(config):
2321
pytest.skip(
2422
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
2523
allow_module_level=True)
26-
27-
28-
@pytest.fixture(scope="function")
29-
def add_worker_extension_path(llm_root: Path):
30-
worker_extension_path = str(llm_root / "examples" / "llm-api" / "rlhf")
31-
original_python_path = os.environ.get('PYTHONPATH', '')
32-
os.environ['PYTHONPATH'] = os.pathsep.join(
33-
filter(None, [worker_extension_path, original_python_path]))
34-
yield
35-
os.environ['PYTHONPATH'] = original_python_path

tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ def run_generate(llm, hf_model, prompts, sampling_params):
139139
return llm_logits, ref_logits
140140

141141

142-
def test_llm_update_weights(add_worker_extension_path):
142+
def test_llm_update_weights():
143143
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
144144
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
145145

146146
hf_model = HFModel(llama_model_path)
147147

148148
llm = LLM(
149149
model=llama_model_path,
150-
ray_worker_extension_cls="rlhf_utils.WorkerExtension",
150+
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
151151
tensor_parallel_size=1,
152152
pipeline_parallel_size=1,
153153
kv_cache_config=kv_cache_config,

0 commit comments

Comments
 (0)