|
3 | 3 |
|
4 | 4 | """Utility methods.""" |
5 | 5 |
|
| 6 | +from typing import Union |
6 | 7 | import numpy as np |
7 | 8 | import pandas as pd |
8 | 9 | import scipy.sparse |
|
21 | 22 | from statsmodels.iolib.table import SimpleTable |
22 | 23 | from statsmodels.iolib.summary import summary_return |
23 | 24 | from inspect import signature |
| 25 | +from packaging.version import parse |
| 26 | + |
24 | 27 |
|
25 | 28 | MAX_RAND_SEED = np.iinfo(np.int32).max |
26 | 29 |
|
@@ -455,6 +458,13 @@ def reshape_Y_T(Y, T): |
455 | 458 | T = T.reshape(-1, 1) |
456 | 459 | return Y, T |
457 | 460 |
|
| 461 | +def _get_ensure_finite_arg(ensure_all_finite: Union[str, bool]) -> dict[str, Union[str, bool]]: |
| 462 | + if parse(sklearn.__version__) < parse("1.6"): |
| 463 | + # `force_all_finite` was renamed to `ensure_all_finite` in sklearn 1.6 and will be deprecated in 1.8 |
| 464 | + return {'force_all_finite': ensure_all_finite} |
| 465 | + else: |
| 466 | + return {'ensure_all_finite': ensure_all_finite} |
| 467 | + |
458 | 468 |
|
459 | 469 | def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True, |
460 | 470 | force_all_finite_X=True, force_all_finite_W=True): |
@@ -511,16 +521,19 @@ def check_inputs(Y, T, X, W=None, multi_output_T=True, multi_output_Y=True, |
511 | 521 | Converted and validated W. |
512 | 522 |
|
513 | 523 | """ |
514 | | - X, T = check_X_y(X, T, multi_output=multi_output_T, y_numeric=True, force_all_finite=force_all_finite_X) |
| 524 | + X, T = check_X_y(X, T, multi_output=multi_output_T, y_numeric=True, |
| 525 | + **_get_ensure_finite_arg(force_all_finite_X)) |
515 | 526 | if force_all_finite_X == 'allow-nan': |
516 | 527 | try: |
517 | 528 | assert_all_finite(X) |
518 | 529 | except ValueError: |
519 | 530 | warnings.warn("X contains NaN. Causal identification strategy can be erroneous" |
520 | 531 | " in the presence of missing values.") |
521 | | - _, Y = check_X_y(X, Y, multi_output=multi_output_Y, y_numeric=True, force_all_finite=force_all_finite_X) |
| 532 | + _, Y = check_X_y(X, Y, multi_output=multi_output_Y, y_numeric=True, |
| 533 | + **_get_ensure_finite_arg(force_all_finite_X)) |
522 | 534 | if W is not None: |
523 | | - W, _ = check_X_y(W, Y, multi_output=multi_output_Y, y_numeric=True, force_all_finite=force_all_finite_W) |
| 535 | + W, _ = check_X_y(W, Y, multi_output=multi_output_Y, y_numeric=True, |
| 536 | + **_get_ensure_finite_arg(force_all_finite_W)) |
524 | 537 | if force_all_finite_W == 'allow-nan': |
525 | 538 | try: |
526 | 539 | assert_all_finite(W) |
@@ -567,7 +580,7 @@ def check_input_arrays(*args, validate_len=True, force_all_finite=True, dtype=No |
567 | 580 | for i, arg in enumerate(args): |
568 | 581 | if np.ndim(arg) > 0: |
569 | 582 | new_arg = check_array(arg, dtype=dtype, ensure_2d=False, accept_sparse=True, |
570 | | - force_all_finite=force_all_finite) |
| 583 | + **_get_ensure_finite_arg(force_all_finite)) |
571 | 584 | if not force_all_finite: |
572 | 585 | # For when checking input values is disabled |
573 | 586 | try: |
@@ -1531,7 +1544,6 @@ def one_hot_encoder(sparse=False, **kwargs): |
1531 | 1544 | This handles the breaking name change from `sparse` to `sparse_output` |
1532 | 1545 | between sklearn versions 1.1 and 1.2. |
1533 | 1546 | """ |
1534 | | - from packaging.version import parse |
1535 | 1547 | if parse(sklearn.__version__) < parse("1.2"): |
1536 | 1548 | return OneHotEncoder(sparse=sparse, **kwargs) |
1537 | 1549 | else: |
|
0 commit comments