diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 00000000..220d21ac --- /dev/null +++ b/test/.gitignore @@ -0,0 +1,13 @@ +reports/ +dataset/ +logs/ +result_outputs/ +results/ +.cache/ +backup/ +$null +*__pycache__/ +.* +*.log +start.bat +!.gitignore \ No newline at end of file diff --git a/test/README.md b/test/README.md new file mode 100644 index 00000000..1e11da7e --- /dev/null +++ b/test/README.md @@ -0,0 +1,179 @@ +# Pytest +[简体中文](README_zh.md) +A comprehensive Pytest testing framework featuring configuration management, database integration, performance testing, and HTML report generation. + +## 📋 Features + +- **Modern Testing Framework**: Complete test solution built on Pytest 7.0+ +- **Configuration Management**: YAML-based config with thread-safe singleton pattern +- **Database Integration**: Built-in MySQL support with automatic result storage +- **HTML Reports**: Auto-generated pytest HTML test reports +- **Tagging System**: Multi-dimensional test tags (stage, feature, platform, etc.) + +## 🗂️ Project Structure + +``` +pytest_demo/ +├── common/ # Common modules +│ ├── __init__.py +│ ├── config_utils.py # Configuration utilities +│ ├── db_utils.py # Database utilities +│ └── capture_utils # Return-value capture utilities +├── results/ # Result storage folder +├── suites/ # Test suites +│ ├── UnitTest # Unit tests +│ ├── Feature # Feature tests +│ └── E2E/ # End-to-end tests +│ └── test_demo_performance.py # Sample test file +├── config.yaml # Main config file +├── conftest.py # Pytest config +├── pytest.ini # Pytest settings +├── requirements.txt # Dependencies +└── README.md # This doc (CN) +``` + +## 🚀 Quick Start + +### Prerequisites + +- Python 3.8+ +- MySQL 5.7+ (optional, for DB features) +- Git + +### Installation + +1. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +2. **Configure database** (optional) + + Edit `config.yaml`: + ```yaml + database: + backup: "results/" + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" + ``` + +3. **Run tests** + ```bash + # Run all tests + pytest + + # Run tests by tag + pytest --stage=1 + pytest --feature=performance + ``` + +## ⚙️ Configuration + +### config.yaml + +Full YAML-based config. Key sections: + +- **reports**: Report settings (HTML, timestamp, etc.) +- **database**: MySQL connection details + +## 🧪 Test Examples + +### Basic functional test + +```python +# suites/E2E/test_demo_performance.py +import pytest + +@pytest.fixture(scope="module", name="calc") +def calculator(): + return Calculator() + +@pytest.mark.feature("mark") +class TestCalculator: + def test_add(self, calc): + assert calc.add(1, 2) == 3 + + def test_divide_by_zero(self, calc): + with pytest.raises(ZeroDivisionError): + calc.divide(6, 0) +``` + +## 🏷️ Tagging System + +Multi-dimensional tags supported: + +### Stage tags +- `stage(0)`: Unit tests +- `stage(1)`: Smoke tests +- `stage(2)`: Regression tests +- `stage(3)`: Release tests + +### Functional tags +- `feature`: Module tag +- `platform`: Platform tag (GPU/NPU) + +### Usage + +```bash +# Run smoke tests and above +pytest --stage=1+ + +# Run by feature +pytest --feature=performance +pytest --feature=performance,reliability + +# Run by platform +pytest --platform=gpu +``` + +### HTML Reports + +Auto-generated timestamped HTML reports: +- Location: `reports/pytest_YYYYMMDD_HHMMSS/report.html` +- Detailed results, errors, timing +- Customizable title & style + +### Database Storage + +If enabled, results are auto-saved to MySQL. +To add new record types, ask DB admin to create tables; otherwise only local files are used. + +Example: +```python +@pytest.mark.feature("capture") # Must be top decorator +@export_vars +def test_capture_mix(): + assert 1 == 1 + return { + '_name': 'demo', + '_data': { + 'length': 10086, # single value + 'accuracy': [0.1, 0.2, 0.3], # list + 'loss': [0.1, 0.2, 0.3], # list + } + } +``` + +### Config Access + +Read settings easily: +```python +from common.config_utils import config_utils +# Get config +db_config = config_utils.get_config("database") +api_config = config_utils.get_nested_config("easyPerf.api") +``` + +## 🛠️ Development Guide + +### Adding New Tests + +1. Create test files under `suites/` categories +2. Apply appropriate tags +3. Naming: `test_*.py` +4. Use fixtures & marks for data management +5. Keep custom marks concise and aligned with overall goals \ No newline at end of file diff --git a/test/README_zh.md b/test/README_zh.md new file mode 100644 index 00000000..26b0f393 --- /dev/null +++ b/test/README_zh.md @@ -0,0 +1,182 @@ +# Pytest 项目 + Pytest 测试框架,包括配置管理、数据库集成、性能测试和 HTML 报告生成。 + +## 📋 项目特性 + +- **现代化测试框架**: 基于 Pytest 7.0+ 的完整测试解决方案 +- **配置管理**: 支持 YAML 配置文件,线程安全的单例模式配置管理 +- **数据库集成**: 内置 MySQL 数据库支持,自动结果存储 +- **HTML 报告**: 自动生成pytest HTML 测试报告 +- **标记系统**: 支持多维度测试标记(阶段、功能、平台等) + +## 🗂️ 项目结构 + +``` +pytest_demo/ +├── common/ # 公共模块 +│ ├── __init__.py +│ ├── config_utils.py # 配置管理工具 +│ ├── db_utils.py # 数据库工具 +│ └── capture_utils # 返回值捕获工具 +├── results/ # 结果存储目录 +├── suites/ # 测试套件 +│ ├── UnitTest # 单元测试 +│ ├── Feature # 功能测试 +│ └── E2E/ # 端到端测试 +│ └── test_demo_performance.py # 示例测试文件 +├── config.yaml # 主配置文件 +├── conftest.py # Pytest 配置文件 +├── pytest.ini # Pytest 配置 +├── requirements.txt # 项目依赖 +└── README.md # 本文档 +``` + +## 🚀 快速开始 + +### 环境要求 + +- Python 3.8+ +- MySQL 5.7+ (可选,用于数据库功能) +- Git + +### 安装步骤 + +1. **安装依赖** + ```bash + pip install -r requirements.txt + ``` + +2. **配置数据库**(可选) + + 编辑 `config.yaml` 文件中的数据库配置: + ```yaml + database: + backup: "results/" + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" + ``` + +3. **运行测试** + ```bash + # 运行所有测试 + pytest + + # 运行特定标记的测试 + pytest --stage=1 + pytest --feature=performance + ``` + +## ⚙️ 配置说明 + + +### config.yaml 配置 + +项目支持完整的 YAML 配置管理,主要配置项包括: + +- **reports**: 报告配置(HTML 报告、时间戳等) +- **database**: 数据库连接配置 + +## 🧪 测试示例 + +### 基础功能测试 + +```python +# suites/E2E/test_demo_performance.py +import pytest + +@pytest.fixture(scope="module", name="calc") +def calculator(): + return Calculator() + +@pytest.mark.feature("mark") +class TestCalculator: + def test_add(self, calc): + assert calc.add(1, 2) == 3 + + def test_divide_by_zero(self, calc): + with pytest.raises(ZeroDivisionError): + calc.divide(6, 0) +``` + +## 🏷️ 测试标记系统 + +项目支持多维度的测试标记: + +### 测试阶段标记 +- `stage(0)`: 单元测试 +- `stage(1)`: 冒烟测试 +- `stage(2)`: 回归测试 +- `stage(3)`: 发布测试 + +### 功能标记 +- `feature`: 功能模块标记 +- `platform`: 平台标记(GPU/NPU) + +### 使用示例 + +```bash +# 运行冒烟测试及以上的所有测试 +pytest --stage=1+ + +# 运行特定功能的测试 +pytest --feature=performance +pytest --feature=performance, reliability +# 运行特定平台的测试 +pytest --platform=gpu +``` + + +### HTML 报告 + +项目自动生成带时间戳的 HTML 测试报告: +- 报告位置:`reports/pytest_YYYYMMDD_HHMMSS/report.html` +- 包含详细的测试结果、错误信息和执行时间 +- 支持自定义报告标题和样式 + +### 数据库存储 + +如果启用数据库功能,测试结果会自动存储到 MySQL 数据库。 +若需要新增记录,请联系管理人员在数据库新增对应表;否则只能保存至本地文件。 +使用方式示例: +```python +@pytest.mark.feature("capture") # pytest 的标签必须在上面,否则无法正常使用标记功能 +@export_vars +def test_capture_mix(): + assert 1 == 1 + return { + '_name': 'demo', + '_data': { + 'length': 10086, # single value + 'accuracy': [0.1, 0.2, 0.3], # list + 'loss': [0.1, 0.2, 0.3], # list + } + } + +``` + + +### 配置管理 + +可以通过配置工具便捷读取参数: +```python +from common.config_utils import config_utils +# 获取配置 +db_config = config_utils.get_config("database") +api_config = config_utils.get_nested_config("easyPerf.api") +``` + + + +## 🛠️ 开发指南 + +### 添加新测试 + +1. 在 `suites/` 目录下的各个分类下创建新的测试文件 +2. 使用适当的测试标记 +3. 遵循命名规范:`test_*.py` +4. 使用 fixture 及mark 进行测试数据管理 +5. 自定义 mark 标签不易过细,应当与整体功能目标相符合 \ No newline at end of file diff --git a/test/common/__init__.py b/test/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/capture_utils.py b/test/common/capture_utils.py new file mode 100644 index 00000000..ee12ed2a --- /dev/null +++ b/test/common/capture_utils.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, List + +from common.db_utils import write_to_db + + +def _align_and_split(name: str, data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Align a mixed data package (single values and/or lists) and split it into + """ + if not data: + return [] + + aligned: Dict[str, List[Any]] = {} + lengths: Dict[str, int] = {} + for k, v in data.items(): + if isinstance(v, (list, tuple)): + aligned[k] = list(v) + else: + aligned[k] = [v] + lengths[k] = len(aligned[k]) + + max_len = max(lengths.values()) + + for k, lst in aligned.items(): + if len(lst) < max_len: + lst.extend([lst[-1]] * (max_len - len(lst))) + + return [{k: aligned[k][i] for k in aligned} for i in range(max_len)] + + +def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]: + """ + Unified post-processing entry point. Supports two calling styles: + """ + results = [] + if "_data" in kwargs: + name = kwargs.get("_name", table_name) + results = _align_and_split(name, kwargs["_data"]) + for result in results: + write_to_db(name, result) + return results + return [] + + +# ---------------- decorator ---------------- +def export_vars(func): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + # If the function returns a dict containing '_data' or 'data', post-process it + if isinstance(result, dict): + if "_data" in result or "data" in result: + return post_process(func.__name__, **result) + # Otherwise return unchanged + return result + + return wrapper + + +# ---------------- usage examples ---------------- +@export_vars +def capture(): + """All single values via 'name' + 'data'""" + return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}} + + +@export_vars +def capture_list(): + """All lists via '_name' + '_data'""" + return { + "_name": "demo", + "_data": { + "accuracy": [0.1, 0.2, 0.3], + "loss": [0.1, 0.2, 0.3], + }, + } + + +@export_vars +def capture_mix(): + """Mixed single + lists via '_name' + '_data'""" + return { + "_name": "demo", + "_data": { + "length": 10086, # single value + "accuracy": [0.1, 0.2, 0.3], # list + "loss": [0.1, 0.2, 0.3], # list + }, + } + + +# quick test +if __name__ == "__main__": + print("capture(): ", capture()) + print("capture_list(): ", capture_list()) + print("capture_mix(): ", capture_mix()) diff --git a/test/common/config_utils.py b/test/common/config_utils.py new file mode 100644 index 00000000..106f783e --- /dev/null +++ b/test/common/config_utils.py @@ -0,0 +1,86 @@ +import os +import threading +from typing import Any, Dict + +import yaml + + +class ConfigUtils: + """ + Singleton Configuration Utility + Provides methods to read and access YAML configuration files. + """ + + _instance = None + _lock = threading.Lock() # Ensure thread-safe singleton creation + + def __init__(self): + self._config = None + + def __new__(cls, config_file: str = None): + # Double-checked locking + if cls._instance is None: + with cls._lock: + if cls._instance is None: + instance = super().__new__(cls) + instance._init_config(config_file) + cls._instance = instance + return cls._instance + + def _init_config(self, config_file: str = None): + """Initialize configuration file path and load config""" + if config_file is None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_file = os.path.join(current_dir, "..", "config.yaml") + + self.config_file = os.path.abspath(config_file) + self._config = None # Lazy load + + def _load_config(self) -> Dict[str, Any]: + """Internal method to read configuration from file""" + try: + with open(self.config_file, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except FileNotFoundError: + print(f"[WARN] Config file not found: {self.config_file}") + return {} + except yaml.YAMLError as e: + print(f"[ERROR] Failed to parse YAML config: {e}") + return {} + + def read_config(self) -> Dict[str, Any]: + """Read configuration file (lazy load)""" + if self._config is None: + self._config = self._load_config() + return self._config + + def reload_config(self): + """Force reload configuration file""" + self._config = self._load_config() + + def get_config(self, key: str, default: Any = None) -> Any: + """Get top-level configuration item""" + config = self.read_config() + return config.get(key, default) + + def get_nested_config(self, key_path: str, default: Any = None) -> Any: + """Get nested configuration, e.g., 'influxdb.host'""" + config = self.read_config() + keys = key_path.split(".") + value = config + try: + for k in keys: + value = value[k] + return value + except (KeyError, TypeError): + return default + + +# Global instance +config_utils = ConfigUtils() + +if __name__ == "__main__": + print("DataBase config:", config_utils.get_config("database")) + print( + "DataBase host:", config_utils.get_nested_config("database.host", "localhost") + ) diff --git a/test/common/db_utils.py b/test/common/db_utils.py new file mode 100644 index 00000000..089af43b --- /dev/null +++ b/test/common/db_utils.py @@ -0,0 +1,183 @@ +import json +import logging +import threading +from pathlib import Path +from typing import Any, Dict, Optional + +import peewee +from common.config_utils import config_utils as config_instance +from peewee import AutoField, Model, MySQLDatabase, TextField + +logger = logging.getLogger("db_handler") +logger.setLevel(logging.DEBUG) + +# Avoid adding handlers multiple times +if not logger.handlers: + logger.setLevel(logging.DEBUG) + +# Global DB instance and lock for thread-safe singleton +_db_instance: Optional[MySQLDatabase] = None +_db_lock = threading.Lock() +_test_build_id: Optional[str] = None +_backup_path: Optional[Path] = None +_db_enabled: bool = False # from config + + +def _get_db() -> Optional[MySQLDatabase]: + """Return a singleton MySQLDatabase instance based on YAML configuration.""" + global _db_instance, _backup_path, _db_enabled + + if _db_instance is None: + with _db_lock: + if _db_instance is None: + db_config = config_instance.get_config("database", {}) + _db_enabled = db_config.get("enabled", False) + + backup_str = db_config.get("backup", "results/") + _backup_path = Path(backup_str).resolve() + _backup_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Backup directory set to: {_backup_path}") + + if not _db_enabled: + return None + + try: + _db_instance = MySQLDatabase( + db_config.get("name", "test_db"), + user=db_config.get("user", "root"), + password=db_config.get("password", ""), + host=db_config.get("host", "localhost"), + port=db_config.get("port", 3306), + charset=db_config.get("charset", "utf8mb4"), + ) + logger.info( + f"Database instance created for: {_db_instance.database}" + ) + except Exception as e: + logger.error(f"Failed to create database instance: {e}") + _db_instance = None + + return _db_instance + + +def _set_test_build_id(build_id: Optional[str] = None) -> None: + """Set or generate a unique test build ID.""" + global _test_build_id + _test_build_id = build_id or "default_build_id" + logger.debug(f"Test build ID set to: {_test_build_id}") + + +def _get_test_build_id() -> str: + """Return the current test build ID, generating one if necessary.""" + global _test_build_id + if _test_build_id is None: + _set_test_build_id() + return _test_build_id + + +class BaseEntity(Model): + """Base PeeWee model class using the singleton database.""" + + class Meta: + database = _get_db() + + +def _backup_to_file(table_name: str, data: Dict[str, Any]) -> None: + """Write data to a JSON Lines (.jsonl) file in the backup directory.""" + if not _backup_path: + logger.warning("Backup path is not set. Skipping backup.") + return + + file_path = _backup_path / f"{table_name}.jsonl" + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + with file_path.open("a", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + f.write("\n") + logger.info(f"Data backed up to {file_path}") + except Exception as e: + logger.error(f"Failed to write backup file {file_path}: {e}") + + +def write_to_db(table_name: str, data: Dict[str, Any]) -> bool: + """ + Attempt to insert data into the specified database table. + If the table doesn't exist or an error occurs, back up to a JSONL file. + """ + db = _get_db() + data["test_build_id"] = _get_test_build_id() + + # Skip DB entirely if disabled + if not _db_enabled or db is None: + _backup_to_file(table_name, data) + return False + + try: + if not db.table_exists(table_name): + logger.warning(f"Table '{table_name}' does not exist. Writing to backup.") + _backup_to_file(table_name, data) + return False + + # Get existing columns and filter data + columns = db.get_columns(table_name) + col_names = {col.name for col in columns} + filtered_data = {k: v for k, v in data.items() if k in col_names} + + # Build dynamic model for insertion + fields = {"id": AutoField()} + for col in columns: + if col.name != "id": + fields[col.name] = TextField(null=True) + + DynamicEntity = type( + f"{table_name.capitalize()}DynamicModel", + (BaseEntity,), + { + "Meta": type("Meta", (), {"database": db, "table_name": table_name}), + **fields, + }, + ) + + with db.atomic(): + DynamicEntity.insert(filtered_data).execute() + logger.info(f"Successfully inserted data into table '{table_name}'.") + return True + + except peewee.PeeweeException as e: + logger.error( + f"Database write error for table '{table_name}': {e}", exc_info=True + ) + except Exception as e: + logger.critical( + f"Unexpected error during DB write for '{table_name}': {e}", exc_info=True + ) + + # Fallback to backup on any failure + _backup_to_file(table_name, data) + return False + + +def database_connection(build_id: str) -> None: + """Test database connection and set the build ID.""" + logger.info(f"Setting test build ID: {build_id}") + _set_test_build_id(build_id) + + db = _get_db() + if not _db_enabled: + logger.info("Database connection skipped because enabled=false.") + return + + if db is None: + logger.error("No database instance available.") + return + + logger.info(f"Attempting connection to database: {db.database}") + try: + db.connect(reuse_if_open=True) + logger.info("Database connection successful.") + except Exception as e: + logger.error(f"Database connection failed: {e}", exc_info=True) + finally: + if not db.is_closed(): + db.close() + logger.debug("Database connection closed.") diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py new file mode 100644 index 00000000..d3df80c6 --- /dev/null +++ b/test/common/llmperf/run_inference.py @@ -0,0 +1,140 @@ +import json +import os +import random +from pathlib import Path +from typing import List, Dict, Any + +import yaml + +from common.llmperf.utils.token_benchmark import run_token_benchmark +from common.llmperf.utils.utils import reset_prefill_cache + + +def run_test_cases(test_cases, timestamp_dir, model, server_url, tokenizer_path): + """ + Execute all test cases and return the list of failed case indices and hit_rate mapping for each case. + Parameters: + test_cases — List of test cases read from the configuration file + timestamp_dir — Directory Path to save results + model — Model name + server_url — Base URL of the service + tokenizer_path— Path to the tokenizer + Returns: + failed_cases — List of failed case indices + """ + print(f"[INFO] Total {len(test_cases)} test cases to be executed") + all_summaries = [] + failed_case = [] + + # Clear proxy environment variables + env = os.environ.copy() + env.pop('http_proxy', None) + env.pop('https_proxy', None) + + for i, case in enumerate(test_cases): + print(f"\n>>> Executing test case {i + 1} <<<") + reset_prefill_cache(env, server_url) + # Use a fixed random_seed for each test to control PC hit_rate + random_seed = random.randint(1, 100000) + summary = {} + + # Read parameters from configuration file + mean_input = case.get("mean_input_tokens", 5000) + stddev_input = case.get("stddev_input_tokens", 0) + mean_output = case.get("mean_output_tokens", 1000) + stddev_output = case.get("stddev_output_tokens", 0) + max_completed = case.get("max_num_completed_requests", 1) + concurrent = case.get("concurrent_requests", 1) + llm_api = case.get("llm_api", "openai") + additional_sampling_params = case.get("additional_sampling_params", "{}") + timeout = case.get("timeout", 60000) + hit_rate = case.get("hit_rate", 0) + + try: + # Determine if two runs are needed (PC hit_rate test) + if hit_rate == 0: + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i} + ) + else: + print(f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate} %") + # hit_rate > 0: first prefill mode + prefill_mean_input = int(mean_input * hit_rate / 100) + print(f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}") + run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=prefill_mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=2, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "prefill"} + ) + # Then run normal mode + print("[INFO] Prefill completed, switching to normal mode execution") + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"} + ) + all_summaries.append(summary) + except Exception as e: + failed_case.append(i) + + return all_summaries, failed_case + +def inference_results(): + config_file = Path(__file__).parent.parent.parent / "config.yaml" + all_smmaries = {} + print("[INFO] Initialization complete, starting main process") + print(f"[INFO] Reading configuration file: {config_file}") + with open(config_file, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + model = config.get("llm_connection", {}).get("model", "") + server_url = config.get("llm_connection", {}).get("server_url", "") + tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "") + test_cases = config.get("llmperf_test_cases", []) + timestamp_dir = Path("result_outputs") + timestamp_dir.mkdir(parents=True, exist_ok=True) + print(f"[INFO] Created results directory: {timestamp_dir}") + + all_summaries, failed_cases = run_test_cases(test_cases, timestamp_dir, model, server_url, tokenizer_path) + total = len(test_cases) + print(f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}") + if failed_cases: + print(f"[WARN] Failed case indices: {failed_cases}") + return all_summaries \ No newline at end of file diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py new file mode 100644 index 00000000..3b05b437 --- /dev/null +++ b/test/common/llmperf/utils/common_metrics.py @@ -0,0 +1,17 @@ +# TODO (Avnishn): compute metrics in class +INTER_TOKEN_LAT = "inter_token_latency_s" +TTFT = "ttft_s" +E2E_LAT = "end_to_end_latency_s" +NUM_INPUT_TOKENS = "number_input_tokens" +NUM_OUTPUT_TOKENS = "number_output_tokens" +NUM_TOTAL_TOKENS = "number_total_tokens" +REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s" +ERROR_MSG = "error_msg" +ERROR_CODE = "error_code" +ERROR_CODE_FREQ = "error_code_frequency" +NUM_ERRORS = "number_errors" +OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s" +NUM_COMPLETED_REQUESTS = "num_completed_requests" +COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min" +ERROR_RATE = "error_rate" +NUM_REQ_STARTED = "num_requests_started" \ No newline at end of file diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py new file mode 100644 index 00000000..f70e8a7e --- /dev/null +++ b/test/common/llmperf/utils/models.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, Optional, Tuple +from pydantic import BaseModel + + +class RequestConfig(BaseModel): + """The configuration for a request to the LLM API. + + Args: + model: The model to use. + prompt: The prompt to provide to the LLM API. + sampling_params: Additional sampling parameters to send with the request. + For more information see the Router app's documentation for the completions + llm_api: The name of the LLM API to send the request to. + metadata: Additional metadata to attach to the request for logging or validation purposes. + """ + + model: str + prompt: Tuple[str, int] + sampling_params: Optional[Dict[str, Any]] = None + llm_api: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + openai_api_base: Optional[str] = "" \ No newline at end of file diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py new file mode 100644 index 00000000..b24320d0 --- /dev/null +++ b/test/common/llmperf/utils/openai_chat_completions_client.py @@ -0,0 +1,122 @@ +import json +import os +import time +from typing import Any, Dict, Tuple + +import requests + +from common.llmperf.utils.models import RequestConfig + +from common.llmperf.utils import common_metrics + + +class OpenAIChatCompletionsClient(): + """ + used for sending HTTP requests, receiving token streams, measuring latency, etc. + """ + def llm_request(self, request_config: RequestConfig) -> Tuple[Dict[str, Any], str, RequestConfig]: + prompt, prompt_len = request_config.prompt + + message = [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": True, + "ignore_eos": True, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + + time_to_next_token = [] + tokens_received = 0 + ttft = 0.0 + error_response_code = None + generated_text = "" + error_msg = "" + output_throughput = 0.0 + total_request_time = 0.0 + flag = False + + metrics: Dict[str, Any] = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = start_time + + address = request_config.openai_api_base + + if not address: + raise ValueError("the environment variable OPENAI_API_BASE must be set.") + key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg") + if not key: + raise ValueError("the environment variable OPENAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=True, + timeout=180, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + if not chunk: + continue + stem = b"data: " + if chunk.startswith(stem): + chunk = chunk[len(stem):] + # Data might already be bytes or str + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="ignore") + if chunk.strip() == "[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(error_msg) + delta = data["choices"][0]["delta"] + content = delta.get("content", None) or delta.get("reasoning_content", "") + if content: + if tokens_received != 0 and flag == False: + ttft = time.monotonic() - start_time + flag = True + else: + time_to_next_token.append(time.monotonic() - most_recent_received_token_time) + most_recent_received_token_time = time.monotonic() + generated_text += content + + total_request_time = time.monotonic() - start_time + if total_request_time > 0: + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config \ No newline at end of file diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt new file mode 100644 index 00000000..9f13ead4 --- /dev/null +++ b/test/common/llmperf/utils/sonnet.txt @@ -0,0 +1,84 @@ +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Where art thou, Muse, that thou forget'st so long +To speak of that which gives thee all thy might? +Spend'st thou thy fury on some worthless song, +Darkening thy power to lend base subjects light? +Return, forgetful Muse, and straight redeem +In gentle numbers time so idly spent; +Sing to the ear that doth thy lays esteem +And gives thy pen both skill and argument. +Rise, resty Muse, my love's sweet face survey, +If Time have any wrinkle graven there; +If any, be a satire to decay, +And make Time's spoils despised every where. +Give my love fame faster than Time wastes life; +So thou prevent'st his scythe and crooked knife. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +So am I as the rich, whose blessed key +Can bring him to his sweet up-locked treasure, +The which he will not every hour survey, +For blunting the fine point of seldom pleasure. +Therefore are feasts so solemn and so rare, +Since, seldom coming, in the long year set, +Like stones of worth they thinly placed are, +Or captain jewels in the carcanet. +So is the time that keeps you as my chest, +Or as the wardrobe which the robe doth hide, +To make some special instant special blest, +By new unfolding his imprison'd pride. +Blessed are you, whose worthiness gives scope, +Being had, to triumph, being lack'd, to hope. +If there be nothing new, but that which is +Hath been before, how are our brains beguiled, +Which, labouring for invention, bear amiss +The second burden of a former child! +O, that record could with a backward look, +Even of five hundred courses of the sun, +Show me your image in some antique book, +Since mind at first in character was done! +That I might see what the old world could say +To this composed wonder of your frame; +Whether we are mended, or whether better they, +Or whether revolution be the same. +O, sure I am, the wits of former days +To subjects worse have given admiring praise. \ No newline at end of file diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py new file mode 100644 index 00000000..456e739c --- /dev/null +++ b/test/common/llmperf/utils/token_benchmark.py @@ -0,0 +1,362 @@ +import logging +from collections.abc import Iterable +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +import re +import time +import random +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd + +from transformers import AutoTokenizer + +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig +from common.llmperf.utils.openai_chat_completions_client import OpenAIChatCompletionsClient +from common.llmperf.utils.utils import ( + randomly_sample_sonnet_lines_prompt, + LLMPerfResults, + sample_random_positive_int, ) + + +def get_token_throughput_latencies( + model: str, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: Optional[Dict[str, Any]] = None, + concurrent_requests: int = 1, + max_num_completed_requests: int = 500, + test_timeout_s=90, + llm_api="openai", + random_seed: int = None, + openai_api_base: str = "", + tokenizer_path: str = None, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]: + """Get the token throughput and latencies for the given model. + + Args: + model: The name of the model to query. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + test_timeout_s: The amount of time to run the test for before reporting results. + llm_api: The name of the llm api to use. Either "openai" or "litellm". + + Returns: + A summary of the performance metrics collected across all completed requests + (e.g. throughput, latencies, etc.) + The individual metrics for each request. + """ + random.seed(random_seed) + + print(f"Using tokenizer:{tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + get_token_length = lambda text: len(tokenizer.encode(text)) + + if not additional_sampling_params: + additional_sampling_params = {} + + # 1. create prompts + prompts: List[Tuple[str, int]] = [] + num_output_tokens_list: List[int] = [] + for i in range(max_num_completed_requests): + num_output = sample_random_positive_int(mean_output_tokens, stddev_output_tokens) + num_output_tokens_list.append(num_output) + prompts.append(randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + tokenizer=tokenizer + )) + start_time = time.monotonic() + completed_requests: List[Dict[str, Any]] = [] + incremental_time_delay = 0.0 + client = OpenAIChatCompletionsClient() + futures = [] + + # 2. Submitting tasks using a thread pool + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + for idx in range(max_num_completed_requests): + sampling = {"max_tokens": num_output_tokens_list[idx]} + sampling.update(additional_sampling_params) + cfg = RequestConfig( + model=model, + prompt=prompts[idx], + sampling_params=sampling, + llm_api=llm_api, + openai_api_base=openai_api_base + ) + futures.append(executor.submit(client.llm_request, cfg)) + # 3. Waiting for completion or timeout + for future in as_completed(futures, timeout=test_timeout_s): + try: + metrics, gen_text, req_cfg = future.result() + except Exception as e: + logging.warning(f"[WARN] Future raised exception: {e}") + continue + num_output_tokens = get_token_length(gen_text) + if num_output_tokens: + metrics[common_metrics.INTER_TOKEN_LAT] /= (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) if ( + metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) else 1 + metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens + metrics[common_metrics.NUM_TOTAL_TOKENS] = metrics[ + common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + try: + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / metrics[ + common_metrics.E2E_LAT] + except ZeroDivisionError: + logging.error("Division by zero in throughput calculation.") + + completed_requests.append(metrics) + + incremental_time_delay += metrics.get(common_metrics.INTER_TOKEN_LAT, 0.0) + + end_time = time.monotonic() + + print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n") + if mean_output_tokens == 2: + print(f"[INFO] First token sending pre-embedding completed\n") + return {}, [], 0.0, 0.0 + + ret = metrics_summary(completed_requests, start_time, end_time) + + metadata = { + "model": model, + "mean_input_tokens": mean_input_tokens, + "stddev_input_tokens": stddev_input_tokens, + "mean_output_tokens": mean_output_tokens, + "stddev_output_tokens": stddev_output_tokens, + "concurrent_requests": concurrent_requests, + "additional_sampling_params": additional_sampling_params, + } + + metadata["results"] = ret + elapsed_time = end_time - start_time + return metadata, completed_requests, elapsed_time, incremental_time_delay + + +def compute_throughput(summary: Dict[str, Any], + completed_requests: List[Dict[str, Any]], + elapsed_time: float, + incremental_time_delay: float) -> Tuple[float, float]: + """ + Compute total_throughput (token/s) based on the metrics in summary. + + Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s + + Args: + summary (Dict[str, Any]): A dictionary containing performance metrics. + + Returns: + float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero. + """ + mean_output_tokens = summary.get("mean_output_tokens", 0) + + total_throughput = ( + (mean_output_tokens * len(completed_requests)) / elapsed_time + if elapsed_time > 0 + else 0.0 + ) + incremental_throughput = ( + (mean_output_tokens * len(completed_requests)) / incremental_time_delay + if incremental_time_delay > 0 + else 0.0 + ) + return round(total_throughput, 4), round(incremental_throughput, 4) + + +def metrics_summary( + metrics: List[Dict[str, Any]], start_time: int, end_time: int +) -> Dict[str, Any]: + """Generate a summary over metrics generated from potentially multiple instances of this client. + + Args: + metrics: The metrics to summarize. + start_time: The time the test started. + end_time: The time the test ended. + + Returns: + A summary with the following information: + - Overall throughput (generated tokens / total test time) + - Number of completed requests + - Error rate + - Error code frequency + - Quantiles (p25-p99) for the following metrics: + - Inter token latency + - Time to first token + - User total request time + - Number of tokens processed per request + - Number of tokens generated per request + - User throughput (tokens / s) + """ + ret = {} + + def flatten(item): + for sub_item in item: + if isinstance(sub_item, Iterable) and not isinstance(sub_item, str): + yield from flatten(sub_item) + else: + yield sub_item + + df = pd.DataFrame(metrics) + df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] + + for key in [ + common_metrics.INTER_TOKEN_LAT, + common_metrics.TTFT, + common_metrics.E2E_LAT, + common_metrics.REQ_OUTPUT_THROUGHPUT, + common_metrics.NUM_INPUT_TOKENS, + common_metrics.NUM_OUTPUT_TOKENS + ]: + print(key) + ret[key] = {} + series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna() + series = series[series > 0] # Calculate non-zero values + quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict() + quantiles_reformatted_keys = {} + for quantile, value in quantiles.items(): + reformatted_key = f"p{int(quantile * 100)}" + print(f" {reformatted_key} = {value}") + quantiles_reformatted_keys[reformatted_key] = value + ret[key]["quantiles"] = quantiles_reformatted_keys + mean = series.mean() + print(f" mean = {mean}") + ret[key]["mean"] = mean + print(f" min = {series.min()}") + ret[key]["min"] = series.min() + print(f" max = {series.max()}") + ret[key]["max"] = series.max() + print(f" stddev = {series.std()}") + ret[key]["stddev"] = series.std() + + ret[common_metrics.NUM_REQ_STARTED] = len(metrics) + + error_codes = df[common_metrics.ERROR_CODE].dropna() + num_errors = len(error_codes) + ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0 + ret[common_metrics.NUM_ERRORS] = num_errors + print(f"Number Of Errored Requests: {num_errors}") + error_code_frequency = dict(error_codes.value_counts()) + if num_errors: + error_code_frequency = dict(error_codes.value_counts()) + print("Error Code Frequency") + print(error_code_frequency) + ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency) + + overall_output_throughput = df_without_errored_req[ + common_metrics.NUM_OUTPUT_TOKENS + ].sum() / (end_time - start_time) + + print(f"Overall Output Throughput: {overall_output_throughput}") + ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput + + num_completed_requests = len(df_without_errored_req) + num_completed_requests_per_min = ( + num_completed_requests / (end_time - start_time) * 60 + ) + print(f"Number Of Completed Requests: {num_completed_requests}") + print(f"Completed Requests Per Minute: {num_completed_requests_per_min}") + + ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests + ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min + + return ret + + +def run_token_benchmark( + llm_api: str, + model: str, + test_timeout_s: int, + max_num_completed_requests: int, + concurrent_requests: int, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: str, + results_dir: str, + random_seed: int, + openai_api_base: str, + tokenizer_path: str, + user_metadata: Dict[str, Any], +): + """ + Args: + llm_api: The name of the llm api to use. + model: The name of the model to query. + max_num_completed_requests: The number of requests to complete before finishing the test. + test_timeout_s: The amount of time to run the test for before reporting results. + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions. + results_dir: The directory to save the results to. + user_metadata: Additional metadata to include in the results. + """ + if mean_input_tokens < 40: + print( + "the minimum number of input tokens that will be sent is 41" + " because of the prompting logic right now" + ) + + summary, completed_requests, elapsed_time, incremental_time_delay = get_token_throughput_latencies( + model=model, + llm_api=llm_api, + test_timeout_s=test_timeout_s, + max_num_completed_requests=max_num_completed_requests, + mean_input_tokens=mean_input_tokens, + stddev_input_tokens=stddev_input_tokens, + mean_output_tokens=mean_output_tokens, + stddev_output_tokens=stddev_output_tokens, + concurrent_requests=concurrent_requests, + additional_sampling_params=json.loads(additional_sampling_params), + random_seed=random_seed, + openai_api_base=openai_api_base, + tokenizer_path=tokenizer_path, + ) + if mean_output_tokens == 2: + return summary, completed_requests, elapsed_time, incremental_time_delay + + timestamp = int(time.time() * 1000) + if results_dir: + filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}" + filename = re.sub(r"[^\w\d-]+", "-", filename) + filename = re.sub(r"-{2,}", "-", filename) + summary_filename = f"{filename}_summary" + + # Update to metadata. + summary.update(user_metadata) + total_tp, req_tp = compute_throughput(summary, completed_requests, elapsed_time, incremental_time_delay) + summary["num_completed_requests"] = len(completed_requests) + summary["elapsed_time"] = elapsed_time + summary["incremental_time_delay"] = incremental_time_delay + summary["total_throughput"] = total_tp + summary["incremental_throughput"] = req_tp + + results = LLMPerfResults(name=summary_filename, metadata=summary) + results_dir = Path(results_dir) + if not results_dir.exists(): + results_dir.mkdir(parents=True) + elif not results_dir.is_dir(): + raise ValueError(f"{results_dir} is not a directory") + + try: + with open(results_dir / f"{summary_filename}.json", "w") as f: + json.dump(results.to_dict(), f, indent=4, default=str) + except Exception as e: + print(results.to_dict()) + raise e + return summary \ No newline at end of file diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py new file mode 100644 index 00000000..e68078b4 --- /dev/null +++ b/test/common/llmperf/utils/utils.py @@ -0,0 +1,168 @@ +import json +import math +import os +import hashlib +import pathlib +import random +import subprocess +import time +from typing import Any, Dict, Tuple + +from transformers import LlamaTokenizerFast + + +RESULTS_VERSION = "2025-10-30" + + +class LLMPerfResults: + def __init__( + self, + name: str, + metadata: Dict[str, Any] = None, + ): + self.name = name + self.metadata = metadata or {} + self.timestamp = int(time.time()) + self.metadata["timestamp"] = self.timestamp + self.version = RESULTS_VERSION + + def to_dict(self): + data = { + "version": self.version, + "name": self.name, + } + data.update(self.metadata) + data = flatten_dict(data) + return data + + def json(self): + data = self.to_dict() + return json.dumps(data) + + +def upload_to_s3(results_path: str, s3_path: str) -> None: + """Upload the results to s3. + + Args: + results_path: The path to the results file. + s3_path: The s3 path to upload the results to. + + """ + + command = ["aws", "s3", "sync", results_path, f"{s3_path}/"] + result = subprocess.run(command) + if result.returncode == 0: + print("Files uploaded successfully!") + else: + print("An error occurred:") + print(result.stderr) + +def randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean: int = 550, + prompt_tokens_stddev: int = 250, + tokenizer: LlamaTokenizerFast = None, +) -> Tuple[str, int]: + """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. + + Args: + prompt_length_mean: The mean length of the prompt to generate. + prompt_len_stddev: The standard deviation of the length of the prompt to generate. + expect_output_tokens: The number of tokens to expect in the output. This is used to + determine the length of the prompt. The prompt will be generated such that the output + will be approximately this many tokens. + + Note: + tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer + ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes + a prompt in less tokens than Llama2, then this will be reflected in the results since + they will be fed identical prompts. + + Returns: + A tuple of the prompt and the length of the prompt. + """ + get_token_length = lambda text: len(tokenizer.encode(text)) + + prompt = ( + "Randomly stream lines from the following text " + "Don't generate eos tokens:\n\n" + ) + # get a prompt length that is at least as long as the base + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + while num_prompt_tokens < get_token_length(prompt): + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt) + sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt" + with open(sonnet_path, "r") as f: + sonnet_lines = f.readlines() + random.shuffle(sonnet_lines) + sampling_lines = True + while sampling_lines: + for line in sonnet_lines: + line_to_add = line + if remaining_prompt_tokens - get_token_length(line_to_add) < 0: + # This will cut off a line in the middle of a word, but that's ok since an + # llm should be able to handle that. + line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))] + sampling_lines = False + prompt += line_to_add + break + prompt += line_to_add + remaining_prompt_tokens -= get_token_length(line_to_add) + print(hashlib.sha256(prompt.encode("utf-8")).hexdigest()) + return (prompt, num_prompt_tokens) + + +def sample_random_positive_int(mean: int, stddev: int) -> int: + """Sample random numbers from a gaussian distribution until a positive number is sampled. + + Args: + mean: The mean of the gaussian distribution to sample from. + stddev: The standard deviation of the gaussian distribution to sample from. + + Returns: + A random positive integer sampled from the gaussian distribution. + """ + ret = -1 + while ret <= 0: + ret = int(random.gauss(mean, stddev)) + return ret + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + +def reset_prefill_cache(env, server_url): + """ + prefix cache / HBM + Param: + env + server_url + """ + reset_url = f"{server_url}/reset_prefix_cache" + print(f"[INFO] Resetting prefix cache: {reset_url}") + try: + result = subprocess.run( + ["curl", "-X", "POST", reset_url, "-s", "-f"], + env=env, + check=False, + capture_output=True, + text=True, + timeout=10 + ) + if result.returncode == 0: + print("[INFO] Prefix cache successfully reset") + else: + print(f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}") + except Exception as e: + print(f"[ERROR] Exception in resetting prefix cache: {e}") \ No newline at end of file diff --git a/test/config.yaml b/test/config.yaml new file mode 100644 index 00000000..88d00a61 --- /dev/null +++ b/test/config.yaml @@ -0,0 +1,18 @@ +reports: + base_dir: "results/reports" + use_timestamp: true + directory_prefix: "pytest" + html: # pytest-html + enabled: true + filename: "report.html" + title: "UCM Pytest Test Report" + +database: + backup: "results/" + enabled: true + host: "127.0.0.1" + port: 3306 + name: "ucm_pytest" + user: "root" + password: "123456" + charset: "utf8mb4" \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..15025795 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import datetime as dt +import platform as pf +import sys +from functools import wraps +from pathlib import Path + +import pytest +from common.config_utils import config_utils as config_instance +from common.db_utils import database_connection, write_to_db + +# ---------------- Constants ---------------- +PRJ_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(PRJ_ROOT)) + + +# ---------------- CLI Options ---------------- +def pytest_addoption(parser): + parser.addoption( + "--stage", action="store", default="", help="Filter by stage marker (1,2,3,+)" + ) + parser.addoption( + "--feature", action="store", default="", help="Filter by feature marker" + ) + parser.addoption( + "--platform", action="store", default="", help="Filter by platform marker" + ) + + +# ---------------- Test Filtering ---------------- +def pytest_collection_modifyitems(config, items): + kept = items[:] + + markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")] + for name in markers: + opt = config.getoption(f"--{name}", "").strip() + if not opt: + continue + + if name == "stage" and opt.endswith("+"): + min_stage = int(opt[:-1]) + kept = [ + it + for it in kept + if any(int(v) >= min_stage for v in _get_marker_args(it, "stage")) + ] + else: + wanted = {x.strip() for x in opt.split(",") if x.strip()} + kept = [ + it + for it in kept + if any(v in wanted for v in _get_marker_args(it, name)) + ] + + config.hook.pytest_deselected(items=[i for i in items if i not in kept]) + items[:] = kept + + +def _get_marker_args(item, marker_name): + """Extract only args (not kwargs) from markers, as strings.""" + return [ + str(arg) for mark in item.iter_markers(name=marker_name) for arg in mark.args + ] + + +# ---------------- Report Setup ---------------- +def _prepare_report_dir(config: pytest.Config) -> Path: + cfg = config_instance.get_config("reports", {}) + base_dir = Path(cfg.get("base_dir", "reports")) + prefix = cfg.get("directory_prefix", "pytest") + if cfg.get("use_timestamp", False): + ts = dt.datetime.now().strftime("%Y%m%d_%H%M%S") + report_dir = base_dir / f"{prefix}_{ts}" + else: + report_dir = base_dir + report_dir.mkdir(parents=True, exist_ok=True) + return report_dir + + +def _setup_html_report(config: pytest.Config, report_dir: Path) -> None: + reports_config = config_instance.get_config("reports", {}) + html_cfg = reports_config.get("html", {}) + if not html_cfg.get("enabled", True): + if hasattr(config.option, "htmlpath"): + config.option.htmlpath = None + print("HTML report disabled according to config.yaml") + return + + html_filename = html_cfg.get("filename", "report.html") + config.option.htmlpath = str(report_dir / html_filename) + config.option.self_contained_html = True + print("HTML report enabled") + + +# ---------------- Build ID & Session Init ---------------- +def _generate_build_id(config: pytest.Config) -> str: + ts = dt.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + cli_parts = [] + markers = [m.split(":", 1)[0].strip() for m in config.getini("markers")] + for opt in markers: + val = config.getoption(opt, "") + if val: + cli_parts.append(f"{opt}={val}") + args_part = "_".join(cli_parts) if cli_parts else "all_cases" + return f"pytest_{ts}_{args_part}" + + +# ---------------- Pytest Hooks ---------------- +def pytest_configure(config: pytest.Config) -> None: + """The global configuration will be executed directly upon entering pytest.""" + print(f"Starting Test Session: {dt.datetime.now():%Y-%m-%d %H:%M:%S}") + + # Set up report directory + report_dir = _prepare_report_dir(config) + config._report_dir = report_dir # Attach to config for later use + _setup_html_report(config, report_dir) + + # Generate and register build ID into DB + build_id = _generate_build_id(config) + config._build_id = build_id + database_connection(build_id) + + +def pytest_sessionstart(session): + print("") + print("-" * 60) + print(f"{'Python':<10} │ {pf.python_version()}") + print(f"{'Platform':<10} │ {pf.system()} {pf.release()}") + print("-" * 60) + + +def pytest_sessionfinish(session, exitstatus): + report_dir = getattr(session.config, "_report_dir", "reports") + print("") + print("-" * 60) + print(f"{'Reports at':<10} │ {report_dir}") + print("Test session ended") + print("-" * 60) + + +# ---------------- Fixtures ---------------- + + +def pytest_runtest_logreport(report): + """ + Called after each test phase. We only care about 'call' (the actual test). + """ + if report.when != "call": + return + + status = report.outcome.upper() # 'passed', 'failed', 'skipped' → 'PASSED', etc. + test_result = { + "test_case": report.nodeid, + "status": status, + # "duration": report.duration, + "error": str(report.longrepr) if report.failed else None, + } + write_to_db("test_case_info", test_result) diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 00000000..4be3cf47 --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,25 @@ +[pytest] +testpaths = suites +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +addopts = + -ra + --strict-markers + --capture=no +filterwarnings = + ignore::pytest.PytestReturnNotNoneWarning + +log_cli = 1 +log_cli_level = INFO +log_cli_format = [%(levelname)s] %(name)s: %(message)s +norecursedirs = .git venv env __pycache__ *.egg + +markers = + # -------- Levels (Required) -------- + stage(n): Unit/Smoke/Regression/Release (0=Unit 1=Smoke 2=Regression 3=Release) + # -------- Features (Recommended) -------- + feature: Feature tag + platform(name): Platform tag(gpu/npu) +# end of markers \ No newline at end of file diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 00000000..d26c4ec3 --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,8 @@ +pytest>=7.0.0 +pytest-html>=3.1.1 +PyYAML>=6.0 +pandas>=2.0.0 +pydantic>=2.0.0 +# MySQL +peewee>=3.14.5 +pymysql>=1.0.2 \ No newline at end of file diff --git a/test/suites/E2E/test_demo_function.py b/test/suites/E2E/test_demo_function.py new file mode 100644 index 00000000..67433ebb --- /dev/null +++ b/test/suites/E2E/test_demo_function.py @@ -0,0 +1,185 @@ +# tests/test_demo.py +import pytest +import allure + +@pytest.mark.stage(1) +@pytest.mark.feature("mark") +@pytest.mark.platform("gpu") +def test_gpu_smoke(): + assert 1 == 1 + +@pytest.mark.stage(1) +@pytest.mark.feature("mark") +def test_regress_accuracy(): + assert 2 + 2 <= 5 + +@pytest.mark.stage(1) +@pytest.mark.feature("mark") +@pytest.mark.platform("npu") +def test_performance_accuracy(): + assert 2 + 2 <= 5 + +# Example of new mark +@pytest.mark.feature("mark") +@pytest.mark.reliability("high") +def test_llm_reliability(): + assert True + + +# Example of importing configuration file parameters +from common.config_utils import config_utils as config_instance +@pytest.mark.feature("config") +def test_llm_config(): + llm_config = config_instance.get_config("llm_connection") + assert llm_config["type"] == "openai" + assert config_instance.get_nested_config("llm_connection.model") == "gpt-3.5-turbo" + assert config_instance.get_nested_config("llm_connection.models", "gpt-3.5-turbo") == "gpt-3.5-turbo" + + + +# Example of using allure +@pytest.mark.feature("allure1") +@allure.feature('test_success') +def test_success(): + """this test succeeds""" + assert True + +@allure.feature('test_failure') +@pytest.mark.feature("allure1") +def test_failure(): + """this test fails""" + assert False + +@allure.feature('test_skip') +@pytest.mark.feature("allure1") +def test_skip(): + """this test is skipped""" + pytest.skip('for a reason!') + +@allure.feature('test_broken') +@pytest.mark.feature("allure1") +def test_broken(): + raise Exception('oops') + +@pytest.mark.feature("allure2") +@pytest.mark.parametrize('param1', ["Hello", "World"]) +@pytest.mark.parametrize('param2', ['Hello', "Hello"]) +def test_parametrize_with_two_parameters(param1, param2): + assert param1 == param2 + +@pytest.mark.feature("allure3") +@allure.description_html(""" +

