Skip to content

Commit 1ee6a11

Browse files
committed
Merge branch 'timotheschmidt/add-weight-vector' of https://github.com/EmoTim/stylegan2-ada-pytorch into timotheschmidt/add-weight-vector
2 parents 2ce9c51 + 11d9e1a commit 1ee6a11

File tree

5 files changed

+589
-10
lines changed

5 files changed

+589
-10
lines changed

age_estimator.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import insightface
2+
from insightface.app import FaceAnalysis
3+
import numpy as np
4+
from PIL import Image
5+
6+
7+
class AgeEstimator:
8+
def __init__(
9+
self, ctx_id: int = 0, det_size: tuple[int, int] = (1024, 1024)
10+
) -> None:
11+
self.app = FaceAnalysis(name="buffalo_l") # SOTA Model including age
12+
self.app.prepare(ctx_id=ctx_id, det_size=det_size) # 0 = gpu
13+
14+
def estimate_age(self, pil_img: Image.Image) -> float | None:
15+
"""
16+
Estimate age from a PIL Image.
17+
18+
Args:
19+
pil_img: PIL Image object
20+
21+
Returns:
22+
Estimated age as float, or None if no face detected
23+
"""
24+
img = np.array(pil_img)
25+
faces = self.app.get(img)
26+
if len(faces) == 0:
27+
return None
28+
return faces[0].age
29+
30+
def __call__(self, pil_img: Image.Image) -> float | None:
31+
"""Allow instance to be called directly."""
32+
return self.estimate_age(pil_img)

alpha_patch_work.png

11.5 MB
Loading

generate.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import matplotlib.pyplot as plt
2121

2222
import legacy
23+
from age_estimator import AgeEstimator
2324

2425
# ----------------------------------------------------------------------------
2526

@@ -224,6 +225,12 @@ def generate_images(
224225
f"Applying weight vector to style blocks {start_idx} to {end_idx} (inclusive)"
225226
)
226227

228+
# Initialize age estimator if using weight vector
229+
age_estimator = None
230+
if weight_vector is not None:
231+
print("Initializing age estimator...")
232+
age_estimator = AgeEstimator(ctx_id=0, det_size=(1024, 1024))
233+
227234
# Generate images.
228235
all_images = [] # Store images for composite: list of lists (one per seed)
229236

@@ -233,9 +240,31 @@ def generate_images(
233240
w = G.mapping(z, label)
234241

235242
if weight_vec is not None:
236-
# Generate images with weight modulation
237-
seed_images = []
238243
start_idx, end_idx = style_range
244+
245+
# First, check age with alpha=0 if age estimator is available
246+
if age_estimator is not None:
247+
img = G.synthesis(w, noise_mode=noise_mode)
248+
img = (
249+
(img.permute(0, 2, 3, 1) * 127.5 + 128)
250+
.clamp(0, 255)
251+
.to(torch.uint8)
252+
)
253+
img_array = img[0].cpu().numpy()
254+
pil_img = PIL.Image.fromarray(img_array, "RGB")
255+
estimated_age = age_estimator(pil_img)
256+
257+
if estimated_age is None:
258+
print(f" ⚠️ Seed {seed}: No face detected, skipping...")
259+
continue
260+
elif estimated_age <= 20:
261+
print(f" ⚠️ Seed {seed}: Age {estimated_age:.1f} <= 20, skipping...")
262+
continue
263+
else:
264+
print(f" ✓ Seed {seed}: Age {estimated_age:.1f} > 20, proceeding...")
265+
266+
# If age check passed (or not applicable), generate all alphas
267+
seed_images = []
239268
for alpha in alphas:
240269
# Clone w to avoid modifying the original
241270
w_modified = w.clone()
@@ -244,7 +273,7 @@ def generate_images(
244273
start_idx : end_idx + 1, :
245274
].unsqueeze(0)
246275
assert w_modified.shape[1:] == (G.num_ws, G.w_dim)
247-
img = G.synthesis(w_modified, noise_mode=noise_mode)
276+
img = G.synthesis(w_modified, truncation_psi=truncation_psi, noise_mode=noise_mode)
248277
img = (
249278
(img.permute(0, 2, 3, 1) * 127.5 + 128)
250279
.clamp(0, 255)
@@ -264,7 +293,7 @@ def generate_images(
264293
all_images.append(seed_images)
265294
else:
266295
# Generate image without weight modulation
267-
img = G.synthesis(w, noise_mode=noise_mode)
296+
img = G.synthesis(w, truncation_psi=truncation_psi, noise_mode=noise_mode)
268297
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
269298
PIL.Image.fromarray(img[0].cpu().numpy(), "RGB").save(
270299
f"{outdir}/seed{seed:04d}.png"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"click>=8.3.0",
99
"imageio>=2.37.2",
1010
"imageio-ffmpeg>=0.6.0",
11+
"insightface>=0.7.3",
1112
"matplotlib>=3.10.7",
1213
"ninja>=1.13.0",
1314
"numpy>=2.3.4",

0 commit comments

Comments
 (0)