@@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
2626 "--parallelism.data_parallel_shard_degree 2" ,
2727 "--parallelism.data_parallel_replicate_degree 2" ,
2828 "--parallelism.context_parallel_degree 2" ,
29- ]
30- ],
31- "HSDP+CP" ,
32- "hsdp+cp" ,
33- ngpu = 8 ,
34- ),
35- OverrideDefinitions (
36- [
37- [
3829 "--validation.enable" ,
39- ]
30+ "--validation.steps 5" ,
31+ "--checkpoint.enable" ,
32+ ],
33+ [],
4034 ],
41- "Flux Validation Test" ,
42- "validation" ,
35+ "HSDP+CP+Validation+Inference" ,
36+ "hsdp+cp+validation+inference" ,
37+ ngpu = 8 ,
4338 ),
4439 ]
4540 return integration_tests_flavors
@@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
6358 t5_encoder_version_arg = (
6459 "--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
6560 )
66- tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
61+ hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer"
6762
6863 all_ranks = "," .join (map (str , range (test_flavor .ngpu )))
6964
@@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
7368 cmd = f'TORCH_TRACE="{ output_dir } /{ test_name } /compile_trace" ' + cmd
7469
7570 # save checkpoint (idx == 0) and load it for generation (idx == 1)
76- if test_name == "test_generate " and idx == 1 :
71+ if test_name == "hsdp+cp+validation+inference " and idx == 1 :
7772 # For flux generation, test using inference script
7873 cmd = (
7974 f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } "
@@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
8479 cmd += " " + random_init_encoder_arg
8580 cmd += " " + clip_encoder_version_arg
8681 cmd += " " + t5_encoder_version_arg
87- cmd += " " + tokenzier_path_arg
82+ cmd += " " + hf_assets_path_arg
8883 if override_arg :
8984 cmd += " " + " " .join (override_arg )
9085
0 commit comments