@@ -49,8 +49,8 @@ def _sklearn_to_dict(model):
4949 'RandomForestClassifier' : 'Forest' ,
5050 'DecisionTreeClassifier' : 'Decision tree' ,
5151 'DecisionTreeRegressor' : 'Decision tree' ,
52- 'classifier' : 'Classification ' ,
53- 'regressor' : 'Prediction ' }
52+ 'classifier' : 'classification ' ,
53+ 'regressor' : 'prediction ' }
5454
5555 if hasattr (model , '_final_estimator' ):
5656 estimator = type (model ._final_estimator )
@@ -207,10 +207,26 @@ def get_version(x):
207207 # If model is a CASTable then assume it holds an ASTORE model.
208208 # Import these via a ZIP file.
209209 if 'swat.cas.table.CASTable' in str (type (model )):
210- zipfile = utils .create_package (model )
210+ zipfile = utils .create_package (model , input = input )
211211
212212 if create_project :
213- project = mr .create_project (project , repo_obj )
213+ outvar = []
214+ invar = []
215+ import zipfile as zp
216+ import copy
217+ zipfilecopy = copy .deepcopy (zipfile )
218+ tmpzip = zp .ZipFile (zipfilecopy )
219+ if "outputVar.json" in tmpzip .namelist ():
220+ outvar = json .loads (tmpzip .read ("outputVar.json" ).decode ('utf=8' )) #added decode for 3.5 and older
221+ for tmp in outvar :
222+ tmp .update ({'role' :'output' })
223+ if "inputVar.json" in tmpzip .namelist ():
224+ invar = json .loads (tmpzip .read ("inputVar.json" ).decode ('utf-8' )) #added decode for 3.5 and older
225+ for tmp in invar :
226+ if tmp ['role' ] != 'input' :
227+ tmp ['role' ]= 'input'
228+ vars = invar + outvar
229+ project = mr .create_project (project , repo_obj , variables = vars )
214230
215231 model = mr .import_model_from_zip (name , project , zipfile ,
216232 version = version )
@@ -302,17 +318,27 @@ def get_version(x):
302318 else :
303319 prediction_variable = None
304320
305- project = mr .create_project (project , repo_obj ,
321+ # As of Viya 3.4 the 'predictionVariable' parameter is not set during
322+ # project creation. Update the project if necessary.
323+ if function == 'prediction' : #Predications require predictionVariable
324+ project = mr .create_project (project , repo_obj ,
306325 variables = vars ,
307326 function = model .get ('function' ),
308327 targetLevel = target_level ,
309328 predictionVariable = prediction_variable )
310329
311- # As of Viya 3.4 the 'predictionVariable' parameter is not set during
312- # project creation. Update the project if necessary.
313- if project .get ('predictionVariable' ) != prediction_variable :
314- project ['predictionVariable' ] = prediction_variable
315- mr .update_project (project )
330+ if project .get ('predictionVariable' ) != prediction_variable :
331+ project ['predictionVariable' ] = prediction_variable
332+ mr .update_project (project )
333+ else : #Classifications require eventProbabilityVariable
334+ project = mr .create_project (project , repo_obj ,
335+ variables = vars ,
336+ function = model .get ('function' ),
337+ targetLevel = target_level ,
338+ eventProbabilityVariable = prediction_variable )
339+ if project .get ('eventProbabilityVariable' ) != prediction_variable :
340+ project ['eventProbabilityVariable' ] = prediction_variable
341+ mr .update_project (project )
316342
317343 model = mr .create_model (model , project )
318344
@@ -506,9 +532,12 @@ def update_model_performance(data, model, label, refresh=True):
506532 "regression and binary classification projects. "
507533 "Received project with '%s' target level. Should be "
508534 "'Interval' or 'Binary'." , project .get ('targetLevel' ))
509- elif project .get ('predictionVariable' , '' ) == '' :
535+ elif project .get ('predictionVariable' , '' ) == '' and project . get ( 'function' , '' ). lower () == 'prediction' :
510536 raise ValueError ("Project '%s' does not have a prediction variable "
511537 "specified." % project )
538+ elif project .get ('eventProbabilityVariable' , '' ) == '' and project .get ('function' , '' ).lower () == 'classification' :
539+ raise ValueError ("Project '%s' does not have an Event Probability variable "
540+ "specified." % project )
512541
513542 # Find the performance definition for the model
514543 # As of Viya 3.4, no way to search by model or project
0 commit comments