diff --git a/inference.py b/inference.py index 4492aed..9524cc0 100644 --- a/inference.py +++ b/inference.py @@ -1,3 +1,4 @@ +import fire from typing import Optional import torch import time @@ -42,8 +43,6 @@ def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_l if device == "cuda": torch.set_default_tensor_type(torch.cuda.HalfTensor) - else: - torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args).to(device) @@ -56,6 +55,7 @@ def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_l return LLaMA(model, tokenizer, model_args) def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None): + device = self.args.device if max_gen_len is None: max_gen_len = self.args.max_seq_len - 1 # Convert each prompt into tokens @@ -128,11 +128,15 @@ def _sample_top_p(self, probs, p): return next_token - -if __name__ == '__main__': +def main( + checkpoints_dir: str ='llama-2-7b/', + tokenizer_path: str ='tokenizer.model', + max_seq_len: int = 128, + max_batch_size: int = 4, + allow_cuda: bool = False +): torch.manual_seed(0) - allow_cuda = False device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu' prompts = [ @@ -153,11 +157,11 @@ def _sample_top_p(self, probs, p): ] model = LLaMA.build( - checkpoints_dir='llama-2-7b/', - tokenizer_path='tokenizer.model', + checkpoints_dir=checkpoints_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, load_model=True, - max_seq_len=1024, - max_batch_size=len(prompts), device=device ) @@ -167,3 +171,6 @@ def _sample_top_p(self, probs, p): print(f'{out_texts[i]}') print('-' * 50) + +if __name__ == '__main__': + fire.Fire(main) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6e76252..90887d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch sentencepiece -tqdm \ No newline at end of file +tqdm +fire \ No newline at end of file