Skip to content

Commit

Permalink
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Browse files Browse the repository at this point in the history
Adds support for saving and loading nested ML Pipelines from Python.  Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations.

Also:
* Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader.
* Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java

Added new unit test for nested Pipelines.  Abstracted validity check into a helper method for the 2 unit tests.

Author: Joseph K. Bradley <[email protected]>

Closes apache#11866 from jkbradley/nested-pipeline-io.
Closes apache#11835
  • Loading branch information
jkbradley authored and mengxr committed Mar 22, 2016
1 parent 297c202 commit 7e3423b
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 175 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
HasWeightCol, MLWritable, MLReadable):
HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
Expand Down Expand Up @@ -198,7 +198,7 @@ def _checkThresholdConsistency(self):
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))


class LogisticRegressionModel(JavaModel, MLWritable, MLReadable):
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
Expand Down Expand Up @@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels):

@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
HasRawPredictionCol, MLWritable, MLReadable):
HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. Multinomial NB
Expand Down Expand Up @@ -720,7 +720,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, MLWritable, MLReadable):
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
'KMeans', 'KMeansModel']


class KMeansModel(JavaModel, MLWritable, MLReadable):
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
Expand All @@ -48,7 +48,7 @@ def computeCost(self, dataset):

@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
MLWritable, MLReadable):
JavaMLWritable, JavaMLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
Expand Down
Loading

0 comments on commit 7e3423b

Please sign in to comment.