1- from typing import Any
2-
31import fastapi
42from fastapi import HTTPException
53from sqlalchemy import select
64
7- from fastapi_app .api_models import ChatRequest , RetrievalResponse
5+ from fastapi_app .api_models import ChatRequest , ItemPublic , ItemWithDistance , RetrievalResponse
86from fastapi_app .dependencies import ChatClient , CommonDeps , DBSession , EmbeddingsClient
97from fastapi_app .postgres_models import Item
108from fastapi_app .postgres_searcher import PostgresSearcher
1412router = fastapi .APIRouter ()
1513
1614
17- @router .get ("/items/{id}" , response_model = dict [ str , Any ] )
18- async def item_handler (id : int , database_session : DBSession ) -> dict [ str , Any ] :
15+ @router .get ("/items/{id}" , response_model = ItemPublic )
16+ async def item_handler (id : int , database_session : DBSession ) -> ItemPublic :
1917 """A simple API to get an item by ID."""
2018 item = (await database_session .scalars (select (Item ).where (Item .id == id ))).first ()
2119 if not item :
2220 raise HTTPException (detail = f"Item with ID { id } not found." , status_code = 404 )
23- return item .to_dict ()
21+ return ItemPublic . model_validate ( item .to_dict () )
2422
2523
26- @router .get ("/similar" , response_model = list [dict [ str , Any ] ])
27- async def similar_handler (database_session : DBSession , id : int , n : int = 5 ) -> list [dict [ str , Any ] ]:
24+ @router .get ("/similar" , response_model = list [ItemWithDistance ])
25+ async def similar_handler (database_session : DBSession , id : int , n : int = 5 ) -> list [ItemWithDistance ]:
2826 """A similarity API to find items similar to items with given ID."""
2927 item = (await database_session .scalars (select (Item ).where (Item .id == id ))).first ()
3028 if not item :
@@ -35,10 +33,12 @@ async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> l
3533 .order_by (Item .embedding .l2_distance (item .embedding ))
3634 .limit (n )
3735 )
38- return [item .to_dict () | {"distance" : round (distance , 2 )} for item , distance in closest ]
36+ return [
37+ ItemWithDistance .model_validate (item .to_dict () | {"distance" : round (distance , 2 )}) for item , distance in closest
38+ ]
3939
4040
41- @router .get ("/search" , response_model = list [dict [ str , Any ] ])
41+ @router .get ("/search" , response_model = list [ItemPublic ])
4242async def search_handler (
4343 context : CommonDeps ,
4444 database_session : DBSession ,
@@ -47,7 +47,7 @@ async def search_handler(
4747 top : int = 5 ,
4848 enable_vector_search : bool = True ,
4949 enable_text_search : bool = True ,
50- ) -> list [dict [ str , Any ] ]:
50+ ) -> list [ItemPublic ]:
5151 """A search API to find items based on a query."""
5252 searcher = PostgresSearcher (
5353 db_session = database_session ,
@@ -59,7 +59,7 @@ async def search_handler(
5959 results = await searcher .search_and_embed (
6060 query , top = top , enable_vector_search = enable_vector_search , enable_text_search = enable_text_search
6161 )
62- return [item .to_dict () for item in results ]
62+ return [ItemPublic . model_validate ( item .to_dict () ) for item in results ]
6363
6464
6565@router .post ("/chat" , response_model = RetrievalResponse )
0 commit comments