Skip to content

Commit cbdb311

Browse files
authored
[FLUX] Add FLUX inference test in CI (#1969)
1 parent 1b9cfda commit cbdb311

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

.github/workflows/integration_test_8gpu_models.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ jobs:
5454
python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8
5555
python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8
5656
rm -rf artifacts-to-be-uploaded/*/checkpoint
57+
rm -rf artifacts-to-be-uploaded/flux/*/inference_results/

tests/integration_tests/flux.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
2626
"--parallelism.data_parallel_shard_degree 2",
2727
"--parallelism.data_parallel_replicate_degree 2",
2828
"--parallelism.context_parallel_degree 2",
29-
]
30-
],
31-
"HSDP+CP",
32-
"hsdp+cp",
33-
ngpu=8,
34-
),
35-
OverrideDefinitions(
36-
[
37-
[
3829
"--validation.enable",
39-
]
30+
"--validation.steps 5",
31+
"--checkpoint.enable",
32+
],
33+
[],
4034
],
41-
"Flux Validation Test",
42-
"validation",
35+
"HSDP+CP+Validation+Inference",
36+
"hsdp+cp+validation+inference",
37+
ngpu=8,
4338
),
4439
]
4540
return integration_tests_flavors
@@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
6358
t5_encoder_version_arg = (
6459
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
6560
)
66-
tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
61+
hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer"
6762

6863
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
6964

@@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
7368
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
7469

7570
# save checkpoint (idx == 0) and load it for generation (idx == 1)
76-
if test_name == "test_generate" and idx == 1:
71+
if test_name == "hsdp+cp+validation+inference" and idx == 1:
7772
# For flux generation, test using inference script
7873
cmd = (
7974
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
@@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
8479
cmd += " " + random_init_encoder_arg
8580
cmd += " " + clip_encoder_version_arg
8681
cmd += " " + t5_encoder_version_arg
87-
cmd += " " + tokenzier_path_arg
82+
cmd += " " + hf_assets_path_arg
8883
if override_arg:
8984
cmd += " " + " ".join(override_arg)
9085

torchtitan/models/flux/inference/infer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def inference(config: JobConfig):
2828
original_prompts = open(config.inference.prompts_path).readlines()
2929
total_prompts = len(original_prompts)
3030

31+
if total_prompts < world_size:
32+
raise ValueError(
33+
f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). "
34+
f"FSDP all-gather will hang if some ranks have no prompts to process."
35+
)
36+
3137
# Distribute prompts across processes using round-robin assignment
3238
prompts = original_prompts[global_rank::world_size]
3339

@@ -45,7 +51,6 @@ def inference(config: JobConfig):
4551
config.job.dump_folder,
4652
config.inference.save_img_folder,
4753
)
48-
4954
# Create mapping from local indices to global prompt indices
5055
global_ids = list(range(global_rank, total_prompts, world_size))
5156

torchtitan/models/flux/train_configs/debug_model.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ enable_wandb = false
2121
[model]
2222
name = "flux"
2323
flavor = "flux-debug"
24+
hf_assets_path = "tests/assets/tokenizer"
2425

2526
[optimizer]
2627
name = "AdamW"
@@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f
4849
[parallelism]
4950
data_parallel_replicate_degree = 1
5051
data_parallel_shard_degree = -1
52+
context_parallel_degree = 1
5153

5254
[activation_checkpoint]
5355
mode = "full"

0 commit comments

Comments
 (0)