Skip to content
49 changes: 49 additions & 0 deletions cpu_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial
from multiprocessing import Process

from inference.core.cache import cache
from inference.core.env import (
ACTIVE_LEARNING_ENABLED,
ENABLE_STREAM_API,
GCP_SERVERLESS,
LAMBDA,
MAX_ACTIVE_MODELS,
STREAM_API_PRELOADED_PROCESSES,
)
from inference.core.interfaces.http.http_api import HttpInterface
from inference.core.interfaces.stream_manager.manager_app.app import start
from inference.core.managers.active_learning import (
ActiveLearningManager,
BackgroundTaskActiveLearningManager,
)
from inference.core.managers.base import ModelManager
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.registries.roboflow import (
RoboflowModelRegistry,
)
from inference.models.utils import ROBOFLOW_MODEL_TYPES

if ENABLE_STREAM_API:
stream_manager_process = Process(
target=partial(start, expected_warmed_up_pipelines=STREAM_API_PRELOADED_PROCESSES),
)
stream_manager_process.start()

model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)

if ACTIVE_LEARNING_ENABLED:
if LAMBDA or GCP_SERVERLESS:
model_manager = ActiveLearningManager(
model_registry=model_registry, cache=cache
)
else:
model_manager = BackgroundTaskActiveLearningManager(
model_registry=model_registry, cache=cache
)
else:
model_manager = ModelManager(model_registry=model_registry)

model_manager = WithFixedSizeCache(model_manager, max_size=MAX_ACTIVE_MODELS)
model_manager.init_pingback()
interface = HttpInterface(model_manager)
app = interface.app
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@
from inference.core.workflows.core_steps.transformations.image_slicer.v2 import (
ImageSlicerBlockV2,
)
from inference.core.workflows.core_steps.transformations.load_image_from_url.v1 import (
LoadImageFromUrlBlockV1,
)
from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import (
PerspectiveCorrectionBlockV1,
)
Expand Down Expand Up @@ -533,6 +536,7 @@
def load_blocks() -> List[Type[WorkflowBlock]]:
return [
AbsoluteStaticCropBlockV1,
LoadImageFromUrlBlockV1,
DynamicCropBlockV1,
DetectionsFilterBlockV1,
DetectionOffsetBlockV1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import hashlib
from typing import List, Literal, Type, Union
from uuid import uuid4

from pydantic import ConfigDict, Field

from inference.core.cache.lru_cache import LRUCache
from inference.core.utils.image_utils import load_image_from_url
from inference.core.workflows.execution_engine.entities.base import (
ImageParentMetadata,
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
IMAGE_KIND,
STRING_KIND,
Selector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

LONG_DESCRIPTION = """
Load an image from a URL.

This block downloads an image from the provided URL and makes it available
for use in the workflow pipeline. Optionally, the block can cache downloaded
images to avoid re-fetching the same URL multiple times.
"""

# Module-level cache instance following common pattern
image_cache = LRUCache(capacity=64)


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Load Image From URL",
"version": "v1",
"short_description": "Load an image from a URL.",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "transformation",
"ui_manifest": {
"section": "transformation",
"icon": "fas fa-image",
"blockPriority": 1,
},
}
)
type: Literal["roboflow_core/load_image_from_url@v1"]
url: Union[str, Selector(kind=[STRING_KIND])] = Field(
description="URL of the image to load",
examples=["https://example.com/image.jpg", "$inputs.image_url"],
)
cache: Union[bool, Selector(kind=[BOOLEAN_KIND])] = Field(
default=True,
description="Whether to cache the downloaded image to avoid re-fetching",
examples=[True, False, "$inputs.cache_image"],
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name="image", kind=[IMAGE_KIND]),
]

@classmethod
def get_execution_engine_compatibility(cls) -> str:
return ">=1.0.0,<2.0.0"


class LoadImageFromUrlBlockV1(WorkflowBlock):
@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(self, url: str, cache: bool = True, **kwargs) -> BlockResult:
try:
# Generate cache key using URL hash (following common pattern)
cache_key = hashlib.md5(url.encode("utf-8")).hexdigest()

# Check cache if enabled
if cache:
cached_image = image_cache.get(cache_key)
if cached_image is not None:
return {"image": cached_image}

# Load image using secure utility
numpy_image = load_image_from_url(value=url)

# Create proper parent metadata
parent_metadata = ImageParentMetadata(parent_id=str(uuid4()))

workflow_image = WorkflowImageData(
parent_metadata=parent_metadata,
numpy_image=numpy_image,
)

# Store in cache if enabled
if cache:
image_cache.set(cache_key, workflow_image)

