|
| 1 | +# Licensed under the MIT: https://mit-license.org/ |
| 2 | +# For details: https://github.com/pylint-dev/pylint-ml/LICENSE |
| 3 | +# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt |
| 4 | + |
| 5 | +"""Check for proper usage of Pandas functions with required parameters.""" |
| 6 | + |
| 7 | +from astroid import nodes |
| 8 | +from pylint.checkers import BaseChecker |
| 9 | +from pylint.checkers.utils import only_required_for_messages |
| 10 | +from pylint.interfaces import HIGH |
| 11 | + |
| 12 | + |
| 13 | +class PandasParameterChecker(BaseChecker): |
| 14 | + name = "pandas-parameter" |
| 15 | + msgs = { |
| 16 | + "W8111": ( |
| 17 | + "Ensure that required parameters %s are explicitly specified in Pandas method %s.", |
| 18 | + "pandas-parameter", |
| 19 | + "Explicitly specifying required parameters improves model performance and prevents unintended behavior.", |
| 20 | + ), |
| 21 | + } |
| 22 | + |
| 23 | + # Define required parameters for specific Pandas classes and methods |
| 24 | + REQUIRED_PARAMS = { |
| 25 | + # DataFrame creation |
| 26 | + 'DataFrame': ['data'], # The primary input data for DataFrame creation |
| 27 | + |
| 28 | + # Concatenation |
| 29 | + 'concat': ['objs'], # The list or dictionary of DataFrames/Series to concatenate |
| 30 | + |
| 31 | + # DataFrame I/O (Input/Output) |
| 32 | + 'read_csv': ['filepath_or_buffer', 'dtype'], # Path to CSV file or file-like object; column data types |
| 33 | + 'read_excel': ['io', 'dtype'], # Path to Excel file or file-like object; column data types |
| 34 | + 'read_table': ['filepath_or_buffer', 'dtype'], # Path to delimited text-file or file object; column data types |
| 35 | + 'to_csv': ['path_or_buf'], # File path or buffer to write the DataFrame to |
| 36 | + 'to_excel': ['excel_writer'], # File path or ExcelWriter object to write the data to |
| 37 | + |
| 38 | + # Merging and Joining |
| 39 | + 'merge': ['right', 'how', 'on', 'validate'], # The DataFrame or Serie to merge with |
| 40 | + 'join': ['other'], # The DataFrame or Series to join |
| 41 | + |
| 42 | + # DataFrame Operations |
| 43 | + 'pivot_table': ['index'], # The column to pivot on (values and columns have defaults) |
| 44 | + 'groupby': ['by'], # The key or list of keys to group by |
| 45 | + 'resample': ['rule'], # The frequency rule to resample by |
| 46 | + |
| 47 | + # Data Cleaning and Transformation |
| 48 | + 'fillna': ['value'], # Value to use to fill NA/NaN values |
| 49 | + 'drop': ['labels'], # Labels to drop |
| 50 | + 'drop_duplicates': ['subset'], # Subset of columns to consider when dropping duplicates |
| 51 | + 'replace': ['to_replace'], # Values to replace |
| 52 | + |
| 53 | + # Plotting |
| 54 | + 'plot': ['x'], # x-values or index for plotting |
| 55 | + 'hist': ['column'], # Column to plot the histogram for |
| 56 | + 'boxplot': ['column'], # Column(s) to plot boxplot for |
| 57 | + |
| 58 | + # DataFrame Sorting |
| 59 | + 'sort_values': ['by'], # Column(s) to sort by |
| 60 | + 'sort_index': ['axis'], # Axis to sort along (index=0, columns=1) |
| 61 | + |
| 62 | + # Statistical Functions |
| 63 | + 'corr': ['method'], # Method to use for correlation ('pearson', 'kendall', 'spearman') |
| 64 | + 'describe': [], # No required parameters, but additional ones could be specified |
| 65 | + |
| 66 | + # Windowing/Resampling Functions |
| 67 | + 'rolling': ['window'], # Size of the moving window |
| 68 | + 'ewm': ['span'], # Span for exponentially weighted calculations |
| 69 | + |
| 70 | + # Miscellaneous Functions |
| 71 | + 'apply': ['func'], # Function to apply to the data |
| 72 | + 'agg': ['func'], # Function or list of functions for aggregation |
| 73 | + } |
| 74 | + |
| 75 | + @only_required_for_messages("pandas-parameter") |
| 76 | + def visit_call(self, node: nodes.Call) -> None: |
| 77 | + method_name = self._get_method_name(node) |
| 78 | + if method_name in self.REQUIRED_PARAMS: |
| 79 | + provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} |
| 80 | + # Collect all missing parameters |
| 81 | + missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] |
| 82 | + if missing_params: |
| 83 | + self.add_message( |
| 84 | + "pandas-parameter", |
| 85 | + node=node, |
| 86 | + confidence=HIGH, |
| 87 | + args=(", ".join(missing_params), method_name), |
| 88 | + ) |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def _get_method_name(node: nodes.Call) -> str: |
| 92 | + """Extracts the method name from a Call node, including handling chained calls.""" |
| 93 | + func = node.func |
| 94 | + while isinstance(func, nodes.Attribute): |
| 95 | + func = func.expr |
| 96 | + return ( |
| 97 | + node.func.attrname |
| 98 | + if isinstance(node.func, nodes.Attribute) |
| 99 | + else func.name if isinstance(func, nodes.Name) else "" |
| 100 | + ) |
0 commit comments