1111
1212import openai
1313
14+ from typing import Dict , Union , Optional
15+ from collections import OrderedDict
1416from flask import Flask , request , jsonify , abort
1517from sentence_transformers import SentenceTransformer
1618
2224
2325
2426class Config :
25- def __init__ (self , config_file ):
27+ def __init__ (self , config_file : str ):
2628 self .config_file = config_file
2729 self .config = configparser .ConfigParser ()
2830 if not os .path .exists (self .config_file ):
@@ -32,7 +34,7 @@ def __init__(self, config_file):
3234 logging .info (f'Loading config file: { self .config_file } ' )
3335 self .config .read (config_file )
3436
35- def get_val (self , section , key ) :
37+ def get_val (self , section : str , key : str ) -> Optional [ str ] :
3638 answer = None
3739
3840 try :
@@ -42,7 +44,7 @@ def get_val(self, section, key):
4244
4345 return answer
4446
45- def get_bool (self , section , key , default = False ):
47+ def get_bool (self , section : str , key : str , default : bool = False ) -> bool :
4648 try :
4749 return self .config .getboolean (section , key )
4850 except Exception as err :
@@ -51,23 +53,28 @@ def get_bool(self, section, key, default=False):
5153
5254
5355class EmbeddingCache :
54- def __init__ (self ):
55- logger .info ('Created in-memory cache' )
56- self .cache = {}
56+ def __init__ (self , max_size : int = 500 ):
57+ logger .info (f'Created in-memory cache; max size={ max_size } ' )
58+ self .cache = OrderedDict ()
59+ self .max_size = max_size
5760
58- def get_cache_key (self , text , model_type ) :
61+ def get_cache_key (self , text : str , model_type : str ) -> str :
5962 return hashlib .sha256 ((text + model_type ).encode ()).hexdigest ()
6063
61- def get (self , text , model_type ):
64+ def get (self , text : str , model_type : str ):
6265 return self .cache .get (self .get_cache_key (text , model_type ))
6366
64- def set (self , text , model_type , embedding ):
65- self .cache [self .get_cache_key (text , model_type )] = embedding
67+ def set (self , text : str , model_type : str , embedding ):
68+ key = self .get_cache_key (text , model_type )
69+ self .cache [key ] = embedding
70+ if len (self .cache ) > self .max_size :
71+ self .cache .popitem (last = False )
6672
6773
6874class EmbeddingGenerator :
69- def __init__ (self , sbert_model = None , openai_key = None ):
75+ def __init__ (self , sbert_model : Optional [ str ] = None , openai_key : Optional [ str ] = None ):
7076 self .sbert_model = sbert_model
77+ self .openai_key = openai_key
7178 if self .sbert_model is not None :
7279 try :
7380 self .model = SentenceTransformer (self .sbert_model )
@@ -77,48 +84,55 @@ def __init__(self, sbert_model=None, openai_key=None):
7784 sys .exit (1 )
7885
7986 if openai_key is not None :
80- openai .api_key = openai_key
87+ openai .api_key = self . openai_key
8188 logger .info ('enabled model: text-embedding-ada-002' )
8289
83- def get_openai_embeddings (self , text ) :
90+ def generate (self , text : str , model_type : str ) -> Dict [ str , Union [ str , float , list ]] :
8491 start_time = time .time ()
92+ result = {'status' : 'success' }
8593
86- try :
87- response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
88- elapsed_time = (time .time () - start_time ) * 1000
89- data = {
90- "embedding" : response ['data' ][0 ]['embedding' ],
91- "status" : "success" ,
92- "elapsed" : elapsed_time ,
93- "model" : "text-embedding-ada-002"
94- }
95- return data
96- except Exception as err :
97- logger .error (f'Failed to get OpenAI embeddings: { err } ' )
98- return {"status" : "error" , "message" : str (err ), "model" : "text-embedding-ada-002" }
99-
100- def get_transformers_embeddings (self , text ):
101- start_time = time .time ()
102-
103- try :
104- embedding = self .model .encode (text ).tolist ()
105- elapsed_time = (time .time () - start_time ) * 1000
106- data = {
107- "embedding" : embedding ,
108- "status" : "success" ,
109- "elapsed" : elapsed_time ,
110- "model" : self .sbert_model
111- }
112- return data
113- except Exception as err :
114- logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
115- return {"status" : "error" , "message" : str (err ), "model" : self .sbert_model }
116-
117- def generate (self , text , model_type ):
11894 if model_type == 'openai' :
119- return self .get_openai_embeddings (text )
95+ try :
96+ response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
97+ result ['embedding' ] = response ['data' ][0 ]['embedding' ]
98+ result ['model' ] = 'text-embedding-ada-002'
99+ except Exception as err :
100+ logger .error (f'Failed to get OpenAI embeddings: { err } ' )
101+ result ['status' ] = 'error'
102+ result ['message' ] = str (err )
103+
120104 else :
121- return self .get_transformers_embeddings (text )
105+ try :
106+ embedding = self .model .encode (text ).tolist ()
107+ result ['embedding' ] = embedding
108+ result ['model' ] = self .sbert_model
109+ except Exception as err :
110+ logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
111+ result ['status' ] = 'error'
112+ result ['message' ] = str (err )
113+
114+ result ['elapsed' ] = (time .time () - start_time ) * 1000
115+ return result
116+
117+
118+ @app .route ('/health' , methods = ['GET' ])
119+ def health_check ():
120+ sbert_on = embedding_generator .sbert_model if embedding_generator .sbert_model else 'disabled'
121+ openai_on = True if embedding_generator .openai_key else 'disabled'
122+
123+ health_status = {
124+ "models" : {
125+ "openai" : openai_on ,
126+ 'sentence-transformers' : sbert_on
127+ },
128+ "cache" : {
129+ "enabled" : embedding_cache is not None ,
130+ "size" : len (embedding_cache .cache ) if embedding_cache else None ,
131+ "max_size" : None
132+ }
133+ }
134+
135+ return jsonify (health_status )
122136
123137
124138@app .route ('/submit' , methods = ['POST' ])
@@ -134,22 +148,32 @@ def submit_text():
134148 if model_type not in ['local' , 'openai' ]:
135149 abort (400 , 'model field must be one of: local, openai' )
136150
137- if embedding_cache :
138- result = embedding_cache .get (text_data , model_type )
139- if result :
140- logger .info ('found embedding in cache!' )
141- result = {'embedding' : result , 'cache' : True , "status" : 'success' }
142- else :
151+ if isinstance (text_data , str ):
152+ text_data = [text_data ]
153+
154+ if not all (isinstance (text , str ) for text in text_data ):
155+ abort (400 , 'all data must be text strings' )
156+
157+ results = []
158+ for text in text_data :
143159 result = None
144160
145- if result is None :
146- result = embedding_generator .generate (text_data , model_type )
161+ if embedding_cache :
162+ result = embedding_cache .get (text , model_type )
163+ if result :
164+ logger .info ('found embedding in cache!' )
165+ result = {'embedding' : result , 'cache' : True , "status" : 'success' }
166+
167+ if result is None :
168+ result = embedding_generator .generate (text , model_type )
169+
170+ if embedding_cache and result ['status' ] == 'success' :
171+ embedding_cache .set (text , model_type , result ['embedding' ])
172+ logger .info ('added to cache' )
147173
148- if embedding_cache and result ['status' ] == 'success' :
149- embedding_cache .set (text_data , model_type , result ['embedding' ])
150- logger .info ('added to cache' )
174+ results .append (result )
151175
152- return jsonify (result )
176+ return jsonify (results )
153177
154178
155179if __name__ == '__main__' :
@@ -168,6 +192,8 @@ def submit_text():
168192 openai_key = conf .get_val ('main' , 'openai_api_key' )
169193 sbert_model = conf .get_val ('main' , 'sent_transformers_model' )
170194 use_cache = conf .get_bool ('main' , 'use_cache' , default = False )
195+ if use_cache :
196+ max_cache_size = int (conf .get_val ('main' , 'cache_max' ))
171197
172198 if openai_key is None :
173199 logger .warn ('No OpenAI API key set in configuration file: server.conf' )
@@ -179,7 +205,7 @@ def submit_text():
179205 logger .error ('No sbert model set *and* no openAI key set; exiting' )
180206 sys .exit (1 )
181207
182- embedding_cache = EmbeddingCache () if use_cache else None
208+ embedding_cache = EmbeddingCache (max_cache_size ) if use_cache else None
183209 embedding_generator = EmbeddingGenerator (sbert_model , openai_key )
184210
185211 app .run (debug = True )
0 commit comments