diff --git a/language/t5.py b/language/t5.py index 9a548cd..7632d66 100644 --- a/language/t5.py +++ b/language/t5.py @@ -12,49 +12,43 @@ from huggingface_hub import hf_hub_download +import os +import re +import torch +from transformers import AutoTokenizer, T5EncoderModel + +import os +import re +import torch +from transformers import AutoTokenizer, T5EncoderModel + class T5Embedder: available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl'] - bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa - - def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, - t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120): + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') + + def __init__( + self, + device, + model_path, + *, + torch_dtype=None, + model_max_length=120 + ): self.device = torch.device(device) self.torch_dtype = torch_dtype or torch.bfloat16 - if t5_model_kwargs is None: - t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} - t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} - - self.use_text_preprocessing = use_text_preprocessing - self.hf_token = hf_token - self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_') - self.dir_or_name = dir_or_name - tokenizer_path, path = dir_or_name, dir_or_name - if local_cache: - cache_dir = os.path.join(self.cache_dir, dir_or_name) - tokenizer_path, path = cache_dir, cache_dir - elif dir_or_name in self.available_models: - cache_dir = os.path.join(self.cache_dir, dir_or_name) - for filename in [ - 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', - 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin' - ]: - hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir, - force_filename=filename, token=self.hf_token) - tokenizer_path, path = cache_dir, cache_dir - else: - cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') - for filename in [ - 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', - ]: - hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, - force_filename=filename, token=self.hf_token) - tokenizer_path = cache_dir - - print(tokenizer_path) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() self.model_max_length = model_max_length + tokenizer_path = os.path.abspath(model_path) + model_path = tokenizer_path + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True) + self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True, + low_cpu_mem_usage=True, + torch_dtype=self.torch_dtype, + device_map={'shared': self.device, 'encoder': self.device}).eval() + + + def get_text_embeddings(self, texts): texts = [self.text_preprocessing(text) for text in texts] @@ -198,4 +192,4 @@ def clean_caption(self, caption): caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) caption = re.sub(r'^\.\S+$', '', caption) - return caption.strip() \ No newline at end of file + return caption.strip()