diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index 19f0c57ec..0020fd12d 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -213,6 +213,8 @@ def pyarrow2postgresql( # noqa: PLR0911 return pyarrow2postgresql(dtype=dtype.value_type, string_type=string_type) if pa.types.is_binary(dtype): return "BYTEA" + if pa.types.is_list(dtype): + return pyarrow2postgresql(dtype=dtype.value_type, string_type=string_type) + "[]" raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}") diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 0f72cb90f..2eab26d15 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -359,6 +359,8 @@ def generate_placeholder_parameter_pairs( """Extract Placeholder and Parameter pairs.""" def convert_value_to_native_python_type(value: Any) -> Any: + if isinstance(value, list): + return value if pd.isna(value): return None if hasattr(value, "to_pydatetime"): diff --git a/tests/unit/test_postgresql.py b/tests/unit/test_postgresql.py index bc964a267..ca8fb7270 100644 --- a/tests/unit/test_postgresql.py +++ b/tests/unit/test_postgresql.py @@ -100,6 +100,8 @@ def test_unknown_overwrite_method_error(postgresql_table, postgresql_con): def test_sql_types(postgresql_table, postgresql_con): table = postgresql_table df = get_df() + df["arrint"] = pd.Series([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + df["arrstr"] = pd.Series([["a", "b", "c"], ["d", "e", "f"], ["g", "h", "i"]]) df.drop(["binary"], axis=1, inplace=True) wr.postgresql.to_sql( df=df, @@ -108,7 +110,7 @@ def test_sql_types(postgresql_table, postgresql_con): schema="public", mode="overwrite", index=True, - dtype={"iint32": "INTEGER"}, + dtype={"iint32": "INTEGER", "arrint": "INTEGER[]", "arrstr": "VARCHAR[]"}, ) df = wr.postgresql.read_sql_query(f"SELECT * FROM public.{table}", postgresql_con) ensure_data_types(df, has_list=False) @@ -130,6 +132,8 @@ def test_sql_types(postgresql_table, postgresql_con): "timestamp": pa.timestamp(unit="ns"), "binary": pa.binary(), "category": pa.float64(), + "arrint": pa.list_(pa.int64()), + "arrstr": pa.list_(pa.string()), }, ) for df in dfs: