diff --git a/datacompy/__init__.py b/datacompy/__init__.py index b47d27c3..b43027ae 100644 --- a/datacompy/__init__.py +++ b/datacompy/__init__.py @@ -19,6 +19,7 @@ from datacompy.fugue import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, diff --git a/datacompy/fugue.py b/datacompy/fugue.py index 2ac4889a..8bc01d33 100644 --- a/datacompy/fugue.py +++ b/datacompy/fugue.py @@ -291,6 +291,101 @@ def all_rows_overlap( return all(overlap) +def count_matching_rows( + df1: AnyDataFrame, + df2: AnyDataFrame, + join_columns: Union[str, List[str]], + abs_tol: float = 0, + rel_tol: float = 0, + df1_name: str = "df1", + df2_name: str = "df2", + ignore_spaces: bool = False, + ignore_case: bool = False, + cast_column_names_lower: bool = True, + parallelism: Optional[int] = None, + strict_schema: bool = False, +) -> int: + """Count the number of rows match (on overlapping fields) + + Parameters + ---------- + df1 : ``AnyDataFrame`` + First dataframe to check + df2 : ``AnyDataFrame`` + Second dataframe to check + join_columns : list or str, optional + Column(s) to join dataframes on. If a string is passed in, that one + column will be used. + abs_tol : float, optional + Absolute tolerance between two values. + rel_tol : float, optional + Relative tolerance between two values. + df1_name : str, optional + A string name for the first dataframe. This allows the reporting to + print out an actual name instead of "df1", and allows human users to + more easily track the dataframes. + df2_name : str, optional + A string name for the second dataframe + ignore_spaces : bool, optional + Flag to strip whitespace (including newlines) from string columns (including any join + columns) + ignore_case : bool, optional + Flag to ignore the case of string columns + cast_column_names_lower: bool, optional + Boolean indicator that controls of column names will be cast into lower case + parallelism: int, optional + An integer representing the amount of parallelism. Entering a value for this + will force to use of Fugue over just vanilla Pandas + strict_schema: bool, optional + The schema must match exactly if set to ``True``. This includes the names and types. Allows for a fast fail. + + Returns + ------- + int + Number of matching rows + """ + if ( + isinstance(df1, pd.DataFrame) + and isinstance(df2, pd.DataFrame) + and parallelism is None # user did not specify parallelism + and fa.get_current_parallelism() == 1 # currently on a local execution engine + ): + comp = Compare( + df1=df1, + df2=df2, + join_columns=join_columns, + abs_tol=abs_tol, + rel_tol=rel_tol, + df1_name=df1_name, + df2_name=df2_name, + ignore_spaces=ignore_spaces, + ignore_case=ignore_case, + cast_column_names_lower=cast_column_names_lower, + ) + return comp.count_matching_rows() + + try: + count_matching_rows = _distributed_compare( + df1=df1, + df2=df2, + join_columns=join_columns, + return_obj_func=lambda comp: comp.count_matching_rows(), + abs_tol=abs_tol, + rel_tol=rel_tol, + df1_name=df1_name, + df2_name=df2_name, + ignore_spaces=ignore_spaces, + ignore_case=ignore_case, + cast_column_names_lower=cast_column_names_lower, + parallelism=parallelism, + strict_schema=strict_schema, + ) + except _StrictSchemaError: + return False + + return sum(count_matching_rows) + + def report( df1: AnyDataFrame, df2: AnyDataFrame, @@ -460,7 +555,6 @@ def _any(col: str) -> int: any_mismatch = len(match_sample) > 0 # Column Matching - cnt_intersect = shape0("intersect_rows_shape") rpt += render( "column_comparison.txt", len([col for col in column_stats if col["unequal_cnt"] > 0]), diff --git a/tests/test_fugue/conftest.py b/tests/test_fugue/conftest.py index 6a5683d2..a2ca99b1 100644 --- a/tests/test_fugue/conftest.py +++ b/tests/test_fugue/conftest.py @@ -1,6 +1,6 @@ -import pytest import numpy as np import pandas as pd +import pytest @pytest.fixture @@ -24,7 +24,8 @@ def ref_df(): c=np.random.choice(["aaa", "b_c", "csd"], 100), ) ) - return [df1, df1_copy, df2, df3, df4] + df5 = df1.sample(frac=0.1) + return [df1, df1_copy, df2, df3, df4, df5] @pytest.fixture @@ -87,3 +88,16 @@ def large_diff_df2(): np.random.seed(0) data = np.random.randint(6, 11, size=10000) return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes() + + +@pytest.fixture +def count_matching_rows_df(): + np.random.seed(0) + df1 = pd.DataFrame( + dict( + a=np.arange(0, 100), + b=np.arange(0, 100), + ) + ) + df2 = df1.sample(frac=0.1) + return [df1, df2] diff --git a/tests/test_fugue/test_duckdb.py b/tests/test_fugue/test_duckdb.py index daed1edd..3643f22d 100644 --- a/tests/test_fugue/test_duckdb.py +++ b/tests/test_fugue/test_duckdb.py @@ -20,6 +20,7 @@ from datacompy import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, unq_columns, @@ -138,3 +139,40 @@ def test_all_rows_overlap_duckdb( duckdb.sql("SELECT 'a' AS a, 'b' AS b"), join_columns="a", ) + + +def test_count_matching_rows_duckdb(count_matching_rows_df): + with duckdb.connect(): + df1 = duckdb.from_df(count_matching_rows_df[0]) + df1_copy = duckdb.from_df(count_matching_rows_df[0]) + df2 = duckdb.from_df(count_matching_rows_df[1]) + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_pandas.py b/tests/test_fugue/test_fugue_pandas.py index 77884c2c..4fd74ce7 100644 --- a/tests/test_fugue/test_fugue_pandas.py +++ b/tests/test_fugue/test_fugue_pandas.py @@ -24,6 +24,7 @@ Compare, all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, @@ -144,7 +145,6 @@ def test_report_pandas( def test_unique_columns_native(ref_df): df1 = ref_df[0] - df1_copy = ref_df[1] df2 = ref_df[2] df3 = ref_df[3] @@ -192,3 +192,41 @@ def test_all_rows_overlap_native( # Fugue assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a", parallelism=2) assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a", parallelism=2) + + +def test_count_matching_rows_native(count_matching_rows_df): + # defaults to Compare class + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[0].copy(), + join_columns="a", + ) + == 100 + ) + assert ( + count_matching_rows( + count_matching_rows_df[0], count_matching_rows_df[1], join_columns="a" + ) + == 10 + ) + # Fugue + + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[0].copy(), + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + count_matching_rows_df[0], + count_matching_rows_df[1], + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_polars.py b/tests/test_fugue/test_fugue_polars.py index fdb2212a..dcd19a94 100644 --- a/tests/test_fugue/test_fugue_polars.py +++ b/tests/test_fugue/test_fugue_polars.py @@ -20,6 +20,7 @@ from datacompy import ( all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, unq_columns, @@ -122,3 +123,37 @@ def test_all_rows_overlap_polars( assert all_rows_overlap(rdf, rdf_copy, join_columns="a") assert all_rows_overlap(rdf, sdf, join_columns="a") assert not all_rows_overlap(rdf, rdf4, join_columns="a") + + +def test_count_matching_rows_polars(count_matching_rows_df): + df1 = pl.from_pandas(count_matching_rows_df[0]) + df2 = pl.from_pandas(count_matching_rows_df[1]) + assert ( + count_matching_rows( + df1, + df1.clone(), + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1.clone(), + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + ) diff --git a/tests/test_fugue/test_fugue_spark.py b/tests/test_fugue/test_fugue_spark.py index 99da708b..efc895ff 100644 --- a/tests/test_fugue/test_fugue_spark.py +++ b/tests/test_fugue/test_fugue_spark.py @@ -22,6 +22,7 @@ Compare, all_columns_match, all_rows_overlap, + count_matching_rows, intersect_columns, is_match, report, @@ -200,3 +201,44 @@ def test_all_rows_overlap_spark( spark_session.sql("SELECT 'a' AS a, 'b' AS b"), join_columns="a", ) + + +def test_count_matching_rows_spark(spark_session, count_matching_rows_df): + count_matching_rows_df[0].iteritems = count_matching_rows_df[ + 0 + ].items # pandas 2 compatibility + count_matching_rows_df[1].iteritems = count_matching_rows_df[ + 1 + ].items # pandas 2 compatibility + df1 = spark_session.createDataFrame(count_matching_rows_df[0]) + df1_copy = spark_session.createDataFrame(count_matching_rows_df[0]) + df2 = spark_session.createDataFrame(count_matching_rows_df[1]) + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + ) + == 100 + ) + assert count_matching_rows(df1, df2, join_columns="a") == 10 + # Fugue + + assert ( + count_matching_rows( + df1, + df1_copy, + join_columns="a", + parallelism=2, + ) + == 100 + ) + assert ( + count_matching_rows( + df1, + df2, + join_columns="a", + parallelism=2, + ) + == 10 + )