From fab469b02095a63c2e321d95928e4e3d3320866e Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 7 Jun 2023 08:30:00 -0500 Subject: [PATCH] [Py testing] Add helper function for cross-validation This commit adds function for comparing tables converted to DataFrame. It allow to compare results of different frameworks. Signed-off-by: Dmitrii Makarenko --- python/tests/helpers.py | 50 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/python/tests/helpers.py b/python/tests/helpers.py index 222808ed5..2de1d4147 100644 --- a/python/tests/helpers.py +++ b/python/tests/helpers.py @@ -4,6 +4,7 @@ # Copyright 2022 Intel Corporation. # # SPDX-License-Identifier: Apache-2.0 +import pandas as pd def check_schema(schema, expected): @@ -26,3 +27,52 @@ def check_res(res, expected): assert abs(expected_val - actual_val) < 0.0001 else: assert expected_val == actual_val + + +def compare_tables( + left_df: pd.DataFrame, right_df: pd.DataFrame, try_to_guess: bool = False +): + left_cols = left_df.columns.to_list() + right_cols = right_df.columns.to_list() + left_cols.sort() + right_cols.sort() + + diff_idx = [ + idx for idx, col_name in enumerate(right_cols) if col_name != left_cols[idx] + ] + + print("compare lists: ", diff_idx) + drop_left = [] + drop_right = [] + for drop_idx in diff_idx: + drop_left += [left_cols[drop_idx]] + drop_right += [right_cols[drop_idx]] + if try_to_guess: + right_df = right_df.rename(columns=dict(zip(drop_right, drop_left))) + else: + left_df = left_df.drop(columns=drop_left) + right_df = right_df.drop(columns=drop_right) + + left_cols = left_df.columns.to_list() + right_cols = right_df.columns.to_list() + left_cols.sort() + right_cols.sort() + + assert left_cols == right_cols, "Table column names are different" + + left_df.sort_values(by=left_cols, inplace=True) + right_df.sort_values(by=left_cols, inplace=True) + for col in left_df.columns: + if left_df[col].dtype in ["category"]: + left_df[col] = left_df[col].astype("str") + right_df[col] = right_df[col].astype("str") + + left_df = left_df.reset_index(drop=True) + right_df = right_df.reset_index(drop=True) + if not all(left_df == right_df): + mask = left_df == right_df + print("Mismathed left: ") + print(left_df[mask]) + print(" right: ") + print(left_df[mask]) + raise RuntimeError("Results mismatched")