Skip to content

Commit 9e22cc4

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Bigquery detect_anomalies tool results sort by timestamp for better visualization
Timestamp need to be ordered so that for better display and further visualization. PiperOrigin-RevId: 829548481
1 parent 51dee43 commit 9e22cc4

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ def detect_anomalies(
11361136
history_data: str,
11371137
times_series_timestamp_col: str,
11381138
times_series_data_col: str,
1139-
horizon: Optional[int] = 10,
1139+
horizon: Optional[int] = 1000,
11401140
target_data: Optional[str] = None,
11411141
times_series_id_cols: Optional[list[str]] = None,
11421142
anomaly_prob_threshold: Optional[float] = 0.95,
@@ -1158,7 +1158,7 @@ def detect_anomalies(
11581158
times_series_data_col (str): The name of the column containing the
11591159
numerical values to be forecasted and anomaly detected.
11601160
horizon (int, optional): The number of time steps to forecast into the
1161-
future. Defaults to 10.
1161+
future. Defaults to 1000.
11621162
target_data (str, optional): The table id of the BigQuery table containing
11631163
the target time series data or a query statement that select the target
11641164
data.
@@ -1301,9 +1301,14 @@ def detect_anomalies(
13011301
OPTIONS ({options_str})
13021302
AS {history_data_source}
13031303
"""
1304+
order_by_id_cols = (
1305+
", ".join(col for col in times_series_id_cols) + ", "
1306+
if times_series_id_cols
1307+
else ""
1308+
)
13041309

13051310
anomaly_detection_query = f"""
1306-
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold))
1311+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold)) ORDER BY {order_by_id_cols}{times_series_timestamp_col}
13071312
"""
13081313
if target_data:
13091314
trimmed_upper_target_data = target_data.strip().upper()
@@ -1312,10 +1317,10 @@ def detect_anomalies(
13121317
) or trimmed_upper_target_data.startswith("WITH"):
13131318
target_data_source = f"({target_data})"
13141319
else:
1315-
target_data_source = f"SELECT * FROM `{target_data}`"
1320+
target_data_source = f"(SELECT * FROM `{target_data}`)"
13161321

13171322
anomaly_detection_query = f"""
1318-
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold), {target_data_source})
1323+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL {model_name}, STRUCT({anomaly_prob_threshold} AS anomaly_prob_threshold), {target_data_source}) ORDER BY {order_by_id_cols}{times_series_timestamp_col}
13191324
"""
13201325

13211326
# Create a session and run the create model query.

tests/unittests/tools/bigquery/test_bigquery_query_tool.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,12 +1436,12 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
14361436

14371437
expected_create_model_query = """
14381438
CREATE TEMP MODEL detect_anomalies_model_test_uuid
1439-
OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 10)
1439+
OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 1000)
14401440
AS (SELECT * FROM `test-dataset.test-table`)
14411441
"""
14421442

14431443
expected_anomaly_detection_query = """
1444-
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.95 AS anomaly_prob_threshold))
1444+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.95 AS anomaly_prob_threshold)) ORDER BY ts_timestamp
14451445
"""
14461446

14471447
assert mock_execute_sql.call_count == 2
@@ -1497,7 +1497,7 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
14971497
"""
14981498

14991499
expected_anomaly_detection_query = """
1500-
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold))
1500+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold)) ORDER BY dim1, dim2, ts_timestamp
15011501
"""
15021502

15031503
assert mock_execute_sql.call_count == 2
@@ -1555,7 +1555,61 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
15551555
"""
15561556

15571557
expected_anomaly_detection_query = """
1558-
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`))
1558+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.8 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY dim1, dim2, ts_timestamp
1559+
"""
1560+
1561+
assert mock_execute_sql.call_count == 2
1562+
mock_execute_sql.assert_any_call(
1563+
project_id="test-project",
1564+
query=expected_create_model_query,
1565+
credentials=mock_credentials,
1566+
settings=mock_settings,
1567+
tool_context=mock_tool_context,
1568+
caller_id="detect_anomalies",
1569+
)
1570+
mock_execute_sql.assert_any_call(
1571+
project_id="test-project",
1572+
query=expected_anomaly_detection_query,
1573+
credentials=mock_credentials,
1574+
settings=mock_settings,
1575+
tool_context=mock_tool_context,
1576+
caller_id="detect_anomalies",
1577+
)
1578+
1579+
1580+
# detect_anomalies calls execute_sql twice. We need to test that
1581+
# the queries are properly constructed and call execute_sql with the correct
1582+
# parameters exactly twice.
1583+
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
1584+
@mock.patch("uuid.uuid4", autospec=True)
1585+
def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql):
1586+
"""Test time series anomaly detection tool invocation with a table id."""
1587+
mock_credentials = mock.MagicMock(spec=Credentials)
1588+
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
1589+
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
1590+
mock_uuid.return_value = "test_uuid"
1591+
mock_execute_sql.return_value = {"status": "SUCCESS"}
1592+
1593+
history_data_query = "SELECT * FROM `test-dataset.test-table`"
1594+
detect_anomalies(
1595+
project_id="test-project",
1596+
history_data=history_data_query,
1597+
times_series_timestamp_col="ts_timestamp",
1598+
times_series_data_col="ts_data",
1599+
target_data="test-dataset.target-table",
1600+
credentials=mock_credentials,
1601+
settings=mock_settings,
1602+
tool_context=mock_tool_context,
1603+
)
1604+
1605+
expected_create_model_query = """
1606+
CREATE TEMP MODEL detect_anomalies_model_test_uuid
1607+
OPTIONS (MODEL_TYPE = 'ARIMA_PLUS', TIME_SERIES_TIMESTAMP_COL = 'ts_timestamp', TIME_SERIES_DATA_COL = 'ts_data', HORIZON = 1000)
1608+
AS (SELECT * FROM `test-dataset.test-table`)
1609+
"""
1610+
1611+
expected_anomaly_detection_query = """
1612+
SELECT * FROM ML.DETECT_ANOMALIES(MODEL detect_anomalies_model_test_uuid, STRUCT(0.95 AS anomaly_prob_threshold), (SELECT * FROM `test-dataset.target-table`)) ORDER BY ts_timestamp
15591613
"""
15601614

15611615
assert mock_execute_sql.call_count == 2

0 commit comments

Comments
 (0)