diff --git a/docs/tutorial/plot_sparse_tfidf.py b/docs/tutorial/plot_sparse_tfidf.py new file mode 100644 index 000000000..0a1240aa8 --- /dev/null +++ b/docs/tutorial/plot_sparse_tfidf.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +TfIdf, SVC and sparse matrices +============================== + +.. index:: sparse + +The example is useful to whom wants to convert a pipeline +doing a TfIdfVectorizer + SVC when the features are sparse. + +The pipeline +++++++++++++ +""" +import os +import pickle +import numpy as np +import scipy +from sklearn.pipeline import Pipeline +from sklearn.svm import SVC +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.decomposition import TruncatedSVD +from onnxruntime import InferenceSession +from skl2onnx import to_onnx, update_registered_converter +from skl2onnx.common.data_types import StringTensorType +from skl2onnx.common._topology import Scope, Operator +from skl2onnx.common._container import ModelComponentContainer +from skl2onnx.common.data_types import ( + DoubleTensorType, + FloatTensorType, + guess_proto_type, +) + + +X_train = np.array( + [ + "This is the first document", + "This document is the second document.", + "And this is the third one", + "Is this the first document?", + ] +).reshape((4, 1)) +y_train = np.array([0, 1, 0, 1]) + +model_pipeline = Pipeline( + [ + ( + "vectorizer", + TfidfVectorizer( + lowercase=True, + use_idf=True, + ngram_range=(1, 3), + max_features=30000, + ), + ), + ( + "classifier", + SVC( + class_weight="balanced", + kernel="rbf", + gamma="scale", + probability=True, + ), + ), + ] +) +model_pipeline.fit(X_train.ravel(), y_train) + +out0 = model_pipeline.steps[0][-1].transform(X_train.ravel()) +is_sparse = isinstance(out0, scipy.sparse._csr.csr_matrix) +print(f"Output type for TfIdfVectorizier is {'sparse' if is_sparse else 'dense'}.") + +svc_coef = model_pipeline.steps[1][-1].support_vectors_ +is_parse = isinstance(svc_coef, scipy.sparse._csr.csr_matrix) +print(f"Supports for SVC is {'sparse' if is_sparse else 'dense'}.") +sparsity = 1 - (svc_coef != 0).sum() / np.prod(svc_coef.shape) +print(f"sparsity={sparsity} and shape={svc_coef.shape}") + + +###################################### +# Size Comparison +# +++++++++++++++ + +pkl_name = "model.pkl" +with open(pkl_name, "wb") as f: + pickle.dump(model_pipeline, f) + +onx_name = "model.onnx" +onx = to_onnx( + model_pipeline, + initial_types=[("input", StringTensorType([None, 1]))], + options={SVC: {"zipmap": False}}, + target_opset=18, +) +with open(onx_name, "wb") as f: + f.write(onx.SerializeToString()) + +print(f"pickle size={os.stat(pkl_name).st_size}") +print(f"onnx size={os.stat(onx_name).st_size}") + +####################################### +# On such small model, it does not show that SVC is using a sparse matrix +# and ONNX SVMClassifier is using a dense one. If the matrix is 90% sparse, +# this part becomes 10 times bigger once converter into ONNX. +# +# Tweak +# +++++ +# +# The idea is to take out the matrix of coefficient out of SVC by +# reducing the number dimensions. +# We could apply a PCA but it does not support sparse features. +# TruncatedSVD does but the matrix it produces to reduce the dimension +# is dense. SparsePCA does not support sparse feature as well. +# Let's try something custom: a TruncatedSVD and then some small coefficient +# will be set to zero. + + +class SparseTruncatedSVD(TruncatedSVD): + def __init__( + self, + n_components=2, + *, + algorithm="randomized", + n_iter=5, + n_oversamples=10, + power_iteration_normalizer="auto", + random_state=None, + tol=0.0, + sparsity=0.9, + ): + TruncatedSVD.__init__( + self, + n_components, + algorithm=algorithm, + n_iter=n_iter, + n_oversamples=n_oversamples, + power_iteration_normalizer=power_iteration_normalizer, + random_state=random_state, + tol=tol, + ) + self.sparsity = sparsity + + def fit_transform(self, X, y=None): + TruncatedSVD.fit_transform(self, X, y) + + # The matrix. We could choose the coefficients to set to zero + # by minimizing `(X @ M.T - X @ M0.T) ** 2` + # where M is the original matrix and M0 the new one. + # In a first approach, we just sort the coefficients by absolute value. + components = self.components_.ravel() + flat = list((v, i) for i, v in enumerate(np.abs(components))) + flat.sort() + last_index = int(self.sparsity * len(flat)) + for tu in flat[:last_index]: + components[tu[1]] = 0 + self.components_ = scipy.sparse.coo_matrix( + components.reshape(self.components_.shape) + ) + return self.transform(X) + + +sparse_pipeline = Pipeline( + [ + ( + "vectorizer", + TfidfVectorizer( + lowercase=True, + use_idf=True, + ngram_range=(1, 3), + max_features=30000, + ), + ), + ("sparse", SparseTruncatedSVD(10, sparsity=0.6)), + ( + "classifier", + SVC( + class_weight="balanced", + kernel="rbf", + gamma="scale", + probability=True, + ), + ), + ] +) +sparse_pipeline.fit(X_train.ravel(), y_train) + +expected = model_pipeline.predict(X_train.ravel()) +got = sparse_pipeline.predict(X_train.ravel()) +print(f"Number of different predicted labels: {((expected-got)==0).sum()}") + +expected = model_pipeline.predict_proba(X_train.ravel()) +got = sparse_pipeline.predict_proba(X_train.ravel()) +diff = np.abs(expected - got) +print(f"Average absolute difference for the probabilities: {diff.max(axis=1)}") + +###################################### +# Conversion to ONNX +# ++++++++++++++++++ +# +# The new transformer cannot be converted because sklearn-onnx does not have any +# registered converter for it. We must implement it. +# We use the converter for TruncatedSVD as a base and a sparse matrix multiplication +# implemented in onnxruntime (see `OperatorKernels.md +# `_). + + +def calculate_sparse_sklearn_truncated_svd_output_shapes(operator): + cls_type = operator.inputs[0].type.__class__ + if cls_type != DoubleTensorType: + cls_type = FloatTensorType + N = operator.inputs[0].get_first_dimension() + K = operator.raw_operator.n_components + operator.outputs[0].type = cls_type([N, K]) + + +def convert_sparse_truncated_svd( + scope: Scope, operator: Operator, container: ModelComponentContainer +): + # Create alias for the scikit-learn truncated SVD model we + # are going to convert + svd = operator.raw_operator + if isinstance(operator.inputs[0].type, DoubleTensorType): + proto_dtype = guess_proto_type(operator.inputs[0].type) + else: + proto_dtype = guess_proto_type(FloatTensorType()) + # Transpose [K, C] matrix to [C, K], where C/K is the + # input/transformed feature dimension + transform_matrix = svd.components_ + transform_matrix_name = scope.get_unique_variable_name("transform_matrix") + # Put the transformation into an ONNX tensor + container.add_initializer( + transform_matrix_name, + proto_dtype, + transform_matrix.shape, + transform_matrix, + ) + + input_name = operator.inputs[0].full_name + + transposed_inputs = scope.get_unique_variable_name("transposed_inputs") + container.add_node("Transpose", input_name, transposed_inputs, perm=[1, 0]) + + transposed_outputs = scope.get_unique_variable_name("transposed_outputs") + container.add_node( + "SparseToDenseMatMul", + [transform_matrix_name, transposed_inputs], + transposed_outputs, + op_domain="com.microsoft", + op_version=1, + ) + container.add_node( + "Transpose", transposed_outputs, operator.outputs[0].full_name, perm=[1, 0] + ) + + +update_registered_converter( + SparseTruncatedSVD, + "SparseTruncatedSVD", + calculate_sparse_sklearn_truncated_svd_output_shapes, + convert_sparse_truncated_svd, +) + +sparse_onx_name = "model_sparse.onnx" +sparse_onx = to_onnx( + sparse_pipeline, + initial_types=[("input", StringTensorType([None, 1]))], + options={SVC: {"zipmap": False}}, + target_opset=18, +) +print(sparse_onx) +with open(sparse_onx_name, "wb") as f: + f.write(sparse_onx.SerializeToString()) + +print(f"pickle size={os.stat(pkl_name).st_size}") +print(f"onnx size={os.stat(onx_name).st_size}") +print(f"sparse onnx size={os.stat(sparse_onx_name).st_size}") + +############################################ +# Let's check it is working with onnxruntime. + +sess = InferenceSession(sparse_onx_name, providers=["CPUExecutionProvider"]) +got = sess.run(None, {"input": X_train}) +print(got) + + +###################################### +# Conclusion +# ++++++++++ +# +# This option decreases the size of the onnx model by using one +# sparse matrix in the converted pipeline. It may bring an accuracy loss. diff --git a/skl2onnx/algebra/onnx_operator.py b/skl2onnx/algebra/onnx_operator.py index e9bf0542d..41ff1c8f8 100644 --- a/skl2onnx/algebra/onnx_operator.py +++ b/skl2onnx/algebra/onnx_operator.py @@ -7,7 +7,7 @@ from onnx.helper import make_graph, make_model from onnx.numpy_helper import from_array from scipy.sparse import coo_matrix -from ..proto import TensorProto +from ..proto import SparseTensorProto, TensorProto from ..common.data_types import _guess_type_proto_str, _guess_type_proto_str_inv from ..common._topology import ( Variable, @@ -1101,7 +1101,10 @@ def to_onnx( model_name, container.inputs, container.outputs, - container.initializers, + [i for i in container.initializers if isinstance(i, TensorProto)], + sparse_initializer=[ + i for i in container.initializers if isinstance(i, SparseTensorProto) + ], ) onnx_model = make_model(graph) diff --git a/skl2onnx/common/_container.py b/skl2onnx/common/_container.py index 7245612e2..3befd1bfe 100644 --- a/skl2onnx/common/_container.py +++ b/skl2onnx/common/_container.py @@ -472,7 +472,7 @@ def add_initializer(self, name, onnx_type, shape, content): "Sparse matrices require SparseTensorProto. Update onnx." ) values_tensor = make_tensor( - name + "_v", + name, data_type=onnx_type, dims=(len(content.data),), vals=content.data, @@ -547,14 +547,7 @@ def add_initializer(self, name, onnx_type, shape, content): cached_name = self.initializers_strings.get(content, None) if cached_name is None: self.initializers_strings[content] = name - self.add_node( - "Constant", - [], - [name], - sparse_value=sparse_tensor, - op_version=self.target_opset, - name=name + "_op", - ) + self.initializers.append(sparse_tensor) return sparse_tensor self.add_node( @@ -872,8 +865,10 @@ def ensure_topological_order(self): name = inp.name order[name] = 0 for inp in self.initializers: - name = inp.name + name = inp.name if hasattr(inp, "name") else inp.values.name order[name] = 0 + print("#", type(inp), name) + print("---", order) n_iter = 0 missing_ops = [] @@ -891,6 +886,7 @@ def ensure_topological_order(self): else: maxi = None missing_names.add(name) + print("***", name, order, node.input) break if maxi is None: missing_ops.append(node) diff --git a/skl2onnx/common/_onnx_optimisation_common.py b/skl2onnx/common/_onnx_optimisation_common.py index e45753f9c..2f665afd3 100644 --- a/skl2onnx/common/_onnx_optimisation_common.py +++ b/skl2onnx/common/_onnx_optimisation_common.py @@ -181,7 +181,14 @@ def _rename_graph_output(graph, old_name, new_name): outputs.append(value_info) nodes = list(graph.node) nodes.append(_make_node("Identity", [old_name], [new_name])) - new_graph = make_graph(nodes, graph.name, graph.input, outputs, graph.initializer) + new_graph = make_graph( + nodes, + graph.name, + graph.input, + outputs, + graph.initializer, + sparse_initializer=graph.sparse_initializer, + ) new_graph.value_info.extend(graph.value_info) return new_graph @@ -207,7 +214,14 @@ def _rename_graph_input(graph, old_name, new_name): inputs.append(value_info) nodes = list(graph.node) nodes.append(_make_node("Identity", [new_name], [old_name])) - new_graph = make_graph(nodes, graph.name, inputs, graph.output, graph.initializer) + new_graph = make_graph( + nodes, + graph.name, + inputs, + graph.output, + graph.initializer, + sparse_initializer=graph.sparse_initializer, + ) new_graph.value_info.extend(graph.value_info) return new_graph diff --git a/skl2onnx/common/_topology.py b/skl2onnx/common/_topology.py index 1be1db166..b8c2e3427 100644 --- a/skl2onnx/common/_topology.py +++ b/skl2onnx/common/_topology.py @@ -7,7 +7,7 @@ from logging import getLogger from collections import OrderedDict import numpy as np -from onnx import onnx_pb as onnx_proto +from onnx import onnx_pb as onnx_proto, TensorProto, SparseTensorProto from onnx.helper import make_graph, make_model, make_tensor_value_info from onnxconverter_common.data_types import ( # noqa DataType, @@ -1358,9 +1358,10 @@ def _check_variable_out_(variable, operator): fed_variables[variable.onnx_name] = variable fed_variables.update( { - i.name: i + (i.name if hasattr(i, "name") else i.values.name): i for i in container.initializers - if i.name not in fed_variables + if (i.name if hasattr(i, "name") else i.values.name) + not in fed_variables } ) self._propagate_status( @@ -1576,7 +1577,10 @@ def convert_topology( model_name, container.inputs + extra_inputs, container.outputs, - container.initializers, + [i for i in container.initializers if isinstance(i, TensorProto)], + sparse_initializer=[ + i for i in container.initializers if isinstance(i, SparseTensorProto) + ], ) else: # In ONNX opset 9 and above, initializers are included as @@ -1587,7 +1591,10 @@ def convert_topology( model_name, container.inputs, container.outputs, - container.initializers, + [i for i in container.initializers if isinstance(i, TensorProto)], + sparse_initializer=[ + i for i in container.initializers if isinstance(i, SparseTensorProto) + ], ) # Add extra information related to the graph diff --git a/skl2onnx/common/onnx_optimisation_identity.py b/skl2onnx/common/onnx_optimisation_identity.py index f90cdda03..ba9870d3b 100644 --- a/skl2onnx/common/onnx_optimisation_identity.py +++ b/skl2onnx/common/onnx_optimisation_identity.py @@ -169,6 +169,7 @@ def retrieve_local_variables_nodes(nodes): onnx_model.input, onnx_model.output, onnx_model.initializer, + sparse_initializer=onnx_model.sparse_initializer, ) graph.value_info.extend(onnx_model.value_info) diff --git a/skl2onnx/common/utils.py b/skl2onnx/common/utils.py index 3aaf1a4e8..650697d33 100644 --- a/skl2onnx/common/utils.py +++ b/skl2onnx/common/utils.py @@ -4,6 +4,7 @@ from collections import OrderedDict import hashlib import numpy as np +from onnx import helper, TensorProto from onnx.numpy_helper import from_array from onnxconverter_common.utils import sklearn_installed, skl2onnx_installed # noqa from onnxconverter_common.utils import is_numeric_type, is_string_type # noqa @@ -169,10 +170,33 @@ def get_column_indices(indices, inputs, multiple): return onnx_var, onnx_is +def from_coo_matrix(content, name): + onnx_type = helper.np_dtype_to_tensor_dtype(content.dtype) + values_tensor = helper.make_tensor( + name + "_v", + data_type=onnx_type, + dims=(len(content.data),), + vals=content.data, + ) + indices = [i * content.shape[1] + j for i, j in zip(content.row, content.col)] + indices_tensor = helper.make_tensor( + name=name + "_i", + data_type=TensorProto.INT64, + dims=(len(indices),), + vals=indices, + ) + dense_shape = list(content.shape) + sparse_tensor = helper.make_sparse_tensor( + values_tensor, indices_tensor, dense_shape + ) + return sparse_tensor + + def hash_array(value, length=15): "Computes a hash identifying the value." + cvt = from_array if isinstance(value, np.ndarray) else from_coo_matrix try: - onx = from_array(value) + onx = cvt(value, "") except (AttributeError, TypeError) as e: # sparse matrix for example if hasattr(value, "tocoo"): diff --git a/skl2onnx/helpers/onnx_helper.py b/skl2onnx/helpers/onnx_helper.py index 93f13461f..cadba48f5 100644 --- a/skl2onnx/helpers/onnx_helper.py +++ b/skl2onnx/helpers/onnx_helper.py @@ -10,6 +10,7 @@ make_tensor, make_node, make_tensor_value_info, + make_sparse_tensor_value_info, make_graph, make_model, ) @@ -156,6 +157,7 @@ def select_model_inputs_outputs(model, outputs=None, inputs=None): model.graph.input, var_out, model.graph.initializer, + sparse_initializer=model.graph.sparse_initializer, ) onnx_model = make_model(graph) onnx_model.ir_version = model.ir_version @@ -237,6 +239,13 @@ def infer_outputs( input.name, input.data_type.real, list(d for d in input.dims) ) onnx_inputs.append(v) + elif isinstance(input, onnx.SparseTensorProto): + v = make_sparse_tensor_value_info( + input.values.name, + input.values.data_type.real, + list(d for d in input.dims), + ) + onnx_inputs.append(v) elif isinstance(input, onnx.AttributeProto): value_info = ValueInfoProto() value_info.name = input.name @@ -312,6 +321,7 @@ def change_onnx_domain(model, ops): model.graph.input, model.graph.output, model.graph.initializer, + sparse_initializer=model.graph.sparse_initializer, ) onnx_model = make_model(graph) onnx_model.ir_version = model.ir_version @@ -426,7 +436,12 @@ def add_output_initializer(model_onnx, name, value, suffix="_init"): nodes.append(make_node("Identity", [name_init], [name_output])) graph = make_graph( - nodes, model_onnx.graph.name, model_onnx.graph.input, outputs, inits + nodes, + model_onnx.graph.name, + model_onnx.graph.input, + outputs, + inits, + sparse_initializer=model_onnx.graph.sparse_initializer, ) onnx_model = make_model(graph) diff --git a/skl2onnx/helpers/onnx_rare_helper.py b/skl2onnx/helpers/onnx_rare_helper.py index 5321274c5..09c05cce8 100644 --- a/skl2onnx/helpers/onnx_rare_helper.py +++ b/skl2onnx/helpers/onnx_rare_helper.py @@ -59,6 +59,7 @@ def upgrade_opset_number(model, new_opsets): model.graph.input, model.graph.input, model.graph.initializer, + sparse_initializer=model.graph.sparse_initializer, ) onnx_model = make_model(graph) onnx_model.ir_version = model.ir_version diff --git a/tests/test_sklearn_tfidf_transformer_converter_sparse_option.py b/tests/test_sklearn_tfidf_transformer_converter_sparse_option.py new file mode 100644 index 000000000..9f113b58e --- /dev/null +++ b/tests/test_sklearn_tfidf_transformer_converter_sparse_option.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +# coding: utf-8 +import unittest +import packaging.version as pv +import numpy +from numpy.testing import assert_almost_equal +import scipy +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.decomposition import SparsePCA +from sklearn.pipeline import Pipeline +from sklearn.svm import SVC +from sklearn.feature_extraction.text import TfidfVectorizer +from onnxruntime import InferenceSession, __version__ as ort_version +from skl2onnx import convert_sklearn +from skl2onnx.common.data_types import StringTensorType +from test_utils import TARGET_OPSET + + +class DensityTransformer(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self + + def transform(self, X): + return numpy.asarray(X.todense()) + + +class TestSklearnTfidfTransformerConverterSparseOption(unittest.TestCase): + def common_test_model_tfidf_vectorizer_pipeline_cls(self, verbose=False): + if pv.Version(ort_version) >= pv.Version("1.4.0"): + # regression with stopwords in onnxruntime 1.4+ + stopwords = ["theh"] + else: + stopwords = ["the", "and", "is"] + + X_train = numpy.array( + [ + "This is the first document", + "This document is the second document.", + "And this is the third one", + "Is this the first document?", + ] + ).reshape((4, 1)) + y_train = numpy.array([0, 1, 0, 1]) + + model_pipeline = Pipeline( + [ + ( + "vectorizer", + TfidfVectorizer( + stop_words=stopwords, + lowercase=True, + use_idf=True, + ngram_range=(1, 3), + max_features=30000, + ), + ), + ("density", DensityTransformer()), + ("feature_selector", SparsePCA(10, alpha=10)), + ( + "classifier", + SVC( + class_weight="balanced", + kernel="rbf", + gamma="scale", + probability=True, + ), + ), + ] + ) + model_pipeline.fit(X_train.ravel(), y_train) + + step0 = model_pipeline.steps[0][-1].transform(X_train.ravel()) + assert isinstance(step0, scipy.sparse._csr.csr_matrix) + + if len(model_pipeline.steps) == 2: + svc_coef = model_pipeline.steps[1][-1].support_vectors_ + assert isinstance(svc_coef, scipy.sparse._csr.csr_matrix) + if verbose: + sparsity = (svc_coef == 0).sum() / numpy.prod(svc_coef.shape) + print(f"sparsity={sparsity}|{svc_coef.shape}") + else: + pca_coef = model_pipeline.steps[2][-1].components_ + print(type(pca_coef)) + # assert isinstance(pca_coef, scipy.sparse._csr.csr_matrix) + if verbose: + sparsity = (pca_coef == 0).sum() / numpy.prod(pca_coef.shape) + print(f"sparsity={sparsity}|{pca_coef.shape}") + + initial_type = [("input", StringTensorType([None, 1]))] + model_onnx = convert_sklearn( + model_pipeline, + "cv", + initial_types=initial_type, + options={SVC: {"zipmap": False}}, + target_opset=TARGET_OPSET, + ) + + exp = [ + model_pipeline.predict(X_train.ravel()), + model_pipeline.predict_proba(X_train.ravel()), + ] + + sess = InferenceSession( + model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, {"input": X_train}) + if verbose: + voc = model_pipeline.steps[0][-1].vocabulary_ + voc = list(sorted([(v, k) for k, v in voc.items()])) + for kv in voc: + print(kv) + for a, b in zip(exp, got): + if verbose: + print(stopwords) + print(a) + print(b) + assert_almost_equal(a, b) + + def test_sparse(self): + self.common_test_model_tfidf_vectorizer_pipeline_cls(__name__ == "__main__") + + +if __name__ == "__main__": + unittest.main()