Skip to content

Commit e45086e

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5839 Don't require OPENLAYER_PROJECT_NAME if inference pipeline id is present and Completes OPEN-5840 Make sure an inference pipeline is created if name is specified and it doesn’t exist
1 parent 8b9a55f commit e45086e

File tree

1 file changed

+52
-50
lines changed

1 file changed

+52
-50
lines changed

openlayer/services/data_streamer.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -96,37 +96,36 @@ def _validate_attributes(self) -> None:
9696
if not self.openlayer_api_key:
9797
raise ValueError(
9898
"An Openlayer API key is required for publishing."
99-
" Please provide `openlayer_api_key` or set the"
100-
" OPENLAYER_API_KEY environment variable."
99+
" Please set it as environment variable named OPENLAYER_API_KEY."
101100
)
102101

103-
if not self.openlayer_project_name:
102+
if (
103+
not self.openlayer_project_name
104+
and not self.openlayer_inference_pipeline_name
105+
and not self.openlayer_inference_pipeline_id
106+
):
104107
raise ValueError(
105-
"You must specify the name of the project on Openlayer"
106-
" that you want to publish to. Please provide"
107-
" `openlayer_project_name` or set the OPENLAYER_PROJECT_NAME"
108-
" environment variable."
108+
"You must provide more information about the project and"
109+
" inference pipeline on Openlayer to publish data."
110+
" Please provide either: "
111+
" - the project name and inference pipeline name, or"
112+
" - the inference pipeline id."
113+
" You can set them as environment variables named"
114+
" OPENLAYER_PROJECT_NAME, OPENLAYER_INFERENCE_PIPELINE_NAME, "
115+
"and OPENLAYER_INFERENCE_PIPELINE_ID."
109116
)
110117

111-
if (
112-
not self.openlayer_inference_pipeline_id
113-
and not self.openlayer_inference_pipeline_name
114-
):
115-
raise ValueError(
116-
"Either an inference pipeline id or name is required."
117-
" Please provide `openlayer_inference_pipeline_id` or"
118-
" `openlayer_inference_pipeline_name`, "
119-
"or set the OPENLAYER_INFERENCE_PIPELINE_ID or"
120-
" OPENLAYER_INFERENCE_PIPELINE_NAME environment variables."
121-
)
122-
logger.info(
123-
"Data will be streamed to Openlayer project %s and inference pipeline %s.",
124-
self.openlayer_project_name,
125-
(
126-
self.openlayer_inference_pipeline_id
127-
or self.openlayer_inference_pipeline_name
128-
),
129-
)
118+
if (
119+
self.openlayer_inference_pipeline_name
120+
and not self.openlayer_project_name
121+
and not self.openlayer_inference_pipeline_id
122+
):
123+
raise ValueError(
124+
"You must provide the Openlayer project name where the inference"
125+
" pipeline is located."
126+
" Please set it as the environment variable"
127+
" OPENLAYER_PROJECT_NAME."
128+
)
130129

131130
def stream_data(self, data: Dict[str, any], config: Dict[str, any]) -> None:
132131
"""Stream data to the Openlayer platform.
@@ -157,32 +156,35 @@ def _load_inference_pipeline(self) -> None:
157156
If no platform/project information is provided, it is set to None.
158157
"""
159158
inference_pipeline = None
160-
if self.openlayer_api_key:
161-
client = openlayer.OpenlayerClient(
162-
api_key=self.openlayer_api_key, verbose=False
159+
client = openlayer.OpenlayerClient(
160+
api_key=self.openlayer_api_key, verbose=False
161+
)
162+
163+
# Prioritize the inference pipeline id over the name
164+
if self.openlayer_inference_pipeline_id:
165+
inference_pipeline = inference_pipelines.InferencePipeline(
166+
client=client,
167+
upload=None,
168+
json={"id": self.openlayer_inference_pipeline_id, "projectId": None},
169+
task_type=tasks.TaskType.LLM,
163170
)
164-
if self.openlayer_inference_pipeline_id:
165-
# Load inference pipeline directly from the id
166-
inference_pipeline = inference_pipelines.InferencePipeline(
167-
client=client,
168-
upload=None,
169-
json={
170-
"id": self.openlayer_inference_pipeline_id,
171-
"projectId": None,
172-
},
173-
task_type=tasks.TaskType.LLM,
171+
elif self.openlayer_inference_pipeline_name:
172+
with utils.HidePrints():
173+
project = client.create_project(
174+
name=self.openlayer_project_name, task_type=tasks.TaskType.LLM
174175
)
175-
else:
176-
if self.openlayer_project_name:
177-
with utils.HidePrints():
178-
project = client.create_project(
179-
name=self.openlayer_project_name,
180-
task_type=tasks.TaskType.LLM,
181-
)
182-
inference_pipeline = project.create_inference_pipeline(
183-
name=self.openlayer_inference_pipeline_name
184-
)
185-
176+
inference_pipeline = project.create_inference_pipeline(
177+
name=self.openlayer_inference_pipeline_name
178+
)
179+
if inference_pipeline:
180+
logger.info(
181+
"Going to try to stream data to the inference pipeline with id %s.",
182+
inference_pipeline.id,
183+
)
184+
else:
185+
logger.warn(
186+
"No inference pipeline found. Data will not be streamed to Openlayer."
187+
)
186188
self.inference_pipeline = inference_pipeline
187189

188190
def publish_batch_data(self, df: pd.DataFrame, config: Dict[str, any]) -> None:

0 commit comments

Comments
 (0)