3737 get_tested_classes , # Official class list
3838 CellMapLossWrapper , # NaN-aware loss
3939)
40+ from cellmap_segmentation_challenge import config as cellmap_cfg
4041
4142# PyTC model building (import only, no modification)
4243from connectomics .models import build_model
@@ -53,24 +54,48 @@ class CellMapLightningModule(pl.LightningModule):
5354 Uses PyTC models as-is, no modifications needed.
5455 """
5556
56- def __init__ (self , model , criterion , optimizer_config , scheduler_config = None , classes = None ):
57+ def __init__ (
58+ self ,
59+ model ,
60+ criterion ,
61+ optimizer_config ,
62+ scheduler_config = None ,
63+ classes = None ,
64+ target_shape = None ,
65+ ):
5766 super ().__init__ ()
5867 self .model = model
5968 self .criterion = criterion
6069 self .optimizer_config = optimizer_config
6170 self .scheduler_config = scheduler_config
6271 self .classes = classes or []
72+ self .target_shape = target_shape
6373
6474 # Save hyperparameters
6575 self .save_hyperparameters (ignore = ['model' , 'criterion' ])
6676
77+ def _maybe_resample (self , images : torch .Tensor , labels : torch .Tensor ):
78+ """Optionally resample images/labels to a fixed shape to avoid scale filtering."""
79+ if self .target_shape is None :
80+ return images , labels
81+ # Expect 5D tensors (B, C, D, H, W). Use nearest for labels.
82+ target = tuple (self .target_shape )
83+ images_rs = torch .nn .functional .interpolate (
84+ images , size = target , mode = "trilinear" , align_corners = False
85+ )
86+ labels_rs = torch .nn .functional .interpolate (
87+ labels , size = target , mode = "nearest"
88+ )
89+ return images_rs , labels_rs
90+
6791 def forward (self , x ):
6892 return self .model (x )
6993
7094 def training_step (self , batch , batch_idx ):
7195 images = batch ['input' ]
7296 labels = batch ['output' ]
7397
98+ images , labels = self ._maybe_resample (images , labels )
7499 predictions = self (images )
75100 loss = self .criterion (predictions , labels )
76101
@@ -81,6 +106,7 @@ def validation_step(self, batch, batch_idx):
81106 images = batch ['input' ]
82107 labels = batch ['output' ]
83108
109+ images , labels = self ._maybe_resample (images , labels )
84110 predictions = self (images )
85111 loss = self .criterion (predictions , labels )
86112
@@ -137,17 +163,24 @@ def configure_optimizers(self):
137163 }
138164
139165
140- def train_cellmap (config_path : str ):
166+ def train_cellmap (config_path : str , data_root : str | None = None , target_shape = None ):
141167 """
142168 Main training function using CellMap's official tools + PyTC models.
143169
144170 Args:
145171 config_path: Path to Python config file (CellMap style)
172+ data_root: Optional override for the CellMap dataset root
146173 """
147174 # Load config (CellMap's safe config loader)
148175 print (f"Loading config from: { config_path } " )
149176 config = load_safe_config (config_path )
150177
178+ # Allow CLI overrides
179+ if data_root :
180+ setattr (config , "data_root" , data_root )
181+ if target_shape :
182+ setattr (config , "target_shape" , target_shape )
183+
151184 # Extract config values
152185 model_name = getattr (config , 'model_name' , 'mednext' )
153186 classes = getattr (config , 'classes' , get_tested_classes ())
@@ -180,16 +213,33 @@ def train_cellmap(config_path: str):
180213 print (f" Max epochs: { max_epochs } " )
181214 print (f" GPUs: { num_gpus } " )
182215 print (f" Precision: { precision } " )
216+ if target_shape :
217+ print (f" Target resample shape: { target_shape } " )
218+
219+ # Resolve data root override (defaults to package SEARCH_PATH under repo/data)
220+ data_root = getattr (config , "data_root" , None )
221+ target_shape = getattr (config , "target_shape" , target_shape )
222+ search_path = cellmap_cfg .SEARCH_PATH
223+ if data_root :
224+ search_path = os .path .normpath (
225+ os .path .join (data_root , "{dataset}/{dataset}.zarr/recon-1/{name}" )
226+ )
227+ print (f"Using data root: { data_root } " )
228+ else :
229+ print (f"Using default CellMap search path: { search_path } " )
183230
184231 # Generate datasplit CSV if doesn't exist (CellMap's official utility)
185232 if not os .path .exists (datasplit_path ):
186233 print (f"Generating datasplit CSV: { datasplit_path } " )
234+ # If resampling is enabled, allow all scales (skip filtering)
235+ scale_filter = None if target_shape else input_array_info .get ('scale' )
187236 make_datasplit_csv (
188237 classes = classes ,
189238 csv_path = datasplit_path ,
190239 validation_prob = 0.15 ,
191- scale = input_array_info . get ( 'scale' ) ,
240+ scale = scale_filter ,
192241 force_all_classes = 'validate' ,
242+ search_path = search_path ,
193243 )
194244 else :
195245 print (f"Using existing datasplit: { datasplit_path } " )
@@ -238,6 +288,7 @@ def train_cellmap(config_path: str):
238288 optimizer_config = {'lr' : learning_rate , 'weight_decay' : 1e-5 },
239289 scheduler_config = getattr (config , 'scheduler_config' , {'name' : 'constant' }),
240290 classes = classes ,
291+ target_shape = target_shape ,
241292 )
242293
243294 # Setup callbacks
@@ -296,6 +347,21 @@ def train_cellmap(config_path: str):
296347
297348 parser = argparse .ArgumentParser (description = 'Train PyTC models on CellMap data' )
298349 parser .add_argument ('config' , type = str , help = 'Path to config file' )
350+ parser .add_argument (
351+ '--data-root' ,
352+ type = str ,
353+ help = 'Override dataset root (e.g., /projects/weilab/dataset/cellmap)' ,
354+ default = None ,
355+ )
356+ parser .add_argument (
357+ '--target-shape' ,
358+ nargs = 3 ,
359+ metavar = ('D' , 'H' , 'W' ),
360+ type = int ,
361+ help = 'Resample input/label to this shape (e.g., --target-shape 128 128 128) to include all scales' ,
362+ default = None ,
363+ )
299364 args = parser .parse_args ()
300365
301- train_cellmap (args .config )
366+ target_shape = tuple (args .target_shape ) if args .target_shape else None
367+ train_cellmap (args .config , data_root = args .data_root , target_shape = target_shape )
0 commit comments