Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 99df35d

Browse files
authored
[NeuralChat] Support language detection & translation for RAG chat (#1361)
1 parent e8c77e7 commit 99df35d

File tree

3 files changed

+192
-57
lines changed

3 files changed

+192
-57
lines changed

.github/workflows/unit-test-neuralchat.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ env:
3434
CONTAINER_NAME: "utTest"
3535
EXTRA_CONTAINER_NAME: "modelTest"
3636
CONTAINER_SCAN: "codeScan"
37+
GOOGLE_API_KEY: ${{ vars.GOOGLE_API_KEY }}
3738

3839
jobs:
3940
neuralchat-unit-test:
@@ -84,6 +85,7 @@ jobs:
8485
-v /home/itrex-docker/models:/models \
8586
-v /dataset/media:/media \
8687
-v /dataset/tf_dataset2:/tf_dataset2 \
88+
-e "GOOGLE_API_KEY=${{ vars.GOOGLE_API_KEY }}" \
8789
${{ env.REPO_NAME }}:${{ env.REPO_TAG }}
8890
8991
- name: Binary build

intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py

Lines changed: 188 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@
2020
import os
2121
import re
2222
import csv
23+
import shutil
2324
import datetime
25+
import requests
2426
from pathlib import Path
2527
from datetime import timedelta, timezone
26-
from typing import Optional, Dict
27-
from fastapi import APIRouter, UploadFile, File, Request, Response, Form
28+
from typing import Optional, Dict, List
29+
from fastapi import APIRouter, UploadFile, File, Request, Response, Form, status, HTTPException
2830
from ...config import GenerationConfig
2931
from ...cli.log import logger
3032
from ...server.restful.request import RetrievalRequest, FeedbackRequest
3133
from ...server.restful.response import RetrievalResponse
32-
from fastapi.responses import StreamingResponse
34+
from fastapi.responses import StreamingResponse, JSONResponse
3335
from ...utils.database.mysqldb import MysqlDb
3436
from ...plugins import plugins
3537

@@ -51,6 +53,68 @@ def get_current_beijing_time():
5153
return beijing_time
5254

5355

56+
def language_detect(text: str):
57+
url = "https://translation.googleapis.com/language/translate/v2/detect"
58+
try:
59+
api_key = os.getenv("GOOGLE_API_KEY")
60+
logger.info(f"[ language_detect ] GOOGLE_API_KEY: {api_key}")
61+
except Exception as e:
62+
logger.info(f"No GOOGLE_API_KEY found. {e}")
63+
params = {
64+
'key': api_key,
65+
'q': text
66+
}
67+
68+
response = requests.post(url, params=params)
69+
if response.status_code == 200:
70+
res = response.json()
71+
return res["data"]["detections"][0][0]
72+
else:
73+
print("Error status:", response.status_code)
74+
print("Error content:", response.json())
75+
return None
76+
77+
78+
def language_translate(text: str, target: str='en'):
79+
url = "https://translation.googleapis.com/language/translate/v2"
80+
api_key = os.getenv("GOOGLE_API_KEY")
81+
logger.info(f"[ language_translate ] GOOGLE_API_KEY: {api_key}")
82+
params = {
83+
'key': api_key,
84+
'q': text,
85+
'target': target
86+
}
87+
88+
response = requests.post(url, params=params)
89+
if response.status_code == 200:
90+
res = response.json()
91+
return res["data"]["translations"][0]
92+
else:
93+
print("Error status:", response.status_code)
94+
print("Error content:", response.json())
95+
return None
96+
97+
98+
def create_upload_dir(knowledge_base_id: str, user_id: str):
99+
if knowledge_base_id == 'default':
100+
path_prefix = RETRIEVAL_FILE_PATH + 'default'
101+
else:
102+
path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id
103+
upload_path = path_prefix + '/upload_dir'
104+
persist_path = path_prefix + '/persist_dir'
105+
if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ):
106+
if knowledge_base_id == 'default':
107+
os.makedirs(Path(path_prefix), exist_ok=True)
108+
os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True)
109+
os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True)
110+
logger.info(f"Default kb {path_prefix} does not exist, create.")
111+
else:
112+
logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}")
113+
raise Exception(f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
114+
Please check kb_id and save path again.")
115+
return upload_path, persist_path
116+
117+
54118
class RetrievalAPIRouter(APIRouter):
55119

56120
def __init__(self) -> None:
@@ -93,20 +157,18 @@ async def retrieval_upload_link(request: Request):
93157
if 'knowledge_base_id' in params.keys():
94158
print(f"[askdoc - upload_link] append")
95159
knowledge_base_id = params['knowledge_base_id']
96-
persist_path = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id + '/persist_dir'
97-
if not os.path.exists(persist_path):
98-
return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
99-
Please check kb_id and save path again."
160+
upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id)
100161

