Skip to content

Commit ce47d3c

Browse files
add support for on-prem
1 parent 94ad7f6 commit ce47d3c

20 files changed

+457
-158
lines changed

model-engine/model_engine_server/api/dependencies.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
9595
LiveEndpointResourceGateway,
9696
)
97+
from model_engine_server.infra.gateways.resources.onprem_queue_endpoint_resource_delegate import (
98+
OnPremQueueEndpointResourceDelegate,
99+
)
97100
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
98101
QueueEndpointResourceDelegate,
99102
)
@@ -114,6 +117,7 @@
114117
FakeDockerRepository,
115118
LiveTokenizerRepository,
116119
LLMFineTuneRepository,
120+
OnPremDockerRepository,
117121
RedisModelEndpointCacheRepository,
118122
S3FileLLMFineTuneEventsRepository,
119123
S3FileLLMFineTuneRepository,
@@ -225,6 +229,8 @@ def _get_external_interfaces(
225229
queue_delegate = FakeQueueEndpointResourceDelegate()
226230
elif infra_config().cloud_provider == "azure":
227231
queue_delegate = ASBQueueEndpointResourceDelegate()
232+
elif infra_config().cloud_provider == "onprem":
233+
queue_delegate = OnPremQueueEndpointResourceDelegate()
228234
else:
229235
queue_delegate = SQSQueueEndpointResourceDelegate(
230236
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
@@ -238,6 +244,9 @@ def _get_external_interfaces(
238244
elif infra_config().cloud_provider == "azure":
239245
inference_task_queue_gateway = servicebus_task_queue_gateway
240246
infra_task_queue_gateway = servicebus_task_queue_gateway
247+
elif infra_config().cloud_provider == "onprem":
248+
inference_task_queue_gateway = redis_task_queue_gateway
249+
infra_task_queue_gateway = redis_task_queue_gateway
241250
elif infra_config().celery_broker_type_redis:
242251
inference_task_queue_gateway = redis_task_queue_gateway
243252
infra_task_queue_gateway = redis_task_queue_gateway
@@ -274,16 +283,12 @@ def _get_external_interfaces(
274283
monitoring_metrics_gateway=monitoring_metrics_gateway,
275284
use_asyncio=(not CIRCLECI),
276285
)
277-
filesystem_gateway = (
278-
ABSFilesystemGateway()
279-
if infra_config().cloud_provider == "azure"
280-
else S3FilesystemGateway()
281-
)
282-
llm_artifact_gateway = (
283-
ABSLLMArtifactGateway()
284-
if infra_config().cloud_provider == "azure"
285-
else S3LLMArtifactGateway()
286-
)
286+
if infra_config().cloud_provider == "azure":
287+
filesystem_gateway = ABSFilesystemGateway()
288+
llm_artifact_gateway = ABSLLMArtifactGateway()
289+
else:
290+
filesystem_gateway = S3FilesystemGateway()
291+
llm_artifact_gateway = S3LLMArtifactGateway()
287292
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
288293
filesystem_gateway=filesystem_gateway
289294
)
@@ -328,18 +333,11 @@ def _get_external_interfaces(
328333
hmi_config.cloud_file_llm_fine_tune_repository,
329334
)
330335
if infra_config().cloud_provider == "azure":
331-
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(
332-
file_path=file_path,
333-
)
336+
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path)
337+
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
334338
else:
335-
llm_fine_tune_repository = S3FileLLMFineTuneRepository(
336-
file_path=file_path,
337-
)
338-
llm_fine_tune_events_repository = (
339-
ABSFileLLMFineTuneEventsRepository()
340-
if infra_config().cloud_provider == "azure"
341-
else S3FileLLMFineTuneEventsRepository()
342-
)
339+
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path)
340+
llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository()
343341
llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService(
344342
docker_image_batch_job_gateway=docker_image_batch_job_gateway,
345343
docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository,
@@ -350,17 +348,18 @@ def _get_external_interfaces(
350348
docker_image_batch_job_gateway=docker_image_batch_job_gateway
351349
)
352350

353-
file_storage_gateway = (
354-
ABSFileStorageGateway()
355-
if infra_config().cloud_provider == "azure"
356-
else S3FileStorageGateway()
357-
)
351+
if infra_config().cloud_provider == "azure":
352+
file_storage_gateway = ABSFileStorageGateway()
353+
else:
354+
file_storage_gateway = S3FileStorageGateway()
358355

359356
docker_repository: DockerRepository
360357
if CIRCLECI:
361358
docker_repository = FakeDockerRepository()
362-
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
359+
elif infra_config().cloud_provider == "azure":
363360
docker_repository = ACRDockerRepository()
361+
elif infra_config().cloud_provider == "onprem":
362+
docker_repository = OnPremDockerRepository()
364363
else:
365364
docker_repository = ECRDockerRepository()
366365

model-engine/model_engine_server/common/config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,29 @@ def from_yaml(cls, yaml_path):
9090

9191
@property
9292
def cache_redis_url(self) -> str:
93+
cloud_provider = infra_config().cloud_provider
94+
95+
if cloud_provider == "onprem":
96+
if self.cache_redis_aws_url:
97+
logger.info("On-prem deployment using cache_redis_aws_url")
98+
return self.cache_redis_aws_url
99+
redis_host = os.getenv("REDIS_HOST", "redis")
100+
redis_port = getattr(infra_config(), "redis_port", 6379)
101+
return f"redis://{redis_host}:{redis_port}/0"
102+
93103
if self.cache_redis_aws_url:
94-
assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
104+
assert cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
95105
if self.cache_redis_aws_secret_name:
96106
logger.warning(
97107
"Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url"
98108
)
99109
return self.cache_redis_aws_url
100110
elif self.cache_redis_aws_secret_name:
101-
assert (
102-
infra_config().cloud_provider == "aws"
103-
), "cache_redis_aws_secret_name is only for AWS"
104-
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
111+
assert cloud_provider == "aws", "cache_redis_aws_secret_name is only for AWS"
112+
creds = get_key_file(self.cache_redis_aws_secret_name)
105113
return creds["cache-url"]
106114

107-
assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
115+
assert self.cache_redis_azure_host and cloud_provider == "azure"
108116
username = os.getenv("AZURE_OBJECT_ID")
109117
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
110118
password = token.token

model-engine/model_engine_server/common/io.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010

1111
def open_wrapper(uri: str, mode: str = "rt", **kwargs):
1212
client: Any
13-
cloud_provider: str
14-
# This follows the 5.1.0 smart_open API
1513
try:
1614
cloud_provider = infra_config().cloud_provider
1715
except Exception:
1816
cloud_provider = "aws"
17+
1918
if cloud_provider == "azure":
2019
from azure.identity import DefaultAzureCredential
2120
from azure.storage.blob import BlobServiceClient
@@ -24,6 +23,22 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs):
2423
f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net",
2524
DefaultAzureCredential(),
2625
)
26+
elif cloud_provider == "onprem":
27+
session = boto3.Session()
28+
client_kwargs = {}
29+
30+
s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv(
31+
"S3_ENDPOINT_URL"
32+
)
33+
if s3_endpoint:
34+
client_kwargs["endpoint_url"] = s3_endpoint
35+
36+
addressing_style = getattr(infra_config(), "s3_addressing_style", "path")
37+
client_kwargs["config"] = boto3.session.Config(
38+
s3={"addressing_style": addressing_style}
39+
)
40+
41+
client = session.client("s3", **client_kwargs)
2742
else:
2843
profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE"))
2944
session = boto3.Session(profile_name=profile_name)

