Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 37d4007

Browse files
authored
[NeuralChat] Refactor RAG code and structure (#913)
[NeuralChat] Refactor RAG code and structure Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
1 parent 6e3a514 commit 37d4007

File tree

19 files changed

+819
-496
lines changed

19 files changed

+819
-496
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from .vectorstore_retriever import VectorStoreRetriever
19+
from .child_parent_retriever import ChildParentRetriever
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""The wrapper for Child-Parent retriever based on langchain"""
19+
from langchain.retrievers import MultiVectorRetriever
20+
from langchain_core.vectorstores import VectorStore
21+
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
22+
from enum import Enum
23+
24+
class SearchType(str, Enum):
25+
"""Enumerator of the types of search to perform."""
26+
27+
similarity = "similarity"
28+
"""Similarity search."""
29+
mmr = "mmr"
30+
"""Maximal Marginal Relevance reranking of similarity search."""
31+
32+
33+
class ChildParentRetriever(MultiVectorRetriever):
34+
"""Retrieve from a set of multiple embeddings for the same document."""
35+
36+
vectorstore: VectorStore
37+
"""The underlying vectorstore to use to store small chunks
38+
and their embedding vectors"""
39+
parentstore: VectorStore
40+
41+
def get_context(self, query:str, *, run_manager: CallbackManagerForRetrieverRun):
42+
"""Get documents relevant to a query.
43+
Args:
44+
query: String to find relevant documents for
45+
run_manager: The callbacks handler to use
46+
Returns:
47+
The concatation of the retrieved documents and the link
48+
"""
49+
if self.search_type == SearchType.mmr:
50+
sub_docs = self.vectorstore.max_marginal_relevance_search(
51+
query, **self.search_kwargs
52+
)
53+
else:
54+
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
55+
56+
ids = []
57+
for d in sub_docs:
58+
if d.metadata['doc_id'] not in ids:
59+
ids.append(d.metadata['doc_id'])
60+
retrieved_documents = self.parentstore.get(ids)
61+
context = ''
62+
links = []
63+
for doc in retrieved_documents:
64+
context = context + doc.page_content + " "
65+
links.append(doc.metadata['source'])
66+
return context.strip(), links
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""The wrapper for Retriever based on langchain"""
19+
from langchain_core.vectorstores import VectorStoreRetriever as VectorRetriever
20+
21+
22+
class VectorStoreRetriever(VectorRetriever):
23+
"""Retrieve the vector document stores using dense retrieval."""
24+
25+
def __init__(self, document_store=None, **kwargs):
26+
super().__init__(**kwargs)
27+
28+
def get_context(self, query):
29+
context = ''
30+
links = []
31+
retrieved_documents = self.get_relevant_documents(query)
32+
for doc in retrieved_documents:
33+
context = context + doc.page_content + " "
34+
links.append(doc.metadata['source'])
35+
return context.strip(), links
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from .chroma import Chroma
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""The wrapper for Chroma retriever based on langchain"""
19+
from __future__ import annotations
20+
import base64
21+
import logging, os
22+
import uuid
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
Callable,
27+
Dict,
28+
Iterable,
29+
List,
30+
Optional,
31+
Tuple,
32+
Type,
33+
)
34+
import numpy as np
35+
from langchain_core.documents import Document
36+
from langchain.vectorstores.chroma import Chroma as Chroma_origin
37+
from langchain_core.embeddings import Embeddings
38+
from langchain_core.utils import xor_args
39+
from langchain_core.vectorstores import VectorStore
40+
import chromadb
41+
import chromadb.config
42+
_DEFAULT_PERSIST_DIR = './output'
43+
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
44+
logging.basicConfig(
45+
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
46+
datefmt="%d-%M-%Y %H:%M:%S",
47+
level=logging.INFO
48+
)
49+
50+
51+
class Chroma(Chroma_origin):
52+
def __init__(self, **kwargs):
53+
super().__init__(**kwargs)
54+
55+
@classmethod
56+
def from_texts(
57+
cls: Type[Chroma],
58+
texts: List[str],
59+
embedding: Optional[Embeddings] = None,
60+
metadatas: Optional[List[dict]] = None,
61+
ids: Optional[List[str]] = None,
62+
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
63+
persist_directory: Optional[str] = None,
64+
client_settings: Optional[chromadb.config.Settings] = None,
65+
client: Optional[chromadb.Client] = None,
66+
collection_metadata: Optional[Dict] = None,
67+
**kwargs: Any,
68+
) -> Chroma:
69+
"""Create a Chroma vectorstore from a raw documents.
70+
71+
If a persist_directory is specified, the collection will be persisted there.
72+
Otherwise, the data will be ephemeral in-memory.
73+
74+
Args:
75+
texts (List[str]): List of texts to add to the collection.
76+
collection_name (str): Name of the collection to create.
77+
persist_directory (Optional[str]): Directory to persist the collection.
78+
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
79+
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
80+
ids (Optional[List[str]]): List of document IDs. Defaults to None.
81+
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
82+
collection_metadata (Optional[Dict]): Collection configurations.
83+
Defaults to None.
84+
85+
Returns:
86+
Chroma: Chroma vectorstore.
87+
"""
88+
chroma_collection = cls(
89+
collection_name=collection_name,
90+
embedding_function=embedding,
91+
persist_directory=persist_directory,
92+
client_settings=client_settings,
93+
client=client,
94+
collection_metadata=collection_metadata,
95+
)
96+
if ids is None:
97+
ids = [str(uuid.uuid1()) for _ in texts]
98+
if hasattr(
99+
chroma_collection._client, "max_batch_size"
100+
): # for Chroma 0.4.10 and above
101+
from chromadb.utils.batch_utils import create_batches
102+
103+
for batch in create_batches(
104+
api=chroma_collection._client,
105+
ids=ids,
106+
metadatas=metadatas,
107+
documents=texts,
108+
):
109+
chroma_collection.add_texts(
110+
texts=batch[3] if batch[3] else [],
111+
metadatas=batch[2] if batch[2] else None,
112+
ids=batch[0],
113+
)
114+
else:
115+
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
116+
return chroma_collection
117+
118+
@classmethod
119+
def from_documents(
120+
cls: Type[Chroma],
121+
documents: List[Document],
122+
sign: str = None,
123+
embedding: Optional[Embeddings] = None,
124+
ids: Optional[List[str]] = None,
125+
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
126+
persist_directory: Optional[str] = _DEFAULT_PERSIST_DIR,
127+
client_settings: Optional[chromadb.config.Settings] = None,
128+
client: Optional[chromadb.Client] = None, # Add this line
129+
collection_metadata: Optional[Dict] = None,
130+
**kwargs: Any,
131+
) -> Chroma:
132+
"""Create a Chroma vectorstore from a list of documents.
133+
134+
If a persist_directory is specified, the collection will be persisted there.
135+
Otherwise, the data will be ephemeral in-memory.
136+
137+
Args:
138+
collection_name (str): Name of the collection to create.
139+
persist_directory (Optional[str]): Directory to persist the collection.
140+
ids (Optional[List[str]]): List of document IDs. Defaults to None.
141+
documents (List[Document]): List of documents to add to the vectorstore.
142+
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
143+
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
144+
collection_metadata (Optional[Dict]): Collection configurations.
145+
Defaults to None.
146+
147+
Returns:
148+
Chroma: Chroma vectorstore.
149+
"""
150+
texts = [doc.page_content for doc in documents]
151+
metadatas = [doc.metadata for doc in documents]
152+
if 'doc_id' in metadatas[0]:
153+
ids = [doc.metadata['doc_id'] for doc in documents]
154+
if sign == 'child':
155+
persist_directory = persist_directory + "_child"
156+
return cls.from_texts(
157+
texts=texts,
158+
embedding=embedding,
159+
metadatas=metadatas,
160+
ids=ids,
161+
collection_name=collection_name,
162+
persist_directory=persist_directory,
163+
client_settings=client_settings,
164+
client=client,
165+
collection_metadata=collection_metadata,
166+
**kwargs,
167+
)
168+
169+
170+
@classmethod
171+
def build(
172+
cls: Type[Chroma],
173+
documents: List[Document],
174+
sign: Optional[str] = None,
175+
embedding: Optional[Embeddings] = None,
176+
metadatas: Optional[List[dict]] = None,
177+
ids: Optional[List[str]] = None,
178+
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
179+
persist_directory: Optional[str] = None,
180+
client_settings: Optional[chromadb.config.Settings] = None,
181+
client: Optional[chromadb.Client] = None,
182+
collection_metadata: Optional[Dict] = None,
183+
**kwargs: Any,
184+
) -> Chroma:
185+
if not persist_directory:
186+
persist_directory = _DEFAULT_PERSIST_DIR
187+
if sign == "child":
188+
persist_directory = persist_directory + "_child"
189+
if os.path.exists(persist_directory):
190+
if bool(os.listdir(persist_directory)):
191+
logging.info("Load the existing database!")
192+
chroma_collection = cls(
193+
collection_name=collection_name,
194+
embedding_function=embedding,
195+
persist_directory=persist_directory,
196+
client_settings=client_settings,
197+
client=client,
198+
collection_metadata=collection_metadata,
199+
**kwargs,
200+
)
201+
return chroma_collection
202+
else:
203+
logging.info("Create a new knowledge base...")
204+
chroma_collection = cls.from_documents(
205+
documents=documents,
206+
embedding=embedding,
207+
ids=ids,
208+
collection_name=collection_name,
209+
persist_directory=persist_directory,
210+
client_settings=client_settings,
211+
client=client,
212+
collection_metadata=collection_metadata,
213+
**kwargs,
214+
)
215+
return chroma_collection
216+
217+
218+
@classmethod
219+
def reload(
220+
cls: Type[Chroma],
221+
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
222+
embedding: Optional[Embeddings] = None,
223+
persist_directory: Optional[str] = None,
224+
client_settings: Optional[chromadb.config.Settings] = None,
225+
collection_metadata: Optional[Dict] = None,
226+
client: Optional[chromadb.Client] = None,
227+
relevance_score_fn: Optional[Callable[[float], float]] = None,
228+
) -> Chroma:
229+
230+
if not persist_directory:
231+
persist_directory = _DEFAULT_PERSIST_DIR
232+
chroma_collection = cls(
233+
collection_name=collection_name,
234+
embedding_function=embedding,
235+
persist_directory=persist_directory,
236+
client_settings=client_settings,
237+
client=client,
238+
collection_metadata=collection_metadata,
239+
)
240+
return chroma_collection
241+
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"doc": "Effective Post-Training Quantization for Large Language Models\nIn this blog, we describe a post-training quantization technique for large language models with enhanced SmoothQuant approach. We also illustrate the usage and demonstrate the accuracy benefits. This method has been integrated into Intel Neural Compressor.\nIn this blog, we demonstrate an enhanced SmoothQuant approach to post-training quantization to improve large language models. This method has been integrated into Intel Neural Compressor, an open-source Python library of popular model compression techniques like quantization, pruning (sparsity), distillation, and neural architecture search. It is compatible with popular frameworks such as TensorFlow, the Intel Extension for TensorFlow, PyTorch, the Intel Extension for PyTorch, ONNX Runtime, and MXNet.", "doc_id": 0}
1+
{"content": "Effective Post-Training Quantization for Large Language Models\nIn this blog, we describe a post-training quantization technique for large language models with enhanced SmoothQuant approach. We also illustrate the usage and demonstrate the accuracy benefits. This method has been integrated into Intel Neural Compressor.\nIn this blog, we demonstrate an enhanced SmoothQuant approach to post-training quantization to improve large language models. This method has been integrated into Intel Neural Compressor, an open-source Python library of popular model compression techniques like quantization, pruning (sparsity), distillation, and neural architecture search. It is compatible with popular frameworks such as TensorFlow, the Intel Extension for TensorFlow, PyTorch, the Intel Extension for PyTorch, ONNX Runtime, and MXNet.", "link": 0}

0 commit comments

Comments
 (0)