|
| 1 | +from typing import Dict, List, Literal, Optional, Type, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import supervision as sv |
| 5 | +from pydantic import AliasChoices, ConfigDict, Field, model_validator |
| 6 | + |
| 7 | +from inference.core.workflows.core_steps.visualizations.common.base import ( |
| 8 | + OUTPUT_IMAGE_KEY, |
| 9 | + VisualizationBlock, |
| 10 | + VisualizationManifest, |
| 11 | +) |
| 12 | +from inference.core.workflows.execution_engine.entities.base import WorkflowImageData |
| 13 | +from inference.core.workflows.execution_engine.entities.types import ( |
| 14 | + IMAGE_KIND, |
| 15 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 16 | + INTEGER_KIND, |
| 17 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 18 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 19 | + STRING_KIND, |
| 20 | + Selector, |
| 21 | +) |
| 22 | +from inference.core.workflows.prototypes.block import BlockResult, WorkflowBlockManifest |
| 23 | + |
| 24 | +TYPE: str = "roboflow_core/icon_visualization@v1" |
| 25 | +SHORT_DESCRIPTION = "Draw icons on an image either at specific static coordinates or dynamically based on detections." |
| 26 | +LONG_DESCRIPTION = """ |
| 27 | +The `IconVisualization` block draws icons on an image using Supervision's `sv.IconAnnotator`. |
| 28 | +It supports two modes: |
| 29 | +1. **Static Mode**: Position an icon at a fixed location (e.g., for watermarks) |
| 30 | +2. **Dynamic Mode**: Position icons based on detection coordinates |
| 31 | +""" |
| 32 | + |
| 33 | + |
| 34 | +class IconManifest(VisualizationManifest): |
| 35 | + type: Literal[f"{TYPE}", "IconVisualization"] |
| 36 | + model_config = ConfigDict( |
| 37 | + json_schema_extra={ |
| 38 | + "name": "Icon Visualization", |
| 39 | + "version": "v1", |
| 40 | + "short_description": SHORT_DESCRIPTION, |
| 41 | + "long_description": LONG_DESCRIPTION, |
| 42 | + "license": "Apache-2.0", |
| 43 | + "block_type": "visualization", |
| 44 | + "search_keywords": ["annotator", "icon", "watermark"], |
| 45 | + "ui_manifest": { |
| 46 | + "section": "visualization", |
| 47 | + "icon": "far fa-image", |
| 48 | + "blockPriority": 5, |
| 49 | + "supervision": True, |
| 50 | + "warnings": [ |
| 51 | + { |
| 52 | + "property": "copy_image", |
| 53 | + "value": False, |
| 54 | + "message": "This setting will mutate its input image. If the input is used by other blocks, it may cause unexpected behavior.", |
| 55 | + } |
| 56 | + ], |
| 57 | + }, |
| 58 | + } |
| 59 | + ) |
| 60 | + |
| 61 | + icon: Selector(kind=[IMAGE_KIND]) = Field( |
| 62 | + title="Icon Image", |
| 63 | + description="The icon image to place on the input image (PNG with transparency recommended)", |
| 64 | + examples=["$inputs.icon", "$steps.image_loader.image"], |
| 65 | + json_schema_extra={ |
| 66 | + "always_visible": True, |
| 67 | + "order": 3, |
| 68 | + }, |
| 69 | + ) |
| 70 | + |
| 71 | + mode: Union[ |
| 72 | + Literal["static", "dynamic"], |
| 73 | + Selector(kind=[STRING_KIND]), |
| 74 | + ] = Field( |
| 75 | + default="dynamic", |
| 76 | + description="Mode for placing icons: 'static' for fixed position (watermark), 'dynamic' for detection-based", |
| 77 | + examples=["static", "dynamic", "$inputs.mode"], |
| 78 | + json_schema_extra={ |
| 79 | + "always_visible": True, |
| 80 | + "order": 1, |
| 81 | + }, |
| 82 | + ) |
| 83 | + |
| 84 | + predictions: Optional[ |
| 85 | + Selector( |
| 86 | + kind=[ |
| 87 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 88 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 89 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 90 | + ] |
| 91 | + ) |
| 92 | + ] = Field( |
| 93 | + default=None, |
| 94 | + description="Model predictions to place icons on (required for dynamic mode)", |
| 95 | + examples=["$steps.object_detection_model.predictions"], |
| 96 | + json_schema_extra={ |
| 97 | + "relevant_for": { |
| 98 | + "mode": {"values": ["dynamic"], "required": True}, |
| 99 | + }, |
| 100 | + "order": 4, |
| 101 | + }, |
| 102 | + ) |
| 103 | + |
| 104 | + icon_width: Union[int, Selector(kind=[INTEGER_KIND])] = Field( |
| 105 | + default=64, |
| 106 | + description="Width of the icon in pixels", |
| 107 | + examples=[64, "$inputs.icon_width"], |
| 108 | + json_schema_extra={ |
| 109 | + "always_visible": True, |
| 110 | + }, |
| 111 | + ) |
| 112 | + |
| 113 | + icon_height: Union[int, Selector(kind=[INTEGER_KIND])] = Field( |
| 114 | + default=64, |
| 115 | + description="Height of the icon in pixels", |
| 116 | + examples=[64, "$inputs.icon_height"], |
| 117 | + json_schema_extra={ |
| 118 | + "always_visible": True, |
| 119 | + }, |
| 120 | + ) |
| 121 | + |
| 122 | + position: Optional[ |
| 123 | + Union[ |
| 124 | + Literal[ |
| 125 | + "CENTER", |
| 126 | + "CENTER_LEFT", |
| 127 | + "CENTER_RIGHT", |
| 128 | + "TOP_CENTER", |
| 129 | + "TOP_LEFT", |
| 130 | + "TOP_RIGHT", |
| 131 | + "BOTTOM_LEFT", |
| 132 | + "BOTTOM_CENTER", |
| 133 | + "BOTTOM_RIGHT", |
| 134 | + "CENTER_OF_MASS", |
| 135 | + ], |
| 136 | + Selector(kind=[STRING_KIND]), |
| 137 | + ] |
| 138 | + ] = Field( |
| 139 | + default="TOP_CENTER", |
| 140 | + description="Position relative to detection for dynamic mode", |
| 141 | + examples=["TOP_CENTER", "$inputs.position"], |
| 142 | + json_schema_extra={ |
| 143 | + "relevant_for": { |
| 144 | + "mode": {"values": ["dynamic"], "required": False}, |
| 145 | + }, |
| 146 | + }, |
| 147 | + ) |
| 148 | + |
| 149 | + x_position: Optional[Union[int, Selector(kind=[INTEGER_KIND])]] = Field( |
| 150 | + default=10, |
| 151 | + description="X coordinate for static mode. Positive values from left edge, negative from right edge", |
| 152 | + examples=[10, -10, "$inputs.x_position"], |
| 153 | + json_schema_extra={ |
| 154 | + "relevant_for": { |
| 155 | + "mode": {"values": ["static"], "required": True}, |
| 156 | + }, |
| 157 | + }, |
| 158 | + ) |
| 159 | + |
| 160 | + y_position: Optional[Union[int, Selector(kind=[INTEGER_KIND])]] = Field( |
| 161 | + default=10, |
| 162 | + description="Y coordinate for static mode. Positive values from top edge, negative from bottom edge", |
| 163 | + examples=[10, -10, "$inputs.y_position"], |
| 164 | + json_schema_extra={ |
| 165 | + "relevant_for": { |
| 166 | + "mode": {"values": ["static"], "required": True}, |
| 167 | + }, |
| 168 | + }, |
| 169 | + ) |
| 170 | + |
| 171 | + @model_validator(mode="after") |
| 172 | + def validate_mode_parameters(self) -> "IconManifest": |
| 173 | + if self.mode == "dynamic": |
| 174 | + if self.predictions is None: |
| 175 | + raise ValueError("The 'predictions' field is required for dynamic mode") |
| 176 | + return self |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def get_execution_engine_compatibility(cls) -> Optional[str]: |
| 180 | + return ">=1.3.0,<2.0.0" |
| 181 | + |
| 182 | + |
| 183 | +class IconVisualizationBlockV1(VisualizationBlock): |
| 184 | + def __init__(self, *args, **kwargs): |
| 185 | + super().__init__(*args, **kwargs) |
| 186 | + self.annotatorCache = {} |
| 187 | + |
| 188 | + @classmethod |
| 189 | + def get_manifest(cls) -> Type[WorkflowBlockManifest]: |
| 190 | + return IconManifest |
| 191 | + |
| 192 | + def getAnnotator( |
| 193 | + self, |
| 194 | + icon_width: int, |
| 195 | + icon_height: int, |
| 196 | + position: Optional[str] = None, |
| 197 | + ) -> Optional[sv.annotators.base.BaseAnnotator]: |
| 198 | + if position is not None: |
| 199 | + key = f"dynamic_{icon_width}_{icon_height}_{position}" |
| 200 | + if key not in self.annotatorCache: |
| 201 | + self.annotatorCache[key] = sv.IconAnnotator( |
| 202 | + icon_resolution_wh=(icon_width, icon_height), |
| 203 | + icon_position=getattr(sv.Position, position), |
| 204 | + ) |
| 205 | + return self.annotatorCache[key] |
| 206 | + return None |
| 207 | + |
| 208 | + def run( |
| 209 | + self, |
| 210 | + image: WorkflowImageData, |
| 211 | + copy_image: bool, |
| 212 | + mode: str, |
| 213 | + icon: WorkflowImageData, |
| 214 | + predictions: Optional[sv.Detections], |
| 215 | + icon_width: int, |
| 216 | + icon_height: int, |
| 217 | + position: Optional[str], |
| 218 | + x_position: Optional[int], |
| 219 | + y_position: Optional[int], |
| 220 | + ) -> BlockResult: |
| 221 | + annotated_image = image.numpy_image.copy() if copy_image else image.numpy_image |
| 222 | + icon_np = icon.numpy_image.copy() |
| 223 | + |
| 224 | + import os |
| 225 | + import tempfile |
| 226 | + |
| 227 | + import cv2 |
| 228 | + |
| 229 | + # WorkflowImageData loses alpha channels when loading images. |
| 230 | + # Try to recover them from the original source. |
| 231 | + if icon_np.shape[2] == 3: |
| 232 | + # Try reloading from file with IMREAD_UNCHANGED |
| 233 | + if ( |
| 234 | + hasattr(icon, "_image_reference") |
| 235 | + and icon._image_reference |
| 236 | + and not icon._image_reference.startswith("http") |
| 237 | + ): |
| 238 | + try: |
| 239 | + icon_with_alpha = cv2.imread( |
| 240 | + icon._image_reference, cv2.IMREAD_UNCHANGED |
| 241 | + ) |
| 242 | + if icon_with_alpha is not None and icon_with_alpha.shape[2] == 4: |
| 243 | + icon_np = icon_with_alpha |
| 244 | + except: |
| 245 | + pass |
| 246 | + |
| 247 | + # Try decoding base64 with alpha preserved |
| 248 | + if ( |
| 249 | + icon_np.shape[2] == 3 |
| 250 | + and hasattr(icon, "_base64_image") |
| 251 | + and icon._base64_image |
| 252 | + ): |
| 253 | + try: |
| 254 | + import base64 |
| 255 | + |
| 256 | + image_bytes = base64.b64decode(icon._base64_image) |
| 257 | + nparr = np.frombuffer(image_bytes, np.uint8) |
| 258 | + decoded = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) |
| 259 | + if decoded is not None and len(decoded.shape) >= 2: |
| 260 | + if len(decoded.shape) == 2: |
| 261 | + decoded = cv2.cvtColor(decoded, cv2.COLOR_GRAY2BGR) |
| 262 | + if decoded.shape[2] == 4: |
| 263 | + icon_np = decoded |
| 264 | + except: |
| 265 | + pass |
| 266 | + |
| 267 | + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: |
| 268 | + # Ensure proper format for IconAnnotator |
| 269 | + if len(icon_np.shape) == 2: |
| 270 | + icon_np = cv2.cvtColor(icon_np, cv2.COLOR_GRAY2BGR) |
| 271 | + alpha = ( |
| 272 | + np.ones( |
| 273 | + (icon_np.shape[0], icon_np.shape[1], 1), dtype=icon_np.dtype |
| 274 | + ) |
| 275 | + * 255 |
| 276 | + ) |
| 277 | + icon_np = np.concatenate([icon_np, alpha], axis=2) |
| 278 | + elif icon_np.shape[2] == 3: |
| 279 | + alpha = ( |
| 280 | + np.ones( |
| 281 | + (icon_np.shape[0], icon_np.shape[1], 1), dtype=icon_np.dtype |
| 282 | + ) |
| 283 | + * 255 |
| 284 | + ) |
| 285 | + icon_np = np.concatenate([icon_np, alpha], axis=2) |
| 286 | + |
| 287 | + cv2.imwrite(f.name, icon_np) |
| 288 | + icon_path = f.name |
| 289 | + |
| 290 | + try: |
| 291 | + if mode == "static": |
| 292 | + img_height, img_width = annotated_image.shape[:2] |
| 293 | + |
| 294 | + # Handle negative positioning (from right/bottom edges) |
| 295 | + if x_position < 0: |
| 296 | + actual_x = img_width + x_position - icon_width |
| 297 | + else: |
| 298 | + actual_x = x_position |
| 299 | + |
| 300 | + if y_position < 0: |
| 301 | + actual_y = img_height + y_position - icon_height |
| 302 | + else: |
| 303 | + actual_y = y_position |
| 304 | + |
| 305 | + # IconAnnotator expects a detection, so create one at the desired position |
| 306 | + center_x = actual_x + icon_width // 2 |
| 307 | + center_y = actual_y + icon_height // 2 |
| 308 | + |
| 309 | + static_detections = sv.Detections( |
| 310 | + xyxy=np.array( |
| 311 | + [[center_x - 1, center_y - 1, center_x + 1, center_y + 1]], |
| 312 | + dtype=np.float64, |
| 313 | + ), |
| 314 | + class_id=np.array([0]), |
| 315 | + confidence=np.array([1.0]), |
| 316 | + ) |
| 317 | + |
| 318 | + annotator = sv.IconAnnotator( |
| 319 | + icon_resolution_wh=(icon_width, icon_height), |
| 320 | + icon_position=sv.Position.CENTER, |
| 321 | + ) |
| 322 | + |
| 323 | + annotated_image = annotator.annotate( |
| 324 | + scene=annotated_image, |
| 325 | + detections=static_detections, |
| 326 | + icon_path=icon_path, |
| 327 | + ) |
| 328 | + |
| 329 | + elif mode == "dynamic" and predictions is not None and len(predictions) > 0: |
| 330 | + annotator = self.getAnnotator( |
| 331 | + icon_width=icon_width, |
| 332 | + icon_height=icon_height, |
| 333 | + position=position, |
| 334 | + ) |
| 335 | + |
| 336 | + if annotator is not None: |
| 337 | + annotated_image = annotator.annotate( |
| 338 | + scene=annotated_image, |
| 339 | + detections=predictions, |
| 340 | + icon_path=icon_path, |
| 341 | + ) |
| 342 | + finally: |
| 343 | + os.unlink(icon_path) |
| 344 | + |
| 345 | + return { |
| 346 | + OUTPUT_IMAGE_KEY: WorkflowImageData.copy_and_replace( |
| 347 | + origin_image_data=image, numpy_image=annotated_image |
| 348 | + ) |
| 349 | + } |
0 commit comments