diff --git a/CHANGES.rst b/CHANGES.rst index f1842b0821..2553e4f810 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -52,6 +52,10 @@ mast - Raise informative error if ``MastMissions`` query radius is too large. [#3447] +- Add ``batch_size`` parameter to ``MastMissions.get_product_list``, ``Observations.get_product_list``, + and ``utils.resolve_object`` to allow controlling the number of items sent in each batch request to the server. + This can help avoid timeouts or connection errors for large requests. [#3454] + jplspec ^^^^^^^ diff --git a/astroquery/mast/missions.py b/astroquery/mast/missions.py index 38af7370a7..7a473accba 100644 --- a/astroquery/mast/missions.py +++ b/astroquery/mast/missions.py @@ -97,22 +97,14 @@ def _extract_products(self, response): list A list of products extracted from the response. """ - def normalize_products(products): - """ - Normalize the products list to ensure it is flat and not nested. - """ + combined = [] + for resp in response: + products = resp.json().get('products', []) + # Flatten if nested if products and isinstance(products[0], list): - return products[0] - return products - - if isinstance(response, list): # multiple async responses from batching - combined = [] - for resp in response: - products = normalize_products(resp.json().get('products', [])) - combined.extend(products) - return combined - else: # single response - return normalize_products(response.json().get('products', [])) + products = products[0] + combined.extend(products) + return combined def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality """ @@ -417,7 +409,7 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse select_cols=select_cols, **criteria) @class_or_instance - def get_product_list_async(self, datasets): + def get_product_list_async(self, datasets, *, batch_size=1000): """ Given a dataset ID or list of dataset IDs, returns a list of associated data products. @@ -428,6 +420,9 @@ def get_product_list_async(self, datasets): datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table` Row/Table of MastMissions query results (e.g. output from `query_object`) or single/list of dataset ID(s). + batch_size : int, optional + Default 1000. Number of dataset IDs to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- @@ -439,8 +434,8 @@ def get_product_list_async(self, datasets): if isinstance(datasets, Table) or isinstance(datasets, Row): dataset_kwd = self.get_dataset_kwd() if not dataset_kwd: - log.warning('Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.') - return None + raise InvalidQueryError(f'Dataset keyword not found for mission "{self.mission}". Please input ' + 'dataset IDs as a string, list of strings, or `~astropy.table.Column`.') # Extract dataset IDs based on input type and mission if isinstance(datasets, Table): @@ -466,17 +461,17 @@ def get_product_list_async(self, datasets): results = utils._batched_request( datasets, params={}, - max_batch=1000, + max_batch=batch_size, param_key="dataset_ids", request_func=lambda p: self._service_api_connection.missions_request_async(self.service, p), extract_func=lambda r: [r], # missions_request_async already returns one result desc=f"Fetching products for {len(datasets)} unique datasets" ) - # Return a list of responses only if multiple requests were made - return results[0] if len(results) == 1 else results + # Return a list of responses + return results - def get_unique_product_list(self, datasets): + def get_unique_product_list(self, datasets, *, batch_size=1000): """ Given a dataset ID or list of dataset IDs, returns a list of associated data products with unique filenames. @@ -486,13 +481,16 @@ def get_unique_product_list(self, datasets): datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table` Row/Table of MastMissions query results (e.g. output from `query_object`) or single/list of dataset ID(s). + batch_size : int, optional + Default 1000. Number of dataset IDs to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- unique_products : `~astropy.table.Table` Table containing products with unique URIs. """ - products = self.get_product_list(datasets) + products = self.get_product_list(datasets, batch_size=batch_size) unique_products = utils.remove_duplicate_products(products, 'filename') if len(unique_products) < len(products): log.info("To return all products, use `MastMissions.get_product_list`") diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index 2a0d56b179..b6d701099a 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -504,7 +504,7 @@ def _filter_ffi_observations(self, observations): return obs_table[mask] @class_or_instance - def get_product_list_async(self, observations): + def get_product_list_async(self, observations, *, batch_size=500): """ Given a "Product Group Id" (column name obsid) returns a list of associated data products. Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in @@ -518,31 +518,50 @@ def get_product_list_async(self, observations): Row/Table of MAST query results (e.g. output from `query_object`) or single/list of MAST Product Group Id(s) (obsid). See description `here `__. + batch_size : int, optional + Default 500. Number of obsids to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- response : list of `~requests.Response` + A list of asynchronous response objects for each batch request. """ - - # getting the obsid list + # Getting the obsids as a list if np.isscalar(observations): - observations = np.array([observations]) - if isinstance(observations, Table) or isinstance(observations, Row): + observations = [observations] + elif isinstance(observations, (Row, Table)): # Filter out TESS FFIs and TICA FFIs # Can only perform filtering on Row or Table because of access to `target_name` field observations = self._filter_ffi_observations(observations) - observations = observations['obsid'] - if isinstance(observations, list): - observations = np.array(observations) - - observations = observations[observations != ""] - if observations.size == 0: - raise InvalidQueryError("Observation list is empty, no associated products.") - - service = self._caom_products - params = {'obsid': ','.join(observations)} - - return self._portal_api_connection.service_request_async(service, params) + observations = observations['obsid'].tolist() + + # Clean and validate + observations = [str(obs).strip() for obs in observations] + observations = [obs for obs in observations if obs] + if not observations: + raise InvalidQueryError('Observation list is empty, no associated products.') + + # Define a helper to join obsids for each batch request + def _request_joined_obsid(params): + """Join batched obsid list into comma-separated string and send async request.""" + pp = dict(params) + vals = pp.get('obsid', []) + pp['obsid'] = ','.join(map(str, vals)) + return self._portal_api_connection.service_request_async(self._caom_products, pp)[0] + + # Perform batched requests + results = utils._batched_request( + items=observations, + params={}, + max_batch=batch_size, + param_key='obsid', + request_func=_request_joined_obsid, + extract_func=lambda r: [r], + desc=f'Fetching products for {len(observations)} unique observations' + ) + + return results def filter_products(self, products, *, mrp_only=False, extension=None, **filters): """ @@ -1029,7 +1048,7 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False): # Query for product URIs return self._cloud_connection.get_cloud_uri(data_product, include_bucket, full_url) - def get_unique_product_list(self, observations): + def get_unique_product_list(self, observations, *, batch_size=500): """ Given a "Product Group Id" (column name obsid), returns a list of associated data products with unique dataURIs. Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in @@ -1041,13 +1060,16 @@ def get_unique_product_list(self, observations): Row/Table of MAST query results (e.g. output from `query_object`) or single/list of MAST Product Group Id(s) (obsid). See description `here `__. + batch_size : int, optional + Default 500. Number of obsids to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- unique_products : `~astropy.table.Table` Table containing products with unique dataURIs. """ - products = self.get_product_list(observations) + products = self.get_product_list(observations, batch_size=batch_size) unique_products = utils.remove_duplicate_products(products, 'dataURI') if len(unique_products) < len(products): log.info("To return all products, use `Observations.get_product_list`") diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 7929afb378..7f9790b9af 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -305,21 +305,21 @@ def test_missions_query_criteria(patch_post): def test_missions_get_product_list_async(patch_post): # String input result = mast.MastMissions.get_product_list_async('Z14Z0104T') - assert isinstance(result, MockResponse) + assert isinstance(result, list) # List input in_datasets = ['Z14Z0104T', 'Z14Z0102T'] result = mast.MastMissions.get_product_list_async(in_datasets) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Row input datasets = mast.MastMissions.query_object("M101", radius=".002 deg") result = mast.MastMissions.get_product_list_async(datasets[:3]) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Table input result = mast.MastMissions.get_product_list_async(datasets[0]) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Unsupported data type for datasets with pytest.raises(TypeError) as err_type: @@ -331,6 +331,11 @@ def test_missions_get_product_list_async(patch_post): mast.MastMissions.get_product_list_async([' ']) assert 'Dataset list is empty' in str(err_empty.value) + # No dataset keyword + with pytest.raises(InvalidQueryError, match='Dataset keyword not found for mission "invalid"'): + missions = mast.MastMissions(mission='invalid') + missions.get_product_list_async(Table({'a': [1, 2, 3]})) + def test_missions_get_product_list(patch_post): # String input @@ -825,6 +830,10 @@ def test_observations_get_product_list(patch_post): result = mast.Observations.get_product_list(in_obsids) assert isinstance(result, Table) + # Error if no valid obsids are found + with pytest.raises(InvalidQueryError, match='Observation list is empty'): + mast.Observations.get_product_list([' ']) + def test_observations_filter_products(patch_post): products = mast.Observations.get_product_list('2003738726') diff --git a/astroquery/mast/tests/test_mast_remote.py b/astroquery/mast/tests/test_mast_remote.py index 0d4be9e961..18195502d5 100644 --- a/astroquery/mast/tests/test_mast_remote.py +++ b/astroquery/mast/tests/test_mast_remote.py @@ -199,19 +199,25 @@ def test_missions_get_product_list_async(self): # Table as input responses = MastMissions.get_product_list_async(datasets[:3]) - assert isinstance(responses, Response) + assert isinstance(responses, list) # Row as input responses = MastMissions.get_product_list_async(datasets[0]) - assert isinstance(responses, Response) + assert isinstance(responses, list) # String as input responses = MastMissions.get_product_list_async(datasets[0]['sci_data_set_name']) - assert isinstance(responses, Response) + assert isinstance(responses, list) # Column as input responses = MastMissions.get_product_list_async(datasets[:3]['sci_data_set_name']) - assert isinstance(responses, Response) + assert isinstance(responses, list) + + # Batching + responses = MastMissions.get_product_list_async(datasets[:4], batch_size=2) + assert isinstance(responses, list) + assert len(responses) == 2 + assert isinstance(responses[0], Response) # Unsupported data type for datasets with pytest.raises(TypeError) as err_type: @@ -248,14 +254,13 @@ def test_missions_get_product_list(self, capsys): assert isinstance(result, Table) assert (result['dataset'] == 'IBKH03020').all() - # Test batching by creating a list of 1001 different strings - # This won't return any results, but will test the batching - dataset_list = [f'{i}' for i in range(1001)] - result = MastMissions.get_product_list(dataset_list) + # Test batching + result_batch = MastMissions.get_product_list(datasets[:2], batch_size=1) out, _ = capsys.readouterr() - assert isinstance(result, Table) - assert len(result) == 0 - assert 'Fetching products for 1001 unique datasets in 2 batches' in out + assert isinstance(result_batch, Table) + assert len(result_batch) == len(result_table) + assert set(result_batch['filename']) == set(result_table['filename']) + assert 'Fetching products for 2 unique datasets in 2 batches' in out def test_missions_get_unique_product_list(self, caplog): # Check that no rows are filtered out when all products are unique @@ -593,7 +598,11 @@ def test_observations_get_product_list_async(self): responses = Observations.get_product_list_async(observations[0:4]) assert isinstance(responses, list) - def test_observations_get_product_list(self): + # Batching + responses = Observations.get_product_list_async(observations[0:4], batch_size=2) + assert isinstance(responses, list) + + def test_observations_get_product_list(self, capsys): observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE']) test_obs_id = str(observations[0]['obsid']) mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid']) @@ -626,6 +635,14 @@ def test_observations_get_product_list(self): assert len(obs_collection) == 1 assert obs_collection[0] == 'IUE' + # Test batching + result_batch = Observations.get_product_list(observations[:2], batch_size=1) + out, _ = capsys.readouterr() + assert isinstance(result_batch, Table) + assert len(result_batch) == len(result1) + assert set(result_batch['productFilename']) == set(filenames1) + assert 'Fetching products for 2 unique observations in 2 batches' in out + def test_observations_get_product_list_tess_tica(self, caplog): # Get observations and products with both TESS and TICA FFIs obs = Observations.query_criteria(target_name=['TESS FFI', 'TICA FFI', '429031146']) diff --git a/astroquery/mast/utils.py b/astroquery/mast/utils.py index 0b4d666192..aba37af17f 100644 --- a/astroquery/mast/utils.py +++ b/astroquery/mast/utils.py @@ -148,7 +148,7 @@ def _batched_request( return extract_func(resp) -def resolve_object(objectname, *, resolver=None, resolve_all=False): +def resolve_object(objectname, *, resolver=None, resolve_all=False, batch_size=30): """ Resolves one or more object names to a position on the sky. @@ -164,6 +164,9 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False): resolve_all : bool, optional If True, will try to resolve the object name using all available resolvers ("NED", "SIMBAD"). Default is False. + batch_size : int, optional + Default 30. Number of object names to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- @@ -230,7 +233,7 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False): results = _batched_request( object_names, params, - max_batch=30, + max_batch=batch_size, param_key="name", request_func=lambda p: _simple_request("http://mastresolver.stsci.edu/Santa-war/query", p), extract_func=lambda r: r.json().get("resolvedCoordinate") or [], diff --git a/docs/mast/mast_missions.rst b/docs/mast/mast_missions.rst index 147aa6dd0d..d1ebfdab35 100644 --- a/docs/mast/mast_missions.rst +++ b/docs/mast/mast_missions.rst @@ -203,11 +203,16 @@ Each observation returned from a MAST query can have one or more associated data one or more datasets or dataset IDs, the `~astroquery.mast.MastMissionsClass.get_product_list` function will return a `~astropy.table.Table` containing the associated data products. +`~astroquery.mast.MastMissionsClass.get_product_list` also includes an optional ``batch_size`` parameter, +which controls how many datasets are sent to the MAST service per request. This can be useful for managing +memory usage or avoiding timeouts when requesting product lists for large numbers of datasets. +If not provided, batch_size defaults to 1000. + .. doctest-remote-data:: >>> datasets = missions.query_criteria(sci_pep_id=12451, ... sci_instrume='ACS', ... sci_hlsp='>1') - >>> products = missions.get_product_list(datasets[:2]) + >>> products = missions.get_product_list(datasets[:2], batch_size=1000) >>> print(products[:5]) # doctest: +IGNORE_OUTPUT product_key access dataset ... category size type ---------------------------- ------ --------- ... ---------- --------- ------- diff --git a/docs/mast/mast_obsquery.rst b/docs/mast/mast_obsquery.rst index 8cfa1acbf1..a7d928cfc0 100644 --- a/docs/mast/mast_obsquery.rst +++ b/docs/mast/mast_obsquery.rst @@ -214,17 +214,22 @@ Getting Product Lists --------------------- Each observation returned from a MAST query can have one or more associated data products. -Given one or more observations or MAST Product Group IDs ("obsid") +Given one or more observations or MAST Product Group IDs ("obsid"), `~astroquery.mast.ObservationsClass.get_product_list` will return a `~astropy.table.Table` containing the associated data products. The product fields are documented `here `__. +`~astroquery.mast.ObservationsClass.get_product_list` also includes an optional ``batch_size`` parameter, +which controls how many observations are sent to the MAST service per request. This can be useful for managing +memory usage or avoiding timeouts when requesting product lists for large numbers of observations. +If not provided, batch_size defaults to 500. + .. doctest-remote-data:: >>> from astroquery.mast import Observations ... >>> obs_table = Observations.query_criteria(objectname="M8", obs_collection=["K2", "IUE"]) - >>> data_products_by_obs = Observations.get_product_list(obs_table[0:2]) + >>> data_products_by_obs = Observations.get_product_list(obs_table[0:2], batch_size=500) >>> print(data_products_by_obs) # doctest: +IGNORE_OUTPUT obsID obs_collection dataproduct_type ... dataRights calib_level filters ------ -------------- ---------------- ... ---------- ----------- -------