Skip to content

Commit a18b392

Browse files
author
Donglai Wei
committed
add resample
1 parent 9c6d27c commit a18b392

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

scripts/cellmap/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ cd ..
2525

2626
# Verify installation
2727
python -c "from cellmap_segmentation_challenge.utils import TEST_CROPS; print(f'{len(TEST_CROPS)} test crops loaded')"
28+
29+
# Point to your dataset location via CLI (preferred)
30+
# Example:
31+
# python scripts/cellmap/train_cellmap.py scripts/cellmap/configs/mednext_cos7.py --data-root /projects/weilab/dataset/cellmap
32+
# If your crops are at mixed resolutions, you can resample to a fixed shape to include all scales:
33+
# python scripts/cellmap/train_cellmap.py scripts/cellmap/configs/mednext_cos7.py --data-root /projects/weilab/dataset/cellmap --target-shape 128 128 128
2834
```
2935

3036
### 2. Quick Test (10 epochs)

scripts/cellmap/train_cellmap.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
get_tested_classes, # Official class list
3838
CellMapLossWrapper, # NaN-aware loss
3939
)
40+
from cellmap_segmentation_challenge import config as cellmap_cfg
4041

4142
# PyTC model building (import only, no modification)
4243
from connectomics.models import build_model
@@ -53,24 +54,48 @@ class CellMapLightningModule(pl.LightningModule):
5354
Uses PyTC models as-is, no modifications needed.
5455
"""
5556

56-
def __init__(self, model, criterion, optimizer_config, scheduler_config=None, classes=None):
57+
def __init__(
58+
self,
59+
model,
60+
criterion,
61+
optimizer_config,
62+
scheduler_config=None,
63+
classes=None,
64+
target_shape=None,
65+
):
5766
super().__init__()
5867
self.model = model
5968
self.criterion = criterion
6069
self.optimizer_config = optimizer_config
6170
self.scheduler_config = scheduler_config
6271
self.classes = classes or []
72+
self.target_shape = target_shape
6373

6474
# Save hyperparameters
6575
self.save_hyperparameters(ignore=['model', 'criterion'])
6676

77+
def _maybe_resample(self, images: torch.Tensor, labels: torch.Tensor):
78+
"""Optionally resample images/labels to a fixed shape to avoid scale filtering."""
79+
if self.target_shape is None:
80+
return images, labels
81+
# Expect 5D tensors (B, C, D, H, W). Use nearest for labels.
82+
target = tuple(self.target_shape)
83+
images_rs = torch.nn.functional.interpolate(
84+
images, size=target, mode="trilinear", align_corners=False
85+
)
86+
labels_rs = torch.nn.functional.interpolate(
87+
labels, size=target, mode="nearest"
88+
)
89+
return images_rs, labels_rs
90+
6791
def forward(self, x):
6892
return self.model(x)
6993

7094
def training_step(self, batch, batch_idx):
7195
images = batch['input']
7296
labels = batch['output']
7397

98+
images, labels = self._maybe_resample(images, labels)
7499
predictions = self(images)
75100
loss = self.criterion(predictions, labels)
76101

@@ -81,6 +106,7 @@ def validation_step(self, batch, batch_idx):
81106
images = batch['input']
82107
labels = batch['output']
83108

109+
images, labels = self._maybe_resample(images, labels)
84110
predictions = self(images)
85111
loss = self.criterion(predictions, labels)
86112

@@ -137,17 +163,24 @@ def configure_optimizers(self):
137163
}
138164

139165

140-
def train_cellmap(config_path: str):
166+
def train_cellmap(config_path: str, data_root: str | None = None, target_shape=None):
141167
"""
142168
Main training function using CellMap's official tools + PyTC models.
143169
144170
Args:
145171
config_path: Path to Python config file (CellMap style)
172+
data_root: Optional override for the CellMap dataset root
146173
"""
147174
# Load config (CellMap's safe config loader)
148175
print(f"Loading config from: {config_path}")
149176
config = load_safe_config(config_path)
150177

178+
# Allow CLI overrides
179+
if data_root:
180+
setattr(config, "data_root", data_root)
181+
if target_shape:
182+
setattr(config, "target_shape", target_shape)
183+
151184
# Extract config values
152185
model_name = getattr(config, 'model_name', 'mednext')
153186
classes = getattr(config, 'classes', get_tested_classes())
@@ -180,16 +213,33 @@ def train_cellmap(config_path: str):
180213
print(f" Max epochs: {max_epochs}")
181214
print(f" GPUs: {num_gpus}")
182215
print(f" Precision: {precision}")
216+
if target_shape:
217+
print(f" Target resample shape: {target_shape}")
218+
219+
# Resolve data root override (defaults to package SEARCH_PATH under repo/data)
220+
data_root = getattr(config, "data_root", None)
221+
target_shape = getattr(config, "target_shape", target_shape)
222+
search_path = cellmap_cfg.SEARCH_PATH
223+
if data_root:
224+
search_path = os.path.normpath(
225+
os.path.join(data_root, "{dataset}/{dataset}.zarr/recon-1/{name}")
226+
)
227+
print(f"Using data root: {data_root}")
228+
else:
229+
print(f"Using default CellMap search path: {search_path}")
183230

184231
# Generate datasplit CSV if doesn't exist (CellMap's official utility)
185232
if not os.path.exists(datasplit_path):
186233
print(f"Generating datasplit CSV: {datasplit_path}")
234+
# If resampling is enabled, allow all scales (skip filtering)
235+
scale_filter = None if target_shape else input_array_info.get('scale')
187236
make_datasplit_csv(
188237
classes=classes,
189238
csv_path=datasplit_path,
190239
validation_prob=0.15,
191-
scale=input_array_info.get('scale'),
240+
scale=scale_filter,
192241
force_all_classes='validate',
242+
search_path=search_path,
193243
)
194244
else:
195245
print(f"Using existing datasplit: {datasplit_path}")
@@ -238,6 +288,7 @@ def train_cellmap(config_path: str):
238288
optimizer_config={'lr': learning_rate, 'weight_decay': 1e-5},
239289
scheduler_config=getattr(config, 'scheduler_config', {'name': 'constant'}),
240290
classes=classes,
291+
target_shape=target_shape,
241292
)
242293

243294
# Setup callbacks
@@ -296,6 +347,21 @@ def train_cellmap(config_path: str):
296347

297348
parser = argparse.ArgumentParser(description='Train PyTC models on CellMap data')
298349
parser.add_argument('config', type=str, help='Path to config file')
350+
parser.add_argument(
351+
'--data-root',
352+
type=str,
353+
help='Override dataset root (e.g., /projects/weilab/dataset/cellmap)',
354+
default=None,
355+
)
356+
parser.add_argument(
357+
'--target-shape',
358+
nargs=3,
359+
metavar=('D', 'H', 'W'),
360+
type=int,
361+
help='Resample input/label to this shape (e.g., --target-shape 128 128 128) to include all scales',
362+
default=None,
363+
)
299364
args = parser.parse_args()
300365

301-
train_cellmap(args.config)
366+
target_shape = tuple(args.target_shape) if args.target_shape else None
367+
train_cellmap(args.config, data_root=args.data_root, target_shape=target_shape)

0 commit comments

Comments
 (0)