Skip to content

Commit 73d6b3a

Browse files
committed
Eliminate sklearn 'force_all_finite' warning
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
1 parent 4463a54 commit 73d6b3a

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

econml/utilities.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Utility methods."""
55

6+
from typing import Union
67
import numpy as np
78
import pandas as pd
89
import scipy.sparse
@@ -21,6 +22,8 @@
2122
from statsmodels.iolib.table import SimpleTable
2223
from statsmodels.iolib.summary import summary_return
2324
from inspect import signature
25+
from packaging.version import parse
26+
2427

2528
MAX_RAND_SEED = np.iinfo(np.int32).max
2629

@@ -455,6 +458,13 @@ def reshape_Y_T(Y, T):
455458
T = T.reshape(-1, 1)
456459
return Y, T
457460

461+
def _get_ensure_finite_arg(ensure_all_finite: Union[str, bool]) -> dict[str, Union[str, bool]]:
462+
if parse(sklearn.__version__) < parse("1.6"):
463+
# `force_all_finite` was renamed to `ensure_all_finite` in sklearn 1.6 and will be deprecated in 1.8
464+
return {'force_all_finite': ensure_all_finite}
465+
else:
466+
return {'ensure_all_finite': ensure_all_finite}
467+
458468

459469
def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True,
460470
force_all_finite_X=True, force_all_finite_W=True):
@@ -511,16 +521,19 @@ def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True,
511521
Converted and validated W.
512522
513523
"""
514-
X, T = check_X_y(X, T, multi_output=multi_output_T, y_numeric=True, force_all_finite=force_all_finite_X)
524+
X, T = check_X_y(X, T, multi_output=multi_output_T, y_numeric=True,
525+
**_get_ensure_finite_arg(force_all_finite_X))
515526
if force_all_finite_X == 'allow-nan':
516527
try:
517528
assert_all_finite(X)
518529
except ValueError:
519530
warnings.warn("X contains NaN. Causal identification strategy can be erroneous"
520531
" in the presence of missing values.")
521-
_, Y = check_X_y(X, Y, multi_output=multi_output_Y, y_numeric=True, force_all_finite=force_all_finite_X)
532+
_, Y = check_X_y(X, Y, multi_output=multi_output_Y, y_numeric=True,
533+
**_get_ensure_finite_arg(force_all_finite_X))
522534
if W is not None:
523-
W, _ = check_X_y(W, Y, multi_output=multi_output_Y, y_numeric=True, force_all_finite=force_all_finite_W)
535+
W, _ = check_X_y(W, Y, multi_output=multi_output_Y, y_numeric=True,
536+
**_get_ensure_finite_arg(force_all_finite_W))
524537
if force_all_finite_W == 'allow-nan':
525538
try:
526539
assert_all_finite(W)
@@ -567,7 +580,7 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True, dtype=No
567580
for i, arg in enumerate(args):
568581
if np.ndim(arg) > 0:
569582
new_arg = check_array(arg, dtype=dtype, ensure_2d=False, accept_sparse=True,
570-
force_all_finite=force_all_finite)
583+
**_get_ensure_finite_arg(force_all_finite))
571584
if not force_all_finite:
572585
# For when checking input values is disabled
573586
try:
@@ -1531,7 +1544,6 @@ def one_hot_encoder(sparse=False, **kwargs):
15311544
This handles the breaking name change from `sparse` to `sparse_output`
15321545
between sklearn versions 1.1 and 1.2.
15331546
"""
1534-
from packaging.version import parse
15351547
if parse(sklearn.__version__) < parse("1.2"):
15361548
return OneHotEncoder(sparse=sparse, **kwargs)
15371549
else:

0 commit comments

Comments
 (0)