@@ -208,61 +208,69 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
208208 },
209209 },
210210 ModelType .FLUX_DEV : {
211- "height" : 1024 ,
212- "width" : 1024 ,
213211 "backbone" : "transformer" ,
214- "guidance_scale" : 3.5 ,
215- "max_sequence_length" : 512 ,
216212 "dataset" : {
217213 "name" : "Gustavosta/Stable-Diffusion-Prompts" ,
218214 "split" : "train" ,
219215 "column" : "Prompt" ,
220216 },
217+ "inference_extra_args" : {
218+ "height" : 1024 ,
219+ "width" : 1024 ,
220+ "guidance_scale" : 3.5 ,
221+ "max_sequence_length" : 512 ,
222+ },
221223 },
222224 ModelType .FLUX_SCHNELL : {
223- "height" : 1024 ,
224- "width" : 1024 ,
225225 "backbone" : "transformer" ,
226- "guidance_scale" : 3.5 ,
227- "max_sequence_length" : 512 ,
228226 "dataset" : {
229227 "name" : "Gustavosta/Stable-Diffusion-Prompts" ,
230228 "split" : "train" ,
231229 "column" : "Prompt" ,
232230 },
231+ "inference_extra_args" : {
232+ "height" : 1024 ,
233+ "width" : 1024 ,
234+ "guidance_scale" : 3.5 ,
235+ "max_sequence_length" : 512 ,
236+ },
233237 },
234238 ModelType .LTX_VIDEO_DEV : {
235- "height" : 512 ,
236- "width" : 704 ,
237239 "backbone" : "transformer" ,
238- "num_frames" : 121 ,
239- "negative_prompt" : "worst quality, inconsistent motion, blurry, jittery, distorted" ,
240240 "dataset" : {
241241 "name" : "Gustavosta/Stable-Diffusion-Prompts" ,
242242 "split" : "train" ,
243243 "column" : "Prompt" ,
244244 },
245+ "inference_extra_args" : {
246+ "height" : 512 ,
247+ "width" : 704 ,
248+ "num_frames" : 121 ,
249+ "negative_prompt" : "worst quality, inconsistent motion, blurry, jittery, distorted" ,
250+ },
245251 },
246252 ModelType .WAN22_T2V : {
247253 "backbone" : "transformer" ,
248- "height" : 720 ,
249- "width" : 1280 ,
250- "num_frames" : 81 ,
251- "fps" : 16 ,
252- "guidance_scale" : 4.0 ,
253- "guidance_scale_2" : 3.0 ,
254- "negative_prompt" : (
255- "vivid colors, overexposed, static, blurry details, subtitles, style, "
256- "work of art, painting, picture, still, overall grayish, worst quality, "
257- "low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, "
258- "poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, "
259- "static image, cluttered background, three legs, many people in the background, "
260- "walking backwards"
261- ),
262254 "dataset" : {"name" : "nkp37/OpenVid-1M" , "split" : "train" , "column" : "caption" },
263255 "from_pretrained_extra_args" : {
264256 "boundary_ratio" : 0.875 ,
265257 },
258+ "inference_extra_args" : {
259+ "height" : 720 ,
260+ "width" : 1280 ,
261+ "num_frames" : 81 ,
262+ "fps" : 16 ,
263+ "guidance_scale" : 4.0 ,
264+ "guidance_scale_2" : 3.0 ,
265+ "negative_prompt" : (
266+ "vivid colors, overexposed, static, blurry details, subtitles, style, "
267+ "work of art, painting, picture, still, overall grayish, worst quality, "
268+ "low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, "
269+ "poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, "
270+ "static image, cluttered background, three legs, many people in the background, "
271+ "walking backwards"
272+ ),
273+ },
266274 },
267275}
268276
@@ -567,7 +575,7 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
567575 batched_prompts: List of batched calibration prompts
568576 """
569577 self .logger .info (f"Starting calibration with { self .config .num_batches } batches" )
570- extra_args = MODEL_DEFAULTS .get (self .model_type , {})
578+ extra_args = MODEL_DEFAULTS .get (self .model_type , {}). get ( "inference_extra_args" , {})
571579
572580 with tqdm (total = self .config .num_batches , desc = "Calibration" , unit = "batch" ) as pbar :
573581 for i , prompt_batch in enumerate (batched_prompts ):
0 commit comments