|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import tensorrt_llm |
| 4 | +import torch |
| 5 | +from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp |
| 6 | + |
| 7 | +from modules import shared |
| 8 | +from modules.logging_colors import logger |
| 9 | +from modules.text_generation import ( |
| 10 | + get_max_prompt_length, |
| 11 | + get_reply_from_output_ids |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +class TensorRTLLMModel: |
| 16 | + def __init__(self): |
| 17 | + pass |
| 18 | + |
| 19 | + @classmethod |
| 20 | + def from_pretrained(self, path_to_model): |
| 21 | + |
| 22 | + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) |
| 23 | + runtime_rank = tensorrt_llm.mpi_rank() |
| 24 | + |
| 25 | + # Define model settings |
| 26 | + runner_kwargs = dict( |
| 27 | + engine_dir=str(path_to_model), |
| 28 | + lora_dir=None, |
| 29 | + rank=runtime_rank, |
| 30 | + debug_mode=False, |
| 31 | + lora_ckpt_source="hf", |
| 32 | + ) |
| 33 | + |
| 34 | + if shared.args.cpp_runner: |
| 35 | + logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"") |
| 36 | + runner_kwargs.update( |
| 37 | + max_batch_size=1, |
| 38 | + max_input_len=shared.args.max_seq_len - 512, |
| 39 | + max_output_len=512, |
| 40 | + max_beam_width=1, |
| 41 | + max_attention_window_size=None, |
| 42 | + sink_token_length=None, |
| 43 | + ) |
| 44 | + else: |
| 45 | + logger.info("TensorRT-LLM: Using \"ModelRunner\"") |
| 46 | + |
| 47 | + # Load the model |
| 48 | + runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner |
| 49 | + runner = runner_cls.from_dir(**runner_kwargs) |
| 50 | + |
| 51 | + result = self() |
| 52 | + result.model = runner |
| 53 | + result.runtime_rank = runtime_rank |
| 54 | + |
| 55 | + return result |
| 56 | + |
| 57 | + def generate_with_streaming(self, prompt, state): |
| 58 | + batch_input_ids = [] |
| 59 | + input_ids = shared.tokenizer.encode( |
| 60 | + prompt, |
| 61 | + add_special_tokens=True, |
| 62 | + truncation=False, |
| 63 | + ) |
| 64 | + input_ids = torch.tensor(input_ids, dtype=torch.int32) |
| 65 | + input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length |
| 66 | + batch_input_ids.append(input_ids) |
| 67 | + |
| 68 | + if shared.args.cpp_runner: |
| 69 | + max_new_tokens = min(512, state['max_new_tokens']) |
| 70 | + elif state['auto_max_new_tokens']: |
| 71 | + max_new_tokens = state['truncation_length'] - input_ids.shape[-1] |
| 72 | + else: |
| 73 | + max_new_tokens = state['max_new_tokens'] |
| 74 | + |
| 75 | + with torch.no_grad(): |
| 76 | + generator = self.model.generate( |
| 77 | + batch_input_ids, |
| 78 | + max_new_tokens=max_new_tokens, |
| 79 | + max_attention_window_size=None, |
| 80 | + sink_token_length=None, |
| 81 | + end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1, |
| 82 | + pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id, |
| 83 | + temperature=state['temperature'], |
| 84 | + top_k=state['top_k'], |
| 85 | + top_p=state['top_p'], |
| 86 | + num_beams=1, |
| 87 | + length_penalty=1.0, |
| 88 | + repetition_penalty=state['repetition_penalty'], |
| 89 | + presence_penalty=state['presence_penalty'], |
| 90 | + frequency_penalty=state['frequency_penalty'], |
| 91 | + stop_words_list=None, |
| 92 | + bad_words_list=None, |
| 93 | + lora_uids=None, |
| 94 | + prompt_table_path=None, |
| 95 | + prompt_tasks=None, |
| 96 | + streaming=not shared.args.cpp_runner, |
| 97 | + output_sequence_lengths=True, |
| 98 | + return_dict=True, |
| 99 | + medusa_choices=None |
| 100 | + ) |
| 101 | + |
| 102 | + torch.cuda.synchronize() |
| 103 | + |
| 104 | + cumulative_reply = '' |
| 105 | + starting_from = batch_input_ids[0].shape[-1] |
| 106 | + |
| 107 | + if shared.args.cpp_runner: |
| 108 | + sequence_length = generator['sequence_lengths'][0].item() |
| 109 | + output_ids = generator['output_ids'][0][0][:sequence_length].tolist() |
| 110 | + |
| 111 | + cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) |
| 112 | + starting_from = sequence_length |
| 113 | + yield cumulative_reply |
| 114 | + else: |
| 115 | + for curr_outputs in generator: |
| 116 | + if shared.stop_everything: |
| 117 | + break |
| 118 | + |
| 119 | + sequence_length = curr_outputs['sequence_lengths'][0].item() |
| 120 | + output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist() |
| 121 | + |
| 122 | + cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) |
| 123 | + starting_from = sequence_length |
| 124 | + yield cumulative_reply |
| 125 | + |
| 126 | + def generate(self, prompt, state): |
| 127 | + output = '' |
| 128 | + for output in self.generate_with_streaming(prompt, state): |
| 129 | + pass |
| 130 | + |
| 131 | + return output |
0 commit comments