Skip to content

Commit a78ae5a

Browse files
committed
Implement batch_size parameter for Observations.get_product_list
1 parent 9b2f3f1 commit a78ae5a

File tree

5 files changed

+93
-55
lines changed

5 files changed

+93
-55
lines changed

astroquery/mast/missions.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,14 @@ def _extract_products(self, response):
9797
list
9898
A list of products extracted from the response.
9999
"""
100-
def normalize_products(products):
101-
"""
102-
Normalize the products list to ensure it is flat and not nested.
103-
"""
100+
combined = []
101+
for resp in response:
102+
products = resp.json().get('products', [])
103+
# Flatten if nested
104104
if products and isinstance(products[0], list):
105-
return products[0]
106-
return products
107-
108-
if isinstance(response, list): # multiple async responses from batching
109-
combined = []
110-
for resp in response:
111-
products = normalize_products(resp.json().get('products', []))
112-
combined.extend(products)
113-
return combined
114-
else: # single response
115-
return normalize_products(response.json().get('products', []))
105+
products = products[0]
106+
combined.extend(products)
107+
return combined
116108

117109
def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality
118110
"""
@@ -417,7 +409,7 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse
417409
select_cols=select_cols, **criteria)
418410

419411
@class_or_instance
420-
def get_product_list_async(self, datasets):
412+
def get_product_list_async(self, datasets, batch_size=1000):
421413
"""
422414
Given a dataset ID or list of dataset IDs, returns a list of associated data products.
423415
@@ -428,6 +420,9 @@ def get_product_list_async(self, datasets):
428420
datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table`
429421
Row/Table of MastMissions query results (e.g. output from `query_object`)
430422
or single/list of dataset ID(s).
423+
batch_size : int, optional
424+
Default 1000. Number of dataset IDs to include in each batch request to the server.
425+
If you experience timeouts or connection errors, consider lowering this value.
431426
432427
Returns
433428
-------
@@ -439,7 +434,8 @@ def get_product_list_async(self, datasets):
439434
if isinstance(datasets, Table) or isinstance(datasets, Row):
440435
dataset_kwd = self.get_dataset_kwd()
441436
if not dataset_kwd:
442-
log.warning('Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
437+
log.warning(f'Dataset keyword not found for mission {self.mission}. '
438+
'Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
443439
return None
444440

445441
# Extract dataset IDs based on input type and mission
@@ -466,15 +462,15 @@ def get_product_list_async(self, datasets):
466462
results = utils._batched_request(
467463
datasets,
468464
params={},
469-
max_batch=1000,
465+
max_batch=batch_size,
470466
param_key="dataset_ids",
471467
request_func=lambda p: self._service_api_connection.missions_request_async(self.service, p),
472468
extract_func=lambda r: [r], # missions_request_async already returns one result
473469
desc=f"Fetching products for {len(datasets)} unique datasets"
474470
)
475471

476-
# Return a list of responses only if multiple requests were made
477-
return results[0] if len(results) == 1 else results
472+
# Return a list of responses
473+
return results
478474

479475
def get_unique_product_list(self, datasets):
480476
"""

astroquery/mast/observations.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _filter_ffi_observations(self, observations):
504504
return obs_table[mask]
505505

