Skip to content

Commit 2ce9c51

Browse files
committed
small changes
1 parent 548c70a commit 2ce9c51

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

create_alpha_patchwork.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def create_patchwork(
204204
plt.tight_layout()
205205

206206
# Save the figure
207-
print(f"Saving patchwork to: {output_path}")
207+
print(f"Saving patchwork to: {output_path}/alpha_patch_work")
208208
plt.savefig(output_path, dpi=150, bbox_inches="tight")
209209
plt.close()
210210

generate.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Generate images using pretrained network pickle."""
1010

1111
import os
12+
import re
1213
from typing import List, Optional
1314

1415
import click
@@ -23,22 +24,25 @@
2324
# ----------------------------------------------------------------------------
2425

2526

27+
def float_range(s: str) -> List[float]:
28+
29+
parts = s.split(":")
30+
if len(parts) == 3:
31+
start, stop, num = float(parts[0]), float(parts[1]), int(parts[2])
32+
return list(np.round(np.linspace(start, stop, num), decimals=1))
33+
else:
34+
raise ValueError(
35+
f"Linspace format requires exactly 3 values (start:stop:num), got {len(parts)}"
36+
)
37+
2638
def num_range(s: str) -> List[int]:
27-
"""Accept either 'start:stop:num' for linspace or comma separated list 'a,b,c,...' for explicit values."""
28-
29-
# Check if it's linspace format (colon-separated: start:stop:num)
30-
if ":" in s:
31-
parts = s.split(":")
32-
if len(parts) == 3:
33-
start, stop, num = float(parts[0]), float(parts[1]), int(parts[2])
34-
return list(np.linspace(start, stop, num).astype(int))
35-
else:
36-
raise ValueError(
37-
f"Linspace format requires exactly 3 values (start:stop:num), got {len(parts)}"
38-
)
39+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
3940

40-
# Otherwise treat as comma-separated list
41-
vals = s.split(",")
41+
range_re = re.compile(r'^(\d+)-(\d+)$')
42+
m = range_re.match(s)
43+
if m:
44+
return list(range(int(m.group(1)), int(m.group(2))+1))
45+
vals = s.split(',')
4246
return [int(x) for x in vals]
4347

4448

@@ -86,7 +90,7 @@ def num_range(s: str) -> List[int]:
8690
)
8791
@click.option(
8892
"--alphas",
89-
type=num_range,
93+
type=float_range,
9094
help='Alpha values for weight modulation (e.g., "-10,0,10" or "-10-10")',
9195
default="-5, -2, -1, 0, 1, 2, 5",
9296
show_default=True,
@@ -98,6 +102,14 @@ def num_range(s: str) -> List[int]:
98102
default=(0, 17),
99103
show_default=True,
100104
)
105+
@click.option(
106+
"--create-composite",
107+
type=bool,
108+
help="Bolean to create or not a composite grid of the generated images",
109+
default=False,
110+
show_default=True,
111+
)
112+
101113
def generate_images(
102114
ctx: click.Context,
103115
network_pkl: str,
@@ -110,6 +122,7 @@ def generate_images(
110122
weight_vector: Optional[str],
111123
alphas: List[int],
112124
style_range: tuple,
125+
create_composite: bool = False
113126
):
114127
"""Generate images using pretrained network pickle.
115128
@@ -258,7 +271,7 @@ def generate_images(
258271
)
259272

260273
# Create composite image if weight modulation was used
261-
if weight_vec is not None and len(all_images) > 0:
274+
if create_composite and weight_vec is not None and len(all_images) > 0:
262275
print("Creating composite image...")
263276
num_rows = len(all_images)
264277
num_cols = len(alphas)

0 commit comments

Comments
 (0)