Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 254e2aa

Browse files
committed
enable table and table summary for rag pdf
Signed-off-by: Manxin Xu <manxin.xu@intel.com>
1 parent 8fdde06 commit 254e2aa

File tree

9 files changed

+205
-17
lines changed

9 files changed

+205
-17
lines changed

.github/workflows/script/unitTest/run_unit_test_neuralchat.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ function main() {
8787
apt-get install -y libgl1-mesa-glx
8888
apt-get install -y libgl1-mesa-dev
8989
apt-get install libsm6 libxext6 -y
90+
brew install tesseract
91+
brew install poppler
9092
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
9193
dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
9294
python -m pip install --upgrade --force-reinstall torch==2.2.0
Binary file not shown.

intel_extension_for_transformers/neural_chat/chatbot.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def build_chatbot(config: PipelineConfig=None):
178178
return
179179
from .models.base_model import register_model_adapter
180180
register_model_adapter(adapter)
181-
# register plugin instance in model adaptor
181+
# register plugin instance in model adapter
182+
use_retrieval_plugin = False
182183
if config.plugins:
183184
for plugin_name, plugin_value in config.plugins.items():
184185
enable_plugin = plugin_value.get('enable', False)
@@ -243,6 +244,10 @@ def build_chatbot(config: PipelineConfig=None):
243244
elif plugin_name == "retrieval":
244245
from .pipeline.plugins.retrieval.retrieval_agent import Agent_QA
245246
plugins[plugin_name]['class'] = Agent_QA
247+
use_retrieval_plugin = True
248+
retrieval_plugin_value = plugin_value
249+
retrieval_plugin_value['args']['table_summary_model_name_or_path'] = config.model_name_or_path
250+
continue
246251
elif plugin_name == "cache":
247252
from .pipeline.plugins.caching.cache import ChatCache
248253
plugins[plugin_name]['class'] = ChatCache
@@ -267,15 +272,12 @@ def build_chatbot(config: PipelineConfig=None):
267272
try:
268273
plugins[plugin_name]["instance"] = plugins[plugin_name]['class'](**plugin_value['args'])
269274
except Exception as e:
270-
if "[Rereieval ERROR] Document format not supported" in str(e):
271-
set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED)
272-
logger.error("build_chatbot: retrieval plugin init failed")
273-
elif "[SafetyChecker ERROR] Sensitive check file not found" in str(e):
275+
if "[SafetyChecker ERROR] Sensitive check file not found" in str(e):
274276
set_latest_error(ErrorCodes.ERROR_SENSITIVE_CHECK_FILE_NOT_FOUND)
275-
logger.error("build_chatbot: safety checker plugin init failed")
277+
logging.error("build_chatbot: safety checker plugin init failed")
276278
else:
277279
set_latest_error(ErrorCodes.ERROR_GENERIC)
278-
logger.error("build_chatbot: plugin init failed")
280+
logging.error("build_chatbot: plugin init failed")
279281
return
280282
adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"])
281283

@@ -306,6 +308,22 @@ def build_chatbot(config: PipelineConfig=None):
306308
if config.hf_endpoint_url:
307309
return adapter
308310
adapter.load_model(parameters)
311+
312+
if use_retrieval_plugin:
313+
print(f"create retrieval plugin instance...")
314+
print(f"plugin parameters: ", retrieval_plugin_value['args'])
315+
try:
316+
plugins["retrieval"]["instance"] = plugins["retrieval"]['class'](**retrieval_plugin_value['args'])
317+
except Exception as e:
318+
if "[Rereieval ERROR] Document format not supported" in str(e):
319+
set_latest_error(ErrorCodes.ERROR_RETRIEVAL_DOC_FORMAT_NOT_SUPPORTED)
320+
logging.error("build_chatbot: retrieval plugin init failed")
321+
else:
322+
set_latest_error(ErrorCodes.ERROR_GENERIC)
323+
logging.error("build_chatbot: plugin init failed")
324+
return
325+
adapter.register_plugin_instance(plugin_name, plugins[plugin_name]["instance"])
326+
309327
if get_latest_error():
310328
return
311329
else:

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/context_utils.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@ def uni_pro(text):
4848
return filtered_text
4949

