Skip to content

Commit 09431f3

Browse files
committed
Add custom model endpoint management system
- Add CustomEndpointManager for CRUD operations on custom endpoints - Support for all providers: CAII, Bedrock, OpenAI, OpenAI Compatible, Gemini - Complete REST API for endpoint management (POST/GET/PUT/DELETE) - Integration with existing model handlers for automatic credential lookup - Move custom_endpoint_manager to app/core/ for better organization - Add OpenAI_Endpoint_Compatible_Key to environment variables - Fix ImportError: replace Example_eval with EvaluationExample - Restore missing fields: max_concurrent_topics, max_workers, Example_eval, etc. - Add comprehensive API documentation and curl examples
1 parent 06537b4 commit 09431f3

File tree

8 files changed

+667
-19
lines changed

8 files changed

+667
-19
lines changed

.project-metadata.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ environment_variables:
4242
default: null
4343
description: >-
4444
Gemini API Key. Check the Google Gemini documentation for information about role access
45+
OpenAI_Endpoint_Compatible_Key:
46+
default: null
47+
description: >-
48+
API Key for OpenAI Compatible endpoints. Used for custom OpenAI-compatible model endpoints.
4549
# runtimes
4650
runtimes:
4751
- editor: JupyterLab
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import json
2+
import os
3+
import uuid
4+
from datetime import datetime, timezone
5+
from typing import List, Dict, Optional, Any
6+
from pathlib import Path
7+
8+
from app.models.request_models import (
9+
CustomEndpoint, CustomCAIIEndpoint, CustomBedrockEndpoint,
10+
CustomOpenAIEndpoint, CustomOpenAICompatibleEndpoint, CustomGeminiEndpoint
11+
)
12+
from app.core.exceptions import APIError
13+
14+
15+
class CustomEndpointManager:
16+
"""Manager for custom model endpoints"""
17+
18+
def __init__(self, config_file: str = "custom_model_endpoints.json"):
19+
"""
20+
Initialize the custom endpoint manager
21+
22+
Args:
23+
config_file: Path to the JSON file storing custom endpoints
24+
"""
25+
self.config_file = Path(config_file)
26+
self._ensure_config_file_exists()
27+
28+
def _ensure_config_file_exists(self):
29+
"""Ensure the configuration file exists with proper structure"""
30+
if not self.config_file.exists():
31+
initial_config = {
32+
"version": "1.0",
33+
"endpoints": {},
34+
"created_at": datetime.now(timezone.utc).isoformat(),
35+
"last_updated": datetime.now(timezone.utc).isoformat()
36+
}
37+
with open(self.config_file, 'w') as f:
38+
json.dump(initial_config, f, indent=2)
39+
40+
def _load_config(self) -> Dict[str, Any]:
41+
"""Load configuration from file"""
42+
try:
43+
with open(self.config_file, 'r') as f:
44+
return json.load(f)
45+
except (json.JSONDecodeError, FileNotFoundError) as e:
46+
raise APIError(f"Failed to load custom endpoints configuration: {str(e)}", 500)
47+
48+
def _save_config(self, config: Dict[str, Any]):
49+
"""Save configuration to file"""
50+
try:
51+
config["last_updated"] = datetime.now(timezone.utc).isoformat()
52+
with open(self.config_file, 'w') as f:
53+
json.dump(config, f, indent=2)
54+
except Exception as e:
55+
raise APIError(f"Failed to save custom endpoints configuration: {str(e)}", 500)
56+
57+
def add_endpoint(self, endpoint: CustomEndpoint) -> str:
58+
"""
59+
Add a new custom endpoint
60+
61+
Args:
62+
endpoint: Custom endpoint configuration
63+
64+
Returns:
65+
endpoint_id: The ID of the added endpoint
66+
67+
Raises:
68+
APIError: If endpoint already exists or validation fails
69+
"""
70+
config = self._load_config()
71+
72+
# Check if endpoint_id already exists
73+
if endpoint.endpoint_id in config["endpoints"]:
74+
raise APIError(f"Endpoint with ID '{endpoint.endpoint_id}' already exists", 400)
75+
76+
# Add timestamps
77+
now = datetime.now(timezone.utc).isoformat()
78+
endpoint.created_at = now
79+
endpoint.updated_at = now
80+
81+
# Store endpoint configuration
82+
config["endpoints"][endpoint.endpoint_id] = endpoint.model_dump()
83+
84+
self._save_config(config)
85+
return endpoint.endpoint_id
86+
87+
def get_endpoint(self, endpoint_id: str) -> Optional[CustomEndpoint]:
88+
"""
89+
Get a specific custom endpoint by ID
90+
91+
Args:
92+
endpoint_id: The endpoint ID to retrieve
93+
94+
Returns:
95+
CustomEndpoint or None if not found
96+
"""
97+
config = self._load_config()
98+
endpoint_data = config["endpoints"].get(endpoint_id)
99+
100+
if not endpoint_data:
101+
return None
102+
103+
return self._parse_endpoint(endpoint_data)
104+
105+
def get_all_endpoints(self) -> List[CustomEndpoint]:
106+
"""
107+
Get all custom endpoints
108+
109+
Returns:
110+
List of all custom endpoints
111+
"""
112+
config = self._load_config()
113+
endpoints = []
114+
115+
for endpoint_data in config["endpoints"].values():
116+
try:
117+
endpoint = self._parse_endpoint(endpoint_data)
118+
endpoints.append(endpoint)
119+
except Exception as e:
120+
print(f"Warning: Failed to parse endpoint {endpoint_data.get('endpoint_id', 'unknown')}: {e}")
121+
continue
122+
123+
return endpoints
124+
125+
def get_endpoints_by_provider(self, provider_type: str) -> List[CustomEndpoint]:
126+
"""
127+
Get all endpoints for a specific provider
128+
129+
Args:
130+
provider_type: The provider type to filter by
131+
132+
Returns:
133+
List of endpoints for the specified provider
134+
"""
135+
all_endpoints = self.get_all_endpoints()
136+
return [ep for ep in all_endpoints if ep.provider_type == provider_type]
137+
138+
def update_endpoint(self, endpoint_id: str, updated_endpoint: CustomEndpoint) -> bool:
139+
"""
140+
Update an existing custom endpoint
141+
142+
Args:
143+
endpoint_id: The endpoint ID to update
144+
updated_endpoint: Updated endpoint configuration
145+
146+
Returns:
147+
True if updated successfully, False if endpoint not found
148+
149+
Raises:
150+
APIError: If validation fails
151+
"""
152+
config = self._load_config()
153+
154+
if endpoint_id not in config["endpoints"]:
155+
return False
156+
157+
# Preserve original created_at timestamp
158+
original_created_at = config["endpoints"][endpoint_id].get("created_at")
159+
160+
# Update timestamps
161+
updated_endpoint.created_at = original_created_at
162+
updated_endpoint.updated_at = datetime.now(timezone.utc).isoformat()
163+
updated_endpoint.endpoint_id = endpoint_id # Ensure ID consistency
164+
165+
# Update endpoint configuration
166+
config["endpoints"][endpoint_id] = updated_endpoint.model_dump()
167+
168+
self._save_config(config)
169+
return True
170+
171+
def delete_endpoint(self, endpoint_id: str) -> bool:
172+
"""
173+
Delete a custom endpoint
174+
175+
Args:
176+
endpoint_id: The endpoint ID to delete
177+
178+
Returns:
179+
True if deleted successfully, False if endpoint not found
180+
"""
181+
config = self._load_config()
182+
183+
if endpoint_id not in config["endpoints"]:
184+
return False
185+
186+
del config["endpoints"][endpoint_id]
187+
self._save_config(config)
188+
return True
189+
190+
def _parse_endpoint(self, endpoint_data: Dict[str, Any]) -> CustomEndpoint:
191+
"""
192+
Parse endpoint data into appropriate CustomEndpoint subclass
193+
194+
Args:
195+
endpoint_data: Raw endpoint data from config
196+
197+
Returns:
198+
Parsed CustomEndpoint instance
199+
200+
Raises:
201+
APIError: If parsing fails
202+
"""
203+
provider_type = endpoint_data.get("provider_type")
204+
205+
try:
206+
if provider_type == "caii":
207+
return CustomCAIIEndpoint(**endpoint_data)
208+
elif provider_type == "bedrock":
209+
return CustomBedrockEndpoint(**endpoint_data)
210+
elif provider_type == "openai":
211+
return CustomOpenAIEndpoint(**endpoint_data)
212+
elif provider_type == "openai_compatible":
213+
return CustomOpenAICompatibleEndpoint(**endpoint_data)
214+
elif provider_type == "gemini":
215+
return CustomGeminiEndpoint(**endpoint_data)
216+
else:
217+
raise APIError(f"Unknown provider type: {provider_type}", 400)
218+
except Exception as e:
219+
raise APIError(f"Failed to parse endpoint configuration: {str(e)}", 500)
220+
221+
def validate_endpoint_id(self, endpoint_id: str) -> bool:
222+
"""
223+
Validate endpoint ID format
224+
225+
Args:
226+
endpoint_id: The endpoint ID to validate
227+
228+
Returns:
229+
True if valid, False otherwise
230+
"""
231+
if not endpoint_id or not isinstance(endpoint_id, str):
232+
return False
233+
234+
# Allow alphanumeric, hyphens, and underscores
235+
import re
236+
return bool(re.match(r'^[a-zA-Z0-9_-]+$', endpoint_id))
237+
238+
def get_endpoint_stats(self) -> Dict[str, Any]:
239+
"""
240+
Get statistics about custom endpoints
241+
242+
Returns:
243+
Dictionary with endpoint statistics
244+
"""
245+
endpoints = self.get_all_endpoints()
246+
247+
provider_counts = {}
248+
for endpoint in endpoints:
249+
provider_type = endpoint.provider_type
250+
provider_counts[provider_type] = provider_counts.get(provider_type, 0) + 1
251+
252+
return {
253+
"total_endpoints": len(endpoints),
254+
"provider_counts": provider_counts,
255+
"endpoint_ids": [ep.endpoint_id for ep in endpoints]
256+
}

