diff --git a/cli/utils/env_loader.py b/cli/utils/env_loader.py index 768fdce..6d03c3f 100644 --- a/cli/utils/env_loader.py +++ b/cli/utils/env_loader.py @@ -71,14 +71,14 @@ def set_vectordb( """VectorDB 타입과 위치를 설정합니다. Args: - vectordb_type (str): VectorDB 타입 ("faiss" 또는 "pgvector"). + vectordb_type (str): VectorDB 타입 ("faiss" 또는 "pgvector" 또는 "qdrant"). vectordb_location (Optional[str]): 경로 또는 연결 URL. Raises: ValueError: 잘못된 타입이나 경로/URL일 경우. """ - if vectordb_type not in ("faiss", "pgvector"): + if vectordb_type not in ("faiss", "pgvector", "qdrant"): raise ValueError(f"지원하지 않는 VectorDB 타입: {vectordb_type}") os.environ["VECTORDB_TYPE"] = vectordb_type diff --git a/docker/docker-compose-pgvector.yml b/docker/docker-compose-pgvector.yml index 8ad5e16..443baf9 100644 --- a/docker/docker-compose-pgvector.yml +++ b/docker/docker-compose-pgvector.yml @@ -1,7 +1,13 @@ -# docker compose -f docker-compose-pgvector.yml up -# docker compose -f docker-compose-pgvector.yml down +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-pgvector.yml up +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-pgvector.yml down services: + streamlit: + environment: + - DATABASE_URL=postgresql://pgvector:pgvector@pgvector:5432/streamlit + depends_on: + - pgvector + pgvector: image: pgvector/pgvector:pg17 hostname: pgvector @@ -12,7 +18,7 @@ services: environment: POSTGRES_USER: pgvector POSTGRES_PASSWORD: pgvector - POSTGRES_DB: pgvector + POSTGRES_DB: streamlit TZ: Asia/Seoul LANG: en_US.utf8 volumes: diff --git a/docker/docker-compose-postgres.yml b/docker/docker-compose-postgres.yml index 696f7e1..b8b4903 100644 --- a/docker/docker-compose-postgres.yml +++ b/docker/docker-compose-postgres.yml @@ -1,7 +1,13 @@ -# docker compose -f docker-compose-postgres.yml up -# docker compose -f docker-compose-postgres.yml down +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-postgres.yml up +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-postgres.yml down services: + streamlit: + environment: + - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/streamlit + depends_on: + - postgres + postgres: image: postgres:15 hostname: postgres @@ -12,7 +18,7 @@ services: environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres - POSTGRES_DB: postgres + POSTGRES_DB: streamlit TZ: Asia/Seoul LANG: en_US.utf8 volumes: diff --git a/docker/docker-compose-qdrant.yml b/docker/docker-compose-qdrant.yml new file mode 100644 index 0000000..98caca5 --- /dev/null +++ b/docker/docker-compose-qdrant.yml @@ -0,0 +1,23 @@ +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-qdrant.yml up +# docker compose -f docker/docker-compose.yml -f docker/docker-compose-qdrant.yml down + +services: + streamlit: + environment: + - QDRANT_HOST=qdrant + - QDRANT_PORT=6333 + depends_on: + - qdrant + + qdrant: + image: qdrant/qdrant:latest + hostname: qdrant + container_name: qdrant + restart: always + ports: + - "6333:6333" + volumes: + - qdrant_data:/qdrant/storage + +volumes: + qdrant_data: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 115575a..66a83ce 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -13,22 +13,4 @@ services: - ../.env environment: - STREAMLIT_SERVER_PORT=8501 - - DATABASE_URL=postgresql://pgvector:pgvector@localhost:5432/streamlit - depends_on: - - pgvector - pgvector: - image: pgvector/pgvector:pg17 - hostname: pgvector - container_name: pgvector - environment: - POSTGRES_USER: pgvector - POSTGRES_PASSWORD: pgvector - POSTGRES_DB: streamlit - ports: - - "5432:5432" - volumes: - - pgdata:/var/lib/postgresql/data - -volumes: - pgdata: diff --git a/interface/app_pages/settings_sections/data_source_section.py b/interface/app_pages/settings_sections/data_source_section.py index 2f63881..e1b3383 100644 --- a/interface/app_pages/settings_sections/data_source_section.py +++ b/interface/app_pages/settings_sections/data_source_section.py @@ -103,10 +103,36 @@ def render_data_source_section(config: Config | None = None) -> None: new_url = st.text_input( "URL", value=existing.url, key="dh_edit_url" ) - new_faiss = st.text_input( - "FAISS 저장 경로(선택)", - value=existing.faiss_path or "", - key="dh_edit_faiss", + new_vdb_type = st.selectbox( + "VectorDB 타입", + options=["faiss", "pgvector", "qdrant"], + index=( + 0 + if existing.vectordb_type == "faiss" + else (1 if existing.vectordb_type == "pgvector" else 2) + ), + key="dh_edit_vdb_type", + ) + new_vdb_loc_placeholder = ( + "FAISS 디렉토리 경로 (예: ./dev/table_info_db)" + if new_vdb_type == "faiss" + else ( + "pgvector 연결 문자열 (postgresql://...)" + if new_vdb_type == "pgvector" + else "Qdrant URL (예: http://localhost:6333)" + ) + ) + new_vdb_location = st.text_input( + "VectorDB 위치", + value=existing.vectordb_location or existing.faiss_path or "", + key="dh_edit_vdb_loc", + placeholder=new_vdb_loc_placeholder, + ) + new_vdb_api_key = st.text_input( + "VectorDB API Key (선택)", + value=existing.vectordb_api_key or "", + type="password", + key="dh_edit_vdb_key", ) new_note = st.text_input( "메모", value=existing.note or "", key="dh_edit_note" @@ -128,7 +154,14 @@ def render_data_source_section(config: Config | None = None) -> None: update_datahub_source( name=edit_dh, url=new_url, - faiss_path=(new_faiss or None), + faiss_path=( + new_vdb_location + if new_vdb_type == "faiss" + else None + ), + vectordb_type=new_vdb_type, + vectordb_location=(new_vdb_location or None), + vectordb_api_key=(new_vdb_api_key or None), note=(new_note or None), ) st.success("저장되었습니다.") @@ -147,10 +180,29 @@ def render_data_source_section(config: Config | None = None) -> None: dh_url = st.text_input( "URL", key="dh_url", placeholder="http://localhost:8080" ) - dh_faiss = st.text_input( - "FAISS 저장 경로(선택)", - key="dh_faiss", - placeholder="예: ./dev/table_info_db", + dh_vdb_type = st.selectbox( + "VectorDB 타입", + options=["faiss", "pgvector", "qdrant"], + key="dh_new_vdb_type", + ) + dh_vdb_loc_placeholder = ( + "FAISS 디렉토리 경로 (예: ./dev/table_info_db)" + if dh_vdb_type == "faiss" + else ( + "pgvector 연결 문자열 (postgresql://...)" + if dh_vdb_type == "pgvector" + else "Qdrant URL (예: http://localhost:6333)" + ) + ) + dh_vdb_location = st.text_input( + "VectorDB 위치", + key="dh_new_vdb_loc", + placeholder=dh_vdb_loc_placeholder, + ) + dh_vdb_api_key = st.text_input( + "VectorDB API Key (선택)", + type="password", + key="dh_new_vdb_key", ) dh_note = st.text_input("메모", key="dh_note", placeholder="선택") @@ -174,7 +226,12 @@ def render_data_source_section(config: Config | None = None) -> None: add_datahub_source( name=dh_name, url=dh_url, - faiss_path=(dh_faiss or None), + faiss_path=( + dh_vdb_location if dh_vdb_type == "faiss" else None + ), + vectordb_type=dh_vdb_type, + vectordb_location=(dh_vdb_location or None), + vectordb_api_key=(dh_vdb_api_key or None), note=dh_note or None, ) st.success("추가되었습니다.") @@ -216,14 +273,22 @@ def render_data_source_section(config: Config | None = None) -> None: if existing: new_type = st.selectbox( "타입", - options=["faiss", "pgvector"], - index=(0 if existing.type == "faiss" else 1), + options=["faiss", "pgvector", "qdrant"], + index=( + 0 + if existing.type == "faiss" + else (1 if existing.type == "pgvector" else 2) + ), key="vdb_edit_type", ) new_loc_placeholder = ( "FAISS 디렉토리 경로 (예: ./dev/table_info_db)" if new_type == "faiss" - else "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)" + else ( + "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)" + if new_type == "pgvector" + else "Qdrant URL (예: http://localhost:6333)" + ) ) new_location = st.text_input( "위치", @@ -231,6 +296,12 @@ def render_data_source_section(config: Config | None = None) -> None: key="vdb_edit_location", placeholder=new_loc_placeholder, ) + new_api_key = st.text_input( + "API Key (선택)", + value=existing.api_key or "", + type="password", + key="vdb_edit_key", + ) new_prefix = st.text_input( "컬렉션 접두사(선택)", value=existing.collection_prefix or "", @@ -258,6 +329,7 @@ def render_data_source_section(config: Config | None = None) -> None: name=edit_vdb, vtype=new_type, location=new_location, + api_key=(new_api_key or None), collection_prefix=(new_prefix or None), note=(new_note or None), ) @@ -275,16 +347,23 @@ def render_data_source_section(config: Config | None = None) -> None: st.write("VectorDB 추가") vdb_name = st.text_input("이름", key="vdb_name") vdb_type = st.selectbox( - "타입", options=["faiss", "pgvector"], key="vdb_type" + "타입", options=["faiss", "pgvector", "qdrant"], key="vdb_type" ) vdb_loc_placeholder = ( "FAISS 디렉토리 경로 (예: ./dev/table_info_db)" if vdb_type == "faiss" - else "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)" + else ( + "pgvector 연결 문자열 (postgresql://user:pass@host:port/db)" + if vdb_type == "pgvector" + else "Qdrant URL (예: http://localhost:6333)" + ) ) vdb_location = st.text_input( "위치", key="vdb_location", placeholder=vdb_loc_placeholder ) + vdb_api_key = st.text_input( + "API Key (선택)", type="password", key="vdb_new_key" + ) vdb_prefix = st.text_input( "컬렉션 접두사(선택)", key="vdb_prefix", placeholder="예: app1_" ) @@ -312,6 +391,7 @@ def render_data_source_section(config: Config | None = None) -> None: name=vdb_name, vtype=vdb_type, location=vdb_location, + api_key=(vdb_api_key or None), collection_prefix=(vdb_prefix or None), note=(vdb_note or None), ) diff --git a/interface/app_pages/sidebar_components/data_source_selector.py b/interface/app_pages/sidebar_components/data_source_selector.py index b32cf3f..fe54540 100644 --- a/interface/app_pages/sidebar_components/data_source_selector.py +++ b/interface/app_pages/sidebar_components/data_source_selector.py @@ -39,8 +39,18 @@ def render_sidebar_data_source_selector(config=None) -> None: return try: update_datahub_server(config, selected.url) - # DataHub 선택 시, FAISS 경로가 정의되어 있으면 기본 VectorDB 로케이션으로도 반영 - if selected.faiss_path: + # DataHub 선택 시, VectorDB 설정이 정의되어 있으면 기본 VectorDB 로케이션으로도 반영 + if selected.vectordb_location: + try: + update_vectordb_settings( + config, + vectordb_type=selected.vectordb_type or "faiss", + vectordb_location=selected.vectordb_location, + ) + except Exception as e: + st.sidebar.warning(f"VectorDB 설정 적용 경고: {e}") + elif selected.faiss_path: + # Backward compatibility try: update_vectordb_settings( config, diff --git a/interface/core/config/models.py b/interface/core/config/models.py index 9ec02af..7be3d14 100644 --- a/interface/core/config/models.py +++ b/interface/core/config/models.py @@ -17,14 +17,18 @@ class DataHubSource: name: str url: str faiss_path: Optional[str] = None + vectordb_type: str = "faiss" + vectordb_location: Optional[str] = None + vectordb_api_key: Optional[str] = None note: Optional[str] = None @dataclass class VectorDBSource: name: str - type: str # 'faiss' | 'pgvector' + type: str # 'faiss' | 'pgvector' | 'qdrant' location: str + api_key: Optional[str] = None collection_prefix: Optional[str] = None note: Optional[str] = None diff --git a/interface/core/config/persist.py b/interface/core/config/persist.py index 81ba144..ee077ba 100644 --- a/interface/core/config/persist.py +++ b/interface/core/config/persist.py @@ -63,11 +63,22 @@ def _parse_datahub_list(items: List[Dict[str, Any]]) -> List[DataHubSource]: name = str(item.get("name", "")).strip() url = str(item.get("url", "")).strip() faiss_path = item.get("faiss_path") + vectordb_type = item.get("vectordb_type", "faiss") + vectordb_location = item.get("vectordb_location") + vectordb_api_key = item.get("vectordb_api_key") note = item.get("note") if not name or not url: continue parsed.append( - DataHubSource(name=name, url=url, faiss_path=faiss_path, note=note) + DataHubSource( + name=name, + url=url, + faiss_path=faiss_path, + vectordb_type=vectordb_type, + vectordb_location=vectordb_location, + vectordb_api_key=vectordb_api_key, + note=note, + ) ) return parsed @@ -81,12 +92,14 @@ def _parse_vectordb_list(items: List[Dict[str, Any]]) -> List[VectorDBSource]: if not name or not vtype or not location: continue collection_prefix = item.get("collection_prefix") + api_key = item.get("api_key") note = item.get("note") parsed.append( VectorDBSource( name=name, type=vtype, location=location, + api_key=api_key, collection_prefix=collection_prefix, note=note, ) diff --git a/interface/core/config/registry_data_sources.py b/interface/core/config/registry_data_sources.py index 8e9b646..7f393f8 100644 --- a/interface/core/config/registry_data_sources.py +++ b/interface/core/config/registry_data_sources.py @@ -41,25 +41,53 @@ def _save_registry(registry: DataSourcesRegistry) -> None: def add_datahub_source( - *, name: str, url: str, faiss_path: Optional[str] = None, note: Optional[str] = None + *, + name: str, + url: str, + faiss_path: Optional[str] = None, + vectordb_type: str = "faiss", + vectordb_location: Optional[str] = None, + vectordb_api_key: Optional[str] = None, + note: Optional[str] = None, ) -> None: registry = get_data_sources_registry() if any(s.name == name for s in registry.datahub): raise ValueError(f"이미 존재하는 DataHub 이름입니다: {name}") registry.datahub.append( - DataHubSource(name=name, url=url, faiss_path=faiss_path, note=note) + DataHubSource( + name=name, + url=url, + faiss_path=faiss_path, + vectordb_type=vectordb_type, + vectordb_location=vectordb_location, + vectordb_api_key=vectordb_api_key, + note=note, + ) ) _save_registry(registry) def update_datahub_source( - *, name: str, url: str, faiss_path: Optional[str], note: Optional[str] + *, + name: str, + url: str, + faiss_path: Optional[str], + vectordb_type: str = "faiss", + vectordb_location: Optional[str] = None, + vectordb_api_key: Optional[str] = None, + note: Optional[str], ) -> None: registry = get_data_sources_registry() for idx, s in enumerate(registry.datahub): if s.name == name: registry.datahub[idx] = DataHubSource( - name=name, url=url, faiss_path=faiss_path, note=note + name=name, + url=url, + faiss_path=faiss_path, + vectordb_type=vectordb_type, + vectordb_location=vectordb_location, + vectordb_api_key=vectordb_api_key, + note=note, ) _save_registry(registry) return @@ -77,12 +105,15 @@ def add_vectordb_source( name: str, vtype: str, location: str, + api_key: Optional[str] = None, collection_prefix: Optional[str] = None, note: Optional[str] = None, ) -> None: vtype = (vtype or "").lower() - if vtype not in ("faiss", "pgvector"): - raise ValueError("VectorDB 타입은 'faiss' 또는 'pgvector'여야 합니다") + if vtype not in ("faiss", "pgvector", "qdrant"): + raise ValueError( + "VectorDB 타입은 'faiss', 'pgvector', 'qdrant' 중 하나여야 합니다" + ) registry = get_data_sources_registry() if any(s.name == name for s in registry.vectordb): raise ValueError(f"이미 존재하는 VectorDB 이름입니다: {name}") @@ -91,6 +122,7 @@ def add_vectordb_source( name=name, type=vtype, location=location, + api_key=api_key, collection_prefix=collection_prefix, note=note, ) @@ -103,12 +135,15 @@ def update_vectordb_source( name: str, vtype: str, location: str, + api_key: Optional[str] = None, collection_prefix: Optional[str], note: Optional[str], ) -> None: vtype = (vtype or "").lower() - if vtype not in ("faiss", "pgvector"): - raise ValueError("VectorDB 타입은 'faiss' 또는 'pgvector'여야 합니다") + if vtype not in ("faiss", "pgvector", "qdrant"): + raise ValueError( + "VectorDB 타입은 'faiss', 'pgvector', 'qdrant' 중 하나여야 합니다" + ) registry = get_data_sources_registry() for idx, s in enumerate(registry.vectordb): if s.name == name: @@ -116,6 +151,7 @@ def update_vectordb_source( name=name, type=vtype, location=location, + api_key=api_key, collection_prefix=collection_prefix, note=note, ) diff --git a/interface/core/config/settings.py b/interface/core/config/settings.py index 9b4eeb2..b0a02ee 100644 --- a/interface/core/config/settings.py +++ b/interface/core/config/settings.py @@ -153,12 +153,13 @@ def update_vectordb_settings( """Validate and update VectorDB settings into env and session. Basic validation rules follow CLI's behavior: - - vectordb_type must be 'faiss' or 'pgvector' + - vectordb_type must be 'faiss' or 'pgvector' or 'qdrant' - if type == 'faiss' and location provided: must be an existing directory - if type == 'pgvector' and location provided: must start with 'postgresql://' + - if type == 'qdrant' and location provided: must start with 'http://' """ vtype = (vectordb_type or "").lower() - if vtype not in ("faiss", "pgvector"): + if vtype not in ("faiss", "pgvector", "qdrant"): raise ValueError(f"지원하지 않는 VectorDB 타입: {vectordb_type}") vloc = vectordb_location or "" diff --git a/utils/llm/README.md b/utils/llm/README.md index 993ee98..260e8f7 100644 --- a/utils/llm/README.md +++ b/utils/llm/README.md @@ -10,7 +10,10 @@ utils/llm/ ├── chains.py # LangChain 체인 생성 모듈 ├── retrieval.py # 테이블 메타 검색 및 재순위화 ├── llm_response_parser.py # LLM 응답에서 SQL 블록 추출 -├── chatbot.py # LangGraph ChatBot 구현 +├── chatbot/ # LangGraph ChatBot 패키지 +│ ├── __init__.py +│ ├── core.py # ChatBot 핵심 로직 +│ └── README.md # [상세 문서](./chatbot/README.md) ├── core/ # LLM/Embedding 팩토리 모듈 │ ├── __init__.py │ ├── factory.py # LLM 및 Embedding 모델 생성 팩토리 @@ -143,7 +146,7 @@ utils/llm/ **목적**: DataHub 메타데이터 수집 및 LangGraph ChatBot용 Tool 함수 제공 **주요 기능:** -- `get_info_from_db()`: DataHub에서 테이블 메타데이터를 LangChain Document로 수집 +- `get_table_schema()`: DataHub에서 테이블 메타데이터를 dictionary 형태로 반환 - `get_metadata_from_db()`: 전체 메타데이터 딕셔너리 반환 - `search_database_tables()`: 벡터 검색 기반 테이블 정보 검색 Tool - `get_glossary_terms()`: 용어집 정보 조회 Tool @@ -152,7 +155,7 @@ utils/llm/ **사용처:** - `utils/llm/vectordb/faiss_db.py`: 벡터DB 초기화 시 메타데이터 수집 - `utils/llm/vectordb/pgvector_db.py`: 벡터DB 초기화 시 메타데이터 수집 -- `utils/llm/chatbot.py`: ChatBot 도구로 사용 +- `utils/llm/chatbot/`: ChatBot 도구로 사용 **상세 문서**: [tools/README.md](./tools/README.md) @@ -298,7 +301,7 @@ engine/query_executor.py │ └── utils/llm/retrieval.py │ └── utils/llm/vectordb/get_vector_db() │ ├── utils/llm/core/get_embeddings() -│ └── utils/llm/tools/get_info_from_db() +│ └── utils/llm/tools/get_table_schema() └── utils/llm/llm_response_parser.py ``` @@ -316,8 +319,8 @@ engine/query_executor.py - `retrieval.py` → `vectordb/get_vector_db()` 사용 **vectordb 모듈:** -- `vectordb/faiss_db.py` → `core/get_embeddings()`, `tools/get_info_from_db()` 사용 -- `vectordb/pgvector_db.py` → `core/get_embeddings()`, `tools/get_info_from_db()` 사용 +- `vectordb/faiss_db.py` → `core/get_embeddings()`, `tools/get_table_schema()` 사용 +- `vectordb/pgvector_db.py` → `core/get_embeddings()`, `tools/get_table_schema()` 사용 **tools 모듈:** - `tools/datahub.py` → DataHub 메타데이터 수집 diff --git a/utils/llm/chatbot.py b/utils/llm/chatbot.py deleted file mode 100644 index 51bcab0..0000000 --- a/utils/llm/chatbot.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -LangGraph 기반 ChatBot 모델 -OpenAI의 ChatGPT 모델을 사용하여 대화 기록을 유지하는 챗봇 구현 -""" - -from typing import Annotated, Sequence, TypedDict - -from langchain_core.messages import BaseMessage, SystemMessage -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph import START, StateGraph -from langgraph.graph.message import add_messages -from langgraph.prebuilt import ToolNode - -from utils.llm.tools import ( - search_database_tables, - get_glossary_terms, - get_query_examples, -) - - -class ChatBotState(TypedDict): - """ - 챗봇 상태 - 사용자 질문을 SQL로 변환 가능한 구체적인 질문으로 만들어가는 과정 추적 - """ - - # 기본 메시지 (MessagesState와 동일) - messages: Annotated[Sequence[BaseMessage], add_messages] - - # datahub 서버 정보 - gms_server: str - - -class ChatBot: - """ - LangGraph를 사용한 대화형 챗봇 클래스 - OpenAI API를 통해 다양한 GPT 모델을 사용할 수 있으며, - MemorySaver를 통해 대화 기록을 관리합니다. - """ - - def __init__( - self, - openai_api_key: str, - model_name: str = "gpt-4o-mini", - gms_server: str = "http://localhost:8080", - ): - """ - ChatBot 인스턴스 초기화 - - Args: - openai_api_key: OpenAI API 키 - model_name: 사용할 모델명 (기본값: gpt-4o-mini) - gms_server: DataHub GMS 서버 URL (기본값: http://localhost:8080) - """ - self.openai_api_key = openai_api_key - self.model_name = model_name - self.gms_server = gms_server - # SQL 생성을 위한 데이터베이스 메타데이터 조회 도구 - self.tools = [ - search_database_tables, # 데이터베이스 테이블 정보 검색 - get_glossary_terms, # 용어집 조회 도구 - get_query_examples, # 쿼리 예제 조회 도구 - ] - self.llm = self._setup_llm() # LLM 인스턴스 설정 - self.app = self._setup_workflow() # LangGraph 워크플로우 설정 - - def _setup_llm(self): - """ - OpenAI ChatGPT LLM 인스턴스 생성 - Tool을 바인딩하여 LLM이 필요시 tool을 호출할 수 있도록 설정합니다. - - Returns: - ChatOpenAI: Tool이 바인딩된 LLM 인스턴스 - """ - llm = ChatOpenAI( - temperature=0.0, # SQL 생성은 정확성이 중요하므로 0으로 설정 - openai_api_key=self.openai_api_key, - model_name=self.model_name, - ) - # Tool을 LLM에 바인딩하여 함수 호출 기능 활성화 - return llm.bind_tools(self.tools) - - def _setup_workflow(self): - """ - LangGraph 워크플로우 설정 - 대화 기록을 관리하고 LLM과 통신하는 그래프 구조를 생성합니다. - Tool 호출 기능을 포함하여 LLM이 필요시 도구를 사용할 수 있도록 합니다. - - Returns: - CompiledGraph: 컴파일된 LangGraph 워크플로우 - """ - # ChatBotState를 사용하는 StateGraph 생성 - workflow = StateGraph(state_schema=ChatBotState) - - def call_model(state: ChatBotState): - """ - LLM 모델을 호출하는 노드 함수 - LLM이 응답을 생성하거나 tool 호출을 결정합니다. - - Args: - state: 현재 메시지 상태 - - Returns: - dict: LLM 응답이 포함된 상태 업데이트 - """ - # 질문 구체화 전문 어시스턴트 시스템 메시지 - sys_msg = SystemMessage( - content="""# 역할 -당신은 사용자의 모호한 질문을 명확하고 구체적인 질문으로 만드는 전문 AI 어시스턴트입니다. - -# 주요 임무 -- 사용자의 자연어 질문을 이해하고 의도를 정확히 파악합니다 -- 대화를 통해 날짜, 지표, 필터 조건 등 구체적인 정보를 수집합니다 -- 단계별로 사용자와 대화하며 명확하고 구체적인 질문으로 다듬어갑니다 - -# 작업 프로세스 -1. 사용자의 최초 질문에서 의도 파악 -2. 질문을 명확히 하기 위해 필요한 정보 식별 (날짜, 지표, 대상, 조건 등) -3. **도구를 적극 활용하여 데이터베이스 스키마, 테이블 정보, 용어집 등을 확인** -4. 부족한 정보를 자연스럽게 질문하여 수집 -5. 수집된 정보를 바탕으로 질문을 점진적으로 구체화 -6. 충분히 구체화되면 최종 질문 확정 - -# 도구 사용 가이드 -- **search_database_tables**: 사용자와의 대화를 데이터와 연관짓기 위해 관련 테이블을 적극적으로 확인할 수 있는 도구 -- **get_glossary_terms**: 사용자가 사용한 용어의 정확한 의미를 확인할 때 사용가능한 도구 -- **get_query_examples**: 조직내 저장된 쿼리 예제를 조회하여 참고할 수 있는 도구 -- 답변하기 전에 최대한 많은 도구를 적극 활용하여 정보를 수집하세요 -- 불확실한 정보가 있다면 추측하지 말고 도구를 사용하여 확인하세요 - -# 예시 -- 모호한 질문: "KPI가 궁금해" -- 대화 후 구체화: "2025-01-02 날짜의 신규 유저가 발생시킨 매출이 궁금해" - -# 주의사항 -- 항상 친절하고 명확하게 대화합니다 -- 이전 대화 맥락을 고려하여 일관성 있게 응답합니다 -- 한 번에 너무 많은 것을 물어보지 않고 단계적으로 진행합니다 -- **중요: 사용자가 말한 내용이 충분히 구체화되지 않거나 의도가 명확히 파악되지 않을 경우, 추측하지 말고 모든 도구(get_glossary_terms, get_query_examples, search_database_tables)를 적극적으로 사용하여 맥락을 파악하세요** -- 도구를 통해 수집한 정보를 바탕으로 사용자에게 구체적인 방향성과 옵션을 제안하세요 -- 불확실한 정보가 있다면 추측하지 말고 도구를 사용하여 확인한 후 답변하세요 - ---- -다음은 사용자와의 대화입니다:""" - ) - # 시스템 메시지를 대화의 맨 앞에 추가 - messages = [sys_msg] + state["messages"] - response = self.llm.invoke(messages) - return {"messages": response} - - def route_model_output(state: ChatBotState): - """ - LLM 출력에 따라 다음 노드를 결정하는 라우팅 함수 - Tool 호출이 필요한 경우 'tools' 노드로, 아니면 대화를 종료합니다. - - Args: - state: 현재 메시지 상태 - - Returns: - str: 다음에 실행할 노드 이름 ('tools' 또는 '__end__') - """ - messages = state["messages"] - last_message = messages[-1] - # LLM이 tool을 호출하려고 하는 경우 (tool_calls가 있는 경우) - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - return "tools" - # Tool 호출이 없으면 대화 종료 - return "__end__" - - # 워크플로우 구조 정의 - workflow.add_edge(START, "model") # 시작 -> model 노드 - workflow.add_node("model", call_model) # LLM 호출 노드 - workflow.add_node("tools", ToolNode(self.tools)) # Tool 실행 노드 - - # model 노드 이후 조건부 라우팅 - workflow.add_conditional_edges("model", route_model_output) - # Tool 실행 후 다시 model로 돌아가서 최종 응답 생성 - workflow.add_edge("tools", "model") - - # MemorySaver를 사용하여 대화 기록 저장 기능 추가 - return workflow.compile(checkpointer=MemorySaver()) - - def chat(self, message: str, thread_id: str): - """ - 사용자 메시지에 대한 응답 생성 - - Args: - message: 사용자 입력 메시지 - thread_id: 대화 세션을 구분하는 고유 ID - - Returns: - dict: LLM 응답을 포함한 결과 딕셔너리 - """ - config = {"configurable": {"thread_id": thread_id}} - - # 상태 준비 - input_state = { - "messages": [{"role": "user", "content": message}], - "gms_server": self.gms_server, # DataHub 서버 URL을 상태에 포함 - } - - return self.app.invoke(input_state, config) - - def update_model(self, model_name: str): - """ - 사용 중인 LLM 모델 변경 - 모델 변경 시 LLM 인스턴스와 워크플로우를 재설정합니다. - - Args: - model_name: 변경할 모델명 - """ - self.model_name = model_name - self.llm = self._setup_llm() # 새 모델로 LLM 재설정 - self.app = self._setup_workflow() # 워크플로우 재생성 diff --git a/utils/llm/chatbot/README.md b/utils/llm/chatbot/README.md new file mode 100644 index 0000000..8d52722 --- /dev/null +++ b/utils/llm/chatbot/README.md @@ -0,0 +1,57 @@ +# ChatBot Module + +LangGraph 기반의 대화형 챗봇 모듈입니다. 사용자의 자연어 질문을 이해하고, 적절한 가이드라인과 도구를 선택하여 답변을 생성합니다. + +## 구조 + +``` +utils/llm/chatbot/ +├── __init__.py # 패키지 초기화 및 ChatBot 클래스 export +├── core.py # ChatBot 클래스 및 LangGraph 워크플로우 정의 +├── guidelines.py # 가이드라인 및 툴 래퍼 함수 정의 +├── matcher.py # LLM 기반 가이드라인 매칭 로직 +└── types.py # 데이터 타입 및 구조 정의 +``` + +## 주요 컴포넌트 + +### `ChatBot` (`core.py`) +챗봇의 메인 클래스입니다. LangGraph를 사용하여 대화 흐름을 제어합니다. +- **초기화**: OpenAI API 키, 모델명, GMS 서버 URL 등을 설정합니다. +- **워크플로우**: `select_guidelines` -> `call_model` 순서로 실행됩니다. +- **chat 메서드**: 사용자 메시지를 입력받아 응답을 생성합니다. + +### `LLMGuidelineMatcher` (`matcher.py`) +사용자의 메시지를 분석하여 가장 적절한 가이드라인을 선택하는 클래스입니다. +- LLM을 사용하여 사용자 의도를 파악하고, 미리 정의된 가이드라인 중 하나 이상을 매칭합니다. +- JSON Schema를 사용하여 구조화된 출력을 보장합니다. + +### `Guideline` (`types.py`) +챗봇이 따를 규칙과 도구를 정의하는 데이터 클래스입니다. +- `id`: 가이드라인 식별자 +- `description`: 가이드라인 설명 +- `example_phrases`: 매칭에 사용될 예시 문구 +- `tools`: 해당 가이드라인에서 사용할 도구 함수 목록 +- `priority`: 매칭 우선순위 + +### `GUIDELINES` (`guidelines.py`) +기본적으로 제공되는 가이드라인 목록입니다. +- `db_search`: 데이터베이스 테이블 정보 검색 +- `glossary`: 용어집 조회 +- `query_examples`: 쿼리 예제 조회 + +## 사용 예시 + +```python +from utils.llm.chatbot import ChatBot + +# 챗봇 인스턴스 생성 +bot = ChatBot( + openai_api_key="sk-...", + gms_server="http://localhost:8080" +) + +# 대화하기 +response = bot.chat("매출 테이블 정보 알려줘", thread_id="session_1") +print(response["messages"][-1].content) +``` diff --git a/utils/llm/chatbot/__init__.py b/utils/llm/chatbot/__init__.py new file mode 100644 index 0000000..d816b62 --- /dev/null +++ b/utils/llm/chatbot/__init__.py @@ -0,0 +1,7 @@ +""" +ChatBot 패키지 초기화 모듈 +""" + +from utils.llm.chatbot.core import ChatBot + +__all__ = ["ChatBot"] diff --git a/utils/llm/chatbot/core.py b/utils/llm/chatbot/core.py new file mode 100644 index 0000000..829a2ad --- /dev/null +++ b/utils/llm/chatbot/core.py @@ -0,0 +1,186 @@ +""" +ChatBot 핵심 로직 및 LangGraph 워크플로우 정의 +""" + +from typing import Any, Dict, List, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from openai import OpenAI + +from utils.llm.chatbot.guidelines import GUIDELINES +from utils.llm.chatbot.matcher import LLMGuidelineMatcher +from utils.llm.chatbot.types import ChatBotState, Guideline + + +class ChatBot: + """ + LangGraph를 사용한 대화형 챗봇 클래스 (Guideline 기반) + """ + + def __init__( + self, + openai_api_key: str, + model_name: str = "gpt-4o-mini", + gms_server: str = "http://localhost:8080", + guidelines: Optional[List[Guideline]] = None, + ): + """ + ChatBot 인스턴스 초기화 + + Args: + openai_api_key: OpenAI API 키 + model_name: 사용할 모델명 (기본값: gpt-4o-mini) + gms_server: DataHub GMS 서버 URL (기본값: http://localhost:8080) + guidelines: 사용할 가이드라인 목록 (없으면 기본값 사용) + """ + self.openai_api_key = openai_api_key + self.model_name = model_name + self.gms_server = gms_server + self.guidelines = guidelines or GUIDELINES + self.guideline_map = {g.id: g for g in self.guidelines} + + self._client = OpenAI(api_key=openai_api_key) + self.matcher = LLMGuidelineMatcher( + self.guidelines, + model=self.model_name, + client_obj=self._client, + ) + self.llm = ChatOpenAI( + temperature=0.0, + model_name=self.model_name, + openai_api_key=openai_api_key, + ) + self.app = self._setup_workflow() + + def _setup_workflow(self): + """ + LangGraph 워크플로우 설정 + """ + workflow = StateGraph(state_schema=ChatBotState) + + def select_guidelines(state: ChatBotState): + user_text = "" + # 마지막 사용자 메시지 찾기 + for msg in reversed(state["messages"]): + if isinstance(msg, HumanMessage) or ( + hasattr(msg, "type") and msg.type == "human" + ): + user_text = msg.content + break + + # 만약 메시지 객체 구조가 달라서 못 찾았을 경우를 대비해 마지막 메시지 내용 사용 + if not user_text and state["messages"]: + user_text = state["messages"][-1].content + + matched = self.matcher.match(str(user_text)) + + # 컨텍스트 업데이트 (현재 사용자 메시지 추가) + ctx = state.get("context") or {} + ctx["last_user_message"] = user_text + ctx["gms_server"] = self.gms_server + # search_database_tables_tool을 위해 query 키도 설정 + ctx["query"] = user_text + + outs: List[str] = [] + for g in matched: + for tool in g.tools or []: + try: + # tool 실행 + result = tool(ctx) + outs.append(f"[{g.id}] {result}") + except Exception as exc: + outs.append(f"[tool_error] {tool.__name__}: {exc}") + + return { + "selected_ids": [g.id for g in matched], + "tool_outputs": outs, + "context": ctx, + } + + def call_model(state: ChatBotState): + selected_ids = state.get("selected_ids", []) + tool_outs = state.get("tool_outputs", []) + + guideline_lines = [ + f"- {gid}: {self.guideline_map[gid].description}" + for gid in selected_ids + if gid in self.guideline_map + ] or ["- 적용 가능한 가이드라인 없음 (일반 대화 진행)"] + + tool_lines = tool_outs or ["(툴 실행 결과 없음)"] + + sys_msg = SystemMessage( + content=( + "# 역할\n" + "당신은 사용자의 모호한 질문을 명확하고 구체적인 질문으로 만드는 전문 AI 어시스턴트입니다.\n" + "제공된 툴 실행 결과와 가이드라인을 바탕으로 사용자에게 답변하세요.\n\n" + "# 적용된 가이드라인\n" + + "\n".join(guideline_lines) + + "\n\n# 툴 실행 결과 (참고 정보)\n" + + "\n".join(f"- {line}" for line in tool_lines) + + "\n\n# 지침\n" + "- 툴 실행 결과에 유용한 정보가 있다면 적극적으로 인용하여 답변하세요.\n" + "- 정보가 부족하다면 추가 질문을 통해 구체화하세요.\n" + "- 항상 친절하고 명확하게 대화하세요." + ) + ) + + # 시스템 메시지를 대화의 맨 앞에 추가 (또는 매번 컨텍스트로 주입) + # LangGraph에서는 메시지 리스트가 계속 쌓이므로, + # 이번 턴의 시스템 메시지를 앞에 붙여서 invoke 하는 방식 사용 + messages = [sys_msg] + list(state["messages"]) + response = self.llm.invoke(messages) + return {"messages": response} + + workflow.add_node("select", select_guidelines) + workflow.add_node("respond", call_model) + + workflow.add_edge(START, "select") + workflow.add_edge("select", "respond") + workflow.add_edge("respond", END) + + return workflow.compile(checkpointer=MemorySaver()) + + def chat(self, message: str, thread_id: str): + """ + 사용자 메시지에 대한 응답 생성 + + Args: + message: 사용자 입력 메시지 + thread_id: 대화 세션을 구분하는 고유 ID + + Returns: + dict: LLM 응답을 포함한 결과 딕셔너리 + """ + config = {"configurable": {"thread_id": thread_id}} + + # 초기 상태 설정 + # add_messages 리듀서가 있으므로 messages에는 새 메시지만 넣으면 됨 + input_state = { + "messages": [HumanMessage(content=message)], + "context": {"gms_server": self.gms_server}, + "selected_ids": [], + "tool_outputs": [], + } + + return self.app.invoke(input_state, config) + + def update_model(self, model_name: str): + """ + 사용 중인 LLM 모델 변경 + """ + self.model_name = model_name + self._client = OpenAI(api_key=self.openai_api_key) + self.matcher = LLMGuidelineMatcher( + self.guidelines, + model=self.model_name, + client_obj=self._client, + ) + self.llm = ChatOpenAI( + temperature=0.0, + model_name=self.model_name, + openai_api_key=self.openai_api_key, + ) diff --git a/utils/llm/chatbot/guidelines.py b/utils/llm/chatbot/guidelines.py new file mode 100644 index 0000000..f12b48c --- /dev/null +++ b/utils/llm/chatbot/guidelines.py @@ -0,0 +1,67 @@ +""" +ChatBot 가이드라인 및 툴 정의 +""" + +from typing import Any, Dict, List + +from utils.llm.tools import ( + search_database_tables, + get_glossary_terms, + get_query_examples, +) +from utils.llm.chatbot.types import Guideline + + +def search_database_tables_tool(ctx: Dict[str, Any]) -> str: + query = ctx.get("query") or ctx.get("last_user_message", "") + return str(search_database_tables.invoke({"query": query})) + + +def get_glossary_terms_tool(ctx: Dict[str, Any]) -> str: + query = ctx.get("query") or ctx.get("last_user_message", "") + return str(get_glossary_terms.invoke({"query": query})) + + +def get_query_examples_tool(ctx: Dict[str, Any]) -> str: + query = ctx.get("query") or ctx.get("last_user_message", "") + return str(get_query_examples.invoke({"query": query})) + + +GUIDELINES: List[Guideline] = [ + Guideline( + id="db_search", + description="데이터베이스 테이블 정보나 스키마 확인이 필요할 때 사용", + example_phrases=[ + "테이블 정보 알려줘", + "어떤 컬럼이 있어?", + "스키마 보여줘", + "데이터 구조가 궁금해", + ], + tools=[search_database_tables_tool], + priority=10, + ), + Guideline( + id="glossary", + description="용어의 정의나 비즈니스 의미 확인이 필요할 때 사용", + example_phrases=[ + "용어집 보여줘", + "이 단어 뜻이 뭐야?", + "비즈니스 용어 설명해줘", + "KPI 정의가 뭐야?", + ], + tools=[get_glossary_terms_tool], + priority=8, + ), + Guideline( + id="query_examples", + description="쿼리 예제나 SQL 작성 패턴 확인이 필요할 때 사용", + example_phrases=[ + "쿼리 예제 보여줘", + "비슷한 쿼리 있어?", + "SQL 어떻게 짜야해?", + "다른 사람들은 어떻게 쿼리했어?", + ], + tools=[get_query_examples_tool], + priority=9, + ), +] diff --git a/utils/llm/chatbot/matcher.py b/utils/llm/chatbot/matcher.py new file mode 100644 index 0000000..ab1ba92 --- /dev/null +++ b/utils/llm/chatbot/matcher.py @@ -0,0 +1,84 @@ +""" +LLM 기반 가이드라인 매칭 로직 +""" + +import json +from typing import Any, Dict, List, Optional + +from openai import OpenAI + +from utils.llm.chatbot.types import Guideline + + +class LLMGuidelineMatcher: + def __init__( + self, + guidelines: List[Guideline], + model: str, + client_obj: Optional[OpenAI] = None, + ): + self.guidelines = guidelines + self.model = model + self.client = client_obj or OpenAI() + self._id_set = {g.id for g in guidelines} + + def _build_messages(self, message: str) -> List[Dict[str, str]]: + sys = ( + "You are a strict GuidelineMatcher.\n" + "Return ONLY a JSON object that matches the provided JSON schema." + ) + lines = [ + "아래 사용자 메시지에 해당하는 모든 가이드라인 id를 선택하세요.", + f"[USER MESSAGE]\n{message}\n", + "[GUIDELINES]", + ] + for g in self.guidelines: + examples = ", ".join(g.example_phrases) if g.example_phrases else "-" + lines.append( + f"- id: {g.id}\n desc: {g.description}\n examples: {examples}" + ) + return [ + {"role": "system", "content": sys}, + {"role": "user", "content": "\n".join(lines)}, + ] + + def _json_schema_spec(self) -> Dict[str, Any]: + return { + "name": "guideline_matches", + "schema": { + "type": "object", + "properties": { + "matches": { + "type": "array", + "items": {"type": "string", "enum": list(self._id_set)}, + } + }, + "required": ["matches"], + "additionalProperties": False, + }, + "strict": True, + } + + def match(self, message: str) -> List[Guideline]: + ids: List[str] = [] + try: + completion = self.client.chat.completions.create( + model=self.model, + temperature=0, + messages=self._build_messages(message), + response_format={ + "type": "json_schema", + "json_schema": self._json_schema_spec(), + }, + ) + raw = completion.choices[0].message.content + data = json.loads(raw) if isinstance(raw, str) else raw + ids = [i for i in (data.get("matches") or []) if i in self._id_set] + except Exception: + # LLM 호출 실패 시 빈 리스트 반환 (일반 대화로 처리) + ids = [] + + id_to_g = {g.id: g for g in self.guidelines} + selected = [id_to_g[i] for i in ids if i in id_to_g] + selected.sort(key=lambda g: g.priority, reverse=True) + return selected diff --git a/utils/llm/chatbot/types.py b/utils/llm/chatbot/types.py new file mode 100644 index 0000000..ab23240 --- /dev/null +++ b/utils/llm/chatbot/types.py @@ -0,0 +1,33 @@ +""" +ChatBot 관련 데이터 타입 및 구조 정의 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, TypedDict, Annotated + +from langchain_core.messages import BaseMessage +from langgraph.graph.message import add_messages + +ToolFn = Callable[[Dict[str, Any]], Any] + + +@dataclass +class Guideline: + id: str + description: str + example_phrases: List[str] + tools: Optional[List[ToolFn]] = None + priority: int = 0 + + +class ChatBotState(TypedDict): + """ + 챗봇 상태 + """ + + messages: Annotated[Sequence[BaseMessage], add_messages] + context: Dict[str, Any] + selected_ids: List[str] + tool_outputs: List[str] diff --git a/utils/llm/retrieval.py b/utils/llm/retrieval.py index 0b5d916..5f85715 100644 --- a/utils/llm/retrieval.py +++ b/utils/llm/retrieval.py @@ -6,6 +6,8 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer from utils.llm.vectordb import get_vector_db +from utils.llm.tools.datahub import get_glossary_vector_data, get_query_vector_data +from utils.llm.core import get_embeddings def load_reranker_model(device: str = "cpu"): @@ -102,3 +104,86 @@ def search_tables( } return documents_dict + + +def _prepare_vector_data(data_fetcher, text_fields): + """ + 데이터를 가져와서 임베딩을 생성하는 헬퍼 함수 + """ + points = data_fetcher() + embeddings = get_embeddings() + + for point in points: + payload = point["payload"] + # 텍스트 필드들을 결합하여 임베딩 생성 + text_to_embed = " ".join([str(payload.get(field, "")) for field in text_fields]) + vector = embeddings.embed_query(text_to_embed) + point["vector"] = {"dense": vector} + + return points + + +def search_glossary(query: str, force_update: bool = False, top_n: int = 5) -> list: + """ + 용어집 검색 함수 + """ + collection_name = "lang2sql_glossary" + db = get_vector_db() + + # 데이터 로더 정의 (임베딩 생성 포함) + def data_loader(): + return _prepare_vector_data(get_glossary_vector_data, ["name", "description"]) + + # 컬렉션 초기화 (필요시) + db.initialize_collection_if_empty( + collection_name=collection_name, + force_update=force_update, + data_loader=data_loader, + ) + + # 검색 수행 + embeddings = get_embeddings() + query_vector = embeddings.embed_query(query) + + results = db.search( + collection_name=collection_name, + query_vector=("dense", query_vector), + limit=top_n, + ) + + # 결과 포맷팅 + return [res.payload for res in results] + + +def search_query_examples( + query: str, force_update: bool = False, top_n: int = 5 +) -> list: + """ + 쿼리 예제 검색 함수 + """ + collection_name = "lang2sql_query_example" + db = get_vector_db() + + # 데이터 로더 정의 (임베딩 생성 포함) + def data_loader(): + return _prepare_vector_data(get_query_vector_data, ["name", "description"]) + + # 컬렉션 초기화 (필요시) + db.initialize_collection_if_empty( + collection_name=collection_name, + force_update=force_update, + data_loader=data_loader, + ) + + # 검색 수행 + embeddings = get_embeddings() + query_vector = embeddings.embed_query(query) + + results = db.search( + collection_name=collection_name, + query_vector=("dense", query_vector), + limit=top_n, + ) + + # 결과 포맷팅 + return [res.payload for res in results] diff --git a/utils/llm/tools/README.md b/utils/llm/tools/README.md index fa7153d..9747422 100644 --- a/utils/llm/tools/README.md +++ b/utils/llm/tools/README.md @@ -21,7 +21,7 @@ utils/llm/tools/ **datahub 모듈에서**: - `set_gms_server`: GMS 서버 설정 -- `get_info_from_db`: LangChain Document 리스트로 테이블/컬럼 정보 반환 +- `get_table_schema`: LangChain Document 리스트로 테이블/컬럼 정보 반환 - `get_metadata_from_db`: 전체 메타데이터 딕셔너리 리스트 반환 **chatbot_tool 모듈에서**: @@ -39,7 +39,7 @@ utils/llm/tools/ - 환경변수 `DATAHUB_SERVER`를 설정하고 DatahubMetadataFetcher 초기화 - 유효하지 않은 서버 URL 시 ValueError 발생 -2. **`get_info_from_db(max_workers: int = 8) -> List[Document]`** +2. **`get_table_schema(max_workers: int = 8) -> List[Document]`** - DataHub에서 모든 테이블 메타데이터를 수집하여 LangChain Document 리스트 반환 - 각 Document에는 테이블명, 설명, 컬럼 정보가 포함 - 형식: `"{테이블명}: {설명}\nColumns:\n {컬럼명}: {컬럼설명}"` @@ -157,10 +157,10 @@ utils/llm/tools/ #### 1. DataHub 메타데이터 수집 (vectorDB 초기화) ```python -from utils.llm.tools import get_info_from_db +from utils.llm.tools import get_table_schema # 모든 테이블 메타데이터를 LangChain Document로 수집 -documents = get_info_from_db(max_workers=8) +documents = get_table_schema(max_workers=8) # 각 document는 다음과 같은 형식: # "테이블명: 설명\nColumns:\n 컬럼1: 설명1\n 컬럼2: 설명2" @@ -224,8 +224,8 @@ queries = get_query_examples( **import하는 파일**: - `utils/llm/chatbot.py`: `from utils.llm.tools import search_database_tables, get_glossary_terms, get_query_examples` -- `utils/llm/vectordb/faiss_db.py`: `from utils.llm.tools import get_info_from_db` -- `utils/llm/vectordb/pgvector_db.py`: `from utils.llm.tools import get_info_from_db` +- `utils/llm/vectordb/faiss_db.py`: `from utils.llm.tools import get_table_schema` +- `utils/llm/vectordb/pgvector_db.py`: `from utils.llm.tools import get_table_schema` - `interface/core/config/settings.py`: `from utils.llm.tools import set_gms_server` **내부 의존성**: @@ -258,7 +258,7 @@ queries = get_query_examples( #### 메타데이터 수집 흐름 (벡터DB 초기화 시) -1. `get_info_from_db()` 호출 +1. `get_table_schema()` 호출 2. `_get_fetcher()`로 DatahubMetadataFetcher 인스턴스 생성 3. `parallel_process()`로 병렬 테이블 정보 수집 4. 각 테이블별로 컬럼 정보 추가 수집 diff --git a/utils/llm/tools/__init__.py b/utils/llm/tools/__init__.py index f0dcb9d..db87052 100644 --- a/utils/llm/tools/__init__.py +++ b/utils/llm/tools/__init__.py @@ -1,5 +1,5 @@ from utils.llm.tools.datahub import ( - get_info_from_db, + get_table_schema, get_metadata_from_db, set_gms_server, ) @@ -12,7 +12,7 @@ __all__ = [ "set_gms_server", - "get_info_from_db", + "get_table_schema", "get_metadata_from_db", "search_database_tables", "get_glossary_terms", diff --git a/utils/llm/tools/chatbot_tool.py b/utils/llm/tools/chatbot_tool.py index 9c496f0..85a125b 100644 --- a/utils/llm/tools/chatbot_tool.py +++ b/utils/llm/tools/chatbot_tool.py @@ -6,6 +6,7 @@ from utils.data.datahub_services.base_client import DataHubBaseClient from utils.data.datahub_services.glossary_service import GlossaryService from utils.data.datahub_services.query_service import QueryService +from utils.llm.retrieval import search_glossary, search_query_examples @tool @@ -105,11 +106,11 @@ def _simplify_glossary_data(glossary_data): @tool -def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list: +def get_glossary_terms(query: str, force_update: bool = False) -> list: """ - DataHub에서 용어집(Glossary) 정보를 조회합니다. + DataHub에서 용어집(Glossary) 정보를 검색합니다. - 이 함수는 DataHub 서버에 연결하여 전체 용어집 데이터를 가져옵니다. + 이 함수는 사용자의 질문과 관련된 용어 정의를 찾기 위해 Vector Search를 수행합니다. 용어집은 비즈니스 용어, 도메인 지식, 데이터 정의 등을 표준화하여 관리하는 곳입니다. **중요**: 사용자의 질문이나 대화에서 다음과 같은 상황이 발생하면 반드시 이 도구를 사용하세요: @@ -120,40 +121,26 @@ def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list: 5. 표준 정의가 필요한 비즈니스 용어가 나왔을 때 Args: - gms_server (str, optional): DataHub GMS 서버 URL입니다. - 기본값은 "http://35.222.65.99:8080" + query (str): 검색할 용어 또는 관련 질문입니다. + force_update (bool, optional): True일 경우 데이터를 새로고침하여 검색 인덱스를 재생성합니다. 기본값은 False. Returns: - list: 간소화된 용어집 데이터 리스트입니다. - 각 항목은 name, description, children(선택적) 필드를 포함합니다. + list: 검색된 용어집 데이터 리스트입니다. + 각 항목은 name, description 등을 포함합니다. 예시 형태: [ { "name": "가짜연구소", "description": "스터디 단체 가짜연구소를 의미하며...", - "children": [ - { - "name": "빌더", - "description": "가짜연구소 스터디 리더를 지칭..." - } - ] + "type": "term" }, - { - "name": "PII", - "description": "개인 식별 정보...", - "children": [ - { - "name": "identifier", - "description": "개인식별정보중 github 아이디..." - } - ] - } + ... ] Examples: - >>> get_glossary_terms() - [{'name': '가짜연구소', 'description': '...', 'children': [...]}] + >>> get_glossary_terms("PII가 뭐야?") + [{'name': 'PII', 'description': '개인 식별 정보...', ...}] Note: 이 도구는 다음과 같은 경우에 **반드시** 사용하세요: @@ -178,37 +165,22 @@ def get_glossary_terms(gms_server: str = "http://35.222.65.99:8080") -> list: 있는지 확인하고, 있다면 먼저 이 도구를 호출하여 정확한 정의를 파악하세요. """ try: - # DataHub 클라이언트 초기화 - client = DataHubBaseClient(gms_server=gms_server) - - # GlossaryService 초기화 - glossary_service = GlossaryService(client) - - # 전체 용어집 데이터 가져오기 - glossary_data = glossary_service.get_glossary_data() + return search_glossary(query=query, force_update=force_update) - # 간소화된 데이터 반환 - simplified_data = _simplify_glossary_data(glossary_data) - - return simplified_data - - except ValueError as e: - return {"error": True, "message": f"DataHub 서버 연결 실패: {str(e)}"} except Exception as e: return {"error": True, "message": f"용어집 조회 중 오류 발생: {str(e)}"} @tool def get_query_examples( - gms_server: str = "http://35.222.65.99:8080", - start: int = 0, - count: int = 10, - query: str = "*", + query: str, + force_update: bool = False, + count: int = 5, ) -> list: """ - DataHub에서 저장된 쿼리 예제들을 조회합니다. + DataHub에서 저장된 쿼리 예제들을 검색합니다. - 이 함수는 DataHub 서버에 연결하여 저장된 SQL 쿼리 목록을 가져옵니다. + 이 함수는 사용자의 질문과 관련된 SQL 쿼리 예제를 찾기 위해 Vector Search를 수행합니다. 조직에서 실제로 사용되고 검증된 쿼리 패턴을 참고하여 더 정확한 SQL을 생성할 수 있습니다. **중요**: 사용자의 질문이나 대화에서 다음과 같은 상황이 발생하면 반드시 이 도구를 사용하세요: @@ -220,11 +192,9 @@ def get_query_examples( 6. 조직 내에서 검증된 쿼리 작성 방식을 확인해야 할 때 Args: - gms_server (str, optional): DataHub GMS 서버 URL입니다. - 기본값은 "http://35.222.65.99:8080" - start (int, optional): 조회 시작 위치입니다. 기본값은 0 - count (int, optional): 조회할 쿼리 개수입니다. 기본값은 10 - query (str, optional): 검색 쿼리입니다. 기본값은 "*" (모든 쿼리) + query (str): 검색할 쿼리 관련 질문이나 키워드입니다. + force_update (bool, optional): True일 경우 데이터를 새로고침하여 검색 인덱스를 재생성합니다. 기본값은 False. + count (int, optional): 반환할 검색 결과 개수입니다. 기본값은 5. Returns: list: 쿼리 정보 리스트입니다. @@ -237,19 +207,12 @@ def get_query_examples( "description": "각 고객별 주문 건수를 집계하는 쿼리", "statement": "SELECT customer_id, COUNT(*) as order_count FROM orders GROUP BY customer_id" }, - { - "name": "월별 매출 현황", - "description": "월별 총 매출을 계산하는 쿼리", - "statement": "SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) FROM orders GROUP BY month" - } + ... ] Examples: - >>> get_query_examples() - [{'name': '고객별 주문 수 조회', 'description': '...', 'statement': 'SELECT ...'}] - - >>> get_query_examples(count=5) - # 5개의 쿼리 예제만 조회 + >>> get_query_examples("매출 집계 쿼리 보여줘") + [{'name': '월별 매출 현황', 'description': '...', 'statement': 'SELECT ...'}] Note: 이 도구는 다음과 같은 경우에 **반드시** 사용하세요: @@ -280,33 +243,10 @@ def get_query_examples( SQL을 생성하는 데 큰 도움이 됩니다. """ try: - # DataHub 클라이언트 초기화 - client = DataHubBaseClient(gms_server=gms_server) - - # QueryService 초기화 - query_service = QueryService(client) - - # 쿼리 데이터 가져오기 - result = query_service.get_query_data(start=start, count=count, query=query) - - # 오류 체크 - if "error" in result and result["error"]: - return {"error": True, "message": result.get("message")} - - # name, description, statement만 추출하여 리스트 생성 - simplified_queries = [] - for query_item in result.get("queries", []): - simplified_query = { - "name": query_item.get("name"), - "description": query_item.get("description", ""), - "statement": query_item.get("statement", ""), - } - simplified_queries.append(simplified_query) - - return simplified_queries + return search_query_examples( + query=query, force_update=force_update, top_n=count + ) - except ValueError as e: - return {"error": True, "message": f"DataHub 서버 연결 실패: {str(e)}"} except Exception as e: return { "error": True, diff --git a/utils/llm/tools/datahub.py b/utils/llm/tools/datahub.py index 42e564d..7fb87cb 100644 --- a/utils/llm/tools/datahub.py +++ b/utils/llm/tools/datahub.py @@ -1,12 +1,16 @@ import os import re +import uuid from concurrent.futures import ThreadPoolExecutor from typing import Callable, Dict, Iterable, List, Optional, TypeVar from langchain.schema import Document from tqdm import tqdm +from utils.data.datahub_services.glossary_service import GlossaryService +from utils.data.datahub_services.query_service import QueryService from utils.data.datahub_source import DatahubMetadataFetcher +from utils.data.datahub_services.base_client import DataHubBaseClient T = TypeVar("T") R = TypeVar("R") @@ -76,7 +80,7 @@ def _get_table_info(max_workers: int = 8) -> Dict[str, str]: def _get_column_info( - table_name: str, urn_table_mapping: Dict[str, str], max_workers: int = 8 + table_name: str, urn_table_mapping: Dict[str, str] ) -> List[Dict[str, str]]: target_urn = urn_table_mapping.get(table_name) if not target_urn: @@ -103,7 +107,21 @@ def _extract_dataset_name_from_urn(urn: str) -> Optional[str]: return None -def get_info_from_db(max_workers: int = 8) -> List[Document]: +def get_metadata_from_db() -> List[Dict]: + fetcher = _get_fetcher() + urns = list(fetcher.get_urns()) + + metadata = [] + total = len(urns) + for idx, urn in enumerate(urns, 1): + print(f"[{idx}/{total}] Processing URN: {urn}") + table_metadata = fetcher.build_table_metadata(urn) + metadata.append(table_metadata) + + return metadata + + +def _prepare_datahub_metadata_mappings(max_workers: int = 8): table_info = _get_table_info(max_workers=max_workers) fetcher = _get_fetcher() @@ -118,20 +136,31 @@ def get_info_from_db(max_workers: int = 8) -> List[Document]: if parsed_name: display_name_by_table[original_name] = parsed_name - def process_table_info(item: tuple[str, str, str]) -> str: - original_table_name, table_description, display_table_name = item - # 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치) - column_info = _get_column_info( - original_table_name, urn_table_mapping, max_workers=max_workers - ) - column_info_str = "\n".join( - [ - f"{col['column_name']}: {col['column_description']}" - for col in column_info - ] - ) - used_name = display_table_name or original_table_name - return f"{used_name}: {table_description}\nColumns:\n {column_info_str}" + return table_info, urn_table_mapping, display_name_by_table + + +def _format_datahub_table_info( + item: tuple[str, str, str], urn_table_mapping: Dict[str, str] +) -> Dict: + original_table_name, table_description, display_table_name = item + # 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치) + column_info = _get_column_info(original_table_name, urn_table_mapping) + + columns = {col["column_name"]: col["column_description"] for col in column_info} + + used_name = display_table_name or original_table_name + return { + used_name: { + "table_description": table_description, + "columns": columns, + } + } + + +def get_table_schema(max_workers: int = 8) -> List[Dict]: + table_info, urn_table_mapping, display_name_by_table = ( + _prepare_datahub_metadata_mappings(max_workers) + ) # 표시용 이름을 세 번째 파라미터로 함께 전달 items_with_display = [ @@ -143,25 +172,116 @@ def process_table_info(item: tuple[str, str, str]) -> str: for name, desc in table_info.items() ] - table_info_str_list = parallel_process( + # parallel_process에 전달할 함수 래핑 + def process_fn(item): + return _format_datahub_table_info(item, urn_table_mapping) + + table_info_list = parallel_process( items_with_display, - process_table_info, + process_fn, max_workers=max_workers, desc="컬럼 정보 수집 중", ) - return [Document(page_content=info) for info in table_info_str_list] - + return table_info_list -def get_metadata_from_db() -> List[Dict]: - fetcher = _get_fetcher() - urns = list(fetcher.get_urns()) - - metadata = [] - total = len(urns) - for idx, urn in enumerate(urns, 1): - print(f"[{idx}/{total}] Processing URN: {urn}") - table_metadata = fetcher.build_table_metadata(urn) - metadata.append(table_metadata) - return metadata +def get_glossary_vector_data() -> List[Dict]: + """ + Vector Search를 위한 용어집 데이터를 조회하고 포맷팅합니다. + """ + gms_server = os.getenv("DATAHUB_SERVER", "http://35.222.65.99:8080") + client = DataHubBaseClient(gms_server=gms_server) + glossary_service = GlossaryService(client) + + glossary_data = glossary_service.get_glossary_data() + + points = [] + if "error" in glossary_data: + print(f"Error fetching glossary data: {glossary_data.get('message')}") + return points + + # Flatten the glossary structure + def process_node(node): + # Current node + name = node.get("name") + description = node.get("description", "") + + # Create point for the node itself if it has meaningful content + if name: + # Generate deterministic UUID based on name + point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, name)) + points.append( + { + "id": point_id, + "vector": {}, # Placeholder, will be embedded later + "payload": { + "name": name, + "description": description, + "type": "term", # or node + }, + } + ) + + # Process children + if "details" in node and "children" in node["details"]: + for child in node["details"]["children"]: + child_name = child.get("name") + child_desc = child.get("description", "") + if child_name: + child_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, child_name)) + points.append( + { + "id": child_id, + "vector": {}, + "payload": { + "name": child_name, + "description": child_desc, + "type": "term", + }, + } + ) + + for node in glossary_data.get("nodes", []): + process_node(node) + + return points + + +def get_query_vector_data() -> List[Dict]: + """ + Vector Search를 위한 쿼리 예제 데이터를 조회하고 포맷팅합니다. + """ + gms_server = os.getenv("DATAHUB_SERVER", "http://35.222.65.99:8080") + client = DataHubBaseClient(gms_server=gms_server) + query_service = QueryService(client) + + # Fetch all queries (adjust count as needed) + query_data = query_service.get_query_data(count=1000) + + points = [] + if "error" in query_data: + print(f"Error fetching query data: {query_data.get('message')}") + return points + + for query in query_data.get("queries", []): + name = query.get("name") + description = query.get("description", "") + statement = query.get("statement", "") + + if name and statement: + # Generate deterministic UUID based on name + point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, name)) + points.append( + { + "id": point_id, + "vector": {}, + "payload": { + "name": name, + "description": description, + "statement": statement, + }, + } + ) + + return points diff --git a/utils/llm/vectordb/README.md b/utils/llm/vectordb/README.md index 356ffcd..2a31981 100644 --- a/utils/llm/vectordb/README.md +++ b/utils/llm/vectordb/README.md @@ -52,13 +52,13 @@ utils/llm/vectordb/ - `vectordb_path`: 저장 경로 (기본: `dev/table_info_db`) - 동작 방식: - 기존 DB가 있으면 `FAISS.load_local()`로 로드 - - 없으면 `get_info_from_db()`로 문서 수집 후 `FAISS.from_documents()` 생성 및 저장 + - 없으면 `get_table_schema()`로 문서 수집 후 `FAISS.from_documents()` 생성 및 저장 - 반환: FAISS 벡터스토어 인스턴스 **의존성**: - `langchain_community.vectorstores.FAISS`: LangChain FAISS 래퍼 - `utils.llm.core.get_embeddings`: 임베딩 모델 로드 -- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집 +- `utils.llm.tools.get_table_schema`: DataHub에서 테이블 메타데이터 수집 **특징**: - 로컬 디스크에 저장되어 네트워크 연결 불필요 @@ -84,7 +84,7 @@ utils/llm/vectordb/ - `PGVECTOR_COLLECTION`: "lang2sql_table_info_db" - 동작 방식: - 기존 컬렉션이 있고 비어있지 않으면 로드 - - 없거나 비어있으면 `get_info_from_db()`로 문서 수집 후 `PGVector.from_documents()` 생성 + - 없거나 비어있으면 `get_table_schema()`로 문서 수집 후 `PGVector.from_documents()` 생성 - 반환: PGVector 벡터스토어 인스턴스 2. **`_check_collection_exists(connection_string, collection_name)`** @@ -96,7 +96,7 @@ utils/llm/vectordb/ - `langchain_postgres.vectorstores.PGVector`: LangChain pgvector 래퍼 - `psycopg2`: PostgreSQL 연결 - `utils.llm.core.get_embeddings`: 임베딩 모델 로드 -- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집 +- `utils.llm.tools.get_table_schema`: DataHub에서 테이블 메타데이터 수집 **특징**: - PostgreSQL 데이터베이스에 저장되어 다중 서버 환경에 적합 @@ -181,7 +181,7 @@ export PGVECTOR_COLLECTION=lang2sql_table_info_db **내부 의존성**: - `utils/llm/core/factory.py`: `get_embeddings()` - 임베딩 모델 로드 -- `utils/llm/tools/datahub.py`: `get_info_from_db()` - DataHub 메타데이터 수집 +- `utils/llm/tools/datahub.py`: `get_table_schema()` - DataHub 메타데이터 수집 **외부 의존성**: - `langchain_community.vectorstores.FAISS`: FAISS 벡터스토어 diff --git a/utils/llm/vectordb/factory.py b/utils/llm/vectordb/factory.py index 942a443..68a13b1 100644 --- a/utils/llm/vectordb/factory.py +++ b/utils/llm/vectordb/factory.py @@ -7,6 +7,20 @@ from utils.llm.vectordb.faiss_db import get_faiss_vector_db from utils.llm.vectordb.pgvector_db import get_pgvector_db +from utils.llm.vectordb.qdrant_db import QdrantDB + + +def get_qdrant_vector_db(url: Optional[str] = None, api_key: Optional[str] = None): + """Qdrant VectorDB 인스턴스를 반환하고 초기화합니다.""" + if url is None: + url = os.getenv("QDRANT_URL", "http://localhost:6333") + + if api_key is None: + api_key = os.getenv("QDRANT_API_KEY") + + db = QdrantDB(url=url, api_key=api_key) + db.initialize_collection_if_empty() + return db def get_vector_db( @@ -16,11 +30,11 @@ def get_vector_db( VectorDB 타입과 위치에 따라 적절한 VectorDB 인스턴스를 반환합니다. Args: - vectordb_type: VectorDB 타입 ("faiss" 또는 "pgvector"). None인 경우 환경 변수에서 읽음. + vectordb_type: VectorDB 타입 ("faiss", "pgvector", "qdrant"). None인 경우 환경 변수에서 읽음. vectordb_location: VectorDB 위치 (FAISS: 디렉토리 경로, pgvector: 연결 문자열). None인 경우 환경 변수에서 읽음. Returns: - VectorDB 인스턴스 (FAISS 또는 PGVector) + VectorDB 인스턴스 (FAISS, PGVector, 또는 Qdrant) """ if vectordb_type is None: vectordb_type = os.getenv("VECTORDB_TYPE", "faiss").lower() @@ -32,7 +46,9 @@ def get_vector_db( return get_faiss_vector_db(vectordb_location) elif vectordb_type == "pgvector": return get_pgvector_db(vectordb_location) + elif vectordb_type == "qdrant": + return get_qdrant_vector_db(url=vectordb_location) else: raise ValueError( - f"지원하지 않는 VectorDB 타입: {vectordb_type}. 'faiss' 또는 'pgvector'를 사용하세요." + f"지원하지 않는 VectorDB 타입: {vectordb_type}. 'faiss', 'pgvector', 또는 'qdrant'를 사용하세요." ) diff --git a/utils/llm/vectordb/faiss_db.py b/utils/llm/vectordb/faiss_db.py index d4754a5..0b48d01 100644 --- a/utils/llm/vectordb/faiss_db.py +++ b/utils/llm/vectordb/faiss_db.py @@ -6,9 +6,10 @@ from typing import Optional from langchain_community.vectorstores import FAISS +from langchain.schema import Document from utils.llm.core import get_embeddings -from utils.llm.tools import get_info_from_db +from utils.llm.tools import get_table_schema def get_faiss_vector_db(vectordb_path: Optional[str] = None): @@ -26,7 +27,15 @@ def get_faiss_vector_db(vectordb_path: Optional[str] = None): allow_dangerous_deserialization=True, ) except: - documents = get_info_from_db() + raw_data = get_table_schema() + documents = [] + for item in raw_data: + for table_name, table_info in item.items(): + column_info_str = "\n".join( + [f"{k}: {v}" for k, v in table_info["columns"].items()] + ) + page_content = f"{table_name}: {table_info['table_description']}\nColumns:\n {column_info_str}" + documents.append(Document(page_content=page_content)) db = FAISS.from_documents(documents, embeddings) db.save_local(vectordb_path) print(f"VectorDB를 새로 생성했습니다: {vectordb_path}") diff --git a/utils/llm/vectordb/pgvector_db.py b/utils/llm/vectordb/pgvector_db.py index d03f034..edba041 100644 --- a/utils/llm/vectordb/pgvector_db.py +++ b/utils/llm/vectordb/pgvector_db.py @@ -7,9 +7,10 @@ import psycopg2 from langchain_postgres.vectorstores import PGVector +from langchain.schema import Document from utils.llm.core import get_embeddings -from utils.llm.tools import get_info_from_db +from utils.llm.tools import get_table_schema def _check_collection_exists(connection_string: str, collection_name: str) -> bool: @@ -71,7 +72,15 @@ def get_pgvector_db( except Exception as e: print(f"exception: {e}") # 컬렉션이 없거나 불러오기에 실패한 경우, 문서를 다시 인덱싱 - documents = get_info_from_db() + raw_data = get_table_schema() + documents = [] + for item in raw_data: + for table_name, table_info in item.items(): + column_info_str = "\n".join( + [f"{k}: {v}" for k, v in table_info["columns"].items()] + ) + page_content = f"{table_name}: {table_info['table_description']}\nColumns:\n {column_info_str}" + documents.append(Document(page_content=page_content)) vector_store = PGVector.from_documents( documents=documents, embedding=embeddings, diff --git a/utils/llm/vectordb/qdrant_db.py b/utils/llm/vectordb/qdrant_db.py new file mode 100644 index 0000000..6369ef8 --- /dev/null +++ b/utils/llm/vectordb/qdrant_db.py @@ -0,0 +1,262 @@ +from qdrant_client import QdrantClient, models +from typing import List, Dict, Any, Optional, Union, Callable +import os +import uuid +from dotenv import load_dotenv + +load_dotenv() + + +class QdrantDB: + def __init__( + self, url: str = "http://localhost:6333", api_key: Optional[str] = None + ): + """ + Qdrant 클라이언트를 초기화합니다. + + Args: + url: Qdrant 서버 URL. + api_key: Qdrant 클라우드 또는 인증된 인스턴스를 위한 API 키. + """ + self.client = QdrantClient(url=url, api_key=api_key) + + def create_collection( + self, collection_name: str, dense_dim: int = 1536, colbert_dim: int = 128 + ): + """ + Dense, ColBERT, Sparse 벡터 구성을 포함한 컬렉션을 생성합니다. + + Args: + collection_name: 생성할 컬렉션의 이름. + dense_dim: Dense 벡터의 차원 (기본값: OpenAI small 모델 기준 1536). + colbert_dim: ColBERT 벡터의 차원 (기본값: 128). + """ + if not self.client.collection_exists(collection_name): + self.client.create_collection( + collection_name=collection_name, + vectors_config={ + "dense": models.VectorParams( + size=dense_dim, distance=models.Distance.COSINE + ), + "colbert": models.VectorParams( + size=colbert_dim, + distance=models.Distance.COSINE, + multivector_config=models.MultiVectorConfig( + comparator=models.MultiVectorComparator.MAX_SIM + ), + hnsw_config=models.HnswConfigDiff(m=0), + ), + }, + sparse_vectors_config={"sparse": models.SparseVectorParams()}, + ) + print(f"Collection '{collection_name}' created.") + else: + print(f"Collection '{collection_name}' already exists.") + + def upsert(self, collection_name: str, points: List[Dict[str, Any]]): + """ + 컬렉션에 포인트들을 업서트(Upsert)합니다. + + Args: + collection_name: 컬렉션 이름. + points: 다음 항목들을 포함하는 딕셔너리 리스트: + - id: 고유 식별자 (int 또는 str) + - vector: 'dense', 'colbert', 'sparse' 키와 해당 벡터 값을 포함하는 딕셔너리. + - payload: 메타데이터를 포함하는 딕셔너리. + """ + point_structs = [] + for point in points: + if "id" not in point or "vector" not in point: + raise ValueError("Each point must contain 'id' and 'vector' keys.") + + point_structs.append( + models.PointStruct( + id=point["id"], + vector=point["vector"], + payload=point.get("payload", {}), + ) + ) + + self.client.upload_points(collection_name=collection_name, points=point_structs) + print( + f"Successfully upserted {len(point_structs)} points to '{collection_name}'." + ) + + def search( + self, + collection_name: str, + query_vector: Union[List[float], tuple], + query_filter: Optional[models.Filter] = None, + limit: int = 10, + with_payload: bool = True, + ) -> List[models.ScoredPoint]: + """ + 특정 컬렉션에서 벡터 검색을 수행합니다. + + Args: + collection_name: 검색할 컬렉션의 이름. + query_vector: 검색에 사용할 쿼리 벡터. 명명된 벡터를 사용하는 경우 ('vector_name', vector) 튜플로 전달해야 합니다. + query_filter: 검색 시 적용할 필터 (선택 사항). + limit: 반환할 결과의 최대 개수 (기본값: 10). + with_payload: 결과에 페이로드를 포함할지 여부 (기본값: True). + + Returns: + 검색 결과 리스트 (ScoredPoint 객체들의 리스트). + """ + return self.client.search( + collection_name=collection_name, + query_vector=query_vector, + query_filter=query_filter, + limit=limit, + with_payload=with_payload, + ) + + def similarity_search( + self, query: str, k: int = 5, collection_name: str = "lang2sql_table_schema" + ) -> List[Any]: + """ + LangChain 호환성을 위한 유사도 검색 메서드. + + Args: + query: 검색 쿼리 문자열. + k: 반환할 결과 개수. + collection_name: 검색할 컬렉션 이름. + + Returns: + LangChain Document 객체 리스트. + """ + from langchain.schema import Document + from utils.llm.core import get_embeddings + + embeddings = get_embeddings() + query_vector = embeddings.embed_query(query) + + results = self.search( + collection_name=collection_name, + query_vector=("dense", query_vector), + limit=k, + ) + + documents = [] + for res in results: + payload = res.payload + # payload를 page_content와 metadata로 변환 + # 여기서는 payload의 모든 내용을 metadata로 넣고, + # 특정 필드를 page_content로 구성하거나 payload 전체를 문자열로 변환 + + # 기존 faiss_db.py의 로직을 참고하여 page_content 구성 + # table_name: table_description + # Columns: + # col1: desc1 + + table_name = payload.get("table_name", "Unknown Table") + table_description = payload.get("table_description", "") + columns = payload.get("columns", {}) + + column_info_str = "\n".join( + [f"{key}: {val}" for key, val in columns.items()] + ) + page_content = ( + f"{table_name}: {table_description}\nColumns:\n {column_info_str}" + ) + + documents.append(Document(page_content=page_content, metadata=payload)) + + return documents + + def as_retriever(self, search_kwargs: Optional[Dict] = None): + """ + LangChain Retriever 인터페이스 호환 메서드. + """ + return self + + def invoke(self, query: str): + """ + Retriever 인터페이스의 invoke 메서드 구현. + """ + # search_kwargs에서 k 값 가져오기 (기본값 5) + # as_retriever 호출 시 저장된 설정이 있다면 그것을 사용해야 하지만, + # 여기서는 간단하게 구현 + return self.similarity_search(query) + + def _get_table_schema_points(self) -> List[Dict[str, Any]]: + """ + 기본 테이블 스키마 정보를 가져와서 포인트 리스트로 변환합니다. + """ + from utils.llm.tools.datahub import get_table_schema + from utils.llm.core import get_embeddings + + raw_data = get_table_schema() + embeddings = get_embeddings() + + points = [] + for idx, item in enumerate(raw_data): + for table_name, table_info in item.items(): + # 벡터 생성을 위한 텍스트 구성 + column_info_str = "\n".join( + [f"{k}: {v}" for k, v in table_info["columns"].items()] + ) + text_to_embed = f"{table_name}: {table_info['table_description']}" + + vector = embeddings.embed_query(text_to_embed) + + # payload 구성 + payload = { + "table_name": table_name, + "table_description": table_info["table_description"], + "columns": table_info["columns"], + } + + # Generate deterministic UUID based on table_name + point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name)) + + points.append( + { + "id": point_id, + "vector": {"dense": vector}, # dense vector only for now + "payload": payload, + } + ) + return points + + def initialize_collection_if_empty( + self, + collection_name: str = "lang2sql_table_schema", + force_update: bool = False, + data_loader: Optional[Callable[[], List[Dict[str, Any]]]] = None, + ): + """ + 컬렉션이 비어있거나 없으면 데이터를 채웁니다. + + Args: + collection_name: 초기화할 컬렉션 이름. + force_update: 데이터가 있어도 강제로 업데이트할지 여부. + data_loader: 데이터를 가져오는 함수. 포인트 리스트(id, vector, payload)를 반환해야 합니다. + None인 경우 기본 테이블 스키마 로더를 사용합니다. + """ + # 컬렉션 존재 여부 확인 및 생성 + if not self.client.collection_exists(collection_name): + self.create_collection(collection_name) + + # 데이터 존재 여부 확인 + if not force_update: + count_result = self.client.count(collection_name=collection_name) + if count_result.count > 0: + print( + f"Collection '{collection_name}' is not empty. Skipping initialization." + ) + return + + print(f"Initializing collection '{collection_name}'...") + + # 데이터 로드 + if data_loader is None: + # 기본 동작: 테이블 스키마 정보 사용 + points = self._get_table_schema_points() + else: + points = data_loader() + + if points: + self.upsert(collection_name, points) + else: + print("No data found to initialize.")