File tree Expand file tree Collapse file tree 1 file changed +4
-6
lines changed
torchtitan/models/flux/inference Expand file tree Collapse file tree 1 file changed +4
-6
lines changed Original file line number Diff line number Diff line change @@ -25,21 +25,19 @@ def inference(config: JobConfig):
2525 # Distributed processing setup: Each GPU/process handles a subset of prompts
2626 world_size = int (os .environ ["WORLD_SIZE" ])
2727 global_rank = int (os .environ ["RANK" ])
28-
2928 original_prompts = open (config .inference .prompts_path ).readlines ()
30- logger .info (f"Reading prompts from: { config .inference .prompts_path } " )
31- if len (original_prompts ) < world_size :
29+ total_prompts = len (original_prompts )
30+
31+ if total_prompts < world_size :
3232 raise ValueError (
33- f"Number of prompts ({ len ( prompts ) } ) must be >= number of ranks ({ world_size } ). "
33+ f"Number of prompts ({ total_prompts } ) must be >= number of ranks ({ world_size } ). "
3434 f"FSDP all-gather will hang if some ranks have no prompts to process."
3535 )
3636
3737 bs = config .inference .local_batch_size
3838 # Distribute prompts across processes using round-robin assignment
3939 prompts = original_prompts [global_rank ::world_size ]
4040
41- total_prompts = len (original_prompts )
42-
4341 trainer .checkpointer .load (step = config .checkpoint .load_step )
4442 t5_tokenizer , clip_tokenizer = build_flux_tokenizer (config )
4543
You can’t perform that action at this time.
0 commit comments