model-engine/model_engine_server/core/celery/app.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -531,17 +531,26 @@ def _get_backend_url_and_conf(
531531
backend_url = get_redis_endpoint(1)
532532
elif backend_protocol == "s3":
533533
backend_url = "s3://"
534-
if aws_role is None:
535-
aws_session = session(infra_config().profile_ml_worker)
534+
if infra_config().cloud_provider == "aws":
535+
if aws_role is None:
536+
aws_session = session(infra_config().profile_ml_worker)
537+
else:
538+
aws_session = session(aws_role)
539+
out_conf_changes.update(
540+
{
541+
"s3_boto3_session": aws_session,
542+
"s3_bucket": s3_bucket,
543+
"s3_base_path": s3_base_path,
544+
}
545+
)
536546
else:
537-
aws_session = session(aws_role)
538-
out_conf_changes.update(
539-
{
540-
"s3_boto3_session": aws_session,
541-
"s3_bucket": s3_bucket,
542-
"s3_base_path": s3_base_path,
543-
}
544-
)
547+
logger.info("Non-AWS deployment, using environment variables for S3 backend credentials")
548+
out_conf_changes.update(
549+
{
550+
"s3_bucket": s3_bucket,
551+
"s3_base_path": s3_base_path,
552+
}
553+
)
545554
elif backend_protocol == "abs":
546555
backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}"
547556
else:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# On-premise deployment configuration
2+
# This configuration file provides defaults for on-prem deployments
3+
# Many values can be overridden via environment variables
4+
5+
cloud_provider: "onprem"
6+
env: "production" # Can be: production, staging, development, local
7+
k8s_cluster_name: "onprem-cluster"
8+
dns_host_domain: "ml.company.local"
9+
default_region: "us-east-1" # Placeholder for compatibility with cloud-agnostic code
10+
11+
# ====================
12+
# Object Storage (MinIO/S3-compatible)
13+
# ====================
14+
s3_bucket: "model-engine"
15+
# S3 endpoint URL - can be overridden by S3_ENDPOINT_URL env var
16+
# Examples: "https://minio.company.local", "http://minio-service:9000"
17+
s3_endpoint_url: "" # Set via S3_ENDPOINT_URL env var if not specified here
18+
# MinIO requires path-style addressing (bucket in URL path, not subdomain)
19+
s3_addressing_style: "path"
20+
21+
# ====================
22+
# Redis Configuration
23+
# ====================
24+
# Redis is used for:
25+
# - Celery task queue broker
26+
# - Model endpoint caching
27+
# - Inference autoscaling metrics
28+
redis_host: "" # Set via REDIS_HOST env var (e.g., "redis.company.local" or "redis-service")
29+
redis_port: 6379
30+
# Whether to use Redis as Celery broker (true for on-prem)
31+
celery_broker_type_redis: true
32+
33+
# ====================
34+
# Celery Configuration
35+
# ====================
36+
# Backend protocol: "redis" for on-prem (not "s3" or "abs")
37+
celery_backend_protocol: "redis"
38+
39+
# ====================
40+
# Database Configuration
41+
# ====================
42+
# Database connection settings (credentials from environment variables)
43+
# DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
44+
db_host: "postgres" # Default hostname, can be overridden by DB_HOST env var
45+
db_port: 5432
46+
db_name: "llm_engine"
47+
db_engine_pool_size: 20
48+
db_engine_max_overflow: 10
49+
db_engine_echo: false
50+
db_engine_echo_pool: false
51+
db_engine_disconnect_strategy: "pessimistic"
52+
53+
# ====================
54+
# Docker Registry Configuration
55+
# ====================
56+
# Docker registry prefix for container images
57+
# Examples: "registry.company.local", "harbor.company.local/ml-platform"
58+
# Leave empty if using full image paths directly
59+
docker_repo_prefix: "registry.company.local"
60+
61+
# ====================
62+
# Monitoring & Observability
63+
# ====================
64+
# Prometheus server address for metrics (optional)
65+
# prometheus_server_address: "http://prometheus:9090"
66+
67+
# ====================
68+
# Not applicable for on-prem (kept for compatibility)
69+
# ====================
70+
ml_account_id: "onprem"
71+
profile_ml_worker: "default"
72+
profile_ml_inference_worker: "default"

