Skip to content

Commit 3f9ac68

Browse files
committed
initialize_collection_if_empty 함수가 범용적으로 동작하도록 수정
1 parent 24ec4e0 commit 3f9ac68

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

utils/llm/vectordb/qdrant_db.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from qdrant_client import QdrantClient, models
2-
from typing import List, Dict, Any, Optional, Union
2+
from typing import List, Dict, Any, Optional, Union, Callable
33
import os
44
from dotenv import load_dotenv
55

@@ -112,7 +112,7 @@ def search(
112112
)
113113

114114
def similarity_search(
115-
self, query: str, k: int = 5, collection_name: str = "table_info"
115+
self, query: str, k: int = 5, collection_name: str = "lang2sql_table_schema"
116116
) -> List[Any]:
117117
"""
118118
LangChain 호환성을 위한 유사도 검색 메서드.
@@ -179,28 +179,13 @@ def invoke(self, query: str):
179179
# 여기서는 간단하게 구현
180180
return self.similarity_search(query)
181181

182-
def initialize_collection_if_empty(self, collection_name: str = "table_info"):
182+
def _get_table_schema_points(self) -> List[Dict[str, Any]]:
183183
"""
184-
컬렉션이 비어있거나 없으면 스키마 정보를 가져와서 채웁니다.
184+
기본 테이블 스키마 정보를 가져와서 포인트 리스트로 변환합니다.
185185
"""
186186
from utils.llm.tools.datahub import get_table_schema
187187
from utils.llm.core import get_embeddings
188188

189-
# 컬렉션 존재 여부 확인 및 생성
190-
if not self.client.collection_exists(collection_name):
191-
self.create_collection(collection_name)
192-
193-
# 데이터 존재 여부 확인
194-
count_result = self.client.count(collection_name=collection_name)
195-
if count_result.count > 0:
196-
print(
197-
f"Collection '{collection_name}' is not empty. Skipping initialization."
198-
)
199-
return
200-
201-
print(f"Initializing collection '{collection_name}' with table schema...")
202-
203-
# 스키마 정보 가져오기
204189
raw_data = get_table_schema()
205190
embeddings = get_embeddings()
206191

@@ -229,8 +214,46 @@ def initialize_collection_if_empty(self, collection_name: str = "table_info"):
229214
"payload": payload,
230215
}
231216
)
217+
return points
218+
219+
def initialize_collection_if_empty(
220+
self,
221+
collection_name: str = "lang2sql_table_schema",
222+
force_update: bool = False,
223+
data_loader: Optional[Callable[[], List[Dict[str, Any]]]] = None,
224+
):
225+
"""
226+
컬렉션이 비어있거나 없으면 데이터를 채웁니다.
227+
228+
Args:
229+
collection_name: 초기화할 컬렉션 이름.
230+
force_update: 데이터가 있어도 강제로 업데이트할지 여부.
231+
data_loader: 데이터를 가져오는 함수. 포인트 리스트(id, vector, payload)를 반환해야 합니다.
232+
None인 경우 기본 테이블 스키마 로더를 사용합니다.
233+
"""
234+
# 컬렉션 존재 여부 확인 및 생성
235+
if not self.client.collection_exists(collection_name):
236+
self.create_collection(collection_name)
237+
238+
# 데이터 존재 여부 확인
239+
if not force_update:
240+
count_result = self.client.count(collection_name=collection_name)
241+
if count_result.count > 0:
242+
print(
243+
f"Collection '{collection_name}' is not empty. Skipping initialization."
244+
)
245+
return
246+
247+
print(f"Initializing collection '{collection_name}'...")
248+
249+
# 데이터 로드
250+
if data_loader is None:
251+
# 기본 동작: 테이블 스키마 정보 사용
252+
points = self._get_table_schema_points()
253+
else:
254+
points = data_loader()
232255

233256
if points:
234257
self.upsert(collection_name, points)
235258
else:
236-
print("No table schema found to initialize.")
259+
print("No data found to initialize.")

0 commit comments

Comments
 (0)