506506
@class_or_instance
507-
def get_product_list_async(self, observations):
507+
def get_product_list_async(self, observations, batch_size=500):
508508
"""
509509
Given a "Product Group Id" (column name obsid) returns a list of associated data products.
510510
Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in
@@ -518,31 +518,49 @@ def get_product_list_async(self, observations):
518518
Row/Table of MAST query results (e.g. output from `query_object`)
519519
or single/list of MAST Product Group Id(s) (obsid).
520520
See description `here <https://masttest.stsci.edu/api/v0/_c_a_o_mfields.html>`__.
521+
batch_size : int, optional
522+
Default 500. Number of obsids to include in each batch request to the server.
523+
If you experience timeouts or connection errors, consider lowering this value.
521524
522525
Returns
523526
-------
524527
response : list of `~requests.Response`
528+
A list of asynchronous response objects for each batch request.
525529
"""
526-
527-
# getting the obsid list
530+
# Getting the obsids as a list
528531
if np.isscalar(observations):
529-
observations = np.array([observations])
530-
if isinstance(observations, Table) or isinstance(observations, Row):
532+
observations = [observations]
533+
elif isinstance(observations, (Row, Table)):
531534
# Filter out TESS FFIs and TICA FFIs
532535
# Can only perform filtering on Row or Table because of access to `target_name` field
533536
observations = self._filter_ffi_observations(observations)
534-
observations = observations['obsid']
535-
if isinstance(observations, list):
536-
observations = np.array(observations)
537-
538-
observations = observations[observations != ""]
539-
if observations.size == 0:
540-
raise InvalidQueryError("Observation list is empty, no associated products.")
541-
542-
service = self._caom_products
543-
params = {'obsid': ','.join(observations)}
544-
545-
return self._portal_api_connection.service_request_async(service, params)
537+
observations = observations['obsid'].tolist()
538+
539+
# Clean and validate
540+
observations = [str(obs).strip() for obs in observations if str(obs).strip()]
541+
if not observations:
542+
raise InvalidQueryError('Observation list is empty, no associated products.')
543+
544+
# Define a helper to join obsids for each batch request
545+
def _request_joined_obsid(params):
546+
"""Join batched obsid list into comma-separated string and send async request."""
547+
pp = dict(params)
548+
vals = pp.get('obsid', [])
549+
pp['obsid'] = ','.join(map(str, vals))
550+
return self._portal_api_connection.service_request_async(self._caom_products, pp)[0]
551+
552+
# Perform batched requests
553+
results = utils._batched_request(
554+
items=observations,
555+
params={},
556+
max_batch=batch_size,
557+
param_key='obsid',
558+
request_func=_request_joined_obsid,
559+
extract_func=lambda r: [r],
560+
desc=f'Fetching products for {len(observations)} unique observations'
561+
)
562+
563+
return results
546564

