diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c0ace5cfc..175e4bb9d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -35,6 +35,7 @@ timm>=1.0.3 torch>=2.1.0 torchvision>=0.15.0 tqdm>=4.64.1 +transformers>=4.51.1 umap-learn>=0.5.3 wsidicom>=0.18.0 zarr>=2.13.3, <3.0.0 diff --git a/tests/models/test_arch_sam.py b/tests/models/test_arch_sam.py new file mode 100644 index 000000000..03e31a94b --- /dev/null +++ b/tests/models/test_arch_sam.py @@ -0,0 +1,65 @@ +"""Unit test package for SAM.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import torch + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device + +ON_GPU = toolbox_env.has_gpu() + +# Test pretrained Model ============================= + + +def test_functional_sam( + remote_sample: Callable, +) -> None: + """Test for SAM.""" + # convert to pathlib Path to prevent wsireader complaint + tile_path = Path(remote_sample("patch-extraction-vf")) + img = imread(tile_path) + + # test creation + + model = SAM(device=select_device(on_gpu=ON_GPU)) + + # create image patch and prompts + patch = img[63:191, 750:878, :] + + points = [[[64, 64]]] + boxes = [[[64, 64, 128, 128]]] + + # test preproc + tensor = torch.from_numpy(img) + patch = np.expand_dims(model.preproc(tensor), axis=0) + patch = model.preproc(patch) + + # test inference + + mask_output, score_output = model.infer_batch( + model, patch, points, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU) + ) + + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" diff --git a/tests/models/test_prompt_segmentor.py b/tests/models/test_prompt_segmentor.py new file mode 100644 index 000000000..1996a9448 --- /dev/null +++ b/tests/models/test_prompt_segmentor.py @@ -0,0 +1,273 @@ +"""Unit test package for Prompt Segmentor.""" + +from __future__ import annotations + +# ! The garbage collector +import multiprocessing +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest + +from tiatoolbox.models import PromptSegmentor +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.models.engine.semantic_segmentor import ( + IOSegmentorConfig, +) +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imwrite +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = toolbox_env.has_gpu() +BATCH_SIZE = 1 if not ON_GPU else 2 +try: + NUM_LOADER_WORKERS = multiprocessing.cpu_count() +except NotImplementedError: + NUM_LOADER_WORKERS = 2 + + +def test_functional_segmentor( + remote_sample: Callable, + tmp_path: Path, +) -> None: + """Functional test for segmentor.""" + save_dir = tmp_path / "dump" + # # convert to pathlib Path to prevent wsireader complaint + resolution = 2.0 + mini_wsi_svs = Path(remote_sample("patch-extraction-vf")) + reader = WSIReader.open(mini_wsi_svs, resolution) + thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") + thumb = thumb[63:191, 750:878, :] + mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" + imwrite(mini_wsi_jpg, thumb) + + # preemptive clean up + shutil.rmtree(save_dir, ignore_errors=True) + + model = SAM() + + # test engine setup + + _ = PromptSegmentor(None, BATCH_SIZE, NUM_LOADER_WORKERS) + + prompt_segmentor = PromptSegmentor(model, BATCH_SIZE, NUM_LOADER_WORKERS) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "mpp", "resolution": 4.0}, + ], + output_resolutions=[{"units": "mpp", "resolution": 4.0}], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[512, 512], + ) + + # test inference + + points = np.array([[[64, 64]], [[64, 64]]]) # Point on nuclei + + # Run on tile mode with multi-prompt + # Test running with multiple images + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg, mini_wsi_jpg], + mode="tile", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + pred_1 = np.load(output_list[0][1] + "/0.raw.0.npy") + pred_2 = np.load(output_list[1][1] + "/0.raw.0.npy") + assert len(output_list) == 2 + assert np.sum(pred_1 - pred_2) == 0 + + points = np.array([[[64, 64], [100, 40], [100, 70]]]) # Points on nuclei + boxes = np.array([[[10, 10, 50, 50], [80, 80, 110, 110]]]) # Boxes on nuclei + + # Run on tile mode with single-prompt + # Also tests boxes + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + mode="tile", + multi_prompt=False, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + box_coords=boxes, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + total_prompts = points.shape[1] + boxes.shape[1] + preds = [ + np.load(output_list[0][1] + f"/{i}.raw.0.npy") for i in range(total_prompts) + ] + + assert len(output_list) == 1 + assert len(preds) == total_prompts + + # Generate mask + mask = np.zeros((thumb.shape[0], thumb.shape[1]), dtype=np.uint8) + mask[32:120, 32:120] = 1 + mini_wsi_msk = f"{tmp_path}/mini_svs_mask.jpg" + imwrite(mini_wsi_msk, mask) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + output_resolutions=[{"units": "baseline", "resolution": 1.0}], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[512, 512], + save_resolution={"units": "baseline", "resolution": 1.0}, + ) + + # Only point within mask should generate a segmentation + points = np.array([[[64, 64], [100, 40]]]) + save_dir = tmp_path / "dump" + + # Run on wsi mode with multi-prompt + # Also tests masks + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + points = np.array([[[10, 30]]]) + boxes = np.array([[[10, 10, 30, 30]]]) + # Test no prompts within mask + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + box_coords=boxes, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + # Run on wsi mode with single-prompt + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + mode="wsi", + multi_prompt=False, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + +def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: + """Functional crash tests for segmentor.""" + # # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) + + save_dir = tmp_path / "test_crash_segmentor" + prompt_segmentor = PromptSegmentor(batch_size=BATCH_SIZE) + + # * test basic crash + with pytest.raises(TypeError, match=r".*`mask_reader`.*"): + prompt_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) + with pytest.raises(TypeError, match=r".*`mask_reader`.*"): + prompt_segmentor.get_mask_bounds(mini_wsi_msk) + with pytest.raises(TypeError, match=r".*mask_reader.*"): + prompt_segmentor.clip_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) + + with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): + prompt_segmentor.filter_coordinates( + WSIReader.open(mini_wsi_msk), + np.array([1.0, 2.0]), + ) + with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): + prompt_segmentor.clip_coordinates( + WSIReader.open(mini_wsi_msk), + np.array([1.0, 2.0]), + ) + prompt_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) + with pytest.raises(ValueError, match=r".*must be a valid file path.*"): + prompt_segmentor.get_reader( + mini_wsi_msk, + "not_exist", + "wsi", + auto_get_mask=True, + ) + + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + with pytest.raises(ValueError, match=r".*valid mode.*"): + prompt_segmentor.predict([], mode="abc") + + crash_segmentor = PromptSegmentor() + + # * test crash segmentor + def _predict_one_wsi( + *args: dict, + **kwargs: dict, + ) -> tuple[WSIReader, str]: + """Override the predict function to test crash segmentor.""" + msg = f"Test crash segmentor:{args} {kwargs}" + raise RuntimeError(msg) + + crash_segmentor._predict_one_wsi = _predict_one_wsi + shutil.rmtree(save_dir, ignore_errors=True) + with pytest.raises( + RuntimeError, + match=r"Test crash segmentor:\(.*\) \{.*\}", + ): + crash_segmentor.predict( + [mini_wsi_svs], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + patch_input_shape=[512, 512], + resolution=2.0, + units="mpp", + crash_on_exception=True, + save_dir=save_dir, + ) + + # test ignore crash + shutil.rmtree(save_dir, ignore_errors=True) + crash_segmentor.predict( + [mini_wsi_svs], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + patch_input_shape=[512, 512], + resolution=2.0, + units="mpp", + crash_on_exception=False, + save_dir=save_dir, + ) diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index ce97fb2fd..a926a86de 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -512,6 +512,50 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No assert len(main.UI["type_column"].children) == 1 +def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None: + """Test running SAM on points and a box.""" + slide_select = doc.get_model_by_name("slide_select0") + slide_select.value = [data_path["slide2"].name] + run_button = doc.get_model_by_name("to_model0") + assert len(main.UI["color_column"].children) == 0 + slide_select.value = [data_path["slide1"].name] + # set up a box selection + main.UI["box_source"].data = { + "x": [1200], + "y": [-2000], + "width": [400], + "height": [400], + } + + # select SAM model and run it on box + model_select = doc.get_model_by_name("model_drop0") + model_select.value = "SAM" + + click = ButtonClick(run_button) + run_button._trigger_event(click) + assert len(main.UI["color_column"].children) > 0 + + # test save functionality + save_button = doc.get_model_by_name("save_button0") + click = ButtonClick(save_button) + save_button._trigger_event(click) + saved_path = ( + data_path["base_path"] / "overlays" / (data_path["slide1"].stem + ".db") + ) + assert saved_path.exists() + + # load an overlay with different types + cprop_select = doc.get_model_by_name("cprop0") + cprop_select.value = ["prob"] + layer_drop = doc.get_model_by_name("layer_drop0") + click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) + layer_drop._trigger_event(click) + assert main.UI["vstate"].types == ["annotation"] + # check the per-type ui controls have been updated + assert len(main.UI["color_column"].children) == 1 + assert len(main.UI["type_column"].children) == 1 + + def test_alpha_sliders(doc: Document) -> None: """Test sliders for adjusting slide and overlay alpha.""" slide_alpha = doc.get_model_by_name("slide_alpha0") diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 39d1441ce..42b758c33 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -8,6 +8,7 @@ from .architecture.mapde import MapDe from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick +from .architecture.sam import SAM from .architecture.sccnn import SCCNN from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor @@ -17,6 +18,7 @@ PatchPredictor, WSIPatchDataset, ) +from .engine.prompt_segmentor import PromptSegmentor from .engine.semantic_segmentor import ( DeepFeatureExtractor, IOSegmentorConfig, @@ -25,6 +27,7 @@ ) __all__ = [ + "SAM", "SCCNN", "HoVerNet", "HoVerNetPlus", @@ -35,5 +38,6 @@ "NuClick", "NucleusInstanceSegmentor", "PatchPredictor", + "PromptSegmentor", "SemanticSegmentor", ] diff --git a/tiatoolbox/models/architecture/sam.py b/tiatoolbox/models/architecture/sam.py new file mode 100644 index 000000000..cb6236ce5 --- /dev/null +++ b/tiatoolbox/models/architecture/sam.py @@ -0,0 +1,223 @@ +"""Define SAM architecture.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image +from transformers import SamModel, SamProcessor + +from tiatoolbox.models.models_abc import ModelABC + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntBounds, IntPair + + +class SAM(ModelABC): + """Segment Anything Model (SAM) Architecture. + + Meta AI's zero-shot segmentation model. + SAM is used for interactive general-purpose segmentation. + + Currently supports SAM, which requires a checkpoint and model type. + + SAM accepts an RGB image patch along with a list of point and bounding + box coordinates as prompts. + + Args: + model_type (str): + Model type. + Currently supported: vit_b, vit_l, vit_h. + checkpoint_path (str): + Path to the model checkpoint. + device (str): + Device to run inference on. + + Examples: + >>> # instantiate SAM with checkpoint path and model type + >>> sam = SAM( + ... model_type="vit_b", + ... checkpoint_path="path/to/sam_checkpoint.pth" + ... ) + """ + + def __init__( + self: SAM, + model_path: str = "facebook/sam-vit-huge", + *, + device: str = "cpu", + ) -> None: + """Initialize :class:`SAM`.""" + super().__init__() + self.net_name = "SAM" + self.device = device + + self.model = SamModel.from_pretrained(model_path).to(device) + self.processor = SamProcessor.from_pretrained(model_path) + + def forward( # skipcq: PYL-W0221 + self: SAM, + imgs: list, + point_coords: list | None = None, + box_coords: list | None = None, + ) -> np.ndarray: + """Torch method. Defines forward pass on each image in the batch. + + Note: This architecture only uses a single layer, so only one forward pass + is needed. + + Args: + imgs (list): + List of images to process, of the shape NHWC. + point_coords (list): + List of point coordinates for each image. + box_coords (list): + Bounding box coordinates for each image. + + Returns: + list: + List of masks and scores for each image. + + """ + masks, scores = [], [] + for i, img in enumerate(imgs): + image = [Image.fromarray(img)] + embeddings, orig_sizes, reshaped_sizes = self._encode_image(image) + point_labels = None + points = None + boxes = None + + # Processor expects coordinates to be lists + def format_coords(coords: np.ndarray | list) -> list: + """Helper function that converts coordinates to list format.""" + if isinstance(coords, np.ndarray): + return coords[:, None, :].tolist() + if isinstance(coords[0], np.ndarray): + return [ + item.tolist() if isinstance(item, np.ndarray) else item + for item in coords + ] + return coords + + if point_coords is not None: + points = point_coords[i] + # Convert point coordinates to list + if points is not None: + point_labels = np.ones((1, len(points), 1), dtype=int).tolist() + points = [format_coords(points)] + + if box_coords is not None: + boxes = box_coords[i] + # Convert box coordinates to list + if boxes is not None: + boxes = [format_coords(boxes)] + inputs = self.processor( + image, + input_points=points, + input_labels=point_labels, + input_boxes=boxes, + return_tensors="pt", + ).to(self.device) + + # Replaces pixel_values with image embeddings + inputs.pop("pixel_values", None) + inputs.update( + { + "image_embeddings": embeddings, + "original_sizes": orig_sizes, + "reshaped_input_sizes": reshaped_sizes, + } + ) + + with torch.inference_mode(): + # Forward pass through the model + outputs = self.model(**inputs, multimask_output=False) + image_masks = self.processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + ) + image_scores = outputs.iou_scores.cpu() + masks.append(image_masks) + scores.append(image_scores) + torch.cuda.empty_cache() + + return np.array(masks), np.array(scores) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: list, + point_coords: list[list[IntPair]] | None = None, + box_coords: list[IntBounds] | None = None, + *, + device: str = "cpu", + ) -> np.ndarray: + """Run inference on an input batch. + + Contains logic for forward operation as well as I/O aggregation. + SAM accepts a list of points and a single bounding box per image. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (list): + A batch of data generated by + `torch.utils.data.DataLoader`. + point_coords (list): + Point coordinates for each image in the batch. + box_coords (list): + Bounding box coordinates for each image in the batch. + device (str): + Device to run inference on. + + Returns: + pred_info (list): + Tuple of masks and scores for each image in the batch. + + """ + model.eval().to(device) + + if isinstance(batch_data, torch.Tensor): + batch_data = batch_data.cpu().numpy() + + print("inputs are:") + print(point_coords, box_coords) + with torch.inference_mode(): + masks, scores = model(batch_data, point_coords, box_coords) + + return masks, scores + + def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray: + """Encodes image and stores size info for later mask post-processing.""" + processed = self.processor(image, return_tensors="pt") + original_sizes = processed["original_sizes"] + reshaped_sizes = processed["reshaped_input_sizes"] + + inputs = processed.to(self.device) + embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) + return embeddings, original_sizes, reshaped_sizes + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Pre-processes an image - Converts it into a format accepted by SAM (HWC).""" + # Move the tensor to the CPU if it's a PyTorch tensor + if isinstance(image, torch.Tensor): + image = image.permute(1, 2, 0).cpu().numpy() + + return image[..., :3] # Remove alpha channel if present + + def to( + self: ModelABC, + device: str = "cpu", + dtype: torch.dtype | None = None, + *, + non_blocking: bool = False, + ) -> ModelABC | torch.nn.DataParallel[ModelABC]: + """Moves the model to the specified device.""" + super().to(device, dtype=dtype, non_blocking=non_blocking) + self.device = device + self.model.to(device) + return self diff --git a/tiatoolbox/models/engine/prompt_segmentor.py b/tiatoolbox/models/engine/prompt_segmentor.py new file mode 100644 index 000000000..7ff6e3f2c --- /dev/null +++ b/tiatoolbox/models/engine/prompt_segmentor.py @@ -0,0 +1,92 @@ +"""This module enables interactive segmentation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import torch + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntBounds, IntPair + + +class PromptSegmentor: + """Engine for prompt-based segmentation of WSIs. + + This class is designed to work with the SAM model architecture. + It allows for interactive segmentation by providing point and bounding box + coordinates as prompts. The model can be used in both tile and WSI modes, + where tile mode processes individual image patches and WSI mode processes + whole-slide images. The class also supports multi-prompt segmentation, + where multiple point and bounding box coordinates can be provided for + segmentation. + + Args: + model (SAM): + Model architecture to use. If None, defaults to SAM. + + """ + + def __init__( + self, + model: torch.nn.Module = None, + ) -> None: + """Initializes the PromptSegmentor.""" + if model is None: + model = SAM() + self.model = model + + def predict( # skipcq: PYL-W0221 + self, + imgs: list, + point_coords: list[list[IntPair]] | None = None, + box_coords: list[list[IntBounds]] | None = None, + save_dir: str | Path | None = None, + device: str = "cpu", + ) -> list[tuple[Path, Path]]: + # use external for testing + self._device = device + self._model = model_to(model=self.model, device=device) + sample_outputs = self.model.infer_batch( + self.model, + torch.tensor(imgs[0]).unsqueeze(0), + point_coords=point_coords, + box_coords=box_coords, + device=self._device, + ) + save_path = save_dir / f"{0}" + mask = np.any(sample_outputs[0][0][0], axis=0, keepdims=False) + dict_to_store_semantic_segmentor( + patch_output={"predictions": mask[0]}, + scale_factor=self.scale, + offset=self.offset, + save_path=Path(f"{save_path}.{0}.db"), + ) + return Path(f"{save_path}.{0}.db") + + def calc_mpp( + self, area_dims: IntPair, base_mpp: float, fixed_size: int = 1500 + ) -> float: + """Calculates the microns per pixel for a fixed area of an image. + + Args: + area_dims (tuple): + Dimensions of the area to be scaled. + base_mpp (float): + Microns per pixel of the base image. + fixed_size (int): + Fixed size of the area. + + Returns: + float: + Microns per pixel required to scale the area to a fixed size. + """ + scale = max(area_dims) / fixed_size if max(area_dims) > fixed_size else 1.0 + self.scale = scale + return base_mpp * scale, scale diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index b222d0266..5499bdef7 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -364,12 +364,16 @@ def __init__( def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader: """Get appropriate reader for input path.""" - img_path = Path(img_path) - if self.mode == "wsi": - return WSIReader.open(img_path) - img = imread(img_path) - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. + if isinstance(img_path, np.ndarray): + # if img_path is a numpy array, it is already an image + img = img_path + else: + img_path = Path(img_path) + if self.mode == "wsi": + return WSIReader.open(img_path) + img = imread(img_path) + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. metadata = WSIMeta( mpp=np.array([1.0, 1.0]), objective_power=10, @@ -721,7 +725,8 @@ def get_reader( auto_get_mask: bool, ) -> tuple[WSIReader, WSIReader]: """Define how to get reader for mask and source image.""" - img_path = Path(img_path) + if not isinstance(img_path, np.ndarray): + img_path = Path(img_path) reader = WSIReader.open(img_path) mask_reader = None diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index a81098acd..3751327ff 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1208,6 +1208,7 @@ def process_contours( contours: list[np.ndarray], hierarchy: np.ndarray, scale_factor: tuple[float, float] = (1, 1), + offset: np.ndarray | None = None, ) -> list[Annotation]: """Process contours and hierarchy to create annotations. @@ -1218,6 +1219,8 @@ def process_contours( A list of hierarchy. scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. + offset (np.ndarray | None): + Optional offset to be added to the coordinates of the annotations. Returns: list: @@ -1231,6 +1234,8 @@ def process_contours( for i, layer_ in enumerate(contours): coords: np.ndarray = layer_.squeeze() scaled_coords: np.ndarray = np.array([np.array(scale_factor) * coords]) + if offset is not None: + scaled_coords += offset # save one points as a line, otherwise save the Polygon if len(layer_) > 2: # noqa: PLR2004 @@ -1308,6 +1313,7 @@ def dict_to_store_semantic_segmentor( scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, + offset: np.ndarray | None = None, ) -> AnnotationStore | Path: """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. @@ -1324,13 +1330,14 @@ def dict_to_store_semantic_segmentor( save_path (str or Path): Optional Output directory to save the Annotation Store results. + offset: np.ndarray | None = None: + Optional offset to be added to the coordinates of the annotations. Returns: (SQLiteStore or Path): An SQLiteStore containing Annotations for each patch or Path to file storing SQLiteStore containing Annotations for each patch. - """ preds = patch_output["predictions"] @@ -1354,7 +1361,7 @@ def dict_to_store_semantic_segmentor( ) contours = cast("list[np.ndarray]", contours) - annotations_list_ = process_contours(contours, hierarchy, scale_factor) + annotations_list_ = process_contours(contours, hierarchy, scale_factor, offset) annotations_list.extend(annotations_list_) _ = store.append_many( diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py index 5df7acb04..7a461c3af 100644 --- a/tiatoolbox/visualization/bokeh_app/main.py +++ b/tiatoolbox/visualization/bokeh_app/main.py @@ -65,9 +65,8 @@ # GitHub actions seems unable to find TIAToolbox unless this is here sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from tiatoolbox import logger -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - NucleusInstanceSegmentor, -) +from tiatoolbox.models.engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from tiatoolbox.models.engine.prompt_segmentor import PromptSegmentor from tiatoolbox.tools.pyramid import ZoomifyGenerator from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.visualization import random_colors @@ -1118,6 +1117,8 @@ def to_model_cb(attr: ButtonClick) -> None: # noqa: ARG001 """Callback to run currently selected model.""" if UI["vstate"].current_model == "hovernet": segment_on_box() + elif UI["vstate"].current_model == "SAM": + sam_segment() # Add any other models here else: # pragma: no cover logger.warning("unknown model") @@ -1273,6 +1274,105 @@ def segment_on_box() -> None: rmtree(tmp_mask_dir) +def sam_segment() -> None: + """Callback to run SAM using a point on the slide. + + Will run GeneralSegmentor on selected region of wsi defined + by the point in pt_source. + + """ + # Get point coordinates + x = np.round(UI["pt_source"].data["x"]) + y = np.round(UI["pt_source"].data["y"]) + point_coords = ( + np.array([[[x[i], -y[i]] for i in range(len(x))]], np.uint32) + if len(x) > 0 + else None + ) + + # Get box coordinates + x = np.round(UI["box_source"].data["x"]) + y = np.round(UI["box_source"].data["y"]) + height = np.round(UI["box_source"].data["height"]) + width = np.round(UI["box_source"].data["width"]) + x = [ + round(UI["box_source"].data["x"][i] - 0.5 * UI["box_source"].data["width"][i]) + for i in range(len(x)) + ] + y = [ + -round(UI["box_source"].data["y"][i] + 0.5 * UI["box_source"].data["height"][i]) + for i in range(len(y)) + ] + width = [round(UI["box_source"].data["width"][i]) for i in range(len(x))] + height = [round(UI["box_source"].data["height"][0]) for i in range(len(x))] + box_coords = ( + np.array( + [[[x[i], y[i], x[i] + width[i], height[i] + y[i]] for i in range(len(x))]], + np.uint32, + ) + if len(x) > 0 + else None + ) + + prompt_segmentor = PromptSegmentor() + tmp_save_dir = Path(tempfile.mkdtemp()) + + x_start = max(0, UI["p"].x_range.start) + y_start = max(0, -UI["p"].y_range.end) + x_end = min(UI["p"].x_range.end, UI["vstate"].dims[0]) + y_end = min(-UI["p"].y_range.start, UI["vstate"].dims[1]) + + height = y_end - y_start + width = x_end - x_start + res, scale_factor = prompt_segmentor.calc_mpp( + (width, height), UI["vstate"].mpp[0], 1500 + ) + + # read the region of interest from the slide + if UI["vstate"].wsi is None: + raise ValueError("No slide loaded, cannot run SAM") + roi = UI["vstate"].wsi.read_bounds( + (int(x_start), int(y_start), int(x_end), int(y_end)), + resolution=res, + units="mpp", + ) + + # transform point_coords and box_coords to the roi coordinate system + if point_coords is not None: + point_coords = (point_coords - np.array([[x_start, y_start]])) / scale_factor + if box_coords is not None: + box_coords = ( + box_coords - np.array([[x_start, y_start, x_start, y_start]]) + ) / scale_factor + prompt_segmentor.offset = np.array([x_start, y_start]) + + # Run SAM on the point + prediction = prompt_segmentor.predict( + imgs=[roi], + device=select_device(on_gpu=torch.cuda.is_available()), + save_dir=tmp_save_dir / "sam_out", + point_coords=point_coords, + box_coords=box_coords, + ) + + ann_loc = str(prediction) + + # slide_filename = UI["vstate"].slide_path.stem + ".db" + # destination = doc_config["overlay_folder"] / slide_filename + + # Move the database file + # ! Need to check if this is necessary + # move(ann_loc, destination) + + fname = make_safe_name(ann_loc) + resp = UI["s"].put( + f"http://{host2}:{port}/tileserver/overlay", + data={"overlay_path": fname}, + ) + ann_types = json.loads(resp.text) + update_ui_on_new_annotations(ann_types) + + # endregion # Set up main window @@ -1501,7 +1601,7 @@ def gather_ui_elements( # noqa: PLR0915 ) model_drop = Select( title="choose model:", - options=["hovernet"], + options=["hovernet", "SAM"], height=25, width=120, max_width=120,