11"""Item crud client."""
2- import json
32import logging
43from base64 import urlsafe_b64encode
4+ import re
55from datetime import datetime as datetime_type
66from datetime import timezone
77from typing import Any , Dict , List , Optional , Set , Type , Union
8- from urllib .parse import urljoin
8+ from urllib .parse import unquote_plus , urljoin
99
1010import attr
11+ import orjson
1112import stac_pydantic
12- from fastapi import HTTPException
13+ from fastapi import HTTPException , Request
1314from overrides import overrides
1415from pydantic import ValidationError
16+ from pygeofilter .backends .cql2_json import to_cql2
17+ from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
1518from stac_pydantic .links import Relations
1619from stac_pydantic .shared import MimeTypes
17- from starlette .requests import Request
1820
1921from stac_fastapi .elasticsearch import serializers
2022from stac_fastapi .elasticsearch .config import ElasticsearchSettings
@@ -303,9 +305,9 @@ def _return_date(interval_str):
303305
304306 return {"lte" : end_date , "gte" : start_date }
305307
306- @overrides
307308 async def get_search (
308309 self ,
310+ request : Request ,
309311 collections : Optional [List [str ]] = None ,
310312 ids : Optional [List [str ]] = None ,
311313 bbox : Optional [List [NumType ]] = None ,
@@ -316,8 +318,8 @@ async def get_search(
316318 fields : Optional [List [str ]] = None ,
317319 sortby : Optional [str ] = None ,
318320 intersects : Optional [str ] = None ,
319- # filter: Optional[str] = None, # todo: requires fastapi > 2.3 unreleased
320- # filter_lang: Optional[str] = None, # todo: requires fastapi > 2.3 unreleased
321+ filter : Optional [str ] = None ,
322+ filter_lang : Optional [str ] = None ,
321323 ** kwargs ,
322324 ) -> ItemCollection :
323325 """Get search results from the database.
@@ -347,17 +349,24 @@ async def get_search(
347349 "bbox" : bbox ,
348350 "limit" : limit ,
349351 "token" : token ,
350- "query" : json .loads (query ) if query else query ,
352+ "query" : orjson .loads (query ) if query else query ,
351353 }
352354
355+ # this is borrowed from stac-fastapi-pgstac
356+ # Kludgy fix because using factory does not allow alias for filter-lan
357+ query_params = str (request .query_params )
358+ if filter_lang is None :
359+ match = re .search (r"filter-lang=([a-z0-9-]+)" , query_params , re .IGNORECASE )
360+ if match :
361+ filter_lang = match .group (1 )
362+
353363 if datetime :
354364 base_args ["datetime" ] = datetime
355365
356366 if intersects :
357- base_args ["intersects" ] = intersects
367+ base_args ["intersects" ] = orjson . loads ( unquote_plus ( intersects ))
358368
359369 if sortby :
360- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
361370 sort_param = []
362371 for sort in sortby :
363372 sort_param .append (
@@ -368,12 +377,13 @@ async def get_search(
368377 )
369378 base_args ["sortby" ] = sort_param
370379
371- # todo: requires fastapi > 2.3 unreleased
372- # if filter:
373- # if filter_lang == "cql2-text":
374- # base_args["filter-lang"] = "cql2-json"
375- # base_args["filter"] = orjson.loads(to_cql2(parse_cql2_text(filter)))
376- # print(f'>>> {base_args["filter"]}')
380+ if filter :
381+ if filter_lang == "cql2-text" :
382+ base_args ["filter-lang" ] = "cql2-json"
383+ base_args ["filter" ] = orjson .loads (to_cql2 (parse_cql2_text (filter )))
384+ else :
385+ base_args ["filter-lang" ] = "cql2-json"
386+ base_args ["filter" ] = orjson .loads (unquote_plus (filter ))
377387
378388 if fields :
379389 includes = set ()
@@ -392,13 +402,12 @@ async def get_search(
392402 search_request = self .post_request_model (** base_args )
393403 except ValidationError :
394404 raise HTTPException (status_code = 400 , detail = "Invalid parameters provided" )
395- resp = await self .post_search (search_request , request = kwargs [ " request" ] )
405+ resp = await self .post_search (search_request = search_request , request = request )
396406
397407 return resp
398408
399- @overrides
400409 async def post_search (
401- self , search_request : BaseSearchPostRequest , ** kwargs
410+ self , search_request : BaseSearchPostRequest , request : Request
402411 ) -> ItemCollection :
403412 """
404413 Perform a POST search on the catalog.
@@ -413,7 +422,6 @@ async def post_search(
413422 Raises:
414423 HTTPException: If there is an error with the cql2_json filter.
415424 """
416- request : Request = kwargs ["request" ]
417425 base_url = str (request .base_url )
418426
419427 search = self .database .make_search ()
@@ -500,7 +508,7 @@ async def post_search(
500508 filter_kwargs = search_request .fields .filter_fields
501509
502510 items = [
503- json .loads (stac_pydantic .Item (** feat ).json (** filter_kwargs ))
511+ orjson .loads (stac_pydantic .Item (** feat ).json (** filter_kwargs ))
504512 for feat in items
505513 ]
506514
0 commit comments