2020import matplotlib .pyplot as plt
2121
2222import legacy
23+ from age_estimator import AgeEstimator
2324
2425# ----------------------------------------------------------------------------
2526
@@ -224,6 +225,12 @@ def generate_images(
224225 f"Applying weight vector to style blocks { start_idx } to { end_idx } (inclusive)"
225226 )
226227
228+ # Initialize age estimator if using weight vector
229+ age_estimator = None
230+ if weight_vector is not None :
231+ print ("Initializing age estimator..." )
232+ age_estimator = AgeEstimator (ctx_id = 0 , det_size = (1024 , 1024 ))
233+
227234 # Generate images.
228235 all_images = [] # Store images for composite: list of lists (one per seed)
229236
@@ -233,9 +240,31 @@ def generate_images(
233240 w = G .mapping (z , label )
234241
235242 if weight_vec is not None :
236- # Generate images with weight modulation
237- seed_images = []
238243 start_idx , end_idx = style_range
244+
245+ # First, check age with alpha=0 if age estimator is available
246+ if age_estimator is not None :
247+ img = G .synthesis (w , noise_mode = noise_mode )
248+ img = (
249+ (img .permute (0 , 2 , 3 , 1 ) * 127.5 + 128 )
250+ .clamp (0 , 255 )
251+ .to (torch .uint8 )
252+ )
253+ img_array = img [0 ].cpu ().numpy ()
254+ pil_img = PIL .Image .fromarray (img_array , "RGB" )
255+ estimated_age = age_estimator (pil_img )
256+
257+ if estimated_age is None :
258+ print (f" ⚠️ Seed { seed } : No face detected, skipping..." )
259+ continue
260+ elif estimated_age <= 20 :
261+ print (f" ⚠️ Seed { seed } : Age { estimated_age :.1f} <= 20, skipping..." )
262+ continue
263+ else :
264+ print (f" ✓ Seed { seed } : Age { estimated_age :.1f} > 20, proceeding..." )
265+
266+ # If age check passed (or not applicable), generate all alphas
267+ seed_images = []
239268 for alpha in alphas :
240269 # Clone w to avoid modifying the original
241270 w_modified = w .clone ()
@@ -244,7 +273,7 @@ def generate_images(
244273 start_idx : end_idx + 1 , :
245274 ].unsqueeze (0 )
246275 assert w_modified .shape [1 :] == (G .num_ws , G .w_dim )
247- img = G .synthesis (w_modified , noise_mode = noise_mode )
276+ img = G .synthesis (w_modified , truncation_psi = truncation_psi , noise_mode = noise_mode )
248277 img = (
249278 (img .permute (0 , 2 , 3 , 1 ) * 127.5 + 128 )
250279 .clamp (0 , 255 )
@@ -264,7 +293,7 @@ def generate_images(
264293 all_images .append (seed_images )
265294 else :
266295 # Generate image without weight modulation
267- img = G .synthesis (w , noise_mode = noise_mode )
296+ img = G .synthesis (w , truncation_psi = truncation_psi , noise_mode = noise_mode )
268297 img = (img .permute (0 , 2 , 3 , 1 ) * 127.5 + 128 ).clamp (0 , 255 ).to (torch .uint8 )
269298 PIL .Image .fromarray (img [0 ].cpu ().numpy (), "RGB" ).save (
270299 f"{ outdir } /seed{ seed :04d} .png"
0 commit comments