return {"image": workflow_image}
except Exception as e:
raise RuntimeError(f"Failed to load image from URL {url}: {str(e)}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import numpy as np
import pytest
from pydantic import ValidationError
from unittest.mock import patch

from inference.core.workflows.core_steps.transformations.load_image_from_url.v1 import (
BlockManifest,
LoadImageFromUrlBlockV1,
)
from inference.core.workflows.execution_engine.entities.base import (
ImageParentMetadata,
WorkflowImageData,
)


@pytest.mark.parametrize("type_alias", ["roboflow_core/load_image_from_url@v1"])
@pytest.mark.parametrize("url_input", ["https://example.com/image.jpg", "$inputs.image_url"])
@pytest.mark.parametrize("cache_input", [True, False, "$inputs.cache_enabled"])
def test_load_image_from_url_manifest_validation_when_valid_input_given(
type_alias: str, url_input: str, cache_input
) -> None:
# given
raw_manifest = {
"type": type_alias,
"name": "load_image",
"url": url_input,
"cache": cache_input,
}

# when
result = BlockManifest.model_validate(raw_manifest)

# then
assert result == BlockManifest(
name="load_image",
type=type_alias,
url=url_input,
cache=cache_input,
)


@pytest.mark.parametrize("field_to_delete", ["type", "name", "url"])
def test_load_image_from_url_manifest_validation_when_required_field_missing(
field_to_delete: str,
) -> None:
# given
raw_manifest = {
"type": "roboflow_core/load_image_from_url@v1",
"name": "load_image",
"url": "https://example.com/image.jpg",
"cache": True,
}
del raw_manifest[field_to_delete]

# when
with pytest.raises(ValidationError):
_ = BlockManifest.model_validate(raw_manifest)


def test_load_image_from_url_manifest_validation_with_default_cache() -> None:
# given
raw_manifest = {
"type": "roboflow_core/load_image_from_url@v1",
"name": "load_image",
"url": "https://example.com/image.jpg",
# cache field omitted - should default to True
}

# when
result = BlockManifest.model_validate(raw_manifest)

# then
assert result.cache is True


@patch("inference.core.workflows.core_steps.transformations.load_image_from_url.v1.load_image_from_url")
def test_load_image_from_url_block_run_success(mock_load_image_from_url) -> None:
# given
test_url = "https://www.peta.org/wp-content/uploads/2023/05/wild-raccoon.jpg"
mock_numpy_image = np.zeros((480, 640, 3), dtype=np.uint8)
mock_load_image_from_url.return_value = mock_numpy_image

block = LoadImageFromUrlBlockV1()

# when
result = block.run(url=test_url, cache=True)

# then
assert "image" in result
assert isinstance(result["image"], WorkflowImageData)
assert np.array_equal(result["image"].numpy_image, mock_numpy_image)
assert isinstance(result["image"].parent_metadata, ImageParentMetadata)
assert result["image"].parent_metadata.parent_id is not None

# Verify the underlying function was called with correct parameters
mock_load_image_from_url.assert_called_once_with(value=test_url)


@patch("inference.core.workflows.core_steps.transformations.load_image_from_url.v1.load_image_from_url")
def test_load_image_from_url_block_run_caching_behavior(mock_load_image_from_url) -> None:
# given
test_url = "https://example.com/cached-image.jpg"
mock_numpy_image = np.zeros((50, 50, 3), dtype=np.uint8)
mock_load_image_from_url.return_value = mock_numpy_image

block = LoadImageFromUrlBlockV1()

# when - first call should load the image
result1 = block.run(url=test_url, cache=True)

# when - second call with same URL should use cache
result2 = block.run(url=test_url, cache=True)

# then
assert "image" in result1
assert "image" in result2

# Both results should have identical image data
assert np.array_equal(result1["image"].numpy_image, result2["image"].numpy_image)

# The underlying function should only be called once due to caching
mock_load_image_from_url.assert_called_once_with(value=test_url)


@patch("inference.core.workflows.core_steps.transformations.load_image_from_url.v1.load_image_from_url")
def test_load_image_from_url_block_run_error_handling(mock_load_image_from_url) -> None:
# given
test_url = "https://nonexistent.example.com/image.jpg"
mock_load_image_from_url.side_effect = Exception("Could not load image from url")

block = LoadImageFromUrlBlockV1()

# when/then
with pytest.raises(RuntimeError) as exc_info:
block.run(url=test_url, cache=False)

assert "Failed to load image from URL" in str(exc_info.value)
assert test_url in str(exc_info.value)
mock_load_image_from_url.assert_called_once_with(value=test_url)


def test_load_image_from_url_block_manifest_outputs() -> None:
# given/when
outputs = BlockManifest.describe_outputs()

# then
assert len(outputs) == 1
assert outputs[0].name == "image"
assert "image" in [kind.name for kind in outputs[0].kind]


def test_load_image_from_url_block_compatibility() -> None:
# given/when
compatibility = BlockManifest.get_execution_engine_compatibility()

# then
assert compatibility == ">=1.0.0,<2.0.0"


# Tests for Requirement 4: URL validation at runtime
@patch("inference.core.workflows.core_steps.transformations.load_image_from_url.v1.load_image_from_url")
def test_load_image_from_url_block_validates_invalid_url_format_at_runtime(mock_load_image_from_url) -> None:
# given
invalid_url = "not-a-valid-url"
mock_load_image_from_url.side_effect = Exception("Providing images via non https:// URL is not supported")

block = LoadImageFromUrlBlockV1()

# when/then
with pytest.raises(RuntimeError) as exc_info:
block.run(url=invalid_url, cache=False)

assert "Failed to load image from URL" in str(exc_info.value)
assert invalid_url in str(exc_info.value)
mock_load_image_from_url.assert_called_once_with(value=invalid_url)


# Tests for Requirement 5: Image extension validation
@patch("inference.core.workflows.core_steps.transformations.load_image_from_url.v1.load_image_from_url")
def test_load_image_from_url_block_validates_non_image_extension_at_runtime(mock_load_image_from_url) -> None:
# given
non_image_url = "https://example.com/document.pdf"
mock_load_image_from_url.side_effect = Exception("Could not decode bytes as image")

block = LoadImageFromUrlBlockV1()

# when/then
with pytest.raises(RuntimeError) as exc_info:
block.run(url=non_image_url, cache=False)

assert "Failed to load image from URL" in str(exc_info.value)
assert non_image_url in str(exc_info.value)
mock_load_image_from_url.assert_called_once_with(value=non_image_url)
3 changes: 3 additions & 0 deletions watch-dev.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

PROJECT=roboflow-platform ENABLE_BUILDER=True ENABLE_STREAM_API=True watchmedo auto-restart --pattern="*.py" --recursive -- uvicorn cpu_http:app --port 9001
Loading