5050

51-
def read_pdf(pdf_path):
51+
def read_pdf(pdf_path, table_summary_mode, table_summary_model_name_or_path):
5252
"""Read the pdf file."""
53+
from unstructured.partition.pdf import partition_pdf
54+
from unstructured.documents.elements import FigureCaption
55+
from intel_extension_for_transformers.neural_chat.models.model_utils import predict
56+
from intel_extension_for_transformers.neural_chat.prompts.prompt import TABLESUMMARY_PROMPT
57+
5358
doc = fitz.open(pdf_path)
5459
reader = easyocr.Reader(['en'])
5560
result =''
@@ -77,7 +82,71 @@ def read_pdf(pdf_path):
7782
else:
7883
pageimg=pageimg+'.'
7984
result=result+pageimg
80-
return result
85+
86+
tables_result = []
87+
def get_relation(table_coords, caption_coords, table_page_number, caption_page_number, threshold=100):
88+
same_page = table_page_number == caption_page_number
89+
x_overlap = (min(table_coords[2][0], caption_coords[2][0]) - max(table_coords[0][0], caption_coords[0][0])) > 0
90+
if table_coords[0][1] - caption_coords[1][1] >= 0:
91+
y_distance = table_coords[0][1] - caption_coords[1][1]
92+
elif caption_coords[0][1] - table_coords[1][1] >= 0:
93+
y_distance = caption_coords[0][1] - table_coords[1][1]
94+
else:
95+
y_distance = 0
96+
y_close = y_distance < threshold
97+
return same_page and x_overlap and y_close, y_distance
98+
99+
raw_pdf_elements = partition_pdf(
100+
filename=pdf_path,
101+
infer_table_structure=True,
102+
)
103+
104+
tables = [el for el in raw_pdf_elements if el.category == "Table"]
105+
for table in tables:
106+
table_coords = table.metadata.coordinates.points
107+
content = table.metadata.text_as_html
108+
table_page_number = table.metadata.page_number
109+
min_distance = float('inf')
110+
table_summary = None
111+
if table_summary_mode == 'title':
112+
for element in raw_pdf_elements:
113+
if isinstance(element, FigureCaption) or element.text.startswith('Tab'):
114+
caption_page_number = element.metadata.page_number
115+
caption_coords = element.metadata.coordinates.points
116+
related, y_distance = get_relation(table_coords, caption_coords, table_page_number, caption_page_number)
117+
if related:
118+
if y_distance < min_distance:
119+
min_distance = y_distance
120+
table_summary = element.text
121+
if table_summary is None:
122+
parent_id = table.metadata.parent_id
123+
for element in raw_pdf_elements:
124+
if element.id == parent_id:
125+
table_summary = element.text
126+
break
127+
elif table_summary_mode == 'llm':
128+
prompt = TABLESUMMARY_PROMPT.format(table_content=content)
129+
params = {}
130+
params["model_name"] = table_summary_model_name_or_path
131+
params["prompt"] = prompt
132+
params["temperature"] = 0.8
133+
params["top_p"] = 0.9
134+
params["top_k"] = 40
135+
params["max_new_tokens"] = 1000
136+
params["num_beams"] = 2
137+
params["num_return_sequences"] = 2
138+
params["use_cache"] = True
139+
table_summary = predict(**params)
140+
table_summary = table_summary[table_summary.find('### Generated Summary:\n'):]
141+
table_summary = re.sub('### Generated Summary:\n', '', table_summary)
142+
elif table_summary_mode == 'none':
143+
table_summary = None
144+
if table_summary is None:
145+
text = f'[Table: {content}]'
146+
else:
147+
text = f'|Table: [Summary: {table_summary}], [Content: {content}]|'
148+
tables_result.append([text, pdf_path])
149+
return result, tables_result
81150

82151

83152
def read_html(html_path):
@@ -214,10 +283,11 @@ def load_structured_data(input, process, max_length, min_length):
214283
content = load_csv(input)
215284
return content
216285