547565
def filter_products(self, products, *, mrp_only=False, extension=None, **filters):
548566
"""

astroquery/mast/tests/test_mast.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,21 +305,21 @@ def test_missions_query_criteria(patch_post):
305305
def test_missions_get_product_list_async(patch_post):
306306
# String input
307307
result = mast.MastMissions.get_product_list_async('Z14Z0104T')
308-
assert isinstance(result, MockResponse)
308+
assert isinstance(result, list)
309309

310310
# List input
311311
in_datasets = ['Z14Z0104T', 'Z14Z0102T']
312312
result = mast.MastMissions.get_product_list_async(in_datasets)
313-
assert isinstance(result, MockResponse)
313+
assert isinstance(result, list)
314314

315315
# Row input
316316
datasets = mast.MastMissions.query_object("M101", radius=".002 deg")
317317
result = mast.MastMissions.get_product_list_async(datasets[:3])
318-
assert isinstance(result, MockResponse)
318+
assert isinstance(result, list)
319319

320320
# Table input
321321
result = mast.MastMissions.get_product_list_async(datasets[0])
322-
assert isinstance(result, MockResponse)
322+
assert isinstance(result, list)
323323

324324
# Unsupported data type for datasets
325325
with pytest.raises(TypeError) as err_type:
@@ -825,6 +825,10 @@ def test_observations_get_product_list(patch_post):
825825
result = mast.Observations.get_product_list(in_obsids)
826826
assert isinstance(result, Table)
827827

828+
# Error if no valid obsids are found
829+
with pytest.raises(InvalidQueryError, match='Observation list is empty'):
830+
mast.Observations.get_product_list([' '])
831+
828832

829833
def test_observations_filter_products(patch_post):
830834
products = mast.Observations.get_product_list('2003738726')

astroquery/mast/tests/test_mast_remote.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,25 @@ def test_missions_get_product_list_async(self):
199199

200200
# Table as input
201201
responses = MastMissions.get_product_list_async(datasets[:3])
202-
assert isinstance(responses, Response)
202+
assert isinstance(responses, list)
203203

204204
# Row as input
205205
responses = MastMissions.get_product_list_async(datasets[0])
206-
assert isinstance(responses, Response)
206+
assert isinstance(responses, list)
207207

208208
# String as input
209209
responses = MastMissions.get_product_list_async(datasets[0]['sci_data_set_name'])
210-
assert isinstance(responses, Response)
210+
assert isinstance(responses, list)
211211

212212
# Column as input
213213
responses = MastMissions.get_product_list_async(datasets[:3]['sci_data_set_name'])
214-
assert isinstance(responses, Response)
214+
assert isinstance(responses, list)
215+
216+
# Batching
217+
responses = MastMissions.get_product_list_async(datasets[:4], batch_size=2)
218+
assert isinstance(responses, list)
219+
assert len(responses) == 2
220+
assert isinstance(responses[0], Response)
215221

216222
# Unsupported data type for datasets
217223
with pytest.raises(TypeError) as err_type:
@@ -248,14 +254,13 @@ def test_missions_get_product_list(self, capsys):
248254
assert isinstance(result, Table)
249255
assert (result['dataset'] == 'IBKH03020').all()
250256

251-
# Test batching by creating a list of 1001 different strings
252-
# This won't return any results, but will test the batching
253-
dataset_list = [f'{i}' for i in range(1001)]
254-
result = MastMissions.get_product_list(dataset_list)
257+
# Test batching
258+
result_batch = MastMissions.get_product_list(datasets[:2], batch_size=1)
255259
out, _ = capsys.readouterr()
256-
assert isinstance(result, Table)
257-
assert len(result) == 0
258-
assert 'Fetching products for 1001 unique datasets in 2 batches' in out
260+
assert isinstance(result_batch, Table)
261+
assert len(result_batch) == len(result_table)
262+
assert set(result_batch['filename']) == set(result_table['filename'])
263+
assert 'Fetching products for 2 unique datasets in 2 batches' in out
259264

260265
def test_missions_get_unique_product_list(self, caplog):
261266
# Check that no rows are filtered out when all products are unique
@@ -593,7 +598,11 @@ def test_observations_get_product_list_async(self):
593598
responses = Observations.get_product_list_async(observations[0:4])
594599
assert isinstance(responses, list)
595600

596-
def test_observations_get_product_list(self):
601+
# Batching
602+
responses = Observations.get_product_list_async(observations[0:4], batch_size=2)
603+
assert isinstance(responses, list)
604+
605+
def test_observations_get_product_list(self, capsys):
597606
observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE'])
598607
test_obs_id = str(observations[0]['obsid'])
599608
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
@@ -626,6 +635,14 @@ def test_observations_get_product_list(self):
626635
assert len(obs_collection) == 1
627636
assert obs_collection[0] == 'IUE'
628637

638+
# Test batching
639+
result_batch = Observations.get_product_list(observations[:2], batch_size=1)
640+
out, _ = capsys.readouterr()
641+
assert isinstance(result_batch, Table)
642+
assert len(result_batch) == len(result1)
643+
assert set(result_batch['productFilename']) == set(filenames1)
644+
assert 'Fetching products for 2 unique observations in 2 batches' in out
645+
629646
def test_observations_get_product_list_tess_tica(self, caplog):
630647
# Get observations and products with both TESS and TICA FFIs
631648
obs = Observations.query_criteria(target_name=['TESS FFI', 'TICA FFI', '429031146'])

astroquery/mast/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _batched_request(
148148
return extract_func(resp)
149149

150150

151-
def resolve_object(objectname, *, resolver=None, resolve_all=False):
151+
def resolve_object(objectname, *, resolver=None, resolve_all=False, batch_size=30):
152152
"""
153153
Resolves one or more object names to a position on the sky.
154154
@@ -164,6 +164,9 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
164164
resolve_all : bool, optional
165165
If True, will try to resolve the object name using all available resolvers ("NED", "SIMBAD").
166166
Default is False.
167+
batch_size : int, optional
168+
Default 30. Number of object names to include in each batch request to the server.
169+
If you experience timeouts or connection errors, consider lowering this value.
167170
168171
Returns
169172
-------
@@ -230,7 +233,7 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
230233
results = _batched_request(
231234
object_names,
232235
params,
233-
max_batch=30,
236+
max_batch=batch_size,
234237
param_key="name",
235238
request_func=lambda p: _simple_request("http://mastresolver.stsci.edu/Santa-war/query", p),
236239
extract_func=lambda r: r.json().get("resolvedCoordinate") or [],

0 commit comments

Comments
 (0)