@@ -96,37 +96,36 @@ def _validate_attributes(self) -> None:
9696 if not self .openlayer_api_key :
9797 raise ValueError (
9898 "An Openlayer API key is required for publishing."
99- " Please provide `openlayer_api_key` or set the"
100- " OPENLAYER_API_KEY environment variable."
99+ " Please set it as environment variable named OPENLAYER_API_KEY."
101100 )
102101
103- if not self .openlayer_project_name :
102+ if (
103+ not self .openlayer_project_name
104+ and not self .openlayer_inference_pipeline_name
105+ and not self .openlayer_inference_pipeline_id
106+ ):
104107 raise ValueError (
105- "You must specify the name of the project on Openlayer"
106- " that you want to publish to. Please provide"
107- " `openlayer_project_name` or set the OPENLAYER_PROJECT_NAME"
108- " environment variable."
108+ "You must provide more information about the project and"
109+ " inference pipeline on Openlayer to publish data."
110+ " Please provide either: "
111+ " - the project name and inference pipeline name, or"
112+ " - the inference pipeline id."
113+ " You can set them as environment variables named"
114+ " OPENLAYER_PROJECT_NAME, OPENLAYER_INFERENCE_PIPELINE_NAME, "
115+ "and OPENLAYER_INFERENCE_PIPELINE_ID."
109116 )
110117
111- if (
112- not self .openlayer_inference_pipeline_id
113- and not self .openlayer_inference_pipeline_name
114- ):
115- raise ValueError (
116- "Either an inference pipeline id or name is required."
117- " Please provide `openlayer_inference_pipeline_id` or"
118- " `openlayer_inference_pipeline_name`, "
119- "or set the OPENLAYER_INFERENCE_PIPELINE_ID or"
120- " OPENLAYER_INFERENCE_PIPELINE_NAME environment variables."
121- )
122- logger .info (
123- "Data will be streamed to Openlayer project %s and inference pipeline %s." ,
124- self .openlayer_project_name ,
125- (
126- self .openlayer_inference_pipeline_id
127- or self .openlayer_inference_pipeline_name
128- ),
129- )
118+ if (
119+ self .openlayer_inference_pipeline_name
120+ and not self .openlayer_project_name
121+ and not self .openlayer_inference_pipeline_id
122+ ):
123+ raise ValueError (
124+ "You must provide the Openlayer project name where the inference"
125+ " pipeline is located."
126+ " Please set it as the environment variable"
127+ " OPENLAYER_PROJECT_NAME."
128+ )
130129
131130 def stream_data (self , data : Dict [str , any ], config : Dict [str , any ]) -> None :
132131 """Stream data to the Openlayer platform.
@@ -157,32 +156,35 @@ def _load_inference_pipeline(self) -> None:
157156 If no platform/project information is provided, it is set to None.
158157 """
159158 inference_pipeline = None
160- if self .openlayer_api_key :
161- client = openlayer .OpenlayerClient (
162- api_key = self .openlayer_api_key , verbose = False
159+ client = openlayer .OpenlayerClient (
160+ api_key = self .openlayer_api_key , verbose = False
161+ )
162+
163+ # Prioritize the inference pipeline id over the name
164+ if self .openlayer_inference_pipeline_id :
165+ inference_pipeline = inference_pipelines .InferencePipeline (
166+ client = client ,
167+ upload = None ,
168+ json = {"id" : self .openlayer_inference_pipeline_id , "projectId" : None },
169+ task_type = tasks .TaskType .LLM ,
163170 )
164- if self .openlayer_inference_pipeline_id :
165- # Load inference pipeline directly from the id
166- inference_pipeline = inference_pipelines .InferencePipeline (
167- client = client ,
168- upload = None ,
169- json = {
170- "id" : self .openlayer_inference_pipeline_id ,
171- "projectId" : None ,
172- },
173- task_type = tasks .TaskType .LLM ,
171+ elif self .openlayer_inference_pipeline_name :
172+ with utils .HidePrints ():
173+ project = client .create_project (
174+ name = self .openlayer_project_name , task_type = tasks .TaskType .LLM
174175 )
175- else :
176- if self .openlayer_project_name :
177- with utils .HidePrints ():
178- project = client .create_project (
179- name = self .openlayer_project_name ,
180- task_type = tasks .TaskType .LLM ,
181- )
182- inference_pipeline = project .create_inference_pipeline (
183- name = self .openlayer_inference_pipeline_name
184- )
185-
176+ inference_pipeline = project .create_inference_pipeline (
177+ name = self .openlayer_inference_pipeline_name
178+ )
179+ if inference_pipeline :
180+ logger .info (
181+ "Going to try to stream data to the inference pipeline with id %s." ,
182+ inference_pipeline .id ,
183+ )
184+ else :
185+ logger .warn (
186+ "No inference pipeline found. Data will not be streamed to Openlayer."
187+ )
186188 self .inference_pipeline = inference_pipeline
187189
188190 def publish_batch_data (self , df : pd .DataFrame , config : Dict [str , any ]) -> None :
0 commit comments