From 759cf83a90038d24214e86b6473499c734450c29 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 26 Jul 2024 08:44:34 -0400 Subject: [PATCH] feat(api): add `length` to string type information --- .../clickhouse/tests/test_datatypes.py | 2 +- ibis/backends/mssql/tests/test_client.py | 12 +- ibis/backends/postgres/tests/test_client.py | 11 +- ibis/backends/sql/datatypes.py | 94 ++++++++++---- ibis/backends/trino/tests/test_datatypes.py | 6 +- ibis/expr/datatypes/core.py | 14 +++ ibis/expr/datatypes/parse.py | 9 +- ibis/expr/datatypes/tests/test_parse.py | 9 +- ibis/tests/strategies.py | 115 +++++++++--------- 9 files changed, 171 insertions(+), 101 deletions(-) diff --git a/ibis/backends/clickhouse/tests/test_datatypes.py b/ibis/backends/clickhouse/tests/test_datatypes.py index 6f78d4d50156..89bec9bdbf0d 100644 --- a/ibis/backends/clickhouse/tests/test_datatypes.py +++ b/ibis/backends/clickhouse/tests/test_datatypes.py @@ -125,7 +125,7 @@ def test_array_discovery_clickhouse(con): ), param( "Array(FixedString(32))", - dt.Array(dt.String(nullable=False), nullable=False), + dt.Array(dt.String(length=32, nullable=False), nullable=False), id="array_fixed_string", ), param( diff --git a/ibis/backends/mssql/tests/test_client.py b/ibis/backends/mssql/tests/test_client.py index b11030dfe875..2ec744936392 100644 --- a/ibis/backends/mssql/tests/test_client.py +++ b/ibis/backends/mssql/tests/test_client.py @@ -49,12 +49,14 @@ ("DATETIMEOFFSET", dt.timestamp(scale=7, timezone="UTC")), ("SMALLDATETIME", dt.Timestamp(scale=0)), ("DATETIME", dt.Timestamp(scale=3)), - # Characters strings - ("CHAR", dt.string), - ("VARCHAR", dt.string), + # Character strings + ("CHAR", dt.String(length=1)), + ("TEXT", dt.string), + ("VARCHAR", dt.String(length=1)), # Unicode character strings - ("NCHAR", dt.string), - ("NVARCHAR", dt.string), + ("NCHAR", dt.String(length=1)), + ("NTEXT", dt.string), + ("NVARCHAR", dt.String(length=1)), # Binary strings ("BINARY", dt.binary), ("VARBINARY", dt.binary), diff --git a/ibis/backends/postgres/tests/test_client.py b/ibis/backends/postgres/tests/test_client.py index c996c33f33ac..c4667c80d6c0 100644 --- a/ibis/backends/postgres/tests/test_client.py +++ b/ibis/backends/postgres/tests/test_client.py @@ -148,7 +148,8 @@ def test_create_and_drop_table(con, temp_table, params): for (pg_type, ibis_type) in [ ("boolean", dt.boolean), ("bytea", dt.binary), - ("char", dt.string), + ("char", dt.String(length=1)), + ("char(42)", dt.String(length=42)), ("bigint", dt.int64), ("smallint", dt.int16), ("integer", dt.int32), @@ -162,8 +163,10 @@ def test_create_and_drop_table(con, temp_table, params): ("macaddr", dt.macaddr), ("macaddr8", dt.macaddr), ("inet", dt.inet), - ("character", dt.string), + ("character", dt.String(length=1)), ("character varying", dt.string), + ("character varying(73)", dt.String(length=73)), + ("varchar(37)", dt.String(length=37)), ("date", dt.date), ("time", dt.time), ("time without time zone", dt.time), @@ -305,13 +308,13 @@ def test_pgvector_type_load(con, vector_size): def test_name_dtype(con): expected_schema = ibis.schema( { - "f_table_catalog": dt.String(nullable=True), + "f_table_catalog": dt.String(length=256, nullable=True), "f_table_schema": dt.String(nullable=True), "f_table_name": dt.String(nullable=True), "f_geometry_column": dt.String(nullable=True), "coord_dimension": dt.Int32(nullable=True), "srid": dt.Int32(nullable=True), - "type": dt.String(nullable=True), + "type": dt.String(length=30, nullable=True), } ) diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index b76630b75900..55ecb6999957 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -19,7 +19,6 @@ typecode.BIGINT: dt.Int64, typecode.BINARY: dt.Binary, typecode.BOOLEAN: dt.Boolean, - typecode.CHAR: dt.String, typecode.DATE: dt.Date, typecode.DATETIME: dt.Timestamp, typecode.DATE32: dt.Date, @@ -28,7 +27,6 @@ typecode.ENUM8: dt.String, typecode.ENUM16: dt.String, typecode.FLOAT: dt.Float32, - typecode.FIXEDSTRING: dt.String, typecode.HSTORE: partial(dt.Map, dt.string, dt.string), typecode.INET: dt.INET, typecode.INT128: partial(dt.Decimal, 38, 0), @@ -43,11 +41,9 @@ typecode.MEDIUMINT: dt.Int32, typecode.MEDIUMTEXT: dt.String, typecode.MONEY: dt.Decimal(19, 4), - typecode.NCHAR: dt.String, typecode.UUID: dt.UUID, typecode.NAME: dt.String, typecode.NULL: dt.Null, - typecode.NVARCHAR: dt.String, typecode.OBJECT: partial(dt.Map, dt.string, dt.json), typecode.ROWVERSION: partial(dt.Binary, nullable=False), typecode.SMALLINT: dt.Int16, @@ -64,7 +60,6 @@ typecode.UTINYINT: dt.UInt8, typecode.UUID: dt.UUID, typecode.VARBINARY: dt.Binary, - typecode.VARCHAR: dt.String, typecode.VARIANT: dt.JSON, typecode.UNIQUEIDENTIFIER: dt.UUID, typecode.SET: partial(dt.Array, dt.string), @@ -116,7 +111,6 @@ dt.Float16: typecode.FLOAT, dt.Float32: typecode.FLOAT, dt.Float64: typecode.DOUBLE, - dt.String: typecode.VARCHAR, dt.Binary: typecode.VARBINARY, dt.INET: typecode.INET, dt.UUID: typecode.UUID, @@ -220,6 +214,19 @@ def _from_sqlglot_ARRAY( ) -> dt.Array: return dt.Array(cls.to_ibis(value_type), nullable=nullable) + @classmethod + def _from_sqlglot_VARCHAR( + cls, length: sge.DataTypeParam | None = None, nullable: bool | None = None + ) -> dt.String: + return dt.String( + length=int(length.this.this) if length is not None else None, + nullable=nullable, + ) + + _from_sqlglot_NVARCHAR = _from_sqlglot_NCHAR = _from_sqlglot_CHAR = ( + _from_sqlglot_FIXEDSTRING + ) = _from_sqlglot_VARCHAR + @classmethod def _from_sqlglot_MAP( cls, @@ -359,6 +366,17 @@ def _from_sqlglot_GEOGRAPHY( srid = int(srid.this.this) return typeclass(geotype="geography", nullable=nullable, srid=srid) + @classmethod + def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: + return sge.DataType( + this=typecode.VARCHAR, + expressions=( + None + if (length := dtype.length) is None + else [sge.DataTypeParam(this=sge.convert(length))] + ), + ) + @classmethod def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType: return sge.DataType(this=typecode.JSONB if dtype.binary else typecode.JSON) @@ -584,10 +602,6 @@ def _from_sqlglot_DATETIME( def _from_sqlglot_TIMESTAMP(cls, nullable: bool | None = None) -> dt.Timestamp: return dt.Timestamp(timezone="UTC", nullable=nullable) - @classmethod - def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: - return sge.DataType(this=typecode.TEXT) - class DuckDBType(SqlglotType): dialect = "duckdb" @@ -768,19 +782,19 @@ def _from_sqlglot_ARRAY( @classmethod def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.VARIANT) + return sge.DataType(this=typecode.VARIANT) @classmethod def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.ARRAY, nested=True) + return sge.DataType(this=typecode.ARRAY, nested=True) @classmethod def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.OBJECT, nested=True) + return sge.DataType(this=typecode.OBJECT, nested=True) @classmethod def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.OBJECT, nested=True) + return sge.DataType(this=typecode.OBJECT, nested=True) class SQLiteType(SqlglotType): @@ -899,9 +913,9 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> NoReturn: @classmethod def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType: if dtype.timezone is None: - return sge.DataType(this=sge.DataType.Type.DATETIME) + return sge.DataType(this=typecode.DATETIME) elif dtype.timezone == "UTC": - return sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ) + return sge.DataType(this=typecode.TIMESTAMPTZ) else: raise com.UnsupportedBackendType( "BigQuery does not support timestamps with timezones other than 'UTC'" @@ -912,9 +926,9 @@ def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType: precision = dtype.precision scale = dtype.scale if (precision, scale) == (76, 38): - return sge.DataType(this=sge.DataType.Type.BIGDECIMAL) + return sge.DataType(this=typecode.BIGDECIMAL) elif (precision, scale) in ((38, 9), (None, None)): - return sge.DataType(this=sge.DataType.Type.DECIMAL) + return sge.DataType(this=typecode.DECIMAL) else: raise com.UnsupportedBackendType( "BigQuery only supports decimal types with precision of 38 and " @@ -930,14 +944,14 @@ def _from_ibis_UInt64(cls, dtype: dt.UInt64) -> NoReturn: @classmethod def _from_ibis_UInt32(cls, dtype: dt.UInt32) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.BIGINT) + return sge.DataType(this=typecode.BIGINT) _from_ibis_UInt8 = _from_ibis_UInt16 = _from_ibis_UInt32 @classmethod def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial) -> sge.DataType: if (dtype.geotype, dtype.srid) == ("geography", 4326): - return sge.DataType(this=sge.DataType.Type.GEOGRAPHY) + return sge.DataType(this=typecode.GEOGRAPHY) else: raise com.UnsupportedBackendType( "BigQuery geography uses points on WGS84 reference ellipsoid." @@ -981,9 +995,14 @@ class ExasolType(SqlglotType): @classmethod def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: + length = dtype.length return sge.DataType( - this=sge.DataType.Type.VARCHAR, - expressions=[sge.DataTypeParam(this=sge.convert(2_000_000))], + this=typecode.VARCHAR, + expressions=[ + sge.DataTypeParam( + this=sge.convert(length if length is not None else 2_000_000) + ) + ], ) @classmethod @@ -1075,9 +1094,27 @@ def _from_sqlglot_TIMESTAMP(cls): def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: return sge.DataType( this=typecode.VARCHAR, - expressions=[sge.DataTypeParam(this=sge.Var(this="max"))], + expressions=[ + sge.DataTypeParam( + this=( + sge.Var(this="max") + if (length := dtype.length) is None + else sge.convert(length) + ) + ) + ], ) + @classmethod + def _from_sqlglot_VARCHAR( + cls, length: sge.DataTypeParam | None = None + ) -> dt.String: + if length is not None and (bound := length.this.this).isdigit(): + bound = int(bound) + else: + bound = None + return dt.String(length=bound, nullable=cls.default_nullable) + @classmethod def _from_ibis_Array(cls, dtype: dt.String) -> sge.DataType: raise com.UnsupportedBackendType("SQL Server does not support arrays") @@ -1206,6 +1243,15 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: this=typecode.MAP, expressions=[key_type, value_type], nested=True ) + @classmethod + def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType: + if (length := dtype.length) is None: + return super()._from_ibis_String(dtype) + return sge.DataType( + this=typecode.FIXEDSTRING, + expressions=[sge.DataTypeParam(this=sge.convert(length))], + ) + class FlinkType(SqlglotType): dialect = "flink" @@ -1214,7 +1260,7 @@ class FlinkType(SqlglotType): @classmethod def _from_ibis_Binary(cls, dtype: dt.Binary) -> sge.DataType: - return sge.DataType(this=sge.DataType.Type.VARBINARY) + return sge.DataType(this=typecode.VARBINARY) @classmethod def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType: diff --git a/ibis/backends/trino/tests/test_datatypes.py b/ibis/backends/trino/tests/test_datatypes.py index 5e3bc14a5403..b88f0b8c9674 100644 --- a/ibis/backends/trino/tests/test_datatypes.py +++ b/ibis/backends/trino/tests/test_datatypes.py @@ -28,11 +28,11 @@ ("integer", dt.int32), ("uuid", dt.uuid), ("char", dt.string), - ("char(42)", dt.string), + ("char(42)", dt.String(length=42)), ("json", dt.json), ("ipaddress", dt.inet), ("varchar", dt.string), - ("varchar(7)", dt.string), + ("varchar(7)", dt.String(length=7)), ("decimal", dt.Decimal(18, 3)), ("decimal(15, 0)", dt.Decimal(15, 0)), ("decimal(23, 5)", dt.Decimal(23, 5)), @@ -44,7 +44,7 @@ ("array(array(decimal(42, 23)))", dt.Array(dt.Array(dt.Decimal(42, 23)))), ( "array(row(xYz map(varchar(3), double)))", - dt.Array(dt.Struct(dict(xYz=dt.Map(dt.string, dt.float64)))), + dt.Array(dt.Struct(dict(xYz=dt.Map(dt.String(length=3), dt.float64)))), ), ("map(varchar, array(double))", dt.Map(dt.string, dt.Array(dt.float64))), ( diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index f7c4324ceb1a..e2a75c64bafe 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -535,6 +535,12 @@ def nbytes(self) -> int: class String(Variadic, Singleton): """A type representing a string. + ::: {.callout-note} + ## The `length` attribute has **no** effect on the end-user API. + + `length` is supported so that fixed-length strings' metadata is preserved. + ::: + Notes ----- Because of differences in the way different backends handle strings, we @@ -542,9 +548,17 @@ class String(Variadic, Singleton): """ + length: int | None = None + scalar = "StringScalar" column = "StringColumn" + @property + def _pretty_piece(self) -> str: + if (length := self.length) is not None: + return f"({length:d})" + return "" + @public class Binary(Variadic, Singleton): diff --git a/ibis/expr/datatypes/parse.py b/ibis/expr/datatypes/parse.py index dc709d64e31d..7684c6bc407f 100644 --- a/ibis/expr/datatypes/parse.py +++ b/ibis/expr/datatypes/parse.py @@ -113,7 +113,6 @@ def geotype_parser(typ: type[dt.DataType]) -> dt.DataType: "uint16", "uint32", "uint64", - "string", "binary", "timestamp", "time", @@ -141,9 +140,11 @@ def geotype_parser(typ: type[dt.DataType]) -> dt.DataType: ) varchar_or_char = ( - spaceless_string("varchar", "char") - .then(LPAREN.then(RAW_NUMBER).skip(RPAREN).optional()) - .result(dt.string) + spaceless_string("varchar", "string", "char") + .then( + LPAREN.then(parsy.seq(length=spaceless(LENGTH))).skip(RPAREN).optional({}) + ) + .combine_dict(dt.String) ) decimal = spaceless_string("decimal").then( diff --git a/ibis/expr/datatypes/tests/test_parse.py b/ibis/expr/datatypes/tests/test_parse.py index b020f96d4ca3..5b5fe5c3663d 100644 --- a/ibis/expr/datatypes/tests/test_parse.py +++ b/ibis/expr/datatypes/tests/test_parse.py @@ -77,9 +77,12 @@ def test_parse_decimal_failure(case): dt.dtype(case) -@pytest.mark.parametrize("spec", ["varchar", "varchar(10)", "char", "char(10)"]) -def test_parse_char_varchar(spec): - assert dt.dtype(spec) == dt.string +@pytest.mark.parametrize( + ("spec", "length"), + [("varchar", None), ("varchar(10)", 10), ("char", None), ("char(10)", 10)], +) +def test_parse_char_varchar(spec, length): + assert dt.dtype(spec) == dt.String(length=length) @pytest.mark.parametrize( diff --git a/ibis/tests/strategies.py b/ibis/tests/strategies.py index 982289977d3b..ba8f983a283b 100644 --- a/ibis/tests/strategies.py +++ b/ibis/tests/strategies.py @@ -24,35 +24,34 @@ def boolean_dtype(nullable=_nullable): def signed_integer_dtypes(nullable=_nullable): - return st.one_of( - st.builds(dt.Int8, nullable=nullable), - st.builds(dt.Int16, nullable=nullable), - st.builds(dt.Int32, nullable=nullable), - st.builds(dt.Int64, nullable=nullable), + return ( + st.builds(dt.Int8, nullable=nullable) + | st.builds(dt.Int16, nullable=nullable) + | st.builds(dt.Int32, nullable=nullable) + | st.builds(dt.Int64, nullable=nullable) ) def unsigned_integer_dtypes(nullable=_nullable): - return st.one_of( - st.builds(dt.UInt8, nullable=nullable), - st.builds(dt.UInt16, nullable=nullable), - st.builds(dt.UInt32, nullable=nullable), - st.builds(dt.UInt64, nullable=nullable), + return ( + st.builds(dt.UInt8, nullable=nullable) + | st.builds(dt.UInt16, nullable=nullable) + | st.builds(dt.UInt32, nullable=nullable) + | st.builds(dt.UInt64, nullable=nullable) ) def integer_dtypes(nullable=_nullable): - return st.one_of( - signed_integer_dtypes(nullable=nullable), - unsigned_integer_dtypes(nullable=nullable), + return signed_integer_dtypes(nullable=nullable) | unsigned_integer_dtypes( + nullable=nullable ) def floating_dtypes(nullable=_nullable): - return st.one_of( - st.builds(dt.Float16, nullable=nullable), - st.builds(dt.Float32, nullable=nullable), - st.builds(dt.Float64, nullable=nullable), + return ( + st.builds(dt.Float16, nullable=nullable) + | st.builds(dt.Float32, nullable=nullable) + | st.builds(dt.Float64, nullable=nullable) ) @@ -65,15 +64,17 @@ def decimal_dtypes(draw, nullable=_nullable): def numeric_dtypes(nullable=_nullable): - return st.one_of( - integer_dtypes(nullable=nullable), - floating_dtypes(nullable=nullable), - decimal_dtypes(nullable=nullable), + return ( + integer_dtypes(nullable=nullable) + | floating_dtypes(nullable=nullable) + | decimal_dtypes(nullable=nullable) ) def string_dtype(nullable=_nullable): - return st.builds(dt.String, nullable=nullable) + return st.builds( + dt.String, length=st.none() | st.integers(min_value=0), nullable=nullable + ) def binary_dtype(nullable=_nullable): @@ -97,13 +98,13 @@ def uuid_dtype(nullable=_nullable): def string_like_dtypes(nullable=_nullable): - return st.one_of( - string_dtype(nullable=nullable), - binary_dtype(nullable=nullable), - json_dtype(nullable=nullable), - inet_dtype(nullable=nullable), - macaddr_dtype(nullable=nullable), - uuid_dtype(nullable=nullable), + return ( + string_dtype(nullable=nullable) + | binary_dtype(nullable=nullable) + | json_dtype(nullable=nullable) + | inet_dtype(nullable=nullable) + | macaddr_dtype(nullable=nullable) + | uuid_dtype(nullable=nullable) ) @@ -128,22 +129,22 @@ def interval_dtype(interval=_interval, nullable=_nullable): return st.builds(dt.Interval, unit=interval, nullable=nullable) -def temporal_dtypes(timezone=_timezone, interval=_interval, nullable=_nullable): - return st.one_of( - date_dtype(nullable=nullable), - time_dtype(nullable=nullable), - timestamp_dtype(timezone=timezone, nullable=nullable), +def temporal_dtypes(timezone=_timezone, nullable=_nullable): + return ( + date_dtype(nullable=nullable) + | time_dtype(nullable=nullable) + | timestamp_dtype(timezone=timezone, nullable=nullable) ) def primitive_dtypes(nullable=_nullable): - return st.one_of( - null_dtype, - boolean_dtype(nullable=nullable), - integer_dtypes(nullable=nullable), - floating_dtypes(nullable=nullable), - date_dtype(nullable=nullable), - time_dtype(nullable=nullable), + return ( + null_dtype + | boolean_dtype(nullable=nullable) + | integer_dtypes(nullable=nullable) + | floating_dtypes(nullable=nullable) + | date_dtype(nullable=nullable) + | time_dtype(nullable=nullable) ) @@ -179,26 +180,26 @@ def struct_dtypes( def geospatial_dtypes(nullable=_nullable): - geotype = st.one_of(st.just("geography"), st.just("geometry")) - srid = st.one_of(st.just(None), st.integers(min_value=0)) - return st.one_of( - st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.Polygon, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.MultiPoint, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.MultiLineString, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.MultiPolygon, geotype=geotype, nullable=nullable, srid=srid), - st.builds(dt.GeoSpatial, geotype=geotype, nullable=nullable, srid=srid), + geotype = st.just("geography") | st.just("geometry") + srid = st.none() | st.integers(min_value=0) + return ( + st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.Polygon, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.MultiPoint, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.MultiLineString, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.MultiPolygon, geotype=geotype, nullable=nullable, srid=srid) + | st.builds(dt.GeoSpatial, geotype=geotype, nullable=nullable, srid=srid) ) def variadic_dtypes(nullable=_nullable): - return st.one_of( - string_dtype(nullable=nullable), - binary_dtype(nullable=nullable), - json_dtype(nullable=nullable), - array_dtypes(nullable=nullable), - map_dtypes(nullable=nullable), + return ( + string_dtype(nullable=nullable) + | binary_dtype(nullable=nullable) + | json_dtype(nullable=nullable) + | array_dtypes(nullable=nullable) + | map_dtypes(nullable=nullable) )