@@ -14,24 +14,23 @@ def mean_pooling(model_output, attention_mask):
1414
1515
1616class Data2VecAudio (BaseEmbedding ):
17- def __init__ (self , model : str = "model/text2vec-base-chinese/" ):
17+ def __init__ (self , model ):
1818 current_dir = os .path .dirname (os .path .abspath (__file__ ))
1919 parent_dir = os .path .dirname (current_dir )
2020 model_dir = os .path .dirname (parent_dir )
21- model = os .path .join (model_dir , model )
21+ model_path = os .path .join (model_dir , model )
22+
23+ self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
24+ self .tokenizer = BertTokenizer .from_pretrained (model_path , local_files_only = True )
25+ self .model = BertModel .from_pretrained (model_path , local_files_only = True )
2226
2327 try :
2428 self .__dimension = self .model .config .hidden_size
2529 except Exception :
2630 from transformers import AutoConfig
27-
2831 config = AutoConfig .from_pretrained (model )
2932 self .__dimension = config .hidden_size
3033
31- self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
32- self .tokenizer = BertTokenizer .from_pretrained (model , local_files_only = True )
33- self .model = BertModel .from_pretrained (model , local_files_only = True )
34-
3534 def to_embeddings (self , data , ** _ ):
3635 encoded_input = self .tokenizer (data , padding = True , truncation = True , return_tensors = 'pt' )
3736 num_tokens = sum (map (len , encoded_input ['input_ids' ]))
0 commit comments