2121 project.status()
2222 project.push()
2323"""
24+ import copy
2425import os
2526import shutil
2627import tarfile
2930import urllib .parse
3031import uuid
3132import warnings
32- from typing import Dict , Optional , Tuple
33+ from typing import Dict , List , Optional , Tuple , Union
3334
3435import pandas as pd
3536import yaml
@@ -1073,74 +1074,50 @@ def upload_reference_dataframe(
10731074 dataset_config_file_path = dataset_config_file_path ,
10741075 task_type = task_type ,
10751076 )
1076-
1077- def send_stream_data (
1077+
1078+ def stream_data (
10781079 self ,
10791080 inference_pipeline_id : str ,
10801081 task_type : TaskType ,
1081- stream_df : pd . DataFrame ,
1082+ stream_data : Union [ Dict [ str , any ], List [ Dict [ str , any ]]] ,
10821083 stream_config : Optional [Dict [str , any ]] = None ,
10831084 stream_config_file_path : Optional [str ] = None ,
1084- verbose : bool = True ,
10851085 ) -> None :
1086- """Publishes a batch of production data to the Openlayer platform."""
1087- if stream_config is None and stream_config_file_path is None :
1086+ """Streams production data to the Openlayer platform."""
1087+ if not isinstance ( stream_data , ( dict , list )) :
10881088 raise ValueError (
1089- "Either `batch_config` or `batch_config_file_path` must be" " provided ."
1089+ "stream_data must be a dictionary or a list of dictionaries ."
10901090 )
1091- if stream_config_file_path is not None and not os .path .exists (
1092- stream_config_file_path
1093- ):
1094- raise exceptions .OpenlayerValidationError (
1095- f"Stream config file path { stream_config_file_path } does not exist."
1096- ) from None
1097- elif stream_config_file_path is not None :
1098- stream_config = utils .read_yaml (stream_config_file_path )
1099-
1100- stream_config_to_validate = dict (stream_config )
1101- stream_config_to_validate ["label" ] = "production"
1091+ if isinstance (stream_data , dict ):
1092+ stream_data = [stream_data ]
11021093
1103- # Validate stream of data
1104- stream_validator = dataset_validators . get_validator (
1094+ stream_df = pd . DataFrame ( stream_data )
1095+ stream_config = self . _validate_production_data_and_load_config (
11051096 task_type = task_type ,
1106- dataset_config = stream_config_to_validate ,
1107- dataset_config_file_path = stream_config_file_path ,
1108- dataset_df = stream_df ,
1097+ config = stream_config ,
1098+ config_file_path = stream_config_file_path ,
1099+ df = stream_df ,
11091100 )
1110- failed_validations = stream_validator .validate ()
1111-
1112- if failed_validations :
1113- raise exceptions .OpenlayerValidationError (
1114- "There are issues with the stream of data and its config. \n "
1115- "Make sure to fix all of the issues listed above before the upload." ,
1116- ) from None
1117-
1118- # Load dataset config and augment with defaults
1119- stream_data = dict (stream_config )
1120-
1121- # Add default columns if not present
1122- columns_to_add = {"timestampColumnName" , "inferenceIdColumnName" }
1123- for column in columns_to_add :
1124- if stream_data .get (column ) is None :
1125- stream_data , stream_df = self ._add_default_column (
1126- config = stream_data , df = stream_df , column_name = column
1127- )
1128-
1129-
1101+ stream_config , stream_df = self ._add_default_columns (
1102+ config = stream_config , df = stream_df
1103+ )
1104+ stream_config = self ._strip_read_only_fields (stream_config )
11301105 body = {
1131- "config" : stream_data ,
1106+ "config" : stream_config ,
11321107 "rows" : stream_df .to_dict (orient = "records" ),
11331108 }
1134-
1135- print ("This is the body!" )
1136- print (body )
11371109 self .api .post_request (
11381110 endpoint = f"inference-pipelines/{ inference_pipeline_id } /data-stream" ,
11391111 body = body ,
11401112 )
1113+ print ("Stream published!" )
11411114
1142- if verbose :
1143- print ("Stream published!" )
1115+ def _strip_read_only_fields (self , config : Dict [str , any ]) -> Dict [str , any ]:
1116+ """Strips read-only fields from the config."""
1117+ stripped_config = copy .deepcopy (config )
1118+ for field in {"columnNames" , "label" }:
1119+ stripped_config .pop (field , None )
1120+ return stripped_config
11441121
11451122 def publish_batch_data (
11461123 self ,
@@ -1151,54 +1128,29 @@ def publish_batch_data(
11511128 batch_config_file_path : Optional [str ] = None ,
11521129 ) -> None :
11531130 """Publishes a batch of production data to the Openlayer platform."""
1154- if batch_config is None and batch_config_file_path is None :
1155- raise ValueError (
1156- "Either `batch_config` or `batch_config_file_path` must be" " provided."
1157- )
1158- if batch_config_file_path is not None and not os .path .exists (
1159- batch_config_file_path
1160- ):
1161- raise exceptions .OpenlayerValidationError (
1162- f"Batch config file path { batch_config_file_path } does not exist."
1163- ) from None
1164- elif batch_config_file_path is not None :
1165- batch_config = utils .read_yaml (batch_config_file_path )
1166-
1167- batch_config ["label" ] = "production"
1168-
1169- # Validate batch of data
1170- batch_validator = dataset_validators .get_validator (
1131+ batch_config = self ._validate_production_data_and_load_config (
11711132 task_type = task_type ,
1172- dataset_config = batch_config ,
1173- dataset_config_file_path = batch_config_file_path ,
1174- dataset_df = batch_df ,
1133+ config = batch_config ,
1134+ config_file_path = batch_config_file_path ,
1135+ df = batch_df ,
1136+ )
1137+ batch_config , batch_df = self ._add_default_columns (
1138+ config = batch_config , df = batch_df
11751139 )
1176- failed_validations = batch_validator .validate ()
11771140
1178- if failed_validations :
1179- raise exceptions .OpenlayerValidationError (
1180- "There are issues with the batch of data and its config. \n "
1181- "Make sure to fix all of the issues listed above before the upload." ,
1182- ) from None
1141+ # Add column names if missing
1142+ if batch_config .get ("columnNames" ) is None :
1143+ batch_config ["columnNames" ] = list (batch_df .columns )
11831144
1184- # Add default columns if not present
1185- if batch_data .get ("columnNames" ) is None :
1186- batch_data ["columnNames" ] = list (batch_df .columns )
1187- columns_to_add = {"timestampColumnName" , "inferenceIdColumnName" }
1188- for column in columns_to_add :
1189- if batch_data .get (column ) is None :
1190- batch_data , batch_df = self ._add_default_column (
1191- config = batch_data , df = batch_df , column_name = column
1192- )
11931145 # Get min and max timestamps
1194- earliest_timestamp = batch_df [batch_data ["timestampColumnName" ]].min ()
1195- latest_timestamp = batch_df [batch_data ["timestampColumnName" ]].max ()
1146+ earliest_timestamp = batch_df [batch_config ["timestampColumnName" ]].min ()
1147+ latest_timestamp = batch_df [batch_config ["timestampColumnName" ]].max ()
11961148 batch_row_count = len (batch_df )
11971149
11981150 with tempfile .TemporaryDirectory () as tmp_dir :
11991151 # Copy save files to tmp dir
12001152 batch_df .to_csv (f"{ tmp_dir } /dataset.csv" , index = False )
1201- utils .write_yaml (batch_data , f"{ tmp_dir } /dataset_config.yaml" )
1153+ utils .write_yaml (batch_config , f"{ tmp_dir } /dataset_config.yaml" )
12021154
12031155 tar_file_path = os .path .join (tmp_dir , "tarfile" )
12041156 with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
@@ -1234,9 +1186,64 @@ def publish_batch_data(
12341186 ),
12351187 presigned_url_query_params = presigned_url_query_params ,
12361188 )
1237-
12381189 print ("Data published!" )
12391190
1191+ def _validate_production_data_and_load_config (
1192+ self ,
1193+ task_type : tasks .TaskType ,
1194+ config : Dict [str , any ],
1195+ config_file_path : str ,
1196+ df : pd .DataFrame ,
1197+ ) -> Dict [str , any ]:
1198+ """Validates the production data and its config and returns a valid config
1199+ populated with the default values."""
1200+ if config is None and config_file_path is None :
1201+ raise ValueError (
1202+ "Either the config or the config file path must be provided."
1203+ )
1204+ if config_file_path is not None and not os .path .exists (config_file_path ):
1205+ raise exceptions .OpenlayerValidationError (
1206+ f"The file specified by the config file path { config_file_path } does"
1207+ " not exist."
1208+ ) from None
1209+ elif config_file_path is not None :
1210+ config = utils .read_yaml (config_file_path )
1211+
1212+ # Force label to be production
1213+ config ["label" ] = "production"
1214+
1215+ # Validate batch of data
1216+ validator = dataset_validators .get_validator (
1217+ task_type = task_type ,
1218+ dataset_config = config ,
1219+ dataset_config_file_path = config_file_path ,
1220+ dataset_df = df ,
1221+ )
1222+ failed_validations = validator .validate ()
1223+
1224+ if failed_validations :
1225+ raise exceptions .OpenlayerValidationError (
1226+ "There are issues with the data and its config. \n "
1227+ "Make sure to fix all of the issues listed above before the upload." ,
1228+ ) from None
1229+
1230+ config = DatasetSchema ().load ({"task_type" : task_type .value , ** config })
1231+
1232+ return config
1233+
1234+ def _add_default_columns (
1235+ self , config : Dict [str , any ], df : pd .DataFrame
1236+ ) -> Tuple [Dict [str , any ], pd .DataFrame ]:
1237+ """Adds the default columns if not present and returns the updated config and
1238+ dataframe."""
1239+ columns_to_add = {"timestampColumnName" , "inferenceIdColumnName" }
1240+ for column in columns_to_add :
1241+ if config .get (column ) is None :
1242+ config , df = self ._add_default_column (
1243+ config = config , df = df , column_name = column
1244+ )
1245+ return config , df
1246+
12401247 def _add_default_column (
12411248 self , config : Dict [str , any ], df : pd .DataFrame , column_name : str
12421249 ) -> Tuple [Dict [str , any ], pd .DataFrame ]:
0 commit comments