@@ -26,21 +26,17 @@ def inference(config: JobConfig):
2626 world_size = int (os .environ ["WORLD_SIZE" ])
2727 global_rank = int (os .environ ["RANK" ])
2828
29- single_prompt_mode = config .inference .prompt is not None
30-
31- # Use single prompt if specified, otherwise read from file
32- if single_prompt_mode :
33- original_prompts = [config .inference .prompt ]
34- logger .info (f"Using single prompt: { config .inference .prompt } " )
35- bs = 1
36- # If only single prompt, each rank will generate an image with the same prompt
37- prompts = original_prompts
38- else :
39- original_prompts = open (config .inference .prompts_path ).readlines ()
40- logger .info (f"Reading prompts from: { config .inference .prompts_path } " )
41- bs = config .inference .local_batch_size
42- # Distribute prompts across processes using round-robin assignment
43- prompts = original_prompts [global_rank ::world_size ]
29+ original_prompts = open (config .inference .prompts_path ).readlines ()
30+ logger .info (f"Reading prompts from: { config .inference .prompts_path } " )
31+ if len (original_prompts ) < world_size :
32+ raise ValueError (
33+ f"Number of prompts ({ len (prompts )} ) must be >= number of ranks ({ world_size } ). "
34+ f"FSDP all-gather will hang if some ranks have no prompts to process."
35+ )
36+
37+ bs = config .inference .local_batch_size
38+ # Distribute prompts across processes using round-robin assignment
39+ prompts = original_prompts [global_rank ::world_size ]
4440
4541 total_prompts = len (original_prompts )
4642
@@ -58,13 +54,7 @@ def inference(config: JobConfig):
5854 config .inference .save_img_folder ,
5955 )
6056 # Create mapping from local indices to global prompt indices
61- if single_prompt_mode :
62- # In single prompt mode, all ranks process the same prompt (index 0)
63- # But each rank generates a different image (different seed/rank)
64- global_ids = [0 ] * len (prompts )
65- else :
66- # In multi-prompt mode, use round-robin distribution
67- global_ids = list (range (global_rank , total_prompts , world_size ))
57+ global_ids = list (range (global_rank , total_prompts , world_size ))
6858
6959 for i in range (0 , len (prompts ), bs ):
7060 images = generate_image (
0 commit comments