Skip to content

Commit 28326b7

Browse files
committed
Merge branch 'timotheschmidt/add-weight-vector' of https://github.com/EmoTim/stylegan2-ada-pytorch into timotheschmidt/add-weight-vector
2 parents 4200143 + 1ee6a11 commit 28326b7

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

create_alpha_patchwork.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def create_patchwork(
204204
plt.tight_layout()
205205

206206
# Save the figure
207-
print(f"Saving patchwork to: {output_path}")
207+
print(f"Saving patchwork to: {output_path}/alpha_patch_work")
208208
plt.savefig(output_path, dpi=150, bbox_inches="tight")
209209
plt.close()
210210

generate.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Generate images using pretrained network pickle."""
1010

1111
import os
12+
import re
1213
from typing import List, Optional
1314

1415
import click
@@ -24,22 +25,25 @@
2425
# ----------------------------------------------------------------------------
2526

2627

28+
def float_range(s: str) -> List[float]:
29+
30+
parts = s.split(":")
31+
if len(parts) == 3:
32+
start, stop, num = float(parts[0]), float(parts[1]), int(parts[2])
33+
return list(np.round(np.linspace(start, stop, num), decimals=1))
34+
else:
35+
raise ValueError(
36+
f"Linspace format requires exactly 3 values (start:stop:num), got {len(parts)}"
37+
)
38+
2739
def num_range(s: str) -> List[int]:
28-
"""Accept either 'start:stop:num' for linspace or comma separated list 'a,b,c,...' for explicit values."""
29-
30-
# Check if it's linspace format (colon-separated: start:stop:num)
31-
if ":" in s:
32-
parts = s.split(":")
33-
if len(parts) == 3:
34-
start, stop, num = float(parts[0]), float(parts[1]), int(parts[2])
35-
return list(np.linspace(start, stop, num).astype(int))
36-
else:
37-
raise ValueError(
38-
f"Linspace format requires exactly 3 values (start:stop:num), got {len(parts)}"
39-
)
40+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
4041

41-
# Otherwise treat as comma-separated list
42-
vals = s.split(",")
42+
range_re = re.compile(r'^(\d+)-(\d+)$')
43+
m = range_re.match(s)
44+
if m:
45+
return list(range(int(m.group(1)), int(m.group(2))+1))
46+
vals = s.split(',')
4347
return [int(x) for x in vals]
4448

4549

@@ -87,7 +91,7 @@ def num_range(s: str) -> List[int]:
8791
)
8892
@click.option(
8993
"--alphas",
90-
type=num_range,
94+
type=float_range,
9195
help='Alpha values for weight modulation (e.g., "-10,0,10" or "-10-10")',
9296
default="-5, -2, -1, 0, 1, 2, 5",
9397
show_default=True,
@@ -99,6 +103,14 @@ def num_range(s: str) -> List[int]:
99103
default=(0, 17),
100104
show_default=True,
101105
)
106+
@click.option(
107+
"--create-composite",
108+
type=bool,
109+
help="Bolean to create or not a composite grid of the generated images",
110+
default=False,
111+
show_default=True,
112+
)
113+
102114
def generate_images(
103115
ctx: click.Context,
104116
network_pkl: str,
@@ -111,6 +123,7 @@ def generate_images(
111123
weight_vector: Optional[str],
112124
alphas: List[int],
113125
style_range: tuple,
126+
create_composite: bool = False
114127
):
115128
"""Generate images using pretrained network pickle.
116129
@@ -287,7 +300,7 @@ def generate_images(
287300
)
288301

289302
# Create composite image if weight modulation was used
290-
if weight_vec is not None and len(all_images) > 0:
303+
if create_composite and weight_vec is not None and len(all_images) > 0:
291304
print("Creating composite image...")
292305
num_rows = len(all_images)
293306
num_cols = len(alphas)

0 commit comments

Comments
 (0)