diff --git a/ai/pipelines/__init__.py b/ai/pipelines/__init__.py index e3eece2..6b523a3 100644 --- a/ai/pipelines/__init__.py +++ b/ai/pipelines/__init__.py @@ -2,5 +2,6 @@ from .comment_summarizer import generate_comment_summaries from .keyword_cluster import cluster_keywords +from .trading_prompt import build_trading_prompt -__all__ = ["generate_comment_summaries", "cluster_keywords"] +__all__ = ["generate_comment_summaries", "cluster_keywords", "build_trading_prompt"] diff --git a/ai/pipelines/trading_prompt.py b/ai/pipelines/trading_prompt.py new file mode 100644 index 0000000..473d030 --- /dev/null +++ b/ai/pipelines/trading_prompt.py @@ -0,0 +1,158 @@ +"""Prompt builder for trading agent market state.""" +from __future__ import annotations + +from typing import Any, Iterable, Mapping, Sequence + +PROMPT_TEMPLATE_HEADER = "🧭 SYSTEM STATE" +PROMPT_MARKET_STATE_TITLE = "⚙️ CURRENT MARKET STATE" +PROMPT_ACCOUNT_TITLE = "💰 ACCOUNT INFORMATION" +PROMPT_PERFORMANCE_TITLE = "📊 PERFORMANCE METRICS" + + +def _format_value(value: Any) -> str: + if value is None: + return "N/A" + if isinstance(value, float): + formatted = f"{value:.8f}".rstrip("0").rstrip(".") + return formatted or "0" + return str(value) + + +def _format_series(series: Iterable[Any]) -> str: + items = [item for item in series] + if not items: + return "[]" + return "[" + ", ".join(_format_value(item) for item in items) + "]" + + +def build_trading_prompt( + system_state: Mapping[str, Any], + market_states: Sequence[Mapping[str, Any]], + account_info: Mapping[str, Any], + *, + performance_metrics: Mapping[str, Any] | None = None, +) -> str: + """Render a textual prompt describing the trading environment.""" + + performance_metrics = performance_metrics or {} + lines: list[str] = [] + + lines.append(PROMPT_TEMPLATE_HEADER) + minutes = _format_value(system_state.get("minutes_since_start")) + current_time = _format_value(system_state.get("current_time")) + invocation_count = _format_value(system_state.get("invocation_count")) + lines.append( + f"It has been {minutes} minutes since you started trading." + ) + lines.append( + f"The current time is {current_time} and you've been invoked {invocation_count} times." + ) + lines.append("Below you are provided with price data, indicators, and account info for decision making.") + lines.append("") + + lines.append(PROMPT_MARKET_STATE_TITLE) + lines.append("For each coin:") + for state in market_states: + symbol = state.get("symbol", "UNKNOWN") + lines.append(f"SYMBOL: {symbol}") + lines.append( + "current_price = {price}, current_ema20 = {ema}, current_macd = {macd}, current_rsi(7) = {rsi7}".format( + price=_format_value(state.get("current_price")), + ema=_format_value(state.get("current_ema20")), + macd=_format_value(state.get("current_macd")), + rsi7=_format_value(state.get("current_rsi_7")), + ) + ) + lines.append("") + lines.append( + "Open Interest: Latest={latest}, Average={avg}".format( + latest=_format_value(state.get("open_interest_latest")), + avg=_format_value(state.get("open_interest_avg")), + ) + ) + lines.append(f"Funding Rate: {_format_value(state.get('funding_rate'))}") + lines.append("") + + short_term = state.get("short_term", {}) + short_interval = short_term.get("interval", "-") + lines.append(f"Intraday Series ({short_interval} intervals, oldest → latest):") + lines.append(f"Mid Prices: {_format_series(short_term.get('mid_prices', []))}") + lines.append(f"EMA(20): {_format_series(short_term.get('ema_series_20', []))}") + lines.append(f"MACD: {_format_series(short_term.get('macd_series', []))}") + lines.append(f"RSI(7): {_format_series(short_term.get('rsi_series_7', []))}") + lines.append(f"RSI(14): {_format_series(short_term.get('rsi_series_14', []))}") + lines.append("") + + long_term = state.get("long_term", {}) + long_interval = long_term.get("interval", "-") + lines.append(f"Longer-term ({long_interval} timeframe):") + lines.append( + "EMA(20): {ema20}, EMA(50): {ema50}".format( + ema20=_format_value(long_term.get("ema_20_4h")), + ema50=_format_value(long_term.get("ema_50_4h")), + ) + ) + lines.append( + "ATR(3): {atr3}, ATR(14): {atr14}".format( + atr3=_format_value(long_term.get("atr_3_4h")), + atr14=_format_value(long_term.get("atr_14_4h")), + ) + ) + lines.append( + "Volume: Current={current}, Average={avg}".format( + current=_format_value(long_term.get("current_volume")), + avg=_format_value(long_term.get("avg_volume")), + ) + ) + lines.append(f"MACD(4h): {_format_series(long_term.get('macd_series_4h', []))}") + lines.append(f"RSI(14,4h): {_format_series(long_term.get('rsi_series_14_4h', []))}") + lines.append("") + + lines.append(PROMPT_ACCOUNT_TITLE) + lines.append( + f"Total Return: {_format_value(account_info.get('total_return_percent'))}%" + ) + lines.append(f"Available Cash: {_format_value(account_info.get('available_cash'))}") + lines.append(f"Account Value: {_format_value(account_info.get('account_value'))}") + lines.append("") + + lines.append("Positions:") + positions = account_info.get("positions", []) + if positions: + for position in positions: + symbol = position.get("symbol", "UNKNOWN") + lines.append( + "- {symbol}: qty={qty}, entry={entry}, current={current}, PnL={pnl}, leverage={leverage}".format( + symbol=symbol, + qty=_format_value(position.get("quantity")), + entry=_format_value(position.get("entry_price")), + current=_format_value(position.get("current_price")), + pnl=_format_value(position.get("unrealized_pnl")), + leverage=_format_value(position.get("leverage")), + ) + ) + exit_plan = position.get("exit_plan", {}) + lines.append( + " TP={tp}, SL={sl}, invalidation={invalid}".format( + tp=_format_value(exit_plan.get("profit_target")), + sl=_format_value(exit_plan.get("stop_loss")), + invalid=_format_value(exit_plan.get("invalidation_condition")), + ) + ) + lines.append( + " confidence={confidence}, risk={risk}, notional={notional}".format( + confidence=_format_value(position.get("confidence")), + risk=_format_value(position.get("risk_usd")), + notional=_format_value(position.get("notional_usd")), + ) + ) + else: + lines.append("- None") + lines.append("") + + lines.append(PROMPT_PERFORMANCE_TITLE) + lines.append( + f"Sharpe Ratio: {_format_value(performance_metrics.get('sharpe_ratio'))}" + ) + + return "\n".join(lines) diff --git a/config/secrets.example.env b/config/secrets.example.env index 093e34e..a558620 100644 --- a/config/secrets.example.env +++ b/config/secrets.example.env @@ -6,6 +6,8 @@ UNIFIED_API_TOKEN=replace-with-token DB_PASSWORD=replace-with-password # Keepa 开发者密钥,订阅 https://keepa.com/#!api 后在仪表盘复制 KEEPA_API_KEY=optional-keepa-key +# TAAPI API 密钥,登录 https://taapi.io/ 控制台获取 +TAAPI_API_KEY=optional-taapi-key # Amazon Selling Partner API 刷新令牌,登录 https://developer.amazon.com/sp-api/ 获取 SPAPI_REFRESH_TOKEN=optional-refresh-token # Qwen(阿里云百炼 DashScope)API Key,登录 https://dashscope.console.aliyun.com/management/app 获取 diff --git a/config/settings.yaml b/config/settings.yaml index 8d3037f..aec83e0 100644 --- a/config/settings.yaml +++ b/config/settings.yaml @@ -30,6 +30,10 @@ rate_limits: qpm: 30 burst: 30 timeout_sec: 10 + taapi: + qpm: 120 + burst: 120 + timeout_sec: 10 retry_policy: max_attempts: 5 base_delay_ms: 500 diff --git a/connectors/__init__.py b/connectors/__init__.py index 1c94812..65fc5d8 100644 --- a/connectors/__init__.py +++ b/connectors/__init__.py @@ -4,6 +4,7 @@ from .paapi_client import PAAPIClient from .h10_client import Helium10Client from .js_client import JungleScoutClient +from .taapi_client import TaapiClient __all__ = [ "KeepaClient", @@ -11,4 +12,5 @@ "PAAPIClient", "Helium10Client", "JungleScoutClient", + "TaapiClient", ] diff --git a/connectors/taapi_client.py b/connectors/taapi_client.py new file mode 100644 index 0000000..e137fa8 --- /dev/null +++ b/connectors/taapi_client.py @@ -0,0 +1,319 @@ +"""TAAPI connector for retrieving technical analysis metrics.""" +from __future__ import annotations + +from typing import Any, Dict, Iterable, List, Sequence + +from connectors.base import BaseConnector, ConnectorError, retry_and_rate_limit + + +class TaapiClient(BaseConnector): + """Client wrapper for TAAPI bulk indicator endpoints.""" + + def __init__( + self, + api_key: str, + *, + base_url: str = "https://api.taapi.io", + settings: Dict[str, Any] | None = None, + session: Any | None = None, + ) -> None: + super().__init__(service_name="taapi", base_url=base_url, settings=settings, session=session) + self.api_key = api_key + + @retry_and_rate_limit(source="taapi") + def _request_bulk(self, construct: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + payload = { + "secret": self.api_key, + "construct": list(construct), + } + headers = {"Content-Type": "application/json"} + return self._http_request("POST", "/bulk", json_payload=payload, headers=headers) + + def get_market_state( + self, + symbol: str, + exchange: str, + *, + short_interval: str = "5m", + long_interval: str = "4h", + short_backtrack: int = 20, + long_backtrack: int = 20, + ) -> Dict[str, Any]: + """Fetch a consolidated market snapshot for ``symbol`` from TAAPI.""" + + construct = [ + self._indicator("price", symbol, exchange, "1m", "current_price"), + self._indicator( + "ema", + symbol, + exchange, + short_interval, + "current_ema20", + optInTimePeriod=20, + ), + self._indicator( + "macd", + symbol, + exchange, + short_interval, + "current_macd", + ), + self._indicator( + "rsi", + symbol, + exchange, + short_interval, + "current_rsi_7", + optInTimePeriod=7, + ), + self._indicator( + "rsi", + symbol, + exchange, + short_interval, + "current_rsi_14", + optInTimePeriod=14, + ), + self._indicator( + "midprice", + symbol, + exchange, + short_interval, + "short_mid_prices", + backtrack=short_backtrack, + ), + self._indicator( + "ema", + symbol, + exchange, + short_interval, + "short_ema20_series", + optInTimePeriod=20, + backtrack=short_backtrack, + ), + self._indicator( + "macd", + symbol, + exchange, + short_interval, + "short_macd_series", + backtrack=short_backtrack, + ), + self._indicator( + "rsi", + symbol, + exchange, + short_interval, + "short_rsi7_series", + optInTimePeriod=7, + backtrack=short_backtrack, + ), + self._indicator( + "rsi", + symbol, + exchange, + short_interval, + "short_rsi14_series", + optInTimePeriod=14, + backtrack=short_backtrack, + ), + self._indicator( + "ema", + symbol, + exchange, + long_interval, + "long_ema20", + optInTimePeriod=20, + ), + self._indicator( + "ema", + symbol, + exchange, + long_interval, + "long_ema50", + optInTimePeriod=50, + ), + self._indicator( + "atr", + symbol, + exchange, + long_interval, + "long_atr3", + optInTimePeriod=3, + ), + self._indicator( + "atr", + symbol, + exchange, + long_interval, + "long_atr14", + optInTimePeriod=14, + ), + self._indicator( + "volume", + symbol, + exchange, + long_interval, + "long_volume_current", + ), + self._indicator( + "sma", + symbol, + exchange, + long_interval, + "long_volume_avg", + optInTimePeriod=long_backtrack, + ), + self._indicator( + "macd", + symbol, + exchange, + long_interval, + "long_macd_series", + backtrack=long_backtrack, + ), + self._indicator( + "rsi", + symbol, + exchange, + long_interval, + "long_rsi14_series", + optInTimePeriod=14, + backtrack=long_backtrack, + ), + self._indicator( + "oi", + symbol, + exchange, + short_interval, + "open_interest_latest", + ), + self._indicator( + "oi", + symbol, + exchange, + short_interval, + "open_interest_avg", + backtrack=short_backtrack, + ), + self._indicator( + "fundingrate", + symbol, + exchange, + short_interval, + "funding_rate", + ), + ] + + response = self._request_bulk(construct) + return self._normalise_market_state( + response, + symbol=symbol, + short_interval=short_interval, + long_interval=long_interval, + ) + + def _indicator( + self, + indicator: str, + symbol: str, + exchange: str, + interval: str, + identifier: str, + **params: Any, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "indicator": indicator, + "symbol": symbol, + "exchange": exchange, + "interval": interval, + "id": identifier, + } + payload.update(params) + return payload + + def _normalise_market_state( + self, + payload: Dict[str, Any], + *, + symbol: str, + short_interval: str, + long_interval: str, + ) -> Dict[str, Any]: + entries = payload.get("data", []) + by_id: Dict[str, Dict[str, Any]] = { + entry.get("id", ""): entry.get("result", {}) for entry in entries if entry.get("id") + } + + def value(key: str) -> Any: + return self._extract_single(by_id.get(key)) + + def series(key: str) -> List[float]: + return self._extract_series(by_id.get(key)) + + market_state = { + "symbol": symbol, + "current_price": value("current_price"), + "current_ema20": value("current_ema20"), + "current_macd": value("current_macd"), + "current_rsi_7": value("current_rsi_7"), + "current_rsi_14": value("current_rsi_14"), + "open_interest_latest": value("open_interest_latest"), + "open_interest_avg": value("open_interest_avg"), + "funding_rate": value("funding_rate"), + "short_term": { + "interval": short_interval, + "mid_prices": series("short_mid_prices"), + "ema_series_20": series("short_ema20_series"), + "macd_series": series("short_macd_series"), + "rsi_series_7": series("short_rsi7_series"), + "rsi_series_14": series("short_rsi14_series"), + }, + "long_term": { + "interval": long_interval, + "ema_20_4h": value("long_ema20"), + "ema_50_4h": value("long_ema50"), + "atr_3_4h": value("long_atr3"), + "atr_14_4h": value("long_atr14"), + "current_volume": value("long_volume_current"), + "avg_volume": value("long_volume_avg"), + "macd_series_4h": series("long_macd_series"), + "rsi_series_14_4h": series("long_rsi14_series"), + }, + } + return market_state + + @staticmethod + def _extract_single(result: Dict[str, Any] | None) -> Any: + if not isinstance(result, dict): + return None + for key in ("value", "valueMACD", "valueMacd", "value1", "close"): + if key in result and isinstance(result[key], (int, float)): + return result[key] + for val in result.values(): + if isinstance(val, (int, float)): + return val + return None + + @staticmethod + def _extract_series(result: Dict[str, Any] | None) -> List[float]: + if not isinstance(result, dict): + return [] + values: Iterable[Any] = result.get("values", []) + extracted: List[float] = [] + for item in values: + numeric = TaapiClient._extract_single(item if isinstance(item, dict) else {"value": item}) + if isinstance(numeric, (int, float)): + extracted.append(float(numeric)) + if not extracted: + single = TaapiClient._extract_single(result) + if isinstance(single, (int, float)): + extracted.append(float(single)) + return extracted + + def healthcheck(self) -> bool: # pragma: no cover - simple passthrough + try: + self._request_bulk([{"indicator": "price", "symbol": "BTC/USDT", "exchange": "binance", "interval": "1m", "id": "ping"}]) + return True + except ConnectorError: + return False diff --git a/tests/test_ai_pipelines.py b/tests/test_ai_pipelines.py index eae6d73..b79cbcd 100644 --- a/tests/test_ai_pipelines.py +++ b/tests/test_ai_pipelines.py @@ -6,7 +6,7 @@ from pathlib import Path from ai.models import EmbeddingRequest, EmbeddingResponse, GenerationRequest, GenerationResponse -from ai.pipelines import cluster_keywords, generate_comment_summaries +from ai.pipelines import build_trading_prompt, cluster_keywords, generate_comment_summaries from ai.clients.base import EmbeddingClient, LLMClient @@ -120,3 +120,70 @@ def test_cluster_keywords(tmp_path: Path) -> None: expected_dir = tmp_path / "20230207" / "ai" assert output_path.parent == expected_dir + + +def test_build_trading_prompt() -> None: + system_state = {"minutes_since_start": 42, "current_time": "2024-04-01T12:00:00Z", "invocation_count": 5} + market_state = [ + { + "symbol": "BTC", + "current_price": 114775.5, + "current_ema20": 114457.686, + "current_macd": 53.823, + "current_rsi_7": 86.769, + "open_interest_latest": 1234.0, + "open_interest_avg": 987.0, + "funding_rate": 0.0005, + "short_term": { + "interval": "5min", + "mid_prices": [114339.5, 114500.0], + "ema_series_20": [114329.3, 114480.5], + "macd_series": [-60.9, -12.1], + "rsi_series_7": [61.9, 65.2], + "rsi_series_14": [51.1, 54.0], + }, + "long_term": { + "interval": "4h", + "ema_20_4h": 113453.764, + "ema_50_4h": 112031.947, + "atr_3_4h": 319.999, + "atr_14_4h": 523.419, + "current_volume": 8.146, + "avg_volume": 4618.058, + "macd_series_4h": [1082.2, 950.1], + "rsi_series_14_4h": [72.327, 70.1], + }, + } + ] + account_info = { + "total_return_percent": 12.3, + "available_cash": 1000.0, + "account_value": 12345.6, + "positions": [ + { + "symbol": "BTC", + "quantity": 0.5, + "entry_price": 100000.0, + "current_price": 114775.5, + "unrealized_pnl": 7387.75, + "leverage": 3, + "exit_plan": { + "profit_target": 120000.0, + "stop_loss": 95000.0, + "invalidation_condition": "Breaks below 90k", + }, + "confidence": "high", + "risk_usd": 500.0, + "notional_usd": 57387.75, + } + ], + } + performance = {"sharpe_ratio": 1.23} + + prompt = build_trading_prompt(system_state, market_state, account_info, performance_metrics=performance) + + assert "🧭 SYSTEM STATE" in prompt + assert "SYMBOL: BTC" in prompt + assert "current_price = 114775.5" in prompt + assert "Positions:" in prompt + assert "Sharpe Ratio: 1.23" in prompt diff --git a/tests/test_connectors.py b/tests/test_connectors.py index b3089a2..4152075 100644 --- a/tests/test_connectors.py +++ b/tests/test_connectors.py @@ -11,6 +11,7 @@ retry_and_rate_limit, ) from connectors.keepa_client import KeepaClient +from connectors.taapi_client import TaapiClient from tests.mock_keepa_server import sample_product_payload TEST_SETTINGS = { @@ -26,6 +27,7 @@ "paapi": {"qpm": 1000, "burst": 1000, "timeout_sec": 1}, "helium10": {"qpm": 1000, "burst": 1000, "timeout_sec": 1}, "junglescout": {"qpm": 1000, "burst": 1000, "timeout_sec": 1}, + "taapi": {"qpm": 1000, "burst": 1000, "timeout_sec": 1}, }, } @@ -122,3 +124,72 @@ def test_get_products_handles_dead_letter(monkeypatch): products = fake.get_products(["B0001", "B0002"], "US") assert len(products) == 1 assert fake.dead_letter == [{"asin": "B0002", "site": "US", "reason": "keepa_fetch_failed"}] + + +def test_taapi_market_state_normalisation(monkeypatch): + client = TaapiClient("secret", settings=TEST_SETTINGS) + + payload = { + "data": [ + {"id": "current_price", "result": {"value": 114775.5}}, + {"id": "current_ema20", "result": {"value": 114457.686}}, + {"id": "current_macd", "result": {"valueMACD": 53.823}}, + {"id": "current_rsi_7", "result": {"value": 86.769}}, + {"id": "current_rsi_14", "result": {"value": 74.835}}, + { + "id": "short_mid_prices", + "result": {"values": [{"value": 114339.5}, {"value": 114500.0}]}, + }, + { + "id": "short_ema20_series", + "result": {"values": [{"value": 114329.3}, {"value": 114480.5}]}, + }, + { + "id": "short_macd_series", + "result": {"values": [{"valueMACD": -60.9}, {"valueMACD": -12.1}]}, + }, + { + "id": "short_rsi7_series", + "result": {"values": [{"value": 61.9}, {"value": 65.2}]}, + }, + { + "id": "short_rsi14_series", + "result": {"values": [{"value": 51.1}, {"value": 54.0}]}, + }, + {"id": "long_ema20", "result": {"value": 113453.764}}, + {"id": "long_ema50", "result": {"value": 112031.947}}, + {"id": "long_atr3", "result": {"value": 319.999}}, + {"id": "long_atr14", "result": {"value": 523.419}}, + {"id": "long_volume_current", "result": {"value": 8.146}}, + {"id": "long_volume_avg", "result": {"value": 4618.058}}, + { + "id": "long_macd_series", + "result": {"values": [{"valueMACD": 1082.2}, {"valueMACD": 950.1}]}, + }, + { + "id": "long_rsi14_series", + "result": {"values": [{"value": 72.327}, {"value": 70.1}]}, + }, + {"id": "open_interest_latest", "result": {"value": 1234.0}}, + {"id": "open_interest_avg", "result": {"value": 987.0}}, + {"id": "funding_rate", "result": {"value": 0.0005}}, + ] + } + + captured: dict[str, Any] = {} + + def fake_request(self, construct): # type: ignore[override] + captured["construct"] = construct + return payload + + monkeypatch.setattr(TaapiClient, "_request_bulk", fake_request) + + state = client.get_market_state("BTC", "binance") + + assert state["symbol"] == "BTC" + assert state["current_macd"] == 53.823 + assert state["short_term"]["mid_prices"] == [114339.5, 114500.0] + assert state["long_term"]["macd_series_4h"] == [1082.2, 950.1] + ids = {entry["id"] for entry in captured["construct"]} + assert "current_price" in ids + assert "long_rsi14_series" in ids diff --git a/utils/config.py b/utils/config.py index ef63a58..8838831 100644 --- a/utils/config.py +++ b/utils/config.py @@ -19,6 +19,7 @@ }, "rate_limits": { "unified_api": {"qpm": 200, "burst": 200, "timeout_sec": 10}, + "taapi": {"qpm": 120, "burst": 120, "timeout_sec": 10}, }, "features": {"rolling": [7, 14, 30]}, "scoring": {