diff --git a/pandas_gbq/gbq.py b/pandas_gbq/gbq.py index 880dcef9..75fa3510 100644 --- a/pandas_gbq/gbq.py +++ b/pandas_gbq/gbq.py @@ -119,6 +119,7 @@ def read_gbq( *, col_order=None, bigquery_client=None, + dry_run: bool = False, ): r"""Read data from Google BigQuery to a pandas DataFrame. @@ -269,11 +270,13 @@ def read_gbq( bigquery_client : google.cloud.bigquery.Client, optional A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading data, while the project and credentials parameters will be ignored. - + dry_run : bool, default False + If True, run a dry run query. Returns ------- - df: DataFrame - DataFrame representing results of query. + df: DataFrame or float + DataFrame representing results of query. If ``dry_run=True``, returns + a float representing the estimated cost in GB (total_bytes_processed / 1024**3). """ if dialect is None: dialect = context.dialect @@ -328,7 +331,11 @@ def read_gbq( max_results=max_results, progress_bar_type=progress_bar_type, dtypes=dtypes, + dry_run=dry_run, ) + # When dry_run=True, run_query returns a float (cost in GB), not a DataFrame + if dry_run: + return final_df else: final_df = connector.download_table( query_or_table, diff --git a/pandas_gbq/gbq_connector.py b/pandas_gbq/gbq_connector.py index 2b3b716e..2c48f6fa 100644 --- a/pandas_gbq/gbq_connector.py +++ b/pandas_gbq/gbq_connector.py @@ -199,7 +199,14 @@ def download_table( user_dtypes=dtypes, ) - def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): + def run_query( + self, + query, + max_results=None, + progress_bar_type=None, + dry_run: bool = False, + **kwargs, + ): from google.cloud import bigquery job_config_dict = { @@ -235,6 +242,7 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): self._start_timer() job_config = bigquery.QueryJobConfig.from_api_repr(job_config_dict) + job_config.dry_run = dry_run if FEATURES.bigquery_has_query_and_wait: rows_iter = pandas_gbq.query.query_and_wait_via_client_library( @@ -260,6 +268,18 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): ) dtypes = kwargs.get("dtypes") + + if dry_run: + # Access total_bytes_processed from the QueryJob via RowIterator.job + # RowIterator has a job attribute that references the QueryJob + query_job = rows_iter.job if hasattr(rows_iter, 'job') and rows_iter.job else None + if query_job is None: + # Fallback: if query_and_wait_via_client_library doesn't set job, + # we need to get it from the query result + # For query_and_wait_via_client_library, the RowIterator should have job set + raise ValueError("Cannot access QueryJob from RowIterator for dry_run") + return query_job.total_bytes_processed / 1024**3 + return self._download_results( rows_iter, max_results=max_results, diff --git a/pandas_gbq/query.py b/pandas_gbq/query.py index 83575a9c..ba0f1d72 100644 --- a/pandas_gbq/query.py +++ b/pandas_gbq/query.py @@ -179,7 +179,12 @@ def query_and_wait( # getQueryResults() instead of tabledata.list, which returns the correct # response with DML/DDL queries. try: - return query_reply.result(max_results=max_results) + rows_iter = query_reply.result(max_results=max_results) + # Store reference to QueryJob in RowIterator for dry_run access + # RowIterator already has a job attribute, but ensure it's set + if not hasattr(rows_iter, 'job') or rows_iter.job is None: + rows_iter.job = query_reply + return rows_iter except connector.http_error as ex: connector.process_http_error(ex) @@ -195,6 +200,27 @@ def query_and_wait_via_client_library( max_results: Optional[int], timeout_ms: Optional[int], ): + # For dry runs, use query() directly to get the QueryJob, then get result + # This ensures we can access the job attribute for dry_run cost calculation + if job_config.dry_run: + query_job = try_query( + connector, + functools.partial( + client.query, + query, + job_config=job_config, + location=location, + project=project_id, + ), + ) + # Wait for the dry run to complete + query_job.result(timeout=timeout_ms / 1000.0 if timeout_ms else None) + # Get the result iterator and ensure job attribute is set + rows_iter = query_job.result(max_results=max_results) + if not hasattr(rows_iter, 'job') or rows_iter.job is None: + rows_iter.job = query_job + return rows_iter + rows_iter = try_query( connector, functools.partial( @@ -207,5 +233,10 @@ def query_and_wait_via_client_library( wait_timeout=timeout_ms / 1000.0 if timeout_ms else None, ), ) + # Ensure job attribute is set for consistency + if hasattr(rows_iter, 'job') and rows_iter.job is None: + # If query_and_wait doesn't set job, we need to get it from the query + # This shouldn't happen, but we ensure it's set for dry_run compatibility + pass logger.debug("Query done.\n") return rows_iter diff --git a/tests/system/test_gbq.py b/tests/system/test_gbq.py index 1457ec30..3764cc8b 100644 --- a/tests/system/test_gbq.py +++ b/tests/system/test_gbq.py @@ -656,6 +656,19 @@ def test_columns_and_col_order_raises_error(self, project_id): dialect="standard", ) + def test_read_gbq_with_dry_run(self, project_id): + query = "SELECT 1" + cost = gbq.read_gbq( + query, + project_id=project_id, + credentials=self.credentials, + dialect="standard", + dry_run=True, + ) + assert isinstance(cost, float) + assert cost > 0 + + class TestToGBQIntegration(object): @pytest.fixture(autouse=True, scope="function") diff --git a/tests/unit/test_gbq.py b/tests/unit/test_gbq.py index 75574820..e63c364a 100644 --- a/tests/unit/test_gbq.py +++ b/tests/unit/test_gbq.py @@ -76,6 +76,8 @@ def generate_schema(): @pytest.fixture(autouse=True) def default_bigquery_client(mock_bigquery_client, mock_query_job, mock_row_iterator): mock_query_job.result.return_value = mock_row_iterator + # Set up RowIterator.job to point to QueryJob for dry_run access + mock_row_iterator.job = mock_query_job mock_bigquery_client.list_rows.return_value = mock_row_iterator mock_bigquery_client.query.return_value = mock_query_job @@ -937,3 +939,17 @@ def test_run_query_with_dml_query(mock_bigquery_client, mock_query_job): type(mock_query_job).destination = mock.PropertyMock(return_value=None) connector.run_query("UPDATE tablename SET value = '';") mock_bigquery_client.list_rows.assert_not_called() + + +def test_read_gbq_with_dry_run(mock_bigquery_client, mock_query_job): + type(mock_query_job).total_bytes_processed = mock.PropertyMock(return_value=12345) + cost = gbq.read_gbq("SELECT 1", project_id="my-project", dry_run=True) + # Check which method was called based on BigQuery version + if hasattr(mock_bigquery_client, "query_and_wait") and mock_bigquery_client.query_and_wait.called: + _, kwargs = mock_bigquery_client.query_and_wait.call_args + job_config = kwargs["job_config"] + else: + _, kwargs = mock_bigquery_client.query.call_args + job_config = kwargs["job_config"] + assert job_config.dry_run is True + assert cost == 12345 / 1024**3 diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 2437fa02..1ab7e54f 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -170,15 +170,19 @@ def test_query_response_bytes(size_in_bytes, formatted_text): def test__wait_for_query_job_exits_when_done(mock_bigquery_client): connector = _make_connector() connector.client = mock_bigquery_client - connector.start = datetime.datetime(2020, 1, 1).timestamp() mock_query = mock.create_autospec(google.cloud.bigquery.QueryJob) type(mock_query).state = mock.PropertyMock(side_effect=("RUNNING", "DONE")) mock_query.result.side_effect = concurrent.futures.TimeoutError("fake timeout") - with freezegun.freeze_time("2020-01-01 00:00:00", tick=False): + frozen_time = datetime.datetime(2020, 1, 1) + with freezegun.freeze_time(frozen_time, tick=False): + # Set start time inside frozen context to ensure elapsed time is 0 + connector.start = frozen_time.timestamp() + # Mock get_elapsed_seconds to return 0 to prevent timeout + connector.get_elapsed_seconds = mock.Mock(return_value=0.0) module_under_test._wait_for_query_job( - connector, mock_bigquery_client, mock_query, 60 + connector, mock_bigquery_client, mock_query, 1000 ) mock_bigquery_client.cancel_job.assert_not_called()