Skip to content

Commit cc3a5fe

Browse files
committed
extra files pre-release to dev
1 parent e9b9c51 commit cc3a5fe

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,4 @@ start-webui.sh
130130
files-to-sync.txt
131131
rsync-exclude.lst
132132
.editorconfig
133+
rsynclist.txt

modules/tensorrt_llm.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from pathlib import Path
2+
3+
import tensorrt_llm
4+
import torch
5+
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
6+
7+
from modules import shared
8+
from modules.logging_colors import logger
9+
from modules.text_generation import (
10+
get_max_prompt_length,
11+
get_reply_from_output_ids
12+
)
13+
14+
15+
class TensorRTLLMModel:
16+
def __init__(self):
17+
pass
18+
19+
@classmethod
20+
def from_pretrained(self, path_to_model):
21+
22+
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
23+
runtime_rank = tensorrt_llm.mpi_rank()
24+
25+
# Define model settings
26+
runner_kwargs = dict(
27+
engine_dir=str(path_to_model),
28+
lora_dir=None,
29+
rank=runtime_rank,
30+
debug_mode=False,
31+
lora_ckpt_source="hf",
32+
)
33+
34+
if shared.args.cpp_runner:
35+
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
36+
runner_kwargs.update(
37+
max_batch_size=1,
38+
max_input_len=shared.args.max_seq_len - 512,
39+
max_output_len=512,
40+
max_beam_width=1,
41+
max_attention_window_size=None,
42+
sink_token_length=None,
43+
)
44+
else:
45+
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
46+
47+
# Load the model
48+
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
49+
runner = runner_cls.from_dir(**runner_kwargs)
50+
51+
result = self()
52+
result.model = runner
53+
result.runtime_rank = runtime_rank
54+
55+
return result
56+
57+
def generate_with_streaming(self, prompt, state):
58+
batch_input_ids = []
59+
input_ids = shared.tokenizer.encode(
60+
prompt,
61+
add_special_tokens=True,
62+
truncation=False,
63+
)
64+
input_ids = torch.tensor(input_ids, dtype=torch.int32)
65+
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
66+
batch_input_ids.append(input_ids)
67+
68+
if shared.args.cpp_runner:
69+
max_new_tokens = min(512, state['max_new_tokens'])
70+
elif state['auto_max_new_tokens']:
71+
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
72+
else:
73+
max_new_tokens = state['max_new_tokens']
74+
75+
with torch.no_grad():
76+
generator = self.model.generate(
77+
batch_input_ids,
78+
max_new_tokens=max_new_tokens,
79+
max_attention_window_size=None,
80+
sink_token_length=None,
81+
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
82+
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
83+
temperature=state['temperature'],
84+
top_k=state['top_k'],
85+
top_p=state['top_p'],
86+
num_beams=1,
87+
length_penalty=1.0,
88+
repetition_penalty=state['repetition_penalty'],
89+
presence_penalty=state['presence_penalty'],
90+
frequency_penalty=state['frequency_penalty'],
91+
stop_words_list=None,
92+
bad_words_list=None,
93+
lora_uids=None,
94+
prompt_table_path=None,
95+
prompt_tasks=None,
96+
streaming=not shared.args.cpp_runner,
97+
output_sequence_lengths=True,
98+
return_dict=True,
99+
medusa_choices=None
100+
)
101+
102+
torch.cuda.synchronize()
103+
104+
cumulative_reply = ''
105+
starting_from = batch_input_ids[0].shape[-1]
106+
107+
if shared.args.cpp_runner:
108+
sequence_length = generator['sequence_lengths'][0].item()
109+
output_ids = generator['output_ids'][0][0][:sequence_length].tolist()
110+
111+
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
112+
starting_from = sequence_length
113+
yield cumulative_reply
114+
else:
115+
for curr_outputs in generator:
116+
if shared.stop_everything:
117+
break
118+
119+
sequence_length = curr_outputs['sequence_lengths'][0].item()
120+
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
121+
122+
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
123+
starting_from = sequence_length
124+
yield cumulative_reply
125+
126+
def generate(self, prompt, state):
127+
output = ''
128+
for output in self.generate_with_streaming(prompt, state):
129+
pass
130+
131+
return output
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"instruction,output": "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n%instruction%<|im_end|>\n<|im_start|>assistant\n%output%<|im_end|>",
3+
"instruction,input,output": "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n%instruction%: %input%<|im_end|>\n<|im_start|>assistant\n%output%<|im_end|>"
4+
}

0 commit comments

Comments
 (0)