1- from fastapi import FastAPI , HTTPException
2- from pydantic import BaseModel
3- from typing import List , Optional
1+ import os
2+ from typing import Any , List , Optional
3+
44import numpy as np
5+ from dotenv import load_dotenv
6+ from fastapi import FastAPI , HTTPException , Request
7+ from fastapi .responses import JSONResponse
8+ from pydantic import BaseModel , ValidationError , field_validator , validator
9+
10+ from anomaly_detection .data .nab_loader import NABLoader
11+ from anomaly_detection .models .factory import ModelFactory
12+
13+ # Load environment variables
14+ load_dotenv ()
515
616app = FastAPI (
717 title = "Anomaly Detection Service" ,
818 description = "A service for detecting anomalies in time series data" ,
9- version = "1.0.0"
19+ version = "1.0.0" ,
1020)
1121
22+ # Store trained models in memory (in production, use a proper database)
23+ trained_models = {}
24+
25+
1226class TrainingRequest (BaseModel ):
1327 algorithm : str
1428 data_path : str
1529 parameters : Optional [dict ] = None
1630
31+
1732class PredictionRequest (BaseModel ):
1833 algorithm : str
19- data : List [float ]
34+ data : List [Any ] # Accept any type of data and validate in the endpoint
2035 model_id : Optional [str ] = None
2136
37+
2238class PredictionResponse (BaseModel ):
2339 is_anomaly : bool
2440 score : float
2541 threshold : float
2642
43+
2744@app .post ("/train" )
2845async def train_model (request : TrainingRequest ):
2946 """
3047 Train an anomaly detection model using the specified algorithm and data.
48+
49+ Args:
50+ request: TrainingRequest containing:
51+ - algorithm: The algorithm to use (isolation_forest or random_cut_forest)
52+ - data_path: Path to the training data
53+ - parameters: Optional parameters for the model
54+
55+ Returns:
56+ dict: Status and model ID
3157 """
3258 try :
33- # TODO: Implement training logic
34- return {"status" : "success" , "model_id" : "model_123" }
59+ # Load data
60+ loader = NABLoader ()
61+ X , _ = loader .load_dataset (request .data_path )
62+
63+ # Create model with parameters
64+ model_params = request .parameters or {}
65+ if request .algorithm == "random_cut_forest" :
66+ # Add AWS credentials if using RCF
67+ model_params .update (
68+ {
69+ "aws_access_key_id" : os .getenv ("AWS_ACCESS_KEY_ID" ),
70+ "aws_secret_access_key" : os .getenv ("AWS_SECRET_ACCESS_KEY" ),
71+ "region_name" : os .getenv ("AWS_REGION" , "us-west-2" ),
72+ }
73+ )
74+
75+ model = ModelFactory .create_model (request .algorithm , ** model_params )
76+
77+ # Train model
78+ model .fit (X )
79+
80+ # Generate unique model ID
81+ model_id = f"{ request .algorithm } _{ len (trained_models )} "
82+
83+ # Store model
84+ trained_models [model_id ] = model
85+
86+ return {
87+ "status" : "success" ,
88+ "model_id" : model_id ,
89+ "message" : f"Model trained successfully with { len (X )} samples" ,
90+ }
3591 except Exception as e :
3692 raise HTTPException (status_code = 500 , detail = str (e ))
3793
94+
3895@app .post ("/predict" )
3996async def predict (request : PredictionRequest ) -> PredictionResponse :
4097 """
4198 Make predictions using a trained anomaly detection model.
99+
100+ Args:
101+ request: PredictionRequest containing:
102+ - algorithm: The algorithm used
103+ - data: List of values to predict
104+ - model_id: ID of the trained model to use
105+
106+ Returns:
107+ PredictionResponse: Prediction results including anomaly status and score
42108 """
109+ # Validate input data
110+ if not request .data :
111+ raise HTTPException (status_code = 500 , detail = "Empty data provided" )
112+
113+ try :
114+ # Try converting all values to float
115+ data_values = [float (x ) for x in request .data ]
116+ except (ValueError , TypeError ):
117+ raise HTTPException (status_code = 500 , detail = "All values must be numeric" )
118+
119+ # Get model first to fail fast if model doesn't exist
120+ if request .model_id not in trained_models :
121+ raise HTTPException (
122+ status_code = 404 ,
123+ detail = f"Model { request .model_id } not found. Please train a model first." ,
124+ )
125+
126+ try :
127+ # Convert input data to numpy array and reshape to 2D
128+ data = np .array (data_values , dtype = float )
129+ if len (data .shape ) == 1 :
130+ data = data .reshape (- 1 , 1 ) # Convert to 2D array with shape (n_samples, 1)
131+ except Exception as e :
132+ raise HTTPException (status_code = 500 , detail = f"Error processing data: { str (e )} " )
133+
43134 try :
44- # TODO: Implement prediction logic
135+ model = trained_models [request .model_id ]
136+
137+ # Make prediction
138+ score = float (model .predict (data )[0 ]) # Convert to float for JSON serialization
139+ threshold = float (model .threshold ) # Convert to float for JSON serialization
140+ is_anomaly = score > threshold
141+
45142 return PredictionResponse (
46- is_anomaly = False ,
47- score = 0.5 ,
48- threshold = 0.7
143+ is_anomaly = bool (is_anomaly ), score = score , threshold = threshold
49144 )
50145 except Exception as e :
51- raise HTTPException (status_code = 500 , detail = str (e ))
146+ raise HTTPException (status_code = 500 , detail = f"Prediction error: { str (e )} " )
147+
148+
149+ @app .exception_handler (ValidationError )
150+ async def validation_exception_handler (request : Request , exc : ValidationError ):
151+ """Handle Pydantic validation errors with 500 status code."""
152+ return JSONResponse (status_code = 500 , content = {"detail" : str (exc )})
153+
52154
53155@app .get ("/health" )
54156async def health_check ():
55157 """
56158 Health check endpoint.
57159 """
58- return {"status" : "healthy" }
160+ return {"status" : "healthy" }
0 commit comments