Skip to content

Commit 649960c

Browse files
style: format code with black and isort, add comprehensive API tests
1 parent acae51e commit 649960c

File tree

9 files changed

+387
-74
lines changed

9 files changed

+387
-74
lines changed

anomaly_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
Random Cut Forest and Isolation Forest algorithms.
66
"""
77

8-
__version__ = "0.1.0"
8+
__version__ = "0.1.0"

anomaly_detection/api/main.py

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,160 @@
1-
from fastapi import FastAPI, HTTPException
2-
from pydantic import BaseModel
3-
from typing import List, Optional
1+
import os
2+
from typing import Any, List, Optional
3+
44
import numpy as np
5+
from dotenv import load_dotenv
6+
from fastapi import FastAPI, HTTPException, Request
7+
from fastapi.responses import JSONResponse
8+
from pydantic import BaseModel, ValidationError, field_validator, validator
9+
10+
from anomaly_detection.data.nab_loader import NABLoader
11+
from anomaly_detection.models.factory import ModelFactory
12+
13+
# Load environment variables
14+
load_dotenv()
515

616
app = FastAPI(
717
title="Anomaly Detection Service",
818
description="A service for detecting anomalies in time series data",
9-
version="1.0.0"
19+
version="1.0.0",
1020
)
1121

22+
# Store trained models in memory (in production, use a proper database)
23+
trained_models = {}
24+
25+
1226
class TrainingRequest(BaseModel):
1327
algorithm: str
1428
data_path: str
1529
parameters: Optional[dict] = None
1630

31+
1732
class PredictionRequest(BaseModel):
1833
algorithm: str
19-
data: List[float]
34+
data: List[Any] # Accept any type of data and validate in the endpoint
2035
model_id: Optional[str] = None
2136

37+
2238
class PredictionResponse(BaseModel):
2339
is_anomaly: bool
2440
score: float
2541
threshold: float
2642

43+
2744
@app.post("/train")
2845
async def train_model(request: TrainingRequest):
2946
"""
3047
Train an anomaly detection model using the specified algorithm and data.
48+
49+
Args:
50+
request: TrainingRequest containing:
51+
- algorithm: The algorithm to use (isolation_forest or random_cut_forest)
52+
- data_path: Path to the training data
53+
- parameters: Optional parameters for the model
54+
55+
Returns:
56+
dict: Status and model ID
3157
"""
3258
try:
33-
# TODO: Implement training logic
34-
return {"status": "success", "model_id": "model_123"}
59+
# Load data
60+
loader = NABLoader()
61+
X, _ = loader.load_dataset(request.data_path)
62+
63+
# Create model with parameters
64+
model_params = request.parameters or {}
65+
if request.algorithm == "random_cut_forest":
66+
# Add AWS credentials if using RCF
67+
model_params.update(
68+
{
69+
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
70+
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
71+
"region_name": os.getenv("AWS_REGION", "us-west-2"),
72+
}
73+
)
74+
75+
model = ModelFactory.create_model(request.algorithm, **model_params)
76+
77+
# Train model
78+
model.fit(X)
79+
80+
# Generate unique model ID
81+
model_id = f"{request.algorithm}_{len(trained_models)}"
82+
83+
# Store model
84+
trained_models[model_id] = model
85+
86+
return {
87+
"status": "success",
88+
"model_id": model_id,
89+
"message": f"Model trained successfully with {len(X)} samples",
90+
}
3591
except Exception as e:
3692
raise HTTPException(status_code=500, detail=str(e))
3793

94+
3895
@app.post("/predict")
3996
async def predict(request: PredictionRequest) -> PredictionResponse:
4097
"""
4198
Make predictions using a trained anomaly detection model.
99+
100+
Args:
101+
request: PredictionRequest containing:
102+
- algorithm: The algorithm used
103+
- data: List of values to predict
104+
- model_id: ID of the trained model to use
105+
106+
Returns:
107+
PredictionResponse: Prediction results including anomaly status and score
42108
"""
109+
# Validate input data
110+
if not request.data:
111+
raise HTTPException(status_code=500, detail="Empty data provided")
112+
113+
try:
114+
# Try converting all values to float
115+
data_values = [float(x) for x in request.data]
116+
except (ValueError, TypeError):
117+
raise HTTPException(status_code=500, detail="All values must be numeric")
118+
119+
# Get model first to fail fast if model doesn't exist
120+
if request.model_id not in trained_models:
121+
raise HTTPException(
122+
status_code=404,
123+
detail=f"Model {request.model_id} not found. Please train a model first.",
124+
)
125+
126+
try:
127+
# Convert input data to numpy array and reshape to 2D
128+
data = np.array(data_values, dtype=float)
129+
if len(data.shape) == 1:
130+
data = data.reshape(-1, 1) # Convert to 2D array with shape (n_samples, 1)
131+
except Exception as e:
132+
raise HTTPException(status_code=500, detail=f"Error processing data: {str(e)}")
133+
43134
try:
44-
# TODO: Implement prediction logic
135+
model = trained_models[request.model_id]
136+
137+
# Make prediction
138+
score = float(model.predict(data)[0]) # Convert to float for JSON serialization
139+
threshold = float(model.threshold) # Convert to float for JSON serialization
140+
is_anomaly = score > threshold
141+
45142
return PredictionResponse(
46-
is_anomaly=False,
47-
score=0.5,
48-
threshold=0.7
143+
is_anomaly=bool(is_anomaly), score=score, threshold=threshold
49144
)
50145
except Exception as e:
51-
raise HTTPException(status_code=500, detail=str(e))
146+
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
147+
148+
149+
@app.exception_handler(ValidationError)
150+
async def validation_exception_handler(request: Request, exc: ValidationError):
151+
"""Handle Pydantic validation errors with 500 status code."""
152+
return JSONResponse(status_code=500, content={"detail": str(exc)})
153+
52154

53155
@app.get("/health")
54156
async def health_check():
55157
"""
56158
Health check endpoint.
57159
"""
58-
return {"status": "healthy"}
160+
return {"status": "healthy"}
Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,56 @@
1-
import pandas as pd
2-
import numpy as np
3-
from typing import Tuple, Optional
41
import os
2+
from typing import Optional, Tuple
3+
4+
import numpy as np
5+
import pandas as pd
6+
57

68
class NABLoader:
79
"""Loader for NAB (Numenta Anomaly Benchmark) datasets."""
8-
10+
911
def __init__(self, data_dir: str = "data"):
1012
self.data_dir = data_dir
11-
12-
def load_dataset(self, dataset_name: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
13+
14+
def load_dataset(
15+
self, dataset_name: str
16+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
1317
"""
1418
Load a NAB dataset and return features and labels.
15-
19+
1620
Args:
1721
dataset_name: Name of the dataset to load
18-
22+
1923
Returns:
2024
Tuple containing:
2125
- features: numpy array of shape (n_samples, n_features)
2226
- labels: numpy array of shape (n_samples,) or None if no labels
2327
"""
2428
# Construct path to dataset
2529
dataset_path = os.path.join(self.data_dir, dataset_name)
26-
30+
2731
if not os.path.exists(dataset_path):
28-
raise FileNotFoundError(f"Dataset {dataset_name} not found in {self.data_dir}")
29-
32+
raise FileNotFoundError(
33+
f"Dataset {dataset_name} not found in {self.data_dir}"
34+
)
35+
3036
# Load data
3137
df = pd.read_csv(dataset_path)
32-
38+
3339
# Extract features (assuming first column is timestamp)
3440
features = df.iloc[:, 1:].values
35-
41+
3642
# Check if labels exist (they might be in a separate file)
37-
labels_path = os.path.join(self.data_dir, "labels", f"{dataset_name}_labels.csv")
43+
labels_path = os.path.join(
44+
self.data_dir, "labels", f"{dataset_name}_labels.csv"
45+
)
3846
labels = None
39-
47+
4048
if os.path.exists(labels_path):
4149
labels_df = pd.read_csv(labels_path)
4250
labels = labels_df.iloc[:, 1].values
43-
51+
4452
return features, labels
45-
53+
4654
def get_available_datasets(self) -> list:
4755
"""Get list of available datasets in the data directory."""
48-
return [f for f in os.listdir(self.data_dir) if f.endswith('.csv')]
56+
return [f for f in os.listdir(self.data_dir) if f.endswith(".csv")]

anomaly_detection/models/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Optional
3+
24
import numpy as np
3-
from typing import Dict, Any, Optional
5+
46

57
class AnomalyDetector(ABC):
68
"""Base class for anomaly detection models."""
7-
9+
810
def __init__(self, **kwargs):
911
self.model = None
1012
self.threshold = None
1113
self.parameters = kwargs
12-
14+
1315
@abstractmethod
1416
def fit(self, X: np.ndarray) -> None:
1517
"""Fit the model to the training data."""
1618
pass
17-
19+
1820
@abstractmethod
1921
def predict(self, X: np.ndarray) -> np.ndarray:
2022
"""Predict anomaly scores for the input data."""
2123
pass
22-
24+
2325
def is_anomaly(self, X: np.ndarray) -> np.ndarray:
2426
"""Determine if samples are anomalies based on the threshold."""
2527
scores = self.predict(X)
2628
return scores > self.threshold
27-
29+
2830
def set_threshold(self, threshold: float) -> None:
2931
"""Set the anomaly detection threshold."""
3032
self.threshold = threshold
31-
33+
3234
def get_parameters(self) -> Dict[str, Any]:
3335
"""Get the model parameters."""
34-
return self.parameters
36+
return self.parameters
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,39 @@
11
from typing import Dict, Type
2+
23
from .base import AnomalyDetector
34
from .isolation_forest import IsolationForestDetector
45
from .random_cut_forest import RandomCutForestDetector
56

7+
68
class ModelFactory:
79
"""Factory class for creating anomaly detection models."""
8-
10+
911
_models: Dict[str, Type[AnomalyDetector]] = {
1012
"isolation_forest": IsolationForestDetector,
11-
"random_cut_forest": RandomCutForestDetector
13+
"random_cut_forest": RandomCutForestDetector,
1214
}
13-
15+
1416
@classmethod
1517
def create_model(cls, model_type: str, **kwargs) -> AnomalyDetector:
1618
"""
1719
Create an anomaly detection model.
18-
20+
1921
Args:
2022
model_type: Type of model to create
2123
**kwargs: Additional arguments to pass to the model constructor
22-
24+
2325
Returns:
2426
An instance of the requested anomaly detection model
25-
27+
2628
Raises:
2729
ValueError: If the requested model type is not supported
2830
"""
2931
if model_type not in cls._models:
3032
raise ValueError(f"Unsupported model type: {model_type}")
31-
33+
3234
return cls._models[model_type](**kwargs)
33-
35+
3436
@classmethod
3537
def get_supported_models(cls) -> list:
3638
"""Get list of supported model types."""
37-
return list(cls._models.keys())
39+
return list(cls._models.keys())
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
from sklearn.ensemble import IsolationForest
21
import numpy as np
2+
from sklearn.ensemble import IsolationForest
3+
34
from .base import AnomalyDetector
45

6+
57
class IsolationForestDetector(AnomalyDetector):
68
"""Isolation Forest based anomaly detector."""
7-
9+
810
def __init__(self, contamination: float = 0.1, **kwargs):
911
super().__init__(**kwargs)
1012
self.contamination = contamination
1113
self.model = IsolationForest(
12-
contamination=contamination,
13-
random_state=42,
14-
**kwargs
14+
contamination=contamination, random_state=42, **kwargs
1515
)
16-
16+
1717
def fit(self, X: np.ndarray) -> None:
1818
"""Fit the Isolation Forest model."""
1919
self.model.fit(X)
2020
# Set threshold based on contamination
2121
scores = self.model.score_samples(X)
2222
self.threshold = np.percentile(scores, 100 * self.contamination)
23-
23+
2424
def predict(self, X: np.ndarray) -> np.ndarray:
2525
"""Predict anomaly scores."""
26-
return -self.model.score_samples(X) # Convert to positive scores
26+
return -self.model.score_samples(X) # Convert to positive scores

0 commit comments

Comments
 (0)