|
| 1 | +# Licensed under the MIT: https://mit-license.org/ |
| 2 | +# For details: https://github.com/pylint-dev/pylint-ml/LICENSE |
| 3 | +# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt |
| 4 | + |
| 5 | +"""Check for proper usage of Pandas functions with required parameters.""" |
| 6 | + |
| 7 | +from astroid import nodes |
| 8 | +from pylint.checkers import BaseChecker |
| 9 | +from pylint.checkers.utils import only_required_for_messages |
| 10 | +from pylint.interfaces import HIGH |
| 11 | + |
| 12 | + |
| 13 | +class PandasParameterChecker(BaseChecker): |
| 14 | + name = "pandas-parameter" |
| 15 | + msgs = { |
| 16 | + "W8111": ( |
| 17 | + "Ensure that required parameters %s are explicitly specified in Pandas method %s.", |
| 18 | + "pandas-parameter", |
| 19 | + "Explicitly specifying required parameters improves model performance and prevents unintended behavior.", |
| 20 | + ), |
| 21 | + } |
| 22 | + |
| 23 | + # Define required parameters for specific Pandas classes and methods |
| 24 | + REQUIRED_PARAMS = { |
| 25 | + # DataFrame creation |
| 26 | + "DataFrame": ["data"], # The primary input data for DataFrame creation |
| 27 | + # Concatenation |
| 28 | + "concat": ["objs"], # The list or dictionary of DataFrames/Series to concatenate |
| 29 | + # DataFrame I/O (Input/Output) |
| 30 | + "read_csv": ["filepath_or_buffer", "dtype"], # Path to CSV file or file-like object; column data types |
| 31 | + "read_excel": ["io", "dtype"], # Path to Excel file or file-like object; column data types |
| 32 | + "read_table": ["filepath_or_buffer", "dtype"], # Path to delimited text-file or file object; column data types |
| 33 | + "to_csv": ["path_or_buf"], # File path or buffer to write the DataFrame to |
| 34 | + "to_excel": ["excel_writer"], # File path or ExcelWriter object to write the data to |
| 35 | + # Merging and Joining |
| 36 | + "merge": ["right", "how", "on", "validate"], # The DataFrame or Serie to merge with |
| 37 | + "join": ["other"], # The DataFrame or Series to join |
| 38 | + # DataFrame Operations |
| 39 | + "pivot_table": ["index"], # The column to pivot on (values and columns have defaults) |
| 40 | + "groupby": ["by"], # The key or list of keys to group by |
| 41 | + "resample": ["rule"], # The frequency rule to resample by |
| 42 | + # Data Cleaning and Transformation |
| 43 | + "fillna": ["value"], # Value to use to fill NA/NaN values |
| 44 | + "drop": ["labels"], # Labels to drop |
| 45 | + "drop_duplicates": ["subset"], # Subset of columns to consider when dropping duplicates |
| 46 | + "replace": ["to_replace"], # Values to replace |
| 47 | + # Plotting |
| 48 | + "plot": ["x"], # x-values or index for plotting |
| 49 | + "hist": ["column"], # Column to plot the histogram for |
| 50 | + "boxplot": ["column"], # Column(s) to plot boxplot for |
| 51 | + # DataFrame Sorting |
| 52 | + "sort_values": ["by"], # Column(s) to sort by |
| 53 | + "sort_index": ["axis"], # Axis to sort along (index=0, columns=1) |
| 54 | + # Statistical Functions |
| 55 | + "corr": ["method"], # Method to use for correlation ('pearson', 'kendall', 'spearman') |
| 56 | + "describe": [], # No required parameters, but additional ones could be specified |
| 57 | + # Windowing/Resampling Functions |
| 58 | + "rolling": ["window"], # Size of the moving window |
| 59 | + "ewm": ["span"], # Span for exponentially weighted calculations |
| 60 | + # Miscellaneous Functions |
| 61 | + "apply": ["func"], # Function to apply to the data |
| 62 | + "agg": ["func"], # Function or list of functions for aggregation |
| 63 | + } |
| 64 | + |
| 65 | + @only_required_for_messages("pandas-parameter") |
| 66 | + def visit_call(self, node: nodes.Call) -> None: |
| 67 | + method_name = self._get_method_name(node) |
| 68 | + if method_name in self.REQUIRED_PARAMS: |
| 69 | + provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} |
| 70 | + # Collect all missing parameters |
| 71 | + missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] |
| 72 | + if missing_params: |
| 73 | + self.add_message( |
| 74 | + "pandas-parameter", |
| 75 | + node=node, |
| 76 | + confidence=HIGH, |
| 77 | + args=(", ".join(missing_params), method_name), |
| 78 | + ) |
| 79 | + |
| 80 | + @staticmethod |
| 81 | + def _get_method_name(node: nodes.Call) -> str: |
| 82 | + """Extracts the method name from a Call node, including handling chained calls.""" |
| 83 | + func = node.func |
| 84 | + while isinstance(func, nodes.Attribute): |
| 85 | + func = func.expr |
| 86 | + return ( |
| 87 | + node.func.attrname |
| 88 | + if isinstance(node.func, nodes.Attribute) |
| 89 | + else func.name if isinstance(func, nodes.Name) else "" |
| 90 | + ) |
0 commit comments