Skip to content

Commit 74dc571

Browse files
predict_proba check only for non-custom models
1 parent 7245b3e commit 74dc571

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

unboxapi/__init__.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -464,31 +464,33 @@ def add_model(
464464
message=f"Your function's additional args {user_args} do not match the kwargs you specifed {kwarg_keys}. \n",
465465
mitigation=f"Make sure to include all of the required kwargs to run inference with your `function`.",
466466
) from None
467-
try:
468-
if task_type in [
469-
TaskType.TabularClassification,
470-
TaskType.TabularRegression,
471-
]:
472-
test_input = train_sample_df[:3][feature_names].to_numpy()
473-
with HidePrints():
474-
function(model, test_input, **kwargs)
475-
else:
476-
test_input = [
477-
"Unbox is great!",
478-
"Let's see if this function is ready for some error analysis",
479-
]
480-
with HidePrints():
481-
function(model, test_input, **kwargs)
482-
except Exception as e:
483-
exception_stack = "".join(
484-
traceback.format_exception(type(e), e, e.__traceback__)
485-
)
486-
raise UnboxResourceError(
487-
context="There is an issue with the specified `function`. \n",
488-
message=f"It is failing with the following error: \n{exception_stack}",
489-
mitigation="Make sure your function receives the model and the input as arguments, plus the additional kwargs. Additionally,"
490-
+ "you may find it useful to debug it on the Jupyter notebook, to ensure it is working correctly before uploading it.",
491-
) from None
467+
468+
if model_type != ModelType.custom:
469+
try:
470+
if task_type in [
471+
TaskType.TabularClassification,
472+
TaskType.TabularRegression,
473+
]:
474+
test_input = train_sample_df[:3][feature_names].to_numpy()
475+
with HidePrints():
476+
function(model, test_input, **kwargs)
477+
else:
478+
test_input = [
479+
"Unbox is great!",
480+
"Let's see if this function is ready for some error analysis",
481+
]
482+
with HidePrints():
483+
function(model, test_input, **kwargs)
484+
except Exception as e:
485+
exception_stack = "".join(
486+
traceback.format_exception(type(e), e, e.__traceback__)
487+
)
488+
raise UnboxResourceError(
489+
context="There is an issue with the specified `function`. \n",
490+
message=f"It is failing with the following error: \n{exception_stack}",
491+
mitigation="Make sure your function receives the model and the input as arguments, plus the additional kwargs. Additionally,"
492+
+ "you may find it useful to debug it on the Jupyter notebook, to ensure it is working correctly before uploading it.",
493+
) from None
492494

493495
# Transformers resources
494496
if model_type is ModelType.transformers:

0 commit comments

Comments
 (0)