Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration_test_8gpu_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ jobs:
python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8
python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8
rm -rf artifacts-to-be-uploaded/*/checkpoint
rm -rf artifacts-to-be-uploaded/flux/*/inference_results/
25 changes: 10 additions & 15 deletions tests/integration_tests/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.data_parallel_replicate_degree 2",
"--parallelism.context_parallel_degree 2",
]
],
"HSDP+CP",
"hsdp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--validation.enable",
]
"--validation.steps 5",
"--checkpoint.enable",
],
[],
],
"Flux Validation Test",
"validation",
"HSDP+CP+Validation+Inference",
"hsdp+cp+validation+inference",
ngpu=8,
),
]
return integration_tests_flavors
Expand All @@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
t5_encoder_version_arg = (
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
)
tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer"

all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

Expand All @@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd

# save checkpoint (idx == 0) and load it for generation (idx == 1)
if test_name == "test_generate" and idx == 1:
if test_name == "hsdp+cp+validation+inference" and idx == 1:
# For flux generation, test using inference script
cmd = (
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
Expand All @@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
cmd += " " + random_init_encoder_arg
cmd += " " + clip_encoder_version_arg
cmd += " " + t5_encoder_version_arg
cmd += " " + tokenzier_path_arg
cmd += " " + hf_assets_path_arg
if override_arg:
cmd += " " + " ".join(override_arg)

Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/flux/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def inference(config: JobConfig):
original_prompts = open(config.inference.prompts_path).readlines()
total_prompts = len(original_prompts)

if total_prompts < world_size:
raise ValueError(
f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). "
f"FSDP all-gather will hang if some ranks have no prompts to process."
)

# Distribute prompts across processes using round-robin assignment
prompts = original_prompts[global_rank::world_size]

Expand All @@ -45,7 +51,6 @@ def inference(config: JobConfig):
config.job.dump_folder,
config.inference.save_img_folder,
)

# Create mapping from local indices to global prompt indices
global_ids = list(range(global_rank, total_prompts, world_size))

Expand Down
2 changes: 2 additions & 0 deletions torchtitan/models/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ enable_wandb = false
[model]
name = "flux"
flavor = "flux-debug"
hf_assets_path = "tests/assets/tokenizer"

[optimizer]
name = "AdamW"
Expand Down Expand Up @@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f
[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
context_parallel_degree = 1

[activation_checkpoint]
mode = "full"
Expand Down
Loading