101162
try:
102163
print("[askdoc - upload_link] starting to append local db...")
103164
instance = plugins['retrieval']["instance"]
165+
print(f"[askdoc - upload_link] persist_path: {persist_path}")
104166
instance.append_localdb(append_path=link_list, persist_directory=persist_path)
105167
print(f"[askdoc - upload_link] kb appended successfully")
106168
except Exception as e: # pragma: no cover
107169
logger.info(f"[askdoc - upload_link] create knowledge base fails! {e}")
108170
return Response(content="Error occurred while uploading links.", status_code=500)
109-
return {"Succeed"}
171+
return {"status": True}
110172
# create new kb with link
111173
else:
112174
print(f"[askdoc - upload_link] create")
@@ -185,51 +247,41 @@ async def retrieval_create(request: Request,
185247

186248
@router.post("/v1/askdoc/append")
187249
async def retrieval_append(request: Request,
188-
file: UploadFile = File(...),
250+
files: List[UploadFile] = File(...),
251+
# file: UploadFile = File(...),
189252
knowledge_base_id: str = Form(...)):
190253
global plugins
191-
filename = file.filename
192-
if '/' in filename:
193-
filename = filename.split('/')[-1]
194-
logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}")
254+
for file in files:
255+
filename = file.filename
256+
if '/' in filename:
257+
filename = filename.split('/')[-1]
258+
logger.info(f"[askdoc - append] received file: {filename}, kb_id: {knowledge_base_id}")
259+
260+
user_id = request.client.host
261+
logger.info(f'[askdoc - append] user id is: {user_id}')
262+
263+
# create local upload dir
264+
upload_path, persist_path = create_upload_dir(knowledge_base_id, user_id)
265+
print(f"[askdoc - upload_link] persist_path: {persist_path}")
266+
cur_time = get_current_beijing_time()
267+
logger.info(f"[askdoc - append] upload path: {upload_path}")
268+
269+
# save file to local path
270+
save_file_name = upload_path + '/' + cur_time + '-' + filename
271+
with open(save_file_name, 'wb') as fout:
272+
content = await file.read()
273+
fout.write(content)
274+
logger.info(f"[askdoc - append] file saved to local path: {save_file_name}")
195275

