Skip to content

Commit

Permalink
chore(mssql): fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 20, 2024
1 parent ec62bf3 commit 0150df5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
11 changes: 10 additions & 1 deletion ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def get_schema(
C.numeric_precision,
C.numeric_scale,
C.datetime_precision,
C.character_maximum_length,
)
.from_(
sg.table(
Expand Down Expand Up @@ -288,12 +289,20 @@ def get_schema(
numeric_precision,
numeric_scale,
datetime_precision,
character_maximum_length,
) in meta:
newtyp = self.compiler.type_mapper.from_string(
typ, nullable=is_nullable == "YES"
)

if typ == "float":
if (
typ.lower() != "hierarchyid"
and character_maximum_length is not None
and character_maximum_length != -1
and newtyp.is_string()
):
newtyp = newtyp.copy(length=character_maximum_length)
elif typ == "float":
newcls = dt.Float64 if numeric_precision == 53 else dt.Float32
newtyp = newcls(nullable=newtyp.nullable)
elif newtyp.is_decimal():
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@
("DATETIME", dt.Timestamp(scale=3)),
# Character strings
("CHAR", dt.String(length=1)),
("TEXT", dt.string),
("VARCHAR", dt.String(length=1)),
# Unicode character strings
("NCHAR", dt.String(length=1)),
("NTEXT", dt.string),
("NVARCHAR", dt.String(length=1)),
# Binary strings
("BINARY", dt.binary),
Expand Down Expand Up @@ -258,7 +256,7 @@ def test_dot_sql_with_unnamed_columns(con):

assert schema.types == (
dt.Timestamp(timezone="UTC", scale=7),
dt.String(nullable=False),
dt.String(nullable=False, length=2),
dt.Int32(nullable=False),
)

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,13 +1107,13 @@ def _from_ibis_String(cls, dtype: dt.String) -> sge.DataType:

@classmethod
def _from_sqlglot_VARCHAR(
cls, length: sge.DataTypeParam | None = None
cls, length: sge.DataTypeParam | None = None, nullable: bool | None = None
) -> dt.String:
if length is not None and (bound := length.this.this).isdigit():
bound = int(bound)
else:
bound = None
return dt.String(length=bound, nullable=cls.default_nullable)
return dt.String(length=bound, nullable=nullable)

@classmethod
def _from_ibis_Array(cls, dtype: dt.String) -> sge.DataType:
Expand Down

0 comments on commit 0150df5

Please sign in to comment.