Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Aug 1, 2023
2 parents b268fe7 + 301973e commit 0ba3181
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions onnxmltools/convert/sparkml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def buildInitialTypesSimple(dataframe):


def getTensorTypeFromSpark(sparktype):
if sparktype == "StringType" or sparktype == "StringType()":
return StringTensorType([1, 1])
elif (
if sparktype in ("StringType", "StringType()"):
return StringTensorType([None, 1])
if (
sparktype == "DecimalType"
or sparktype == "DecimalType()"
or sparktype == "DoubleType"
Expand All @@ -34,17 +34,16 @@ def getTensorTypeFromSpark(sparktype):
or sparktype == "BooleanType"
or sparktype == "BooleanType()"
):
return FloatTensorType([1, 1])
else:
raise TypeError("Cannot map this type to Onnx types: " + sparktype)
return FloatTensorType([None, 1])
raise TypeError("Cannot map this type to Onnx types: " + sparktype)


def buildInputDictSimple(dataframe):
import numpy

result = {}
for field in dataframe.schema.fields:
if str(field.dataType) == 'StringType' or str(field.dataType) == 'StringType()':
if str(field.dataType) in ("StringType", "StringType()"):
result[field.name] = dataframe.select(field.name).toPandas().values
else:
result[field.name] = (
Expand Down

0 comments on commit 0ba3181

Please sign in to comment.