33Author: Dhar Rawal (@drawal1)
44"""
55
6+ from dsp .utils import dotdict
67from typing import Optional , List , Union
78import openai
89import dspy
@@ -30,6 +31,7 @@ class PineconeRM(dspy.Retrieve):
3031 pinecone_index_name (str): The name of the Pinecone index to query against.
3132 pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
3233 pinecone_env (str, optional): The Pinecone environment. Defaults to None.
34+ local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
3335 openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
3436 openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
3537 openai_org (str, optional): The organization for OpenAI. Defaults to None.
@@ -57,36 +59,62 @@ def __init__(
5759 pinecone_index_name : str ,
5860 pinecone_api_key : Optional [str ] = None ,
5961 pinecone_env : Optional [str ] = None ,
60- openai_embed_model : str = "text-embedding-ada-002" ,
62+ local_embed_model : Optional [str ] = None ,
63+ openai_embed_model : Optional [str ] = "text-embedding-ada-002" ,
6164 openai_api_key : Optional [str ] = None ,
6265 openai_org : Optional [str ] = None ,
6366 k : int = 3 ,
6467 ):
65- self ._openai_embed_model = openai_embed_model
68+ if local_embed_model is not None :
69+ try :
70+ import torch
71+ from transformers import AutoModel , AutoTokenizer
72+ except ImportError as exc :
73+ raise ModuleNotFoundError (
74+ "You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
75+ ) from exc
76+
77+ self ._local_embed_model = AutoModel .from_pretrained (local_embed_model )
78+ self ._local_tokenizer = AutoTokenizer .from_pretrained (local_embed_model )
79+ self .use_local_model = True
80+ self .device = torch .device (
81+ 'cuda:0' if torch .cuda .is_available () else
82+ 'mps' if torch .backends .mps .is_available ()
83+ else 'cpu'
84+ )
85+ elif openai_embed_model is not None :
86+ self ._openai_embed_model = openai_embed_model
87+ self .use_local_model = False
88+ # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
89+ if openai_api_key :
90+ openai .api_key = openai_api_key
91+ if openai_org :
92+ openai .organization = openai_org
93+ else :
94+ raise ValueError (
95+ "Either local_embed_model or openai_embed_model must be provided."
96+ )
97+
6698 self ._pinecone_index = self ._init_pinecone (
6799 pinecone_index_name , pinecone_api_key , pinecone_env
68100 )
69101
70- # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
71- if openai_api_key :
72- openai .api_key = openai_api_key
73- if openai_org :
74- openai .organization = openai_org
75-
76102 super ().__init__ (k = k )
77103
78104 def _init_pinecone (
79105 self ,
80106 index_name : str ,
81107 api_key : Optional [str ] = None ,
82108 environment : Optional [str ] = None ,
109+ dimension : Optional [int ] = None ,
110+ distance_metric : Optional [str ] = None ,
83111 ) -> pinecone .Index :
84112 """Initialize pinecone and return the loaded index.
85113
86114 Args:
87- index_name (str): The name of the index to load.
115+ index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
88116 api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
89- environment (str, optional): The environment (ie. `us-west1-gcp`. Defaults to env PINECONE_ENVIRONMENT.
117+ environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter` . Defaults to env PINECONE_ENVIRONMENT.
90118
91119 Raises:
92120 ValueError: If api_key or environment is not provided and not set as an environment variable.
@@ -103,14 +131,46 @@ def _init_pinecone(
103131 kwargs ["environment" ] = environment
104132 pinecone .init (** kwargs )
105133
106- return pinecone .Index (index_name )
134+ active_indexes = pinecone .list_indexes ()
135+ if index_name not in active_indexes :
136+ if dimension is None and distance_metric is None :
137+ raise ValueError (
138+ "dimension and distance_metric must be provided since the index provided does not exist."
139+ )
107140
141+ pinecone .create_index (
142+ name = index_name ,
143+ dimension = dimension ,
144+ metric = distance_metric ,
145+ )
146+
147+ return pinecone .Index (index_name )
148+
149+ def _mean_pooling (
150+ self ,
151+ model_output ,
152+ attention_mask
153+ ):
154+ try :
155+ import torch
156+ except ImportError as exc :
157+ raise ModuleNotFoundError (
158+ "You need to install torch to use a local embedding model with PineconeRM."
159+ ) from exc
160+
161+ token_embeddings = model_output [0 ] # First element of model_output contains all token embeddings
162+ input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
163+ return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
164+
108165 @backoff .on_exception (
109166 backoff .expo ,
110167 (openai .error .RateLimitError , openai .error .ServiceUnavailableError ),
111168 max_time = 15 ,
112169 )
113- def _get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
170+ def _get_embeddings (
171+ self ,
172+ queries : List [str ]
173+ ) -> List [List [float ]]:
114174 """Return query vector after creating embedding using OpenAI
115175
116176 Args:
@@ -119,10 +179,30 @@ def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
119179 Returns:
120180 List[List[float]]: List of embeddings corresponding to each query.
121181 """
122- embedding = openai .Embedding .create (
123- input = queries , model = self ._openai_embed_model
124- )
125- return [embedding ["embedding" ] for embedding in embedding ["data" ]]
182+ try :
183+ import torch
184+ except ImportError as exc :
185+ raise ModuleNotFoundError (
186+ "You need to install torch to use a local embedding model with PineconeRM."
187+ ) from exc
188+
189+ if not self .use_local_model :
190+ embedding = openai .Embedding .create (
191+ input = queries , model = self ._openai_embed_model
192+ )
193+ return [embedding ["embedding" ] for embedding in embedding ["data" ]]
194+
195+ # Use local model
196+ encoded_input = self ._local_tokenizer (queries , padding = True , truncation = True , return_tensors = "pt" ).to (self .device )
197+ with torch .no_grad ():
198+ model_output = self ._local_embed_model (** encoded_input .to (self .device ))
199+
200+ embeddings = self ._mean_pooling (model_output , encoded_input ['attention_mask' ])
201+ normalized_embeddings = torch .nn .functional .normalize (embeddings , p = 2 , dim = 1 )
202+ return normalized_embeddings .cpu ().numpy ().tolist ()
203+
204+ # we need a pooling strategy to get a single vector representation of the input
205+ # so the default is to take the mean of the hidden states
126206
127207 def forward (self , query_or_queries : Union [str , List [str ]]) -> dspy .Prediction :
128208 """Search with pinecone for self.k top passages for query
@@ -149,9 +229,10 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
149229
150230 # Sort results by score
151231 sorted_results = sorted (
152- results_dict ["matches" ], key = lambda x : x [ "score" ] , reverse = True
232+ results_dict ["matches" ], key = lambda x : x . get ( "scores" , 0.0 ) , reverse = True
153233 )
154234 passages = [result ["metadata" ]["text" ] for result in sorted_results ]
235+ passages = [dotdict ({"long_text" : passage for passage in passages })]
155236 return dspy .Prediction (passages = passages )
156237
157238 # For multiple queries, query each and return the highest scoring passages
@@ -170,4 +251,4 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
170251 sorted_passages = sorted (
171252 passage_scores .items (), key = lambda x : x [1 ], reverse = True
172253 )[: self .k ]
173- return dspy .Prediction (passages = [passage for passage , _ in sorted_passages ])
254+ return dspy .Prediction (passages = [dotdict ({ "long_text" : passage }) for passage , _ in sorted_passages ])
0 commit comments