model-engine/model_engine_server/db/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,17 @@ def get_engine_url(
5959
key_file = get_key_file_name(env) # type: ignore
6060
logger.debug(f"Using key file {key_file}")
6161

62-
if infra_config().cloud_provider == "azure":
62+
if infra_config().cloud_provider == "onprem":
63+
user = os.environ.get("DB_USER", "postgres")
64+
password = os.environ.get("DB_PASSWORD", "postgres")
65+
host = os.environ.get("DB_HOST_RO") or os.environ.get("DB_HOST", "localhost")
66+
port = os.environ.get("DB_PORT", "5432")
67+
dbname = os.environ.get("DB_NAME", "llm_engine")
68+
logger.info(f"Connecting to db {host}:{port}, name {dbname}")
69+
70+
engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
71+
72+
elif infra_config().cloud_provider == "azure":
6373
client = SecretClient(
6474
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net",
6575
credential=DefaultAzureCredential(),

model-engine/model_engine_server/domain/entities/model_bundle_entity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def validate_fields_present_for_framework_type(cls, field_values):
7171
"type was selected."
7272
)
7373
else: # field_values["framework_type"] == ModelBundleFramework.CUSTOM:
74-
assert field_values["ecr_repo"] and field_values["image_tag"], (
75-
"Expected `ecr_repo` and `image_tag` to be non-null because the custom framework "
74+
assert field_values["image_tag"], (
75+
"Expected `image_tag` to be non-null because the custom framework "
7676
"type was selected."
7777
)
7878
return field_values

model-engine/model_engine_server/entrypoints/k8s_cache.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
ECRDockerRepository,
4343
FakeDockerRepository,
4444
)
45+
from model_engine_server.infra.repositories.onprem_docker_repository import (
46+
OnPremDockerRepository,
47+
)
4548
from model_engine_server.infra.repositories.db_model_endpoint_record_repository import (
4649
DbModelEndpointRecordRepository,
4750
)
@@ -124,8 +127,10 @@ async def main(args: Any):
124127
docker_repo: DockerRepository
125128
if CIRCLECI:
126129
docker_repo = FakeDockerRepository()
127-
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
130+
elif infra_config().cloud_provider == "azure":
128131
docker_repo = ACRDockerRepository()
132+
elif infra_config().cloud_provider == "onprem":
133+
docker_repo = OnPremDockerRepository()
129134
else:
130135
docker_repo = ECRDockerRepository()
131136
while True:
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Any, Dict, Sequence
2+
3+
from model_engine_server.core.loggers import logger_name, make_logger
4+
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
5+
QueueEndpointResourceDelegate,
6+
QueueInfo,
7+
)
8+
9+
logger = make_logger(logger_name())
10+
11+
__all__: Sequence[str] = ("OnPremQueueEndpointResourceDelegate",)
12+
13+
14+
class OnPremQueueEndpointResourceDelegate(QueueEndpointResourceDelegate):
15+
async def create_queue_if_not_exists(
16+
self,
17+
endpoint_id: str,
18+
endpoint_name: str,
19+
endpoint_created_by: str,
20+
endpoint_labels: Dict[str, Any],
21+
) -> QueueInfo:
22+
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
23+
24+
logger.debug(
25+
f"On-prem queue for endpoint {endpoint_id}: {queue_name} "
26+
f"(Redis queues don't require explicit creation)"
27+
)
28+
29+
return QueueInfo(queue_name=queue_name, queue_url=None)
30+
31+
async def delete_queue(self, endpoint_id: str) -> None:
32+
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
33+
logger.debug(
34+
f"Delete request for queue {queue_name} (no-op for Redis-based queues)"
35+
)
36+
37+
async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]:
38+
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
39+
40+
logger.debug(f"Getting attributes for queue {queue_name}")
41+
42+
return {
43+
"Attributes": {
44+
"ApproximateNumberOfMessages": "0",
45+
"QueueName": queue_name,
46+
},
47+
"ResponseMetadata": {
48+
"HTTPStatusCode": 200,
49+
},
50+
}

0 commit comments

Comments
 (0)