From 903da912dee3cdfd3d68cbb8f1c90dd500131562 Mon Sep 17 00:00:00 2001 From: lixfz Date: Thu, 15 Feb 2024 19:41:51 +0800 Subject: [PATCH] Update column selector --- hypernets/tabular/column_selector.py | 15 ++++++++------- .../tabular/tb_dask/dask_transofromer_test.py | 4 +++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/hypernets/tabular/column_selector.py b/hypernets/tabular/column_selector.py index 1d7f6807..690e821c 100644 --- a/hypernets/tabular/column_selector.py +++ b/hypernets/tabular/column_selector.py @@ -104,7 +104,7 @@ def __init__(self, pattern=None, *, dtype_include=None, dtype_exclude=None, assert isinstance(word_count_threshold, int) and word_count_threshold >= 1 if dtype_include is None: - dtype_include = ['object'] + dtype_include = ['object', 'string'] super(TextColumnSelector, self).__init__(pattern, dtype_include=dtype_include, @@ -241,19 +241,20 @@ def __call__(self, df): column_all = ColumnSelector() -column_object_category_bool = ColumnSelector(dtype_include=['object', 'category', 'bool']) -column_object_category_bool_with_auto = AutoCategoryColumnSelector(dtype_include=['object', 'category', 'bool'], - cat_exponent=0.5) -column_text = TextColumnSelector(dtype_include=['object']) +column_object_category_bool = ColumnSelector(dtype_include=['object', 'string', 'category', 'bool']) +column_object_category_bool_with_auto = AutoCategoryColumnSelector( + dtype_include=['object', 'string', 'category', 'bool'], + cat_exponent=0.5) +column_text = TextColumnSelector(dtype_include=['object', 'string']) column_latlong = LatLongColumnSelector() -column_object = ColumnSelector(dtype_include=['object']) +column_object = ColumnSelector(dtype_include=['object', 'string']) column_category = ColumnSelector(dtype_include=['category']) column_bool = ColumnSelector(dtype_include=['bool']) column_number = ColumnSelector(dtype_include='number') column_number_exclude_timedelta = ColumnSelector(dtype_include='number', dtype_exclude='timedelta') column_object_category_bool_int = ColumnSelector( - dtype_include=['object', 'category', 'bool', + dtype_include=['object', 'string', 'category', 'bool', 'int', 'int8', 'int16', 'int32', 'int64', 'uint', 'uint8', 'uint16', 'uint32', 'uint64']) diff --git a/hypernets/tests/tabular/tb_dask/dask_transofromer_test.py b/hypernets/tests/tabular/tb_dask/dask_transofromer_test.py index 6b4a7e07..5fbcf855 100644 --- a/hypernets/tests/tabular/tb_dask/dask_transofromer_test.py +++ b/hypernets/tests/tabular/tb_dask/dask_transofromer_test.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import pytest from hypernets.tabular.datasets import dsutils from hypernets.utils import const @@ -130,10 +131,11 @@ def test_varlen_encoder_with_customized_data(self): print(d_result_df) assert all(d_result_df.values == result.values) + @pytest.mark.xfail # see: dask_ml ColumnTransformer def test_dataframe_wrapper(self): X = self.bank_data.copy() - cats = X.select_dtypes(['object', ]).columns.to_list() + cats = X.select_dtypes(['object', 'string']).columns.to_list() continous = X.select_dtypes(['float', 'float64', 'int', 'int64']).columns.to_list() transformers = [('cats', dex.SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=''),