Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
from model_engine_server.infra.gateways.resources.onprem_queue_endpoint_resource_delegate import (
OnPremQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
Expand All @@ -114,6 +117,7 @@
FakeDockerRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
OnPremDockerRepository,
RedisModelEndpointCacheRepository,
S3FileLLMFineTuneEventsRepository,
S3FileLLMFineTuneRepository,
Expand Down Expand Up @@ -225,6 +229,8 @@ def _get_external_interfaces(
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "onprem":
queue_delegate = OnPremQueueEndpointResourceDelegate()
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -238,6 +244,9 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "onprem":
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
elif infra_config().celery_broker_type_redis:
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
Expand Down Expand Up @@ -274,16 +283,12 @@ def _get_external_interfaces(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
llm_artifact_gateway = (
ABSLLMArtifactGateway()
if infra_config().cloud_provider == "azure"
else S3LLMArtifactGateway()
)
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
llm_artifact_gateway = ABSLLMArtifactGateway()
else:
filesystem_gateway = S3FilesystemGateway()
llm_artifact_gateway = S3LLMArtifactGateway()
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
Expand Down Expand Up @@ -328,18 +333,11 @@ def _get_external_interfaces(
hmi_config.cloud_file_llm_fine_tune_repository,
)
if infra_config().cloud_provider == "azure":
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_repository = ABSFileLLMFineTuneRepository(file_path=file_path)
llm_fine_tune_events_repository = ABSFileLLMFineTuneEventsRepository()
else:
llm_fine_tune_repository = S3FileLLMFineTuneRepository(
file_path=file_path,
)
llm_fine_tune_events_repository = (
ABSFileLLMFineTuneEventsRepository()
if infra_config().cloud_provider == "azure"
else S3FileLLMFineTuneEventsRepository()
)
llm_fine_tune_repository = S3FileLLMFineTuneRepository(file_path=file_path)
llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository()
llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService(
docker_image_batch_job_gateway=docker_image_batch_job_gateway,
docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository,
Expand All @@ -350,17 +348,18 @@ def _get_external_interfaces(
docker_image_batch_job_gateway=docker_image_batch_job_gateway
)

file_storage_gateway = (
ABSFileStorageGateway()
if infra_config().cloud_provider == "azure"
else S3FileStorageGateway()
)
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
else:
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
if CIRCLECI:
docker_repository = FakeDockerRepository()
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
elif infra_config().cloud_provider == "azure":
docker_repository = ACRDockerRepository()
elif infra_config().cloud_provider == "onprem":
docker_repository = OnPremDockerRepository()
else:
docker_repository = ECRDockerRepository()

Expand Down
20 changes: 14 additions & 6 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,29 @@ def from_yaml(cls, yaml_path):

@property
def cache_redis_url(self) -> str:
cloud_provider = infra_config().cloud_provider

if cloud_provider == "onprem":
if self.cache_redis_aws_url:
logger.info("On-prem deployment using cache_redis_aws_url")
return self.cache_redis_aws_url
redis_host = os.getenv("REDIS_HOST", "redis")
redis_port = getattr(infra_config(), "redis_port", 6379)
return f"redis://{redis_host}:{redis_port}/0"

if self.cache_redis_aws_url:
assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
assert cloud_provider == "aws", "cache_redis_aws_url is only for AWS"
if self.cache_redis_aws_secret_name:
logger.warning(
"Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url"
)
return self.cache_redis_aws_url
elif self.cache_redis_aws_secret_name:
assert (
infra_config().cloud_provider == "aws"
), "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role
assert cloud_provider == "aws", "cache_redis_aws_secret_name is only for AWS"
creds = get_key_file(self.cache_redis_aws_secret_name)
return creds["cache-url"]

assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure"
assert self.cache_redis_azure_host and cloud_provider == "azure"
username = os.getenv("AZURE_OBJECT_ID")
token = DefaultAzureCredential().get_token("https://redis.azure.com/.default")
password = token.token
Expand Down
17 changes: 15 additions & 2 deletions model-engine/model_engine_server/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

def open_wrapper(uri: str, mode: str = "rt", **kwargs):
client: Any
cloud_provider: str
# This follows the 5.1.0 smart_open API
try:
cloud_provider = infra_config().cloud_provider
except Exception:
cloud_provider = "aws"

if cloud_provider == "azure":
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
Expand All @@ -24,6 +23,20 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs):
f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net",
DefaultAzureCredential(),
)
elif cloud_provider == "onprem":
session = boto3.Session()
client_kwargs = {}

s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv(
"S3_ENDPOINT_URL"
)
if s3_endpoint:
client_kwargs["endpoint_url"] = s3_endpoint

addressing_style = getattr(infra_config(), "s3_addressing_style", "path")
client_kwargs["config"] = boto3.session.Config(s3={"addressing_style": addressing_style})

client = session.client("s3", **client_kwargs)
else:
profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE"))
session = boto3.Session(profile_name=profile_name)
Expand Down
31 changes: 21 additions & 10 deletions model-engine/model_engine_server/core/celery/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,17 +531,28 @@ def _get_backend_url_and_conf(
backend_url = get_redis_endpoint(1)
elif backend_protocol == "s3":
backend_url = "s3://"
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
if infra_config().cloud_provider == "aws":
if aws_role is None:
aws_session = session(infra_config().profile_ml_worker)
else:
aws_session = session(aws_role)
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
else:
aws_session = session(aws_role)
out_conf_changes.update(
{
"s3_boto3_session": aws_session,
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
logger.info(
"Non-AWS deployment, using environment variables for S3 backend credentials"
)
out_conf_changes.update(
{
"s3_bucket": s3_bucket,
"s3_base_path": s3_base_path,
}
)
elif backend_protocol == "abs":
backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}"
else:
Expand Down
72 changes: 72 additions & 0 deletions model-engine/model_engine_server/core/configs/onprem.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# On-premise deployment configuration
# This configuration file provides defaults for on-prem deployments
# Many values can be overridden via environment variables

cloud_provider: "onprem"
env: "production" # Can be: production, staging, development, local
k8s_cluster_name: "onprem-cluster"
dns_host_domain: "ml.company.local"
default_region: "us-east-1" # Placeholder for compatibility with cloud-agnostic code

# ====================
# Object Storage (MinIO/S3-compatible)
# ====================
s3_bucket: "model-engine"
# S3 endpoint URL - can be overridden by S3_ENDPOINT_URL env var
# Examples: "https://minio.company.local", "http://minio-service:9000"
s3_endpoint_url: "" # Set via S3_ENDPOINT_URL env var if not specified here
# MinIO requires path-style addressing (bucket in URL path, not subdomain)
s3_addressing_style: "path"

# ====================
# Redis Configuration
# ====================
# Redis is used for:
# - Celery task queue broker
# - Model endpoint caching
# - Inference autoscaling metrics
redis_host: "" # Set via REDIS_HOST env var (e.g., "redis.company.local" or "redis-service")
redis_port: 6379
# Whether to use Redis as Celery broker (true for on-prem)
celery_broker_type_redis: true

# ====================
# Celery Configuration
# ====================
# Backend protocol: "redis" for on-prem (not "s3" or "abs")
celery_backend_protocol: "redis"

# ====================
# Database Configuration
# ====================
# Database connection settings (credentials from environment variables)
# DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
db_host: "postgres" # Default hostname, can be overridden by DB_HOST env var
db_port: 5432
db_name: "llm_engine"
db_engine_pool_size: 20
db_engine_max_overflow: 10
db_engine_echo: false
db_engine_echo_pool: false
db_engine_disconnect_strategy: "pessimistic"

# ====================
# Docker Registry Configuration
# ====================
# Docker registry prefix for container images
# Examples: "registry.company.local", "harbor.company.local/ml-platform"
# Leave empty if using full image paths directly
docker_repo_prefix: "registry.company.local"

# ====================
# Monitoring & Observability
# ====================
# Prometheus server address for metrics (optional)
# prometheus_server_address: "http://prometheus:9090"

# ====================
# Not applicable for on-prem (kept for compatibility)
# ====================
ml_account_id: "onprem"
profile_ml_worker: "default"
profile_ml_inference_worker: "default"
12 changes: 11 additions & 1 deletion model-engine/model_engine_server/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@ def get_engine_url(
key_file = get_key_file_name(env) # type: ignore
logger.debug(f"Using key file {key_file}")

if infra_config().cloud_provider == "azure":
if infra_config().cloud_provider == "onprem":
user = os.environ.get("DB_USER", "postgres")
password = os.environ.get("DB_PASSWORD", "postgres")
host = os.environ.get("DB_HOST_RO") or os.environ.get("DB_HOST", "localhost")
port = os.environ.get("DB_PORT", "5432")
dbname = os.environ.get("DB_NAME", "llm_engine")
logger.info(f"Connecting to db {host}:{port}, name {dbname}")

engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"

elif infra_config().cloud_provider == "azure":
client = SecretClient(
vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net",
credential=DefaultAzureCredential(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,18 @@ def validate_fields_present_for_framework_type(cls, field_values):
"type was selected."
)
else: # field_values["framework_type"] == ModelBundleFramework.CUSTOM:
assert field_values["ecr_repo"] and field_values["image_tag"], (
"Expected `ecr_repo` and `image_tag` to be non-null because the custom framework "
assert field_values["image_tag"], (
"Expected `image_tag` to be non-null because the custom framework "
"type was selected."
)
if not field_values.get("ecr_repo"):
from model_engine_server.core.config import infra_config

if infra_config().cloud_provider != "onprem":
raise ValueError(
"Expected `ecr_repo` to be non-null for custom framework. "
"For on-prem deployments, ecr_repo can be omitted to use direct image references."
)
return field_values

model_config = ConfigDict(from_attributes=True)
Expand Down
7 changes: 6 additions & 1 deletion model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
ECRDockerRepository,
FakeDockerRepository,
)
from model_engine_server.infra.repositories.onprem_docker_repository import (
OnPremDockerRepository,
)
from model_engine_server.infra.repositories.db_model_endpoint_record_repository import (
DbModelEndpointRecordRepository,
)
Expand Down Expand Up @@ -124,8 +127,10 @@ async def main(args: Any):
docker_repo: DockerRepository
if CIRCLECI:
docker_repo = FakeDockerRepository()
elif infra_config().docker_repo_prefix.endswith("azurecr.io"):
elif infra_config().cloud_provider == "azure":
docker_repo = ACRDockerRepository()
elif infra_config().cloud_provider == "onprem":
docker_repo = OnPremDockerRepository()
else:
docker_repo = ECRDockerRepository()
while True:
Expand Down
Loading