196-
user_id = request.client.host
197-
logger.info(f'[askdoc - append] user id is: {user_id}')
198-
if knowledge_base_id == 'default':
199-
path_prefix = RETRIEVAL_FILE_PATH + 'default'
200-
else:
201-
path_prefix = RETRIEVAL_FILE_PATH+user_id+'-'+knowledge_base_id
202-
upload_path = path_prefix + '/upload_dir'
203-
persist_path = path_prefix + '/persist_dir'
204-
if ( not os.path.exists(upload_path) ) or ( not os.path.exists(persist_path) ):
205-
if knowledge_base_id == 'default':
206-
os.makedirs(Path(path_prefix), exist_ok=True)
207-
os.makedirs(Path(path_prefix) / 'upload_dir', exist_ok=True)
208-
os.makedirs(Path(path_prefix) / 'persist_dir', exist_ok=True)
209-
logger.info(f"Default kb {path_prefix} does not exist, create.")
210-
else:
211-
logger.info(f"kbid [{knowledge_base_id}] does not exist for user {user_id}")
212-
return f"Knowledge base id [{knowledge_base_id}] does not exist for user {user_id}, \
213-
Please check kb_id and save path again."
214-
cur_time = get_current_beijing_time()
215-
logger.info(f"[askdoc - append] upload path: {upload_path}")
216-
217-
# save file to local path
218-
save_file_name = upload_path + '/' + cur_time + '-' + filename
219-
with open(save_file_name, 'wb') as fout:
220-
content = await file.read()
221-
fout.write(content)
222-
logger.info(f"[askdoc - append] file saved to local path: {save_file_name}")
223-
224-
try:
225-
# get retrieval instance and reload db with new knowledge base
226-
logger.info("[askdoc - append] starting to append to local db...")
227-
instance = plugins['retrieval']["instance"]
228-
instance.append_localdb(append_path=save_file_name, persist_directory=persist_path)
229-
logger.info(f"[askdoc - append] new file successfully appended to kb")
230-
except Exception as e: # pragma: no cover
231-
logger.info(f"[askdoc - append] create knowledge base fails! {e}")
232-
return "Error occurred while uploading files."
276+
try:
277+
# get retrieval instance and reload db with new knowledge base
278+
logger.info("[askdoc - append] starting to append to local db...")
279+
instance = plugins['retrieval']["instance"]
280+
instance.append_localdb(append_path=save_file_name, persist_directory=persist_path)
281+
logger.info(f"[askdoc - append] new file successfully appended to kb")
282+
except Exception as e: # pragma: no cover
283+
logger.info(f"[askdoc - append] create knowledge base fails! {e}")
284+
return "Error occurred while uploading files."
233285
return "Succeed"
234286

235287

@@ -244,18 +296,24 @@ async def retrieval_chat(request: Request):
244296

245297
# parse parameters
246298
params = await request.json()
247-
query = params['query']
248-
origin_query = params['translated']
299+
origin_query = params['query']
249300
kb_id = params['knowledge_base_id']
250301
stream = params['stream']
251302
max_new_tokens = params['max_new_tokens']
252303
return_link = params['return_link']
253-
logger.info(f"[askdoc - chat] kb_id: '{kb_id}', query: '{query}', \
304+
logger.info(f"[askdoc - chat] kb_id: '{kb_id}', \
254305
origin_query: '{origin_query}', stream mode: '{stream}', \
255306
max_new_tokens: '{max_new_tokens}', \
256307
return_link: '{return_link}'")
257308
config = GenerationConfig(max_new_tokens=max_new_tokens)
258309

310+
# detect and translate query
311+
detect_res = language_detect(origin_query)
312+
if detect_res['language'] == 'en':
313+
query = origin_query
314+
else:
315+
query = language_translate(origin_query)['translatedText']
316+
259317
path_prefix = RETRIEVAL_FILE_PATH
260318
cur_path = Path(path_prefix) / "default" / "persist_dir"
261319
os.makedirs(path_prefix, exist_ok=True)
@@ -350,6 +408,27 @@ def stream_generator():
350408
return StreamingResponse(stream_generator(), media_type="text/event-stream")
351409

352410

411+
@router.post("/v1/askdoc/translate")
412+
async def retrieval_translate(request: Request):
413+
user_id = request.client.host
414+
logger.info(f'[askdoc - translate] user id is: {user_id}')
415+
416+
# parse parameters
417+
params = await request.json()
418+
content = params['content']
419+
logger.info(f'[askdoc - translate] origin content: {content}')
420+
421+
detect_res = language_detect(content)
422+
logger.info(f'[askdoc - translate] detected language: {detect_res["language"]}')
423+
if detect_res['language'] == 'en':
424+
translate_res = language_translate(content, target='zh-CN')['translatedText']
425+
else:
426+
translate_res = language_translate(content, target='en')['translatedText']
427+
428+
logger.info(f'[askdoc - translate] translated result: {translate_res}')
429+
return {"tranlated_content": translate_res}
430+
431+
353432
@router.post("/v1/askdoc/feedback")
354433
def save_chat_feedback_to_db(request: FeedbackRequest) -> None:
355434
logger.info(f'[askdoc - feedback] fastrag feedback received.')
@@ -433,3 +512,59 @@ def data_generator():
433512
data_generator(),
434513
media_type='text/csv',
435514
headers={"Content-Disposition": f"attachment;filename=feedback{cur_time_str}.csv"})
515+
516+
517+
@router.post("/v1/askdoc/verify_upload")
518+
async def verify_upload(request: Request):
519+
params = await request.json()
520+
user_id = params['user_id']
521+
logger.info(f'[askdoc - verify_upload] current user: {user_id}')
522+
523+
if user_id == "admin":
524+
upload_path = RETRIEVAL_FILE_PATH + 'default/upload_dir'
525+
if not os.path.exists(upload_path):
526+
logger.info(f'[askdoc - verify_upload] currently NOT uploaded')
527+
return {"is_uploaded": False}
528+
else:
529+
logger.info(f'[askdoc - verify_upload] currently ALREADY uploaded')
530+
return {"is_uploaded": True}
531+
else:
532+
return JSONResponse(
533+
content={"message": f"Current user {user_id} is not allowed to access /verify_upload api."},
534+
status_code=status.HTTP_400_BAD_REQUEST)
535+
536+
537+
@router.delete("/v1/askdoc/delete_all")
538+
async def delete_all_files():
539+
delete_path = RETRIEVAL_FILE_PATH + 'default'
540+
if not os.path.exists(delete_path):
541+
logger.info(f'[askdoc - delete_all] No file/link uploaded. Clear.')
542+
return {"status": True}
543+
else:
544+
# delete all upload files
545+
for filename in os.listdir(delete_path+'/upload_dir'):
546+
file_path = os.path.join(delete_path, filename)
547+
try:
548+
if os.path.isfile(file_path) or os.path.islink(file_path):
549+
os.unlink(file_path)
550+
elif os.path.isdir(file_path):
551+
shutil.rmtree(file_path)
552+
except Exception as e:
553+
raise HTTPException(
554+
status_code=500,
555+
detail=f'Failed to delete {filename}. Reason: {e}'
556+
)
557+
try:
558+
shutil.rmtree(delete_path+'/upload_dir')
559+
except Exception as e:
560+
raise HTTPException(
561+
status_code=500,
562+
detail=f'Failed to delete {delete_path}/upload_dir. Reason: {e}'
563+
)
564+
# reload default kb
565+
origin_persist_dir = "/home/sdp/askgm_persist_new"
566+
instance = plugins['retrieval']["instance"]
567+
instance.reload_localdb(local_persist_dir = origin_persist_dir)
568+
print(f"[askdoc - delete_all] Original kb loaded from: {origin_persist_dir}")
569+
570+
return {"status": True}

intel_extension_for_transformers/neural_chat/tests/ci/server/test_askdoc_server.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def test_append_existing_kb_with_links(self):
118118
json={**sample_link_list, "knowledge_base_id": gaudi2_kb_id},
119119
)
120120
assert response.status_code == 200
121-
assert "Succeed" in response.json()
121+
assert response.json()['status'] == True
122122

