Skip to content

Commit a5c544b

Browse files
author
Donglai Wei
committed
fiber inference
1 parent e528dc8 commit a5c544b

File tree

14 files changed

+898
-225
lines changed

14 files changed

+898
-225
lines changed

CLAUDE.md

Lines changed: 734 additions & 0 deletions
Large diffs are not rendered by default.

QUICKSTART.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ conda create -n pytc python=3.10 -y
5454
conda activate pytc
5555

5656
# Install pre-built packages (avoids compilation)
57-
conda install -c conda-forge numpy h5py cython connected-components-3d -y
57+
conda install -c conda-forge numpy=1.23 h5py cython connected-components-3d mahotas -y
5858

5959
# Install PyTorch (adjust for your CUDA version)
6060
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

TROUBLESHOOTING.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@ pip install -e .
4747

4848
---
4949

50+
### ❌ "AttributeError: module 'numpy' has no attribute 'float'"
51+
52+
**Cause:** Mahotas version incompatibility with NumPy 2.0+. This occurs with older mahotas versions (< 1.4.18).
53+
54+
**Solution 1 - Upgrade packages (recommended):**
55+
```bash
56+
pip install --upgrade numpy mahotas
57+
```
58+
59+
**Solution 2 - Pin compatible versions:**
60+
```bash
61+
pip install numpy>=1.23.0 mahotas>=1.4.18
62+
```
63+
64+
**Why?** Mahotas 1.4.18+ is compatible with NumPy 2.x. The deprecated `np.float` alias was removed in NumPy 2.0.
65+
66+
---
67+
68+
### ❌ "Matplotlib requires numpy>=1.23"
69+
70+
**Cause:** Matplotlib requires NumPy 1.23 or higher for compatibility.
71+
72+
**Solution:**
73+
```bash
74+
conda activate pytc
75+
conda install -c conda-forge matplotlib -y
76+
pip install -e . --no-build-isolation
77+
```
78+
79+
---
80+
5081
### ❌ "Could not find a version that satisfies connected-components-3d"
5182

5283
**Cause:** Python version incompatibility (cc3d requires Python 3.10).

connectomics/config/hydra_config.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,6 @@ class CheckpointConfig:
518518
save_top_k: int = 1
519519
save_last: bool = True
520520
save_every_n_epochs: int = 10
521-
dirpath: str = "checkpoints/"
522521
checkpoint_filename: Optional[str] = None # Auto-generated from monitor if None
523522
use_timestamp: bool = True # Create timestamped subdirectories (YYYYMMDD_HHMMSS)
524523

@@ -815,8 +814,14 @@ class AugmentationConfig:
815814

816815
@dataclass
817816
class InferenceDataConfig:
818-
"""Inference data configuration."""
817+
"""Inference data configuration.
819818
819+
Output path is automatically computed from checkpoint directory:
820+
- Checkpoint: outputs/experiment_name/YYYYMMDD_HHMMSS/checkpoints/last.ckpt
821+
- Inference: outputs/experiment_name/YYYYMMDD_HHMMSS/inference/last.ckpt/{output_name}
822+
"""
823+
824+
test_path: str = "" # Base path for test data (e.g., "/path/to/dataset/test/")
820825
test_image: Any = None # str, List[str], or None - Can be single file or list of files
821826
test_label: Any = None # str, List[str], or None - Can be single file or list of files
822827
test_mask: Any = None # str, List[str], or None - Optional mask for inference
@@ -826,7 +831,9 @@ class InferenceDataConfig:
826831
test_transpose: List[int] = field(
827832
default_factory=list
828833
) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx)
829-
output_path: str = "results/"
834+
output_name: str = (
835+
"predictions.h5" # Output filename (auto-pathed to inference/{checkpoint}/{output_name})
836+
)
830837

831838
# 2D data support
832839
do_2d: bool = False # Enable 2D data processing for inference
@@ -838,8 +845,7 @@ class SlidingWindowConfig:
838845

