Skip to content

Commit 847cf88

Browse files
Khauneesh-AIKeivan Vosoughi
authored andcommitted
CDP token for CAII inference
1 parent bf22fed commit 847cf88

File tree

4 files changed

+70
-45
lines changed

4 files changed

+70
-45
lines changed

.project-metadata.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ environment_variables:
3030
default: "your huggingface username"
3131
description: >-
3232
hf_username
33+
CDP_TOKEN:
34+
default: "API key for Cloudera AI Inference"
35+
description: >-
36+
CDP_TOKEN
3337
3438

3539

@@ -69,7 +73,7 @@ tasks:
6973
script: build/build_client.py
7074
arguments: None
7175
cpu: 2
72-
memory: 2
76+
memory: 4
7377
short_summary: Create job to build client application
7478
environment:
7579
TASK_TYPE: CREATE/RUN_JOB

app/core/config.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import requests
66
import json
77
from fastapi.responses import JSONResponse
8+
import os
9+
from pathlib import Path
10+
from dotenv import load_dotenv
11+
load_dotenv()
812

913
class UseCase(str, Enum):
1014
CODE_GENERATION = "code_generation"
@@ -281,18 +285,55 @@ def get_examples_for_topic(use_case: UseCase, topic: str) -> List[Dict[str, str]
281285
}
282286
}
283287

288+
JWT_PATH = Path("/tmp/jwt")
289+
290+
def _get_caii_token() -> str:
291+
if (tok := os.getenv("CDP_TOKEN")):
292+
return tok
293+
try:
294+
payload = json.loads(open(JWT_PATH).read())
295+
except FileNotFoundError:
296+
raise HTTPException(
297+
status_code=status.HTTP_401_UNAUTHORIZED,
298+
detail="No CDP_TOKEN env‑var and no /tmp/jwt file")
299+
except json.JSONDecodeError:
300+
raise HTTPException(
301+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
302+
detail="Malformed /tmp/jwt")
303+
304+
if not (tok := payload.get("access_token")):
305+
raise HTTPException(
306+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
307+
detail="access_token missing in /tmp/jwt")
308+
return tok
309+
310+
def caii_check(endpoint: str, timeout: int = 3) -> requests.Response:
311+
"""
312+
Return the GET /models response if everything is healthy.
313+
Raise HTTPException on *any* problem.
314+
"""
315+
if not endpoint:
316+
raise HTTPException(400, "CAII endpoint not provided")
317+
318+
token = _get_caii_token()
319+
url = endpoint.removesuffix("/chat/completions") + "/models"
320+
321+
try:
322+
r = requests.get(url,
323+
headers={"Authorization": f"Bearer {token}"},
324+
timeout=timeout)
325+
except requests.exceptions.RequestException as exc:
326+
raise HTTPException(503, f"CAII endpoint unreachable: {exc}")
327+
328+
if r.status_code in (401, 403):
329+
raise HTTPException(403, "Token is valid but has no access to this environment")
330+
if r.status_code == 404:
331+
raise HTTPException(404, "CAII endpoint or resource not found")
332+
if 500 <= r.status_code < 600:
333+
raise HTTPException(503, "CAII endpoint is downscaled; retry in ~15 min")
334+
if r.status_code != 200:
335+
raise HTTPException(r.status_code, r.text)
336+
337+
return r
284338

285-
def caii_check(caii_endpoint):
286-
API_KEY = json.load(open("/tmp/jwt"))["access_token"]
287-
headers = {
288-
"Authorization": f"Bearer {API_KEY}"
289-
}
290-
291-
292-
if caii_endpoint:
293-
caii_endpoint = caii_endpoint.removesuffix('/chat/completions')
294-
caii_endpoint = caii_endpoint + "/models"
295-
response = requests.get(caii_endpoint, headers=headers, timeout=3) # Will raise RequestException if fails
296-
297-
return response
298339

app/core/model_handlers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from openai import OpenAI
1212
from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError
1313
from app.core.telemetry_integration import track_llm_operation
14+
from app.core.config import _get_caii_token
1415

1516

1617

@@ -280,7 +281,8 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool):
280281
def _handle_caii_request(self, prompt: str):
281282
"""Original CAII implementation"""
282283
try:
283-
API_KEY = json.load(open("/tmp/jwt"))["access_token"]
284+
#API_KEY = json.load(open("/tmp/jwt"))["access_token"]
285+
API_KEY = _get_caii_token()
284286
MODEL_ID = self.model_id
285287
caii_endpoint = self.caii_endpoint
286288

app/main.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,10 @@ async def generate_examples(request: SynthesisRequest):
408408
# Generate a request ID
409409
request_id = str(uuid.uuid4())
410410

411-
if request.inference_type== "CAII":
411+
if request.inference_type == "CAII":
412412
caii_endpoint = request.caii_endpoint
413-
response = caii_check(caii_endpoint)
414-
message = "The CAII endpoint you are tring to reach is downscaled, please try after >15 minutes while it autoscales, meanwhile please try another model"
415-
if response.status_code != 200:
416-
return JSONResponse(
417-
status_code=503, # Service Unavailable
418-
content={"status": "failed", "error": message}
419-
)
413+
414+
caii_check(caii_endpoint)
420415

421416

422417
is_demo = request.is_demo
@@ -464,13 +459,7 @@ async def generate_freeform_data(request: SynthesisRequest):
464459

465460
if request.inference_type == "CAII":
466461
caii_endpoint = request.caii_endpoint
467-
response = caii_check(caii_endpoint)
468-
message = "The CAII endpoint you are trying to reach is downscaled, please try after >15 minutes while it autoscales, meanwhile please try another model"
469-
if response.status_code != 200:
470-
return JSONResponse(
471-
status_code=503, # Service Unavailable
472-
content={"status": "failed", "error": message}
473-
)
462+
caii_check(caii_endpoint)
474463

475464
is_demo = request.is_demo
476465
mem = 4
@@ -514,15 +503,9 @@ async def evaluate_examples(request: EvaluationRequest):
514503
"""Evaluate generated QA pairs"""
515504
request_id = str(uuid.uuid4())
516505

517-
if request.inference_type== "CAII":
506+
if request.inference_type == "CAII":
518507
caii_endpoint = request.caii_endpoint
519-
response = caii_check(caii_endpoint)
520-
message = "The CAII endpoint you are tring to reach is downscaled, please try after >15 minutes while it autoscales, meanwhile please try another model"
521-
if response.status_code != 200:
522-
return JSONResponse(
523-
status_code=503, # Service Unavailable
524-
content={"status": "failed", "error": message}
525-
)
508+
caii_check(caii_endpoint)
526509

527510
is_demo = request.is_demo
528511
if is_demo:
@@ -541,13 +524,8 @@ async def evaluate_freeform(request: EvaluationRequest):
541524

542525
if request.inference_type == "CAII":
543526
caii_endpoint = request.caii_endpoint
544-
response = caii_check(caii_endpoint)
545-
message = "The CAII endpoint you are trying to reach is downscaled, please try after >15 minutes while it autoscales, meanwhile please try another model"
546-
if response.status_code != 200:
547-
return JSONResponse(
548-
status_code=503, # Service Unavailable
549-
content={"status": "failed", "error": message}
550-
)
527+
caii_check(caii_endpoint)
528+
551529

552530
is_demo = getattr(request, 'is_demo', True)
553531
if is_demo:

0 commit comments

Comments
 (0)