From 48b96af08535fb7f2b139a89468b803a6c6bbd9d Mon Sep 17 00:00:00 2001 From: Keivan Vosoughi Date: Thu, 28 Aug 2025 11:16:03 -0700 Subject: [PATCH 01/12] fix: resolve TypeScript errors in TemplateCard and update Finish component for DSE-47182 --- app/client/src/pages/DataGenerator/Finish.tsx | 19 ++++++++++++++++--- app/client/src/pages/Home/TemplateCard.tsx | 10 +++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/app/client/src/pages/DataGenerator/Finish.tsx b/app/client/src/pages/DataGenerator/Finish.tsx index b54e5fee..60f7d4de 100644 --- a/app/client/src/pages/DataGenerator/Finish.tsx +++ b/app/client/src/pages/DataGenerator/Finish.tsx @@ -255,7 +255,8 @@ const Finish = () => { title: 'Review Dataset', description: 'Review your dataset to ensure it properly fits your usecase.', icon: , - href: getFilesURL(genDatasetResp?.export_path?.local || "") + href: getFilesURL(genDatasetResp?.export_path?.local || ""), + external: true }, { avatar: '', @@ -278,7 +279,8 @@ const Finish = () => { title: 'Review Dataset', description: 'Once your dataset finishes generating, you can review your dataset in the workbench files', icon: , - href: getFilesURL('') + href: getFilesURL(''), + external: true }, { avatar: '', @@ -361,7 +363,18 @@ const Finish = () => { ( + renderItem={({ title, href, icon, description, external }, i) => ( + external ? + + + } + title={title} + description={description} + /> + + : + props.theme.color}; - background-color: ${props => props.theme.backgroundColor}; - border: 1px solid ${props => props.theme.borderColor}; +const StyledTag = styled(Tag)<{ $theme: { color: string; backgroundColor: string; borderColor: string } }>` + color: ${props => props.$theme.color} !important; + background-color: ${props => props.$theme.backgroundColor} !important; + border: 1px solid ${props => props.$theme.borderColor} !important; `; @@ -150,7 +150,7 @@ const TemplateCard: React.FC = ({ template }) => { const { color, backgroundColor, borderColor } = getTemplateTagColors(theme as string); return ( - +
{tag}
From 2599758bed920d317489410c53b8715d7aad653b Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Thu, 11 Sep 2025 00:02:28 +0530 Subject: [PATCH 02/12] support for openai compatible models --- app/core/model_handlers.py | 62 ++++++++++++++++++++++++++++++++++++ app/models/request_models.py | 1 + 2 files changed, 63 insertions(+) diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py index 92391b8c..b61a0ec7 100644 --- a/app/core/model_handlers.py +++ b/app/core/model_handlers.py @@ -180,6 +180,8 @@ def generate_response( return self._handle_caii_request(prompt) if self.inference_type == "openai": return self._handle_openai_request(prompt) + if self.inference_type == "openai_compatible": + return self._handle_openai_compatible_request(prompt) if self.inference_type == "gemini": return self._handle_gemini_request(prompt) raise ModelHandlerError(f"Unsupported inference_type={self.inference_type}", 400) @@ -342,6 +344,66 @@ def _handle_openai_request(self, prompt: str): except Exception as e: raise ModelHandlerError(f"OpenAI request failed: {e}", 500) + # ---------- OpenAI Compatible ------------------------------------------------------- + def _handle_openai_compatible_request(self, prompt: str): + """Handle OpenAI compatible endpoints with proper timeout configuration""" + try: + import httpx + from openai import OpenAI + + # Get API key from environment variable (only credential needed) + api_key = os.getenv('OpenAI_Endpoint_Compatible_Key') + if not api_key: + raise ModelHandlerError("OpenAI_Endpoint_Compatible_Key environment variable not set", 500) + + # Base URL comes from caii_endpoint parameter (passed during initialization) + openai_compatible_endpoint = self.caii_endpoint + if not openai_compatible_endpoint: + raise ModelHandlerError("OpenAI compatible endpoint not provided", 500) + + # Configure timeout for OpenAI compatible client (same as OpenAI v1.57.2) + timeout_config = httpx.Timeout( + connect=self.OPENAI_CONNECT_TIMEOUT, + read=self.OPENAI_READ_TIMEOUT, + write=10.0, + pool=5.0 + ) + + # Configure httpx client with certificate verification for private cloud + if os.path.exists("/etc/ssl/certs/ca-certificates.crt"): + http_client = httpx.Client( + verify="/etc/ssl/certs/ca-certificates.crt", + timeout=timeout_config + ) + else: + http_client = httpx.Client(timeout=timeout_config) + + # Remove trailing '/chat/completions' if present (similar to CAII handling) + openai_compatible_endpoint = openai_compatible_endpoint.removesuffix('/chat/completions') + + client = OpenAI( + api_key=api_key, + base_url=openai_compatible_endpoint, + http_client=http_client + ) + + completion = client.chat.completions.create( + model=self.model_id, + messages=[{"role": "user", "content": prompt}], + max_tokens=self.model_params.max_tokens, + temperature=self.model_params.temperature, + top_p=self.model_params.top_p, + stream=False, + ) + + print("generated via OpenAI Compatible endpoint") + response_text = completion.choices[0].message.content + + return self._extract_json_from_text(response_text) if not self.custom_p else response_text + + except Exception as e: + raise ModelHandlerError(f"OpenAI Compatible request failed: {str(e)}", status_code=500) + # ---------- Gemini ------------------------------------------------------- def _handle_gemini_request(self, prompt: str): if genai is None: diff --git a/app/models/request_models.py b/app/models/request_models.py index eaf9d7e0..7cf9e61d 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -123,6 +123,7 @@ class SynthesisRequest(BaseModel): # Optional fields that can override defaults inference_type: Optional[str] = "aws_bedrock" caii_endpoint: Optional[str] = None + openai_compatible_endpoint: Optional[str] = None topics: Optional[List[str]] = None doc_paths: Optional[List[str]] = None input_path: Optional[List[str]] = None From 00eb07a337521c77782504f994eb67d9e2a2cf71 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Tue, 16 Sep 2025 22:24:28 +0530 Subject: [PATCH 03/12] Refactor: Separate freeform functionality from legacy SFT/Custom_Workflow - Split synthesis_service.py into: * synthesis_service.py (freeform only) * synthesis_legacy_service.py (SFT & Custom_Workflow) - Split evaluator_service.py into: * evaluator_service.py (freeform only) * evaluator_legacy_service.py (SFT & Custom_Workflow) - Updated main.py to route requests to appropriate services - Updated all dependent files (synthesis_job, model_alignment, run scripts) - Created comprehensive test coverage for both legacy and freeform services - Added .gitignore patterns to prevent committing generated data files - Maintained full backward compatibility - all endpoints work as before This refactoring isolates freeform functionality while preserving existing SFT and Custom_Workflow features without breaking changes. --- .gitignore | 10 + app/main.py | 19 +- app/run_eval_job.py | 3 +- app/run_job.py | 3 +- app/services/evaluator_legacy_service.py | 407 ++++++++++ app/services/evaluator_service.py | 349 +-------- app/services/model_alignment.py | 8 +- app/services/synthesis_job.py | 8 +- app/services/synthesis_legacy_service.py | 700 ++++++++++++++++++ app/services/synthesis_service.py | 598 +-------------- tests/integration/test_evaluate_api.py | 10 +- tests/integration/test_evaluate_legacy_api.py | 113 +++ tests/integration/test_synthesis_api.py | 4 +- .../integration/test_synthesis_legacy_api.py | 91 +++ tests/unit/test_evaluator_freeform_service.py | 80 ++ tests/unit/test_evaluator_legacy_service.py | 83 +++ tests/unit/test_evaluator_service.py | 18 +- tests/unit/test_synthesis_freeform_service.py | 70 ++ tests/unit/test_synthesis_legacy_service.py | 56 ++ tests/unit/test_synthesis_service.py | 10 +- 20 files changed, 1661 insertions(+), 979 deletions(-) create mode 100644 app/services/evaluator_legacy_service.py create mode 100644 app/services/synthesis_legacy_service.py create mode 100644 tests/integration/test_evaluate_legacy_api.py create mode 100644 tests/integration/test_synthesis_legacy_api.py create mode 100644 tests/unit/test_evaluator_freeform_service.py create mode 100644 tests/unit/test_evaluator_legacy_service.py create mode 100644 tests/unit/test_synthesis_freeform_service.py create mode 100644 tests/unit/test_synthesis_legacy_service.py diff --git a/.gitignore b/.gitignore index 0549d46a..a4dfe0da 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,16 @@ qa_pairs* Khauneesh/ *job_args* +# Generated data files +freeform_data_*.json +row_data_*.json +lending_*.json +seeds_*.json +SeedsInstructions.json +*_example.json +nm.json +french_input.json + # DB *metadata.db-shm *metadata.db-wal diff --git a/app/main.py b/app/main.py index b81bc05c..e82bd296 100644 --- a/app/main.py +++ b/app/main.py @@ -42,8 +42,10 @@ sys.path.append(str(ROOT_DIR)) from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath, Technique from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.services.export_results import Export_Service from app.core.prompt_templates import PromptBuilder, PromptHandler @@ -66,8 +68,10 @@ #****************************************Initialize************************************************ # Initialize services -synthesis_service = SynthesisService() -evaluator_service = EvaluatorService() +synthesis_service = SynthesisService() # Freeform only +synthesis_legacy_service = SynthesisLegacyService() # SFT and Custom_Workflow +evaluator_service = EvaluatorService() # Freeform only +evaluator_legacy_service = EvaluatorLegacyService() # SFT and Custom_Workflow export_service = Export_Service() db_manager = DatabaseManager() @@ -552,9 +556,11 @@ async def generate_examples(request: SynthesisRequest): if is_demo== True: if request.input_path: - return await synthesis_service.generate_result(request,is_demo, request_id=request_id) + # Custom_Workflow technique - route to legacy service + return await synthesis_legacy_service.generate_result(request,is_demo, request_id=request_id) else: - return await synthesis_service.generate_examples(request,is_demo, request_id=request_id) + # SFT technique - route to legacy service + return await synthesis_legacy_service.generate_examples(request,is_demo, request_id=request_id) else: return synthesis_job.generate_job(request, core, mem, request_id=request_id) @@ -626,7 +632,8 @@ async def evaluate_examples(request: EvaluationRequest): is_demo = request.is_demo if is_demo: - return evaluator_service.evaluate_results(request, request_id=request_id) + # SFT and Custom_Workflow evaluation - route to legacy service + return evaluator_legacy_service.evaluate_results(request, request_id=request_id) else: return synthesis_job.evaluate_job(request, request_id=request_id) @@ -1242,7 +1249,7 @@ def is_empty(self): async def health_check(): """Get API health status""" #return {"status": "healthy"} - return synthesis_service.get_health_check() + return synthesis_legacy_service.get_health_check() @app.get("/{use_case}/example_payloads") async def get_example_payloads(use_case:UseCase): diff --git a/app/run_eval_job.py b/app/run_eval_job.py index 9591cee4..7a60714d 100644 --- a/app/run_eval_job.py +++ b/app/run_eval_job.py @@ -31,6 +31,7 @@ from app.models.request_models import EvaluationRequest, ModelParameters from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService import asyncio import nest_asyncio @@ -40,7 +41,7 @@ async def run_eval(request, job_name, request_id): try: - job = EvaluatorService() + job = EvaluatorLegacyService() result = job.evaluate_results(request,job_name, is_demo=False, request_id=request_id) return result except Exception as e: diff --git a/app/run_job.py b/app/run_job.py index 6d15833a..213c6d0a 100644 --- a/app/run_job.py +++ b/app/run_job.py @@ -32,6 +32,7 @@ import json from app.models.request_models import SynthesisRequest from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService import asyncio import nest_asyncio # Add this import @@ -41,7 +42,7 @@ async def run_synthesis(request, job_name, request_id): """Run standard synthesis job for question-answer pairs""" try: - job = SynthesisService() + job = SynthesisLegacyService() if request.input_path: result = await job.generate_result(request, job_name, is_demo=False, request_id=request_id) else: diff --git a/app/services/evaluator_legacy_service.py b/app/services/evaluator_legacy_service.py new file mode 100644 index 00000000..9d0656a4 --- /dev/null +++ b/app/services/evaluator_legacy_service.py @@ -0,0 +1,407 @@ +import boto3 +from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from app.models.request_models import Example, ModelParameters, EvaluationRequest +from app.core.model_handlers import create_handler +from app.core.prompt_templates import PromptBuilder, PromptHandler +from app.services.aws_bedrock import get_bedrock_client +from app.core.database import DatabaseManager +from app.core.config import UseCase, Technique, get_model_family +from app.services.check_guardrail import ContentGuardrail +from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError +import os +from datetime import datetime, timezone +import json +import logging +from logging.handlers import RotatingFileHandler +from app.core.telemetry_integration import track_llm_operation +from functools import partial + +class EvaluatorLegacyService: + """Legacy service for evaluating generated QA pairs using Claude with parallel processing (SFT and Custom_Workflow only)""" + + def __init__(self, max_workers: int = 4): + self.bedrock_client = get_bedrock_client() + self.db = DatabaseManager() + self.max_workers = max_workers + self.guard = ContentGuardrail() + self._setup_logging() + + def _setup_logging(self): + """Set up logging configuration""" + os.makedirs('logs', exist_ok=True) + + self.logger = logging.getLogger('evaluator_legacy_service') + self.logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + # File handler for general logs + file_handler = RotatingFileHandler( + 'logs/evaluator_legacy_service.log', + maxBytes=10*1024*1024, # 10MB + backupCount=5 + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + # File handler for errors + error_handler = RotatingFileHandler( + 'logs/evaluator_legacy_service_errors.log', + maxBytes=10*1024*1024, + backupCount=5 + ) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(formatter) + self.logger.addHandler(error_handler) + + + #@track_llm_operation("evaluate_single_pair") + def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict: + """Evaluate a single QA pair""" + try: + # Default error response + error_response = { + request.output_key: qa_pair.get(request.output_key, "Unknown"), + request.output_value: qa_pair.get(request.output_value, "Unknown"), + "evaluation": { + "score": 0, + "justification": "Error during evaluation" + } + } + + try: + self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...") + except Exception as e: + self.logger.error(f"Error logging QA pair: {str(e)}") + + try: + # Validate input qa_pair structure + if not all(key in qa_pair for key in [request.output_key, request.output_value]): + error_msg = "Missing required keys in qa_pair" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + prompt = PromptBuilder.build_eval_prompt( + request.model_id, + request.use_case, + qa_pair[request.output_key], + qa_pair[request.output_value], + request.examples, + request.custom_prompt + ) + #print(prompt) + except Exception as e: + error_msg = f"Error building evaluation prompt: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + try: + response = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") + raise + except Exception as e: + error_msg = f"Error generating model response: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + if not response: + error_msg = "Failed to parse model response" + self.logger.warning(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + try: + score = response[0].get('score', "no score key") + justification = response[0].get('justification', 'No justification provided') + if score== "no score key": + self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}") + justification = "The evaluated pair did not generate valid score and justification" + score = 0 + else: + self.logger.info(f"Successfully evaluated QA pair with score: {score}") + + return { + "question": qa_pair[request.output_key], + "solution": qa_pair[request.output_value], + "evaluation": { + "score": score, + "justification": justification + } + } + except Exception as e: + error_msg = f"Error processing model response: {str(e)}" + self.logger.error(error_msg) + error_response["evaluation"]["justification"] = error_msg + return error_response + + except ModelHandlerError: + raise + except Exception as e: + self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}") + return error_response + + #@track_llm_operation("evaluate_topic") + def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict: + """Evaluate all QA pairs for a given topic in parallel""" + try: + self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs") + evaluated_pairs = [] + failed_pairs = [] + + try: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + try: + evaluate_func = partial( + self.evaluate_single_pair, + model_handler=model_handler, + request=request, request_id=request_id + ) + + future_to_pair = { + executor.submit(evaluate_func, pair): pair + for pair in qa_pairs + } + + for future in as_completed(future_to_pair): + try: + result = future.result() + evaluated_pairs.append(result) + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Error processing future result: {str(e)}" + self.logger.error(error_msg) + failed_pairs.append({ + "error": error_msg, + "pair": future_to_pair[future] + }) + + except Exception as e: + error_msg = f"Error in parallel execution: {str(e)}" + self.logger.error(error_msg) + raise + + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}" + self.logger.error(error_msg) + raise + + try: + # Calculate statistics only from successful evaluations + scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None] + + if scores: + average_score = sum(scores) / len(scores) + average_score = round(average_score, 2) + min_score = min(scores) + max_score = max(scores) + else: + average_score = min_score = max_score = 0 + + topic_stats = { + "average_score": average_score, + "min_score": min_score, + "max_score": max_score, + "evaluated_pairs": evaluated_pairs, + "failed_pairs": failed_pairs, + "total_evaluated": len(evaluated_pairs), + "total_failed": len(failed_pairs) + } + + self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}") + return topic_stats + + except Exception as e: + error_msg = f"Error calculating topic statistics: {str(e)}" + self.logger.error(error_msg) + return { + "average_score": 0, + "min_score": 0, + "max_score": 0, + "evaluated_pairs": evaluated_pairs, + "failed_pairs": failed_pairs, + "error": error_msg + } + except ModelHandlerError: + raise + except Exception as e: + error_msg = f"Critical error in evaluate_topic: {str(e)}" + self.logger.error(error_msg) + return { + "average_score": 0, + "min_score": 0, + "max_score": 0, + "evaluated_pairs": [], + "failed_pairs": [], + "error": error_msg + } + + #@track_llm_operation("evaluate_results") + def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict: + """Evaluate all QA pairs with parallel processing""" + try: + self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}") + + model_params = request.model_params or ModelParameters() + + self.logger.info(f"Creating model handler for model: {request.model_id}") + model_handler = create_handler( + request.model_id, + self.bedrock_client, + model_params=model_params, + inference_type = request.inference_type, + caii_endpoint = request.caii_endpoint + ) + + self.logger.info(f"Loading QA pairs from: {request.import_path}") + with open(request.import_path, 'r') as file: + data = json.load(file) + + evaluated_results = {} + all_scores = [] + + transformed_data = { + "results": {}, + } + for item in data: + topic = item.get('Seeds') + + # Create topic list if it doesn't exist + if topic not in transformed_data['results']: + transformed_data['results'][topic] = [] + + # Create QA pair + qa_pair = { + request.output_key: item.get(request.output_key, ''), # Use get() with default value + request.output_value: item.get(request.output_value, '') # Use get() with default value + } + + # Add to appropriate topic list + transformed_data['results'][topic].append(qa_pair) + + self.logger.info(f"Processing {len(transformed_data['results'])} topics with {self.max_workers} workers") + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_topic = { + executor.submit( + self.evaluate_topic, + topic, + qa_pairs, + model_handler, + request, request_id=request_id + ): topic + for topic, qa_pairs in transformed_data['results'].items() + } + + for future in as_completed(future_to_topic): + try: + topic = future_to_topic[future] + topic_stats = future.result() + evaluated_results[topic] = topic_stats + all_scores.extend([ + pair["evaluation"]["score"] + for pair in topic_stats["evaluated_pairs"] + ]) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in future processing: {str(e)}") + raise APIError(f"Model evaluation failed: {str(e)}") + + + overall_average = sum(all_scores) / len(all_scores) if all_scores else 0 + overall_average = round(overall_average, 2) + evaluated_results['Overall_Average'] = overall_average + + self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}") + + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + model_name = get_model_family(request.model_id).split('.')[-1] + output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json" + + self.logger.info(f"Saving evaluation results to: {output_path}") + with open(output_path, 'w') as f: + json.dump(evaluated_results, f, indent=2) + + custom_prompt_str = PromptHandler.get_default_custom_eval_prompt( + request.use_case, + request.custom_prompt + ) + + + examples_value = ( + PromptHandler.get_default_eval_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + #print(examples_value, '\n',examples_str) + + metadata = { + 'timestamp': timestamp, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'custom_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(request.import_path), + 'evaluate_file_name': os.path.basename(output_path), + 'display_name': request.display_name, + 'local_export_path': output_path, + 'examples': examples_str, + 'Overall_Average': overall_average + } + + self.logger.info("Saving evaluation metadata to database") + + + if is_demo: + self.db.save_evaluation_metadata(metadata) + return { + "status": "completed", + "result": evaluated_results, + "output_path": output_path + } + else: + + + job_status = "ENGINE_SUCCEEDED" + evaluate_file_name = os.path.basename(output_path) + self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed", + "output_path": output_path + } + except APIError: + raise + except ModelHandlerError as e: + # Add this specific handler + self.logger.error(f"ModelHandlerError in evaluation: {str(e)}") + raise APIError(str(e)) + except Exception as e: + error_msg = f"Error in evaluation process: {str(e)}" + self.logger.error(error_msg, exc_info=True) + if is_demo: + raise APIError(str(e)) + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "ENGINE_FAILED" + file_name = '' + output_path = '' + overall_average = '' + self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status) + + raise + + def safe_json_dumps(self, value): + """Convert value to JSON string only if it's not None""" + return json.dumps(value) if value is not None else None diff --git a/app/services/evaluator_service.py b/app/services/evaluator_service.py index 2b094b15..a556f2ae 100644 --- a/app/services/evaluator_service.py +++ b/app/services/evaluator_service.py @@ -19,7 +19,7 @@ from functools import partial class EvaluatorService: - """Service for evaluating generated QA pairs using Claude with parallel processing""" + """Service for evaluating freeform data rows using Claude with parallel processing (Freeform technique only)""" def __init__(self, max_workers: int = 4): self.bedrock_client = get_bedrock_client() @@ -56,351 +56,6 @@ def _setup_logging(self): error_handler.setFormatter(formatter) self.logger.addHandler(error_handler) - - #@track_llm_operation("evaluate_single_pair") - def evaluate_single_pair(self, qa_pair: Dict, model_handler, request: EvaluationRequest, request_id=None) -> Dict: - """Evaluate a single QA pair""" - try: - # Default error response - error_response = { - request.output_key: qa_pair.get(request.output_key, "Unknown"), - request.output_value: qa_pair.get(request.output_value, "Unknown"), - "evaluation": { - "score": 0, - "justification": "Error during evaluation" - } - } - - try: - self.logger.info(f"Evaluating QA pair: {qa_pair.get(request.output_key, '')[:50]}...") - except Exception as e: - self.logger.error(f"Error logging QA pair: {str(e)}") - - try: - # Validate input qa_pair structure - if not all(key in qa_pair for key in [request.output_key, request.output_value]): - error_msg = "Missing required keys in qa_pair" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - prompt = PromptBuilder.build_eval_prompt( - request.model_id, - request.use_case, - qa_pair[request.output_key], - qa_pair[request.output_value], - request.examples, - request.custom_prompt - ) - #print(prompt) - except Exception as e: - error_msg = f"Error building evaluation prompt: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - try: - response = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") - raise - except Exception as e: - error_msg = f"Error generating model response: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - if not response: - error_msg = "Failed to parse model response" - self.logger.warning(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - try: - score = response[0].get('score', "no score key") - justification = response[0].get('justification', 'No justification provided') - if score== "no score key": - self.logger.info(f"Unsuccessful QA pair evaluation with score: {score}") - justification = "The evaluated pair did not generate valid score and justification" - score = 0 - else: - self.logger.info(f"Successfully evaluated QA pair with score: {score}") - - return { - "question": qa_pair[request.output_key], - "solution": qa_pair[request.output_value], - "evaluation": { - "score": score, - "justification": justification - } - } - except Exception as e: - error_msg = f"Error processing model response: {str(e)}" - self.logger.error(error_msg) - error_response["evaluation"]["justification"] = error_msg - return error_response - - except ModelHandlerError: - raise - except Exception as e: - self.logger.error(f"Critical error in evaluate_single_pair: {str(e)}") - return error_response - - #@track_llm_operation("evaluate_topic") - def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, request: EvaluationRequest, request_id=None) -> Dict: - """Evaluate all QA pairs for a given topic in parallel""" - try: - self.logger.info(f"Starting evaluation for topic: {topic} with {len(qa_pairs)} QA pairs") - evaluated_pairs = [] - failed_pairs = [] - - try: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - try: - evaluate_func = partial( - self.evaluate_single_pair, - model_handler=model_handler, - request=request, request_id=request_id - ) - - future_to_pair = { - executor.submit(evaluate_func, pair): pair - for pair in qa_pairs - } - - for future in as_completed(future_to_pair): - try: - result = future.result() - evaluated_pairs.append(result) - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Error processing future result: {str(e)}" - self.logger.error(error_msg) - failed_pairs.append({ - "error": error_msg, - "pair": future_to_pair[future] - }) - - except Exception as e: - error_msg = f"Error in parallel execution: {str(e)}" - self.logger.error(error_msg) - raise - - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Error in ThreadPoolExecutor setup: {str(e)}" - self.logger.error(error_msg) - raise - - try: - # Calculate statistics only from successful evaluations - scores = [pair["evaluation"]["score"] for pair in evaluated_pairs if pair.get("evaluation", {}).get("score") is not None] - - if scores: - average_score = sum(scores) / len(scores) - average_score = round(average_score, 2) - min_score = min(scores) - max_score = max(scores) - else: - average_score = min_score = max_score = 0 - - topic_stats = { - "average_score": average_score, - "min_score": min_score, - "max_score": max_score, - "evaluated_pairs": evaluated_pairs, - "failed_pairs": failed_pairs, - "total_evaluated": len(evaluated_pairs), - "total_failed": len(failed_pairs) - } - - self.logger.info(f"Completed evaluation for topic: {topic}. Average score: {topic_stats['average_score']:.2f}") - return topic_stats - - except Exception as e: - error_msg = f"Error calculating topic statistics: {str(e)}" - self.logger.error(error_msg) - return { - "average_score": 0, - "min_score": 0, - "max_score": 0, - "evaluated_pairs": evaluated_pairs, - "failed_pairs": failed_pairs, - "error": error_msg - } - except ModelHandlerError: - raise - except Exception as e: - error_msg = f"Critical error in evaluate_topic: {str(e)}" - self.logger.error(error_msg) - return { - "average_score": 0, - "min_score": 0, - "max_score": 0, - "evaluated_pairs": [], - "failed_pairs": [], - "error": error_msg - } - #@track_llm_operation("evaluate_results") - def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bool = True, request_id=None) -> Dict: - """Evaluate all QA pairs with parallel processing""" - try: - self.logger.info(f"Starting evaluation process - Demo Mode: {is_demo}") - - model_params = request.model_params or ModelParameters() - - self.logger.info(f"Creating model handler for model: {request.model_id}") - model_handler = create_handler( - request.model_id, - self.bedrock_client, - model_params=model_params, - inference_type = request.inference_type, - caii_endpoint = request.caii_endpoint - ) - - self.logger.info(f"Loading QA pairs from: {request.import_path}") - with open(request.import_path, 'r') as file: - data = json.load(file) - - evaluated_results = {} - all_scores = [] - - transformed_data = { - "results": {}, - } - for item in data: - topic = item.get('Seeds') - - # Create topic list if it doesn't exist - if topic not in transformed_data['results']: - transformed_data['results'][topic] = [] - - # Create QA pair - qa_pair = { - request.output_key: item.get(request.output_key, ''), # Use get() with default value - request.output_value: item.get(request.output_value, '') # Use get() with default value - } - - # Add to appropriate topic list - transformed_data['results'][topic].append(qa_pair) - - self.logger.info(f"Processing {len(transformed_data['results'])} topics with {self.max_workers} workers") - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_topic = { - executor.submit( - self.evaluate_topic, - topic, - qa_pairs, - model_handler, - request, request_id=request_id - ): topic - for topic, qa_pairs in transformed_data['results'].items() - } - - for future in as_completed(future_to_topic): - try: - topic = future_to_topic[future] - topic_stats = future.result() - evaluated_results[topic] = topic_stats - all_scores.extend([ - pair["evaluation"]["score"] - for pair in topic_stats["evaluated_pairs"] - ]) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in future processing: {str(e)}") - raise APIError(f"Model evaluation failed: {str(e)}") - - - overall_average = sum(all_scores) / len(all_scores) if all_scores else 0 - overall_average = round(overall_average, 2) - evaluated_results['Overall_Average'] = overall_average - - self.logger.info(f"Evaluation completed. Overall average score: {overall_average:.2f}") - - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - model_name = get_model_family(request.model_id).split('.')[-1] - output_path = f"qa_pairs_{model_name}_{time_file}_evaluated.json" - - self.logger.info(f"Saving evaluation results to: {output_path}") - with open(output_path, 'w') as f: - json.dump(evaluated_results, f, indent=2) - - custom_prompt_str = PromptHandler.get_default_custom_eval_prompt( - request.use_case, - request.custom_prompt - ) - - - examples_value = ( - PromptHandler.get_default_eval_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - #print(examples_value, '\n',examples_str) - - metadata = { - 'timestamp': timestamp, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'custom_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(request.import_path), - 'evaluate_file_name': os.path.basename(output_path), - 'display_name': request.display_name, - 'local_export_path': output_path, - 'examples': examples_str, - 'Overall_Average': overall_average - } - - self.logger.info("Saving evaluation metadata to database") - - - if is_demo: - self.db.save_evaluation_metadata(metadata) - return { - "status": "completed", - "result": evaluated_results, - "output_path": output_path - } - else: - - - job_status = "ENGINE_SUCCEEDED" - evaluate_file_name = os.path.basename(output_path) - self.db.update_job_evaluate(job_name, evaluate_file_name, output_path, timestamp, overall_average, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed", - "output_path": output_path - } - except APIError: - raise - except ModelHandlerError as e: - # Add this specific handler - self.logger.error(f"ModelHandlerError in evaluation: {str(e)}") - raise APIError(str(e)) - except Exception as e: - error_msg = f"Error in evaluation process: {str(e)}" - self.logger.error(error_msg, exc_info=True) - if is_demo: - raise APIError(str(e)) - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "ENGINE_FAILED" - file_name = '' - output_path = '' - overall_average = '' - self.db.update_job_evaluate(job_name,file_name, output_path, time_stamp, job_status) - - raise - def evaluate_single_row(self, row: Dict[str, Any], model_handler, request: EvaluationRequest, request_id = None) -> Dict: """Evaluate a single data row""" try: @@ -690,4 +345,4 @@ def evaluate_row_data(self, request: EvaluationRequest, job_name=None, is_demo: def safe_json_dumps(self, value): """Convert value to JSON string only if it's not None""" - return json.dumps(value) if value is not None else None \ No newline at end of file + return json.dumps(value) if value is not None else None diff --git a/app/services/model_alignment.py b/app/services/model_alignment.py index 650005a3..2b3fcfe1 100644 --- a/app/services/model_alignment.py +++ b/app/services/model_alignment.py @@ -7,8 +7,8 @@ import asyncio from datetime import datetime, timezone from typing import Dict, Optional -from app.services.synthesis_service import SynthesisService -from app.services.evaluator_service import EvaluatorService +from app.services.synthesis_legacy_service import SynthesisLegacyService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest from app.models.request_models import ModelParameters from app.services.aws_bedrock import get_bedrock_client @@ -21,8 +21,8 @@ class ModelAlignment: """Service for aligning model outputs through synthesis and evaluation""" def __init__(self): - self.synthesis_service = SynthesisService() - self.evaluator_service = EvaluatorService() + self.synthesis_service = SynthesisLegacyService() + self.evaluator_service = EvaluatorLegacyService() self.db = DatabaseManager() self.bedrock_client = get_bedrock_client() # Add this line self._setup_logging() diff --git a/app/services/synthesis_job.py b/app/services/synthesis_job.py index 323fa7f0..8ea42010 100644 --- a/app/services/synthesis_job.py +++ b/app/services/synthesis_job.py @@ -3,9 +3,9 @@ import uuid import os from typing import Dict, Any, Optional -from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath -from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.services.export_results import Export_Service from app.core.prompt_templates import PromptBuilder, PromptHandler from app.core.config import UseCase, USE_CASE_CONFIGS @@ -22,8 +22,8 @@ import cmlapi # Initialize services -synthesis_service = SynthesisService() -evaluator_service = EvaluatorService() +synthesis_service = SynthesisLegacyService() +evaluator_service = EvaluatorLegacyService() export_service = Export_Service() db_manager = DatabaseManager() diff --git a/app/services/synthesis_legacy_service.py b/app/services/synthesis_legacy_service.py new file mode 100644 index 00000000..7f10927a --- /dev/null +++ b/app/services/synthesis_legacy_service.py @@ -0,0 +1,700 @@ +import boto3 +import json +import uuid +import time +import csv +from typing import List, Dict, Optional, Tuple +import uuid +from datetime import datetime, timezone +import os +from huggingface_hub import HfApi, HfFolder, Repository +from concurrent.futures import ThreadPoolExecutor +from functools import partial +import math +import asyncio +from fastapi import FastAPI, BackgroundTasks, HTTPException +from app.core.exceptions import APIError, InvalidModelError, ModelHandlerError, JSONParsingError +from app.core.data_loader import DataLoader +import pandas as pd +import numpy as np + +from app.models.request_models import SynthesisRequest, Example, ModelParameters +from app.core.model_handlers import create_handler +from app.core.prompt_templates import PromptBuilder, PromptHandler +from app.core.config import UseCase, Technique, get_model_family +from app.services.aws_bedrock import get_bedrock_client +from app.core.database import DatabaseManager +from app.services.check_guardrail import ContentGuardrail +from app.services.doc_extraction import DocumentProcessor +import logging +from logging.handlers import RotatingFileHandler +import traceback +from app.core.telemetry_integration import track_llm_operation +import uuid + + +class SynthesisLegacyService: + """Legacy service for generating synthetic QA pairs (SFT and Custom_Workflow only)""" + QUESTIONS_PER_BATCH = 5 # Maximum questions per batch + MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations + + + def __init__(self): + self.bedrock_client = get_bedrock_client() + self.db = DatabaseManager() + self._setup_logging() + self.guard = ContentGuardrail() + + + def _setup_logging(self): + """Set up logging configuration""" + os.makedirs('logs', exist_ok=True) + + self.logger = logging.getLogger('synthesis_legacy_service') + self.logger.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + # File handler for general logs + file_handler = RotatingFileHandler( + 'logs/synthesis_legacy_service.log', + maxBytes=10*1024*1024, # 10MB + backupCount=5 + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + # File handler for errors + error_handler = RotatingFileHandler( + 'logs/synthesis_legacy_service_errors.log', + maxBytes=10*1024*1024, + backupCount=5 + ) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(formatter) + self.logger.addHandler(error_handler) + + + #@track_llm_operation("process_single_topic") + def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: + """ + Process a single topic to generate questions and solutions. + Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails. + + Args: + topic: The topic to generate questions for + model_handler: Handler for the AI model + request: The synthesis request object + num_questions: Total number of questions to generate + + Returns: + Tuple containing: + - topic (str) + - list of validated QA pairs + - list of error messages + - list of output dictionaries with topic information + + Raises: + ModelHandlerError: When there's an error in model generation that should stop processing + """ + topic_results = [] + topic_output = [] + topic_errors = [] + questions_remaining = num_questions + omit_questions = [] + + try: + # Process questions in batches + for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH): + if questions_remaining <= 0: + break + + batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining) + self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}") + + try: + # Attempt batch processing + prompt = PromptBuilder.build_prompt( + model_id=request.model_id, + use_case=request.use_case, + topic=topic, + num_questions=batch_size, + omit_questions=omit_questions, + examples=request.examples or [], + technique=request.technique, + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + # print("prompt :", prompt) + batch_qa_pairs = None + try: + batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.warning(f"Batch processing failed: {str(e)}") + if isinstance(e, JSONParsingError): + # For JSON parsing errors, fall back to single processing + self.logger.info("JSON parsing failed, falling back to single processing") + continue + else: + # For other model errors, propagate up + raise + + if batch_qa_pairs: + # Process batch results + valid_pairs = [] + valid_outputs = [] + invalid_count = 0 + + for pair in batch_qa_pairs: + if self._validate_qa_pair(pair): + valid_pairs.append({ + "question": pair["question"], + "solution": pair["solution"] + }) + valid_outputs.append({ + "Topic": topic, + "question": pair["question"], + "solution": pair["solution"] + }) + omit_questions.append(pair["question"]) + #else: + invalid_count = batch_size - len(valid_pairs) + + if valid_pairs: + topic_results.extend(valid_pairs) + topic_output.extend(valid_outputs) + questions_remaining -= len(valid_pairs) + omit_questions = omit_questions[-100:] # Keep last 100 questions + self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}") + print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) + # If all pairs were valid, skip fallback + if invalid_count <= 0: + continue + + else: + # Fall back to single processing for remaining or failed questions + self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}") + remaining_batch = invalid_count + print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) + for _ in range(remaining_batch): + if questions_remaining <= 0: + break + + try: + # Single question processing + prompt = PromptBuilder.build_prompt( + model_id=request.model_id, + use_case=request.use_case, + topic=topic, + num_questions=1, + omit_questions=omit_questions, + examples=request.examples or [], + technique=request.technique, + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + + try: + single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.warning(f"Batch processing failed: {str(e)}") + if isinstance(e, JSONParsingError): + # For JSON parsing errors, fall back to single processing + self.logger.info("JSON parsing failed, falling back to single processing") + continue + else: + # For other model errors, propagate up + raise + + if single_qa_pairs and len(single_qa_pairs) > 0: + pair = single_qa_pairs[0] + if self._validate_qa_pair(pair): + validated_pair = { + "question": pair["question"], + "solution": pair["solution"] + } + validated_output = { + "Topic": topic, + "question": pair["question"], + "solution": pair["solution"] + } + + topic_results.append(validated_pair) + topic_output.append(validated_output) + omit_questions.append(pair["question"]) + omit_questions = omit_questions[-100:] + questions_remaining -= 1 + + self.logger.info(f"Successfully generated single question for topic {topic}") + else: + error_msg = f"Invalid QA pair structure in single processing for topic {topic}" + self.logger.warning(error_msg) + topic_errors.append(error_msg) + else: + error_msg = f"No QA pair generated in single processing for topic {topic}" + self.logger.warning(error_msg) + topic_errors.append(error_msg) + + except ModelHandlerError as e: + # Don't raise - add to errors and continue + error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + continue + + except ModelHandlerError: + # Re-raise ModelHandlerError to propagate up + raise + except Exception as e: + error_msg = f"Error processing batch for topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + continue + + except ModelHandlerError: + # Re-raise ModelHandlerError to propagate up + raise + except Exception as e: + error_msg = f"Critical error processing topic {topic}: {str(e)}" + self.logger.error(error_msg) + topic_errors.append(error_msg) + + return topic, topic_results, topic_errors, topic_output + + + async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict: + """Generate examples based on request parameters (SFT technique)""" + try: + output_key = request.output_key + output_value = request.output_value + st = time.time() + self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") + + # Use default parameters if none provided + model_params = request.model_params or ModelParameters() + + # Create model handler + self.logger.info("Creating model handler") + model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint) + + # Limit topics and questions in demo mode + if request.doc_paths: + processor = DocumentProcessor(chunk_size=1000, overlap=100) + paths = request.doc_paths + topics = [] + for path in paths: + chunks = processor.process_document(path) + topics.extend(chunks) + #topics = topics[0:5] + print("total chunks: ", len(topics)) + if request.num_questions<=len(topics): + topics = topics[0:request.num_questions] + num_questions = 1 + print("num_questions :", num_questions) + else: + num_questions = math.ceil(request.num_questions/len(topics)) + #print(num_questions) + total_count = request.num_questions + else: + if request.topics: + topics = request.topics + num_questions = request.num_questions + total_count = request.num_questions*len(request.topics) + + else: + self.logger.error("Generation failed: No topics provided") + raise RuntimeError("Invalid input: No topics provided") + + + # Track results for each topic + results = {} + all_errors = [] + final_output = [] + + # Create thread pool + loop = asyncio.get_event_loop() + with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: + topic_futures = [ + loop.run_in_executor( + executor, + self.process_single_topic, + topic, + model_handler, + request, + num_questions, + request_id + ) + for topic in topics + ] + + # Wait for all futures to complete + try: + completed_topics = await asyncio.gather(*topic_futures) + except ModelHandlerError as e: + self.logger.error(f"Model generation failed: {str(e)}") + raise APIError(f"Failed to generate content: {str(e)}") + + # Process results + + for topic, topic_results, topic_errors, topic_output in completed_topics: + if topic_errors: + all_errors.extend(topic_errors) + if topic_results and is_demo: + results[topic] = topic_results + if topic_output: + final_output.extend(topic_output) + + generation_time = time.time() - st + self.logger.info(f"Generation completed in {generation_time:.2f} seconds") + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + mode_suffix = "test" if is_demo else "final" + model_name = get_model_family(request.model_id).split('.')[-1] + file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" + if request.doc_paths: + final_output = [{ + 'Generated_From': item['Topic'], + output_key: item['question'], + output_value: item['solution'] } + for item in final_output] + else: + final_output = [{ + 'Seeds': item['Topic'], + output_key: item['question'], + output_value: item['solution'] } + for item in final_output] + output_path = {} + try: + with open(file_path, "w") as f: + json.dump(final_output, indent=2, fp=f) + except Exception as e: + self.logger.error(f"Error saving results: {str(e)}", exc_info=True) + + output_path['local']= file_path + + + + + # Handle custom prompt, examples and schema + custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) + # For examples + examples_value = ( + PromptHandler.get_default_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + + # For schema + schema_value = ( + PromptHandler.get_default_schema(request.use_case, request.schema) + if hasattr(request, 'schema') + else None + ) + schema_str = self.safe_json_dumps(schema_value) + + # For topics + topics_value = topics if not getattr(request, 'doc_paths', None) else None + topic_str = self.safe_json_dumps(topics_value) + + # For doc_paths and input_path + doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) + input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) + + metadata = { + 'timestamp': timestamp, + 'technique': request.technique, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'final_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(output_path['local']), + 'display_name': request.display_name, + 'output_path': output_path, + 'num_questions':getattr(request, 'num_questions', None), + 'topics': topic_str, + 'examples': examples_str, + "total_count":total_count, + 'schema': schema_str, + 'doc_paths': doc_paths_str, + 'input_path':input_path_str, + 'input_key': request.input_key, + 'output_key':request.output_key, + 'output_value':request.output_value + } + + #print("metadata: ",metadata) + if is_demo: + + self.db.save_generation_metadata(metadata) + return { + "status": "completed" if results else "failed", + "results": results, + "errors": all_errors if all_errors else None, + "export_path": output_path + } + else: + # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) + # time_stamp = extract_timestamp(metadata.get('generate_file_name')) + job_status = "ENGINE_SUCCEEDED" + generate_file_name = os.path.basename(output_path['local']) + + self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed" if final_output else "failed", + "export_path": output_path + } + except APIError: + raise + + except Exception as e: + self.logger.error(f"Generation failed: {str(e)}", exc_info=True) + if is_demo: + raise APIError(str(e)) # Let middleware decide status code + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "ENGINE_FAILED" + file_name = '' + output_path = '' + self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) + raise # Just re-raise the original exception + + + def _validate_qa_pair(self, pair: Dict) -> bool: + """Validate a question-answer pair""" + return ( + isinstance(pair, dict) and + "question" in pair and + "solution" in pair and + isinstance(pair["question"], str) and + isinstance(pair["solution"], str) and + len(pair["question"].strip()) > 0 and + len(pair["solution"].strip()) > 0 + ) + + #@track_llm_operation("process_single_input") + async def process_single_input(self, input, model_handler, request, request_id=None): + try: + prompt = PromptBuilder.build_generate_result_prompt( + model_id=request.model_id, + use_case=request.use_case, + input=input, + examples=request.examples or [], + schema=request.schema, + custom_prompt=request.custom_prompt, + ) + try: + result = model_handler.generate_response(prompt, request_id=request_id) + except ModelHandlerError as e: + self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") + raise + + return {"question": input, "solution": result} + + except ModelHandlerError: + raise + except Exception as e: + self.logger.error(f"Error processing input: {str(e)}") + raise APIError(f"Failed to process input: {str(e)}") + + async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict: + """Generate results based on request parameters (Custom_Workflow technique)""" + try: + self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") + + + # Use default parameters if none provided + model_params = request.model_params or ModelParameters() + + # Create model handler + self.logger.info("Creating model handler") + model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True) + + inputs = [] + file_paths = request.input_path + for path in file_paths: + try: + with open(path) as f: + data = json.load(f) + inputs.extend(item.get(request.input_key, '') for item in data) + except Exception as e: + print(f"Error processing {path}: {str(e)}") + MAX_WORKERS = 5 + + + # Create thread pool + loop = asyncio.get_event_loop() + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + # Create futures for each input + input_futures = [ + loop.run_in_executor( + executor, + lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)), + input + ) + for input in inputs + ] + + # Wait for all futures to complete + try: + final_output = await asyncio.gather(*input_futures) + except ModelHandlerError as e: + self.logger.error(f"Model generation failed: {str(e)}") + raise APIError(f"Failed to generate content: {str(e)}") + + + + + timestamp = datetime.now(timezone.utc).isoformat() + time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] + mode_suffix = "test" if is_demo else "final" + model_name = get_model_family(request.model_id).split('.')[-1] + file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" + input_key = request.output_key or request.input_key + result = [{ + + input_key: item['question'], + request.output_value: item['solution'] } + for item in final_output] + output_path = {} + try: + with open(file_path, "w") as f: + json.dump(result, indent=2, fp=f) + except Exception as e: + self.logger.error(f"Error saving results: {str(e)}", exc_info=True) + + + + + output_path['local']= file_path + + + # Handle custom prompt, examples and schema + custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) + # For examples + examples_value = ( + PromptHandler.get_default_example(request.use_case, request.examples) + if hasattr(request, 'examples') + else None + ) + examples_str = self.safe_json_dumps(examples_value) + + # For schema + schema_value = ( + PromptHandler.get_default_schema(request.use_case, request.schema) + if hasattr(request, 'schema') + else None + ) + schema_str = self.safe_json_dumps(schema_value) + + # For topics + topics_value = None + topic_str = self.safe_json_dumps(topics_value) + + # For doc_paths and input_path + doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) + input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) + + + + metadata = { + 'timestamp': timestamp, + 'technique': request.technique, + 'model_id': request.model_id, + 'inference_type': request.inference_type, + 'caii_endpoint':request.caii_endpoint, + 'use_case': request.use_case, + 'final_prompt': custom_prompt_str, + 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, + 'generate_file_name': os.path.basename(output_path['local']), + 'display_name': request.display_name, + 'output_path': output_path, + 'num_questions':getattr(request, 'num_questions', None), + 'topics': topic_str, + 'examples': examples_str, + "total_count":len(inputs), + 'schema': schema_str, + 'doc_paths': doc_paths_str, + 'input_path':input_path_str, + 'input_key': request.input_key, + 'output_key':request.output_key, + 'output_value':request.output_value + } + + + if is_demo: + + self.db.save_generation_metadata(metadata) + return { + "status": "completed" if final_output else "failed", + "results": final_output, + "export_path": output_path + } + else: + # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) + # time_stamp = extract_timestamp(metadata.get('generate_file_name')) + job_status = "success" + generate_file_name = os.path.basename(output_path['local']) + + self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) + self.db.backup_and_restore_db() + return { + "status": "completed" if final_output else "failed", + "export_path": output_path + } + + except APIError: + raise + except Exception as e: + self.logger.error(f"Generation failed: {str(e)}", exc_info=True) + if is_demo: + raise APIError(str(e)) # Let middleware decide status code + else: + time_stamp = datetime.now(timezone.utc).isoformat() + job_status = "failure" + file_name = '' + output_path = '' + self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) + raise # Just re-raise the original exception + + def get_health_check(self) -> Dict: + """Get service health status""" + try: + test_body = { + "prompt": "\n\nHuman: test\n\nAssistant: ", + "max_tokens_to_sample": 1, + "temperature": 0 + } + + self.bedrock_client.invoke_model( + modelId="anthropic.claude-instant-v1", + body=json.dumps(test_body) + ) + + status = { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "service": "SynthesisLegacyService", + "aws_region": self.bedrock_client.meta.region_name + } + self.logger.info("Health check passed", extra=status) + return status + + except Exception as e: + status = { + "status": "unhealthy", + "error": str(e), + "timestamp": datetime.now().isoformat(), + "service": "SynthesisLegacyService", + "aws_region": self.bedrock_client.meta.region_name + } + self.logger.error("Health check failed", extra=status, exc_info=True) + return status + + def safe_json_dumps(self, value): + """Convert value to JSON string only if it's not None""" + return json.dumps(value) if value is not None else None diff --git a/app/services/synthesis_service.py b/app/services/synthesis_service.py index e42ac3c8..413de18f 100644 --- a/app/services/synthesis_service.py +++ b/app/services/synthesis_service.py @@ -34,19 +34,13 @@ class SynthesisService: - """Service for generating synthetic QA pairs""" + """Service for generating synthetic freeform data (Freeform technique only)""" QUESTIONS_PER_BATCH = 5 # Maximum questions per batch MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations def __init__(self): - # self.bedrock_client = boto3.Session(profile_name='cu_manowar_dev').client( - # 'bedrock-runtime', - # region_name='us-west-2' - # ) - self.bedrock_client = get_bedrock_client() - - + self.bedrock_client = get_bedrock_client() self.db = DatabaseManager() self._setup_logging() self.guard = ContentGuardrail() @@ -80,592 +74,6 @@ def _setup_logging(self): error_handler.setFormatter(formatter) self.logger.addHandler(error_handler) - - #@track_llm_operation("process_single_topic") - def process_single_topic(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: - """ - Process a single topic to generate questions and solutions. - Attempts batch processing first (default 5 questions), falls back to single question processing if batch fails. - - Args: - topic: The topic to generate questions for - model_handler: Handler for the AI model - request: The synthesis request object - num_questions: Total number of questions to generate - - Returns: - Tuple containing: - - topic (str) - - list of validated QA pairs - - list of error messages - - list of output dictionaries with topic information - - Raises: - ModelHandlerError: When there's an error in model generation that should stop processing - """ - topic_results = [] - topic_output = [] - topic_errors = [] - questions_remaining = num_questions - omit_questions = [] - - try: - # Process questions in batches - for batch_idx in range(0, num_questions, self.QUESTIONS_PER_BATCH): - if questions_remaining <= 0: - break - - batch_size = min(self.QUESTIONS_PER_BATCH, questions_remaining) - self.logger.info(f"Processing topic: {topic}, attempting batch {batch_idx+1}-{batch_idx+batch_size}") - - try: - # Attempt batch processing - prompt = PromptBuilder.build_prompt( - model_id=request.model_id, - use_case=request.use_case, - topic=topic, - num_questions=batch_size, - omit_questions=omit_questions, - examples=request.examples or [], - technique=request.technique, - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - # print("prompt :", prompt) - batch_qa_pairs = None - try: - batch_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.warning(f"Batch processing failed: {str(e)}") - if isinstance(e, JSONParsingError): - # For JSON parsing errors, fall back to single processing - self.logger.info("JSON parsing failed, falling back to single processing") - continue - else: - # For other model errors, propagate up - raise - - if batch_qa_pairs: - # Process batch results - valid_pairs = [] - valid_outputs = [] - invalid_count = 0 - - for pair in batch_qa_pairs: - if self._validate_qa_pair(pair): - valid_pairs.append({ - "question": pair["question"], - "solution": pair["solution"] - }) - valid_outputs.append({ - "Topic": topic, - "question": pair["question"], - "solution": pair["solution"] - }) - omit_questions.append(pair["question"]) - #else: - invalid_count = batch_size - len(valid_pairs) - - if valid_pairs: - topic_results.extend(valid_pairs) - topic_output.extend(valid_outputs) - questions_remaining -= len(valid_pairs) - omit_questions = omit_questions[-100:] # Keep last 100 questions - self.logger.info(f"Successfully generated {len(valid_pairs)} questions in batch for topic {topic}") - print("invalid_count:", invalid_count, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) - # If all pairs were valid, skip fallback - if invalid_count <= 0: - continue - - else: - # Fall back to single processing for remaining or failed questions - self.logger.info(f"Falling back to single processing for remaining questions in topic {topic}") - remaining_batch = invalid_count - print("remaining_batch:", remaining_batch, '\n', "batch_size: ", batch_size, '\n', "valid_pairs: ", len(valid_pairs)) - for _ in range(remaining_batch): - if questions_remaining <= 0: - break - - try: - # Single question processing - prompt = PromptBuilder.build_prompt( - model_id=request.model_id, - use_case=request.use_case, - topic=topic, - num_questions=1, - omit_questions=omit_questions, - examples=request.examples or [], - technique=request.technique, - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - - try: - single_qa_pairs = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.warning(f"Batch processing failed: {str(e)}") - if isinstance(e, JSONParsingError): - # For JSON parsing errors, fall back to single processing - self.logger.info("JSON parsing failed, falling back to single processing") - continue - else: - # For other model errors, propagate up - raise - - if single_qa_pairs and len(single_qa_pairs) > 0: - pair = single_qa_pairs[0] - if self._validate_qa_pair(pair): - validated_pair = { - "question": pair["question"], - "solution": pair["solution"] - } - validated_output = { - "Topic": topic, - "question": pair["question"], - "solution": pair["solution"] - } - - topic_results.append(validated_pair) - topic_output.append(validated_output) - omit_questions.append(pair["question"]) - omit_questions = omit_questions[-100:] - questions_remaining -= 1 - - self.logger.info(f"Successfully generated single question for topic {topic}") - else: - error_msg = f"Invalid QA pair structure in single processing for topic {topic}" - self.logger.warning(error_msg) - topic_errors.append(error_msg) - else: - error_msg = f"No QA pair generated in single processing for topic {topic}" - self.logger.warning(error_msg) - topic_errors.append(error_msg) - - except ModelHandlerError as e: - # Don't raise - add to errors and continue - error_msg = f"ModelHandlerError in single processing for topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - continue - - except ModelHandlerError: - # Re-raise ModelHandlerError to propagate up - raise - except Exception as e: - error_msg = f"Error processing batch for topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - continue - - except ModelHandlerError: - # Re-raise ModelHandlerError to propagate up - raise - except Exception as e: - error_msg = f"Critical error processing topic {topic}: {str(e)}" - self.logger.error(error_msg) - topic_errors.append(error_msg) - - return topic, topic_results, topic_errors, topic_output - - - async def generate_examples(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id= None) -> Dict: - """Generate examples based on request parameters""" - try: - output_key = request.output_key - output_value = request.output_value - st = time.time() - self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") - - # Use default parameters if none provided - model_params = request.model_params or ModelParameters() - - # Create model handler - self.logger.info("Creating model handler") - model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint) - - # Limit topics and questions in demo mode - if request.doc_paths: - processor = DocumentProcessor(chunk_size=1000, overlap=100) - paths = request.doc_paths - topics = [] - for path in paths: - chunks = processor.process_document(path) - topics.extend(chunks) - #topics = topics[0:5] - print("total chunks: ", len(topics)) - if request.num_questions<=len(topics): - topics = topics[0:request.num_questions] - num_questions = 1 - print("num_questions :", num_questions) - else: - num_questions = math.ceil(request.num_questions/len(topics)) - #print(num_questions) - total_count = request.num_questions - else: - if request.topics: - topics = request.topics - num_questions = request.num_questions - total_count = request.num_questions*len(request.topics) - - else: - self.logger.error("Generation failed: No topics provided") - raise RuntimeError("Invalid input: No topics provided") - - - # Track results for each topic - results = {} - all_errors = [] - final_output = [] - - # Create thread pool - loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: - topic_futures = [ - loop.run_in_executor( - executor, - self.process_single_topic, - topic, - model_handler, - request, - num_questions, - request_id - ) - for topic in topics - ] - - # Wait for all futures to complete - try: - completed_topics = await asyncio.gather(*topic_futures) - except ModelHandlerError as e: - self.logger.error(f"Model generation failed: {str(e)}") - raise APIError(f"Failed to generate content: {str(e)}") - - # Process results - - for topic, topic_results, topic_errors, topic_output in completed_topics: - if topic_errors: - all_errors.extend(topic_errors) - if topic_results and is_demo: - results[topic] = topic_results - if topic_output: - final_output.extend(topic_output) - - generation_time = time.time() - st - self.logger.info(f"Generation completed in {generation_time:.2f} seconds") - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - mode_suffix = "test" if is_demo else "final" - model_name = get_model_family(request.model_id).split('.')[-1] - file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" - if request.doc_paths: - final_output = [{ - 'Generated_From': item['Topic'], - output_key: item['question'], - output_value: item['solution'] } - for item in final_output] - else: - final_output = [{ - 'Seeds': item['Topic'], - output_key: item['question'], - output_value: item['solution'] } - for item in final_output] - output_path = {} - try: - with open(file_path, "w") as f: - json.dump(final_output, indent=2, fp=f) - except Exception as e: - self.logger.error(f"Error saving results: {str(e)}", exc_info=True) - - output_path['local']= file_path - - - - - # Handle custom prompt, examples and schema - custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) - # For examples - examples_value = ( - PromptHandler.get_default_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - - # For schema - schema_value = ( - PromptHandler.get_default_schema(request.use_case, request.schema) - if hasattr(request, 'schema') - else None - ) - schema_str = self.safe_json_dumps(schema_value) - - # For topics - topics_value = topics if not getattr(request, 'doc_paths', None) else None - topic_str = self.safe_json_dumps(topics_value) - - # For doc_paths and input_path - doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) - input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) - - metadata = { - 'timestamp': timestamp, - 'technique': request.technique, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'final_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(output_path['local']), - 'display_name': request.display_name, - 'output_path': output_path, - 'num_questions':getattr(request, 'num_questions', None), - 'topics': topic_str, - 'examples': examples_str, - "total_count":total_count, - 'schema': schema_str, - 'doc_paths': doc_paths_str, - 'input_path':input_path_str, - 'input_key': request.input_key, - 'output_key':request.output_key, - 'output_value':request.output_value - } - - #print("metadata: ",metadata) - if is_demo: - - self.db.save_generation_metadata(metadata) - return { - "status": "completed" if results else "failed", - "results": results, - "errors": all_errors if all_errors else None, - "export_path": output_path - } - else: - # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) - # time_stamp = extract_timestamp(metadata.get('generate_file_name')) - job_status = "ENGINE_SUCCEEDED" - generate_file_name = os.path.basename(output_path['local']) - - self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed" if final_output else "failed", - "export_path": output_path - } - except APIError: - raise - - except Exception as e: - self.logger.error(f"Generation failed: {str(e)}", exc_info=True) - if is_demo: - raise APIError(str(e)) # Let middleware decide status code - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "ENGINE_FAILED" - file_name = '' - output_path = '' - self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) - raise # Just re-raise the original exception - - - def _validate_qa_pair(self, pair: Dict) -> bool: - """Validate a question-answer pair""" - return ( - isinstance(pair, dict) and - "question" in pair and - "solution" in pair and - isinstance(pair["question"], str) and - isinstance(pair["solution"], str) and - len(pair["question"].strip()) > 0 and - len(pair["solution"].strip()) > 0 - ) - #@track_llm_operation("process_single_input") - async def process_single_input(self, input, model_handler, request, request_id=None): - try: - prompt = PromptBuilder.build_generate_result_prompt( - model_id=request.model_id, - use_case=request.use_case, - input=input, - examples=request.examples or [], - schema=request.schema, - custom_prompt=request.custom_prompt, - ) - try: - result = model_handler.generate_response(prompt, request_id=request_id) - except ModelHandlerError as e: - self.logger.error(f"ModelHandlerError in generate_response: {str(e)}") - raise - - return {"question": input, "solution": result} - - except ModelHandlerError: - raise - except Exception as e: - self.logger.error(f"Error processing input: {str(e)}") - raise APIError(f"Failed to process input: {str(e)}") - - async def generate_result(self, request: SynthesisRequest , job_name = None, is_demo: bool = True, request_id=None) -> Dict: - try: - self.logger.info(f"Starting example generation - Demo Mode: {is_demo}") - - - # Use default parameters if none provided - model_params = request.model_params or ModelParameters() - - # Create model handler - self.logger.info("Creating model handler") - model_handler = create_handler(request.model_id, self.bedrock_client, model_params = model_params, inference_type = request.inference_type, caii_endpoint = request.caii_endpoint, custom_p = True) - - inputs = [] - file_paths = request.input_path - for path in file_paths: - try: - with open(path) as f: - data = json.load(f) - inputs.extend(item.get(request.input_key, '') for item in data) - except Exception as e: - print(f"Error processing {path}: {str(e)}") - MAX_WORKERS = 5 - - - # Create thread pool - loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - # Create futures for each input - input_futures = [ - loop.run_in_executor( - executor, - lambda x: asyncio.run(self.process_single_input(x, model_handler, request, request_id)), - input - ) - for input in inputs - ] - - # Wait for all futures to complete - try: - final_output = await asyncio.gather(*input_futures) - except ModelHandlerError as e: - self.logger.error(f"Model generation failed: {str(e)}") - raise APIError(f"Failed to generate content: {str(e)}") - - - - - timestamp = datetime.now(timezone.utc).isoformat() - time_file = datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%f')[:-3] - mode_suffix = "test" if is_demo else "final" - model_name = get_model_family(request.model_id).split('.')[-1] - file_path = f"qa_pairs_{model_name}_{time_file}_{mode_suffix}.json" - input_key = request.output_key or request.input_key - result = [{ - - input_key: item['question'], - request.output_value: item['solution'] } - for item in final_output] - output_path = {} - try: - with open(file_path, "w") as f: - json.dump(result, indent=2, fp=f) - except Exception as e: - self.logger.error(f"Error saving results: {str(e)}", exc_info=True) - - - - - - output_path['local']= file_path - - - # Handle custom prompt, examples and schema - custom_prompt_str = PromptHandler.get_default_custom_prompt(request.use_case, request.custom_prompt) - # For examples - examples_value = ( - PromptHandler.get_default_example(request.use_case, request.examples) - if hasattr(request, 'examples') - else None - ) - examples_str = self.safe_json_dumps(examples_value) - - # For schema - schema_value = ( - PromptHandler.get_default_schema(request.use_case, request.schema) - if hasattr(request, 'schema') - else None - ) - schema_str = self.safe_json_dumps(schema_value) - - # For topics - topics_value = None - topic_str = self.safe_json_dumps(topics_value) - - # For doc_paths and input_path - doc_paths_str = self.safe_json_dumps(getattr(request, 'doc_paths', None)) - input_path_str = self.safe_json_dumps(getattr(request, 'input_path', None)) - - - - metadata = { - 'timestamp': timestamp, - 'technique': request.technique, - 'model_id': request.model_id, - 'inference_type': request.inference_type, - 'caii_endpoint':request.caii_endpoint, - 'use_case': request.use_case, - 'final_prompt': custom_prompt_str, - 'model_parameters': json.dumps(model_params.model_dump()) if model_params else None, - 'generate_file_name': os.path.basename(output_path['local']), - 'display_name': request.display_name, - 'output_path': output_path, - 'num_questions':getattr(request, 'num_questions', None), - 'topics': topic_str, - 'examples': examples_str, - "total_count":len(inputs), - 'schema': schema_str, - 'doc_paths': doc_paths_str, - 'input_path':input_path_str, - 'input_key': request.input_key, - 'output_key':request.output_key, - 'output_value':request.output_value - } - - - if is_demo: - - self.db.save_generation_metadata(metadata) - return { - "status": "completed" if final_output else "failed", - "results": final_output, - "export_path": output_path - } - else: - # extract_timestamp = lambda filename: '_'.join(filename.split('_')[-3:-1]) - # time_stamp = extract_timestamp(metadata.get('generate_file_name')) - job_status = "success" - generate_file_name = os.path.basename(output_path['local']) - - self.db.update_job_generate(job_name,generate_file_name, output_path['local'], timestamp, job_status) - self.db.backup_and_restore_db() - return { - "status": "completed" if final_output else "failed", - "export_path": output_path - } - - except APIError: - raise - except Exception as e: - self.logger.error(f"Generation failed: {str(e)}", exc_info=True) - if is_demo: - raise APIError(str(e)) # Let middleware decide status code - else: - time_stamp = datetime.now(timezone.utc).isoformat() - job_status = "failure" - file_name = '' - output_path = '' - self.db.update_job_generate(job_name, file_name, output_path, time_stamp, job_status) - raise # Just re-raise the original exception - #@track_llm_operation("process_single_freeform") def process_single_freeform(self, topic: str, model_handler: any, request: SynthesisRequest, num_questions: int, request_id=None) -> Tuple[str, List[Dict], List[str], List[Dict]]: """ @@ -1209,4 +617,4 @@ def get_health_check(self) -> Dict: def safe_json_dumps(self, value): """Convert value to JSON string only if it's not None""" - return json.dumps(value) if value is not None else None \ No newline at end of file + return json.dumps(value) if value is not None else None diff --git a/tests/integration/test_evaluate_api.py b/tests/integration/test_evaluate_api.py index a842cd55..2ff664a9 100644 --- a/tests/integration/test_evaluate_api.py +++ b/tests/integration/test_evaluate_api.py @@ -3,7 +3,7 @@ import json from fastapi.testclient import TestClient from pathlib import Path -from app.main import app, db_manager, evaluator_service # global instance created on import +from app.main import app, db_manager, evaluator_legacy_service # global instance created on import client = TestClient(app) # Create a dummy bedrock client that simulates the Converse/invoke_model response. @@ -37,8 +37,8 @@ def mock_qa_file(tmp_path, mock_qa_data): # Patch the global evaluator_service's AWS client before tests run. @pytest.fixture(autouse=True) def patch_evaluator_bedrock_client(): - from app.main import evaluator_service - evaluator_service.bedrock_client = DummyBedrockClient() + from app.main import evaluator_legacy_service + evaluator_legacy_service.bedrock_client = DummyBedrockClient() def test_evaluate_endpoint(mock_qa_file): request_data = { @@ -51,7 +51,7 @@ def test_evaluate_endpoint(mock_qa_file): "output_value": "Completion" } # Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation. - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] response = client.post("/synthesis/evaluate", json=request_data) # In demo mode, our endpoint returns a dict with "status", "result", and "output_path". @@ -71,7 +71,7 @@ def test_job_handling(mock_qa_file): "output_key": "Prompt", "output_value": "Completion" } - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] response = client.post("/synthesis/evaluate", json=request_data) assert response.status_code == 200 diff --git a/tests/integration/test_evaluate_legacy_api.py b/tests/integration/test_evaluate_legacy_api.py new file mode 100644 index 00000000..2ff664a9 --- /dev/null +++ b/tests/integration/test_evaluate_legacy_api.py @@ -0,0 +1,113 @@ +import pytest +from unittest.mock import patch, Mock +import json +from fastapi.testclient import TestClient +from pathlib import Path +from app.main import app, db_manager, evaluator_legacy_service # global instance created on import +client = TestClient(app) + +# Create a dummy bedrock client that simulates the Converse/invoke_model response. +class DummyBedrockClient: + def invoke_model(self, modelId, body): + # Return a dummy response structure (adjust if your handler expects a different format) + return [{ + "score": 1.0, + "justification": "Dummy response from invoke_model" + }] + @property + def meta(self): + class Meta: + region_name = "us-west-2" + return Meta() + +@pytest.fixture +def mock_qa_data(): + return [ + {"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"}, + {"Seeds": "python_basics", "Prompt": "How do you define a function?", "Completion": "Use the def keyword followed by function name"} + ] + +@pytest.fixture +def mock_qa_file(tmp_path, mock_qa_data): + file_path = tmp_path / "qa_pairs.json" + with open(file_path, "w") as f: + json.dump(mock_qa_data, f) + return str(file_path) + +# Patch the global evaluator_service's AWS client before tests run. +@pytest.fixture(autouse=True) +def patch_evaluator_bedrock_client(): + from app.main import evaluator_legacy_service + evaluator_legacy_service.bedrock_client = DummyBedrockClient() + +def test_evaluate_endpoint(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + # Optionally, patch create_handler to return a dummy handler that returns a dummy evaluation. + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] + response = client.post("/synthesis/evaluate", json=request_data) + # In demo mode, our endpoint returns a dict with "status", "result", and "output_path". + assert response.status_code == 200 + res_json = response.json() + assert res_json["status"] == "completed" + assert "output_path" in res_json + assert "result" in res_json + +def test_job_handling(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 1.0, "justification": "Dummy evaluation"}] + response = client.post("/synthesis/evaluate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + # In demo mode, we don't expect a "job_id" key; we check for "output_path" and "result". + assert "output_path" in res_json + # Now simulate history by patching db_manager.get_all_evaluate_metadata + db_manager.get_all_evaluate_metadata = lambda: [{"evaluate_file_name": "test.json", "average_score": 0.9}] + response = client.get("/evaluations/history") + assert response.status_code == 200 + history = response.json() + assert len(history) > 0 + +def test_evaluate_with_invalid_model(mock_qa_file): + request_data = { + "use_case": "custom", + "model_id": "invalid.model", + "inference_type": "aws_bedrock", + "import_path": mock_qa_file, + "is_demo": True, + "output_key": "Prompt", + "output_value": "Completion" + } + + from app.core.exceptions import ModelHandlerError + + # Patch create_handler to raise ModelHandlerError + with patch('app.services.evaluator_service.create_handler') as mock_create: + mock_create.side_effect = ModelHandlerError("Invalid model identifier: invalid.model") + response = client.post("/synthesis/evaluate", json=request_data) + + # Print for debugging + print(f"Response status: {response.status_code}") + print(f"Response content: {response.json()}") + + # Expect a 400 or 500 error response + assert response.status_code in [400, 500] + res_json = response.json() + assert "error" in res_json diff --git a/tests/integration/test_synthesis_api.py b/tests/integration/test_synthesis_api.py index cbd60e9f..db54d9c7 100644 --- a/tests/integration/test_synthesis_api.py +++ b/tests/integration/test_synthesis_api.py @@ -14,7 +14,7 @@ def test_generate_endpoint_with_topics(): "topics": ["python_basics"], "is_demo": True } - with patch('app.main.SynthesisService.generate_examples') as mock_generate: + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: mock_generate.return_value = { "status": "completed", "export_path": {"local": "test.json"}, @@ -35,7 +35,7 @@ def test_generate_endpoint_with_doc_paths(): "doc_paths": ["test.pdf"], "is_demo": True } - with patch('app.main.SynthesisService.generate_examples') as mock_generate: + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: mock_generate.return_value = { "status": "completed", "export_path": {"local": "test.json"}, diff --git a/tests/integration/test_synthesis_legacy_api.py b/tests/integration/test_synthesis_legacy_api.py new file mode 100644 index 00000000..db54d9c7 --- /dev/null +++ b/tests/integration/test_synthesis_legacy_api.py @@ -0,0 +1,91 @@ +import pytest +from unittest.mock import patch +import json +from fastapi.testclient import TestClient +from app.main import app, db_manager +client = TestClient(app) + +def test_generate_endpoint_with_topics(): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "num_questions": 2, + "topics": ["python_basics"], + "is_demo": True + } + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: + mock_generate.return_value = { + "status": "completed", + "export_path": {"local": "test.json"}, + "results": {"python_basics": [{"question": "test?", "solution": "test!"}]} + } + response = client.post("/synthesis/generate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + assert res_json.get("status") == "completed" + assert "export_path" in res_json + +def test_generate_endpoint_with_doc_paths(): + request_data = { + "use_case": "custom", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "inference_type": "aws_bedrock", + "num_questions": 2, + "doc_paths": ["test.pdf"], + "is_demo": True + } + with patch('app.main.synthesis_legacy_service.generate_examples') as mock_generate: + mock_generate.return_value = { + "status": "completed", + "export_path": {"local": "test.json"}, + "results": {"chunk1": [{"question": "test?", "solution": "test!"}]} + } + response = client.post("/synthesis/generate", json=request_data) + assert response.status_code == 200 + res_json = response.json() + assert res_json.get("status") == "completed" + assert "export_path" in res_json + +def test_generation_history(): + # Patch db_manager.get_paginated_generate_metadata to return dummy metadata with pagination info + db_manager.get_paginated_generate_metadata_light = lambda page, page_size: ( + 1, # total_count + [{"generate_file_name": "qa_pairs_claude_20250210T170521148_test.json", + "timestamp": "2024-02-10T12:00:00", + "model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0"}] + ) + + # Since get_pending_generate_job_ids might be called, we should patch it too + db_manager.get_pending_generate_job_ids = lambda: [] + + response = client.get("/generations/history?page=1&page_size=10") + assert response.status_code == 200 + res_json = response.json() + + # Check that the response contains the expected structure + assert "data" in res_json + assert "pagination" in res_json + + # Check pagination metadata + assert res_json["pagination"]["total"] == 1 + assert res_json["pagination"]["page"] == 1 + assert res_json["pagination"]["page_size"] == 10 + assert res_json["pagination"]["total_pages"] == 1 + + # Check data contents + assert len(res_json["data"]) > 0 + # Instead of expecting exactly "test.json", assert the filename contains "_test.json" + assert "_test.json" in res_json["data"][0]["generate_file_name"] + +def test_error_handling(): + request_data = { + "use_case": "custom", + "model_id": "invalid.model", + "is_demo": True + } + response = client.post("/synthesis/generate", json=request_data) + # Expect an error with status code in [400,503] and key "error" + assert response.status_code in [400, 503] + res_json = response.json() + assert "error" in res_json diff --git a/tests/unit/test_evaluator_freeform_service.py b/tests/unit/test_evaluator_freeform_service.py new file mode 100644 index 00000000..38c8691d --- /dev/null +++ b/tests/unit/test_evaluator_freeform_service.py @@ -0,0 +1,80 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.evaluator_service import EvaluatorService +from app.models.request_models import EvaluationRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_freeform_data(): + return [{"field1": "value1", "field2": "value2", "field3": "value3"}] + +@pytest.fixture +def mock_freeform_file(tmp_path, mock_freeform_data): + file_path = tmp_path / "freeform_data.json" + with open(file_path, "w") as f: + json.dump(mock_freeform_data, f) + return str(file_path) + +@pytest.fixture +def evaluator_freeform_service(): + service = EvaluatorService() + service.db = MockDatabaseManager() + return service + +def test_evaluate_row_data(evaluator_freeform_service, mock_freeform_file): + request = EvaluationRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + use_case="custom", + import_path=mock_freeform_file, + is_demo=True, + output_key="field1", + output_value="field2" + ) + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good freeform data"}] + result = evaluator_freeform_service.evaluate_row_data(request) + assert result["status"] == "completed" + assert "output_path" in result + assert len(evaluator_freeform_service.db.evaluation_metadata) == 1 + +def test_evaluate_single_row(evaluator_freeform_service): + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_response = [{"score": 4, "justification": "Good freeform row"}] + mock_handler.return_value.generate_response.return_value = mock_response + + row = {"field1": "value1", "field2": "value2"} + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="field1", + output_value="field2" + ) + result = evaluator_freeform_service.evaluate_single_row(row, mock_handler.return_value, request) + assert result["evaluation"]["score"] == 4 + assert "justification" in result["evaluation"] + assert result["row"] == row + +def test_evaluate_rows(evaluator_freeform_service): + rows = [ + {"field1": "value1", "field2": "value2"}, + {"field1": "value3", "field2": "value4"} + ] + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="field1", + output_value="field2" + ) + + with patch('app.services.evaluator_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good row"}] + result = evaluator_freeform_service.evaluate_rows(rows, mock_handler.return_value, request) + + assert result["total_evaluated"] == 2 + assert result["average_score"] == 4 + assert len(result["evaluated_rows"]) == 2 diff --git a/tests/unit/test_evaluator_legacy_service.py b/tests/unit/test_evaluator_legacy_service.py new file mode 100644 index 00000000..c7dded81 --- /dev/null +++ b/tests/unit/test_evaluator_legacy_service.py @@ -0,0 +1,83 @@ +import pytest +from io import StringIO +from unittest.mock import patch +import json +from app.services.evaluator_legacy_service import EvaluatorLegacyService +from app.models.request_models import EvaluationRequest +from tests.mocks.mock_db import MockDatabaseManager +from app.core.exceptions import ModelHandlerError, APIError + +@pytest.fixture +def mock_qa_data(): + return [{"question": "test question?", "solution": "test solution"}] + +@pytest.fixture +def mock_qa_file(tmp_path, mock_qa_data): + file_path = tmp_path / "qa_pairs.json" + with open(file_path, "w") as f: + json.dump(mock_qa_data, f) + return str(file_path) + +@pytest.fixture +def evaluator_service(): + service = EvaluatorLegacyService() + service.db = MockDatabaseManager() + return service + +def test_evaluate_results(evaluator_service, mock_qa_file): + request = EvaluationRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + use_case="custom", + import_path=mock_qa_file, + is_demo=True, + output_key="Prompt", + output_value="Completion" + ) + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}] + result = evaluator_service.evaluate_results(request) + assert result["status"] == "completed" + assert "output_path" in result + assert len(evaluator_service.db.evaluation_metadata) == 1 + +def test_evaluate_single_pair(): + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: + mock_response = [{"score": 4, "justification": "Good explanation"}] + mock_handler.return_value.generate_response.return_value = mock_response + service = EvaluatorLegacyService() + qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"} + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + is_demo=True, + output_key="Prompt", + output_value="Completion" + ) + result = service.evaluate_single_pair(qa_pair, mock_handler.return_value, request) + assert result["evaluation"]["score"] == 4 + assert "justification" in result["evaluation"] + +def test_evaluate_results_with_error(): + fake_json = '[{"Seeds": "python_basics", "Prompt": "What is Python?", "Completion": "Python is a programming language"}]' + class DummyHandler: + def generate_response(self, prompt, **kwargs): # Accept any keyword arguments + raise ModelHandlerError("Test error") + with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \ + patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \ + patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \ + patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): + service = EvaluatorLegacyService() + request = EvaluationRequest( + use_case="custom", + model_id="test.model", + inference_type="aws_bedrock", + import_path="test.json", + is_demo=True, + output_key="Prompt", + output_value="Completion", + caii_endpoint="dummy_endpoint", + display_name="dummy" + ) + with pytest.raises(APIError, match="Test error"): + service.evaluate_results(request) diff --git a/tests/unit/test_evaluator_service.py b/tests/unit/test_evaluator_service.py index 629b80a3..c7dded81 100644 --- a/tests/unit/test_evaluator_service.py +++ b/tests/unit/test_evaluator_service.py @@ -2,7 +2,7 @@ from io import StringIO from unittest.mock import patch import json -from app.services.evaluator_service import EvaluatorService +from app.services.evaluator_legacy_service import EvaluatorLegacyService from app.models.request_models import EvaluationRequest from tests.mocks.mock_db import MockDatabaseManager from app.core.exceptions import ModelHandlerError, APIError @@ -20,7 +20,7 @@ def mock_qa_file(tmp_path, mock_qa_data): @pytest.fixture def evaluator_service(): - service = EvaluatorService() + service = EvaluatorLegacyService() service.db = MockDatabaseManager() return service @@ -33,7 +33,7 @@ def test_evaluate_results(evaluator_service, mock_qa_file): output_key="Prompt", output_value="Completion" ) - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"score": 4, "justification": "Good answer"}] result = evaluator_service.evaluate_results(request) assert result["status"] == "completed" @@ -41,10 +41,10 @@ def test_evaluate_results(evaluator_service, mock_qa_file): assert len(evaluator_service.db.evaluation_metadata) == 1 def test_evaluate_single_pair(): - with patch('app.services.evaluator_service.create_handler') as mock_handler: + with patch('app.services.evaluator_legacy_service.create_handler') as mock_handler: mock_response = [{"score": 4, "justification": "Good explanation"}] mock_handler.return_value.generate_response.return_value = mock_response - service = EvaluatorService() + service = EvaluatorLegacyService() qa_pair = {"Prompt": "What is Python?", "Completion": "Python is a programming language"} request = EvaluationRequest( use_case="custom", @@ -63,11 +63,11 @@ def test_evaluate_results_with_error(): class DummyHandler: def generate_response(self, prompt, **kwargs): # Accept any keyword arguments raise ModelHandlerError("Test error") - with patch('app.services.evaluator_service.os.path.exists', return_value=True), \ + with patch('app.services.evaluator_legacy_service.os.path.exists', return_value=True), \ patch('builtins.open', new=lambda f, mode, *args, **kwargs: StringIO(fake_json)), \ - patch('app.services.evaluator_service.create_handler', return_value=DummyHandler()), \ - patch('app.services.evaluator_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): - service = EvaluatorService() + patch('app.services.evaluator_legacy_service.create_handler', return_value=DummyHandler()), \ + patch('app.services.evaluator_legacy_service.PromptBuilder.build_eval_prompt', return_value="dummy prompt"): + service = EvaluatorLegacyService() request = EvaluationRequest( use_case="custom", model_id="test.model", diff --git a/tests/unit/test_synthesis_freeform_service.py b/tests/unit/test_synthesis_freeform_service.py new file mode 100644 index 00000000..f0bcb861 --- /dev/null +++ b/tests/unit/test_synthesis_freeform_service.py @@ -0,0 +1,70 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.synthesis_service import SynthesisService +from app.models.request_models import SynthesisRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_json_data(): + return [{"topic": "test_topic", "example_field": "test_value"}] + +@pytest.fixture +def mock_file(tmp_path, mock_json_data): + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(mock_json_data, f) + return str(file_path) + +@pytest.fixture +def synthesis_freeform_service(): + service = SynthesisService() + service.db = MockDatabaseManager() + return service + +@pytest.mark.asyncio +async def test_generate_freeform_with_topics(synthesis_freeform_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=2, + topics=["test_topic"], + is_demo=True, + use_case="custom", + technique="freeform" + ) + with patch('app.services.synthesis_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"field1": "value1", "field2": "value2"}] + result = await synthesis_freeform_service.generate_freeform(request) + assert result["status"] == "completed" + assert len(synthesis_freeform_service.db.generation_metadata) == 1 + assert synthesis_freeform_service.db.generation_metadata[0]["model_id"] == request.model_id + +@pytest.mark.asyncio +async def test_generate_freeform_with_custom_examples(synthesis_freeform_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + topics=["test_topic"], + is_demo=True, + use_case="custom", + technique="freeform", + example_custom=[{"example_field": "example_value"}] + ) + with patch('app.services.synthesis_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"generated_field": "generated_value"}] + result = await synthesis_freeform_service.generate_freeform(request) + assert result["status"] == "completed" + assert "export_path" in result + +def test_validate_freeform_item(synthesis_freeform_service): + # Valid freeform item + valid_item = {"field1": "value1", "field2": "value2"} + assert synthesis_freeform_service._validate_freeform_item(valid_item) == True + + # Invalid freeform item (empty dict) + invalid_item = {} + assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False + + # Invalid freeform item (not a dict) + invalid_item = "not a dict" + assert synthesis_freeform_service._validate_freeform_item(invalid_item) == False diff --git a/tests/unit/test_synthesis_legacy_service.py b/tests/unit/test_synthesis_legacy_service.py new file mode 100644 index 00000000..120fad7c --- /dev/null +++ b/tests/unit/test_synthesis_legacy_service.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import patch, Mock +import json +from app.services.synthesis_legacy_service import SynthesisLegacyService +from app.models.request_models import SynthesisRequest +from tests.mocks.mock_db import MockDatabaseManager + +@pytest.fixture +def mock_json_data(): + return [{"input": "test question?"}] + +@pytest.fixture +def mock_file(tmp_path, mock_json_data): + file_path = tmp_path / "test.json" + with open(file_path, "w") as f: + json.dump(mock_json_data, f) + return str(file_path) + +@pytest.fixture +def synthesis_service(): + service = SynthesisLegacyService() + service.db = MockDatabaseManager() + return service + +@pytest.mark.asyncio +async def test_generate_examples_with_topics(synthesis_service): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + topics=["test_topic"], + is_demo=True, + use_case="custom" + ) + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler: + mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] + result = await synthesis_service.generate_examples(request) + assert result["status"] == "completed" + assert len(synthesis_service.db.generation_metadata) == 1 + assert synthesis_service.db.generation_metadata[0]["model_id"] == request.model_id + +@pytest.mark.asyncio +async def test_generate_examples_with_doc_paths(synthesis_service, mock_file): + request = SynthesisRequest( + model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0", + num_questions=1, + doc_paths=[mock_file], + is_demo=True, + use_case="custom" + ) + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \ + patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor: + mock_processor.return_value.process_document.return_value = ["chunk1"] + mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] + result = await synthesis_service.generate_examples(request) + assert result["status"] == "completed" + assert len(synthesis_service.db.generation_metadata) == 1 diff --git a/tests/unit/test_synthesis_service.py b/tests/unit/test_synthesis_service.py index 6e9ca410..120fad7c 100644 --- a/tests/unit/test_synthesis_service.py +++ b/tests/unit/test_synthesis_service.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, Mock import json -from app.services.synthesis_service import SynthesisService +from app.services.synthesis_legacy_service import SynthesisLegacyService from app.models.request_models import SynthesisRequest from tests.mocks.mock_db import MockDatabaseManager @@ -18,7 +18,7 @@ def mock_file(tmp_path, mock_json_data): @pytest.fixture def synthesis_service(): - service = SynthesisService() + service = SynthesisLegacyService() service.db = MockDatabaseManager() return service @@ -31,7 +31,7 @@ async def test_generate_examples_with_topics(synthesis_service): is_demo=True, use_case="custom" ) - with patch('app.services.synthesis_service.create_handler') as mock_handler: + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler: mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] result = await synthesis_service.generate_examples(request) assert result["status"] == "completed" @@ -47,8 +47,8 @@ async def test_generate_examples_with_doc_paths(synthesis_service, mock_file): is_demo=True, use_case="custom" ) - with patch('app.services.synthesis_service.create_handler') as mock_handler, \ - patch('app.services.synthesis_service.DocumentProcessor') as mock_processor: + with patch('app.services.synthesis_legacy_service.create_handler') as mock_handler, \ + patch('app.services.synthesis_legacy_service.DocumentProcessor') as mock_processor: mock_processor.return_value.process_document.return_value = ["chunk1"] mock_handler.return_value.generate_response.return_value = [{"question": "test?", "solution": "test!"}] result = await synthesis_service.generate_examples(request) From 06537b4be68ffab1c6dfd6d018167984560f8d61 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Wed, 17 Sep 2025 01:12:21 +0530 Subject: [PATCH 04/12] Add configurable concurrency parameters to synthesis and evaluation services - Added max_concurrent_topics field to SynthesisRequest (1-100, default: 5) - Added max_workers field to EvaluationRequest (1-100, default: 4) - Updated all synthesis services to use request.max_concurrent_topics - Updated all evaluator services to use request.max_workers - Added validation constraints to prevent invalid values - Updated example payloads in main.py to include new parameters - Services now respect API-configurable concurrency limits while maintaining defaults This allows users to optimize performance based on their infrastructure and workload requirements via API parameters. --- app/main.py | 3 +++ app/models/request_models.py | 19 ++++++++++++++++--- app/services/evaluator_legacy_service.py | 10 ++++++---- app/services/evaluator_service.py | 7 ++++--- app/services/synthesis_legacy_service.py | 5 +++-- app/services/synthesis_service.py | 5 +++-- 6 files changed, 35 insertions(+), 14 deletions(-) diff --git a/app/main.py b/app/main.py index e82bd296..2becafb8 100644 --- a/app/main.py +++ b/app/main.py @@ -1262,6 +1262,7 @@ async def get_example_payloads(use_case:UseCase): "technique": "sft", "topics": ["python_basics", "data_structures"], "is_demo": True, + "max_concurrent_topics": 5, "examples": [ { "question": "How do you create a list in Python and add elements to it?", @@ -1288,6 +1289,7 @@ async def get_example_payloads(use_case:UseCase): "technique": "sft", "topics": ["basic_queries", "joins"], "is_demo": True, + "max_concurrent_topics": 5, "schema": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(255));\nCREATE TABLE orders (id INT PRIMARY KEY, user_id INT, amount DECIMAL(10,2), FOREIGN KEY (user_id) REFERENCES users(id));", "examples":[ { @@ -1316,6 +1318,7 @@ async def get_example_payloads(use_case:UseCase): "topics": ["topic 1", "topic 2"], "custom_prompt": "Give your instructions here", "is_demo": True, + "max_concurrent_topics": 5, "examples":[ { diff --git a/app/models/request_models.py b/app/models/request_models.py index 7cf9e61d..bf22343f 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -138,7 +138,13 @@ class SynthesisRequest(BaseModel): example_path: Optional[str] = None schema: Optional[str] = None # Added schema field custom_prompt: Optional[str] = None - display_name: Optional[str] = None + display_name: Optional[str] = None + max_concurrent_topics: Optional[int] = Field( + default=5, + ge=1, + le=100, + description="Maximum number of concurrent topics to process (1-100)" + ) # Optional model parameters with defaults model_params: Optional[ModelParameters] = Field( @@ -156,7 +162,7 @@ class SynthesisRequest(BaseModel): "technique": "sft", "topics": ["python_basics", "data_structures"], "is_demo": True, - + "max_concurrent_topics": 5 } } @@ -209,6 +215,12 @@ class EvaluationRequest(BaseModel): display_name: Optional[str] = None output_key: Optional[str] = 'Prompt' output_value: Optional[str] = 'Completion' + max_workers: Optional[int] = Field( + default=4, + ge=1, + le=100, + description="Maximum number of worker threads for parallel evaluation (1-100)" + ) # Export configuration export_type: str = "local" # "local" or "s3" @@ -227,7 +239,8 @@ class EvaluationRequest(BaseModel): "inference_type": "aws_bedrock", "import_path": "qa_pairs_llama3-1-70b-instruct-v1:0_20241114_212837_test.json", "import_type": "local", - "export_type":"local" + "export_type":"local", + "max_workers": 4 } } diff --git a/app/services/evaluator_legacy_service.py b/app/services/evaluator_legacy_service.py index 9d0656a4..328c1f83 100644 --- a/app/services/evaluator_legacy_service.py +++ b/app/services/evaluator_legacy_service.py @@ -24,7 +24,7 @@ class EvaluatorLegacyService: def __init__(self, max_workers: int = 4): self.bedrock_client = get_bedrock_client() self.db = DatabaseManager() - self.max_workers = max_workers + self.max_workers = max_workers # Default max workers (configurable via request) self.guard = ContentGuardrail() self._setup_logging() @@ -155,7 +155,8 @@ def evaluate_topic(self, topic: str, qa_pairs: List[Dict], model_handler, reques failed_pairs = [] try: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + max_workers = request.max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=max_workers) as executor: try: evaluate_func = partial( self.evaluate_single_pair, @@ -287,8 +288,9 @@ def evaluate_results(self, request: EvaluationRequest, job_name=None,is_demo: bo # Add to appropriate topic list transformed_data['results'][topic].append(qa_pair) - self.logger.info(f"Processing {len(transformed_data['results'])} topics with {self.max_workers} workers") - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + max_workers = request.max_workers or self.max_workers + self.logger.info(f"Processing {len(transformed_data['results'])} topics with {max_workers} workers") + with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_topic = { executor.submit( self.evaluate_topic, diff --git a/app/services/evaluator_service.py b/app/services/evaluator_service.py index a556f2ae..f3fbbcad 100644 --- a/app/services/evaluator_service.py +++ b/app/services/evaluator_service.py @@ -21,10 +21,10 @@ class EvaluatorService: """Service for evaluating freeform data rows using Claude with parallel processing (Freeform technique only)""" - def __init__(self, max_workers: int = 4): + def __init__(self, max_workers: int = 5): self.bedrock_client = get_bedrock_client() self.db = DatabaseManager() - self.max_workers = max_workers + self.max_workers = max_workers # Default max workers (configurable via request) self.guard = ContentGuardrail() self._setup_logging() @@ -143,7 +143,8 @@ def evaluate_rows(self, rows: List[Dict[str, Any]], model_handler, request: Eval failed_rows = [] try: - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + max_workers = request.max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=max_workers) as executor: try: evaluate_func = partial( self.evaluate_single_row, diff --git a/app/services/synthesis_legacy_service.py b/app/services/synthesis_legacy_service.py index 7f10927a..62df587c 100644 --- a/app/services/synthesis_legacy_service.py +++ b/app/services/synthesis_legacy_service.py @@ -36,7 +36,7 @@ class SynthesisLegacyService: """Legacy service for generating synthetic QA pairs (SFT and Custom_Workflow only)""" QUESTIONS_PER_BATCH = 5 # Maximum questions per batch - MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations + MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request) def __init__(self): @@ -313,7 +313,8 @@ async def generate_examples(self, request: SynthesisRequest , job_name = None, i # Create thread pool loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: + max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS + with ThreadPoolExecutor(max_workers=max_workers) as executor: topic_futures = [ loop.run_in_executor( executor, diff --git a/app/services/synthesis_service.py b/app/services/synthesis_service.py index 413de18f..6f8e643e 100644 --- a/app/services/synthesis_service.py +++ b/app/services/synthesis_service.py @@ -36,7 +36,7 @@ class SynthesisService: """Service for generating synthetic freeform data (Freeform technique only)""" QUESTIONS_PER_BATCH = 5 # Maximum questions per batch - MAX_CONCURRENT_TOPICS = 5 # Limit concurrent I/O operations + MAX_CONCURRENT_TOPICS = 5 # Default limit for concurrent I/O operations (configurable via request) def __init__(self): @@ -368,7 +368,8 @@ async def generate_freeform(self, request: SynthesisRequest, job_name=None, is_d # Create thread pool loop = asyncio.get_event_loop() - with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_TOPICS) as executor: + max_workers = request.max_concurrent_topics or self.MAX_CONCURRENT_TOPICS + with ThreadPoolExecutor(max_workers=max_workers) as executor: topic_futures = [ loop.run_in_executor( executor, From 09431f34fe738018a22acfe6fe0482e5b88b6028 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Wed, 24 Sep 2025 11:33:02 +0530 Subject: [PATCH 05/12] Add custom model endpoint management system - Add CustomEndpointManager for CRUD operations on custom endpoints - Support for all providers: CAII, Bedrock, OpenAI, OpenAI Compatible, Gemini - Complete REST API for endpoint management (POST/GET/PUT/DELETE) - Integration with existing model handlers for automatic credential lookup - Move custom_endpoint_manager to app/core/ for better organization - Add OpenAI_Endpoint_Compatible_Key to environment variables - Fix ImportError: replace Example_eval with EvaluationExample - Restore missing fields: max_concurrent_topics, max_workers, Example_eval, etc. - Add comprehensive API documentation and curl examples --- .project-metadata.yaml | 4 + app/core/custom_endpoint_manager.py | 256 ++++++++++++++++++++++++++++ app/core/model_endpoints.py | 55 +++++- app/core/model_handlers.py | 99 +++++++++-- app/core/prompt_templates.py | 12 +- app/main.py | 154 ++++++++++++++++- app/models/request_models.py | 89 +++++++++- custom_model_endpoints.json | 17 ++ 8 files changed, 667 insertions(+), 19 deletions(-) create mode 100644 app/core/custom_endpoint_manager.py create mode 100644 custom_model_endpoints.json diff --git a/.project-metadata.yaml b/.project-metadata.yaml index 616b8e94..88c78a4d 100644 --- a/.project-metadata.yaml +++ b/.project-metadata.yaml @@ -42,6 +42,10 @@ environment_variables: default: null description: >- Gemini API Key. Check the Google Gemini documentation for information about role access + OpenAI_Endpoint_Compatible_Key: + default: null + description: >- + API Key for OpenAI Compatible endpoints. Used for custom OpenAI-compatible model endpoints. # runtimes runtimes: - editor: JupyterLab diff --git a/app/core/custom_endpoint_manager.py b/app/core/custom_endpoint_manager.py new file mode 100644 index 00000000..c6206722 --- /dev/null +++ b/app/core/custom_endpoint_manager.py @@ -0,0 +1,256 @@ +import json +import os +import uuid +from datetime import datetime, timezone +from typing import List, Dict, Optional, Any +from pathlib import Path + +from app.models.request_models import ( + CustomEndpoint, CustomCAIIEndpoint, CustomBedrockEndpoint, + CustomOpenAIEndpoint, CustomOpenAICompatibleEndpoint, CustomGeminiEndpoint +) +from app.core.exceptions import APIError + + +class CustomEndpointManager: + """Manager for custom model endpoints""" + + def __init__(self, config_file: str = "custom_model_endpoints.json"): + """ + Initialize the custom endpoint manager + + Args: + config_file: Path to the JSON file storing custom endpoints + """ + self.config_file = Path(config_file) + self._ensure_config_file_exists() + + def _ensure_config_file_exists(self): + """Ensure the configuration file exists with proper structure""" + if not self.config_file.exists(): + initial_config = { + "version": "1.0", + "endpoints": {}, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_updated": datetime.now(timezone.utc).isoformat() + } + with open(self.config_file, 'w') as f: + json.dump(initial_config, f, indent=2) + + def _load_config(self) -> Dict[str, Any]: + """Load configuration from file""" + try: + with open(self.config_file, 'r') as f: + return json.load(f) + except (json.JSONDecodeError, FileNotFoundError) as e: + raise APIError(f"Failed to load custom endpoints configuration: {str(e)}", 500) + + def _save_config(self, config: Dict[str, Any]): + """Save configuration to file""" + try: + config["last_updated"] = datetime.now(timezone.utc).isoformat() + with open(self.config_file, 'w') as f: + json.dump(config, f, indent=2) + except Exception as e: + raise APIError(f"Failed to save custom endpoints configuration: {str(e)}", 500) + + def add_endpoint(self, endpoint: CustomEndpoint) -> str: + """ + Add a new custom endpoint + + Args: + endpoint: Custom endpoint configuration + + Returns: + endpoint_id: The ID of the added endpoint + + Raises: + APIError: If endpoint already exists or validation fails + """ + config = self._load_config() + + # Check if endpoint_id already exists + if endpoint.endpoint_id in config["endpoints"]: + raise APIError(f"Endpoint with ID '{endpoint.endpoint_id}' already exists", 400) + + # Add timestamps + now = datetime.now(timezone.utc).isoformat() + endpoint.created_at = now + endpoint.updated_at = now + + # Store endpoint configuration + config["endpoints"][endpoint.endpoint_id] = endpoint.model_dump() + + self._save_config(config) + return endpoint.endpoint_id + + def get_endpoint(self, endpoint_id: str) -> Optional[CustomEndpoint]: + """ + Get a specific custom endpoint by ID + + Args: + endpoint_id: The endpoint ID to retrieve + + Returns: + CustomEndpoint or None if not found + """ + config = self._load_config() + endpoint_data = config["endpoints"].get(endpoint_id) + + if not endpoint_data: + return None + + return self._parse_endpoint(endpoint_data) + + def get_all_endpoints(self) -> List[CustomEndpoint]: + """ + Get all custom endpoints + + Returns: + List of all custom endpoints + """ + config = self._load_config() + endpoints = [] + + for endpoint_data in config["endpoints"].values(): + try: + endpoint = self._parse_endpoint(endpoint_data) + endpoints.append(endpoint) + except Exception as e: + print(f"Warning: Failed to parse endpoint {endpoint_data.get('endpoint_id', 'unknown')}: {e}") + continue + + return endpoints + + def get_endpoints_by_provider(self, provider_type: str) -> List[CustomEndpoint]: + """ + Get all endpoints for a specific provider + + Args: + provider_type: The provider type to filter by + + Returns: + List of endpoints for the specified provider + """ + all_endpoints = self.get_all_endpoints() + return [ep for ep in all_endpoints if ep.provider_type == provider_type] + + def update_endpoint(self, endpoint_id: str, updated_endpoint: CustomEndpoint) -> bool: + """ + Update an existing custom endpoint + + Args: + endpoint_id: The endpoint ID to update + updated_endpoint: Updated endpoint configuration + + Returns: + True if updated successfully, False if endpoint not found + + Raises: + APIError: If validation fails + """ + config = self._load_config() + + if endpoint_id not in config["endpoints"]: + return False + + # Preserve original created_at timestamp + original_created_at = config["endpoints"][endpoint_id].get("created_at") + + # Update timestamps + updated_endpoint.created_at = original_created_at + updated_endpoint.updated_at = datetime.now(timezone.utc).isoformat() + updated_endpoint.endpoint_id = endpoint_id # Ensure ID consistency + + # Update endpoint configuration + config["endpoints"][endpoint_id] = updated_endpoint.model_dump() + + self._save_config(config) + return True + + def delete_endpoint(self, endpoint_id: str) -> bool: + """ + Delete a custom endpoint + + Args: + endpoint_id: The endpoint ID to delete + + Returns: + True if deleted successfully, False if endpoint not found + """ + config = self._load_config() + + if endpoint_id not in config["endpoints"]: + return False + + del config["endpoints"][endpoint_id] + self._save_config(config) + return True + + def _parse_endpoint(self, endpoint_data: Dict[str, Any]) -> CustomEndpoint: + """ + Parse endpoint data into appropriate CustomEndpoint subclass + + Args: + endpoint_data: Raw endpoint data from config + + Returns: + Parsed CustomEndpoint instance + + Raises: + APIError: If parsing fails + """ + provider_type = endpoint_data.get("provider_type") + + try: + if provider_type == "caii": + return CustomCAIIEndpoint(**endpoint_data) + elif provider_type == "bedrock": + return CustomBedrockEndpoint(**endpoint_data) + elif provider_type == "openai": + return CustomOpenAIEndpoint(**endpoint_data) + elif provider_type == "openai_compatible": + return CustomOpenAICompatibleEndpoint(**endpoint_data) + elif provider_type == "gemini": + return CustomGeminiEndpoint(**endpoint_data) + else: + raise APIError(f"Unknown provider type: {provider_type}", 400) + except Exception as e: + raise APIError(f"Failed to parse endpoint configuration: {str(e)}", 500) + + def validate_endpoint_id(self, endpoint_id: str) -> bool: + """ + Validate endpoint ID format + + Args: + endpoint_id: The endpoint ID to validate + + Returns: + True if valid, False otherwise + """ + if not endpoint_id or not isinstance(endpoint_id, str): + return False + + # Allow alphanumeric, hyphens, and underscores + import re + return bool(re.match(r'^[a-zA-Z0-9_-]+$', endpoint_id)) + + def get_endpoint_stats(self) -> Dict[str, Any]: + """ + Get statistics about custom endpoints + + Returns: + Dictionary with endpoint statistics + """ + endpoints = self.get_all_endpoints() + + provider_counts = {} + for endpoint in endpoints: + provider_type = endpoint.provider_type + provider_counts[provider_type] = provider_counts.get(provider_type, 0) + 1 + + return { + "total_endpoints": len(endpoints), + "provider_counts": provider_counts, + "endpoint_ids": [ep.endpoint_id for ep in endpoints] + } diff --git a/app/core/model_endpoints.py b/app/core/model_endpoints.py index 8d652f34..c9e0a091 100644 --- a/app/core/model_endpoints.py +++ b/app/core/model_endpoints.py @@ -324,7 +324,10 @@ async def bound(p: _CaiiPair): # Single orchestrator used by the api endpoint # ──────────────────────────────────────────────────────────────── async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]: - """Collect and health-check models from all providers.""" + """Collect and health-check models from all providers, including custom endpoints.""" + + # Import here to avoid circular imports + from app.core.custom_endpoint_manager import CustomEndpointManager # Bedrock bedrock_all = list_bedrock_models() @@ -350,6 +353,10 @@ async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]: "google_gemini": { "enabled": gemini_enabled, "disabled": gemini_disabled, + }, + "openai_compatible": { + "enabled": [], + "disabled": [], } } @@ -364,5 +371,51 @@ async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]: else: catalog["CAII"] = {} + # Add custom endpoints + try: + custom_manager = CustomEndpointManager() + custom_endpoints = custom_manager.get_all_endpoints() + + for endpoint in custom_endpoints: + provider_key = _get_catalog_key_for_provider(endpoint.provider_type) + + if provider_key not in catalog: + catalog[provider_key] = {"enabled": [], "disabled": []} + + # For now, assume custom endpoints are enabled (we could add health checks later) + if endpoint.provider_type in ["caii"]: + # CAII format: {"model": name, "endpoint": url} + catalog[provider_key]["enabled"].append({ + "model": endpoint.model_id, + "endpoint": getattr(endpoint, 'endpoint_url', ''), + "custom": True, + "endpoint_id": endpoint.endpoint_id, + "display_name": endpoint.display_name + }) + else: + # Other providers: just the model name with custom metadata + catalog[provider_key]["enabled"].append({ + "model": endpoint.model_id, + "custom": True, + "endpoint_id": endpoint.endpoint_id, + "display_name": endpoint.display_name, + "provider_type": endpoint.provider_type + }) + + except Exception as e: + print(f"Warning: Failed to load custom endpoints: {e}") + return catalog + +def _get_catalog_key_for_provider(provider_type: str) -> str: + """Map provider types to catalog keys""" + mapping = { + "bedrock": "aws_bedrock", + "openai": "openai", + "openai_compatible": "openai_compatible", + "gemini": "google_gemini", + "caii": "CAII" + } + return mapping.get(provider_type, provider_type) + diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py index b61a0ec7..6514d514 100644 --- a/app/core/model_handlers.py +++ b/app/core/model_handlers.py @@ -18,6 +18,33 @@ import google.generativeai as genai +def get_custom_endpoint_config(model_id: str, provider_type: str): + """ + Get custom endpoint configuration for a model if it exists + + Args: + model_id: The model identifier + provider_type: The provider type + + Returns: + Custom endpoint configuration or None + """ + try: + from app.core.custom_endpoint_manager import CustomEndpointManager + + custom_manager = CustomEndpointManager() + custom_endpoints = custom_manager.get_endpoints_by_provider(provider_type) + + # Find endpoint matching the model_id + for endpoint in custom_endpoints: + if endpoint.model_id == model_id: + return endpoint + + return None + except Exception as e: + print(f"Warning: Failed to get custom endpoint config: {e}") + return None + class UnifiedModelHandler: """Unified handler for all model types using Bedrock's converse API""" @@ -310,6 +337,19 @@ def _handle_openai_request(self, prompt: str): import httpx from openai import OpenAI + # Check for custom endpoint configuration + custom_config = get_custom_endpoint_config(self.model_id, "openai") + + # Get API key from custom config or environment + if custom_config: + api_key = custom_config.api_key + print(f"Using custom OpenAI endpoint for model: {self.model_id}") + else: + api_key = os.getenv('OPENAI_API_KEY') + + if not api_key: + raise ModelHandlerError("OpenAI API key not available", 500) + # Configure timeout for OpenAI client (OpenAI v1.57.2) timeout_config = httpx.Timeout( connect=self.OPENAI_CONNECT_TIMEOUT, @@ -328,7 +368,7 @@ def _handle_openai_request(self, prompt: str): http_client = httpx.Client(timeout=timeout_config) client = OpenAI( - api_key=os.getenv('OPENAI_API_KEY'), + api_key=api_key, http_client=http_client ) completion = client.chat.completions.create( @@ -351,13 +391,22 @@ def _handle_openai_compatible_request(self, prompt: str): import httpx from openai import OpenAI - # Get API key from environment variable (only credential needed) - api_key = os.getenv('OpenAI_Endpoint_Compatible_Key') + # Check for custom endpoint configuration + custom_config = get_custom_endpoint_config(self.model_id, "openai_compatible") + + if custom_config: + # Use custom endpoint configuration + api_key = custom_config.api_key + openai_compatible_endpoint = custom_config.endpoint_url + print(f"Using custom OpenAI compatible endpoint for model: {self.model_id}") + else: + # Fallback to environment variables and initialization parameters + api_key = os.getenv('OpenAI_Endpoint_Compatible_Key') + openai_compatible_endpoint = self.caii_endpoint + if not api_key: - raise ModelHandlerError("OpenAI_Endpoint_Compatible_Key environment variable not set", 500) + raise ModelHandlerError("OpenAI compatible API key not available", 500) - # Base URL comes from caii_endpoint parameter (passed during initialization) - openai_compatible_endpoint = self.caii_endpoint if not openai_compatible_endpoint: raise ModelHandlerError("OpenAI compatible endpoint not provided", 500) @@ -412,7 +461,20 @@ def _handle_gemini_request(self, prompt: str): 500, ) try: - genai.configure(api_key=os.getenv("GEMINI_API_KEY")) + # Check for custom endpoint configuration + custom_config = get_custom_endpoint_config(self.model_id, "gemini") + + # Get API key from custom config or environment + if custom_config: + api_key = custom_config.api_key + print(f"Using custom Gemini endpoint for model: {self.model_id}") + else: + api_key = os.getenv("GEMINI_API_KEY") + + if not api_key: + raise ModelHandlerError("Gemini API key not available", 500) + + genai.configure(api_key=api_key) model = genai.GenerativeModel(self.model_id) # e.g. 'gemini-1.5-pro-latest' resp = model.generate_content( prompt, @@ -437,9 +499,26 @@ def _handle_caii_request(self, prompt: str): import httpx from openai import OpenAI - API_KEY = _get_caii_token() - MODEL_ID = self.model_id - caii_endpoint = self.caii_endpoint + # Check for custom endpoint configuration + custom_config = get_custom_endpoint_config(self.model_id, "caii") + + if custom_config: + # Use custom endpoint configuration + API_KEY = custom_config.cdp_token + MODEL_ID = self.model_id + caii_endpoint = custom_config.endpoint_url + print(f"Using custom CAII endpoint for model: {self.model_id}") + else: + # Fallback to environment variables and initialization parameters + API_KEY = _get_caii_token() + MODEL_ID = self.model_id + caii_endpoint = self.caii_endpoint + + if not API_KEY: + raise ModelHandlerError("CAII API key not available", 500) + + if not caii_endpoint: + raise ModelHandlerError("CAII endpoint not provided", 500) caii_endpoint = caii_endpoint.removesuffix('/chat/completions') diff --git a/app/core/prompt_templates.py b/app/core/prompt_templates.py index 7500177f..04ede677 100644 --- a/app/core/prompt_templates.py +++ b/app/core/prompt_templates.py @@ -4,7 +4,7 @@ import os import pandas as pd import numpy as np -from app.models.request_models import Example, Example_eval +from app.models.request_models import Example, EvaluationExample from app.core.config import UseCase, Technique, ModelFamily, get_model_family,USE_CASE_CONFIGS, LENDING_DATA_PROMPT, USE_CASE_CONFIGS_EVALS from app.core.data_loader import DataLoader from app.core.data_analyser import DataAnalyser @@ -179,7 +179,7 @@ def format_examples(examples: List[Example]) -> str: ] @staticmethod - def format_examples_eval(examples: List[Example_eval]) -> str: + def format_examples_eval(examples: List[EvaluationExample]) -> str: """Format examples as JSON string""" return [ {"score": example.score, "justification": example.justification} @@ -506,7 +506,7 @@ def get_eval_prompt(model_id: str, use_case: UseCase, question: str, solution: str, - examples: List[Example_eval], + examples: List[EvaluationExample], custom_prompt = Optional[str] ) -> str: custom_prompt_str = PromptHandler.get_default_custom_eval_prompt(use_case, custom_prompt) @@ -562,7 +562,7 @@ def get_eval_prompt(model_id: str, def get_freeform_eval_prompt(model_id: str, use_case: UseCase, row: Dict[str, Any], - examples: List[Example_eval], + examples: List[EvaluationExample], custom_prompt = Optional[str] ) -> str: custom_prompt_str = PromptHandler.get_default_custom_eval_prompt(use_case, custom_prompt) @@ -1110,7 +1110,7 @@ def build_eval_prompt(model_id: str, use_case: UseCase, question: str, solution: str, - examples: List[Example_eval], + examples: List[EvaluationExample], custom_prompt = Optional[str] ) -> str: @@ -1154,7 +1154,7 @@ def build_freeform_prompt(model_id: str, def build_freeform_eval_prompt(model_id: str, use_case: UseCase, row: Dict[str, Any], - examples: List[Example_eval], + examples: List[EvaluationExample], custom_prompt = Optional[str] ) -> str: diff --git a/app/main.py b/app/main.py index 2becafb8..22e4dc21 100644 --- a/app/main.py +++ b/app/main.py @@ -43,7 +43,7 @@ from app.services.evaluator_service import EvaluatorService from app.services.evaluator_legacy_service import EvaluatorLegacyService -from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath, Technique +from app.models.request_models import SynthesisRequest, EvaluationRequest, Export_synth, ModelParameters, CustomPromptRequest, JsonDataSize, RelativePath, Technique, AddCustomEndpointRequest, CustomEndpointListResponse from app.services.synthesis_service import SynthesisService from app.services.synthesis_legacy_service import SynthesisLegacyService from app.services.export_results import Export_Service @@ -59,6 +59,7 @@ from app.core.config import responses, caii_check from app.core.path_manager import PathManager from app.core.model_endpoints import collect_model_catalog, sort_unique_models, list_bedrock_models +from app.core.custom_endpoint_manager import CustomEndpointManager # from app.core.telemetry_middleware import TelemetryMiddleware # from app.routes.telemetry_routes import router as telemetry_router @@ -74,6 +75,7 @@ evaluator_legacy_service = EvaluatorLegacyService() # SFT and Custom_Workflow export_service = Export_Service() db_manager = DatabaseManager() +custom_endpoint_manager = CustomEndpointManager() #Initialize path manager @@ -1475,6 +1477,156 @@ async def perform_upgrade(): except Exception as e: raise HTTPException(status_code=500, detail=f"Upgrade failed: {str(e)}") + + +# ──────────────────────────────────────────────────────────────── +# Custom Model Endpoint Management APIs +# ──────────────────────────────────────────────────────────────── + +@app.post("/add_model_endpoint", include_in_schema=True, responses=responses, + description="Add a custom model endpoint") +async def add_custom_model_endpoint(request: AddCustomEndpointRequest): + """Add a new custom model endpoint""" + try: + # Validate endpoint ID format + if not custom_endpoint_manager.validate_endpoint_id(request.endpoint_config.endpoint_id): + raise HTTPException( + status_code=400, + detail="Invalid endpoint ID. Use only alphanumeric characters, hyphens, and underscores." + ) + + endpoint_id = custom_endpoint_manager.add_endpoint(request.endpoint_config) + + return { + "status": "success", + "message": f"Custom endpoint '{endpoint_id}' added successfully", + "endpoint_id": endpoint_id + } + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to add custom endpoint: {str(e)}") + + +@app.get("/custom_model_endpoints", include_in_schema=True, responses=responses, + description="List all custom model endpoints", response_model=CustomEndpointListResponse) +async def list_custom_model_endpoints(provider_type: Optional[str] = None): + """List all custom model endpoints, optionally filtered by provider type""" + try: + if provider_type: + endpoints = custom_endpoint_manager.get_endpoints_by_provider(provider_type) + else: + endpoints = custom_endpoint_manager.get_all_endpoints() + + return CustomEndpointListResponse( + endpoints=endpoints, + total=len(endpoints) + ) + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to list custom endpoints: {str(e)}") + + +@app.get("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, + description="Get a specific custom model endpoint") +async def get_custom_model_endpoint(endpoint_id: str): + """Get details of a specific custom model endpoint""" + try: + endpoint = custom_endpoint_manager.get_endpoint(endpoint_id) + + if not endpoint: + raise HTTPException( + status_code=404, + detail=f"Custom endpoint '{endpoint_id}' not found" + ) + + return { + "status": "success", + "endpoint": endpoint + } + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get custom endpoint: {str(e)}") + + +@app.put("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, + description="Update a custom model endpoint") +async def update_custom_model_endpoint(endpoint_id: str, request: AddCustomEndpointRequest): + """Update an existing custom model endpoint""" + try: + # Ensure the endpoint ID in the request matches the URL parameter + if request.endpoint_config.endpoint_id != endpoint_id: + raise HTTPException( + status_code=400, + detail="Endpoint ID in request body must match the URL parameter" + ) + + success = custom_endpoint_manager.update_endpoint(endpoint_id, request.endpoint_config) + + if not success: + raise HTTPException( + status_code=404, + detail=f"Custom endpoint '{endpoint_id}' not found" + ) + + return { + "status": "success", + "message": f"Custom endpoint '{endpoint_id}' updated successfully" + } + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update custom endpoint: {str(e)}") + + +@app.delete("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, + description="Delete a custom model endpoint") +async def delete_custom_model_endpoint(endpoint_id: str): + """Delete a custom model endpoint""" + try: + success = custom_endpoint_manager.delete_endpoint(endpoint_id) + + if not success: + raise HTTPException( + status_code=404, + detail=f"Custom endpoint '{endpoint_id}' not found" + ) + + return { + "status": "success", + "message": f"Custom endpoint '{endpoint_id}' deleted successfully" + } + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to delete custom endpoint: {str(e)}") + + +@app.get("/custom_model_endpoints_stats", include_in_schema=True, responses=responses, + description="Get statistics about custom model endpoints") +async def get_custom_model_endpoints_stats(): + """Get statistics about custom model endpoints""" + try: + stats = custom_endpoint_manager.get_endpoint_stats() + + return { + "status": "success", + "stats": stats + } + + except APIError as e: + raise HTTPException(status_code=e.status_code, detail=e.message) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get endpoint statistics: {str(e)}") + + #****** comment below for testing just backend************** current_directory = os.path.dirname(os.path.abspath(__file__)) client_build_path = os.path.join(current_directory, "client", "dist") diff --git a/app/models/request_models.py b/app/models/request_models.py index bf22343f..57ec9c92 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -45,6 +45,12 @@ class Example_eval(BaseModel): } ) +class EvaluationExample(BaseModel): + question: str + answer: str + score: float + justification: str + # In app/models/request_models.py class S3Config(BaseModel): @@ -275,4 +281,85 @@ class CustomPromptRequest(BaseModel): } } ) - \ No newline at end of file + + +# Custom Endpoint Models +class CustomEndpointBase(BaseModel): + """Base model for custom endpoints""" + endpoint_id: str = Field(..., description="Unique identifier for the custom endpoint") + display_name: str = Field(..., description="Human-readable name for the endpoint") + model_id: str = Field(..., description="Model identifier") + provider_type: str = Field(..., description="Provider type: caii, bedrock, openai, openai_compatible, gemini") + created_at: Optional[str] = Field(default=None, description="Creation timestamp") + updated_at: Optional[str] = Field(default=None, description="Last update timestamp") + + +class CustomCAIIEndpoint(CustomEndpointBase): + """Custom CAII endpoint configuration""" + provider_type: str = Field(default="caii", description="Provider type") + endpoint_url: str = Field(..., description="CAII endpoint URL") + cdp_token: str = Field(..., description="CDP token for authentication") + + +class CustomBedrockEndpoint(CustomEndpointBase): + """Custom Bedrock endpoint configuration""" + provider_type: str = Field(default="bedrock", description="Provider type") + endpoint_url: str = Field(..., description="Custom Bedrock endpoint URL") + aws_access_key_id: str = Field(..., description="AWS Access Key ID") + aws_secret_access_key: str = Field(..., description="AWS Secret Access Key") + aws_region: str = Field(default="us-west-2", description="AWS region") + + +class CustomOpenAIEndpoint(CustomEndpointBase): + """Custom OpenAI endpoint configuration""" + provider_type: str = Field(default="openai", description="Provider type") + api_key: str = Field(..., description="OpenAI API key") + + +class CustomOpenAICompatibleEndpoint(CustomEndpointBase): + """Custom OpenAI Compatible endpoint configuration""" + provider_type: str = Field(default="openai_compatible", description="Provider type") + endpoint_url: str = Field(..., description="OpenAI compatible endpoint URL") + api_key: str = Field(..., description="API key for authentication") + + +class CustomGeminiEndpoint(CustomEndpointBase): + """Custom Gemini endpoint configuration""" + provider_type: str = Field(default="gemini", description="Provider type") + api_key: str = Field(..., description="Gemini API key") + + +# Union type for all custom endpoint types +CustomEndpoint = Union[ + CustomCAIIEndpoint, + CustomBedrockEndpoint, + CustomOpenAIEndpoint, + CustomOpenAICompatibleEndpoint, + CustomGeminiEndpoint +] + + +class AddCustomEndpointRequest(BaseModel): + """Request model for adding custom endpoints""" + endpoint_config: CustomEndpoint = Field(..., description="Custom endpoint configuration") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "endpoint_config": { + "endpoint_id": "my-custom-claude", + "display_name": "My Custom Claude Instance", + "model_id": "claude-3-sonnet-20240229", + "provider_type": "openai_compatible", + "endpoint_url": "https://my-endpoint.com/v1", + "api_key": "sk-..." + } + } + } + ) + + +class CustomEndpointListResponse(BaseModel): + """Response model for listing custom endpoints""" + endpoints: List[CustomEndpoint] = Field(default=[], description="List of custom endpoints") + total: int = Field(..., description="Total number of custom endpoints") \ No newline at end of file diff --git a/custom_model_endpoints.json b/custom_model_endpoints.json new file mode 100644 index 00000000..0e1b2efb --- /dev/null +++ b/custom_model_endpoints.json @@ -0,0 +1,17 @@ +{ + "version": "1.0", + "endpoints": { + "my-caii-llama": { + "endpoint_id": "my-caii-llama", + "display_name": "My CAII Llama Model", + "model_id": "llama-3.1-70b-instruct", + "provider_type": "caii", + "created_at": "2025-09-24T05:46:04.854909+00:00", + "updated_at": "2025-09-24T05:46:04.854909+00:00", + "endpoint_url": "https://modelservice.ml.my-cluster.com/model/llamav3/invocations", + "cdp_token": "abc123def456" + } + }, + "created_at": "2025-09-24T05:24:03.451632+00:00", + "last_updated": "2025-09-24T05:46:04.854973+00:00" +} \ No newline at end of file From 3fa8d23f1a4c2772712d51189382bc1b2e5f3c29 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Wed, 24 Sep 2025 13:17:17 +0530 Subject: [PATCH 06/12] dipaly name optional --- app/models/request_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/models/request_models.py b/app/models/request_models.py index 57ec9c92..f906eb35 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -287,7 +287,7 @@ class CustomPromptRequest(BaseModel): class CustomEndpointBase(BaseModel): """Base model for custom endpoints""" endpoint_id: str = Field(..., description="Unique identifier for the custom endpoint") - display_name: str = Field(..., description="Human-readable name for the endpoint") + display_name: Optional[str] = Field(default=None, description="Human-readable name for the endpoint (optional)") model_id: str = Field(..., description="Model identifier") provider_type: str = Field(..., description="Provider type: caii, bedrock, openai, openai_compatible, gemini") created_at: Optional[str] = Field(default=None, description="Creation timestamp") From 12e41d666df0809f52f1dcb49251daaaa2ec08b2 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Wed, 24 Sep 2025 22:31:23 +0530 Subject: [PATCH 07/12] Simplify custom endpoint management system - Remove unnecessary endpoint_id field (use model_id + provider_type as natural unique key) - Remove display_name, created_at, updated_at metadata fields (not needed for server operations) - Remove endpoint_url from Bedrock/OpenAI/Gemini (they use standard APIs) - Keep endpoint_url only for CAII and OpenAI Compatible (custom deployments) - Update API endpoints to use /model_id/provider_type pattern - Simplify JSON storage structure to use model_id:provider_type keys - Test confirmed: server starts and endpoints work correctly --- app/core/custom_endpoint_manager.py | 90 +++++++++++------------------ app/core/model_endpoints.py | 6 +- app/core/model_handlers.py | 8 +-- app/main.py | 49 +++++++--------- app/models/request_models.py | 38 +++++------- custom_model_endpoints.json | 5 ++ 6 files changed, 79 insertions(+), 117 deletions(-) diff --git a/app/core/custom_endpoint_manager.py b/app/core/custom_endpoint_manager.py index c6206722..7e55aa45 100644 --- a/app/core/custom_endpoint_manager.py +++ b/app/core/custom_endpoint_manager.py @@ -13,7 +13,7 @@ class CustomEndpointManager: - """Manager for custom model endpoints""" + """Manager for custom model endpoints - simplified""" def __init__(self, config_file: str = "custom_model_endpoints.json"): """ @@ -30,9 +30,7 @@ def _ensure_config_file_exists(self): if not self.config_file.exists(): initial_config = { "version": "1.0", - "endpoints": {}, - "created_at": datetime.now(timezone.utc).isoformat(), - "last_updated": datetime.now(timezone.utc).isoformat() + "endpoints": {} } with open(self.config_file, 'w') as f: json.dump(initial_config, f, indent=2) @@ -48,7 +46,6 @@ def _load_config(self) -> Dict[str, Any]: def _save_config(self, config: Dict[str, Any]): """Save configuration to file""" try: - config["last_updated"] = datetime.now(timezone.utc).isoformat() with open(self.config_file, 'w') as f: json.dump(config, f, indent=2) except Exception as e: @@ -62,40 +59,40 @@ def add_endpoint(self, endpoint: CustomEndpoint) -> str: endpoint: Custom endpoint configuration Returns: - endpoint_id: The ID of the added endpoint + unique_key: The unique key for the added endpoint (model_id:provider_type) Raises: APIError: If endpoint already exists or validation fails """ config = self._load_config() - # Check if endpoint_id already exists - if endpoint.endpoint_id in config["endpoints"]: - raise APIError(f"Endpoint with ID '{endpoint.endpoint_id}' already exists", 400) + # Use model_id + provider_type as unique key + unique_key = f"{endpoint.model_id}:{endpoint.provider_type}" - # Add timestamps - now = datetime.now(timezone.utc).isoformat() - endpoint.created_at = now - endpoint.updated_at = now + # Check if endpoint already exists + if unique_key in config["endpoints"]: + raise APIError(f"Endpoint for model '{endpoint.model_id}' with provider '{endpoint.provider_type}' already exists", 400) # Store endpoint configuration - config["endpoints"][endpoint.endpoint_id] = endpoint.model_dump() + config["endpoints"][unique_key] = endpoint.model_dump() self._save_config(config) - return endpoint.endpoint_id + return unique_key - def get_endpoint(self, endpoint_id: str) -> Optional[CustomEndpoint]: + def get_endpoint(self, model_id: str, provider_type: str) -> Optional[CustomEndpoint]: """ - Get a specific custom endpoint by ID + Get a specific custom endpoint by model_id and provider_type Args: - endpoint_id: The endpoint ID to retrieve + model_id: The model identifier + provider_type: The provider type Returns: CustomEndpoint or None if not found """ config = self._load_config() - endpoint_data = config["endpoints"].get(endpoint_id) + unique_key = f"{model_id}:{provider_type}" + endpoint_data = config["endpoints"].get(unique_key) if not endpoint_data: return None @@ -112,12 +109,12 @@ def get_all_endpoints(self) -> List[CustomEndpoint]: config = self._load_config() endpoints = [] - for endpoint_data in config["endpoints"].values(): + for unique_key, endpoint_data in config["endpoints"].items(): try: endpoint = self._parse_endpoint(endpoint_data) endpoints.append(endpoint) except Exception as e: - print(f"Warning: Failed to parse endpoint {endpoint_data.get('endpoint_id', 'unknown')}: {e}") + print(f"Warning: Failed to parse endpoint {unique_key}: {e}") continue return endpoints @@ -135,12 +132,13 @@ def get_endpoints_by_provider(self, provider_type: str) -> List[CustomEndpoint]: all_endpoints = self.get_all_endpoints() return [ep for ep in all_endpoints if ep.provider_type == provider_type] - def update_endpoint(self, endpoint_id: str, updated_endpoint: CustomEndpoint) -> bool: + def update_endpoint(self, model_id: str, provider_type: str, updated_endpoint: CustomEndpoint) -> bool: """ Update an existing custom endpoint Args: - endpoint_id: The endpoint ID to update + model_id: The model identifier + provider_type: The provider type updated_endpoint: Updated endpoint configuration Returns: @@ -151,39 +149,36 @@ def update_endpoint(self, endpoint_id: str, updated_endpoint: CustomEndpoint) -> """ config = self._load_config() - if endpoint_id not in config["endpoints"]: - return False - - # Preserve original created_at timestamp - original_created_at = config["endpoints"][endpoint_id].get("created_at") + unique_key = f"{model_id}:{provider_type}" - # Update timestamps - updated_endpoint.created_at = original_created_at - updated_endpoint.updated_at = datetime.now(timezone.utc).isoformat() - updated_endpoint.endpoint_id = endpoint_id # Ensure ID consistency + if unique_key not in config["endpoints"]: + return False # Update endpoint configuration - config["endpoints"][endpoint_id] = updated_endpoint.model_dump() + config["endpoints"][unique_key] = updated_endpoint.model_dump() self._save_config(config) return True - def delete_endpoint(self, endpoint_id: str) -> bool: + def delete_endpoint(self, model_id: str, provider_type: str) -> bool: """ Delete a custom endpoint Args: - endpoint_id: The endpoint ID to delete + model_id: The model identifier + provider_type: The provider type Returns: True if deleted successfully, False if endpoint not found """ config = self._load_config() - if endpoint_id not in config["endpoints"]: + unique_key = f"{model_id}:{provider_type}" + + if unique_key not in config["endpoints"]: return False - del config["endpoints"][endpoint_id] + del config["endpoints"][unique_key] self._save_config(config) return True @@ -218,23 +213,6 @@ def _parse_endpoint(self, endpoint_data: Dict[str, Any]) -> CustomEndpoint: except Exception as e: raise APIError(f"Failed to parse endpoint configuration: {str(e)}", 500) - def validate_endpoint_id(self, endpoint_id: str) -> bool: - """ - Validate endpoint ID format - - Args: - endpoint_id: The endpoint ID to validate - - Returns: - True if valid, False otherwise - """ - if not endpoint_id or not isinstance(endpoint_id, str): - return False - - # Allow alphanumeric, hyphens, and underscores - import re - return bool(re.match(r'^[a-zA-Z0-9_-]+$', endpoint_id)) - def get_endpoint_stats(self) -> Dict[str, Any]: """ Get statistics about custom endpoints @@ -245,12 +223,14 @@ def get_endpoint_stats(self) -> Dict[str, Any]: endpoints = self.get_all_endpoints() provider_counts = {} + model_list = [] for endpoint in endpoints: provider_type = endpoint.provider_type provider_counts[provider_type] = provider_counts.get(provider_type, 0) + 1 + model_list.append(f"{endpoint.model_id} ({endpoint.provider_type})") return { "total_endpoints": len(endpoints), "provider_counts": provider_counts, - "endpoint_ids": [ep.endpoint_id for ep in endpoints] + "models": model_list } diff --git a/app/core/model_endpoints.py b/app/core/model_endpoints.py index c9e0a091..a58ad2d9 100644 --- a/app/core/model_endpoints.py +++ b/app/core/model_endpoints.py @@ -388,17 +388,13 @@ async def collect_model_catalog() -> Dict[str, Dict[str, List[str]]]: catalog[provider_key]["enabled"].append({ "model": endpoint.model_id, "endpoint": getattr(endpoint, 'endpoint_url', ''), - "custom": True, - "endpoint_id": endpoint.endpoint_id, - "display_name": endpoint.display_name + "custom": True }) else: # Other providers: just the model name with custom metadata catalog[provider_key]["enabled"].append({ "model": endpoint.model_id, "custom": True, - "endpoint_id": endpoint.endpoint_id, - "display_name": endpoint.display_name, "provider_type": endpoint.provider_type }) diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py index 6514d514..3d4fd434 100644 --- a/app/core/model_handlers.py +++ b/app/core/model_handlers.py @@ -33,14 +33,8 @@ def get_custom_endpoint_config(model_id: str, provider_type: str): from app.core.custom_endpoint_manager import CustomEndpointManager custom_manager = CustomEndpointManager() - custom_endpoints = custom_manager.get_endpoints_by_provider(provider_type) - - # Find endpoint matching the model_id - for endpoint in custom_endpoints: - if endpoint.model_id == model_id: - return endpoint + return custom_manager.get_endpoint(model_id, provider_type) - return None except Exception as e: print(f"Warning: Failed to get custom endpoint config: {e}") return None diff --git a/app/main.py b/app/main.py index 22e4dc21..4385eb48 100644 --- a/app/main.py +++ b/app/main.py @@ -1488,19 +1488,14 @@ async def perform_upgrade(): async def add_custom_model_endpoint(request: AddCustomEndpointRequest): """Add a new custom model endpoint""" try: - # Validate endpoint ID format - if not custom_endpoint_manager.validate_endpoint_id(request.endpoint_config.endpoint_id): - raise HTTPException( - status_code=400, - detail="Invalid endpoint ID. Use only alphanumeric characters, hyphens, and underscores." - ) - - endpoint_id = custom_endpoint_manager.add_endpoint(request.endpoint_config) + unique_key = custom_endpoint_manager.add_endpoint(request.endpoint_config) return { "status": "success", - "message": f"Custom endpoint '{endpoint_id}' added successfully", - "endpoint_id": endpoint_id + "message": f"Custom endpoint for '{request.endpoint_config.model_id}' ({request.endpoint_config.provider_type}) added successfully", + "model_id": request.endpoint_config.model_id, + "provider_type": request.endpoint_config.provider_type, + "unique_key": unique_key } except APIError as e: @@ -1530,17 +1525,17 @@ async def list_custom_model_endpoints(provider_type: Optional[str] = None): raise HTTPException(status_code=500, detail=f"Failed to list custom endpoints: {str(e)}") -@app.get("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, +@app.get("/custom_model_endpoints/{model_id}/{provider_type}", include_in_schema=True, responses=responses, description="Get a specific custom model endpoint") -async def get_custom_model_endpoint(endpoint_id: str): +async def get_custom_model_endpoint(model_id: str, provider_type: str): """Get details of a specific custom model endpoint""" try: - endpoint = custom_endpoint_manager.get_endpoint(endpoint_id) + endpoint = custom_endpoint_manager.get_endpoint(model_id, provider_type) if not endpoint: raise HTTPException( status_code=404, - detail=f"Custom endpoint '{endpoint_id}' not found" + detail=f"Custom endpoint for model '{model_id}' with provider '{provider_type}' not found" ) return { @@ -1554,29 +1549,29 @@ async def get_custom_model_endpoint(endpoint_id: str): raise HTTPException(status_code=500, detail=f"Failed to get custom endpoint: {str(e)}") -@app.put("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, +@app.put("/custom_model_endpoints/{model_id}/{provider_type}", include_in_schema=True, responses=responses, description="Update a custom model endpoint") -async def update_custom_model_endpoint(endpoint_id: str, request: AddCustomEndpointRequest): +async def update_custom_model_endpoint(model_id: str, provider_type: str, request: AddCustomEndpointRequest): """Update an existing custom model endpoint""" try: - # Ensure the endpoint ID in the request matches the URL parameter - if request.endpoint_config.endpoint_id != endpoint_id: + # Ensure the model_id and provider_type in the request match the URL parameters + if request.endpoint_config.model_id != model_id or request.endpoint_config.provider_type != provider_type: raise HTTPException( status_code=400, - detail="Endpoint ID in request body must match the URL parameter" + detail="Model ID and provider type in request body must match the URL parameters" ) - success = custom_endpoint_manager.update_endpoint(endpoint_id, request.endpoint_config) + success = custom_endpoint_manager.update_endpoint(model_id, provider_type, request.endpoint_config) if not success: raise HTTPException( status_code=404, - detail=f"Custom endpoint '{endpoint_id}' not found" + detail=f"Custom endpoint for model '{model_id}' with provider '{provider_type}' not found" ) return { "status": "success", - "message": f"Custom endpoint '{endpoint_id}' updated successfully" + "message": f"Custom endpoint for '{model_id}' ({provider_type}) updated successfully" } except APIError as e: @@ -1585,22 +1580,22 @@ async def update_custom_model_endpoint(endpoint_id: str, request: AddCustomEndpo raise HTTPException(status_code=500, detail=f"Failed to update custom endpoint: {str(e)}") -@app.delete("/custom_model_endpoints/{endpoint_id}", include_in_schema=True, responses=responses, +@app.delete("/custom_model_endpoints/{model_id}/{provider_type}", include_in_schema=True, responses=responses, description="Delete a custom model endpoint") -async def delete_custom_model_endpoint(endpoint_id: str): +async def delete_custom_model_endpoint(model_id: str, provider_type: str): """Delete a custom model endpoint""" try: - success = custom_endpoint_manager.delete_endpoint(endpoint_id) + success = custom_endpoint_manager.delete_endpoint(model_id, provider_type) if not success: raise HTTPException( status_code=404, - detail=f"Custom endpoint '{endpoint_id}' not found" + detail=f"Custom endpoint for model '{model_id}' with provider '{provider_type}' not found" ) return { "status": "success", - "message": f"Custom endpoint '{endpoint_id}' deleted successfully" + "message": f"Custom endpoint for '{model_id}' ({provider_type}) deleted successfully" } except APIError as e: diff --git a/app/models/request_models.py b/app/models/request_models.py index f906eb35..e5f2f202 100644 --- a/app/models/request_models.py +++ b/app/models/request_models.py @@ -283,48 +283,42 @@ class CustomPromptRequest(BaseModel): ) -# Custom Endpoint Models -class CustomEndpointBase(BaseModel): - """Base model for custom endpoints""" - endpoint_id: str = Field(..., description="Unique identifier for the custom endpoint") - display_name: Optional[str] = Field(default=None, description="Human-readable name for the endpoint (optional)") +# Custom Endpoint Models - Ultra Simplified +class CustomCAIIEndpoint(BaseModel): + """Custom CAII endpoint - needs custom URL""" model_id: str = Field(..., description="Model identifier") - provider_type: str = Field(..., description="Provider type: caii, bedrock, openai, openai_compatible, gemini") - created_at: Optional[str] = Field(default=None, description="Creation timestamp") - updated_at: Optional[str] = Field(default=None, description="Last update timestamp") - - -class CustomCAIIEndpoint(CustomEndpointBase): - """Custom CAII endpoint configuration""" provider_type: str = Field(default="caii", description="Provider type") endpoint_url: str = Field(..., description="CAII endpoint URL") cdp_token: str = Field(..., description="CDP token for authentication") -class CustomBedrockEndpoint(CustomEndpointBase): - """Custom Bedrock endpoint configuration""" +class CustomBedrockEndpoint(BaseModel): + """Custom Bedrock endpoint - uses standard AWS Bedrock API""" + model_id: str = Field(..., description="Model identifier") provider_type: str = Field(default="bedrock", description="Provider type") - endpoint_url: str = Field(..., description="Custom Bedrock endpoint URL") aws_access_key_id: str = Field(..., description="AWS Access Key ID") aws_secret_access_key: str = Field(..., description="AWS Secret Access Key") aws_region: str = Field(default="us-west-2", description="AWS region") -class CustomOpenAIEndpoint(CustomEndpointBase): - """Custom OpenAI endpoint configuration""" +class CustomOpenAIEndpoint(BaseModel): + """Custom OpenAI endpoint - uses standard OpenAI API""" + model_id: str = Field(..., description="Model identifier") provider_type: str = Field(default="openai", description="Provider type") api_key: str = Field(..., description="OpenAI API key") -class CustomOpenAICompatibleEndpoint(CustomEndpointBase): - """Custom OpenAI Compatible endpoint configuration""" +class CustomOpenAICompatibleEndpoint(BaseModel): + """Custom OpenAI Compatible endpoint - needs custom URL""" + model_id: str = Field(..., description="Model identifier") provider_type: str = Field(default="openai_compatible", description="Provider type") endpoint_url: str = Field(..., description="OpenAI compatible endpoint URL") api_key: str = Field(..., description="API key for authentication") -class CustomGeminiEndpoint(CustomEndpointBase): - """Custom Gemini endpoint configuration""" +class CustomGeminiEndpoint(BaseModel): + """Custom Gemini endpoint - uses standard Gemini API""" + model_id: str = Field(..., description="Model identifier") provider_type: str = Field(default="gemini", description="Provider type") api_key: str = Field(..., description="Gemini API key") @@ -347,8 +341,6 @@ class AddCustomEndpointRequest(BaseModel): json_schema_extra={ "example": { "endpoint_config": { - "endpoint_id": "my-custom-claude", - "display_name": "My Custom Claude Instance", "model_id": "claude-3-sonnet-20240229", "provider_type": "openai_compatible", "endpoint_url": "https://my-endpoint.com/v1", diff --git a/custom_model_endpoints.json b/custom_model_endpoints.json index 0e1b2efb..86d0c3db 100644 --- a/custom_model_endpoints.json +++ b/custom_model_endpoints.json @@ -10,6 +10,11 @@ "updated_at": "2025-09-24T05:46:04.854909+00:00", "endpoint_url": "https://modelservice.ml.my-cluster.com/model/llamav3/invocations", "cdp_token": "abc123def456" + }, + "gpt-4o-test:openai": { + "model_id": "gpt-4o-test", + "provider_type": "openai", + "api_key": "sk-test123" } }, "created_at": "2025-09-24T05:24:03.451632+00:00", From e86f8e8c6eb1010807c3eda158bce6b32e68f910 Mon Sep 17 00:00:00 2001 From: Khauneesh Saigal Date: Wed, 24 Sep 2025 22:57:00 +0530 Subject: [PATCH 08/12] feat: Add custom AWS credential support for Bedrock handler - Implement custom credential lookup for Bedrock provider - Use custom AWS credentials from JSON file when available - Fallback to default AWS credential chain when no custom config found - Update client creation on retry/connection errors - Now all providers (OpenAI, Gemini, CAII, OpenAI Compatible, Bedrock) support custom credentials - Maintain consistent credential priority: JSON > Environment Variables > Error --- app/core/model_handlers.py | 57 +++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/app/core/model_handlers.py b/app/core/model_handlers.py index 3d4fd434..b3f64e67 100644 --- a/app/core/model_handlers.py +++ b/app/core/model_handlers.py @@ -209,6 +209,31 @@ def generate_response( def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool): """Handle Bedrock requests with retry logic""" + + # Check for custom endpoint configuration + custom_config = get_custom_endpoint_config(self.model_id, "bedrock") + + if custom_config: + # Use custom AWS credentials + from botocore.config import Config + retry_config = Config( + region_name=custom_config.aws_region, + retries={"max_attempts": 2, "mode": "standard"}, + connect_timeout=5, + read_timeout=3600 + ) + bedrock_client = boto3.client( + 'bedrock-runtime', + aws_access_key_id=custom_config.aws_access_key_id, + aws_secret_access_key=custom_config.aws_secret_access_key, + region_name=custom_config.aws_region, + config=retry_config + ) + print(f"Using custom Bedrock endpoint for model: {self.model_id}") + else: + # Fallback to default bedrock client (environment/IAM credentials) + bedrock_client = self.bedrock_client + retries = 0 last_exception = None new_max_tokens = 8192 @@ -228,7 +253,7 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool): "stopSequences": ["\n\nHuman:"], } - response = self.bedrock_client.converse( + response = bedrock_client.converse( modelId=self.model_id, messages=conversation, inferenceConfig=inference_config, @@ -242,7 +267,7 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool): "stopSequences": [] } print(inference_config) - response = self.bedrock_client.converse( + response = bedrock_client.converse( modelId=self.model_id, messages=conversation, inferenceConfig=inference_config @@ -270,11 +295,29 @@ def _handle_bedrock_request(self, prompt: str, retry_with_reduced_tokens: bool): self._exponential_backoff(retries) retries += 1 - # Create a new client on connection errors - self.bedrock_client = boto3.client( - service_name="bedrock-runtime", - config=self.bedrock_client.meta.config - ) + # Create a new client on connection errors + if custom_config: + # Recreate with custom credentials + from botocore.config import Config + retry_config = Config( + region_name=custom_config.aws_region, + retries={"max_attempts": 2, "mode": "standard"}, + connect_timeout=5, + read_timeout=3600 + ) + bedrock_client = boto3.client( + 'bedrock-runtime', + aws_access_key_id=custom_config.aws_access_key_id, + aws_secret_access_key=custom_config.aws_secret_access_key, + region_name=custom_config.aws_region, + config=retry_config + ) + else: + # Recreate default client + bedrock_client = boto3.client( + service_name="bedrock-runtime", + config=self.bedrock_client.meta.config + ) continue # Handle other AWS errors From 1377563fbe55674b58839e04fecb08eaecc17119 Mon Sep 17 00:00:00 2001 From: Keivan Vosoughi Date: Tue, 23 Sep 2025 21:23:03 -0700 Subject: [PATCH 09/12] Adding Settings Page Add Settings Page Add Models Table Add Model Provider Modal Add Delete Model Add Edit Modal Change Configure Model ID Refetch After Delete --- app/client/src/Container.tsx | 6 + .../components/JobStatus/jobStatusIcon.tsx | 10 +- app/client/src/constants.ts | 1 + .../src/pages/DataGenerator/Configure.tsx | 59 ++++- .../src/pages/DataGenerator/Examples.tsx | 3 +- .../DataGenerator/FreeFormExampleTable.tsx | 1 + .../src/pages/DataGenerator/Success.tsx | 2 +- .../src/pages/DataGenerator/constants.ts | 2 + app/client/src/pages/DataGenerator/types.ts | 2 + .../src/pages/Datasets/DatasetsPage.tsx | 1 - app/client/src/pages/Home/DatasetsTab.tsx | 10 +- .../pages/Settings/AddModelProviderButton.tsx | 228 +++++++++++++++++ .../src/pages/Settings/EditModelProvider.tsx | 236 ++++++++++++++++++ .../src/pages/Settings/SettingsPage.tsx | 233 +++++++++++++++++ app/client/src/pages/Settings/Toolbar.tsx | 55 ++++ app/client/src/pages/Settings/hooks.ts | 80 ++++++ app/client/src/routes.tsx | 8 + app/client/src/types.ts | 3 +- 18 files changed, 923 insertions(+), 17 deletions(-) create mode 100644 app/client/src/pages/Settings/AddModelProviderButton.tsx create mode 100644 app/client/src/pages/Settings/EditModelProvider.tsx create mode 100644 app/client/src/pages/Settings/SettingsPage.tsx create mode 100644 app/client/src/pages/Settings/Toolbar.tsx create mode 100644 app/client/src/pages/Settings/hooks.ts diff --git a/app/client/src/Container.tsx b/app/client/src/Container.tsx index 7f8cfda5..a35a5fa1 100644 --- a/app/client/src/Container.tsx +++ b/app/client/src/Container.tsx @@ -95,6 +95,12 @@ const pages: MenuItem[] = [ {LABELS[Pages.EXPORTS]} ), }, + { + key: Pages.SETTINGS, + label: ( + {LABELS[Pages.SETTINGS]} + ), + }, // { // key: Pages.TELEMETRY, diff --git a/app/client/src/components/JobStatus/jobStatusIcon.tsx b/app/client/src/components/JobStatus/jobStatusIcon.tsx index c9511a33..42f822b6 100644 --- a/app/client/src/components/JobStatus/jobStatusIcon.tsx +++ b/app/client/src/components/JobStatus/jobStatusIcon.tsx @@ -24,6 +24,12 @@ const IconWrapper = styled.div` } ` +const StyledIconWrapper = styled(IconWrapper)` + svg { + color: #008cff; + } +`; + export default function JobStatusIcon({ status, customTooltipTitles }: JobStatusProps) { const tooltipTitles = {...defaultTooltipTitles, ...customTooltipTitles}; @@ -44,11 +50,11 @@ export default function JobStatusIcon({ status, customTooltipTitles }: JobStatus ; case 'ENGINE_SCHEDULING': return - + ; case 'ENGINE_RUNNING': return - + ; case null: return diff --git a/app/client/src/constants.ts b/app/client/src/constants.ts index b5f201bc..37436c28 100644 --- a/app/client/src/constants.ts +++ b/app/client/src/constants.ts @@ -12,6 +12,7 @@ export const LABELS = { [Pages.EXPORTS]: 'Exports', [Pages.HISTORY]: 'History', [Pages.FEEDBACK]: 'Feedback', + [Pages.SETTINGS]: 'Settings', //[Pages.TELEMETRY]: 'Telemetry', [ModelParameters.TEMPERATURE]: 'Temperature', [ModelParameters.TOP_K]: 'Top K', diff --git a/app/client/src/pages/DataGenerator/Configure.tsx b/app/client/src/pages/DataGenerator/Configure.tsx index 880abe17..1e6dc0d6 100644 --- a/app/client/src/pages/DataGenerator/Configure.tsx +++ b/app/client/src/pages/DataGenerator/Configure.tsx @@ -2,17 +2,23 @@ import endsWith from 'lodash/endsWith'; import isEmpty from 'lodash/isEmpty'; import isFunction from 'lodash/isFunction'; import { FunctionComponent, useEffect, useState } from 'react'; -import { Flex, Form, FormInstance, Input, Select, Typography } from 'antd'; +import { Flex, Form, Input, Select, Typography } from 'antd'; import styled from 'styled-components'; import { File, WorkflowType } from './types'; import { useFetchModels } from '../../api/api'; import { MODEL_PROVIDER_LABELS } from './constants'; import { ModelProviders, ModelProvidersDropdownOpts } from './types'; -import { getWizardModel, getWizardModeType, useWizardCtx } from './utils'; +import { getWizardModeType, useWizardCtx } from './utils'; import FileSelectorButton from './FileSelectorButton'; import UseCaseSelector from './UseCaseSelector'; import { useLocation, useParams } from 'react-router-dom'; import { WizardModeType } from '../../types'; +import get from 'lodash/get'; +import forEach from 'lodash/forEach'; +import { useModelProviders } from '../Settings/hooks'; +import { ModelProviderType } from '../Settings/AddModelProviderButton'; +import { CustomModel } from '../Settings/SettingsPage'; +import filter from 'lodash/filter'; const StepContainer = styled(Flex)` @@ -47,6 +53,8 @@ export const WORKFLOW_OPTIONS = [ export const MODEL_TYPE_OPTIONS: ModelProvidersDropdownOpts = [ { label: MODEL_PROVIDER_LABELS[ModelProviders.BEDROCK], value: ModelProviders.BEDROCK}, { label: MODEL_PROVIDER_LABELS[ModelProviders.CAII], value: ModelProviders.CAII }, + { label: MODEL_PROVIDER_LABELS[ModelProviders.OPENAI], value: ModelProviders.OPENAI }, + { label: MODEL_PROVIDER_LABELS[ModelProviders.GEMINI], value: ModelProviders.GEMINI }, ]; const Configure: FunctionComponent = () => { @@ -54,7 +62,12 @@ const Configure: FunctionComponent = () => { const formData = Form.useWatch((values) => values, form); const location = useLocation(); const { template_name, generate_file_name } = useParams(); + const [models, setModels] = useState([]) const [wizardModeType, setWizardModeType] = useState(getWizardModeType(location)); + const { data } = useFetchModels(); + const customModelPrividersReq = useModelProviders(); + const customModels = get(customModelPrividersReq, 'data.endpoints', []); + console.log('customModels', customModels); useEffect(() => { if (wizardModeType === WizardModeType.DATA_AUGMENTATION) { @@ -77,10 +90,18 @@ const Configure: FunctionComponent = () => { } }, [template_name]); + useEffect(() => { + // set model providers + // set model ids + if (formData && (formData?.inference_type === ModelProviderType.OPENAI || formData?.inference_type === ModelProviderType.GEMINI) && isEmpty(generate_file_name)) { + form.setFieldValue('inference_type', ModelProviders.OPENAI); + } + + }, [customModels, formData]); + // let formData = Form.useWatch((values) => values, form); const { setIsStepValid } = useWizardCtx(); - const { data } = useFetchModels(); const [selectedFiles, setSelectedFiles] = useState( !isEmpty(form.getFieldValue('doc_paths')) ? form.getFieldValue('doc_paths') : []); @@ -104,7 +125,6 @@ const Configure: FunctionComponent = () => { useEffect(() => { - console.log('useEffect 1'); if (formData && formData?.inference_type === undefined && isEmpty(generate_file_name)) { form.setFieldValue('inference_type', ModelProviders.CAII); setTimeout(() => { @@ -155,6 +175,20 @@ const Configure: FunctionComponent = () => { } } + const onModelProviderChange = (value: string) => { + form.setFieldValue('model_id', undefined) + console.log('value', value); + if (ModelProviderType.OPENAI === value) { + const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.OPENAI); + setModels(_models.map((_model: CustomModel) => _model.model_id)); + } else if (ModelProviderType.GEMINI === value) { + const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.GEMINI); + setModels(_models.map((_model: CustomModel) => _model.model_id)); + } + } + console.log('models', models); + + return ( @@ -178,7 +212,7 @@ const Configure: FunctionComponent = () => { > ) : ( - + {formData?.inference_type === ModelProviders.BEDROCK && data?.models?.[ModelProviders.BEDROCK]?.map((model, i) => ( {model} - )} + ))} + {(formData?.inference_type === ModelProviders.OPENAI || formData?.inference_type === ModelProviders.GEMINI) && models?.map((model, i) => ( + + {model} + + ))} )} - {formData?.inference_type === ModelProviders.CAII && ( <> diff --git a/app/client/src/pages/DataGenerator/Examples.tsx b/app/client/src/pages/DataGenerator/Examples.tsx index 9b46f389..4bfbe449 100644 --- a/app/client/src/pages/DataGenerator/Examples.tsx +++ b/app/client/src/pages/DataGenerator/Examples.tsx @@ -132,12 +132,13 @@ const Examples: FunctionComponent = () => { }; const showEmptyState = (workflowType === WorkflowType.FREE_FORM_DATA_GENERATION && - isEmpty(mutation.data) && + isEmpty(mutation.data) && Array.isArray(records) && records.length === 0) || (form.getFieldValue('use_case') === 'custom' && isEmpty(form.getFieldValue('examples'))); + console.log('records', records); return ( {mutation?.isPending || restore_mutation.isPending && } diff --git a/app/client/src/pages/DataGenerator/FreeFormExampleTable.tsx b/app/client/src/pages/DataGenerator/FreeFormExampleTable.tsx index c93bceba..f905cce2 100644 --- a/app/client/src/pages/DataGenerator/FreeFormExampleTable.tsx +++ b/app/client/src/pages/DataGenerator/FreeFormExampleTable.tsx @@ -48,6 +48,7 @@ interface Props { const FreeFormExampleTable: FunctionComponent = ({ data }) => { const [colDefs, setColDefs] = useState([]); const [rowData, setRowData] = useState([]); + console.log('FreeFormExampleTable', data); useEffect(() => { if (!isEmpty(data)) { diff --git a/app/client/src/pages/DataGenerator/Success.tsx b/app/client/src/pages/DataGenerator/Success.tsx index f7c8e504..e942e3c9 100644 --- a/app/client/src/pages/DataGenerator/Success.tsx +++ b/app/client/src/pages/DataGenerator/Success.tsx @@ -122,7 +122,7 @@ const Success: FC = ({ formData, isDemo = true }) => { <Flex align='center' gap={10}> <CheckCircleIcon style={{ color: '#178718' }}/> - {'Success'} + {isDemo ? 'Success' : 'Job successfully started.'} </Flex> {isDemo ? ( diff --git a/app/client/src/pages/DataGenerator/constants.ts b/app/client/src/pages/DataGenerator/constants.ts index b90946bc..b4f1f058 100644 --- a/app/client/src/pages/DataGenerator/constants.ts +++ b/app/client/src/pages/DataGenerator/constants.ts @@ -5,6 +5,8 @@ export const MODEL_PROVIDER_LABELS = { [ModelProviders.CAII]: 'Cloudera AI Inference Service', [ModelProviders.GOOGLE_GEMINI]: 'Google Gemini', [ModelProviders.AZURE_OPENAI]: 'Azure OpenAI', + [ModelProviders.GEMINI]: 'Gemini', + [ModelProviders.OPENAI]: 'OpenAI' }; export const MIN_SEED_INSTRUCTIONS = 1 diff --git a/app/client/src/pages/DataGenerator/types.ts b/app/client/src/pages/DataGenerator/types.ts index 73c64b20..027bb054 100644 --- a/app/client/src/pages/DataGenerator/types.ts +++ b/app/client/src/pages/DataGenerator/types.ts @@ -19,6 +19,8 @@ export enum ModelProviders { CAII = 'CAII', AZURE_OPENAI = 'AZURE_OPENAI', GOOGLE_GEMINI = 'GOOGLE_GEMINI', + OPENAI = 'openai', + GEMINI = 'gemini', } export type ModelProvidersDropdownOpts = { label: string, value: ModelProviders }[]; diff --git a/app/client/src/pages/Datasets/DatasetsPage.tsx b/app/client/src/pages/Datasets/DatasetsPage.tsx index b36c85d7..2ff5954c 100644 --- a/app/client/src/pages/Datasets/DatasetsPage.tsx +++ b/app/client/src/pages/Datasets/DatasetsPage.tsx @@ -4,7 +4,6 @@ import { Col, Flex, Input, Layout, Row, Table, TableProps, Tooltip, notification import styled from 'styled-components'; import Paragraph from 'antd/es/typography/Paragraph'; import { useDatasets } from '../Home/hooks'; -import { ExportResult } from '../../components/Export/ExportModal'; import { SearchProps } from 'antd/es/input'; import Loading from '../Evaluator/Loading'; import { Dataset } from '../Evaluator/types'; diff --git a/app/client/src/pages/Home/DatasetsTab.tsx b/app/client/src/pages/Home/DatasetsTab.tsx index 7ce2d040..6512dc1a 100644 --- a/app/client/src/pages/Home/DatasetsTab.tsx +++ b/app/client/src/pages/Home/DatasetsTab.tsx @@ -106,10 +106,16 @@ const DatasetsTab: React.FC = ({ hideSearch = false }) => { key: 'job_status', title: 'Status', dataIndex: 'job_status', - width: 80, + width: 140, sorter: sortItemsByKey('job_status'), - render: (status: JobStatus) => + render: (status: JobStatus) => + + {status === 'ENGINE_SCHEDULING' &&
{'Scheduling'}
} + {status === 'ENGINE_RUNNING' &&
{'Running'}
} + {status === 'ENGINE_STOPPED' &&
{'Stopped'}
} + {status === 'ENGINE_SUCCEEDED' &&
{'Success'}
} + {status === 'ENGINE_TIMEDOUT' &&
{'Timeout'}
}
}, { diff --git a/app/client/src/pages/Settings/AddModelProviderButton.tsx b/app/client/src/pages/Settings/AddModelProviderButton.tsx new file mode 100644 index 00000000..0ba8b780 --- /dev/null +++ b/app/client/src/pages/Settings/AddModelProviderButton.tsx @@ -0,0 +1,228 @@ +import { useEffect, useState } from 'react'; +import { PlusCircleOutlined } from '@ant-design/icons'; +import { Alert, Button, Form, Input, Modal, notification, Radio, Select } from 'antd'; +import type { CheckboxGroupProps } from 'antd/es/checkbox'; +import get from 'lodash/get'; +import isEqual from 'lodash/isEqual'; +import { useMutation } from "@tanstack/react-query"; +import { addModelProvider } from './hooks'; +import Loading from '../Evaluator/Loading'; + +export enum ModelProviderType { + OPENAI = 'openai', + GEMINI = 'gemini', + CAII = 'caii' +} + + +const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ + { label: 'OpenAI', value: 'openai' }, + // { label: 'CAII', value: 'caii' }, + { label: 'Gemini', value: 'gemini' }, +]; + +const OPENAI_MODELS = [ + "gpt-4.1", // Latest GPT-4.1 series (April 2025) + "gpt-4.1-mini", + "gpt-4.1-nano", + "o3", // Latest reasoning models (April 2025) + "o4-mini", + "o3-mini", // January 2025 + "o1", // December 2024 + "gpt-4o", // November 2024 + "gpt-4o-mini", // July 2024 + "gpt-4-turbo", // April 2024 + "gpt-3.5-turbo" // Legacy but still widely used +]; + +const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ + label: model, + value: model +})); + +const GEMINI_MODELS = [ + "gemini-2.5-pro", // June 2025 - most powerful thinking model + "gemini-2.5-flash", // June 2025 - best price-performance + "gemini-2.5-flash-lite", // June 2025 - cost-efficient + "gemini-2.0-flash", // February 2025 - next-gen features + "gemini-2.0-flash-lite", // February 2025 - low latency + "gemini-1.5-pro", // September 2024 - complex reasoning + "gemini-1.5-flash", // September 2024 - fast & versatile + "gemini-1.5-flash-8b" // October 2024 - lightweight +]; + +const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ + label: model, + value: model +})); + +interface Props { + refetch: () => void; +} + +const AddModelProviderButton: React.FC = ({ refetch }) => { + const [form] = Form.useForm(); + const [showModal, setShowModal] = useState(false); + const [models, setModels] = useState(OPENAI_MODELS_OPTIONS); + const mutation = useMutation({ + mutationFn: addModelProvider + }); + + + useEffect(() => { + if (mutation.isError) { + notification.error({ + message: 'Error', + description: `An error occurred while fetching the prompt.\n ${mutation.error}` + }); + } + if (mutation.isSuccess) { + notification.success({ + message: 'Success', + description: `THe model provider has been created successfully!.` + }); + form.resetFields(); + setShowModal(false); + refetch(); + } + }, [mutation.error, mutation.isSuccess]); + + const onCancel = () => { + form.resetFields(); + setShowModal(false); + } + + const onSubmit = async () => { + try { + await form.validateFields(); + const values = form.getFieldsValue(); + + mutation.mutate({ + endpoint_config: { + display_name: values.display_name, + endpoint_id: values.endpoint_id, + model_id: values.model_id, + provider_type: values.provider_type, + api_key: values.api_key, + endpoint_url: values.endpoint_url + } + }); + } catch (error) { + console.error(error); + } + }; + + + const initialValues = { + provider_type: 'openai' + }; + + const onChange = (e: any) => { + const value = get(e, 'target.value'); + if (value === 'openai' && !isEqual(OPENAI_MODELS_OPTIONS, models)) { + setModels(OPENAI_MODELS_OPTIONS); + } else if (value === 'gemini' && !isEqual(GEMINI_MODELS_OPTIONS, models)) { + setModels(GEMINI_MODELS_OPTIONS); + } + } + + return ( + <> + + {showModal && +
+
+
+ {mutation.isPending && } + {mutation.error && ( + {mutation.error instanceof Error ? mutation.error.message : String(mutation.error)} + } + /> + )} + + + + + + + + + + + + + + + + +
} + + ); +} + +export default AddModelProviderButton; + diff --git a/app/client/src/pages/Settings/EditModelProvider.tsx b/app/client/src/pages/Settings/EditModelProvider.tsx new file mode 100644 index 00000000..0c786a46 --- /dev/null +++ b/app/client/src/pages/Settings/EditModelProvider.tsx @@ -0,0 +1,236 @@ +import { useEffect, useState } from 'react'; +import { PlusCircleOutlined } from '@ant-design/icons'; +import { Alert, Button, Form, Input, Modal, notification, Radio, Select } from 'antd'; +import type { CheckboxGroupProps } from 'antd/es/checkbox'; +import get from 'lodash/get'; +import isEqual from 'lodash/isEqual'; +import { useMutation } from "@tanstack/react-query"; +import { addModelProvider, useGetModelProvider } from './hooks'; +import Loading from '../Evaluator/Loading'; +import { CustomModel } from './SettingsPage'; +import isEmpty from 'lodash/isEmpty'; + +export enum ModelProviderType { + OPENAI = 'openai', + GEMINIE = 'gemini', + CAII = 'caii' +} + + +const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ + { label: 'OpenAI', value: 'openai' }, + // { label: 'CAII', value: 'caii' }, + { label: 'Gemini', value: 'gemini' }, +]; + +const OPENAI_MODELS = [ + "gpt-4.1", // Latest GPT-4.1 series (April 2025) + "gpt-4.1-mini", + "gpt-4.1-nano", + "o3", // Latest reasoning models (April 2025) + "o4-mini", + "o3-mini", // January 2025 + "o1", // December 2024 + "gpt-4o", // November 2024 + "gpt-4o-mini", // July 2024 + "gpt-4-turbo", // April 2024 + "gpt-3.5-turbo" // Legacy but still widely used +]; + +const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ + label: model, + value: model +})); + +const GEMINI_MODELS = [ + "gemini-2.5-pro", // June 2025 - most powerful thinking model + "gemini-2.5-flash", // June 2025 - best price-performance + "gemini-2.5-flash-lite", // June 2025 - cost-efficient + "gemini-2.0-flash", // February 2025 - next-gen features + "gemini-2.0-flash-lite", // February 2025 - low latency + "gemini-1.5-pro", // September 2024 - complex reasoning + "gemini-1.5-flash", // September 2024 - fast & versatile + "gemini-1.5-flash-8b" // October 2024 - lightweight +]; + +const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ + label: model, + value: model +})); + +interface Props { + refetch: () => void; + onClose: () => void; + model: CustomModel; +} + +const EditModelProvider: React.FC = ({ model, refetch, onClose }) => { + const [form] = Form.useForm(); + const modelProviderReq = useGetModelProvider(model.endpoint_id); + const [models, setModels] = useState(OPENAI_MODELS_OPTIONS); + const mutation = useMutation({ + mutationFn: addModelProvider + }); + + useEffect(() => { + if (!isEmpty(modelProviderReq.data)) { + const endpoint = get(modelProviderReq, 'data.endpoint'); + form.setFieldsValue({ + ...endpoint + }); + } + }, [modelProviderReq.data]); + + + useEffect(() => { + if (mutation.isError) { + notification.error({ + message: 'Error', + description: `An error occurred while fetching the model.\n ${mutation.error}` + }); + } + if (mutation.isSuccess) { + notification.success({ + message: 'Success', + description: `THe model provider has been edited successfully!.` + }); + refetch(); + } + }, [mutation.error, mutation.isSuccess]); + + const onCancel = () => { + form.resetFields(); + onClose(); + } + + const onSubmit = async () => { + try { + await form.validateFields(); + const values = form.getFieldsValue(); + + mutation.mutate({ + endpoint_config: { + display_name: values.display_name, + endpoint_id: values.endpoint_id, + model_id: values.model_id, + provider_type: values.provider_type, + api_key: values.api_key, + endpoint_url: values.endpoint_url + } + }); + } catch (error) { + console.error(error); + } + }; + + + const initialValues = { + provider_type: 'openai' + }; + + const onChange = (e: any) => { + const value = get(e, 'target.value'); + if (value === 'openai' && !isEqual(OPENAI_MODELS_OPTIONS, models)) { + setModels(OPENAI_MODELS_OPTIONS); + } else if (value === 'gemini' && !isEqual(GEMINI_MODELS_OPTIONS, models)) { + setModels(GEMINI_MODELS_OPTIONS); + } + } + + return ( + <> + +
+
+
+ {(mutation.isPending || modelProviderReq.isLoading) && } + {mutation.error && ( + {mutation.error instanceof Error ? mutation.error.message : String(mutation.error)} + } + /> + )} + + + + + + + + + + + + + + + + +
+ + ); +} + +export default EditModelProvider; + diff --git a/app/client/src/pages/Settings/SettingsPage.tsx b/app/client/src/pages/Settings/SettingsPage.tsx new file mode 100644 index 00000000..4913ffd6 --- /dev/null +++ b/app/client/src/pages/Settings/SettingsPage.tsx @@ -0,0 +1,233 @@ +import { Button, Col, Flex, Layout, Modal, notification, Row, Table, Tooltip, Tooltip } from "antd"; +import { Content } from "antd/es/layout/layout"; +import styled from "styled-components"; +import { deleteModelProvider, useModelProviders } from "./hooks"; +import get from "lodash/get"; +import { sortItemsByKey } from "../../utils/sortutils"; +import Paragraph from "antd/es/typography/Paragraph"; +import StyledTitle from "../Evaluator/StyledTitle"; +import Toolbar from "./Toolbar"; +import AddModelProviderButton, { ModelProviderType } from "./AddModelProviderButton"; +import DateTime from "../../components/DateTime/DateTime"; +import { + EditOutlined, + DeleteOutlined + } from '@ant-design/icons'; +import { useMutation } from "@tanstack/react-query"; +import { useState } from "react"; +import EditModelProvider from "./EditModelProvider"; + + + +const StyledContent = styled(Content)` + padding: 24px; + background-color: #f5f7f8; +`; + + +export interface CustomModel { + endpoint_id: string; + display_name: string; + model_id: string; + provider_type: string; + api_key?: string; + cdp_token?: string; + created_at: string +} + +const Container = styled.div` + background-color: #ffffff; + padding: 1rem; + overflow-x: auto; +`; + +const StyledTable = styled(Table)` + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-table-thead > tr > th { + color: #5a656d; + border-bottom: 1px solid #eaebec; + font-weight: 500; + text-align: left; + // background: #ffffff; + border-bottom: 1px solid #eaebec; + transition: background 0.3s ease; + } + .ant-table-row > td.ant-table-cell { + padding: 8px; + padding-left: 16px; + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; + .ant-typography { + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + } + } +`; + +const StyledParagraph = styled(Paragraph)` + font-size: 13px; + font-family: Roboto, -apple-system, 'Segoe UI', sans-serif; + color: #5a656d; +`; + +const StyledButton = styled(Button)` + margin-left: 8px; +`; + +const SettingsPage: React.FC = () => { + const [showModal, setShowModal] = useState(false); + const [model, setModel] = useState(null); + const filteredModelsReq = useModelProviders(); + const customModels = get(filteredModelsReq, 'data.endpoints', []); + + const mutation = useMutation({ + mutationFn: deleteModelProvider + }); + + const onDelete = (model: CustomModel) => { + Modal.confirm({ + content: ( + {`Are you sure you want to delete the model \'${model.display_name}\'`} + ), + onOk: async () => { + try { + mutation.mutate({ + endpoint_id: model.endpoint_id + }) + } catch (error) { + notification.error({ + message: "Error", + description: error instanceof Error ? error.message : String(error), + }); + } + filteredModelsReq.refetch(); + }, + title: 'Confirm' + }); + }; + + const onEdit = (_model: CustomModel) => { + setShowModal(true); + setModel(_model) + + + } + + const modelProvidersColumns = [{ + key: 'display_name', + title: 'Display Name', + dataIndex: 'display_name', + width: 200, + sorter: sortItemsByKey('display_name') + + }, { + key: 'provider_type', + title: 'Provider Type', + dataIndex: 'provider_type', + width: 150, + sorter: sortItemsByKey('provider_type'), + render: (provider_type: string) => { + if (provider_type === 'openai') { + return 'OpenAI'; + } else if (provider_type === ModelProviderType.GEMINI) { + return 'Gemini'; + } else if (provider_type === ModelProviderType.CAII) { + return 'CAII'; + } + return 'N/A' + } + }, { + key: 'model_id', + title: 'Model ID', + dataIndex: 'model_id', + width: 200, + sorter: sortItemsByKey('model_id') + + }, { + key: 'created_at', + title: 'Created At', + dataIndex: 'created_at', + width: 200, + sorter: sortItemsByKey('created_at'), + render: (timestamp: string) => <>{timestamp == null ? 'N/A' : } + + }, { + key: 'endpoint_url', + title: 'Endpoint', + dataIndex: 'endpoint_url', + width: 300, + sorter: sortItemsByKey('endpoint_url'), + render: (endpoint_url: string) => {endpoint_url} + }, { + title: 'Actions', + width: 100, + render: (model: CustomModel) => { + return ( + + + + + + onEdit(model)} + data-event-category="User Action" + data-event="Edit" + > + + + + + ); + + } + }]; + + return ( + + + + {'Settings'} + +
+
+ + + {'Custom Models'} + } + right={ + + + + } + /> + `${row?.endpoint_id}`} + tableLayout="fixed" + columns={modelProvidersColumns} + dataSource={customModels || [] as CustomModel[]} /> + + +
+ {showModal && + setShowModal(false)} />} +
+
+ ); + +} + +export default SettingsPage; \ No newline at end of file diff --git a/app/client/src/pages/Settings/Toolbar.tsx b/app/client/src/pages/Settings/Toolbar.tsx new file mode 100644 index 00000000..579591de --- /dev/null +++ b/app/client/src/pages/Settings/Toolbar.tsx @@ -0,0 +1,55 @@ +import React from 'react'; +import { Row } from 'antd'; +import classNames from 'classnames'; +import styled from 'styled-components'; + +interface Props { + /** Content for the left content area. */ + left?: React.ReactNode; + /** Content for the center content area. */ + center?: React.ReactNode; + /** Content for the right content area. */ + right?: React.ReactNode; + /** Additional class name to add to the underlying Row component. */ + className?: string; + style?: React.CSSProperties; +} + +const StyledRow = styled(Row)` + .altus-toolbar { + & .altus-toolbar-group:not(.ant-row-start) { + flex-grow: 1; + } +} + +`; + +/** + * @deprecated + */ +export default function Toolbar({ left, center, right, className = '', ...otherProps }: Props) { + return ( + + {left && ( + + {left} + + )} + {center && ( + + {center} + + )} + {right && ( + + {right} + + )} + + ); +} diff --git a/app/client/src/pages/Settings/hooks.ts b/app/client/src/pages/Settings/hooks.ts new file mode 100644 index 00000000..4269a73a --- /dev/null +++ b/app/client/src/pages/Settings/hooks.ts @@ -0,0 +1,80 @@ +import { useQuery } from "@tanstack/react-query"; + +const BASE_API_URL = import.meta.env.VITE_AMP_URL; + + +const fetchFilteredModels = async () => { + // const model_filtered_resp = await fetch(`${BASE_API_URL}/model/model_id_filter`, { + const model_filtered_resp = await fetch(`/custom_model_endpoints`, { + method: 'GET', + }); + return await model_filtered_resp.json(); +}; + + +export const deleteModelProvider = async ({ endpoint_id }) => { + const delete_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { + method: 'DELETE' + }); + return await delete_resp.json(); +} + +export const getModelProvider = async ({ endpoint_id }) => { + const get_model_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { + method: 'GET' + }); + return await get_model_resp.json(); +} + +export const updateModelProvider = async ({ endpoint_id }) => { + const update_model_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { + method: 'PUT' + }); + return await update_model_resp.json(); +} + + +export const useModelProviders = () => { + + const { data, isLoading, isError, refetch } = useQuery( + { + queryKey: ['fetchFilteredModels'], + queryFn: () => fetchFilteredModels(), + refetchInterval: 15000 + } + ); + + return { + data, + isLoading, + isError, + refetch + }; +} + +export const addModelProvider = async (params: any) => { + const model_filtered_resp = await fetch(`/add_model_endpoint`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(params) + }); + return await model_filtered_resp.json(); +} + +export const useGetModelProvider = (endpoint_id) => { + + const { data, isLoading, isError, refetch } = useQuery( + { + queryKey: ['getModelProvider'], + queryFn: () => getModelProvider({ endpoint_id }), + refetchInterval: 15000 + } + ); + + return { + data, + isLoading, + isError, + refetch + }; +} \ No newline at end of file diff --git a/app/client/src/routes.tsx b/app/client/src/routes.tsx index 9a0e4cc2..db393328 100644 --- a/app/client/src/routes.tsx +++ b/app/client/src/routes.tsx @@ -12,6 +12,7 @@ import EvaluationDetailsPage from "./pages/EvaluationDetails/EvaluationDetailsPa import DatasetsPage from "./pages/Datasets/DatasetsPage"; import EvaluationsPage from "./pages/Evaluations/EvaluationsPage"; import ExportsPage from "./pages/Exports/ExportsPage"; +import SettingsPage from "./pages/Settings/SettingsPage"; //import TelemetryDashboard from "./components/TelemetryDashboard"; @@ -108,6 +109,13 @@ const router = createBrowserRouter([ errorElement: , loader: async () => null }, + { + path: Pages.SETTINGS, + element: , + errorElement: , + loader: async () => null + }, + // { // path: `telemetry`, // element: , diff --git a/app/client/src/types.ts b/app/client/src/types.ts index 900b81ab..414fdbd0 100644 --- a/app/client/src/types.ts +++ b/app/client/src/types.ts @@ -10,7 +10,8 @@ export enum Pages { EXPORTS = 'exports', WELCOME = 'welcome', FEEDBACK = 'feedback', - UPGRADE = 'upgrade' + UPGRADE = 'upgrade', + SETTINGS = 'settings' //TELEMETRY = 'telemetry' } From 303b6f70fe80cc02aaa643ce5e1bbb382d82fbf5 Mon Sep 17 00:00:00 2001 From: Keivan Vosoughi Date: Thu, 2 Oct 2025 17:11:42 -0700 Subject: [PATCH 10/12] Remove redundant Gemini Model --- .../pages/Settings/AddModelProviderButton.tsx | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/app/client/src/pages/Settings/AddModelProviderButton.tsx b/app/client/src/pages/Settings/AddModelProviderButton.tsx index 0ba8b780..4b6b5587 100644 --- a/app/client/src/pages/Settings/AddModelProviderButton.tsx +++ b/app/client/src/pages/Settings/AddModelProviderButton.tsx @@ -24,15 +24,7 @@ const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ const OPENAI_MODELS = [ "gpt-4.1", // Latest GPT-4.1 series (April 2025) "gpt-4.1-mini", - "gpt-4.1-nano", - "o3", // Latest reasoning models (April 2025) - "o4-mini", - "o3-mini", // January 2025 - "o1", // December 2024 - "gpt-4o", // November 2024 - "gpt-4o-mini", // July 2024 - "gpt-4-turbo", // April 2024 - "gpt-3.5-turbo" // Legacy but still widely used + "gpt-4.1-nano" ]; const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ @@ -43,12 +35,7 @@ const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ const GEMINI_MODELS = [ "gemini-2.5-pro", // June 2025 - most powerful thinking model "gemini-2.5-flash", // June 2025 - best price-performance - "gemini-2.5-flash-lite", // June 2025 - cost-efficient - "gemini-2.0-flash", // February 2025 - next-gen features - "gemini-2.0-flash-lite", // February 2025 - low latency - "gemini-1.5-pro", // September 2024 - complex reasoning - "gemini-1.5-flash", // September 2024 - fast & versatile - "gemini-1.5-flash-8b" // October 2024 - lightweight + "gemini-2.5-flash-lite" // June 2025 - cost-efficient ]; const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ From 0867ee13609485c26f330e880c99f6f6f52dc89b Mon Sep 17 00:00:00 2001 From: Keivan Vosoughi Date: Thu, 2 Oct 2025 20:56:37 -0700 Subject: [PATCH 11/12] Fix the message --- app/client/src/pages/DataGenerator/Finish.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/client/src/pages/DataGenerator/Finish.tsx b/app/client/src/pages/DataGenerator/Finish.tsx index 60f7d4de..54d4ef4b 100644 --- a/app/client/src/pages/DataGenerator/Finish.tsx +++ b/app/client/src/pages/DataGenerator/Finish.tsx @@ -322,12 +322,14 @@ const Finish = () => { const rawData = genDatasetResp !== null && hasTopics(genDatasetResp) ? getRawData(genDatasetResp) : genDatasetResp?.results + console.log('Finish >> ', isDemo); + return (
<Flex align='center' gap={10}> <CheckCircleIcon style={{ color: '#178718' }}/> - {'Success'} + {isDemo ? 'Success' : 'Job Successfully Started'} </Flex> {isDemo ? ( From 62421d03a468344ba3c0e3304ffed2421c2f6472 Mon Sep 17 00:00:00 2001 From: Keivan Vosoughi Date: Thu, 9 Oct 2025 14:38:05 -0700 Subject: [PATCH 12/12] Add CAII Model Provider Add AWS Bedrock Model Provider Fix edit custom model issue --- .../pages/Settings/AddModelProviderButton.tsx | 131 +++++++++++++++--- .../src/pages/Settings/EditModelProvider.tsx | 122 +++++++--------- .../src/pages/Settings/SettingsPage.tsx | 50 ++++--- app/client/src/pages/Settings/hooks.ts | 21 +-- 4 files changed, 203 insertions(+), 121 deletions(-) diff --git a/app/client/src/pages/Settings/AddModelProviderButton.tsx b/app/client/src/pages/Settings/AddModelProviderButton.tsx index 4b6b5587..a1207a6f 100644 --- a/app/client/src/pages/Settings/AddModelProviderButton.tsx +++ b/app/client/src/pages/Settings/AddModelProviderButton.tsx @@ -1,6 +1,6 @@ import { useEffect, useState } from 'react'; import { PlusCircleOutlined } from '@ant-design/icons'; -import { Alert, Button, Form, Input, Modal, notification, Radio, Select } from 'antd'; +import { Alert, AutoComplete, Button, Form, Input, Modal, notification, Radio, Select } from 'antd'; import type { CheckboxGroupProps } from 'antd/es/checkbox'; import get from 'lodash/get'; import isEqual from 'lodash/isEqual'; @@ -10,15 +10,19 @@ import Loading from '../Evaluator/Loading'; export enum ModelProviderType { OPENAI = 'openai', + OPENAI_COMPATIBLE = 'openai_compatible', GEMINI = 'gemini', - CAII = 'caii' + CAII = 'caii', + AWS_BEDROCK = 'aws_bedrock' } -const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ +export const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ { label: 'OpenAI', value: 'openai' }, - // { label: 'CAII', value: 'caii' }, + { label: 'OpenAI Compatible', value: 'openai_compatible' }, { label: 'Gemini', value: 'gemini' }, + { label: 'AWS Bedrock', value: 'aws_bedrock' }, + { label: 'CAII', value: 'caii' }, ]; const OPENAI_MODELS = [ @@ -27,7 +31,7 @@ const OPENAI_MODELS = [ "gpt-4.1-nano" ]; -const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ +export const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ label: model, value: model })); @@ -38,7 +42,42 @@ const GEMINI_MODELS = [ "gemini-2.5-flash-lite" // June 2025 - cost-efficient ]; -const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ +export const AWS_BEDROCK_MODELS = [ + "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "us.anthropic.claude-opus-4-1-20250805-v1:0", + "us.anthropic.claude-opus-4-20250514-v1:0", + "global.anthropic.claude-sonnet-4-20250514-v1:0", + "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "us.anthropic.claude-3-opus-20240229-v1:0", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-large-2402-v1:0", + "mistral.mistral-small-2402-v1:0", + "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-3b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0", + "us.meta.llama3-2-1b-instruct-v1:0", + "us.meta.llama3-1-8b-instruct-v1:0", + "us.meta.llama3-1-70b-instruct-v1:0", + "us.meta.llama3-3-70b-instruct-v1:0", + "us.mistral.pixtral-large-2502-v1:0", + "us.meta.llama4-scout-17b-instruct-v1:0", + "us.meta.llama4-maverick-17b-instruct-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1" +]; + +export const AWS_BEDROCK_MODELS_OPTIONS = AWS_BEDROCK_MODELS.map((model: string) => ({ + label: model, + value: model +})); + +export const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ label: model, value: model })); @@ -50,6 +89,7 @@ interface Props { const AddModelProviderButton: React.FC = ({ refetch }) => { const [form] = Form.useForm(); const [showModal, setShowModal] = useState(false); + const [modelProviderType, setModelProviderType] = useState(ModelProviderType.OPENAI); const [models, setModels] = useState(OPENAI_MODELS_OPTIONS); const mutation = useMutation({ mutationFn: addModelProvider @@ -106,10 +146,13 @@ const AddModelProviderButton: React.FC = ({ refetch }) => { const onChange = (e: any) => { const value = get(e, 'target.value'); + setModelProviderType(value as ModelProviderType); if (value === 'openai' && !isEqual(OPENAI_MODELS_OPTIONS, models)) { setModels(OPENAI_MODELS_OPTIONS); } else if (value === 'gemini' && !isEqual(GEMINI_MODELS_OPTIONS, models)) { setModels(GEMINI_MODELS_OPTIONS); + } else if (value === 'aws_bedrock' && !isEqual(GEMINI_MODELS_OPTIONS, models)) { + setModels(AWS_BEDROCK_MODELS_OPTIONS); } } @@ -146,7 +189,7 @@ const AddModelProviderButton: React.FC = ({ refetch }) => { defaultValue="openai" optionType="button" buttonStyle="solid" - style={{ width: '40%' }} + style={{ width: '100%', whiteSpace: 'nowrap' }} onChange={onChange} /> @@ -162,49 +205,97 @@ const AddModelProviderButton: React.FC = ({ refetch }) => { - + {'Enter Model Name '}, value: '' } + ].concat( + models + )} + placeholder={'Select Model'} + /> - + + } + {modelProviderType !== ModelProviderType.AWS_BEDROCK && modelProviderType !== ModelProviderType.CAII && - + - + + + } } diff --git a/app/client/src/pages/Settings/EditModelProvider.tsx b/app/client/src/pages/Settings/EditModelProvider.tsx index 0c786a46..c6dda9c2 100644 --- a/app/client/src/pages/Settings/EditModelProvider.tsx +++ b/app/client/src/pages/Settings/EditModelProvider.tsx @@ -9,54 +9,8 @@ import { addModelProvider, useGetModelProvider } from './hooks'; import Loading from '../Evaluator/Loading'; import { CustomModel } from './SettingsPage'; import isEmpty from 'lodash/isEmpty'; +import { GEMINI_MODELS_OPTIONS, ModelProviderType, modelProviderTypeOptions, OPENAI_MODELS_OPTIONS } from './AddModelProviderButton'; -export enum ModelProviderType { - OPENAI = 'openai', - GEMINIE = 'gemini', - CAII = 'caii' -} - - -const modelProviderTypeOptions: CheckboxGroupProps['options'] = [ - { label: 'OpenAI', value: 'openai' }, - // { label: 'CAII', value: 'caii' }, - { label: 'Gemini', value: 'gemini' }, -]; - -const OPENAI_MODELS = [ - "gpt-4.1", // Latest GPT-4.1 series (April 2025) - "gpt-4.1-mini", - "gpt-4.1-nano", - "o3", // Latest reasoning models (April 2025) - "o4-mini", - "o3-mini", // January 2025 - "o1", // December 2024 - "gpt-4o", // November 2024 - "gpt-4o-mini", // July 2024 - "gpt-4-turbo", // April 2024 - "gpt-3.5-turbo" // Legacy but still widely used -]; - -const OPENAI_MODELS_OPTIONS = OPENAI_MODELS.map((model: string) => ({ - label: model, - value: model -})); - -const GEMINI_MODELS = [ - "gemini-2.5-pro", // June 2025 - most powerful thinking model - "gemini-2.5-flash", // June 2025 - best price-performance - "gemini-2.5-flash-lite", // June 2025 - cost-efficient - "gemini-2.0-flash", // February 2025 - next-gen features - "gemini-2.0-flash-lite", // February 2025 - low latency - "gemini-1.5-pro", // September 2024 - complex reasoning - "gemini-1.5-flash", // September 2024 - fast & versatile - "gemini-1.5-flash-8b" // October 2024 - lightweight -]; - -const GEMINI_MODELS_OPTIONS = GEMINI_MODELS.map((model: string) => ({ - label: model, - value: model -})); interface Props { refetch: () => void; @@ -66,7 +20,8 @@ interface Props { const EditModelProvider: React.FC = ({ model, refetch, onClose }) => { const [form] = Form.useForm(); - const modelProviderReq = useGetModelProvider(model.endpoint_id); + const [modelProviderType, setModelProviderType] = useState(ModelProviderType.OPENAI); + const modelProviderReq = useGetModelProvider(model); const [models, setModels] = useState(OPENAI_MODELS_OPTIONS); const mutation = useMutation({ mutationFn: addModelProvider @@ -75,9 +30,12 @@ const EditModelProvider: React.FC = ({ model, refetch, onClose }) => { useEffect(() => { if (!isEmpty(modelProviderReq.data)) { const endpoint = get(modelProviderReq, 'data.endpoint'); - form.setFieldsValue({ - ...endpoint - }); + if (!isEmpty(endpoint)) { + form.setFieldsValue({ + ...endpoint + }); + setModelProviderType(endpoint?.provider_type as ModelProviderType); + } } }, [modelProviderReq.data]); @@ -167,24 +125,24 @@ const EditModelProvider: React.FC = ({ model, refetch, onClose }) => { defaultValue="openai" optionType="button" buttonStyle="solid" - style={{ width: '40%' }} + style={{ width: '100%', whiteSpace: 'nowrap' }} onChange={onChange} /> - + - - } + {modelProviderType !== ModelProviderType.AWS_BEDROCK && modelProviderType !== ModelProviderType.CAII && - + - + + + } diff --git a/app/client/src/pages/Settings/SettingsPage.tsx b/app/client/src/pages/Settings/SettingsPage.tsx index 4913ffd6..81fbd328 100644 --- a/app/client/src/pages/Settings/SettingsPage.tsx +++ b/app/client/src/pages/Settings/SettingsPage.tsx @@ -16,6 +16,7 @@ import { import { useMutation } from "@tanstack/react-query"; import { useState } from "react"; import EditModelProvider from "./EditModelProvider"; +import isEmpty from "lodash/isEmpty"; @@ -116,13 +117,12 @@ const SettingsPage: React.FC = () => { } const modelProvidersColumns = [{ - key: 'display_name', - title: 'Display Name', - dataIndex: 'display_name', + key: 'model_id', + title: 'Model ID', + dataIndex: 'model_id', width: 200, - sorter: sortItemsByKey('display_name') - - }, { + sorter: sortItemsByKey('model_id') + },{ key: 'provider_type', title: 'Provider Type', dataIndex: 'provider_type', @@ -131,6 +131,8 @@ const SettingsPage: React.FC = () => { render: (provider_type: string) => { if (provider_type === 'openai') { return 'OpenAI'; + } else if (provider_type === 'openai_compatible') { + return 'OpenAI Compatible'; } else if (provider_type === ModelProviderType.GEMINI) { return 'Gemini'; } else if (provider_type === ModelProviderType.CAII) { @@ -139,27 +141,30 @@ const SettingsPage: React.FC = () => { return 'N/A' } }, { - key: 'model_id', - title: 'Model ID', - dataIndex: 'model_id', - width: 200, - sorter: sortItemsByKey('model_id') - - }, { - key: 'created_at', - title: 'Created At', - dataIndex: 'created_at', - width: 200, - sorter: sortItemsByKey('created_at'), - render: (timestamp: string) => <>{timestamp == null ? 'N/A' : } - - }, { + // key: 'created_at', + // title: 'Created At', + // dataIndex: 'created_at', + // width: 200, + // sorter: sortItemsByKey('created_at'), + // render: (timestamp: string) => <>{timestamp == null ? 'N/A' : } + // }, { key: 'endpoint_url', title: 'Endpoint', dataIndex: 'endpoint_url', width: 300, sorter: sortItemsByKey('endpoint_url'), - render: (endpoint_url: string) => {endpoint_url} + render: (endpoint_url: string) => { + if (isEmpty(endpoint_url)) { + return 'N/A'; + } + + return ( + + {endpoint_url} + + ) + + } }, { title: 'Actions', width: 100, @@ -180,6 +185,7 @@ const SettingsPage: React.FC = () => { onEdit(model)} data-event-category="User Action" diff --git a/app/client/src/pages/Settings/hooks.ts b/app/client/src/pages/Settings/hooks.ts index 4269a73a..951bbe5a 100644 --- a/app/client/src/pages/Settings/hooks.ts +++ b/app/client/src/pages/Settings/hooks.ts @@ -1,4 +1,5 @@ import { useQuery } from "@tanstack/react-query"; +import { CustomModel } from "./SettingsPage"; const BASE_API_URL = import.meta.env.VITE_AMP_URL; @@ -12,22 +13,22 @@ const fetchFilteredModels = async () => { }; -export const deleteModelProvider = async ({ endpoint_id }) => { - const delete_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { +export const deleteModelProvider = async ({ model }: { model: CustomModel }) => { + const delete_resp = await fetch(`/custom_model_endpoints/${model.model_id}/${model.provider_type}`, { method: 'DELETE' }); return await delete_resp.json(); } -export const getModelProvider = async ({ endpoint_id }) => { - const get_model_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { +export const getModelProvider = async ({ model }: { model: CustomModel }) => { + const get_model_resp = await fetch(`/custom_model_endpoints/${model.model_id}/${model.provider_type}`, { method: 'GET' - }); - return await get_model_resp.json(); + }); + return await get_model_resp.json(); } -export const updateModelProvider = async ({ endpoint_id }) => { - const update_model_resp = await fetch(`/custom_model_endpoints/${endpoint_id}`, { +export const updateModelProvider = async ({ model }: { model: CustomModel }) => { + const update_model_resp = await fetch(`/custom_model_endpoints/${model.model_id}/${model.provider_type}`, { method: 'PUT' }); return await update_model_resp.json(); @@ -61,12 +62,12 @@ export const addModelProvider = async (params: any) => { return await model_filtered_resp.json(); } -export const useGetModelProvider = (endpoint_id) => { +export const useGetModelProvider = (model: CustomModel) => { const { data, isLoading, isError, refetch } = useQuery( { queryKey: ['getModelProvider'], - queryFn: () => getModelProvider({ endpoint_id }), + queryFn: () => getModelProvider({ model }), refetchInterval: 15000 } );