217-
def load_unstructured_data(input):
286+
def load_unstructured_data(input, table_summary_mode, table_summary_model_name_or_path):
218287
"""Load unstructured context."""
288+
tables = None
219289
if input.endswith("pdf"):
220-
text = read_pdf(input)
290+
text, tables = read_pdf(input, table_summary_mode, table_summary_model_name_or_path)
221291
elif input.endswith("docx"):
222292
text = read_docx(input)
223293
elif input.endswith("html"):
@@ -231,7 +301,7 @@ def load_unstructured_data(input):
231301
text = text.replace('\n\n', ' ')
232302
text = uni_pro(text)
233303
text = re.sub(r'\s+', ' ', text)
234-
return text
304+
return text, tables
235305

236306
def get_chuck_data(content, max_length, min_length, input):
237307
"""Process the context to make it maintain a suitable length for the generation."""

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/parser/parser.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def load(self, input, **kwargs):
4949
self.min_chuck_size = kwargs['min_chuck_size']
5050
if 'process' in kwargs:
5151
self.process = kwargs['process']
52-
52+
self.table_summary_model_name_or_path = kwargs['table_summary_model_name_or_path']
53+
self.table_summary_mode = kwargs['table_summary_mode'] if 'table_summary_mode' in kwargs else 'none'
54+
5355
if isinstance(input, str):
5456
if os.path.isfile(input):
5557
data_collection = self.parse_document(input)
@@ -74,11 +76,13 @@ def parse_document(self, input):
7476
"""
7577
if input.endswith("pdf") or input.endswith("docx") or input.endswith("html") \
7678
or input.endswith("txt") or input.endswith("md"):
77-
content = load_unstructured_data(input)
79+
content, tables = load_unstructured_data(input, self.table_summary_mode, self.table_summary_model_name_or_path)
7880
if self.process:
7981
chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input)
8082
else:
8183
chuck = [[content.strip(),input]]
84+
if tables is not None:
85+
chuck = chuck + tables
8286
elif input.endswith("jsonl") or input.endswith("xlsx") or input.endswith("csv") or \
8387
input.endswith("json"):
8488
chuck = load_structured_data(input, self.process, \
@@ -118,11 +122,13 @@ def batch_parse_document(self, input):
118122
for filename in filenames:
119123
if filename.endswith("pdf") or filename.endswith("docx") or filename.endswith("html") \
120124
or filename.endswith("txt") or filename.endswith("md"):
121-
content = load_unstructured_data(os.path.join(dirpath, filename))
125+
content, tables = load_unstructured_data(os.path.join(dirpath, filename), self.table_summary_mode, self.table_summary_model_name_or_path)
122126
if self.process:
123127
chuck = get_chuck_data(content, self.max_chuck_size, self.min_chuck_size, input)
124128
else:
125129
chuck = [[content.strip(),input]]
130+
if tables is not None:
131+
chuck = chuck + tables
126132
paragraphs += chuck
127133
elif filename.endswith("jsonl") or filename.endswith("xlsx") or filename.endswith("csv") or \
128134
filename.endswith("json"):

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ qdrant-client
1414
rank_bm25
1515
scikit-learn
1616
sentence-transformers==2.3.1
17-
unstructured
17+
unstructured[all-docs]

intel_extension_for_transformers/neural_chat/prompts/prompt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,14 @@ def generate_sqlcoder_prompt(qurey, metadata_file):
321321
qurey=qurey, table_metadata_string=table_metadata_string
322322
)
323323
return prompt
324+
325+
TABLESUMMARY_PROMPT = """
326+
Task: Your task is to give a concise summary of the table. \
327+
The summary should cover the overall table structure and all detailed information of the table. \
328+
The table will be given in html format. Summarize the table below.
329+
---
330+
### Table:
331+
{table_content}
332+
---
333+
### Generated Summary:
334+
"""

intel_extension_for_transformers/neural_chat/tests/ci/plugins/retrieval/test_parameters.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,5 +714,86 @@ def test_embedding_precision_fp32(self):
714714
plugins.retrieval.args = {}
715715
plugins.retrieval.enable = False
716716

717+
class TestTableSummaryTitleMode(unittest.TestCase):
718+
def setUp(self):
719+
if os.path.exists("./table_summary_title_mode"):
720+
shutil.rmtree("./table_summary_title_mode", ignore_errors=True)
721+
return super().setUp()
722+
723+
def tearDown(self) -> None:
724+
if os.path.exists("./table_summary_title_mode"):
725+
shutil.rmtree("./table_summary_title_mode", ignore_errors=True)
726+
return super().tearDown()
727+
728+
def test_table_summary_title_mode(self):
729+
plugins.retrieval.args = {}
730+
plugins.retrieval.enable = True
731+
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
732+
plugins.retrieval.args["persist_directory"] = "./table_summary_title_mode"
733+
plugins.retrieval.args["retrieval_type"] = 'default'
734+
plugins.retrieval.args["table_summary_mode"] = 'title'
735+
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
736+
plugins=plugins)
737+
chatbot = build_chatbot(config)
738+
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
739+
print(response)
740+
self.assertIsNotNone(response)
741+
plugins.retrieval.args = {}
742+
plugins.retrieval.enable = False
743+
744+
class TestTableSummaryLLMMode(unittest.TestCase):
745+
def setUp(self):
746+
if os.path.exists("./table_summary_llm_mode"):
747+
shutil.rmtree("./table_summary_llm_mode", ignore_errors=True)
748+
return super().setUp()
749+
750+
def tearDown(self) -> None:
751+
if os.path.exists("./table_summary_llm_mode"):
752+
shutil.rmtree("./table_summary_llm_mode", ignore_errors=True)
753+
return super().tearDown()
754+
755+
def test_table_summary_llm_mode(self):
756+
plugins.retrieval.args = {}
757+
plugins.retrieval.enable = True
758+
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
759+
plugins.retrieval.args["persist_directory"] = "./table_summary_llm_mode"
760+
plugins.retrieval.args["retrieval_type"] = 'default'
761+
plugins.retrieval.args["table_summary_mode"] = 'llm'
762+
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
763+
plugins=plugins)
764+
chatbot = build_chatbot(config)
765+
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
766+
print(response)
767+
self.assertIsNotNone(response)
768+
plugins.retrieval.args = {}
769+
plugins.retrieval.enable = False
770+
771+
class TestTableSummaryNoneMode(unittest.TestCase):
772+
def setUp(self):
773+
if os.path.exists("./table_summary_none_mode"):
774+
shutil.rmtree("./table_summary_none_mode", ignore_errors=True)
775+
return super().setUp()
776+
777+
def tearDown(self) -> None:
778+
if os.path.exists("./table_summary_none_mode"):
779+
shutil.rmtree("./table_summary_none_mode", ignore_errors=True)
780+
return super().tearDown()
781+
782+
def test_table_summary_none_mode(self):
783+
plugins.retrieval.args = {}
784+
plugins.retrieval.enable = True
785+
plugins.retrieval.args["input_path"] = "../assets/docs/LLAMA2_short.pdf"
786+
plugins.retrieval.args["persist_directory"] = "./table_summary_none_mode"
787+
plugins.retrieval.args["retrieval_type"] = 'default'
788+
plugins.retrieval.args["table_summary_mode"] = 'none'
789+
config = PipelineConfig(model_name_or_path="facebook/opt-125m",
790+
plugins=plugins)
791+
chatbot = build_chatbot(config)
792+
response = chatbot.predict("What is the number of training tokens for LLaMA2?")
793+
print(response)
794+
self.assertIsNotNone(response)
795+
plugins.retrieval.args = {}
796+
plugins.retrieval.enable = False
797+
717798
if __name__ == '__main__':
718799
unittest.main()

intel_extension_for_transformers/neural_chat/tests/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ torchvision==0.17.0
8484
tqdm
8585
transformers==4.36.2
8686
transformers_stream_generator
87-
unstructured
87+
unstructured[all-docs]
8888
urllib3
8989
uvicorn
9090
vector_quantize_pytorch

0 commit comments

Comments
 (0)