app/core/model_endpoints.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ async def bound(p: _CaiiPair):
324324
# Single orchestrator used by the api endpoint
325325
# ────────────────────────────────────────────────────────────────
326326
async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]:
327-
"""Collect and health-check models from all providers."""
327+
"""Collect and health-check models from all providers, including custom endpoints."""
328+
329+
# Import here to avoid circular imports
330+
from app.core.custom_endpoint_manager import CustomEndpointManager
328331

329332
# Bedrock
330333
bedrock_all = list_bedrock_models()
@@ -350,6 +353,10 @@ async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]:
350353
"google_gemini": {
351354
"enabled": gemini_enabled,
352355
"disabled": gemini_disabled,
356+
},
357+
"openai_compatible": {
358+
"enabled": [],
359+
"disabled": [],
353360
}
354361
}
355362

@@ -364,5 +371,51 @@ async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]:
364371
else:
365372
catalog["CAII"] = {}
366373

374+
# Add custom endpoints
375+
try:
376+
custom_manager = CustomEndpointManager()
377+
custom_endpoints = custom_manager.get_all_endpoints()
378+
379+
for endpoint in custom_endpoints:
380+
provider_key = _get_catalog_key_for_provider(endpoint.provider_type)
381+
382+
if provider_key not in catalog:
383+
catalog[provider_key] = {"enabled": [], "disabled": []}
384+
385+
# For now, assume custom endpoints are enabled (we could add health checks later)
386+
if endpoint.provider_type in ["caii"]:
387+
# CAII format: {"model": name, "endpoint": url}
388+
catalog[provider_key]["enabled"].append({
389+
"model": endpoint.model_id,
390+
"endpoint": getattr(endpoint, 'endpoint_url', ''),
391+
"custom": True,
392+
"endpoint_id": endpoint.endpoint_id,
393+
"display_name": endpoint.display_name
394+
})
395+
else:
396+
# Other providers: just the model name with custom metadata
397+
catalog[provider_key]["enabled"].append({
398+
"model": endpoint.model_id,
399+
"custom": True,
400+
"endpoint_id": endpoint.endpoint_id,
401+
"display_name": endpoint.display_name,
402+
"provider_type": endpoint.provider_type
403+
})
404+
405+
except Exception as e:
406+
print(f"Warning: Failed to load custom endpoints: {e}")
407+
367408
return catalog
368409

410+
411+
def _get_catalog_key_for_provider(provider_type: str) -> str:
412+
"""Map provider types to catalog keys"""
413+
mapping = {
414+
"bedrock": "aws_bedrock",
415+
"openai": "openai",
416+
"openai_compatible": "openai_compatible",
417+
"gemini": "google_gemini",
418+
"caii": "CAII"
419+
}
420+
return mapping.get(provider_type, provider_type)
421+

0 commit comments

Comments
 (0)