839846
window_size: Optional[List[int]] = None
840847
sw_batch_size: Optional[int] = None # If None, will use system.inference.batch_size
841-
overlap: Optional[float] = 0.5 # Overlap ratio (0-1), or None to use stride instead
842-
stride: Optional[List[int]] = None # Explicit stride (overrides overlap if set)
848+
stride: Optional[List[int]] = None # Explicit stride for controlling window movement
843849
blending: str = "gaussian" # 'gaussian' or 'constant' - blending mode for overlapping patches
844850
sigma_scale: float = (
845851
0.125 # Gaussian sigma scale (only for blending='gaussian'); larger = smoother blending
@@ -860,7 +866,7 @@ class TestTimeAugmentationConfig:
860866
None # Single activation for all channels: 'softmax', 'sigmoid', 'tanh', None (deprecated, use channel_activations)
861867
)
862868
channel_activations: Optional[List[Any]] = (
863-
None # Per-channel activations: [[0, 'sigmoid'], [1, 'sigmoid'], [2, 'tanh']]
869+
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
864870
)
865871
select_channel: Any = (
866872
None # Channel selection: null (all), [1] (foreground), -1 (all) (applied even with null flip_axes)

connectomics/config/hydra_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def resolve_data_paths(cfg: Config) -> Config:
241241
2. Expanding glob patterns to actual file lists
242242
3. Flattening nested lists from glob expansion
243243
244+
Supported paths:
245+
- Training: cfg.data.train_path + cfg.data.train_image/train_label/train_mask
246+
- Validation: cfg.data.val_path + cfg.data.val_image/val_label/val_mask
247+
- Testing (legacy): cfg.data.test_path + cfg.data.test_image/test_label/test_mask
248+
- Inference (primary): cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask
249+
244250
Args:
245251
cfg: Config object to resolve paths for
246252
@@ -253,6 +259,12 @@ def resolve_data_paths(cfg: Config) -> Config:
253259
>>> resolve_data_paths(cfg)
254260
>>> print(cfg.data.train_image)
255261
['/data/barcode/PT37/img1_raw.tif', '/data/barcode/PT37/img2_raw.tif', '/data/barcode/file.tif']
262+
263+
>>> cfg.inference.data.test_path = "/data/test/"
264+
>>> cfg.inference.data.test_image = ["volume_*.tif"]
265+
>>> resolve_data_paths(cfg)
266+
>>> print(cfg.inference.data.test_image)
267+
['/data/test/volume_1.tif', '/data/test/volume_2.tif']
256268
"""
257269
import os
258270
from glob import glob
@@ -304,18 +316,18 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) ->
304316
cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask)
305317
cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json)
306318

307-
# Resolve test paths
319+
# Resolve test paths (legacy support for cfg.data.test_path)
308320
if cfg.data.test_path:
309321
cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image)
310322
cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label)
311323
cfg.data.test_mask = _combine_path(cfg.data.test_path, cfg.data.test_mask)
312324
cfg.data.test_json = _combine_path(cfg.data.test_path, cfg.data.test_json)
313325

314-
# Also resolve inference data paths
315-
if cfg.data.test_path and cfg.inference.data:
316-
cfg.inference.data.test_image = _combine_path(cfg.data.test_path, cfg.inference.data.test_image)
317-
cfg.inference.data.test_label = _combine_path(cfg.data.test_path, cfg.inference.data.test_label)
318-
cfg.inference.data.test_mask = _combine_path(cfg.data.test_path, cfg.inference.data.test_mask)
326+
# Resolve inference data paths (primary location for test_path)
327+
if hasattr(cfg.inference.data, 'test_path') and cfg.inference.data.test_path:
328+
cfg.inference.data.test_image = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_image)
329+
cfg.inference.data.test_label = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_label)
330+
cfg.inference.data.test_mask = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_mask)
319331

320332
return cfg
321333

connectomics/decoding/postprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def watershed_split(
139139
mask = np.zeros(distance.shape, dtype=bool)
140140
mask[tuple(coords.T)] = True
141141
markers = cc3d.connected_components(mask)
142-
split_objects = mahotas.cwatershed(-distance, markers, mask=cropped)
142+
split_objects = mahotas.cwatershed(-distance, markers)
143+
split_objects[~cropped] = 0 # Apply mask manually (mahotas 1.4.18 doesn't support mask parameter)
143144

144145
seg_id = np.unique(split_objects)
145146
new_id = []

0 commit comments

Comments
 (0)