|
1 | 1 | from qdrant_client import QdrantClient, models |
2 | | -from typing import List, Dict, Any, Optional, Union |
| 2 | +from typing import List, Dict, Any, Optional, Union, Callable |
3 | 3 | import os |
4 | 4 | from dotenv import load_dotenv |
5 | 5 |
|
@@ -112,7 +112,7 @@ def search( |
112 | 112 | ) |
113 | 113 |
|
114 | 114 | def similarity_search( |
115 | | - self, query: str, k: int = 5, collection_name: str = "table_info" |
| 115 | + self, query: str, k: int = 5, collection_name: str = "lang2sql_table_schema" |
116 | 116 | ) -> List[Any]: |
117 | 117 | """ |
118 | 118 | LangChain 호환성을 위한 유사도 검색 메서드. |
@@ -179,28 +179,13 @@ def invoke(self, query: str): |
179 | 179 | # 여기서는 간단하게 구현 |
180 | 180 | return self.similarity_search(query) |
181 | 181 |
|
182 | | - def initialize_collection_if_empty(self, collection_name: str = "table_info"): |
| 182 | + def _get_table_schema_points(self) -> List[Dict[str, Any]]: |
183 | 183 | """ |
184 | | - 컬렉션이 비어있거나 없으면 스키마 정보를 가져와서 채웁니다. |
| 184 | + 기본 테이블 스키마 정보를 가져와서 포인트 리스트로 변환합니다. |
185 | 185 | """ |
186 | 186 | from utils.llm.tools.datahub import get_table_schema |
187 | 187 | from utils.llm.core import get_embeddings |
188 | 188 |
|
189 | | - # 컬렉션 존재 여부 확인 및 생성 |
190 | | - if not self.client.collection_exists(collection_name): |
191 | | - self.create_collection(collection_name) |
192 | | - |
193 | | - # 데이터 존재 여부 확인 |
194 | | - count_result = self.client.count(collection_name=collection_name) |
195 | | - if count_result.count > 0: |
196 | | - print( |
197 | | - f"Collection '{collection_name}' is not empty. Skipping initialization." |
198 | | - ) |
199 | | - return |
200 | | - |
201 | | - print(f"Initializing collection '{collection_name}' with table schema...") |
202 | | - |
203 | | - # 스키마 정보 가져오기 |
204 | 189 | raw_data = get_table_schema() |
205 | 190 | embeddings = get_embeddings() |
206 | 191 |
|
@@ -229,8 +214,46 @@ def initialize_collection_if_empty(self, collection_name: str = "table_info"): |
229 | 214 | "payload": payload, |
230 | 215 | } |
231 | 216 | ) |
| 217 | + return points |
| 218 | + |
| 219 | + def initialize_collection_if_empty( |
| 220 | + self, |
| 221 | + collection_name: str = "lang2sql_table_schema", |
| 222 | + force_update: bool = False, |
| 223 | + data_loader: Optional[Callable[[], List[Dict[str, Any]]]] = None, |
| 224 | + ): |
| 225 | + """ |
| 226 | + 컬렉션이 비어있거나 없으면 데이터를 채웁니다. |
| 227 | +
|
| 228 | + Args: |
| 229 | + collection_name: 초기화할 컬렉션 이름. |
| 230 | + force_update: 데이터가 있어도 강제로 업데이트할지 여부. |
| 231 | + data_loader: 데이터를 가져오는 함수. 포인트 리스트(id, vector, payload)를 반환해야 합니다. |
| 232 | + None인 경우 기본 테이블 스키마 로더를 사용합니다. |
| 233 | + """ |
| 234 | + # 컬렉션 존재 여부 확인 및 생성 |
| 235 | + if not self.client.collection_exists(collection_name): |
| 236 | + self.create_collection(collection_name) |
| 237 | + |
| 238 | + # 데이터 존재 여부 확인 |
| 239 | + if not force_update: |
| 240 | + count_result = self.client.count(collection_name=collection_name) |
| 241 | + if count_result.count > 0: |
| 242 | + print( |
| 243 | + f"Collection '{collection_name}' is not empty. Skipping initialization." |
| 244 | + ) |
| 245 | + return |
| 246 | + |
| 247 | + print(f"Initializing collection '{collection_name}'...") |
| 248 | + |
| 249 | + # 데이터 로드 |
| 250 | + if data_loader is None: |
| 251 | + # 기본 동작: 테이블 스키마 정보 사용 |
| 252 | + points = self._get_table_schema_points() |
| 253 | + else: |
| 254 | + points = data_loader() |
232 | 255 |
|
233 | 256 | if points: |
234 | 257 | self.upsert(collection_name, points) |
235 | 258 | else: |
236 | | - print("No table schema found to initialize.") |
| 259 | + print("No data found to initialize.") |
0 commit comments