This is HTML description

+ + + + + + + + + + + + + + + + +
FirstnameLastnameAge
jademr18
roadTester18
+""") +def test_html_description(): + assert True + +@pytest.mark.feature("allure3") +@allure.description("""Multi-line description""") +def test_description_from_decorator(): + assert 42 == int(6 * 7) + +@pytest.mark.feature("allure3") +def test_unicode_in_docstring_description(): + """Description can also be below the function""" + assert 42 == int(6 * 7) + +@pytest.mark.feature("allure4") +@allure.title("Assert that 2+2=4") +def test_with_a_title(): + assert 2 + 2 == 4 + +@pytest.mark.feature("allure4") +@allure.title("Dynamic title: {param1} + {param2} = {expected}") +@pytest.mark.parametrize('param1,param2,expected', [(2, 2, 4),(1, 2, 5)]) +def test_with_parameterized_title(param1, param2, expected): + assert param1 + param2 == expected + +@pytest.mark.feature("allure4") +@allure.title("This is a dynamic title that will be replaced") +def test_with_dynamic_title(): + assert 2 + 2 == 4 + allure.dynamic.title('Test completed, used as title') + + +@pytest.mark.feature("allure5") +def test_with_steps(): + """Example test case with steps""" + with allure.step("Step 1: Initialize variables"): + a = 2 + b = 3 + + with allure.step("Step 2: Perform addition"): + result = a + b + + with allure.step("Step 3: Verify result"): + assert result == 5 + +import tempfile +import os +@pytest.mark.feature("allure6") +def test_with_attachment(): + """Example test case with attachment""" + # Create some data to attach + data = "This is sample data for attachment\nLine 2\nLine 3" + + # Attach text data + allure.attach(data, name="Sample Data", attachment_type=allure.attachment_type.TEXT) + + # Create and attach a simple file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("Sample file content\nFor testing attachment feature") + temp_file_path = f.name + + # Attach the file + allure.attach.file(temp_file_path, name="Attached File", + attachment_type=allure.attachment_type.TEXT) + + # Clean up temporary file + os.unlink(temp_file_path) + + assert True + +@pytest.mark.feature("allure7") +def test_mixed_steps_and_attachments(): + """Example test case combining steps and attachments""" + with allure.step("Initialize test data"): + test_data = {"name": "John", "age": 30, "city": "New York"} + + with allure.step("Convert data to JSON string"): + import json + json_data = json.dumps(test_data, indent=2) + allure.attach(json_data, name="JSON Data", attachment_type=allure.attachment_type.JSON) + + with allure.step("Validate data"): + assert test_data["name"] == "John" + assert test_data["age"] == 30 + + with allure.step("Create and attach report"): + report_content = f""" + Test Report + =========== + Name: {test_data['name']} + Age: {test_data['age']} + City: {test_data['city']} + Status: PASSED + """ + allure.attach(report_content, name="Test Report", + attachment_type=allure.attachment_type.TEXT) \ No newline at end of file diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py new file mode 100644 index 00000000..9bc26092 --- /dev/null +++ b/test/suites/E2E/test_uc_performance.py @@ -0,0 +1,121 @@ +import pytest + +from common.llmperf.run_inference import inference_results + +from common.capture_utils import export_vars + + +@pytest.mark.feature("uc_performance_test") +@export_vars +def test_performance(): + all_summaries = inference_results() + failed_cases = [] + + value_lists = { + 'mean_input_tokens': [], + 'mean_output_tokens': [], + 'results_inter_token_latency_s_quantiles_p50': [], + 'results_inter_token_latency_s_quantiles_p90': [], + 'results_inter_token_latency_s_quantiles_p99': [], + 'results_inter_token_latency_s_mean': [], + 'results_ttft_s_quantiles_p50': [], + 'results_ttft_s_quantiles_p90': [], + 'results_ttft_s_quantiles_p99': [], + 'results_ttft_s_mean': [], + 'results_end_to_end_latency_s_quantiles_p50': [], + 'results_end_to_end_latency_s_quantiles_p90': [], + 'results_end_to_end_latency_s_quantiles_p99': [], + 'results_end_to_end_latency_s_mean': [], + 'num_completed_requests': [], + 'elapsed_time': [], + 'incremental_time_delay': [], + 'total_throughput': [], + 'incremental_throughput': [], + } + + for i, summary in enumerate(all_summaries): + mean_input_tokens = summary["mean_input_tokens"] + mean_output_tokens = summary["mean_output_tokens"] + + results_inter_token_latency_s_quantiles_p50 = summary["results"]["inter_token_latency_s"]["quantiles"]["p50"] + results_inter_token_latency_s_quantiles_p90 = summary["results"]["inter_token_latency_s"]["quantiles"]["p90"] + results_inter_token_latency_s_quantiles_p99 = summary["results"]["inter_token_latency_s"]["quantiles"]["p99"] + results_inter_token_latency_s_mean = summary["results"]["inter_token_latency_s"]["mean"] + + results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"] + results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"] + results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"] + results_ttft_s_mean = summary["results"]["ttft_s"]["mean"] + + results_end_to_end_latency_s_quantiles_p50 = summary["results"]["end_to_end_latency_s"]["quantiles"]["p50"] + results_end_to_end_latency_s_quantiles_p90 = summary["results"]["end_to_end_latency_s"]["quantiles"]["p90"] + results_end_to_end_latency_s_quantiles_p99 = summary["results"]["end_to_end_latency_s"]["quantiles"]["p99"] + results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"]["mean"] + + num_completed_requests = summary["num_completed_requests"] + elapsed_time = summary["elapsed_time"] + incremental_time_delay = summary["incremental_time_delay"] + total_throughput = summary["total_throughput"] + incremental_throughput = summary["incremental_throughput"] + + values = [ + mean_input_tokens, + mean_output_tokens, + results_inter_token_latency_s_quantiles_p50, + results_inter_token_latency_s_quantiles_p90, + results_inter_token_latency_s_quantiles_p99, + results_inter_token_latency_s_mean, + results_ttft_s_quantiles_p50, + results_ttft_s_quantiles_p90, + results_ttft_s_quantiles_p99, + results_ttft_s_mean, + results_end_to_end_latency_s_quantiles_p50, + results_end_to_end_latency_s_quantiles_p90, + results_end_to_end_latency_s_quantiles_p99, + results_end_to_end_latency_s_mean, + num_completed_requests, + elapsed_time, + incremental_time_delay, + total_throughput, + incremental_throughput + ] + + for var_name, val in zip([ + 'mean_input_tokens', + 'mean_output_tokens', + 'results_inter_token_latency_s_quantiles_p50', + 'results_inter_token_latency_s_quantiles_p90', + 'results_inter_token_latency_s_quantiles_p99', + 'results_inter_token_latency_s_mean', + 'results_ttft_s_quantiles_p50', + 'results_ttft_s_quantiles_p90', + 'results_ttft_s_quantiles_p99', + 'results_ttft_s_mean', + 'results_end_to_end_latency_s_quantiles_p50', + 'results_end_to_end_latency_s_quantiles_p90', + 'results_end_to_end_latency_s_quantiles_p99', + 'results_end_to_end_latency_s_mean', + 'num_completed_requests', + 'elapsed_time', + 'incremental_time_delay', + 'total_throughput', + 'incremental_throughput' + ], values): + value_lists[var_name].append(val) + if val is None: + failed_cases.append((i, var_name, "missing")) + + try: + assert val > 0, f"value <= 0" + except AssertionError as e: + failed_cases.append((i, var_name, str(e))) + + # Output final result + if failed_cases: + print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found") + for i, key, reason in failed_cases: + print(f" Iteration={i + 1}, key='{key}' -> {reason}") + else: + print("\n[INFO] All values are greater than 0. Assertion passed!") + + return value_lists \ No newline at end of file diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py index 0c2261d8..d4a0caeb 100644 --- a/test/test_uc_connector.py +++ b/test/test_uc_connector.py @@ -25,7 +25,6 @@ import random import secrets import unittest -from collections import defaultdict from typing import List, Union from unittest.mock import MagicMock, Mock, patch @@ -107,14 +106,12 @@ def init_uc( ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {} ucconnector.total_tp_size = self.total_tp_size ucconnector._connector_metadata = metadata - ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict( - dict - ) + ucconnector.layerwise_load_tasks: dict[ + str, dict[str, tuple[Task, Task]] + ] = {} ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {} ucconnector._load_failed_reqs: set[str] = set() ucconnector._load_req_to_blocks: dict[str, set[int]] = {} - ucconnector.num_layers = 48 - ucconnector.is_mla = False return ucconnector def test_get_num_new_matched_tokens_hit_all_on_storage(self): @@ -511,7 +508,6 @@ def test_wait_for_save_not_layerwise_invalid_para(self): ucconnector.block_size = self.block_size ucconnector.use_layerwise = False ucconnector._connector_metadata = Mock() - ucconnector.is_mla = False with self.assertRaises(AssertionError): ucconnector.wait_for_save() @@ -546,7 +542,6 @@ def mock_wait(task: Task) -> int: ) forward_context = Mock() ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == 1 def test_start_load_kv_invalid_para(self): with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): @@ -564,7 +559,6 @@ def test_start_load_kv_layerwise_success(self): req_meta1.load_blocks = [ (secrets.token_hex(8), i) for i in range(self.block_number) ] - req_meta1.load_async = False metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -581,7 +575,7 @@ def mock_load( ucconnector = self.init_uc(mock_connector, metadata=metadata) forward_context = Mock() ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == self.num_layers + assert mock_connector.load.call_count == 2 * self.num_layers if __name__ == "__main__": diff --git a/test/test_ucm_dram.py b/test/test_ucm_dram.py new file mode 100644 index 00000000..020405d1 --- /dev/null +++ b/test/test_ucm_dram.py @@ -0,0 +1,250 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import random +import unittest +import unittest.mock as mock +from contextlib import contextmanager +from typing import List +from unittest.mock import MagicMock + +import torch +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 +from vllm.v1.core.kv_cache_utils import hash_request_tokens +from vllm.v1.request import Request + + +@contextmanager +def mock_stream_context(stream=None): + yield + + +class MockStream: + def __init__(self, device=None): + self.device = device or torch.device("cpu") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def synchronize(self): + pass + + def record_event(self, event=None): + return event or MockEvent() + + def wait_stream(self, stream): + pass + + +class MockEvent: + def __init__(self, enable_timing=False): + self.enable_timing = enable_timing + + def record(self, stream=None): + pass + + def wait(self, stream=None): + pass + + def synchronize(self): + pass + + +def patch_cuda_for_cpu(): + mock.patch("torch.cuda.Stream", MockStream).start() + mock.patch("torch.cuda.Event", MockEvent).start() + mock.patch("torch.cuda.current_stream", return_value=MockStream()).start() + mock.patch("torch.cuda.synchronize", side_effect=lambda *a, **k: None).start() + mock.patch("torch.cuda.is_available", return_value=True).start() + mock.patch("torch.cuda.stream", mock_stream_context).start() + + +patch_cuda_for_cpu() +from ucm.store.dramstore.dramstore_connector import ( # isort: skip + DramTask, + UcmDramStore, +) + + +def make_request( + request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None +): + if mm_positions is None: + multi_modal_inputs = None + else: + multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) + + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_inputs=multi_modal_inputs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + arrival_time=0, + lora_request=None, + cache_salt=cache_salt, + ) + + +class TestUcmDram(unittest.TestCase): + + @classmethod + def setUpClass(cls): + print("===> Before all tests (setUpClass)") + + @classmethod + def tearDownClass(cls): + print("===> After all tests (setUpClass)") + + def setUp(self): + self.config = {"block_size": 4} + self.scheduler_config = { + "role": "scheduler", + "max_cache_size": 1073741824, + "kv_block_size": 262144, + } + self.worker_config = { + "role": "worker", + "max_cache_size": 1073741824, + "kv_block_size": 262144, + } + + self.block_number = 4 + self.block_size = int(self.config["block_size"]) + self.scheduler_dram = UcmDramStore(self.scheduler_config) + self.worker_dram = UcmDramStore(self.worker_config) + random.seed(20250728) + self.request = make_request( + request_id=1, + prompt_token_ids=random.sample( + range(0, 10000), self.block_number * self.block_size + ), + mm_positions=None, + mm_hashes=None, + ) + block_hash_types = hash_request_tokens(sha256, self.block_size, self.request) + self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types] + + def test_look_up_all_hit(self): + """ + Test for all blocks hitten in cache + """ + expected = [True] * len(self.block_hashes) + self.scheduler_dram.cached_blocks.update(self.block_hashes) + actual = self.scheduler_dram.lookup(self.block_hashes) + + self.assertEqual(actual, expected) + + def test_lookup_partial_hit(self): + """ + Test for part of the blocks hitten in cache + """ + partial_index = random.randint(0, 4) + partial_hashes = self.block_hashes[:partial_index] + self.scheduler_dram.cached_blocks.update(partial_hashes) + actual = self.scheduler_dram.lookup(self.block_hashes) + expected = [True] * partial_index + [False] * (self.block_size - partial_index) + self.assertEqual(actual, expected) + + def test_lookup_none_hit(self): + """ + Test for none of the blocks hitten in cache + """ + actual = self.scheduler_dram.lookup(self.block_hashes) + expected = [False] * len(self.block_hashes) + self.assertEqual(actual, expected) + + def test_load_success(self): + """ + Test for load from cache successfully + """ + src_tensors = [ + torch.randint(0, 100, (self.block_size,), dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + offsets = [i for i in range(len(self.block_hashes))] + dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) + self.worker_dram.wait(dump_task) + dst_tensors = [ + torch.zeros(self.block_size, dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors) + + self.assertIsInstance(load_task, DramTask) + self.assertIsNotNone(load_task.event) + for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)): + self.assertEqual(dst_tensor.shape[0], self.block_size) + self.assertTrue( + torch.equal(src_tensor, dst_tensor), + f"Block {i} loaded data is different", + ) + + def test_dump_success(self): + """ + Test data dump successfully + """ + src_tensors = [ + torch.randint(0, 100, (self.block_size,), dtype=torch.int8) + for _ in range(len(self.block_hashes)) + ] + offsets = [i for i in range(len(self.block_hashes))] + original_data = [tensor.clone() for tensor in src_tensors] + dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) + self.assertIsInstance(dump_task, DramTask) + self.assertIsNotNone(dump_task.event) + self.worker_dram.wait(dump_task) + for i, block_id in enumerate(self.block_hashes): + key = block_id + "_" + str(offsets[i]) + cached_data = self.worker_dram.dram_cache[key] + self.assertEqual(cached_data.shape[0], self.block_size) + self.assertTrue(torch.equal(cached_data, original_data[i])) + + def test_wait_success(self): + """ + Test wait for task successfully + """ + task = DramTask() + task.event = MagicMock() + result = self.worker_dram.wait(task) + self.assertEqual(result, 0) + task.event.synchronize.assert_called_once() + + def test_wait_failure(self): + task = DramTask() + task.event = None + result = self.worker_dram.wait(task) + self.assertEqual(result, -1) + + +if __name__ == "__main__": + unittest.main()