11"""Database logic."""
22import logging
3- from typing import List , Type , Union
3+ from typing import List , Optional , Tuple , Type , Union
44
55import attr
66import elasticsearch
77from elasticsearch import helpers
88from elasticsearch_dsl import Q , Search
9+ from geojson_pydantic .geometries import (
10+ GeometryCollection ,
11+ LineString ,
12+ MultiLineString ,
13+ MultiPoint ,
14+ MultiPolygon ,
15+ Point ,
16+ Polygon ,
17+ )
918
1019from stac_fastapi .elasticsearch import serializers
1120from stac_fastapi .elasticsearch .config import ElasticsearchSettings
1221from stac_fastapi .types .errors import ConflictError , ForeignKeyError , NotFoundError
13- from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22+ from stac_fastapi .types .stac import Collection , Item
1423
1524logger = logging .getLogger (__name__ )
1625
@@ -31,10 +40,10 @@ class DatabaseLogic:
3140
3241 settings = ElasticsearchSettings ()
3342 client = settings .create_client
34- item_serializer : Type [serializers .Serializer ] = attr .ib (
43+ item_serializer : Type [serializers .ItemSerializer ] = attr .ib (
3544 default = serializers .ItemSerializer
3645 )
37- collection_serializer : Type [serializers .Serializer ] = attr .ib (
46+ collection_serializer : Type [serializers .CollectionSerializer ] = attr .ib (
3847 default = serializers .CollectionSerializer
3948 )
4049
@@ -46,7 +55,7 @@ def bbox2poly(b0, b1, b2, b3):
4655
4756 """CORE LOGIC"""
4857
49- def get_all_collections (self , base_url : str ) -> Collections :
58+ def get_all_collections (self , base_url : str ) -> List [ Collection ] :
5059 """Database logic to retrieve a list of all collections."""
5160 try :
5261 collections = self .client .search (
@@ -66,9 +75,10 @@ def get_all_collections(self, base_url: str) -> Collections:
6675
6776 def get_item_collection (
6877 self , collection_id : str , limit : int , base_url : str
69- ) -> ItemCollection :
78+ ) -> Tuple [ List [ Item ], Optional [ int ]] :
7079 """Database logic to retrieve an ItemCollection and a count of items contained."""
71- search = Search (using = self .client , index = "stac_items" )
80+ search = self .create_search_object ()
81+ search = self .search_collections (search , [collection_id ])
7282
7383 collection_filter = Q (
7484 "bool" , should = [Q ("match_phrase" , ** {"collection" : collection_id })]
@@ -79,7 +89,11 @@ def get_item_collection(
7989
8090 # search = search.sort({"id.keyword" : {"order" : "asc"}})
8191 search = search .query ()[0 :limit ]
82- collection_children = search .execute ().to_dict ()
92+
93+ body = search .to_dict ()
94+ collection_children = self .client .search (
95+ index = ITEMS_INDEX , query = body ["query" ], sort = body .get ("sort" )
96+ )
8397
8498 serialized_children = [
8599 self .item_serializer .db_to_stac (item ["_source" ], base_url = base_url )
@@ -100,21 +114,17 @@ def get_one_item(self, collection_id: str, item_id: str) -> Item:
100114 )
101115 return item ["_source" ]
102116
103- def create_search_object (self ):
117+ @staticmethod
118+ def create_search_object ():
104119 """Database logic to create a nosql Search instance."""
105- search = (
106- Search ()
107- .using (self .client )
108- .index (ITEMS_INDEX )
109- .sort (
110- {"properties.datetime" : {"order" : "desc" }},
111- {"id" : {"order" : "desc" }},
112- {"collection" : {"order" : "desc" }},
113- )
120+ return Search ().sort (
121+ {"properties.datetime" : {"order" : "desc" }},
122+ {"id" : {"order" : "desc" }},
123+ {"collection" : {"order" : "desc" }},
114124 )
115- return search
116125
117- def create_query_filter (self , search , op : str , field : str , value : float ):
126+ @staticmethod
127+ def create_query_filter (search : Search , op : str , field : str , value : float ):
118128 """Database logic to perform query for search endpoint."""
119129 if op != "eq" :
120130 key_filter = {field : {f"{ op } " : value }}
@@ -124,7 +134,8 @@ def create_query_filter(self, search, op: str, field: str, value: float):
124134
125135 return search
126136
127- def search_ids (self , search , item_ids : List ):
137+ @staticmethod
138+ def search_ids (search : Search , item_ids : List ):
128139 """Database logic to search a list of STAC item ids."""
129140 id_list = []
130141 for item_id in item_ids :
@@ -134,17 +145,14 @@ def search_ids(self, search, item_ids: List):
134145
135146 return search
136147
137- def search_collections (self , search , collection_ids : List ):
148+ @staticmethod
149+ def search_collections (search : Search , collection_ids : List ):
138150 """Database logic to search a list of STAC collection ids."""
139- collection_list = []
140- for collection_id in collection_ids :
141- collection_list .append (Q ("match_phrase" , ** {"collection" : collection_id }))
142- collection_filter = Q ("bool" , should = collection_list )
143- search = search .query (collection_filter )
144-
145- return search
151+ collections_query = [Q ("term" , ** {"collection" : cid }) for cid in collection_ids ]
152+ return search .query (Q ("bool" , should = collections_query ))
146153
147- def search_datetime (self , search , datetime_search ):
154+ @staticmethod
155+ def search_datetime (search : Search , datetime_search ):
148156 """Database logic to search datetime field."""
149157 if "eq" in datetime_search :
150158 search = search .query (
@@ -159,9 +167,10 @@ def search_datetime(self, search, datetime_search):
159167 )
160168 return search
161169
162- def search_bbox (self , search , bbox : List ):
170+ @staticmethod
171+ def search_bbox (search : Search , bbox : List ):
163172 """Database logic to search on bounding box."""
164- poly = self .bbox2poly (bbox [0 ], bbox [1 ], bbox [2 ], bbox [3 ])
173+ poly = DatabaseLogic .bbox2poly (bbox [0 ], bbox [1 ], bbox [2 ], bbox [3 ])
165174 bbox_filter = Q (
166175 {
167176 "geo_shape" : {
@@ -175,7 +184,19 @@ def search_bbox(self, search, bbox: List):
175184 search = search .query (bbox_filter )
176185 return search
177186
178- def search_intersects (self , search , intersects : dict ):
187+ @staticmethod
188+ def search_intersects (
189+ search : Search ,
190+ intersects : Union [
191+ Point ,
192+ MultiPoint ,
193+ LineString ,
194+ MultiLineString ,
195+ Polygon ,
196+ MultiPolygon ,
197+ GeometryCollection ,
198+ ],
199+ ):
179200 """Database logic to search a geojson object."""
180201 intersect_filter = Q (
181202 {
@@ -193,24 +214,27 @@ def search_intersects(self, search, intersects: dict):
193214 search = search .query (intersect_filter )
194215 return search
195216
196- def sort_field ( self , search , field , direction ):
197- """Database logic to sort nosql search instance."""
198- search = search . sort ({ field : { "order" : direction }})
199- return search
217+ @ staticmethod
218+ def sort_field ( search : Search , field , direction ):
219+ """Database logic to sort search instance."""
220+ return search . sort ({ field : { "order" : direction }})
200221
201- def search_count (self , search ) -> int :
222+ def search_count (self , search : Search ) -> int :
202223 """Database logic to count search results."""
203224 try :
204- count = search .count ()
225+ return self .client .count (
226+ index = ITEMS_INDEX , body = search .to_dict (count = True )
227+ ).get ("count" )
205228 except elasticsearch .exceptions .NotFoundError :
206229 raise NotFoundError ("No items exist" )
207230
208- return count
209-
210231 def execute_search (self , search , limit : int , base_url : str ) -> List :
211232 """Database logic to execute search with limit."""
212233 search = search .query ()[0 :limit ]
213- response = search .execute ().to_dict ()
234+ body = search .to_dict ()
235+ response = self .client .search (
236+ index = ITEMS_INDEX , query = body ["query" ], sort = body .get ("sort" )
237+ )
214238
215239 if len (response ["hits" ]["hits" ]) > 0 :
216240 response_features = [
@@ -242,30 +266,35 @@ def prep_create_item(self, item: Item, base_url: str) -> Item:
242266
243267 return self .item_serializer .stac_to_db (item , base_url )
244268
245- def create_item (self , item : Item , base_url : str ):
269+ def create_item (self , item : Item , refresh : bool = False ):
246270 """Database logic for creating one item."""
247271 # todo: check if collection exists, but cache
248272 es_resp = self .client .index (
249273 index = ITEMS_INDEX ,
250274 id = mk_item_id (item ["id" ], item ["collection" ]),
251275 document = item ,
276+ refresh = refresh ,
252277 )
253278
254279 if (meta := es_resp .get ("meta" )) and meta .get ("status" ) == 409 :
255280 raise ConflictError (
256281 f"Item { item ['id' ]} in collection { item ['collection' ]} already exists"
257282 )
258283
259- def delete_item (self , item_id : str , collection_id : str ):
284+ def delete_item (self , item_id : str , collection_id : str , refresh : bool = False ):
260285 """Database logic for deleting one item."""
261286 try :
262- self .client .delete (index = ITEMS_INDEX , id = mk_item_id (item_id , collection_id ))
287+ self .client .delete (
288+ index = ITEMS_INDEX ,
289+ id = mk_item_id (item_id , collection_id ),
290+ refresh = refresh ,
291+ )
263292 except elasticsearch .exceptions .NotFoundError :
264293 raise NotFoundError (
265294 f"Item { item_id } in collection { collection_id } not found"
266295 )
267296
268- def create_collection (self , collection : Collection ):
297+ def create_collection (self , collection : Collection , refresh : bool = False ):
269298 """Database logic for creating one collection."""
270299 if self .client .exists (index = COLLECTIONS_INDEX , id = collection ["id" ]):
271300 raise ConflictError (f"Collection { collection ['id' ]} already exists" )
@@ -274,6 +303,7 @@ def create_collection(self, collection: Collection):
274303 index = COLLECTIONS_INDEX ,
275304 id = collection ["id" ],
276305 document = collection ,
306+ refresh = refresh ,
277307 )
278308
279309 def find_collection (self , collection_id : str ) -> Collection :
@@ -285,12 +315,12 @@ def find_collection(self, collection_id: str) -> Collection:
285315
286316 return collection ["_source" ]
287317
288- def delete_collection (self , collection_id : str ):
318+ def delete_collection (self , collection_id : str , refresh : bool = False ):
289319 """Database logic for deleting one collection."""
290320 _ = self .find_collection (collection_id = collection_id )
291- self .client .delete (index = COLLECTIONS_INDEX , id = collection_id )
321+ self .client .delete (index = COLLECTIONS_INDEX , id = collection_id , refresh = refresh )
292322
293- def bulk_sync (self , processed_items ):
323+ def bulk_sync (self , processed_items , refresh : bool = False ):
294324 """Database logic for bulk item insertion."""
295325 actions = [
296326 {
@@ -300,4 +330,22 @@ def bulk_sync(self, processed_items):
300330 }
301331 for item in processed_items
302332 ]
303- helpers .bulk (self .client , actions )
333+ helpers .bulk (self .client , actions , refresh = refresh )
334+
335+ # DANGER
336+ def delete_items (self ) -> None :
337+ """Danger. this is only for tests."""
338+ self .client .delete_by_query (
339+ index = ITEMS_INDEX ,
340+ body = {"query" : {"match_all" : {}}},
341+ wait_for_completion = True ,
342+ )
343+
344+ # DANGER
345+ def delete_collections (self ) -> None :
346+ """Danger. this is only for tests."""
347+ self .client .delete_by_query (
348+ index = COLLECTIONS_INDEX ,
349+ body = {"query" : {"match_all" : {}}},
350+ wait_for_completion = True ,
351+ )
0 commit comments