123123
async def test_append_existing_kb(self):
124124
# create oneapi knowledge base
@@ -133,7 +133,7 @@ async def test_append_existing_kb(self):
133133
with open("./gaudi2.txt", "rb") as file:
134134
response = client.post(
135135
"/v1/askdoc/append",
136-
files={"file": ("./gaudi2.txt", file, "multipart/form-data")},
136+
files={"files": ("./gaudi2.txt", file, "multipart/form-data")},
137137
data={"knowledge_base_id": oneapi_kb_id},
138138
)
139139
assert response.status_code == 200
@@ -151,7 +151,6 @@ async def test_non_stream_chat(self):
151151
gaudi2_kb_id = response.json()["knowledge_base_id"]
152152
query_params = {
153153
"query": "How about the benchmark test of Habana Gaudi2?",
154-
"translated": "How about the benchmark test of Habana Gaudi2?",
155154
"knowledge_base_id": gaudi2_kb_id,
156155
"stream": False,
157156
"max_new_tokens": 64,
@@ -172,7 +171,6 @@ async def test_stream_chat(self):
172171
gaudi2_kb_id = response.json()["knowledge_base_id"]
173172
query_params = {
174173
"query": "How about the benchmark test of Habana Gaudi2?",
175-
"translated": "How about the benchmark test of Habana Gaudi2?",
176174
"knowledge_base_id": gaudi2_kb_id,
177175
"stream": True,
178176
"max_new_tokens": 64,

0 commit comments

Comments
 (0)