diff --git a/.github/workflows/integration_test_8gpu_models.yaml b/.github/workflows/integration_test_8gpu_models.yaml index 129049b8f6..b673da5adf 100644 --- a/.github/workflows/integration_test_8gpu_models.yaml +++ b/.github/workflows/integration_test_8gpu_models.yaml @@ -54,3 +54,4 @@ jobs: python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8 python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8 rm -rf artifacts-to-be-uploaded/*/checkpoint + rm -rf artifacts-to-be-uploaded/flux/*/inference_results/ diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 321ac1280c..a7ed51832f 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", - ] - ], - "HSDP+CP", - "hsdp+cp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ "--validation.enable", - ] + "--validation.steps 5", + "--checkpoint.enable", + ], + [], ], - "Flux Validation Test", - "validation", + "HSDP+CP+Validation+Inference", + "hsdp+cp+validation+inference", + ngpu=8, ), ] return integration_tests_flavors @@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir t5_encoder_version_arg = ( "--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/" ) - tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer" + hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) @@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: + if test_name == "hsdp+cp+validation+inference" and idx == 1: # For flux generation, test using inference script cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " @@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd += " " + random_init_encoder_arg cmd += " " + clip_encoder_version_arg cmd += " " + t5_encoder_version_arg - cmd += " " + tokenzier_path_arg + cmd += " " + hf_assets_path_arg if override_arg: cmd += " " + " ".join(override_arg) diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 0c06a385ef..b89887ad51 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -28,6 +28,12 @@ def inference(config: JobConfig): original_prompts = open(config.inference.prompts_path).readlines() total_prompts = len(original_prompts) + if total_prompts < world_size: + raise ValueError( + f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). " + f"FSDP all-gather will hang if some ranks have no prompts to process." + ) + # Distribute prompts across processes using round-robin assignment prompts = original_prompts[global_rank::world_size] @@ -45,7 +51,6 @@ def inference(config: JobConfig): config.job.dump_folder, config.inference.save_img_folder, ) - # Create mapping from local indices to global prompt indices global_ids = list(range(global_rank, total_prompts, world_size)) diff --git a/torchtitan/models/flux/train_configs/debug_model.toml b/torchtitan/models/flux/train_configs/debug_model.toml index 47a033c546..b943925c1c 100644 --- a/torchtitan/models/flux/train_configs/debug_model.toml +++ b/torchtitan/models/flux/train_configs/debug_model.toml @@ -21,6 +21,7 @@ enable_wandb = false [model] name = "flux" flavor = "flux-debug" +hf_assets_path = "tests/assets/tokenizer" [optimizer] name = "AdamW" @@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 +context_parallel_degree = 1 [activation_checkpoint] mode = "full"