From 773b3b170e894ea7e2f09347d80190359fc84adf Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 30 Aug 2017 12:32:06 -0700 Subject: [PATCH 01/98] initial commit with some ideas on how to do multilabel. --- .../aloha/models/MultilabelModel.scala | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala new file mode 100644 index 00000000..ccab20ac --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala @@ -0,0 +1,108 @@ +package com.eharmony.aloha.models + +import com.eharmony.aloha.audit.Auditor +import com.eharmony.aloha.factory._ +import com.eharmony.aloha.id.ModelIdentity +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.semantics.Semantics +import spray.json.{JsonFormat, JsonReader} + +sealed trait LabelExtraction[-A, K] +final case class KnownLabelExtraction[K](labels: Seq[K]) extends LabelExtraction[Any, K] +final case class PerExampleLabelExtraction[A, K](extractor: A => Seq[K]) extends LabelExtraction[A, K] + +/** + * Created by ryan.deak on 8/29/17. + * + * @param modelId + * @param labelExtraction + * @param labelDependentFeatures + * @param featureExtraction + * @param predictorProducer + * @param auditor + * @tparam U upper bound on model output type `B` + * @tparam F type of features produced + * @tparam K type of label or class + * @tparam A input type of the model + * @tparam B output type of the model. + */ +case class MultilabelModel[U, F, K, -A, +B <: U]( + modelId: ModelIdentity, + labelExtraction: LabelExtraction[A, K], + labelDependentFeatures: Seq[K => F], + featureExtraction: Seq[A => F], + predictorProducer: () => (Seq[F], Seq[K]) => Map[K, Double], + auditor: Auditor[U, Map[K, Double], B] +) +extends SubmodelBase[U, Map[K, Double], A, B] { + + @transient private[this] lazy val predictor = predictorProducer() + + { + // Force predictor eagerly + predictor + } + + private val globalLdf = labelExtraction match { + case KnownLabelExtraction(labels) => Option(applyLdf(labels)) + case _ => None + } + + private[this] def applyLdf(labels: Seq[K]): Seq[Seq[F]] = + labels.map(label => labelDependentFeatures.map(f => f(label))) + + override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { + val labels = labelExtraction match { + case KnownLabelExtraction(ls) => ls + case PerExampleLabelExtraction(extractor) => extractor(a) + } + + // Get the label-dependent features. + val ldf = globalLdf getOrElse applyLdf(labels) + + val features = featureExtraction.map(f => f(a)) + val natural = predictor(features, labels) + val aud = auditor.success(modelId, natural) + Subvalue(aud, Option(natural)) + } +} + +object MultilabelModel extends ParserProviderCompanion { + + object Parser extends ModelParsingPlugin { + override val modelType: String = "multilabel" + override def modelJsonReader[U, N, A, B <: U]( + factory: SubmodelFactory[U, A], + semantics: Semantics[A], + auditor: Auditor[U, N, B] + )(implicit r: RefInfo[N], jf: JsonFormat[N]): Option[JsonReader[MultilabelModel[U, _, _, A, B]]] = { + + if (!RefInfoOps.isSubType[N, Map[_, Double]]) + None + else { + // Because N is a subtype of map, it "should" have two type parameters. + // This is obviously not true in all cases, like with LongMap + // http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap + // TODO: Make this more robust. + val refInfoK = RefInfoOps.typeParams(r).head + + // To allow custom class (key) types, we'll need to create a custom ModelFactoryImpl instance + // with a specialized RefInfoToJsonFormat. + // + // type: Option[JsonFormat[_]] + val jsonFormatK = factory.jsonFormat(refInfoK) + + // TODO: parse the label extraction + + // TODO: parse the feature extraction + + // TODO: parse the native submodel from the wrapped ML library. This involves plugins + + + ??? + } + } + } + + override def parser: ModelParser = Parser +} From 66817f7bcdf6d9b6305b3d8c373dd8cb7e8aff0d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 30 Aug 2017 13:54:45 -0700 Subject: [PATCH 02/98] Additional hacking on multilabel model. --- .../aloha/models/MultilabelModel.scala | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala index ccab20ac..7ec2f61d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala @@ -5,11 +5,15 @@ import com.eharmony.aloha.factory._ import com.eharmony.aloha.id.ModelIdentity import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics +import com.eharmony.aloha.semantics.func.GenAggFunc import spray.json.{JsonFormat, JsonReader} -sealed trait LabelExtraction[-A, K] +import scala.collection.{immutable, immutable => sci} + + +sealed trait LabelExtraction[-A, +K] final case class KnownLabelExtraction[K](labels: Seq[K]) extends LabelExtraction[Any, K] -final case class PerExampleLabelExtraction[A, K](extractor: A => Seq[K]) extends LabelExtraction[A, K] +final case class PerExampleLabelExtraction[A, K](extractor: GenAggFunc[A, sci.IndexedSeq[K]]) extends LabelExtraction[A, K] /** * Created by ryan.deak on 8/29/17. @@ -26,12 +30,19 @@ final case class PerExampleLabelExtraction[A, K](extractor: A => Seq[K]) extends * @tparam A input type of the model * @tparam B output type of the model. */ + +/* + featureNames: sci.IndexedSeq[String], + featureFunctions: sci.IndexedSeq[GenAggFunc[A, Iterable[(String, Double)]]], + */ case class MultilabelModel[U, F, K, -A, +B <: U]( modelId: ModelIdentity, labelExtraction: LabelExtraction[A, K], - labelDependentFeatures: Seq[K => F], - featureExtraction: Seq[A => F], - predictorProducer: () => (Seq[F], Seq[K]) => Map[K, Double], + labelDependentFeatures: sci.IndexedSeq[GenAggFunc[K, F]], + featureExtraction: sci.IndexedSeq[GenAggFunc[A, F]], + + // TODO: Make this a type alias or trait or something. + predictorProducer: () => (Seq[F], Seq[K], Seq[sci.IndexedSeq[F]]) => Map[K, Double], auditor: Auditor[U, Map[K, Double], B] ) extends SubmodelBase[U, Map[K, Double], A, B] { @@ -48,21 +59,29 @@ extends SubmodelBase[U, Map[K, Double], A, B] { case _ => None } - private[this] def applyLdf(labels: Seq[K]): Seq[Seq[F]] = + private[this] def applyLdf(labels: Seq[K]): Seq[sci.IndexedSeq[F]] = labels.map(label => labelDependentFeatures.map(f => f(label))) override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { + + // Get the labels for which predictions should be produced. val labels = labelExtraction match { - case KnownLabelExtraction(ls) => ls + case KnownLabelExtraction(constLabels) => constLabels case PerExampleLabelExtraction(extractor) => extractor(a) } // Get the label-dependent features. - val ldf = globalLdf getOrElse applyLdf(labels) + // TODO: Handle missing data in the extraction process like in RegressionFeatures.constructFeatures + val ldf: Seq[sci.IndexedSeq[F]] = globalLdf getOrElse applyLdf(labels) + + // TODO: Handle missing data in the extraction process like in RegressionFeatures.constructFeatures + val features: sci.IndexedSeq[F] = featureExtraction.map(f => f(a)) - val features = featureExtraction.map(f => f(a)) - val natural = predictor(features, labels) - val aud = auditor.success(modelId, natural) + // predictor is responsible for getting the data into the correct type and applying + // it within the underlying ML library to produce a prediction. The mapping back to + // (K, Double) pairs is also its responsibility. + val natural: Map[K, Double] = predictor(features, labels, ldf) + val aud: B = auditor.success(modelId, natural) Subvalue(aud, Option(natural)) } } From c2516f5ae35adc26c70196f76f24801768e686b0 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 12:21:56 -0700 Subject: [PATCH 03/98] Added multilabel type aliases. Updated MultilabelModel but it's broken. --- .../aloha/models/MultilabelModel.scala | 43 +++++++------ .../aloha/models/multilabel/package.scala | 64 +++++++++++++++++++ 2 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala index 7ec2f61d..ae9afdc4 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala @@ -1,48 +1,46 @@ package com.eharmony.aloha.models import com.eharmony.aloha.audit.Auditor +import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.factory._ import com.eharmony.aloha.id.ModelIdentity +import com.eharmony.aloha.models.multilabel.SparsePredictorProducer import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.GenAggFunc import spray.json.{JsonFormat, JsonReader} -import scala.collection.{immutable, immutable => sci} +import scala.collection.{immutable => sci} -sealed trait LabelExtraction[-A, +K] -final case class KnownLabelExtraction[K](labels: Seq[K]) extends LabelExtraction[Any, K] -final case class PerExampleLabelExtraction[A, K](extractor: GenAggFunc[A, sci.IndexedSeq[K]]) extends LabelExtraction[A, K] - /** + * * Created by ryan.deak on 8/29/17. * * @param modelId - * @param labelExtraction + * @param labelInTrainingSet + * @param labelsOfInterest + * @param featureNames + * @param featureFunctions * @param labelDependentFeatures - * @param featureExtraction * @param predictorProducer * @param auditor * @tparam U upper bound on model output type `B` - * @tparam F type of features produced * @tparam K type of label or class * @tparam A input type of the model * @tparam B output type of the model. */ - -/* - featureNames: sci.IndexedSeq[String], - featureFunctions: sci.IndexedSeq[GenAggFunc[A, Iterable[(String, Double)]]], - */ -case class MultilabelModel[U, F, K, -A, +B <: U]( +case class MultilabelModel[U, K, -A, +B <: U]( modelId: ModelIdentity, - labelExtraction: LabelExtraction[A, K], - labelDependentFeatures: sci.IndexedSeq[GenAggFunc[K, F]], - featureExtraction: sci.IndexedSeq[GenAggFunc[A, F]], + labelInTrainingSet: Seq[K], - // TODO: Make this a type alias or trait or something. - predictorProducer: () => (Seq[F], Seq[K], Seq[sci.IndexedSeq[F]]) => Map[K, Double], + labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], + + featureNames: sci.IndexedSeq[String], + featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]], + labelDependentFeatures: sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[Sparse]]], + + predictorProducer: SparsePredictorProducer[K], auditor: Auditor[U, Map[K, Double], B] ) extends SubmodelBase[U, Map[K, Double], A, B] { @@ -54,6 +52,13 @@ extends SubmodelBase[U, Map[K, Double], A, B] { predictor } + // TODO: Create the label dependent features once and then select them based on the labels created from the example. + // This is going to be very difficult when the label dependent feature functions are produced by the factory + // because we don't have a Semantics[K], we only have a semantics[A]. So, if we want to cut down on computation + // time, we can cache the features using a scala.collection.concurrent.TrieMap or a + // java.util.concurrent.ConcurrentHashMap. This should be fine AS LONG AS the functions in labelDependentFeatures + // are idempotent. That should be a stated in the documentation for that parameter. + private val globalLdf = labelExtraction match { case KnownLabelExtraction(labels) => Option(applyLdf(labels)) case _ => None diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala new file mode 100644 index 00000000..84715d1b --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -0,0 +1,64 @@ +package com.eharmony.aloha.models + +import com.eharmony.aloha.dataset.density.Sparse +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 8/31/17. + */ +package object multilabel { + + /** + * Features about the input value (NOT including features based on labels). + */ + private type SparseFeatures = sci.IndexedSeq[Sparse] + + /** + * Indices of the labels for which predictions should be produced into the + * sequence of all labels. Indices will be sorted in ascending order. + */ + private type LabelIndices = sci.IndexedSeq[Int] + + /** + * Labels for which predictions should be produced. This can be an improper subset of all labels. + * `Labels` align with `LabelIndices` meaning `LabelIndices[i]` will give the index of `Labels[i]` + * into the sequence of all labels a model knows about. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + private type Labels[K] = sci.IndexedSeq[K] + + /** + * Features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` + * sequences meaning `SparseLabelDepFeatures[i]` relates to the features of `Labels[i]`. + */ + private type SparseLabelDepFeatures = sci.IndexedSeq[sci.IndexedSeq[Sparse]] + + /** + * A sparse multi-label predictor takes: + * + - features + - labels for which a prediction should be produced + - indices of those labels into sequence of all of the labels the model knows about. + - label dependent-features + * + * and returns a Map from the labels passed in, to the prediction associated with the label. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + private type SparseMultiLabelPredictor[K] = + (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] + + /** + * A lazy version of a sparse multi-label predictor. It is a curried zero-arg function that + * produces a sparse multi-label predictor. + * + * This definition is "lazy" because we can't guarantee that the underlying predictor is + * `Serializable` so we pass around a function that can be cached in a ''transient'' + * `lazy val`. This function should however be `Serializable` and testing should be done + * to ensure that each predictor producer is `Serializable`. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + type SparsePredictorProducer[K] = () => SparseMultiLabelPredictor[K] +} From 73d378b08fd06db109a3be6c8a2902a98407ef8d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 12:42:49 -0700 Subject: [PATCH 04/98] Added additional commentst. --- .../aloha/models/MultilabelModel.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala index ae9afdc4..49476e07 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala @@ -38,6 +38,28 @@ case class MultilabelModel[U, K, -A, +B <: U]( featureNames: sci.IndexedSeq[String], featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]], + + // Here is a problem. B/C we only have a Semantics[A] (and not a Semantics[K]) in the factory, + // we can't produce sci.IndexedSeq[GenAggFunc[K, Sparse]] representing the many feature functions + // that will be applied to each label to produce the sparse values. We CAN produce a + // sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[Sparse]]] where sci.IndexedSeq[Sparse] is the result + // of one feature applied to all applicable labels. Many feature definitions + // are possible and that's why there is a sci.IndexedSeq[GenAggFunc[...]]. But the issue arises + // of how to link the indices of sci.IndexedSeq[Sparse] to the indices of labelsOfInterest, + // especially if labelsOfInterest is an Option and not required. + // + // We can have an "honor system" based approach where we kindly ask the caller to behave and make + // the sequences line up, but that always seems like a recipe for disaster. + // + // We could do sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[(K, Sparse)]]] and remove the honor + // system approach but that seems like a lot of computational overhead. + // + // It would be nice if we could use MorphableSemantics to change Semantics[A] into Semantics[K] + // but this is fraught with problems. For instance, with the Avro semantics, we have to supply + // parameters to create a CompiledSemanticsAvroPlugin. Therefore, we'd have to use reflection + // to see how K relates to A, and it's not guaranteed that K is embedded in A and that K is a + // GenericRecord. + // labelDependentFeatures: sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[Sparse]]], predictorProducer: SparsePredictorProducer[K], From 48a61570518341fb45c8fb583a772cd6ffa4ca7c Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 14:04:52 -0700 Subject: [PATCH 05/98] Made SparseLabelDepFeatures type alias a little nicer but abused notation a little. --- .../scala/com/eharmony/aloha/models/multilabel/package.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 84715d1b..85de5ecc 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -29,10 +29,10 @@ package object multilabel { private type Labels[K] = sci.IndexedSeq[K] /** - * Features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` + * Sparse features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` * sequences meaning `SparseLabelDepFeatures[i]` relates to the features of `Labels[i]`. */ - private type SparseLabelDepFeatures = sci.IndexedSeq[sci.IndexedSeq[Sparse]] + private type SparseLabelDepFeatures = Labels[SparseFeatures] /** * A sparse multi-label predictor takes: From 0ebc1089a67e9318ac9e62adf695de5e82bd0f72 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 17:53:36 -0700 Subject: [PATCH 06/98] New plan. No label-dependent features for now. --- .../aloha/models/MultilabelModel.scala | 134 ++++++++++-------- .../aloha/models/multilabel/package.scala | 5 +- 2 files changed, 79 insertions(+), 60 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala index 49476e07..95fb307c 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala @@ -18,11 +18,10 @@ import scala.collection.{immutable => sci} * Created by ryan.deak on 8/29/17. * * @param modelId - * @param labelInTrainingSet + * @param labelsInTrainingSet * @param labelsOfInterest * @param featureNames * @param featureFunctions - * @param labelDependentFeatures * @param predictorProducer * @param auditor * @tparam U upper bound on model output type `B` @@ -30,42 +29,23 @@ import scala.collection.{immutable => sci} * @tparam A input type of the model * @tparam B output type of the model. */ +// TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. +// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] will need to derived a from Semantics[A]. +// TODO: MorphableSemantics provides this. If K is *embedded inside* A, it should be possible in some cases. case class MultilabelModel[U, K, -A, +B <: U]( modelId: ModelIdentity, - labelInTrainingSet: Seq[K], + labelsInTrainingSet: sci.IndexedSeq[K], labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], featureNames: sci.IndexedSeq[String], featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]], - // Here is a problem. B/C we only have a Semantics[A] (and not a Semantics[K]) in the factory, - // we can't produce sci.IndexedSeq[GenAggFunc[K, Sparse]] representing the many feature functions - // that will be applied to each label to produce the sparse values. We CAN produce a - // sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[Sparse]]] where sci.IndexedSeq[Sparse] is the result - // of one feature applied to all applicable labels. Many feature definitions - // are possible and that's why there is a sci.IndexedSeq[GenAggFunc[...]]. But the issue arises - // of how to link the indices of sci.IndexedSeq[Sparse] to the indices of labelsOfInterest, - // especially if labelsOfInterest is an Option and not required. - // - // We can have an "honor system" based approach where we kindly ask the caller to behave and make - // the sequences line up, but that always seems like a recipe for disaster. - // - // We could do sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[(K, Sparse)]]] and remove the honor - // system approach but that seems like a lot of computational overhead. - // - // It would be nice if we could use MorphableSemantics to change Semantics[A] into Semantics[K] - // but this is fraught with problems. For instance, with the Avro semantics, we have to supply - // parameters to create a CompiledSemanticsAvroPlugin. Therefore, we'd have to use reflection - // to see how K relates to A, and it's not guaranteed that K is embedded in A and that K is a - // GenericRecord. - // - labelDependentFeatures: sci.IndexedSeq[GenAggFunc[A, sci.IndexedSeq[Sparse]]], - predictorProducer: SparsePredictorProducer[K], auditor: Auditor[U, Map[K, Double], B] ) extends SubmodelBase[U, Map[K, Double], A, B] { + import MultilabelModel._ @transient private[this] lazy val predictor = predictorProducer() @@ -74,55 +54,93 @@ extends SubmodelBase[U, Map[K, Double], A, B] { predictor } - // TODO: Create the label dependent features once and then select them based on the labels created from the example. - // This is going to be very difficult when the label dependent feature functions are produced by the factory - // because we don't have a Semantics[K], we only have a semantics[A]. So, if we want to cut down on computation - // time, we can cache the features using a scala.collection.concurrent.TrieMap or a - // java.util.concurrent.ConcurrentHashMap. This should be fine AS LONG AS the functions in labelDependentFeatures - // are idempotent. That should be a stated in the documentation for that parameter. - - private val globalLdf = labelExtraction match { - case KnownLabelExtraction(labels) => Option(applyLdf(labels)) - case _ => None - } + private[this] val labelToInd: Map[K, Int] = + labelsInTrainingSet.zipWithIndex.map { case (label, i) => label -> i }(collection.breakOut) - private[this] def applyLdf(labels: Seq[K]): Seq[sci.IndexedSeq[F]] = - labels.map(label => labelDependentFeatures.map(f => f(label))) override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { - // Get the labels for which predictions should be produced. - val labels = labelExtraction match { - case KnownLabelExtraction(constLabels) => constLabels - case PerExampleLabelExtraction(extractor) => extractor(a) - } + // TODO: Is this good enough? Are we tracking enough missing information? Probably not. + val (indices, labelsToPredict, labelsWithNoPrediction) = + labelsOfInterest.map ( labelFn => + labelsForPrediction(a, labelFn, labelToInd) + ) getOrElse { + (labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty) + } - // Get the label-dependent features. - // TODO: Handle missing data in the extraction process like in RegressionFeatures.constructFeatures - val ldf: Seq[sci.IndexedSeq[F]] = globalLdf getOrElse applyLdf(labels) + // TODO: Do a better job of tracking missing features like in the RegressionFeatures trait. + val features = featureFunctions.map(f => f(a)) - // TODO: Handle missing data in the extraction process like in RegressionFeatures.constructFeatures - val features: sci.IndexedSeq[F] = featureExtraction.map(f => f(a)) + // TODO: If labelsToPredict is empty, don't run predictor. Return failure and report inside. // predictor is responsible for getting the data into the correct type and applying // it within the underlying ML library to produce a prediction. The mapping back to // (K, Double) pairs is also its responsibility. - val natural: Map[K, Double] = predictor(features, labels, ldf) - val aud: B = auditor.success(modelId, natural) + // + // TODO: When supporting label-dependent features, fill in the last parameter with a valid value. + // TODO: Consider wrapping in a Try + val natural = predictor(features, labelsToPredict, indices, sci.IndexedSeq.empty) + + val errors = + if (labelsWithNoPrediction.nonEmpty) + Seq(s"Labels provide for which a prediction could not be produced: ${labelsWithNoPrediction.mkString(", ")}.") + else Seq.empty + + // TODO: Incorporate missing data reporting. + val aud: B = auditor.success(modelId, natural, errorMsgs = errors) + Subvalue(aud, Option(natural)) } + + override def close(): Unit = predictor.close() } object MultilabelModel extends ParserProviderCompanion { - object Parser extends ModelParsingPlugin { - override val modelType: String = "multilabel" - override def modelJsonReader[U, N, A, B <: U]( + protected[models] def labelsForPrediction[A, K]( + a: A, + labelsOfInterest: GenAggFunc[A, sci.IndexedSeq[K]], + labelToInd: Map[K, Int] + ): (sci.IndexedSeq[Int], sci.IndexedSeq[K], Seq[K]) = { + + val labelsShouldPredict = labelsOfInterest(a) + + val unsorted = + for { + label <- labelsShouldPredict + ind <- labelToInd.get(label).toList + } yield (ind, label) + + val noPrediction = + if (unsorted.size == labelsShouldPredict.size) Seq.empty + else labelsShouldPredict.filterNot(labelToInd.contains) + + val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip + + (ind, lab, noPrediction) + } + + override def parser: ModelParser = Parser + + object Parser extends ModelSubmodelParsingPlugin { + override val modelType: String = "multilabel-sparse" + + // TODO: Figure if a Option[JsonReader[MultilabelModel[U, _, A, B]]] can be returned. + // See: parser that returns SegmentationModel[U, _, N, A, B] + // See: parser that returns RegressionModel[U, A, B] + // Seems like this should be possible but we get the error: + // + // [error] method commonJsonReader has incompatible type + // [error] override def commonJsonReader[U, N, A, B <: U]( + // [error] ^ + // + override def commonJsonReader[U, N, A, B <: U]( factory: SubmodelFactory[U, A], semantics: Semantics[A], - auditor: Auditor[U, N, B] - )(implicit r: RefInfo[N], jf: JsonFormat[N]): Option[JsonReader[MultilabelModel[U, _, _, A, B]]] = { - + auditor: Auditor[U, N, B])(implicit + r: RefInfo[N], + jf: JsonFormat[N] + ): Option[JsonReader[_ <: Model[A, B] with Submodel[_, A, B]]] = { if (!RefInfoOps.isSubType[N, Map[_, Double]]) None else { @@ -149,6 +167,4 @@ object MultilabelModel extends ParserProviderCompanion { } } } - - override def parser: ModelParser = Parser } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 85de5ecc..dba89370 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -1,6 +1,9 @@ package com.eharmony.aloha.models +import java.io.Closeable + import com.eharmony.aloha.dataset.density.Sparse + import scala.collection.{immutable => sci} /** @@ -47,7 +50,7 @@ package object multilabel { * @tparam K the type of labels (or classes in the machine learning literature). */ private type SparseMultiLabelPredictor[K] = - (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] + ((SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double]) with Closeable /** * A lazy version of a sparse multi-label predictor. It is a curried zero-arg function that From 8ba3e6eff7e8e8246126b73a430ef1c1ae1c2564 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 17:54:56 -0700 Subject: [PATCH 07/98] moved MultilabelModel to multilabel package. --- .../aloha/models/{ => multilabel}/MultilabelModel.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename aloha-core/src/main/scala/com/eharmony/aloha/models/{ => multilabel}/MultilabelModel.scala (97%) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala similarity index 97% rename from aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala rename to aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 95fb307c..0f82708d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -1,10 +1,10 @@ -package com.eharmony.aloha.models +package com.eharmony.aloha.models.multilabel import com.eharmony.aloha.audit.Auditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.factory._ import com.eharmony.aloha.id.ModelIdentity -import com.eharmony.aloha.models.multilabel.SparsePredictorProducer +import com.eharmony.aloha.models._ import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.GenAggFunc @@ -97,7 +97,7 @@ extends SubmodelBase[U, Map[K, Double], A, B] { object MultilabelModel extends ParserProviderCompanion { - protected[models] def labelsForPrediction[A, K]( + protected[multilabel] def labelsForPrediction[A, K]( a: A, labelsOfInterest: GenAggFunc[A, sci.IndexedSeq[K]], labelToInd: Map[K, Int] From 5931ef0f74b74acff2c4b57d877073fa954c2aba Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 17:56:05 -0700 Subject: [PATCH 08/98] updated comments --- .../com/eharmony/aloha/models/multilabel/MultilabelModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 0f82708d..875a4b3a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -30,7 +30,7 @@ import scala.collection.{immutable => sci} * @tparam B output type of the model. */ // TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. -// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] will need to derived a from Semantics[A]. +// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] needs to be derived from a Semantics[A]. // TODO: MorphableSemantics provides this. If K is *embedded inside* A, it should be possible in some cases. case class MultilabelModel[U, K, -A, +B <: U]( modelId: ModelIdentity, From 5e1f1dcdec83b9cae0a4e6112f1961b4fa952cbd Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 17:59:36 -0700 Subject: [PATCH 09/98] Removed the requirements that SparseMultiLabelPredictor is Closeable. --- .../aloha/models/multilabel/MultilabelModel.scala | 8 +++++++- .../com/eharmony/aloha/models/multilabel/package.scala | 4 +--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 875a4b3a..6e9128a7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -1,5 +1,7 @@ package com.eharmony.aloha.models.multilabel +import java.io.Closeable + import com.eharmony.aloha.audit.Auditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.factory._ @@ -92,7 +94,11 @@ extends SubmodelBase[U, Map[K, Double], A, B] { Subvalue(aud, Option(natural)) } - override def close(): Unit = predictor.close() + override def close(): Unit = + predictor match { + case closeable: Closeable => closeable.close() + case _ => + } } object MultilabelModel extends ParserProviderCompanion { diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index dba89370..06880bd7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -1,7 +1,5 @@ package com.eharmony.aloha.models -import java.io.Closeable - import com.eharmony.aloha.dataset.density.Sparse import scala.collection.{immutable => sci} @@ -50,7 +48,7 @@ package object multilabel { * @tparam K the type of labels (or classes in the machine learning literature). */ private type SparseMultiLabelPredictor[K] = - ((SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double]) with Closeable + (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] /** * A lazy version of a sparse multi-label predictor. It is a curried zero-arg function that From 0aea10ce4fc2d7d7e4a8c2b15a3249dae0417a16 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 31 Aug 2017 18:01:52 -0700 Subject: [PATCH 10/98] Added comment to predictor. --- .../eharmony/aloha/models/multilabel/MultilabelModel.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 6e9128a7..ab8073a1 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -49,6 +49,10 @@ case class MultilabelModel[U, K, -A, +B <: U]( extends SubmodelBase[U, Map[K, Double], A, B] { import MultilabelModel._ + /** + * predictory is transient lazy value because we don't need to worry about serialization. + * We don't care about the lazy property. It should be created eagerly. + */ @transient private[this] lazy val predictor = predictorProducer() { From 85229f988f8ffebb40962f3059cf3fdde9d37c9a Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 10:43:37 -0700 Subject: [PATCH 11/98] Updated skeleton of MultilabelModel and small change to RegressionFeatures. --- .../models/multilabel/MultilabelModel.scala | 266 ++++++++++++++---- .../aloha/models/multilabel/package.scala | 6 +- .../aloha/models/reg/RegressionFeatures.scala | 7 +- 3 files changed, 218 insertions(+), 61 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index ab8073a1..7d77f1e3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -7,46 +7,66 @@ import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.factory._ import com.eharmony.aloha.id.ModelIdentity import com.eharmony.aloha.models._ +import com.eharmony.aloha.models.reg.RegressionFeatures import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics -import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.semantics.func.{GenAggFunc, GenAggFuncAccessorProblems} import spray.json.{JsonFormat, JsonReader} -import scala.collection.{immutable => sci} +import scala.collection.{immutable => sci, mutable => scm} +import scala.util.{Failure, Success, Try} +// TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. +// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] needs to be derived from a Semantics[A]. +// TODO: MorphableSemantics provides this. If K is *embedded inside* A, it should be possible in some cases. +// TODO: An alternative is to pass a Map[K, Sparse], Map[K, Option[Sparse]], Map[K, Seq[Sparse]] or something. +// TODO: Directly passing the map of LDFs avoids the need to derive a Semantics[K]. This is easier to code. +// TODO: Directly passing LDFs would however be more burdensome to the data scientists. + /** + * A multi-label predictor. * * Created by ryan.deak on 8/29/17. * - * @param modelId - * @param labelsInTrainingSet - * @param labelsOfInterest - * @param featureNames - * @param featureFunctions - * @param predictorProducer - * @param auditor + * @param modelId An identifier for the model. User in score and error reporting. + * @param featureNames feature names (parallel to featureFunctions) + * @param featureFunctions feature extracting functions. + * @param labelsInTrainingSet a sequence of all labels encountered during training. Note: the + * order of labels may relate to the predictor produced by + * predictorProducer. It is the caller's responsibility to ensure + * the order is correct. To mitigate such problems, both labels + * and indices into labelsInTrainingSet are passed to the predictor + * produced by predictorProducer. + * @param labelsOfInterest if provided, a sequence of labels will be extracted from the example + * for which a prediction is desired. The ''intersection'' of the + * extracted labels and the training labels will be the labels for which + * predictions will be produced. + * @param predictorProducer the function produced when calling this function is responsible for + * getting the data into the correct type and using it within an + * underlying ML library to produce a prediction. The mapping back to + * (K, Double) pairs is also its responsibility. + * @param numMissingThreshold if provided, we check whether the threshold is exceeded. If so, + * return an error instead of the computed score. This is for missing + * data situations. + * @param auditor transforms a `Map[K, Double]` to a `B`. Reports successes and errors. * @tparam U upper bound on model output type `B` * @tparam K type of label or class * @tparam A input type of the model * @tparam B output type of the model. */ -// TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. -// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] needs to be derived from a Semantics[A]. -// TODO: MorphableSemantics provides this. If K is *embedded inside* A, it should be possible in some cases. case class MultilabelModel[U, K, -A, +B <: U]( modelId: ModelIdentity, - labelsInTrainingSet: sci.IndexedSeq[K], - - labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], - featureNames: sci.IndexedSeq[String], featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]], - + labelsInTrainingSet: sci.IndexedSeq[K], + labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], predictorProducer: SparsePredictorProducer[K], - auditor: Auditor[U, Map[K, Double], B] -) -extends SubmodelBase[U, Map[K, Double], A, B] { + numMissingThreshold: Option[Int], + auditor: Auditor[U, Map[K, Double], B]) +extends SubmodelBase[U, Map[K, Double], A, B] + with RegressionFeatures[A] { + import MultilabelModel._ /** @@ -54,48 +74,31 @@ extends SubmodelBase[U, Map[K, Double], A, B] { * We don't care about the lazy property. It should be created eagerly. */ @transient private[this] lazy val predictor = predictorProducer() - - { - // Force predictor eagerly - predictor - } + predictor // Force predictor eagerly private[this] val labelToInd: Map[K, Int] = labelsInTrainingSet.zipWithIndex.map { case (label, i) => label -> i }(collection.breakOut) - override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { + val li = labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) - // TODO: Is this good enough? Are we tracking enough missing information? Probably not. - val (indices, labelsToPredict, labelsWithNoPrediction) = - labelsOfInterest.map ( labelFn => - labelsForPrediction(a, labelFn, labelToInd) - ) getOrElse { - (labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty) - } - - // TODO: Do a better job of tracking missing features like in the RegressionFeatures trait. - val features = featureFunctions.map(f => f(a)) - - // TODO: If labelsToPredict is empty, don't run predictor. Return failure and report inside. - - // predictor is responsible for getting the data into the correct type and applying - // it within the underlying ML library to produce a prediction. The mapping back to - // (K, Double) pairs is also its responsibility. - // - // TODO: When supporting label-dependent features, fill in the last parameter with a valid value. - // TODO: Consider wrapping in a Try - val natural = predictor(features, labelsToPredict, indices, sci.IndexedSeq.empty) - - val errors = - if (labelsWithNoPrediction.nonEmpty) - Seq(s"Labels provide for which a prediction could not be produced: ${labelsWithNoPrediction.mkString(", ")}.") - else Seq.empty + if (li.labels.isEmpty) + reportNoPrediction(modelId, li, auditor) + else { + val Features(x, missing, missingOk) = constructFeatures(a) - // TODO: Incorporate missing data reporting. - val aud: B = auditor.success(modelId, natural, errorMsgs = errors) + if (!missingOk) + reportTooManyMissing(modelId, li, missing, auditor) + else { + // TODO: To support label-dependent features, fill last parameter with a valid value. + val predictionTry = Try { predictor(x, li.labels, li.indices, sci.IndexedSeq.empty) } - Subvalue(aud, Option(natural)) + predictionTry match { + case Success(pred) => reportSuccess(modelId, li, missing, pred, auditor) + case Failure(ex) => reportPredictorError(modelId, li, missing, ex, auditor) + } + } + } } override def close(): Unit = @@ -107,11 +110,156 @@ extends SubmodelBase[U, Map[K, Double], A, B] { object MultilabelModel extends ParserProviderCompanion { + /** + * + * @param indices + * @param labels + * @param missingLabels + * @param problems + * @tparam K + */ + protected[multilabel] case class LabelsAndInfo[K]( + indices: sci.IndexedSeq[Int], + labels: sci.IndexedSeq[K], + missingLabels: Seq[K], + problems: Option[GenAggFuncAccessorProblems] + ) + + /** + * + * @param a + * @param labelsInTrainingSet + * @param labelsOfInterest + * @param labelToInd + * @tparam A + * @tparam K + * @return + */ + protected[multilabel] def labelsAndInfo[A, K]( + a: A, + labelsInTrainingSet: sci.IndexedSeq[K], + labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], + labelToInd: Map[K, Int] + ): LabelsAndInfo[K] = { + // TODO: Is this good enough? Are we tracking enough missing information? Probably not. + labelsOfInterest.map ( labelFn => + labelsForPrediction(a, labelFn, labelToInd) + ) getOrElse { + LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) + } + } + + /** + * + * @param modelId + * @param missing + * @param auditor + * @tparam U + * @tparam K + * @tparam B + * @return + */ + protected[multilabel] def reportTooManyMissing[U, K, B]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + // TODO: Fill in the errors. + val aud = auditor.failure(modelId, missingVarNames = missing.values.flatten.toSet) + Subvalue(aud, None) + } + + /** + * + * @param modelId + * @param labelInfo + * @param auditor + * @tparam U + * @tparam K + * @tparam B + * @return + */ + protected[multilabel] def reportNoPrediction[U, K, B]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + // TODO: Fill in the errors. + val aud = auditor.failure(modelId) + Subvalue(aud, None) + } + + /** + * + * @param modelId + * @param labelInfo + * @param missing + * @param prediction + * @param auditor + * @tparam U + * @tparam K + * @tparam B + * @return + */ + protected[multilabel] def reportSuccess[U, K, B]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]], + prediction: Map[K, Double], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Map[K, Double]] = { + + val errors = + if (labelInfo.missingLabels.nonEmpty) + Seq(s"Labels provide for which a prediction could not be produced: ${labelInfo.missingLabels.mkString(", ")}.") + else Seq.empty + + // TODO: Incorporate missing data reporting. + val aud: B = auditor.success(modelId, prediction, errorMsgs = errors) + + Subvalue(aud, Option(prediction)) + } + + /** + * + * @param modelId + * @param labelInfo + * @param missing + * @param throwable + * @param auditor + * @tparam U + * @tparam K + * @tparam B + * @return + */ + protected[multilabel] def reportPredictorError[U, K, B]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]], + throwable: Throwable, + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + + // TODO: Fill in. + val aud = auditor.failure(modelId) + Subvalue(aud, None) + } + + /** + * + * @param a + * @param labelsOfInterest + * @param labelToInd + * @tparam A + * @tparam K + * @return + */ protected[multilabel] def labelsForPrediction[A, K]( a: A, labelsOfInterest: GenAggFunc[A, sci.IndexedSeq[K]], labelToInd: Map[K, Int] - ): (sci.IndexedSeq[Int], sci.IndexedSeq[K], Seq[K]) = { + ): LabelsAndInfo[K] = { val labelsShouldPredict = labelsOfInterest(a) @@ -121,15 +269,20 @@ object MultilabelModel extends ParserProviderCompanion { ind <- labelToInd.get(label).toList } yield (ind, label) + val problems = + if (labelsShouldPredict.nonEmpty) None + else Option(labelsOfInterest.accessorOutputProblems(a)) + val noPrediction = if (unsorted.size == labelsShouldPredict.size) Seq.empty else labelsShouldPredict.filterNot(labelToInd.contains) val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip - (ind, lab, noPrediction) + LabelsAndInfo(ind, lab, noPrediction, problems) } + override def parser: ModelParser = Parser object Parser extends ModelSubmodelParsingPlugin { @@ -172,7 +325,6 @@ object MultilabelModel extends ParserProviderCompanion { // TODO: parse the native submodel from the wrapped ML library. This involves plugins - ??? } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 06880bd7..b381d8f0 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -11,8 +11,12 @@ package object multilabel { /** * Features about the input value (NOT including features based on labels). + * This should probably be a `sci.IndexedSeq[Sparse]` but `RegressionFeatures` + * returns a `collection.IndexedSeq` and using + * [[com.eharmony.aloha.models.reg.RegressionFeatures.constructFeatures]] is + * preferable and will provide consistent results across many model types. */ - private type SparseFeatures = sci.IndexedSeq[Sparse] + private type SparseFeatures = IndexedSeq[Sparse] /** * Indices of the labels for which predictions should be produced into the diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala index 21be6fb0..da9bdb48 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala @@ -1,5 +1,6 @@ package com.eharmony.aloha.models.reg +import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.semantics.func.GenAggFunc import scala.collection.{immutable => sci, mutable => scm} @@ -20,7 +21,7 @@ trait RegressionFeatures[-A] { /** * Parallel to featureNames. This is the sequence of functions that extract data from the input value. */ - protected[this] val featureFunctions: sci.IndexedSeq[GenAggFunc[A, Iterable[(String, Double)]]] + protected[this] val featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]] /** * A threshold dictating how many missing features to allow before making the prediction fail. None means @@ -66,7 +67,7 @@ trait RegressionFeatures[-A] { * 1 the map of bad features to the missing values in the raw data that were needed to compute the feature * 1 whether the amount of missing data is acceptable to still continue */ - protected[this] final def constructFeatures(a: A): Features[IndexedSeq[Iterable[(String, Double)]]] = { + protected[this] final def constructFeatures(a: A): Features[IndexedSeq[Sparse]] = { // NOTE: Since this function is at the center of the regression process and will be called many times, it // needs to be efficient. Therefore, it uses some things that are not idiomatic scala. For instance, // there are mutable variables, while loops instead of for comprehensions or Range.foreach, etc. @@ -95,7 +96,7 @@ trait RegressionFeatures[-A] { i += 1 } - val numMissingOk = numMissingThreshold map { missing.size <= _ } getOrElse true + val numMissingOk = numMissingThreshold.forall(t => missing.size <= t) // If we are going to err out, allow a linear scan (with repeated work so that we can get richer error // diagnostics. Only include the values where the list of missing accessors variables is not empty. From 78120198255ac891744e8c0a76f2e2d675edc95d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 10:49:42 -0700 Subject: [PATCH 12/98] Added B <: U in helper methods. --- .../aloha/models/multilabel/MultilabelModel.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 7d77f1e3..96d35ab6 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -159,7 +159,7 @@ object MultilabelModel extends ParserProviderCompanion { * @tparam B * @return */ - protected[multilabel] def reportTooManyMissing[U, K, B]( + protected[multilabel] def reportTooManyMissing[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], missing: scm.Map[String, Seq[String]], @@ -180,7 +180,7 @@ object MultilabelModel extends ParserProviderCompanion { * @tparam B * @return */ - protected[multilabel] def reportNoPrediction[U, K, B]( + protected[multilabel] def reportNoPrediction[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], auditor: Auditor[U, Map[K, Double], B] @@ -202,7 +202,7 @@ object MultilabelModel extends ParserProviderCompanion { * @tparam B * @return */ - protected[multilabel] def reportSuccess[U, K, B]( + protected[multilabel] def reportSuccess[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], missing: scm.Map[String, Seq[String]], @@ -233,7 +233,7 @@ object MultilabelModel extends ParserProviderCompanion { * @tparam B * @return */ - protected[multilabel] def reportPredictorError[U, K, B]( + protected[multilabel] def reportPredictorError[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], missing: scm.Map[String, Seq[String]], From f78e4bde719e40ba975423463a533a226a5bf09d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 11:59:59 -0700 Subject: [PATCH 13/98] Added a few comments to model and added test skeleton. --- .../models/multilabel/MultilabelModel.scala | 4 +- .../multilabel/MultilabelModelTest.scala | 167 ++++++++++++++++++ 2 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 96d35ab6..737b7a38 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -45,7 +45,9 @@ import scala.util.{Failure, Success, Try} * @param predictorProducer the function produced when calling this function is responsible for * getting the data into the correct type and using it within an * underlying ML library to produce a prediction. The mapping back to - * (K, Double) pairs is also its responsibility. + * (K, Double) pairs is also its responsibility. If the predictor + * produced by predictorProducer is Closeable, it will be closed when + * MultilabelModel's close method is called. * @param numMissingThreshold if provided, we check whether the threshold is exceeded. If so, * return an error instead of the computed score. This is for missing * data situations. diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala new file mode 100644 index 00000000..3190a3cd --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -0,0 +1,167 @@ +package com.eharmony.aloha.models.multilabel + +import com.eharmony.aloha.ModelSerializationTestHelper +import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +/** + * Created by ryan.deak on 9/1/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class MultilabelModelTest extends ModelSerializationTestHelper { + + + // TODO: Fill in the test implementation and delete comments once done. + + @Test def testSerialization(): Unit = { + // The name of this test needs to be exactly 'testSerialization'. Don't change. + // Assuming all parameters passed to the MultilabelModel constructor are + // Serializable, MultilabelModel should also be Serializable. + // + // See com.eharmony.aloha.models.ConstantModelTest.testSerialization() + + fail() + } + + @Test def testModelCloseClosesPredictor(): Unit = { + // Make the predictorProducer passed to the constructor be a + // 'SparsePredictorProducer[K] with Closeable'. + // predictorProducer should track whether it is closed (using an AtomicBoolean or something). + // Call close on the MultilabelModel instance and ensure that the underlying predictor is + // also closed. + + fail() + } + + @Test def testLabelsOfInterestOmitted(): Unit = { + // Test labelsAndInfo[A, K] function. + // + // When labelsOfInterest = None, labelsAndInfo should return: + // LabelsAndInfo[K]( + // indices = labelsInTrainingSet.indices, + // labels = labelsInTrainingSet, + // missingLabels = Seq.empty[K], + // problems = None + // ) + + fail() + } + + @Test def testLabelsOfInterestProvided(): Unit = { + // Test labelsAndInfo[A, K] function. + // + // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == + // labelsForPrediction(a, labelsOfInterest.get, labelToInd) + + fail() + } + + @Test def testReportTooManyMissing(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + fail() + } + + @Test def testReportNoPrediction(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + fail() + } + + @Test def testReportPredictorError(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + fail() + } + + @Test def testReportSuccess(): Unit = { + // Make sure Subvalue.natural == Some(value) + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be Some(value2). + // 'value' should equal 'value2'. + // Check the errors and missing values. + + fail() + } + + @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { + // Test this: + // val problems = + // if (labelsShouldPredict.nonEmpty) None + // else Option(labelsOfInterest.accessorOutputProblems(a)) + + fail() + } + + @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { + // Test this: + // val noPrediction = + // if (unsorted.size == labelsShouldPredict.size) Seq.empty + // else labelsShouldPredict.filterNot(labelToInd.contains) + + fail() + } + + @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { + // Test this: + // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip + + fail() + } + + @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { + // Test this: + // if (li.labels.isEmpty) + // reportNoPrediction(modelId, li, auditor) + + fail() + } + + @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { + // When the amount of missing data exceeds the threshold, reportTooManyMissing should be + // called and its value should be returned. Instantiate a MultilabelModel and + // call apply with some missing data required by the features. + + fail() + } + + @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { + // Create a predictorProducer that throws. Check that the model still returns a value + // and that the error message is incorporated appropriately. + + fail() + } + + @Test def testSubvalueSuccess(): Unit = { + // Test the happy path by calling model.apply. Check the value, missing data, and error messages. + + fail() + } +} + +object MultilabelModelTest { + // TODO: Use this label type and Auditor. + + private type Label = String + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]( + accumulateErrors = false, + accumulateMissingFeatures = false + ) + +// TODO: Access information returned in audited value by using the following functions: + // val aud: RootedTree[Any, Map[Label, Double]] = ??? + // aud.modelId // : ModelIdentity + // aud.value // : Option[Map[Label, Double]] // Should be missing on failure. + // aud.missingVarNames // : Set[String] + // aud.errorMsgs // : Seq[String] + // aud.prob // : Option[Float] (Shouldn't need this) +} \ No newline at end of file From f21fa4f181059f15bcb7d5dd70bbd32149767d04 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 12:05:48 -0700 Subject: [PATCH 14/98] removed parameters to Auditor. Just use defaults because the values don't matter because there are no sub models or super models. --- .../aloha/models/multilabel/MultilabelModelTest.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 3190a3cd..d8a8311e 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -152,10 +152,7 @@ object MultilabelModelTest { // TODO: Use this label type and Auditor. private type Label = String - private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]( - accumulateErrors = false, - accumulateMissingFeatures = false - ) + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() // TODO: Access information returned in audited value by using the following functions: // val aud: RootedTree[Any, Map[Label, Double]] = ??? From a9a0f8bbb40922695e75794975131ec5cdad3106 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 12:07:46 -0700 Subject: [PATCH 15/98] Added import to companion object Label type and Auditor. --- .../eharmony/aloha/models/multilabel/MultilabelModelTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index d8a8311e..c5172a2c 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -12,7 +12,7 @@ import org.junit.runners.BlockJUnit4ClassRunner */ @RunWith(classOf[BlockJUnit4ClassRunner]) class MultilabelModelTest extends ModelSerializationTestHelper { - + import MultilabelModelTest._ // TODO: Fill in the test implementation and delete comments once done. From 2d6de6f68e2fc1e6c8d3660e487865c1bb35ea00 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 12:16:49 -0700 Subject: [PATCH 16/98] Added additional test. --- .../aloha/models/multilabel/MultilabelModelTest.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index c5172a2c..14897eb9 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -146,6 +146,12 @@ class MultilabelModelTest extends ModelSerializationTestHelper { fail() } + + @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { + // NOTE: This is by design. + + fail() + } } object MultilabelModelTest { From 6defbe07a1afbde1767d2578efd5f47101f3f8bc Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 14:11:36 -0700 Subject: [PATCH 17/98] SparseMultiLabelPredictor was made package private for testing. --- .../scala/com/eharmony/aloha/models/multilabel/package.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index b381d8f0..f27fa11a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -49,9 +49,11 @@ package object multilabel { * * and returns a Map from the labels passed in, to the prediction associated with the label. * + * '''NOTE''': This is exposed as package private for testing. + * * @tparam K the type of labels (or classes in the machine learning literature). */ - private type SparseMultiLabelPredictor[K] = + private[multilabel] type SparseMultiLabelPredictor[K] = (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] /** From 7661df256363c8deb8a6776bc755cfb74b4b5bd4 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 15:13:53 -0700 Subject: [PATCH 18/98] updated privacy of type aliases in multilabel package object --- .../eharmony/aloha/models/multilabel/package.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index f27fa11a..583ad6e3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -9,6 +9,8 @@ import scala.collection.{immutable => sci} */ package object multilabel { + // All but the last type are package private, for testing. The last is public. + /** * Features about the input value (NOT including features based on labels). * This should probably be a `sci.IndexedSeq[Sparse]` but `RegressionFeatures` @@ -16,13 +18,13 @@ package object multilabel { * [[com.eharmony.aloha.models.reg.RegressionFeatures.constructFeatures]] is * preferable and will provide consistent results across many model types. */ - private type SparseFeatures = IndexedSeq[Sparse] + private[multilabel] type SparseFeatures = IndexedSeq[Sparse] /** * Indices of the labels for which predictions should be produced into the * sequence of all labels. Indices will be sorted in ascending order. */ - private type LabelIndices = sci.IndexedSeq[Int] + private[multilabel] type LabelIndices = sci.IndexedSeq[Int] /** * Labels for which predictions should be produced. This can be an improper subset of all labels. @@ -31,13 +33,13 @@ package object multilabel { * * @tparam K the type of labels (or classes in the machine learning literature). */ - private type Labels[K] = sci.IndexedSeq[K] + private[multilabel] type Labels[K] = sci.IndexedSeq[K] /** * Sparse features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` * sequences meaning `SparseLabelDepFeatures[i]` relates to the features of `Labels[i]`. */ - private type SparseLabelDepFeatures = Labels[SparseFeatures] + private[multilabel] type SparseLabelDepFeatures = Labels[SparseFeatures] /** * A sparse multi-label predictor takes: @@ -69,3 +71,4 @@ package object multilabel { */ type SparsePredictorProducer[K] = () => SparseMultiLabelPredictor[K] } + From 9fd15adc298e568032d9547492370098eb1b92bf Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 1 Sep 2017 15:17:19 -0700 Subject: [PATCH 19/98] Serialization test --- .../aloha/models/multilabel/package.scala | 8 +- .../multilabel/MultilabelModelTest.scala | 286 ++++++++++-------- 2 files changed, 158 insertions(+), 136 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index f27fa11a..45a7a90e 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -16,13 +16,13 @@ package object multilabel { * [[com.eharmony.aloha.models.reg.RegressionFeatures.constructFeatures]] is * preferable and will provide consistent results across many model types. */ - private type SparseFeatures = IndexedSeq[Sparse] + private[multilabel] type SparseFeatures = IndexedSeq[Sparse] /** * Indices of the labels for which predictions should be produced into the * sequence of all labels. Indices will be sorted in ascending order. */ - private type LabelIndices = sci.IndexedSeq[Int] + private[multilabel] type LabelIndices = sci.IndexedSeq[Int] /** * Labels for which predictions should be produced. This can be an improper subset of all labels. @@ -31,13 +31,13 @@ package object multilabel { * * @tparam K the type of labels (or classes in the machine learning literature). */ - private type Labels[K] = sci.IndexedSeq[K] + private[multilabel] type Labels[K] = sci.IndexedSeq[K] /** * Sparse features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` * sequences meaning `SparseLabelDepFeatures[i]` relates to the features of `Labels[i]`. */ - private type SparseLabelDepFeatures = Labels[SparseFeatures] + private[multilabel] type SparseLabelDepFeatures = Labels[SparseFeatures] /** * A sparse multi-label predictor takes: diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 14897eb9..a604ecc4 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -2,11 +2,16 @@ package com.eharmony.aloha.models.multilabel import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.semantics.func.GenAggFunc import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import scala.collection.{immutable => sci} + /** * Created by ryan.deak on 9/1/17. */ @@ -17,141 +22,158 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // TODO: Fill in the test implementation and delete comments once done. @Test def testSerialization(): Unit = { - // The name of this test needs to be exactly 'testSerialization'. Don't change. // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - // - // See com.eharmony.aloha.models.ConstantModelTest.testSerialization() - - fail() - } - - @Test def testModelCloseClosesPredictor(): Unit = { - // Make the predictorProducer passed to the constructor be a - // 'SparsePredictorProducer[K] with Closeable'. - // predictorProducer should track whether it is closed (using an AtomicBoolean or something). - // Call close on the MultilabelModel instance and ensure that the underlying predictor is - // also closed. - - fail() - } - - @Test def testLabelsOfInterestOmitted(): Unit = { - // Test labelsAndInfo[A, K] function. - // - // When labelsOfInterest = None, labelsAndInfo should return: - // LabelsAndInfo[K]( - // indices = labelsInTrainingSet.indices, - // labels = labelsInTrainingSet, - // missingLabels = Seq.empty[K], - // problems = None - // ) - - fail() - } - - @Test def testLabelsOfInterestProvided(): Unit = { - // Test labelsAndInfo[A, K] function. - // - // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == - // labelsForPrediction(a, labelsOfInterest.get, labelToInd) - - fail() - } - - @Test def testReportTooManyMissing(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - - fail() - } - - @Test def testReportNoPrediction(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - - fail() - } - - @Test def testReportPredictorError(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - - fail() - } - - @Test def testReportSuccess(): Unit = { - // Make sure Subvalue.natural == Some(value) - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be Some(value2). - // 'value' should equal 'value2'. - // Check the errors and missing values. - - fail() - } - @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { - // Test this: - // val problems = - // if (labelsShouldPredict.nonEmpty) None - // else Option(labelsOfInterest.accessorOutputProblems(a)) - - fail() - } - - @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { - // Test this: - // val noPrediction = - // if (unsorted.size == labelsShouldPredict.size) Seq.empty - // else labelsShouldPredict.filterNot(labelToInd.contains) - - fail() - } - - @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { - // Test this: - // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip - - fail() - } - - @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { - // Test this: - // if (li.labels.isEmpty) - // reportNoPrediction(modelId, li, auditor) - - fail() - } - - @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { - // When the amount of missing data exceeds the threshold, reportTooManyMissing should be - // called and its value should be returned. Instantiate a MultilabelModel and - // call apply with some missing data required by the features. - - fail() - } - - @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { - // Create a predictorProducer that throws. Check that the model still returns a value - // and that the error message is incorporated appropriately. - - fail() - } - - @Test def testSubvalueSuccess(): Unit = { - // Test the happy path by calling model.apply. Check the value, missing data, and error messages. - - fail() - } - - @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { - // NOTE: This is by design. - - fail() - } + final case class ConstantMultiLabelPredictor[K](returnVal: Map[K, Double]) + extends SparseMultiLabelPredictor[K] { + override def apply(v1: SparseFeatures, + v2: Labels[K], + v3: LabelIndices, + v4: SparseLabelDepFeatures): Map[K, Double] = returnVal + } + + val model = MultilabelModel( + ModelId(), + sci.IndexedSeq(), + sci.IndexedSeq[GenAggFunc[Int, Sparse]](), + sci.IndexedSeq[Label](), + None, + () => ConstantMultiLabelPredictor(Map[Label, Double]()), + None, + Auditor + ) + + val modelRoundTrip = serializeDeserializeRoundTrip(model) + assertEquals(model, modelRoundTrip) + } + +// @Test def testModelCloseClosesPredictor(): Unit = { +// // Make the predictorProducer passed to the constructor be a +// // 'SparsePredictorProducer[K] with Closeable'. +// // predictorProducer should track whether it is closed (using an AtomicBoolean or something). +// // Call close on the MultilabelModel instance and ensure that the underlying predictor is +// // also closed. +// +// fail() +// } +// +// @Test def testLabelsOfInterestOmitted(): Unit = { +// // Test labelsAndInfo[A, K] function. +// // +// // When labelsOfInterest = None, labelsAndInfo should return: +// // LabelsAndInfo[K]( +// // indices = labelsInTrainingSet.indices, +// // labels = labelsInTrainingSet, +// // missingLabels = Seq.empty[K], +// // problems = None +// // ) +// +// fail() +// } +// +// @Test def testLabelsOfInterestProvided(): Unit = { +// // Test labelsAndInfo[A, K] function. +// // +// // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == +// // labelsForPrediction(a, labelsOfInterest.get, labelToInd) +// +// fail() +// } +// +// @Test def testReportTooManyMissing(): Unit = { +// // Make sure Subvalue.natural == None +// // Check the values of Subvalue.audited and make sure they are as expected. +// // Subvalue.audited.value should be None. Check the errors and missing values. +// +// fail() +// } +// +// @Test def testReportNoPrediction(): Unit = { +// // Make sure Subvalue.natural == None +// // Check the values of Subvalue.audited and make sure they are as expected. +// // Subvalue.audited.value should be None. Check the errors and missing values. +// +// fail() +// } +// +// @Test def testReportPredictorError(): Unit = { +// // Make sure Subvalue.natural == None +// // Check the values of Subvalue.audited and make sure they are as expected. +// // Subvalue.audited.value should be None. Check the errors and missing values. +// +// fail() +// } +// +// @Test def testReportSuccess(): Unit = { +// // Make sure Subvalue.natural == Some(value) +// // Check the values of Subvalue.audited and make sure they are as expected. +// // Subvalue.audited.value should be Some(value2). +// // 'value' should equal 'value2'. +// // Check the errors and missing values. +// +// fail() +// } +// +// @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { +// // Test this: +// // val problems = +// // if (labelsShouldPredict.nonEmpty) None +// // else Option(labelsOfInterest.accessorOutputProblems(a)) +// +// fail() +// } +// +// @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { +// // Test this: +// // val noPrediction = +// // if (unsorted.size == labelsShouldPredict.size) Seq.empty +// // else labelsShouldPredict.filterNot(labelToInd.contains) +// +// fail() +// } +// +// @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { +// // Test this: +// // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip +// +// fail() +// } +// +// @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { +// // Test this: +// // if (li.labels.isEmpty) +// // reportNoPrediction(modelId, li, auditor) +// +// fail() +// } +// +// @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { +// // When the amount of missing data exceeds the threshold, reportTooManyMissing should be +// // called and its value should be returned. Instantiate a MultilabelModel and +// // call apply with some missing data required by the features. +// +// fail() +// } +// +// @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { +// // Create a predictorProducer that throws. Check that the model still returns a value +// // and that the error message is incorporated appropriately. +// +// fail() +// } +// +// @Test def testSubvalueSuccess(): Unit = { +// // Test the happy path by calling model.apply. Check the value, missing data, and error messages. +// +// fail() +// } +// +// @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { +// // NOTE: This is by design. +// +// fail() +// } } object MultilabelModelTest { From ce66b7c8a24a1498ae72ca2a3e47061f82ad98f6 Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 1 Sep 2017 16:39:29 -0700 Subject: [PATCH 20/98] First test passing --- .../multilabel/MultilabelModelTest.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index a604ecc4..3436e418 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -25,21 +25,13 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - final case class ConstantMultiLabelPredictor[K](returnVal: Map[K, Double]) - extends SparseMultiLabelPredictor[K] { - override def apply(v1: SparseFeatures, - v2: Labels[K], - v3: LabelIndices, - v4: SparseLabelDepFeatures): Map[K, Double] = returnVal - } - val model = MultilabelModel( ModelId(), sci.IndexedSeq(), sci.IndexedSeq[GenAggFunc[Int, Sparse]](), sci.IndexedSeq[Label](), None, - () => ConstantMultiLabelPredictor(Map[Label, Double]()), + Lazy(ConstantPredictor[Label]()), None, Auditor ) @@ -182,6 +174,18 @@ object MultilabelModelTest { private type Label = String private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + case class ConstantPredictor[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] + with Serializable { + override def apply(v1: SparseFeatures, + v2: Labels[K], + v3: LabelIndices, + v4: SparseLabelDepFeatures): Map[K, Double] = v2.map(_ -> prediction).toMap + } + + case class Lazy[A](value: A) extends (() => A) { + override def apply(): A = value + } + // TODO: Access information returned in audited value by using the following functions: // val aud: RootedTree[Any, Map[Label, Double]] = ??? // aud.modelId // : ModelIdentity @@ -189,4 +193,4 @@ object MultilabelModelTest { // aud.missingVarNames // : Set[String] // aud.errorMsgs // : Seq[String] // aud.prob // : Option[Float] (Shouldn't need this) -} \ No newline at end of file +} From 75dee229cecdbd7f36898a9d99f2553e98fe27f6 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 16:39:29 -0700 Subject: [PATCH 21/98] Hopefully code complete for the case class. --- .../models/multilabel/MultilabelModel.scala | 224 +++++++++++------- 1 file changed, 137 insertions(+), 87 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 737b7a38..99e11bb8 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -1,6 +1,6 @@ package com.eharmony.aloha.models.multilabel -import java.io.Closeable +import java.io.{Closeable, PrintWriter, StringWriter} import com.eharmony.aloha.audit.Auditor import com.eharmony.aloha.dataset.density.Sparse @@ -29,7 +29,7 @@ import scala.util.{Failure, Success, Try} * * Created by ryan.deak on 8/29/17. * - * @param modelId An identifier for the model. User in score and error reporting. + * @param modelId An identifier for the model. Used in score and error reporting. * @param featureNames feature names (parallel to featureFunctions) * @param featureFunctions feature extracting functions. * @param labelsInTrainingSet a sequence of all labels encountered during training. Note: the @@ -78,11 +78,18 @@ extends SubmodelBase[U, Map[K, Double], A, B] @transient private[this] lazy val predictor = predictorProducer() predictor // Force predictor eagerly + /** + * Cache this in case labelsOfInterest is None. In that case, we don't want to repeatedly + * create this because it could create a GC burden for no real reason. + */ + @transient private[this] lazy val defaultLabelInfo = + LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) + private[this] val labelToInd: Map[K, Int] = labelsInTrainingSet.zipWithIndex.map { case (label, i) => label -> i }(collection.breakOut) override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { - val li = labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) + val li = labelsAndInfo(a, labelsOfInterest, labelToInd, defaultLabelInfo) if (li.labels.isEmpty) reportNoPrediction(modelId, li, auditor) @@ -113,53 +120,80 @@ extends SubmodelBase[U, Map[K, Double], A, B] object MultilabelModel extends ParserProviderCompanion { /** - * - * @param indices - * @param labels - * @param missingLabels - * @param problems - * @tparam K + * Contains information about the labels to be used for predictions, and problems encountered + * while trying to get those labels. + * @param indices indices into the sequence of all labels seen during training. These should + * be sorted in ascending order. + * @param labels labels for which a prediction should be produced. labels are parallel to + * indices so `indices(i)` is the index associated with `labels(i)`. + * @param missingLabels a sequence of labels derived from the input data that could not be + * found in the sequence of all labels seen during training. + * @param problems any problems encountered when trying to get the labels. This should only + * be present when the caller indicates labels should be embedded in the + * input data passed to the prediction function in the MultilabelModel. + * @tparam K type of label or class */ protected[multilabel] case class LabelsAndInfo[K]( indices: sci.IndexedSeq[Int], labels: sci.IndexedSeq[K], missingLabels: Seq[K], problems: Option[GenAggFuncAccessorProblems] - ) + ) { + def missingVarNames: Seq[String] = problems.map(p => p.missing).getOrElse(Nil) + def errorMsgs: Seq[String] = { + missingLabels.map { lab => s"Label not in training labels: $lab" } ++ + problems.map(p => p.errors).getOrElse(Nil) + } + } + + private[multilabel] val NumLinesToKeepInStackTrace = 20 + + private[multilabel] val TooManyMissingError = + "Too many missing features encountered to produce prediction." + + private[multilabel] val NoLabelsError = "No labels found in training set." /** - * - * @param a - * @param labelsInTrainingSet - * @param labelsOfInterest - * @param labelToInd - * @tparam A - * @tparam K - * @return + * Get the labels and information about the labels. + * @param a an input from which label information should be derived if labelsOfInterest is not empty. + * @param labelsOfInterest an optional function used to extract label information from the input `a`. + * @param labelToInd a mapping from label to index into the sequence of all labels seen during training. + * @param defaultLabelInfo label information related to all labels seen at training time. If + * `labelsOfInterest` is not provided, this information will be used. + * @tparam A input type of the model + * @tparam K type of label or class + * @return labels and information about the labels. */ protected[multilabel] def labelsAndInfo[A, K]( a: A, - labelsInTrainingSet: sci.IndexedSeq[K], labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], - labelToInd: Map[K, Int] - ): LabelsAndInfo[K] = { - // TODO: Is this good enough? Are we tracking enough missing information? Probably not. - labelsOfInterest.map ( labelFn => - labelsForPrediction(a, labelFn, labelToInd) - ) getOrElse { - LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) - } - } + labelToInd: Map[K, Int], + defaultLabelInfo: LabelsAndInfo[K] + ): LabelsAndInfo[K] = + labelsOfInterest.fold(defaultLabelInfo)(f => labelsForPrediction(a, f, labelToInd)) + + /** + * Combine the missing variables found into a set. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @tparam K type of label or class + * @return a set of missing features + */ + protected[multilabel] def combineMissing[K]( + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]] + ): Set[String] = missing.values.flatten.toSet ++ labelInfo.missingVarNames /** - * - * @param modelId - * @param missing - * @param auditor - * @tparam U - * @tparam K - * @tparam B - * @return + * Report that a prediction could not be made because too many missing features were encountered. + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. */ protected[multilabel] def reportTooManyMissing[U, K, B <: U]( modelId: ModelIdentity, @@ -167,42 +201,50 @@ object MultilabelModel extends ParserProviderCompanion { missing: scm.Map[String, Seq[String]], auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Nothing] = { - // TODO: Fill in the errors. - val aud = auditor.failure(modelId, missingVarNames = missing.values.flatten.toSet) + + // TODO: Check that missing.values.flatten.toSet AND labelInfo.missingFeatures have the same format. + val aud = auditor.failure( + modelId, + errorMsgs = TooManyMissingError +: labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missing) + ) Subvalue(aud, None) } /** - * - * @param modelId - * @param labelInfo - * @param auditor - * @tparam U - * @tparam K - * @tparam B - * @return + * Report that no prediction attempt was made because of issues with the labels. + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. */ protected[multilabel] def reportNoPrediction[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Nothing] = { - // TODO: Fill in the errors. - val aud = auditor.failure(modelId) + val aud = auditor.failure( + modelId, + errorMsgs = NoLabelsError +: labelInfo.errorMsgs, + missingVarNames = labelInfo.missingVarNames.toSet + ) Subvalue(aud, None) } /** - * - * @param modelId - * @param labelInfo - * @param missing - * @param prediction - * @param auditor - * @tparam U - * @tparam K - * @tparam B - * @return + * Report that the model succeeded. + * @param modelId An identifier for the model. Used in score reporting. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @param prediction the prediction(s) made by the embedded predictor. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating success. */ protected[multilabel] def reportSuccess[U, K, B <: U]( modelId: ModelIdentity, @@ -212,28 +254,27 @@ object MultilabelModel extends ParserProviderCompanion { auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Map[K, Double]] = { - val errors = - if (labelInfo.missingLabels.nonEmpty) - Seq(s"Labels provide for which a prediction could not be produced: ${labelInfo.missingLabels.mkString(", ")}.") - else Seq.empty - - // TODO: Incorporate missing data reporting. - val aud: B = auditor.success(modelId, prediction, errorMsgs = errors) + val aud = auditor.success( + modelId, + prediction, + errorMsgs = labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missing) + ) Subvalue(aud, Option(prediction)) } /** - * - * @param modelId - * @param labelInfo - * @param missing - * @param throwable - * @param auditor - * @tparam U - * @tparam K - * @tparam B - * @return + * Report that a `Throwable` was thrown while invoking the predictor + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @param throwable the error the occurred in the predictor. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. */ protected[multilabel] def reportPredictorError[U, K, B <: U]( modelId: ModelIdentity, @@ -243,27 +284,36 @@ object MultilabelModel extends ParserProviderCompanion { auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Nothing] = { - // TODO: Fill in. - val aud = auditor.failure(modelId) + val pw = new PrintWriter(new StringWriter) + throwable.printStackTrace(pw) + val stackTrace = pw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + + val aud = auditor.failure( + modelId, + errorMsgs = stackTrace +: labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missing) + ) Subvalue(aud, None) } /** - * - * @param a - * @param labelsOfInterest - * @param labelToInd - * @tparam A - * @tparam K - * @return + * Get labels from the input for which a prediction should be produced. + * @param example the example provided to the model + * @param labelsOfInterest a function used to extract labels for which a + * prediction should be produced. + * @param labelToInd mapping from Label to index into the sequence of all + * labels seen in the training set. + * @tparam A input type of the model + * @tparam K type of label or class + * @return labels and information about the labels. */ protected[multilabel] def labelsForPrediction[A, K]( - a: A, + example: A, labelsOfInterest: GenAggFunc[A, sci.IndexedSeq[K]], labelToInd: Map[K, Int] ): LabelsAndInfo[K] = { - val labelsShouldPredict = labelsOfInterest(a) + val labelsShouldPredict = labelsOfInterest(example) val unsorted = for { @@ -273,7 +323,7 @@ object MultilabelModel extends ParserProviderCompanion { val problems = if (labelsShouldPredict.nonEmpty) None - else Option(labelsOfInterest.accessorOutputProblems(a)) + else Option(labelsOfInterest.accessorOutputProblems(example)) val noPrediction = if (unsorted.size == labelsShouldPredict.size) Seq.empty From 2e0f4eb2fbc44734a194bc51ef3a403c07f99bfb Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 1 Sep 2017 16:42:13 -0700 Subject: [PATCH 22/98] Added comment. --- .../com/eharmony/aloha/models/multilabel/MultilabelModel.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 99e11bb8..26326002 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -85,6 +85,9 @@ extends SubmodelBase[U, Map[K, Double], A, B] @transient private[this] lazy val defaultLabelInfo = LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) + /** + * Making from label to index into the sequence of all labels encountered during training. + */ private[this] val labelToInd: Map[K, Int] = labelsInTrainingSet.zipWithIndex.map { case (label, i) => label -> i }(collection.breakOut) From 3b2652b5b518d238fa93862ce4caa0e4ed79b3af Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 5 Sep 2017 11:05:06 -0700 Subject: [PATCH 23/98] lessened privileges. --- .../scala/com/eharmony/aloha/models/multilabel/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 583ad6e3..04412828 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -55,7 +55,7 @@ package object multilabel { * * @tparam K the type of labels (or classes in the machine learning literature). */ - private[multilabel] type SparseMultiLabelPredictor[K] = + type SparseMultiLabelPredictor[K] = (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] /** From 999348b7eccba259660f35b630643fb3602e87cb Mon Sep 17 00:00:00 2001 From: amirziai Date: Tue, 5 Sep 2017 16:21:29 -0700 Subject: [PATCH 24/98] Adding more multi-label tests --- .../multilabel/MultilabelModelTest.scala | 139 ++++++++++++++---- 1 file changed, 109 insertions(+), 30 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 3436e418..05876f4a 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -1,22 +1,25 @@ package com.eharmony.aloha.models.multilabel +import java.io.{PrintWriter, StringWriter} + import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.semantics.func.GenAggFunc -import org.junit.Assert._ import org.junit.Test +import org.junit.Assert._ import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner -import scala.collection.{immutable => sci} +import scala.collection.{immutable => sci, mutable => scm} /** * Created by ryan.deak on 9/1/17. */ @RunWith(classOf[BlockJUnit4ClassRunner]) class MultilabelModelTest extends ModelSerializationTestHelper { + import MultilabelModel._ import MultilabelModelTest._ // TODO: Fill in the test implementation and delete comments once done. @@ -25,19 +28,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - val model = MultilabelModel( - ModelId(), - sci.IndexedSeq(), - sci.IndexedSeq[GenAggFunc[Int, Sparse]](), - sci.IndexedSeq[Label](), - None, - Lazy(ConstantPredictor[Label]()), - None, - Auditor - ) - - val modelRoundTrip = serializeDeserializeRoundTrip(model) - assertEquals(model, modelRoundTrip) + val modelRoundTrip = serializeDeserializeRoundTrip(modelNoFeatures) + assertEquals(modelNoFeatures, modelRoundTrip) } // @Test def testModelCloseClosesPredictor(): Unit = { @@ -81,21 +73,89 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // fail() // } // -// @Test def testReportNoPrediction(): Unit = { -// // Make sure Subvalue.natural == None -// // Check the values of Subvalue.audited and make sure they are as expected. -// // Subvalue.audited.value should be None. Check the errors and missing values. -// -// fail() -// } -// -// @Test def testReportPredictorError(): Unit = { -// // Make sure Subvalue.natural == None -// // Check the values of Subvalue.audited and make sure they are as expected. -// // Subvalue.audited.value should be None. Check the errors and missing values. -// -// fail() -// } + @Test def testReportNoPrediction(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + val labelInfo = LabelsAndInfo( + indices = sci.IndexedSeq[Int](), + labels = sci.IndexedSeq[Label](), + missingLabels = Seq[Label](), + problems = None + ) + + val report = reportNoPrediction( + ModelId(1, "a"), + labelInfo, + Auditor + ) + + // TODO: check labelInfo values + assertEquals(Vector(NoLabelsError), report.audited.errorMsgs.take(1)) + assertEquals(None, report.audited.value) + } + + @Test def testReportNoPredictionMissingLabelsDoNotExist(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + val labelInfo = LabelsAndInfo( + indices = sci.IndexedSeq[Int](), + labels = sci.IndexedSeq[Label](), + missingLabels = missingLabels, + problems = None + ) + + val report = reportNoPrediction( + ModelId(1, "a"), + labelInfo, + Auditor + ) + + // TODO: look into problems + assertEquals(Vector(NoLabelsError) ++ errorMessages, report.audited.errorMsgs) + assertEquals(None, report.audited.value) + } + + + @Test def testReportPredictorError(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + val labelInfo = LabelsAndInfo( + indices = sci.IndexedSeq[Int](), + labels = sci.IndexedSeq[Label](), + missingLabels = missingLabels, + problems = None + ) + + // TODO: is this a good example of a throwable? + val throwable = new RuntimeException("error") + val pw = new PrintWriter(new StringWriter) + throwable.printStackTrace(pw) + val stackTrace = pw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + + // TODO: understand why this is a map + val missing = scm.Map("x" -> missingLabels) + + val report = reportPredictorError( + ModelId(-1, "x"), + labelInfo, + missing, + throwable, + Auditor + ) + + // TODO: stack trace seems to be a reference and therefore not the same + // assertEquals(Vector(stackTrace) ++ errorMessages, report.audited.errorMsgs) + + assertEquals(missingLabels.toSet, report.audited.missingVarNames) + assertEquals(None, report.natural) + assertEquals(None, report.audited.value) + } // // @Test def testReportSuccess(): Unit = { // // Make sure Subvalue.natural == Some(value) @@ -186,6 +246,25 @@ object MultilabelModelTest { override def apply(): A = value } + val missingLabels: Seq[Label] = Seq("a", "b") + + val modelNoFeatures = MultilabelModel( + ModelId(), + sci.IndexedSeq(), + sci.IndexedSeq[GenAggFunc[Int, Sparse]](), + sci.IndexedSeq[Label](), + None, + Lazy(ConstantPredictor[Label]()), + None, + Auditor + ) + + val aud = RootedTreeAuditor[Any, Map[Label, Double]]() + // private val failure = aud.failure() + + val baseErrorMessage = Stream.continually("Label not in training labels: ") + val errorMessages = baseErrorMessage.zip(missingLabels).map{ case(msg, label) => s"$msg$label" } + // TODO: Access information returned in audited value by using the following functions: // val aud: RootedTree[Any, Map[Label, Double]] = ??? // aud.modelId // : ModelIdentity From a38b5eab2f0c41983d0cec1fa5c649717aaadf98 Mon Sep 17 00:00:00 2001 From: amirziai Date: Tue, 5 Sep 2017 17:46:03 -0700 Subject: [PATCH 25/98] Success report test case --- .../multilabel/MultilabelModelTest.scala | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 05876f4a..2bbf20ac 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -156,17 +156,36 @@ class MultilabelModelTest extends ModelSerializationTestHelper { assertEquals(None, report.natural) assertEquals(None, report.audited.value) } -// -// @Test def testReportSuccess(): Unit = { -// // Make sure Subvalue.natural == Some(value) -// // Check the values of Subvalue.audited and make sure they are as expected. -// // Subvalue.audited.value should be Some(value2). -// // 'value' should equal 'value2'. -// // Check the errors and missing values. -// -// fail() -// } -// + + @Test def testReportSuccess(): Unit = { + // Make sure Subvalue.natural == Some(value) + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be Some(value2). + // 'value' should equal 'value2'. + // Check the errors and missing values. + + // TODO: refactor this? + val labelInfo = LabelsAndInfo( + indices = sci.IndexedSeq[Int](), + labels = sci.IndexedSeq[Label](), + missingLabels = missingLabels, + problems = None + ) + val predictions = Map("label1" -> 1.0) + + val report = reportSuccess( + ModelId(0, "ModelId"), + labelInfo, + scm.Map("x" -> missingLabels), + predictions, + Auditor + ) + + assertEquals(Some(predictions), report.natural) + assertEquals(Some(predictions), report.audited.value) + assertEquals(report.natural, report.audited.value) + } + // @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { // // Test this: // // val problems = From a5c5259e02775abed0ed3513222b9fafd7e74c6b Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 5 Sep 2017 18:30:27 -0700 Subject: [PATCH 26/98] Addressing JMorra's PR comments. --- .../models/multilabel/MultilabelModel.scala | 11 ++++++++++- .../multilabel/SerializabilityEvidence.scala | 17 +++++++++++++++++ .../aloha/models/multilabel/package.scala | 2 +- 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 26326002..28000989 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -11,6 +11,7 @@ import com.eharmony.aloha.models.reg.RegressionFeatures import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.{GenAggFunc, GenAggFuncAccessorProblems} + import spray.json.{JsonFormat, JsonReader} import scala.collection.{immutable => sci, mutable => scm} @@ -52,6 +53,7 @@ import scala.util.{Failure, Success, Try} * return an error instead of the computed score. This is for missing * data situations. * @param auditor transforms a `Map[K, Double]` to a `B`. Reports successes and errors. + * @param ev evidence that `K` is serializable. * @tparam U upper bound on model output type `B` * @tparam K type of label or class * @tparam A input type of the model @@ -66,6 +68,7 @@ case class MultilabelModel[U, K, -A, +B <: U]( predictorProducer: SparsePredictorProducer[K], numMissingThreshold: Option[Int], auditor: Auditor[U, Map[K, Double], B]) +(implicit ev: SerializabilityEvidence[K]) extends SubmodelBase[U, Map[K, Double], A, B] with RegressionFeatures[A] { @@ -82,7 +85,9 @@ extends SubmodelBase[U, Map[K, Double], A, B] * Cache this in case labelsOfInterest is None. In that case, we don't want to repeatedly * create this because it could create a GC burden for no real reason. */ - @transient private[this] lazy val defaultLabelInfo = + private[this] val defaultLabelInfo = + // Hopefully, making this non-transient doesn't increase the serialized size much. + // labelsInTrainingSet isn't serialized twice is it? LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) /** @@ -113,6 +118,10 @@ extends SubmodelBase[U, Map[K, Double], A, B] } } + /** + * When the `predictor` passed to the constructor is a java.io.Closeable, its `close` + * method is called. + */ override def close(): Unit = predictor match { case closeable: Closeable => closeable.close() diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala new file mode 100644 index 00000000..281bde1a --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala @@ -0,0 +1,17 @@ +package com.eharmony.aloha.models.multilabel + +/** + * A type class used to indicate a parameter has a type that can be serialized in + * a larger Serializable object. + */ +sealed trait SerializabilityEvidence[A] + +object SerializabilityEvidence { + + implicit def anyValEvidence[A <: AnyVal]: SerializabilityEvidence[A] = + new SerializabilityEvidence[A]{} + + implicit def serializableEvidence[A <: Serializable]: SerializabilityEvidence[A] = + new SerializabilityEvidence[A]{} +} + diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 04412828..94ff90b3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -9,7 +9,7 @@ import scala.collection.{immutable => sci} */ package object multilabel { - // All but the last type are package private, for testing. The last is public. + // All but the last two types are package private, for testing. The others are public. /** * Features about the input value (NOT including features based on labels). From 26576725ade8f20f90c53368e4f7f0125eed9096 Mon Sep 17 00:00:00 2001 From: amirziai Date: Wed, 6 Sep 2017 13:20:12 -0700 Subject: [PATCH 27/98] More tests --- .../models/multilabel/MultilabelModel.scala | 5 ++-- .../multilabel/MultilabelModelTest.scala | 25 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 26326002..2c0015dd 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -287,9 +287,10 @@ object MultilabelModel extends ParserProviderCompanion { auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Nothing] = { - val pw = new PrintWriter(new StringWriter) + val sw = new StringWriter + val pw = new PrintWriter(sw) throwable.printStackTrace(pw) - val stackTrace = pw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") val aud = auditor.failure( modelId, diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 2bbf20ac..e03df5c4 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -13,6 +13,7 @@ import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.collection.{immutable => sci, mutable => scm} +import scala.util.Try /** * Created by ryan.deak on 9/1/17. @@ -132,27 +133,26 @@ class MultilabelModelTest extends ModelSerializationTestHelper { problems = None ) - // TODO: is this a good example of a throwable? - val throwable = new RuntimeException("error") - val pw = new PrintWriter(new StringWriter) + val throwable = Try(throw new Exception("error")).failed.get + val sw = new StringWriter + val pw = new PrintWriter(sw) throwable.printStackTrace(pw) - val stackTrace = pw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") - // TODO: understand why this is a map - val missing = scm.Map("x" -> missingLabels) + // This is missing variables for a features + val missingVariables = Seq("a", "b") + val missingFeatureMap = scm.Map("x" -> missingVariables) val report = reportPredictorError( ModelId(-1, "x"), labelInfo, - missing, + missingFeatureMap, throwable, Auditor ) - // TODO: stack trace seems to be a reference and therefore not the same - // assertEquals(Vector(stackTrace) ++ errorMessages, report.audited.errorMsgs) - - assertEquals(missingLabels.toSet, report.audited.missingVarNames) + assertEquals(Vector(stackTrace) ++ errorMessages, report.audited.errorMsgs) + assertEquals(missingVariables.toSet, report.audited.missingVarNames) assertEquals(None, report.natural) assertEquals(None, report.audited.value) } @@ -253,8 +253,7 @@ object MultilabelModelTest { private type Label = String private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() - case class ConstantPredictor[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] - with Serializable { + case class ConstantPredictor[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] { override def apply(v1: SparseFeatures, v2: Labels[K], v3: LabelIndices, From 71e21aac02b7383ec196540e564fc3d51c30e230 Mon Sep 17 00:00:00 2001 From: amirziai Date: Wed, 6 Sep 2017 14:16:17 -0700 Subject: [PATCH 28/98] java Serializable needs to be here --- .../aloha/models/multilabel/SerializabilityEvidence.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala index 281bde1a..920605ae 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala @@ -11,7 +11,6 @@ object SerializabilityEvidence { implicit def anyValEvidence[A <: AnyVal]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} - implicit def serializableEvidence[A <: Serializable]: SerializabilityEvidence[A] = + implicit def javaSerializableEvidence[A <: java.io.Serializable]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} } - From ed082289bdedeff2611a6a5dec21a350f100a33e Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 6 Sep 2017 17:12:38 -0700 Subject: [PATCH 29/98] Adding MultilabelModel parsing stuff, plugins, VW version, etc. --- .../eharmony/aloha/factory/ModelFactory.scala | 26 +--- .../models/multilabel/MultilabelModel.scala | 116 ++++++++++++++---- .../MultilabelModelParserPlugin.scala | 41 +++++++ .../MultilabelPluginProviderCompanion.scala | 8 ++ .../multilabel/json/MultilabelModelAst.scala | 11 ++ .../com/eharmony/aloha/reflect/RefInfo.scala | 51 ++++---- .../reflect/RuntimeClasspathScanning.scala | 66 ++++++++++ .../SerializabilityEvidence.scala | 4 +- .../multilabel/VwMultilabelModelPlugin.scala | 110 +++++++++++++++++ 9 files changed, 362 insertions(+), 71 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala rename aloha-core/src/main/scala/com/eharmony/aloha/{models/multilabel => util}/SerializabilityEvidence.scala (73%) create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala index 516f9146..1e608a78 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala @@ -1,19 +1,17 @@ package com.eharmony.aloha.factory -import com.eharmony.aloha import com.eharmony.aloha.audit.{Auditor, MorphableAuditor} import com.eharmony.aloha.factory.ModelFactory.{InlineReader, ModelInlineReader, SubmodelInlineReader} import com.eharmony.aloha.factory.ex.{AlohaFactoryException, RecursiveModelDefinitionException} import com.eharmony.aloha.factory.jsext.JsValueExtensions import com.eharmony.aloha.factory.ri2jf.{RefInfoToJsonFormat, StdRefInfoToJsonFormat} -import com.eharmony.aloha.io.{GZippedReadable, LocationLoggingReadable, ReadableByString} import com.eharmony.aloha.io.multiple.{MultipleAlohaReadable, SequenceMultipleReadable} import com.eharmony.aloha.io.sources.ReadableSource +import com.eharmony.aloha.io.{GZippedReadable, LocationLoggingReadable, ReadableByString} import com.eharmony.aloha.models.{Model, Submodel} -import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps, RuntimeClasspathScanning} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.util.Logging -import org.reflections.Reflections import spray.json.DefaultJsonProtocol.{StringJsonFormat, jsonFormat2, optionFormat} import spray.json.{CompactPrinter, JsObject, JsValue, JsonFormat, JsonReader, RootJsonFormat, pimpString} @@ -268,7 +266,7 @@ case class ModelFactoryImpl[U, N, A, B <: U]( } } -object ModelFactory { +object ModelFactory extends RuntimeClasspathScanning { /** Provides a default factory capable of producing models defined in aloha-core. The list of models come from * the knownModelParsers method. @@ -286,23 +284,7 @@ object ModelFactory { /** Get the list of models on the classpath with parsers that can be used by a model factory. * @return */ - def knownModelParsers(): Seq[ModelParser] = { - val reflections = new Reflections(aloha.pkgName) - import scala.collection.JavaConversions.asScalaSet - val parserProviderCompanions = reflections.getSubTypesOf(classOf[ParserProviderCompanion]).toSeq - - parserProviderCompanions.flatMap { - case ppc if ppc.getCanonicalName.endsWith("$") => - Try { - val c = Class.forName(ppc.getCanonicalName.dropRight(1)) - c.getMethod("parser").invoke(null) match { - case mp: ModelParser => mp - case _ => throw new IllegalStateException() - } - }.toOption - case _ => None - } - } + def knownModelParsers(): Seq[ModelParser] = scanObjects[ParserProviderCompanion, ModelParser]("parser") private[factory] sealed trait InlineReader[U, N, -A, +B <: U, Y] { def jsonReader(parser: ModelParser): Try[JsonReader[_ <: Y]] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 28000989..c5088921 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -5,14 +5,15 @@ import java.io.{Closeable, PrintWriter, StringWriter} import com.eharmony.aloha.audit.Auditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.factory._ +import com.eharmony.aloha.factory.ri2jf.RefInfoToJsonFormat import com.eharmony.aloha.id.ModelIdentity import com.eharmony.aloha.models._ import com.eharmony.aloha.models.reg.RegressionFeatures import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.{GenAggFunc, GenAggFuncAccessorProblems} - -import spray.json.{JsonFormat, JsonReader} +import com.eharmony.aloha.util.SerializabilityEvidence +import spray.json.{JsValue, JsonFormat, JsonReader} import scala.collection.{immutable => sci, mutable => scm} import scala.util.{Failure, Success, Try} @@ -280,7 +281,7 @@ object MultilabelModel extends ParserProviderCompanion { * Report that a `Throwable` was thrown while invoking the predictor * @param modelId An identifier for the model. Used in error reporting. * @param labelInfo labels and information about the labels. - * @param missing missing features from + * @param missingFeatureMap missing features from RegressionFeatures * @param throwable the error the occurred in the predictor. * @param auditor an auditor used to audit the output. * @tparam U upper bound on model output type `B` @@ -291,19 +292,20 @@ object MultilabelModel extends ParserProviderCompanion { protected[multilabel] def reportPredictorError[U, K, B <: U]( modelId: ModelIdentity, labelInfo: LabelsAndInfo[K], - missing: scm.Map[String, Seq[String]], + missingFeatureMap: scm.Map[String, Seq[String]], throwable: Throwable, auditor: Auditor[U, Map[K, Double], B] ): Subvalue[B, Nothing] = { - val pw = new PrintWriter(new StringWriter) + val sw = new StringWriter + val pw = new PrintWriter(sw) throwable.printStackTrace(pw) - val stackTrace = pw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") val aud = auditor.failure( modelId, errorMsgs = stackTrace +: labelInfo.errorMsgs, - missingVarNames = combineMissing(labelInfo, missing) + missingVarNames = combineMissing(labelInfo, missingFeatureMap) ) Subvalue(aud, None) } @@ -346,6 +348,8 @@ object MultilabelModel extends ParserProviderCompanion { LabelsAndInfo(ind, lab, noPrediction, problems) } + private[multilabel] lazy val plugins: Map[String, MultilabelModelParserPlugin] = + MultilabelModelParserPlugin.plugins().map(p => p.name -> p).toMap override def parser: ModelParser = Parser @@ -371,26 +375,94 @@ object MultilabelModel extends ParserProviderCompanion { if (!RefInfoOps.isSubType[N, Map[_, Double]]) None else { - // Because N is a subtype of map, it "should" have two type parameters. - // This is obviously not true in all cases, like with LongMap - // http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap - // TODO: Make this more robust. - val refInfoK = RefInfoOps.typeParams(r).head + val readerAttempt = + for { + refInfoK <- getRefInfoK[N, Any](r).right + jsonFormatK <- getJsonFormatK(factory, refInfoK).right + serEvK <- getSerializableEvidenceK(refInfoK).right + } yield { + MultiLabelModelReader( + refInfoK, + jsonFormatK, + serEvK, + semantics, + auditor, + plugins + ) + } + + readerAttempt match { + case Left(err) => + // TODO: Log err + None + case Right(reader) => Option(reader) + } + } + } + } - // To allow custom class (key) types, we'll need to create a custom ModelFactoryImpl instance - // with a specialized RefInfoToJsonFormat. - // - // type: Option[JsonFormat[_]] - val jsonFormatK = factory.jsonFormat(refInfoK) + def getSerializableEvidenceK[K](refInfoK: RefInfo[K]): Either[String, SerializabilityEvidence[K]] = { + val serEv = + if (RefInfoOps.isSubType(refInfoK, RefInfo.JavaSerializable)) + Option(SerializabilityEvidence.serializableEvidence[java.io.Serializable]) + else if (RefInfoOps.isSubType(refInfoK, RefInfo.AnyVal)) + Option(SerializabilityEvidence.anyValEvidence[AnyVal]) + else None - // TODO: parse the label extraction + val serEvK = serEv.asInstanceOf[Option[SerializabilityEvidence[K]]] - // TODO: parse the feature extraction + serEvK.toRight(s"Couldn't produce evidence that ${RefInfoOps.toString(refInfoK)} is Serializable.") + } - // TODO: parse the native submodel from the wrapped ML library. This involves plugins + def getRefInfoK[N, K](rin: RefInfo[N]): Either[String, RefInfo[K]] = { + // Because N is a subtype of map, it "should" have two type parameters. + // This is obviously not true in all cases, like with LongMap + // http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap + // TODO: Make this more robust. - ??? - } + Option(RefInfoOps.typeParams(rin).head).asInstanceOf[Option[RefInfo[K]]] + .toRight(s"Couldn't extract key type from natural type: ${RefInfoOps.toString(rin)}") + } + + def getJsonFormatK[U, A, K](factory: SubmodelFactory[U, A], refInfoK: RefInfo[K]): Either[String, JsonFormat[K]] = { + // To allow custom class (key) types, we'll need to create a custom ModelFactoryImpl instance + // with a specialized RefInfoToJsonFormat. + factory.jsonFormat(refInfoK) + .toRight(s"Couldn't find a JSON Format for ${RefInfoOps.toString(refInfoK)}. Consider using a different ${classOf[RefInfoToJsonFormat].getCanonicalName}.") + } + + case class MultiLabelModelReader[U, N, K, A, B <: U]( + refInfoK: RefInfo[K], + jsonFormatK: JsonFormat[K], + serEvK: SerializabilityEvidence[K], + semantics: Semantics[A], + auditor: Auditor[U, N, B], + plugins: Map[String, MultilabelModelParserPlugin] + ) extends JsonReader[MultilabelModel[U, K, A, B]] { + override def read(json: JsValue): MultilabelModel[U, K, A, B] = { + + /* + val aud = auditor.asInstanceOf[Auditor[U, Double, B]] + + // Get the metadata necessary to create the model. + val d = json.convertTo[RegData] + + // Turn the map of features into a Seq to fix the order for all subsequent operations because they + // need a common understanding of the indices for the features. + val featureMap: Seq[(String, Spec)] = d.features.toSeq + val featureNameToIndex: Map[String, Int] = featureMap.map(_._1).zipWithIndex.toMap + + // This is the weight vector. + val beta = getBeta(d.features.size, d.weights, higherOrderFeatures(d, featureNameToIndex)) + + val (featureNames, featureFns) = features(featureMap, semantics).fold(f => throw new DeserializationException(f.mkString("\n")), identity).toIndexedSeq.unzip + + val m = RegressionModel(d.modelId, featureNames, featureFns, beta, identity, d.spline, d.numMissingThreshold, aud) + m + + */ + + ??? } } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala new file mode 100644 index 00000000..9f8f7ac1 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala @@ -0,0 +1,41 @@ +package com.eharmony.aloha.models.multilabel + +import com.eharmony.aloha.models.multilabel.json.MultilabelModelAst +import com.eharmony.aloha.reflect.{RefInfo, RuntimeClasspathScanning} +import spray.json.{JsonFormat, JsonReader} + +/** + * A plugin that will produce the + * Created by ryan.deak on 9/6/17. + */ +trait MultilabelModelParserPlugin { + + /** + * A globally unique name for the plugin. If there are name collisions, + * [[MultilabelModelParserPlugin.plugins()]] will error out. + * @return the plugin name. + */ + def name: String + + /** + * Provide a JSON reader that can translate JSON ASTs to a `SparsePredictorProducer`. + * @param ast information about the multi-label model + * @param ri reflection information about the label type + * @param jf a JSON format representing a bidirectional mapping between instances of + * `K` and JSON ASTs. + * @tparam K the label or class type to be produced by the multi-label model. + * @return a JSON reader that can create `SparsePredictorProducer[K]` from JSON ASTs. + */ + def parser[K](ast: MultilabelModelAst) + (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] +} + +object MultilabelModelParserPlugin extends RuntimeClasspathScanning { + + /** + * Finds the plugins in the `com.eharmony.aloha` namespace. + */ + // TODO: Consider making this implicit so that we can parameterize the parser implicitly on a sequence of plugins. + protected[multilabel] def plugins(): Seq[MultilabelModelParserPlugin] = + scanObjects[MultilabelPluginProviderCompanion, MultilabelModelParserPlugin]("multilabelPlugin") +} \ No newline at end of file diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala new file mode 100644 index 00000000..652590ec --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala @@ -0,0 +1,8 @@ +package com.eharmony.aloha.models.multilabel + +/** + * Created by ryan.deak on 9/6/17. + */ +trait MultilabelPluginProviderCompanion { + def multilabelPlugin: MultilabelModelParserPlugin +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala new file mode 100644 index 00000000..84e44571 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala @@ -0,0 +1,11 @@ +package com.eharmony.aloha.models.multilabel.json + +import com.eharmony.aloha.models.reg.json.Spec + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/6/17. + */ +case class MultilabelModelAst(features: ListMap[String, Spec]) + diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala index 7290b1dc..703bead4 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala @@ -1,34 +1,35 @@ package com.eharmony.aloha.reflect -import java.{lang => jl} +import java.{lang => jl, io => ji} import deaktator.reflect.runtime.manifest.ManifestParser object RefInfo { - val Any: RefInfo[Any] = RefInfoOps.refInfo[Any] - val AnyRef: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] - val AnyVal: RefInfo[AnyVal] = RefInfoOps.refInfo[AnyVal] - val Boolean: RefInfo[Boolean] = RefInfoOps.refInfo[Boolean] - val Byte: RefInfo[Byte] = RefInfoOps.refInfo[Byte] - val Char: RefInfo[Char] = RefInfoOps.refInfo[Char] - val Double: RefInfo[Double] = RefInfoOps.refInfo[Double] - val Float: RefInfo[Float] = RefInfoOps.refInfo[Float] - val Int: RefInfo[Int] = RefInfoOps.refInfo[Int] - val JavaBoolean: RefInfo[jl.Boolean] = RefInfoOps.refInfo[jl.Boolean] - val JavaByte: RefInfo[jl.Byte] = RefInfoOps.refInfo[jl.Byte] - val JavaCharacter: RefInfo[jl.Character] = RefInfoOps.refInfo[jl.Character] - val JavaDouble: RefInfo[jl.Double] = RefInfoOps.refInfo[jl.Double] - val JavaFloat: RefInfo[jl.Float] = RefInfoOps.refInfo[jl.Float] - val JavaInteger: RefInfo[jl.Integer] = RefInfoOps.refInfo[jl.Integer] - val JavaLong: RefInfo[jl.Long] = RefInfoOps.refInfo[jl.Long] - val JavaShort: RefInfo[jl.Short] = RefInfoOps.refInfo[jl.Short] - val Long: RefInfo[Long] = RefInfoOps.refInfo[Long] - val Nothing: RefInfo[Nothing] = RefInfoOps.refInfo[Nothing] - val Null: RefInfo[Null] = RefInfoOps.refInfo[Null] - val Object: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] - val Short: RefInfo[Short] = RefInfoOps.refInfo[Short] - val Unit: RefInfo[Unit] = RefInfoOps.refInfo[Unit] - val String: RefInfo[String] = RefInfoOps.refInfo[String] + val Any: RefInfo[Any] = RefInfoOps.refInfo[Any] + val AnyRef: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] + val AnyVal: RefInfo[AnyVal] = RefInfoOps.refInfo[AnyVal] + val Boolean: RefInfo[Boolean] = RefInfoOps.refInfo[Boolean] + val Byte: RefInfo[Byte] = RefInfoOps.refInfo[Byte] + val Char: RefInfo[Char] = RefInfoOps.refInfo[Char] + val Double: RefInfo[Double] = RefInfoOps.refInfo[Double] + val Float: RefInfo[Float] = RefInfoOps.refInfo[Float] + val Int: RefInfo[Int] = RefInfoOps.refInfo[Int] + val JavaBoolean: RefInfo[jl.Boolean] = RefInfoOps.refInfo[jl.Boolean] + val JavaByte: RefInfo[jl.Byte] = RefInfoOps.refInfo[jl.Byte] + val JavaCharacter: RefInfo[jl.Character] = RefInfoOps.refInfo[jl.Character] + val JavaDouble: RefInfo[jl.Double] = RefInfoOps.refInfo[jl.Double] + val JavaFloat: RefInfo[jl.Float] = RefInfoOps.refInfo[jl.Float] + val JavaInteger: RefInfo[jl.Integer] = RefInfoOps.refInfo[jl.Integer] + val JavaLong: RefInfo[jl.Long] = RefInfoOps.refInfo[jl.Long] + val JavaShort: RefInfo[jl.Short] = RefInfoOps.refInfo[jl.Short] + val JavaSerializable: RefInfo[ji.Serializable] = RefInfoOps.refInfo[ji.Serializable] + val Long: RefInfo[Long] = RefInfoOps.refInfo[Long] + val Nothing: RefInfo[Nothing] = RefInfoOps.refInfo[Nothing] + val Null: RefInfo[Null] = RefInfoOps.refInfo[Null] + val Object: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] + val Short: RefInfo[Short] = RefInfoOps.refInfo[Short] + val Unit: RefInfo[Unit] = RefInfoOps.refInfo[Unit] + val String: RefInfo[String] = RefInfoOps.refInfo[String] def apply[A: RefInfo]: RefInfo[A] = RefInfoOps.refInfo[A] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala new file mode 100644 index 00000000..43520318 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala @@ -0,0 +1,66 @@ +package com.eharmony.aloha.reflect + +import com.eharmony.aloha +import org.reflections.Reflections + +import scala.reflect.ClassTag +import scala.util.Try + +/** + * Created by ryan.deak on 9/6/17. + */ +trait RuntimeClasspathScanning { + + private[this] val objectSuffix = "$" + + /** + * Determine if the class is a Scala object by looking at the class name. If it + * ends in the `objectSuffix` + * @param c a `Class` instance. + * @tparam A type of Class. + * @return + */ + private[this] def isObject[A](c: Class[A]) = c.getCanonicalName.endsWith(objectSuffix) + + /** + * Scan the classpath within subpackages of `packageToSearch` to find Scala objects that + * contain extend `OBJ` and contain method called `methodName` that should return an `A`. + * Call the method and get the result. If the result is indeed an `A`, add to it to + * the resulting sequence. + * + * This function makes some assumptions. It assumes that Scala objects have classes + * ending with '$'. It also assumes that methods in objects have static forwarders. + * + * @param methodName the name of a method in an object that should be found and called. + * @param packageToSearch the package to search for candidates + * @tparam OBJ the super type of the objects the search should find + * @tparam A the output type of elements that should be returned by the method named + * `methodName` + * @return a sequence of `A` instances that could be found. + */ + protected[this] def scanObjects[OBJ: ClassTag, A: ClassTag]( + methodName: String, + packageToSearch: String = aloha.pkgName + ): Seq[A] = { + val reflections = new Reflections(aloha.pkgName) + import scala.collection.JavaConversions.asScalaSet + // val classA = implicitly[ClassTag[A]].runtimeClass + val objects = reflections.getSubTypesOf(implicitly[ClassTag[OBJ]].runtimeClass).toSeq + + val suffixLength = objectSuffix.length + + objects.flatMap { + case o if isObject(o) => + Try { + // This may have some classloading issues. + val classObj = Class.forName(o.getCanonicalName.dropRight(suffixLength)) + classObj.getMethod(methodName).invoke(null) match { + case a: A => a + case _ => throw new IllegalStateException() + } + }.toOption + case _ => None + } + } +} + diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala similarity index 73% rename from aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala rename to aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala index 281bde1a..2ac9d561 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/SerializabilityEvidence.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala @@ -1,4 +1,4 @@ -package com.eharmony.aloha.models.multilabel +package com.eharmony.aloha.util /** * A type class used to indicate a parameter has a type that can be serialized in @@ -11,7 +11,7 @@ object SerializabilityEvidence { implicit def anyValEvidence[A <: AnyVal]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} - implicit def serializableEvidence[A <: Serializable]: SerializabilityEvidence[A] = + implicit def serializableEvidence[A <: java.io.Serializable]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala new file mode 100644 index 00000000..81734f26 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala @@ -0,0 +1,110 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.Closeable + +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.io.sources.ModelSource +import com.eharmony.aloha.models.multilabel._ +import com.eharmony.aloha.models.multilabel.json.MultilabelModelAst +import com.eharmony.aloha.reflect.RefInfo +import spray.json.{JsValue, JsonFormat, JsonReader} +import spray.json.DefaultJsonProtocol._ + +import scala.collection.immutable.ListMap +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 9/5/17. + */ +case class VwMultilabelModelPlugin[K]( + modelSource: ModelSource, + params: String, + defaultNs: List[Int], + namespaces: List[(String, List[Int])]) +extends SparsePredictorProducer[K] { + + import VwMultilabelModelPlugin._ + + override def apply(): VwSparsePredictorProducer[K] = + VwSparsePredictorProducer[K](modelSource, params, defaultNs, namespaces) +} + +object VwMultilabelModelPlugin extends MultilabelPluginProviderCompanion { + + private val JsonErrStrLength = 100 + + def multilabelPlugin = Plugin + object Plugin extends MultilabelModelParserPlugin { + override def name: String = "vw" + override def parser[K](ast: MultilabelModelAst) + (implicit ri: RefInfo[K], jf: JsonFormat[K]): VwMultilabelModelPluginJsonReader[K] = { + VwMultilabelModelPluginJsonReader[K](ast.features.keys.toVector) + } + } + + + private[multilabel] case class VwMultilabelAst( + modelSource: ModelSource, + params: Option[Either[Seq[String], String]] = Option(Right("")), + namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) + ) + + private[this] implicit val vwMultilabelAstFormat = jsonFormat3(VwMultilabelAst) + + case class VwSparsePredictorProducer[K]( + modelSource: ModelSource, + params: String, + defaultNs: List[Int], + namespaces: List[(String, List[Int])]) + extends SparseMultiLabelPredictor[K] + with Closeable { + + // TODO: Create the model. + + override def apply( + features: IndexedSeq[Sparse], + labels: sci.IndexedSeq[K], + indices: sci.IndexedSeq[Int], + labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] + ): Map[K, Double] = { + + // Create VW input. + + ??? + } + + override def close(): Unit = { + // TODO: Add actual + } + } + + case class VwMultilabelModelPluginJsonReader[K](featureNames: Vector[String]) extends JsonReader[VwMultilabelModelPlugin[K]] { + + private[multilabel] def namespaceMapping(featureNames: Vector[String], nsMap: ListMap[String, Seq[String]]): (List[Int], List[(String, List[Int])]) = { + + ??? + } + + override def read(json: JsValue): VwMultilabelModelPlugin[K] = { + val ast = json.asJsObject(notObjErr(json)).convertTo[VwMultilabelAst] + + val params = ast.params.map { + case Left(paramList) => paramList.mkString(" ") + case Right(paramsStr) => paramsStr + } + + params.map { ps => + val (defaultNs, namespaces) = namespaceMapping(featureNames, ast.namespaces.getOrElse(ListMap.empty)) + VwMultilabelModelPlugin[K](ast.modelSource, ps, defaultNs, namespaces) + } getOrElse { + throw new Exception("no VW params provided") + } + } + + protected[multilabel] def notObjErr(json: JsValue): String = { + val str = json.prettyPrint + val substr = str.substring(0, JsonErrStrLength) + s"JSON object expected. Found " + substr + (if (str.length != substr.length) " ..." else "") + } + } +} \ No newline at end of file From e450ce10e7188f19e5d9c67e4a613a9a93a78b50 Mon Sep 17 00:00:00 2001 From: amirziai Date: Thu, 7 Sep 2017 12:09:12 -0700 Subject: [PATCH 30/98] Empty label problems test --- .../multilabel/MultilabelModelTest.scala | 70 +++++++++++++------ .../reg/PolynomialEvaluationAlgoTest.scala | 1 - 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index e03df5c4..754a31d4 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -6,7 +6,7 @@ import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId -import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.semantics.func._ import org.junit.Test import org.junit.Assert._ import org.junit.runner.RunWith @@ -186,23 +186,49 @@ class MultilabelModelTest extends ModelSerializationTestHelper { assertEquals(report.natural, report.audited.value) } -// @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { -// // Test this: -// // val problems = -// // if (labelsShouldPredict.nonEmpty) None -// // else Option(labelsOfInterest.accessorOutputProblems(a)) -// -// fail() -// } -// -// @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { -// // Test this: -// // val noPrediction = -// // if (unsorted.size == labelsShouldPredict.size) Seq.empty -// // else labelsShouldPredict.filterNot(labelToInd.contains) -// -// fail() -// } + @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { + def extractLabelsOutOfExample(example: Map[String, String]) = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + + // Example with no problems + val example: Map[String, String] = Map( + "feature1" -> "1", + "feature2" -> "2", + "feature3" -> "2", + "label1" -> "a", + "label2" -> "b" + ) + val allLabels = sci.IndexedSeq("a", "b", "c") + val labelToInt = allLabels.zipWithIndex.toMap + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) + assertEquals(None, labelsAndInfo.problems) + + // Example with 1 missing label + val exampleMissingOneLabel = Map("feature1" -> "1", "label1" -> "a") + val labelsAndInfoMissingOneLabel = labelsForPrediction( + exampleMissingOneLabel, + labelsOfInterestExtractor, + labelToInt) + assertEquals(None, labelsAndInfoMissingOneLabel.problems) + + // Example with no labels + val exampleNoLabels = Map("feature1" -> "1", "feature2" -> "2") + val labelsAndInfoNoLabels = labelsForPrediction(exampleNoLabels, + labelsOfInterestExtractor, + labelToInt) + val problemsNoLabels = Option(GenAggFuncAccessorProblems(Seq(), Seq())) + assertEquals(problemsNoLabels, labelsAndInfoNoLabels.problems) + } + + @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { + // Test this: + // val noPrediction = + // if (unsorted.size == labelsShouldPredict.size) Seq.empty + // else labelsShouldPredict.filterNot(labelToInd.contains) + + fail() + } // // @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { // // Test this: @@ -277,11 +303,13 @@ object MultilabelModelTest { Auditor ) - val aud = RootedTreeAuditor[Any, Map[Label, Double]]() + val aud: RootedTreeAuditor[Any, Map[Label, Double]] = RootedTreeAuditor[Any, Map[Label, Double]]() // private val failure = aud.failure() - val baseErrorMessage = Stream.continually("Label not in training labels: ") - val errorMessages = baseErrorMessage.zip(missingLabels).map{ case(msg, label) => s"$msg$label" } + val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") + val errorMessages: Seq[String] = baseErrorMessage.zip(missingLabels).map { + case(msg, label) => s"$msg$label" + } // TODO: Access information returned in audited value by using the following functions: // val aud: RootedTree[Any, Map[Label, Double]] = ??? diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala index 0acb71ea..41601246 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala @@ -30,7 +30,6 @@ class PolynomialEvaluationAlgoTest { private[this] val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]()) - @Test def testManualPolyEval() { val x = IndexedSeq( Seq(("intercept", 1.0)), From 77d57e0c3ec0c756a23c8d2470e7c2ecfeece39f Mon Sep 17 00:00:00 2001 From: amirziai Date: Thu, 7 Sep 2017 12:10:52 -0700 Subject: [PATCH 31/98] More explicit val name --- .../aloha/models/multilabel/MultilabelModelTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 754a31d4..55536b1c 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -217,8 +217,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val labelsAndInfoNoLabels = labelsForPrediction(exampleNoLabels, labelsOfInterestExtractor, labelToInt) - val problemsNoLabels = Option(GenAggFuncAccessorProblems(Seq(), Seq())) - assertEquals(problemsNoLabels, labelsAndInfoNoLabels.problems) + val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq())) + assertEquals(problemsNoLabelsExpected, labelsAndInfoNoLabels.problems) } @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { From 6dc5636a34668ccde57c3504c86c8d186cb57773 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 7 Sep 2017 17:42:37 -0700 Subject: [PATCH 32/98] MultilabelModel parsing is compiling. --- .../models/multilabel/MultilabelModel.scala | 153 ++++++++++-------- .../MultilabelModelParserPlugin.scala | 5 +- .../aloha/models/multilabel/PluginInfo.scala | 12 ++ .../multilabel/json/MultilabelModelAst.scala | 11 -- .../multilabel/json/MultilabelModelJson.scala | 43 +++++ .../json/MultilabelModelReader.scala | 111 +++++++++++++ 6 files changed, 257 insertions(+), 78 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala delete mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index c5088921..dba57ec5 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -8,12 +8,13 @@ import com.eharmony.aloha.factory._ import com.eharmony.aloha.factory.ri2jf.RefInfoToJsonFormat import com.eharmony.aloha.id.ModelIdentity import com.eharmony.aloha.models._ +import com.eharmony.aloha.models.multilabel.json.MultilabelModelReader import com.eharmony.aloha.models.reg.RegressionFeatures import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.{GenAggFunc, GenAggFuncAccessorProblems} -import com.eharmony.aloha.util.SerializabilityEvidence -import spray.json.{JsValue, JsonFormat, JsonReader} +import com.eharmony.aloha.util.{Logging, SerializabilityEvidence} +import spray.json.{JsonFormat, JsonReader} import scala.collection.{immutable => sci, mutable => scm} import scala.util.{Failure, Success, Try} @@ -353,7 +354,7 @@ object MultilabelModel extends ParserProviderCompanion { override def parser: ModelParser = Parser - object Parser extends ModelSubmodelParsingPlugin { + object Parser extends ModelSubmodelParsingPlugin with Logging { override val modelType: String = "multilabel-sparse" // TODO: Figure if a Option[JsonReader[MultilabelModel[U, _, A, B]]] can be returned. @@ -371,29 +372,34 @@ object MultilabelModel extends ParserProviderCompanion { auditor: Auditor[U, N, B])(implicit r: RefInfo[N], jf: JsonFormat[N] - ): Option[JsonReader[_ <: Model[A, B] with Submodel[_, A, B]]] = { - if (!RefInfoOps.isSubType[N, Map[_, Double]]) + ): Option[JsonReader[Model[A, B] with Submodel[N, A, B]]] = { + + // If the type N is a Map with Double values, we can specify a key type (call it K). + // Then the necessary type classes related to K are *instantiated* and a reader is + // created using types N and K. Ultimately, the reader consumes the type `K` but + // only the type N is exposed in the returned reader. + if (!RefInfoOps.isSubType[N, Map[_, Double]]) { + warn(s"N=${RefInfoOps.toString[N]} is not a Map[K, Double]. Cannot create a JsonReader for MultilabelModel.") None + } + else if (2 != RefInfoOps.typeParams(r).size) { + warn(s"N=${RefInfoOps.toString[N]} does not have 2 type parameters. Cannot infer label type K needed create a JsonReader for MultilabelModel.") + None + } else { + // This would be a prime candidate for the WriterT monad transformer: + // type Result[A] = WriterT[Option, Vector[String], A] + // https://github.com/typelevel/cats/blob/0.7.x/core/src/main/scala/cats/data/WriterT.scala#L8 val readerAttempt = for { - refInfoK <- getRefInfoK[N, Any](r).right - jsonFormatK <- getJsonFormatK(factory, refInfoK).right - serEvK <- getSerializableEvidenceK(refInfoK).right - } yield { - MultiLabelModelReader( - refInfoK, - jsonFormatK, - serEvK, - semantics, - auditor, - plugins - ) - } + ri <- refInfoOrError[N, Any](r).right // Force type of K = Any + jf <- jsonFormatOrError(factory, ri).right + se <- serializableEvidenceOrError(ri).right + } yield reader(semantics, auditor, ri, jf, se) readerAttempt match { case Left(err) => - // TODO: Log err + warn(err) None case Right(reader) => Option(reader) } @@ -401,7 +407,49 @@ object MultilabelModel extends ParserProviderCompanion { } } - def getSerializableEvidenceK[K](refInfoK: RefInfo[K]): Either[String, SerializabilityEvidence[K]] = { + /** + * Produce the reader. + * + * '''NOTE''': This function should only be applied after we know `N` equals `Map[K, Double]` + * for some `K`. + * + * @param semantics semantics to use for compiling specifications. + * @param auditor an auditor of type `Auditor[U, N, B]`. It's not of type + * `Auditor[U, Map[K, Double], B]`, but because we know the relationship + * between `N` and `Map[K, Double]`, we can cast `N` to `Map[K, Double]`. + * @param ri reflection information about `K`. + * @param jf a JSON format that can translate JSON ASTs to and from `K`s. + * @param se evidence that `K` is `Serializable`. + * @tparam U upper bound on model output type `B` + * @tparam N the expected natural output type the model that will be produced by the + * JSON reader. This should be isomorphic to `Map[K, Double]`. + * @tparam K type of label or class + * @tparam A input type of the model. + * @tparam B output type of the model. + * @return a JSON reader capable of producing a [[MultilabelModel]] from a JSON definition. + */ + private[multilabel] def reader[U, N, K, A, B <: U]( + semantics: Semantics[A], + auditor: Auditor[U, N, B], + ri: RefInfo[K], + jf: JsonFormat[K], + se: SerializabilityEvidence[K]): JsonReader[Model[A, B] with Submodel[N, A, B]] = { + // At this point, N = Map[K, Double], so we are just casting to itself essentially. + val aud = auditor.asInstanceOf[Auditor[U, Map[K, Double], B]] + + // Create a cast from Map[K, Double] to N. Ideally, this would be a Map[K, Double] <:< N + // rather than a Map[K, Double] => N. But it's hard (with good reason) to create a <:< + // and easy to create a function. + implicit val cast = (m: Map[K, Double]) => m.asInstanceOf[N] + + // MultilabelModelReader can produce a MultilabelModel. But there's a problem returning + // the proper type because the compiler doesn't have compile-time evidence that N is + // Map[K, Double] so, a less specific type (Model[A, B] with Submodel[N, A, B]) is + // returned. + MultilabelModelReader(semantics, aud, plugins)(ri, jf, se).untypedReader[N] + } + + private[multilabel] def serializableEvidence[K](refInfoK: RefInfo[K]) = { val serEv = if (RefInfoOps.isSubType(refInfoK, RefInfo.JavaSerializable)) Option(SerializabilityEvidence.serializableEvidence[java.io.Serializable]) @@ -409,60 +457,37 @@ object MultilabelModel extends ParserProviderCompanion { Option(SerializabilityEvidence.anyValEvidence[AnyVal]) else None - val serEvK = serEv.asInstanceOf[Option[SerializabilityEvidence[K]]] + serEv.asInstanceOf[Option[SerializabilityEvidence[K]]] + } - serEvK.toRight(s"Couldn't produce evidence that ${RefInfoOps.toString(refInfoK)} is Serializable.") + private[multilabel] def serializableEvidenceOrError[K](refInfoK: RefInfo[K]) = { + serializableEvidence(refInfoK) + .toRight(s"Couldn't produce evidence that ${RefInfoOps.toString(refInfoK)} is Serializable.") } - def getRefInfoK[N, K](rin: RefInfo[N]): Either[String, RefInfo[K]] = { - // Because N is a subtype of map, it "should" have two type parameters. - // This is obviously not true in all cases, like with LongMap - // http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap - // TODO: Make this more robust. + /** + * Get reflection information about the label type `K` for the [[MultilabelModel]] to be produced. + * + * At the time of application, we should already know N is a subtype of Map with 2 type parameters. + * This will preclude things like + * [[http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap scala.collection.immutable.LongMap]]. + * @param rin reflection information about `N`. + * @tparam N natural output type of the model. `N` should equal `Map[K, Double]`. + * @tparam K label type of the [[MultilabelModel]] + * @return reflection information about K. + */ + private[multilabel] def refInfo[N, K](rin: RefInfo[N]) = + RefInfoOps.typeParams(rin).headOption.asInstanceOf[Option[RefInfo[K]]] - Option(RefInfoOps.typeParams(rin).head).asInstanceOf[Option[RefInfo[K]]] + private[multilabel] def refInfoOrError[N, K](rin: RefInfo[N]) = { + refInfo[N, K](rin) .toRight(s"Couldn't extract key type from natural type: ${RefInfoOps.toString(rin)}") } - def getJsonFormatK[U, A, K](factory: SubmodelFactory[U, A], refInfoK: RefInfo[K]): Either[String, JsonFormat[K]] = { + private[multilabel] def jsonFormatOrError[U, A, K](factory: SubmodelFactory[U, A], refInfoK: RefInfo[K]): Either[String, JsonFormat[K]] = { // To allow custom class (key) types, we'll need to create a custom ModelFactoryImpl instance // with a specialized RefInfoToJsonFormat. factory.jsonFormat(refInfoK) .toRight(s"Couldn't find a JSON Format for ${RefInfoOps.toString(refInfoK)}. Consider using a different ${classOf[RefInfoToJsonFormat].getCanonicalName}.") } - - case class MultiLabelModelReader[U, N, K, A, B <: U]( - refInfoK: RefInfo[K], - jsonFormatK: JsonFormat[K], - serEvK: SerializabilityEvidence[K], - semantics: Semantics[A], - auditor: Auditor[U, N, B], - plugins: Map[String, MultilabelModelParserPlugin] - ) extends JsonReader[MultilabelModel[U, K, A, B]] { - override def read(json: JsValue): MultilabelModel[U, K, A, B] = { - - /* - val aud = auditor.asInstanceOf[Auditor[U, Double, B]] - - // Get the metadata necessary to create the model. - val d = json.convertTo[RegData] - - // Turn the map of features into a Seq to fix the order for all subsequent operations because they - // need a common understanding of the indices for the features. - val featureMap: Seq[(String, Spec)] = d.features.toSeq - val featureNameToIndex: Map[String, Int] = featureMap.map(_._1).zipWithIndex.toMap - - // This is the weight vector. - val beta = getBeta(d.features.size, d.weights, higherOrderFeatures(d, featureNameToIndex)) - - val (featureNames, featureFns) = features(featureMap, semantics).fold(f => throw new DeserializationException(f.mkString("\n")), identity).toIndexedSeq.unzip - - val m = RegressionModel(d.modelId, featureNames, featureFns, beta, identity, d.spline, d.numMissingThreshold, aud) - m - - */ - - ??? - } - } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala index 9f8f7ac1..bb5b586f 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala @@ -1,6 +1,5 @@ package com.eharmony.aloha.models.multilabel -import com.eharmony.aloha.models.multilabel.json.MultilabelModelAst import com.eharmony.aloha.reflect.{RefInfo, RuntimeClasspathScanning} import spray.json.{JsonFormat, JsonReader} @@ -19,14 +18,14 @@ trait MultilabelModelParserPlugin { /** * Provide a JSON reader that can translate JSON ASTs to a `SparsePredictorProducer`. - * @param ast information about the multi-label model + * @param info information about the multi-label model passed to the plugin. * @param ri reflection information about the label type * @param jf a JSON format representing a bidirectional mapping between instances of * `K` and JSON ASTs. * @tparam K the label or class type to be produced by the multi-label model. * @return a JSON reader that can create `SparsePredictorProducer[K]` from JSON ASTs. */ - def parser[K](ast: MultilabelModelAst) + def parser[K](info: PluginInfo) (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala new file mode 100644 index 00000000..1fe8d0f3 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala @@ -0,0 +1,12 @@ +package com.eharmony.aloha.models.multilabel + +import com.eharmony.aloha.models.reg.json.Spec + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/7/17. + */ +trait PluginInfo { + def features: ListMap[String, Spec] +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala deleted file mode 100644 index 84e44571..00000000 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelAst.scala +++ /dev/null @@ -1,11 +0,0 @@ -package com.eharmony.aloha.models.multilabel.json - -import com.eharmony.aloha.models.reg.json.Spec - -import scala.collection.immutable.ListMap - -/** - * Created by ryan.deak on 9/6/17. - */ -case class MultilabelModelAst(features: ListMap[String, Spec]) - diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala new file mode 100644 index 00000000..c0e4dadf --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala @@ -0,0 +1,43 @@ +package com.eharmony.aloha.models.multilabel.json + +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.multilabel.PluginInfo +import com.eharmony.aloha.models.reg.json.{Spec, SpecJson} +import spray.json.DefaultJsonProtocol._ +import spray.json.{JsObject, JsonFormat, RootJsonFormat} + +import scala.collection.immutable.ListMap +import com.eharmony.aloha.factory.ScalaJsonFormats + +trait MultilabelModelJson extends SpecJson with ScalaJsonFormats { + + protected[this] case class Plugin(`type`: String) + + /** + * Data for the + * + * @param modelType + * @param modelId + * @param features + * @param numMissingThreshold + * @param labelsInTrainingSet + * @param labelsOfInterest + * @param underlying the underlying model that will be produced by a + * @tparam K + */ + protected[this] case class MultilabelData[K]( + modelType: String, + modelId: ModelId, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Vector[K], + labelsOfInterest: Option[String], + underlying: JsObject + ) extends PluginInfo + + protected[this] final implicit def multilabelDataJsonFormat[K: JsonFormat]: RootJsonFormat[MultilabelData[K]] = + jsonFormat7(MultilabelData[K]) + + protected[this] final implicit val pluginJsonFormat: RootJsonFormat[Plugin] = + jsonFormat1(Plugin) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala new file mode 100644 index 00000000..fc94e039 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala @@ -0,0 +1,111 @@ +package com.eharmony.aloha.models.multilabel.json + +import com.eharmony.aloha.audit.Auditor +import com.eharmony.aloha.models.{Model, Submodel} +import com.eharmony.aloha.models.multilabel.{MultilabelModel, MultilabelModelParserPlugin} +import com.eharmony.aloha.models.reg.RegFeatureCompiler +import com.eharmony.aloha.models.reg.json.Spec +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.Semantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.{EitherHelpers, SerializabilityEvidence} +import spray.json.{DeserializationException, JsValue, JsonFormat, JsonReader} + +import scala.collection.{immutable => sci} + +/** + * A JSON reader capable of turning JSON to a [[MultilabelModel]]. + * Created by ryan.deak on 9/7/17. + * @param semantics semantics used to generate features and labels of interest. + * @param auditor the auditor used to translate the [[MultilabelModel]] output values to `B` instances. + * @param plugins the possible plugins to which [[MultilabelModel]] can delegate to produce predictions. + * Plugins are responsible for creating the `predictorProducer` passed to the + * [[MultilabelModel]] constructor. + * @param refInfoK reflection information about the label type. + * @param jsonFormatK a JSON format capable of parsing the label type. + * @param serEvK evidence that the label type is `Serializable`. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam A input type of the model + * @tparam B output type of the model. + */ +final case class MultilabelModelReader[U, K, A, B <: U]( + semantics: Semantics[A], + auditor: Auditor[U, Map[K, Double], B], + plugins: Map[String, MultilabelModelParserPlugin]) +(implicit + refInfoK: RefInfo[K], + jsonFormatK: JsonFormat[K], + serEvK: SerializabilityEvidence[K]) + extends JsonReader[MultilabelModel[U, K, A, B]] + with MultilabelModelJson + with EitherHelpers + with RegFeatureCompiler { self => + + /** + * Creates a reader for [[MultilabelModel]] that has a less specific type assuming that we + * can produce evidence that `N` is a super type of `Map[K, Double]`. + * @param ev Provides a conversion from `Map[K, Double]` to `N`. This indicates a subtype + * relationship and is ''close to the same'' as implicit subtype evidence + * `<:<`, but is not as strong (because it's not statically checked by the compiler). + * @tparam N The natural output of the Model that is a super type of `Map[K, Double]`. Since + * auditors take values of type `N` as input, supplying a subtype should also work. + * @return A JSON reader the creates [[MultilabelModel]], but returns instances as a less + * specific type. + */ + private[multilabel] def untypedReader[N](implicit ev: Map[K, Double] => N): JsonReader[Model[A, B] with Submodel[N, A, B]] = + new JsonReader[Model[A, B] with Submodel[N, A, B]]{ + override def read(json: JsValue): Model[A, B] with Submodel[N, A, B] = + self.read(json).asInstanceOf[Model[A, B] with Submodel[N, A, B]] + } + + override def read(json: JsValue): MultilabelModel[U, K, A, B] = { + val mlData = json.convertTo[MultilabelData[K]] + val pluginType = mlData.underlying.convertTo[Plugin].`type` + + plugins.get(pluginType) + .map { plugin => model(mlData, plugin) } + .getOrElse { throw new DeserializationException(errorMsg(pluginType)) } + } + + private[multilabel] def errorMsg(pluginType: String): String = { + val pluginNames = plugins.mapValues(p => p.getClass.getCanonicalName) + s"Couldn't find plugin for type $pluginType. Plugins available: $pluginNames" + } + + private[multilabel] def compileLabelsOfInterest(labelsOfInterestSpec: String): ENS[GenAggFunc[A, sci.IndexedSeq[K]]] = { + semantics.createFunction[sci.IndexedSeq[K]](labelsOfInterestSpec, Option(Vector.empty[K])). + left.map { Seq(s"Error processing spec '$labelsOfInterestSpec'") ++ _ } + } + + private[multilabel] def model(mlData: MultilabelData[K], plugin: MultilabelModelParserPlugin): MultilabelModel[U, K, A, B] = { + val predProdReader = plugin.parser(mlData)(refInfoK, jsonFormatK) + val predProd = mlData.underlying.convertTo(predProdReader) + + // Compile the labels of interest function. + val labelsOfInterest = mlData.labelsOfInterest.map { spec => + compileLabelsOfInterest(spec) match { + case Left(errs) => throw new DeserializationException(errs.mkString("\n")) + case Right(f) => f + } + } + + // Compile the features. + val featureMap: Seq[(String, Spec)] = mlData.features.toSeq + val (featureNames, featureFns) = + features(featureMap, semantics) + .fold(f => throw new DeserializationException(f.mkString("\n")), identity) + .toIndexedSeq.unzip + + MultilabelModel( + modelId = mlData.modelId, + featureNames = featureNames, + featureFunctions = featureFns, + labelsInTrainingSet = mlData.labelsInTrainingSet, + labelsOfInterest = labelsOfInterest, + predictorProducer = predProd, + numMissingThreshold = mlData.numMissingThreshold, + auditor = auditor + )(serEvK) + } +} \ No newline at end of file From c6152ac9df78cb55297344a776e102086e488d53 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 7 Sep 2017 17:43:26 -0700 Subject: [PATCH 33/98] Added changed from Iterable[(String, Double)] to Sparse. --- .../com/eharmony/aloha/models/reg/RegFeatureCompiler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala index fd95d951..4045041a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala @@ -1,5 +1,6 @@ package com.eharmony.aloha.models.reg +import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.models.reg.json.Spec import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.GenAggFunc @@ -23,7 +24,7 @@ trait RegFeatureCompiler { self: EitherHelpers => protected[this] def features[A](featureMap: Seq[(String, Spec)], semantics: Semantics[A]): ENS[Seq[(String, GenAggFunc[A, Iterable[(String, Double)]])]] = mapSeq(featureMap) { case (k, Spec(spec, default)) => - semantics.createFunction[Iterable[(String, Double)]](spec, default). + semantics.createFunction[Sparse](spec, default). left.map { Seq(s"Error processing spec '$spec'") ++ _ }. // Add the spec that errored. right.map { f => (k, f) } } From ff1725a7fc23f2871df0d29ad16c6e4d894f510d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 7 Sep 2017 17:54:39 -0700 Subject: [PATCH 34/98] new line at EOF. --- .../aloha/models/multilabel/json/MultilabelModelReader.scala | 2 +- .../eharmony/aloha/models/multilabel/MultilabelModelTest.scala | 2 +- .../models/vw/jni/multilabel/VwMultilabelModelPlugin.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala index fc94e039..82e013bb 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala @@ -108,4 +108,4 @@ final case class MultilabelModelReader[U, K, A, B <: U]( auditor = auditor )(serEvK) } -} \ No newline at end of file +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 14897eb9..58f080c8 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -167,4 +167,4 @@ object MultilabelModelTest { // aud.missingVarNames // : Set[String] // aud.errorMsgs // : Seq[String] // aud.prob // : Option[Float] (Shouldn't need this) -} \ No newline at end of file +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala index 81734f26..c8ad7947 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala @@ -107,4 +107,4 @@ object VwMultilabelModelPlugin extends MultilabelPluginProviderCompanion { s"JSON object expected. Found " + substr + (if (str.length != substr.length) " ..." else "") } } -} \ No newline at end of file +} From 79422abe8e4802bc29f62cd4bf9bea0a03174851 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 8 Sep 2017 17:38:26 -0700 Subject: [PATCH 35/98] VW compiling but still a few holes to fill in. Added Namespaces trait. --- .../aloha/models/vw/jni/Namespaces.scala | 50 ++++++++ .../aloha/models/vw/jni/VwJniModel.scala | 21 ++-- .../multilabel/VwMultilabelModelPlugin.scala | 110 ------------------ .../VwSparseMultilabelPredictor.scala | 94 +++++++++++++++ .../VwSparseMultilabelPredictorProducer.scala | 49 ++++++++ .../json/VwMultilabelModelJson.scala | 23 ++++ .../VwMultilabelModelPluginJsonReader.scala | 55 +++++++++ 7 files changed, 278 insertions(+), 124 deletions(-) create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala delete mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala new file mode 100644 index 00000000..d268ccfc --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala @@ -0,0 +1,50 @@ +package com.eharmony.aloha.models.vw.jni + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/8/17. + */ +trait Namespaces { + + def namespaceIndicesAndMissing( + indices: Map[String, Int], + namespaces: ListMap[String, Seq[String]] + ): (List[(String, List[Int])], List[(String, List[String])]) = { + val nss = namespaces.map { case (ns, fs) => + val indMissing = fs.map { f => + val oInd = indices.get(f) + if (oInd.isEmpty) + (oInd, Option(f)) + else (oInd, None) + } + + val (oInd, oMissing) = indMissing.unzip + (ns, oInd.flatten.toList, oMissing.flatten.toList) + } + + val nsInd = nss.collect { case (ns, ind, _) if ind.nonEmpty => (ns, ind) }.toList + val nsMissing = nss.collect { case (ns, _, m) if m.nonEmpty => (ns, m) }.toList + (nsInd, nsMissing) + } + + def defaultNamespaceIndices( + indices: Map[String, Int], + namespaces: ListMap[String, Seq[String]] + ): List[Int] = { + val featuresInNss = namespaces.foldLeft(Set.empty[String]){ case (s, (_, fs)) => s ++ fs } + val unusedFeatures = indices.keySet -- featuresInNss + unusedFeatures.flatMap(indices.get).toList + } + + def allNamespaceIndices( + featureNames: Seq[String], + namespaces: ListMap[String, Seq[String]] + ): (List[(String, List[Int])], List[Int], List[(String, List[String])]) = { + val indices = featureNames.zipWithIndex.toMap + val (nsi, missing) = namespaceIndicesAndMissing(indices, namespaces) + val defaultNsInd = defaultNamespaceIndices(indices, namespaces) + + (nsi, defaultNsInd, missing) + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala index 41b30318..8887a9e6 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala @@ -219,7 +219,8 @@ object VwJniModel object Parser extends ModelSubmodelParsingPlugin with EitherHelpers - with RegFeatureCompiler { self => + with RegFeatureCompiler + with Namespaces { self => val modelType = "VwJNI" @@ -239,21 +240,13 @@ object VwJniModel case Right(featureMap) => val (names, functions) = featureMap.toIndexedSeq.unzip - val indices = featureMap.unzip._1.view.zipWithIndex.toMap - val nssRaw = vw.namespaces.getOrElse(sci.ListMap.empty) - val nss = nssRaw.map { case (ns, fs) => - val fi = fs.flatMap { f => - val oInd = indices.get(f) - if (oInd.isEmpty) info(s"Ignoring feature '$f' in namespace '$ns'. Not in the feature list.") - oInd - }.toList - (ns, fi) - }.toList + val (nss, defaultNs, missing) = + allNamespaceIndices(names, vw.namespaces.getOrElse(sci.ListMap.empty)) - val vwParams = vw.vw.params.fold("")(_.fold(_.mkString(" "), x => x)).trim - - val defaultNs = (indices.keySet -- (Set.empty[String] /: nssRaw)(_ ++ _._2)).flatMap(indices.get).toList + if (missing.nonEmpty) + info(s"Ignoring features in namespaces not in feature list: $missing") + val vwParams = vw.vw.params.fold("")(_.fold(_.mkString(" "), x => x)).trim val learnerCreator = LearnerCreator[N](vw.classLabels, vw.spline, vwParams) VwJniModel( diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala deleted file mode 100644 index c8ad7947..00000000 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelPlugin.scala +++ /dev/null @@ -1,110 +0,0 @@ -package com.eharmony.aloha.models.vw.jni.multilabel - -import java.io.Closeable - -import com.eharmony.aloha.dataset.density.Sparse -import com.eharmony.aloha.io.sources.ModelSource -import com.eharmony.aloha.models.multilabel._ -import com.eharmony.aloha.models.multilabel.json.MultilabelModelAst -import com.eharmony.aloha.reflect.RefInfo -import spray.json.{JsValue, JsonFormat, JsonReader} -import spray.json.DefaultJsonProtocol._ - -import scala.collection.immutable.ListMap -import scala.collection.{immutable => sci} - -/** - * Created by ryan.deak on 9/5/17. - */ -case class VwMultilabelModelPlugin[K]( - modelSource: ModelSource, - params: String, - defaultNs: List[Int], - namespaces: List[(String, List[Int])]) -extends SparsePredictorProducer[K] { - - import VwMultilabelModelPlugin._ - - override def apply(): VwSparsePredictorProducer[K] = - VwSparsePredictorProducer[K](modelSource, params, defaultNs, namespaces) -} - -object VwMultilabelModelPlugin extends MultilabelPluginProviderCompanion { - - private val JsonErrStrLength = 100 - - def multilabelPlugin = Plugin - object Plugin extends MultilabelModelParserPlugin { - override def name: String = "vw" - override def parser[K](ast: MultilabelModelAst) - (implicit ri: RefInfo[K], jf: JsonFormat[K]): VwMultilabelModelPluginJsonReader[K] = { - VwMultilabelModelPluginJsonReader[K](ast.features.keys.toVector) - } - } - - - private[multilabel] case class VwMultilabelAst( - modelSource: ModelSource, - params: Option[Either[Seq[String], String]] = Option(Right("")), - namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) - ) - - private[this] implicit val vwMultilabelAstFormat = jsonFormat3(VwMultilabelAst) - - case class VwSparsePredictorProducer[K]( - modelSource: ModelSource, - params: String, - defaultNs: List[Int], - namespaces: List[(String, List[Int])]) - extends SparseMultiLabelPredictor[K] - with Closeable { - - // TODO: Create the model. - - override def apply( - features: IndexedSeq[Sparse], - labels: sci.IndexedSeq[K], - indices: sci.IndexedSeq[Int], - labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] - ): Map[K, Double] = { - - // Create VW input. - - ??? - } - - override def close(): Unit = { - // TODO: Add actual - } - } - - case class VwMultilabelModelPluginJsonReader[K](featureNames: Vector[String]) extends JsonReader[VwMultilabelModelPlugin[K]] { - - private[multilabel] def namespaceMapping(featureNames: Vector[String], nsMap: ListMap[String, Seq[String]]): (List[Int], List[(String, List[Int])]) = { - - ??? - } - - override def read(json: JsValue): VwMultilabelModelPlugin[K] = { - val ast = json.asJsObject(notObjErr(json)).convertTo[VwMultilabelAst] - - val params = ast.params.map { - case Left(paramList) => paramList.mkString(" ") - case Right(paramsStr) => paramsStr - } - - params.map { ps => - val (defaultNs, namespaces) = namespaceMapping(featureNames, ast.namespaces.getOrElse(ListMap.empty)) - VwMultilabelModelPlugin[K](ast.modelSource, ps, defaultNs, namespaces) - } getOrElse { - throw new Exception("no VW params provided") - } - } - - protected[multilabel] def notObjErr(json: JsValue): String = { - val str = json.prettyPrint - val substr = str.substring(0, JsonErrStrLength) - s"JSON object expected. Found " + substr + (if (str.length != substr.length) " ..." else "") - } - } -} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala new file mode 100644 index 00000000..b536ccef --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -0,0 +1,94 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.Closeable + +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.io.sources.ModelSource +import com.eharmony.aloha.models.multilabel.SparseMultiLabelPredictor +import vowpalWabbit.responses.ActionScores +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +import scala.collection.{immutable => sci} +import scala.util.Try + +/** + * Created by ryan.deak on 9/8/17. + */ +case class VwSparseMultilabelPredictor[K]( + modelSource: ModelSource, + params: String, + defaultNs: List[Int], + namespaces: List[(String, List[Int])]) +extends SparseMultiLabelPredictor[K] + with Closeable { + + import VwSparseMultilabelPredictor._ + + @transient private[multilabel] lazy val vwModel = createLearner(modelSource, params).get + + { + // Force creation. + require(vwModel != null) + } + + override def apply( + features: IndexedSeq[Sparse], + labels: sci.IndexedSeq[K], + indices: sci.IndexedSeq[Int], + labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] + ): Map[K, Double] = { + val x = constructInput(features, indices, defaultNs, namespaces) + val pred = Try { vwModel.predict(x) } + val yOut = pred.map { y => produceOutput(y, labels, indices) } + + // TODO: Change the interface to Try[Map[K, Double]] + yOut.get + } + + override def close(): Unit = vwModel.close() +} + +object VwSparseMultilabelPredictor { + + private[multilabel] def constructInput[K]( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + defaultNs: List[Int], + namespaces: List[(String, List[Int])] + ): Array[String] = { + ??? + } + + private[multilabel] def produceOutput[K]( + pred: ActionScores, + labels: sci.IndexedSeq[K], + indices: sci.IndexedSeq[Int] + ): Map[K, Double] = { + val n = labels.size + + // TODO: Possibly update the interface to pass this in (possibly non-strictly). + val indToLabel: Map[Int, K] = indices.zip(labels)(collection.breakOut) + + val y: Map[K, Double] = (for { + as <- pred.getActionScores + label <- indToLabel.get(as.getAction).toIterable + pred = as.getScore.toDouble + } yield label -> pred)(collection.breakOut) + + y + } + + private[multilabel] def paramsWithSource(modelSource: ModelSource, params: String): String = { + // TODO: Fill in. + ??? + } + + private[multilabel] def createLearner(modelSource: ModelSource, params: String): Try[VWActionScoresLearner] = { + val modelFile = modelSource.localVfs.replicatedToLocal() + val updatedparams = paramsWithSource(modelSource, params) + + // : _root_.com.eharmony.aloha.io.vfs.File + val vfs = modelSource.localVfs + Try { VWLearners.create[VWActionScoresLearner](params) } + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala new file mode 100644 index 00000000..99d73cfa --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -0,0 +1,49 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.io.sources.ModelSource +import com.eharmony.aloha.models.multilabel._ +import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelPluginJsonReader +import com.eharmony.aloha.reflect.RefInfo +import spray.json.{JsonFormat, JsonReader} + +/** + * A thing wrapper responsible for creating a [[VwSparseMultilabelPredictor]]. This + * creation is deferred because VW JNI models are not Serializable because they are + * thin wrappers around native C models in memory. The underlying binary model is + * externalizable in a file, byte array, etc. and this information is contained in + * `modelSource`. The actual VW JNI model is created only where needed. In short, + * this class can be serialized but the JNI model created from `modelSource` cannot + * be. + * + * Created by ryan.deak on 9/5/17. + * + * @param modelSource a source from which the binary VW model information can be + * extracted and used to create a VW JNI model. + * @param params VW parameters passed to the JNI constructor + * @param defaultNs indices into the features that belong to VW's default namespace. + * @param namespaces namespace name to feature indices map. + * @tparam K type of the labels returned by the [[VwSparseMultilabelPredictor]] that + * will be produced. + */ +case class VwSparseMultilabelPredictorProducer[K]( + modelSource: ModelSource, + params: String, + defaultNs: List[Int], + namespaces: List[(String, List[Int])]) +extends SparsePredictorProducer[K] { + override def apply(): VwSparseMultilabelPredictor[K] = + VwSparseMultilabelPredictor[K](modelSource, params, defaultNs, namespaces) +} + +object VwSparseMultilabelPredictorProducer extends MultilabelPluginProviderCompanion { + def multilabelPlugin: MultilabelModelParserPlugin = Plugin + + object Plugin extends MultilabelModelParserPlugin { + override def name: String = "vw" + + override def parser[K](info: PluginInfo) + (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] = { + VwMultilabelModelPluginJsonReader[K](info.features.keys.toVector) + } + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala new file mode 100644 index 00000000..4bfee0d8 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala @@ -0,0 +1,23 @@ +package com.eharmony.aloha.models.vw.jni.multilabel.json + +import com.eharmony.aloha.io.sources.ModelSource + +import spray.json.DefaultJsonProtocol._ + +import scala.collection.immutable.ListMap +import com.eharmony.aloha.factory.ScalaJsonFormats + +/** + * Created by ryan.deak on 9/8/17. + */ +trait VwMultilabelModelJson extends ScalaJsonFormats { + + private[multilabel] case class VwMultilabelAst( + `type`: String, + modelSource: ModelSource, + params: Either[Seq[String], String] = Right(""), + namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) + ) + + protected[this] implicit val vwMultilabelAstFormat = jsonFormat4(VwMultilabelAst) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala new file mode 100644 index 00000000..3fb9956f --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala @@ -0,0 +1,55 @@ +package com.eharmony.aloha.models.vw.jni.multilabel.json + +import com.eharmony.aloha.models.multilabel.SparsePredictorProducer +import com.eharmony.aloha.models.vw.jni.Namespaces +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictorProducer +import com.eharmony.aloha.util.Logging +import spray.json.{JsValue, JsonReader} + +import scala.collection.immutable.ListMap + +/** + * + * + * '''NOTE''': This class extends `JsonReader[SparsePredictorProducer[K]]` rather than the + * more specific type `JsonReader[VwSparseMultilabelPredictorProducer[K]]` because `JsonReader` + * is not covariant in its type parameter. + * + * Created by ryan.deak on 9/8/17. + * + * @param featureNames feature names from the multi-label model. + * @tparam K label type for the predictions outputted by the + */ +case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String]) + extends JsonReader[SparsePredictorProducer[K]] + with VwMultilabelModelJson + with Namespaces + with Logging { + + import VwMultilabelModelPluginJsonReader._ + + override def read(json: JsValue): VwSparseMultilabelPredictorProducer[K] = { + val ast = json.asJsObject(notObjErr(json)).convertTo[VwMultilabelAst] + val params = vwParams(ast.params) + val (namespaces, defaultNs, missing) = + allNamespaceIndices(featureNames, ast.namespaces.getOrElse(ListMap.empty)) + + if (missing.nonEmpty) + info(s"features in namespaces not found in featureNames: $missing") + + VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces) + } +} + +object VwMultilabelModelPluginJsonReader extends Logging { + private val JsonErrStrLength = 100 + + private[multilabel] def vwParams(params: Either[Seq[String], String]): String = + params.fold(_ mkString " ", identity).trim + + private[multilabel] def notObjErr(json: JsValue): String = { + val str = json.prettyPrint + val substr = str.substring(0, JsonErrStrLength) + s"JSON object expected. Found " + substr + (if (str.length != substr.length) " ..." else "") + } +} \ No newline at end of file From 75462628532a603d847339fc2cfd540861bf1d60 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 11 Sep 2017 16:38:02 -0700 Subject: [PATCH 36/98] VW compiling. Added test shell. Fill in shell. --- .../dataset/vw/unlabeled/VwRowCreator.scala | 178 ++++++++++-------- .../models/multilabel/MultilabelModel.scala | 4 +- .../aloha/models/multilabel/package.scala | 3 +- .../VwSparseMultilabelPredictor.scala | 116 +++++++++--- .../multilabel/VwMultilabelModelTest.scala | 48 +++++ 5 files changed, 243 insertions(+), 106 deletions(-) create mode 100644 aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/unlabeled/VwRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/unlabeled/VwRowCreator.scala index f8800b13..7bebbcce 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/unlabeled/VwRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/unlabeled/VwRowCreator.scala @@ -2,11 +2,10 @@ package com.eharmony.aloha.dataset.vw.unlabeled import java.text.DecimalFormat +import com.eharmony.aloha.dataset._ import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.dataset.vw.VwCovariateProducer -import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator.{DefaultVwNamespaceName, inEpsilonInterval} import com.eharmony.aloha.dataset.vw.unlabeled.json.VwUnlabeledJson -import com.eharmony.aloha.dataset._ import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.util.Logging import spray.json._ @@ -49,59 +48,74 @@ extends RowCreator[A] collect { case s if s.nonEmpty => s.toVector.sorted.mkString(", ") }. foreach { empty => info(s"The following namespaces were empty: $empty.") } + def unlabeledVwInput(features: IndexedSeq[Sparse]): CharSequence = { + val vwLine = VwRowCreator.unlabeledVwInput(features, defaultNamespace, nonEmptyNamespaces, includeZeroValues) + val out = normalizer.map(n => n(vwLine)).getOrElse(vwLine) + out + } def apply(data: A): (MissingAndErroneousFeatureInfo, CharSequence) = { val (extractionInfo, features) = featuresFunction(data) val vwIn = unlabeledVwInput(features) (extractionInfo, vwIn) } +} + +object VwRowCreator { + private[vw] val DefaultVwNamespaceName = "" /** - * Each namespace will only be present if it contains at least one feature in the namespace. - * @param features features to insert into VW input line. - * @return + * The reason to choose 17 digits is that + 1 1 == (1 - 1.0e-17) + 1 0.9999999999999999 == (1 - 1.0e-16). + * We want to retain as much information as possible without allowing long trailing sequences of zeroes. */ - def unlabeledVwInput(features: IndexedSeq[Sparse]) = { - - // RMD 2015-06-12: GOD I HATE THIS CODE!!! Maybe functionalize it in the future! - - val sb = new StringBuilder + private[vw] val LabelDecimalDigits = 17 + private[vw] val LabelDecimalFormatter = new DecimalFormat(List.fill(LabelDecimalDigits)("#").mkString("0.", "", "")) + private[this] val labelEps = math.pow(10, -LabelDecimalDigits) / 2 + private[this] val labelNegEps = -labelEps + private[vw] def labelInEpsilonInterval(label: Double) = labelNegEps < label && label < labelEps - // Whether a namespace has been added previously. - var nsAlreadyInserted = false + private[vw] val FeatureDecimalDigits = 6 + private[vw] val DecimalFormatter = new DecimalFormat(List.fill(FeatureDecimalDigits)("#").mkString("0.", "", "")) + private[this] val eps = math.pow(10, -FeatureDecimalDigits) / 2 + private[this] val negEps = -eps + private[vw] def inEpsilonInterval(x: Double) = negEps < x && x < eps - if (defaultNamespace.nonEmpty) - nsAlreadyInserted = addNamespaceFeaturesToVwLine(sb, nsAlreadyInserted, DefaultVwNamespaceName, defaultNamespace, features, includeZeroValues = includeZeroValues) + final class Producer[A] + extends RowCreatorProducer[A, VwRowCreator[A]] + with RowCreatorProducerName + with VwCovariateProducer[A] + with SparseCovariateProducer + with CompilerFailureMessages { - var nss: List[(String, List[Int])] = nonEmptyNamespaces - while (nss.nonEmpty) { - val ns = nss.head - nsAlreadyInserted = addNamespaceFeaturesToVwLine(sb, nsAlreadyInserted, ns._1, ns._2, features, includeZeroValues) - nss = nss.tail + type JsonType = VwUnlabeledJson + def parse(json: JsValue): Try[VwUnlabeledJson] = Try { json.convertTo[VwUnlabeledJson] } + def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwUnlabeledJson): Try[VwRowCreator[A]] = { + val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) + val spec = covariates.map(c => new VwRowCreator(c, default, nss, normalizer)) + spec } - - // If necessary, apply the normalizer. - val vwLine = normalizer.map(n => n(sb)).getOrElse(sb) - vwLine } + /** - * Add data from a namespace to the VW line. Data comes in the form of an iterable sequence of key-value pairs - * where keys are strings and values are doubles. The values are truncated according to - * [[VwRowCreator.DecimalFormatter]]. If the truncated value is the integer, 1, then the value is omitted from the - * output (as is allowed by VW). If the truncated value is zero, then the feature is included only if - * ''includeZeroValues'' is true. - * - * @param sb string builder into which data is - * @param previousNsInserted Whether a namespace has previously been inserted - * @param nsName the namespace name. - * @param nsFeatureIndices the feature indices included in the referenced namespace. - * @param features the entire list of features (across all namespaces). Since this is an ''IndexedSeq'', lookup - * by index is constant or near-constant time. - * @param includeZeroValues whether to include key-value pairs in the VW output whose values are zero. - * @return - */ - private[this] def addNamespaceFeaturesToVwLine( + * Add data from a namespace to the VW line. Data comes in the form of an iterable sequence of key-value pairs + * where keys are strings and values are doubles. The values are truncated according to + * [[VwRowCreator.DecimalFormatter]]. If the truncated value is the integer, 1, then the value is omitted from the + * output (as is allowed by VW). If the truncated value is zero, then the feature is included only if + * ''includeZeroValues'' is true. + * + * @param sb string builder into which data is + * @param previousNsInserted Whether a namespace has previously been inserted + * @param nsName the namespace name. + * @param nsFeatureIndices the feature indices included in the referenced namespace. + * @param features the entire list of features (across all namespaces). Since this is an ''IndexedSeq'', lookup + * by index is constant or near-constant time. + * @param includeZeroValues whether to include key-value pairs in the VW output whose values are zero. + * @return + */ + private[unlabeled] def addNamespaceFeaturesToVwLine( sb: StringBuilder, previousNsInserted: Boolean, nsName: String, @@ -125,18 +139,18 @@ extends RowCreator[A] while(it.hasNext) { val (feature, value) = it.next() - if (VwRowCreator.inEpsilonInterval(value - 1)) { - if (ins) sb.append(" ") - sb.append(feature) - } - else if (!inEpsilonInterval(value) || includeZeroValues) { - // For double values, format it to 6 decimals. VW seems to not handle crazy long - // numbers too well. Note that for super large numbers, this DecimalFormat will - // spit out very large strings of numbers. i don't think very large weights occur - // that often with VW so for now i'm not addressing that (potential) issue. - if (ins) sb.append(" ") - sb.append(feature).append(":").append(VwRowCreator.DecimalFormatter.format(value)) - } + if (VwRowCreator.inEpsilonInterval(value - 1)) { + if (ins) sb.append(" ") + sb.append(feature) + } + else if (!inEpsilonInterval(value) || includeZeroValues) { + // For double values, format it to 6 decimals. VW seems to not handle crazy long + // numbers too well. Note that for super large numbers, this DecimalFormat will + // spit out very large strings of numbers. i don't think very large weights occur + // that often with VW so for now i'm not addressing that (potential) issue. + if (ins) sb.append(" ") + sb.append(feature).append(":").append(VwRowCreator.DecimalFormatter.format(value)) + } } h(indices.tail, ins, nameIns) @@ -145,42 +159,42 @@ extends RowCreator[A] h(nsFeatureIndices, previousNsInserted, nameAlreadyInserted = false) } -} - -final object VwRowCreator { - private[vw] val DefaultVwNamespaceName = "" /** - * The reason to choose 17 digits is that - 1 1 == (1 - 1.0e-17) - 1 0.9999999999999999 == (1 - 1.0e-16). - * We want to retain as much information as possible without allowing long trailing sequences of zeroes. - */ - private[vw] val LabelDecimalDigits = 17 - private[vw] val LabelDecimalFormatter = new DecimalFormat(List.fill(LabelDecimalDigits)("#").mkString("0.", "", "")) - private[this] val labelEps = math.pow(10, -LabelDecimalDigits) / 2 - private[this] val labelNegEps = -labelEps - private[vw] def labelInEpsilonInterval(label: Double) = labelNegEps < label && label < labelEps + * + * Each namespace will only be present if it contains at least one feature in the namespace. + * @param features features to insert into VW input line. + * @param defaultNamespace + * @param nonEmptyNamespaces + * @param includeZeroValues + * @return + */ + private[aloha] def unlabeledVwInput( + features: IndexedSeq[Sparse], + defaultNamespace: List[Int], + nonEmptyNamespaces: List[(String, List[Int])], + includeZeroValues: Boolean): CharSequence = { - private[vw] val FeatureDecimalDigits = 6 - private[vw] val DecimalFormatter = new DecimalFormat(List.fill(FeatureDecimalDigits)("#").mkString("0.", "", "")) - private[this] val eps = math.pow(10, -FeatureDecimalDigits) / 2 - private[this] val negEps = -eps - private[vw] def inEpsilonInterval(x: Double) = negEps < x && x < eps + // RMD 2015-06-12: GOD I HATE THIS CODE!!! Maybe functionalize it in the future! - final class Producer[A] - extends RowCreatorProducer[A, VwRowCreator[A]] - with RowCreatorProducerName - with VwCovariateProducer[A] - with SparseCovariateProducer - with CompilerFailureMessages { + val sb = new StringBuilder - type JsonType = VwUnlabeledJson - def parse(json: JsValue): Try[VwUnlabeledJson] = Try { json.convertTo[VwUnlabeledJson] } - def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwUnlabeledJson): Try[VwRowCreator[A]] = { - val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) - val spec = covariates.map(c => new VwRowCreator(c, default, nss, normalizer)) - spec + // Whether a namespace has been added previously. + var nsAlreadyInserted = false + + if (defaultNamespace.nonEmpty) + nsAlreadyInserted = addNamespaceFeaturesToVwLine( + sb, nsAlreadyInserted, DefaultVwNamespaceName, defaultNamespace, + features, includeZeroValues = includeZeroValues) + + var nss: List[(String, List[Int])] = nonEmptyNamespaces + while (nss.nonEmpty) { + val ns = nss.head + nsAlreadyInserted = addNamespaceFeaturesToVwLine( + sb, nsAlreadyInserted, ns._1, ns._2, features, includeZeroValues) + nss = nss.tail } + + sb } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index dba57ec5..3c63acd1 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -17,7 +17,7 @@ import com.eharmony.aloha.util.{Logging, SerializabilityEvidence} import spray.json.{JsonFormat, JsonReader} import scala.collection.{immutable => sci, mutable => scm} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Success} // TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. @@ -110,7 +110,7 @@ extends SubmodelBase[U, Map[K, Double], A, B] reportTooManyMissing(modelId, li, missing, auditor) else { // TODO: To support label-dependent features, fill last parameter with a valid value. - val predictionTry = Try { predictor(x, li.labels, li.indices, sci.IndexedSeq.empty) } + val predictionTry = predictor(x, li.labels, li.indices, sci.IndexedSeq.empty) predictionTry match { case Success(pred) => reportSuccess(modelId, li, missing, pred, auditor) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index 94ff90b3..e1cd9564 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -3,6 +3,7 @@ package com.eharmony.aloha.models import com.eharmony.aloha.dataset.density.Sparse import scala.collection.{immutable => sci} +import scala.util.Try /** * Created by ryan.deak on 8/31/17. @@ -56,7 +57,7 @@ package object multilabel { * @tparam K the type of labels (or classes in the machine learning literature). */ type SparseMultiLabelPredictor[K] = - (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Map[K, Double] + (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Try[Map[K, Double]] /** * A lazy version of a sparse multi-label predictor. It is a curried zero-arg function that diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index b536ccef..62070ecc 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -1,19 +1,31 @@ package com.eharmony.aloha.models.vw.jni.multilabel -import java.io.Closeable +import java.io.{Closeable, File} import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator import com.eharmony.aloha.io.sources.ModelSource import com.eharmony.aloha.models.multilabel.SparseMultiLabelPredictor -import vowpalWabbit.responses.ActionScores import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} +import vowpalWabbit.responses.ActionScores import scala.collection.{immutable => sci} import scala.util.Try /** * Created by ryan.deak on 9/8/17. + * @param modelSource a specification of a location for the underlying VW model that + * will be materialized in this class. + * @param params VW parameters. + * @param defaultNs The list of indices into the `features` sequence that does not have + * an exist in any value of the `namespaces` map. + * @param namespaces Mapping from namespace name to indices in the `features` sequence passed + * to the `apply` method. There should be no empty namespaces, meaning + * ''key-value'' pairs appearing in the map should not values that are empty + * sequences. '''This is a requirement.''' + * @tparam K the label or class type. */ +// TODO: Comment this function. It requires a lot of assumptions. Make those known. case class VwSparseMultilabelPredictor[K]( modelSource: ModelSource, params: String, @@ -27,6 +39,12 @@ extends SparseMultiLabelPredictor[K] @transient private[multilabel] lazy val vwModel = createLearner(modelSource, params).get { + val emptyNss = namespaces collect { case (ns, ind) if ind.isEmpty => ns } + require( + emptyNss.isEmpty, + s"There should be no namespaces that are empty. Found: ${emptyNss mkString ", "}" + ) + // Force creation. require(vwModel != null) } @@ -36,27 +54,57 @@ extends SparseMultiLabelPredictor[K] labels: sci.IndexedSeq[K], indices: sci.IndexedSeq[Int], labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] - ): Map[K, Double] = { - val x = constructInput(features, indices, defaultNs, namespaces) + ): Try[Map[K, Double]] = { + val x = multiLabelClassifierInput(features, indices, defaultNs, namespaces) val pred = Try { vwModel.predict(x) } val yOut = pred.map { y => produceOutput(y, labels, indices) } - - // TODO: Change the interface to Try[Map[K, Double]] - yOut.get + yOut } override def close(): Unit = vwModel.close() } object VwSparseMultilabelPredictor { + private val DummyClassNS = "y" + private val ClassNS = "Y" + private val NegDummyClass = Int.MaxValue.toLong + 1 + private val PosDummyClass = NegDummyClass + 1 + private val NegDummyClassLine = s"$NegDummyClass |$DummyClassNS _C${NegDummyClass}_" + private val PosDummyClassLine = s"$PosDummyClass |$DummyClassNS _C${PosDummyClass}_" - private[multilabel] def constructInput[K]( + + private[multilabel] def multiLabelClassifierInput( features: IndexedSeq[Sparse], indices: sci.IndexedSeq[Int], defaultNs: List[Int], namespaces: List[(String, List[Int])] ): Array[String] = { - ??? + val n = indices.size + // The length of the output array is n + 3. The first row is the shared features. + // These are features that are not label dependent. Then we have two dummy classes. + // + // The class at the index 1 (0-based) is one that is always negative. The class at + // index 2 is one that is always positive. These dummy classes are required to make + // the probabilities work out for multi-label classifier. + val x = new Array[String](n + 3) + + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) + // "shared" is a special keyword in VW multi-class (multi-row) format. + // See: https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html + x(0) = "shared " + shared + x(1) = NegDummyClassLine + x(2) = PosDummyClassLine + + // This is mutable because we want speed + // and there is nothing + var i = 0 + while (i < n) { + val labelInd = indices(i) + x(i + 3) = s"$labelInd |$ClassNS _C${labelInd}_" + i += 1 + } + + x } private[multilabel] def produceOutput[K]( @@ -64,31 +112,57 @@ object VwSparseMultilabelPredictor { labels: sci.IndexedSeq[K], indices: sci.IndexedSeq[Int] ): Map[K, Double] = { - val n = labels.size // TODO: Possibly update the interface to pass this in (possibly non-strictly). val indToLabel: Map[Int, K] = indices.zip(labels)(collection.breakOut) val y: Map[K, Double] = (for { - as <- pred.getActionScores + as <- pred.getActionScores label <- indToLabel.get(as.getAction).toIterable - pred = as.getScore.toDouble + pred = as.getScore.toDouble } yield label -> pred)(collection.breakOut) y } - private[multilabel] def paramsWithSource(modelSource: ModelSource, params: String): String = { - // TODO: Fill in. - ??? - } + + /** + * Update the parameters with the + * + * VW params of interest when doing multi-class: + * + - `--csoaa_ldf mc` Label-dependent features for multi-class classification + - `--csoaa_rank` (Probably) necessary to get scores for m-c classification. + - `--loss_function logistic` Standard logistic loss for learning. + - `--noconstant` Don't want a constant since it's not interacted with NS Y. + - `-q YX` Cross product of label-dependent and features and features + - `--ignore_linear Y` Don't care about the 1st-order wts of the label-dep features. + - `--ignore_linear X` Don't care about the 1st-order wts of the features. + - `--ignore y` Ignore everything related to the dummy class instances. + * + * {{{ + * val str = + * "shared |X feature" + "\n" + + * + * "0:1 |y _C0_" + "\n" + // These two instances are dummy classes + * "1:0 |y _C1_" + "\n" + + * + * "2:0 |Y _C2_" + "\n" + + * "3:1 |Y _C3_" + * + * val ex = str.split("\n") + * }}} + * @param modelSource + * @param params + * @return + */ + // TODO: How much of the parameter setup is up to the caller versus this function? + private[multilabel] def paramsWithSource(modelSource: File, params: String): String = + params + " -i" + modelSource.getCanonicalPath + " -t --quiet" private[multilabel] def createLearner(modelSource: ModelSource, params: String): Try[VWActionScoresLearner] = { val modelFile = modelSource.localVfs.replicatedToLocal() - val updatedparams = paramsWithSource(modelSource, params) - - // : _root_.com.eharmony.aloha.io.vfs.File - val vfs = modelSource.localVfs - Try { VWLearners.create[VWActionScoresLearner](params) } + val updatedParams = paramsWithSource(modelFile.fileObj, params) + Try { VWLearners.create[VWActionScoresLearner](updatedParams) } } } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala new file mode 100644 index 00000000..5389f702 --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -0,0 +1,48 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.SerializabilityEvidence +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 9/11/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelModelTest { + import VwMultilabelModelTest._ + + @Test def test1(): Unit = { + val predProd = VwSparseMultilabelPredictorProducer[Label]( + modelSource = null, // ModelSource, + params = "", + defaultNs = List.empty[Int], + namespaces = List.empty[(String, List[Int])] + ) + + val model = + MultilabelModel( + modelId = ModelId(1, "model"), + featureNames = sci.IndexedSeq.empty[String], + featureFunctions = sci.IndexedSeq.empty[GenAggFunc[Domain, Sparse]], + labelsInTrainingSet = sci.IndexedSeq.empty[Label], + labelsOfInterest = Option.empty[GenAggFunc[Domain, sci.IndexedSeq[Label]]], + predictorProducer = predProd, + numMissingThreshold = Option(1000000), + auditor = Auditor) + } +} + +object VwMultilabelModelTest { + private type Label = String + private type Domain = Any + + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() +} \ No newline at end of file From 6d5671d64b48fadfc9869d8f91909691fb225442 Mon Sep 17 00:00:00 2001 From: amirziai Date: Mon, 11 Sep 2017 17:12:52 -0700 Subject: [PATCH 37/98] Number of new changes --- .../multilabel/MultilabelModelTest.scala | 156 ++++++++++++++++-- 1 file changed, 138 insertions(+), 18 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 55536b1c..bff54d26 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -6,6 +6,7 @@ import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.semantics.SemanticsUdfException import com.eharmony.aloha.semantics.func._ import org.junit.Test import org.junit.Assert._ @@ -187,7 +188,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { - def extractLabelsOutOfExample(example: Map[String, String]) = + def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq // Example with no problems @@ -214,11 +215,70 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Example with no labels val exampleNoLabels = Map("feature1" -> "1", "feature2" -> "2") - val labelsAndInfoNoLabels = labelsForPrediction(exampleNoLabels, + val labelsAndInfoNoLabels = labelsForPrediction( + exampleNoLabels, labelsOfInterestExtractor, labelToInt) val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq())) assertEquals(problemsNoLabelsExpected, labelsAndInfoNoLabels.problems) + + // missing labels + def badFunction(example: Map[String, String]) = example.get("feature1") + val gen1 = GenFunc1("concat _1", (m: Option[String]) => sci.IndexedSeq(s"${m}_1"), + GeneratedAccessor("extract feature 1", badFunction, None)) + + val labelsExtractor = + (m: Map[String, String]) => m.get("labels") match { + case ls: sci.IndexedSeq[_] if ls.forall { x: String => x.isInstanceOf[Label] } => + Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) + case _ => None + } + + val f1 = + GenFunc.f1(GeneratedAccessor("labels", labelsExtractor, None))( + "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] + ) + + val labelsAndInfoNoLabelsGen1 = labelsForPrediction( + Map[String, String](), + f1, + labelToInt) + val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq("labels"), Seq())) + assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) + + // TODO: figre this out + // error + val labelsExtractorError: (Map[String, String]) => Option[sci.IndexedSeq[Label]] = + (m: Map[String, String]) => m.get("labels") match { + case ls: sci.IndexedSeq[_] if ls.forall { x: String => x.isInstanceOf[Label] } => + Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) + case _ => throw new Exception("labels does not exist") + } + + val f2 = + GenFunc.f1(GeneratedAccessor("labels", + ((m: Map[String, String]) => throw new Exception("errmsg")) : Map[String, String] => + Option[sci.IndexedSeq[String]] + , None))( + "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] + ) + val f2Wrapped = EnrichedErrorGenAggFunc(f2) + + new SemanticsUdfException[Any](null, null, null, null, null, null) + + val problemsNoLabelsExpectedGen2 = Option(GenAggFuncAccessorProblems(Seq(), Seq("labels"))) + Try( + labelsForPrediction( + Map[String, String](), + f2Wrapped, + labelToInt) + ).failed.get match { + case ex: SemanticsUdfException[_] => + assertEquals(ex.accessorsInErr, labelsAndInfoNoLabelsGen2.problems) + } + + + // assertEquals(problemsNoLabelsExpectedGen2, labelsAndInfoNoLabelsGen2.problems) } @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { @@ -227,24 +287,84 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // if (unsorted.size == labelsShouldPredict.size) Seq.empty // else labelsShouldPredict.filterNot(labelToInd.contains) + def extractLabelsOutOfExample(example: Map[String, String]) = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + + val example: Map[String, String] = Map( + "feature1" -> "1", + "feature2" -> "2", + "feature3" -> "2", + "label1" -> "a", + "label2" -> "b" + ) + val allLabels = sci.IndexedSeq("a", "b", "c") + val labelToInt = allLabels.zipWithIndex.toMap + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) + val missingLabels = labelsAndInfo.missingLabels + assertEquals(Seq(), missingLabels) + + // Extra label not in the list + val example2 = Map("label4" -> "d") + val labelsAndInfo2 = labelsForPrediction(example2, labelsOfInterestExtractor, labelToInt) + val missingLabels2 = labelsAndInfo2.missingLabels + assertEquals(Seq("d"), missingLabels2) + + // No labels + val example3 = Map("feature2" -> "5") + val labelsAndInfo3 = labelsForPrediction(example3, labelsOfInterestExtractor, labelToInt) + val missingLabels3 = labelsAndInfo3.missingLabels + assertEquals(Seq(), missingLabels3) + } + + @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { + // Test this: + // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip + + def extractLabelsOutOfExample(example: Map[String, String]) = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + + val example: Map[String, String] = Map( + "feature1" -> "1", + "feature2" -> "2", + "feature3" -> "2", + "label1" -> "a", + "label2" -> "b", + "label3" -> "l23", + "label4" -> "100", + "label5" -> "235", + "label6" -> "c", + "label7" -> "1", + "label8" -> "l1" + ) + + val allLabels = sci.IndexedSeq("a", "b", "c", "235", "1", "l1", "l23", "100") + val labelToInt = allLabels.zipWithIndex.toMap + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) + + assertEquals(allLabels, labelsAndInfo.labels) + assertEquals(allLabels.indices, labelsAndInfo.indices) + } + + @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { + // Test this: + // if (li.labels.isEmpty) + // reportNoPrediction(modelId, li, auditor) + def extractLabelsOutOfExample(example: Map[String, String]) = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + + val example = Map("" -> "") + val allLabels = sci.IndexedSeq("a", "b", "c") + val labelToInt = allLabels.zipWithIndex.toMap + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + + + // labelsAndInfo(example,) + fail() } -// -// @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { -// // Test this: -// // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip -// -// fail() -// } -// -// @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { -// // Test this: -// // if (li.labels.isEmpty) -// // reportNoPrediction(modelId, li, auditor) -// -// fail() -// } -// + // @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { // // When the amount of missing data exceeds the threshold, reportTooManyMissing should be // // called and its value should be returned. Instantiate a MultilabelModel and From 6f4514576218d35e4c41e0383245493cf46614c9 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 11 Sep 2017 21:31:31 -0700 Subject: [PATCH 38/98] Test passing. It appears we don't need the dummy classes in test mode. --- .../VwSparseMultilabelPredictor.scala | 33 +++-- .../multilabel/VwMultilabelModelTest.scala | 130 ++++++++++++++++-- 2 files changed, 139 insertions(+), 24 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 62070ecc..b2e072c3 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -69,8 +69,8 @@ object VwSparseMultilabelPredictor { private val ClassNS = "Y" private val NegDummyClass = Int.MaxValue.toLong + 1 private val PosDummyClass = NegDummyClass + 1 - private val NegDummyClassLine = s"$NegDummyClass |$DummyClassNS _C${NegDummyClass}_" - private val PosDummyClassLine = s"$PosDummyClass |$DummyClassNS _C${PosDummyClass}_" + private val NegDummyClassLine = s"$NegDummyClass:1 |$DummyClassNS _C${NegDummyClass}_" + private val PosDummyClassLine = s"$PosDummyClass:0 |$DummyClassNS _C${PosDummyClass}_" private[multilabel] def multiLabelClassifierInput( @@ -86,24 +86,27 @@ object VwSparseMultilabelPredictor { // The class at the index 1 (0-based) is one that is always negative. The class at // index 2 is one that is always positive. These dummy classes are required to make // the probabilities work out for multi-label classifier. - val x = new Array[String](n + 3) + val x = new Array[String](n + 1) val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) // "shared" is a special keyword in VW multi-class (multi-row) format. // See: https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html x(0) = "shared " + shared - x(1) = NegDummyClassLine - x(2) = PosDummyClassLine - // This is mutable because we want speed - // and there is nothing + // This is mutable because we want speed. var i = 0 + + // This is fantastic! + // TODO: It appears that we don't have to add the dummy classes at test time. Double check. while (i < n) { val labelInd = indices(i) - x(i + 3) = s"$labelInd |$ClassNS _C${labelInd}_" + x(i + 1) = s"$labelInd:0 |$ClassNS _C${labelInd}_" i += 1 } +// x(n + 1) = NegDummyClassLine +// x(n + 2) = PosDummyClassLine + x } @@ -116,15 +119,25 @@ object VwSparseMultilabelPredictor { // TODO: Possibly update the interface to pass this in (possibly non-strictly). val indToLabel: Map[Int, K] = indices.zip(labels)(collection.breakOut) + // The last two action IDs in the action scores array are the dummy actions. val y: Map[K, Double] = (for { - as <- pred.getActionScores + as <- pred.getActionScores // if as.getAction < indices.size if we need to deal with dummy classes. label <- indToLabel.get(as.getAction).toIterable - pred = as.getScore.toDouble + pred = modifiedLogistic(as.getScore) } yield label -> pred)(collection.breakOut) y } + /** + * A modified logistic function where the sign of the exponent is opposite the usual + * definition. Since CSOAA in VW employs costs, it changes the sign of the normal + * logistic function so the definition becomes `1 / (1 + exp(x))`. + * + * @param x an input produced by a VW CSOAA prediction. + * @return a probability. + */ + @inline final private def modifiedLogistic(x: Float) = 1 / (1 + math.exp(x)) /** * Update the parameters with the diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 5389f702..a0dc4ce7 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -1,14 +1,19 @@ package com.eharmony.aloha.models.vw.jni.multilabel +import java.io.File + import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs import com.eharmony.aloha.models.multilabel.MultilabelModel -import com.eharmony.aloha.semantics.func.GenAggFunc -import com.eharmony.aloha.util.SerializabilityEvidence +import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} +import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} import scala.collection.{immutable => sci} @@ -20,29 +25,126 @@ class VwMultilabelModelTest { import VwMultilabelModelTest._ @Test def test1(): Unit = { + val features = Vector[GenAggFunc[Domain, Sparse]]( + GenFunc0("", (_: Domain) => Iterable(("", 1d))) + ) + val predProd = VwSparseMultilabelPredictorProducer[Label]( - modelSource = null, // ModelSource, - params = "", + modelSource = TrainedModel, + params = "", // to see the output: "-p /dev/stdout", defaultNs = List.empty[Int], - namespaces = List.empty[(String, List[Int])] + namespaces = List(("X", List(0))) // List.empty[(String, List[Int])] ) val model = MultilabelModel( - modelId = ModelId(1, "model"), - featureNames = sci.IndexedSeq.empty[String], - featureFunctions = sci.IndexedSeq.empty[GenAggFunc[Domain, Sparse]], - labelsInTrainingSet = sci.IndexedSeq.empty[Label], - labelsOfInterest = Option.empty[GenAggFunc[Domain, sci.IndexedSeq[Label]]], - predictorProducer = predProd, - numMissingThreshold = Option(1000000), - auditor = Auditor) + modelId = ModelId(1, "model"), + featureNames = Vector(FeatureName), + featureFunctions = features, + labelsInTrainingSet = AllLabels, + labelsOfInterest = Option.empty[GenAggFunc[Domain, sci.IndexedSeq[Label]]], + predictorProducer = predProd, + numMissingThreshold = Option.empty[Int], + auditor = Auditor) + + val result = model("This value doesn't matter.") + + try { + result.value match { + case None => fail() + case Some(labelMap) => + assertEquals(AllLabels.toSet, labelMap.keySet) + assertEquals(0.7, labelMap(LabelSeven), 0.01) + assertEquals(0.8, labelMap(LabelEight), 0.01) + assertEquals(0.6, labelMap(LabelSix), 0.01) + } + } + finally { + model.close() + } } } object VwMultilabelModelTest { - private type Label = String + private type Label = String private type Domain = Any + private val LabelSeven = "seven" + private val LabelEight = "eight" + private val LabelSix = "six" + + /** + * The order here is important because it's the same as the indices in the training data. + - LabelSeven is _C0_ (index 0) + - LabelEight is _C1_ (index 1) + - LabelSix is _C2_ (index 2) + */ + private val AllLabels = Vector(LabelSeven, LabelEight, LabelSix) + + private def tmpFile() = { + val f = File.createTempFile(classOf[VwMultilabelModelTest].getSimpleName + "_", ".vw.model") + f.deleteOnExit() + f + } + + private def vwTrainingParams(modelFile: File = tmpFile()) = { + + val flags = + """ + | --quiet + | --csoaa_ldf mc + | --csoaa_rank + | --loss_function logistic + | --link logistic + | -q YX + | --noconstant + | --ignore_linear X + | --ignore y + | -f + """.stripMargin.trim + + (flags + " " + modelFile.getCanonicalPath).split("\n").map(_.trim).mkString(" ") + } + + private val ModelFile = tmpFile() + + private val FeatureName = "feature" + + /** + * A dataset that creates the following marginal distribution. + - Pr[seven] = 0.7 where seven is _C0_ + - Pr[eight] = 0.8 where eight is _C1_ + - Pr[six] = 0.6 where six is _C2_ + */ + private val TrainingData = + Vector( + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.084 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.084 |Y _C0_\n2:-0.084 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.024 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.336 |y _C2147483649_\n1:-0.336 |Y _C1_\n0:-0.336 |Y _C0_\n2:-0.336 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.056 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.056 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.144 |y _C2147483649_\n1:-0.144 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.144 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.224 |y _C2147483649_\n1:-0.224 |Y _C1_\n0:-0.224 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.036 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.036 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.096 |y _C2147483649_\n1:-0.096 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_" + ).map(_.split(raw"\n")) + + private lazy val TrainedModel: ModelSource = { + val modelFile = tmpFile() + val params = vwTrainingParams(modelFile) + val learner = VWLearners.create[VWActionScoresLearner](params) + + for { + _ <- 1 to 50 + d <- TrainingData + } { + val asdf = d.toVector + learner.learn(asdf.toArray) + } + + learner.close() + + ExternalSource(Vfs.javaFileToAloha(modelFile)) + } + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() } \ No newline at end of file From f71aba58d19bcd0f615bd38940aae7d6cd90508e Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 12 Sep 2017 10:21:03 -0700 Subject: [PATCH 39/98] Updated VwSparseMultilabelPredictor. It now seems to be fully working as show in the test VwMultilabelModelTest. --- .../models/multilabel/MultilabelModel.scala | 2 +- .../VwSparseMultilabelPredictor.scala | 37 +--- .../multilabel/VwMultilabelModelTest.scala | 174 +++++++++++++----- 3 files changed, 135 insertions(+), 78 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 3c63acd1..67059ff3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -165,7 +165,7 @@ object MultilabelModel extends ParserProviderCompanion { private[multilabel] val TooManyMissingError = "Too many missing features encountered to produce prediction." - private[multilabel] val NoLabelsError = "No labels found in training set." + private[multilabel] val NoLabelsError = "No labels provided. Cannot produce a prediction." /** * Get the labels and information about the labels. diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index b2e072c3..58241a51 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -57,7 +57,7 @@ extends SparseMultiLabelPredictor[K] ): Try[Map[K, Double]] = { val x = multiLabelClassifierInput(features, indices, defaultNs, namespaces) val pred = Try { vwModel.predict(x) } - val yOut = pred.map { y => produceOutput(y, labels, indices) } + val yOut = pred.map { y => produceOutput(y, labels) } yOut } @@ -65,13 +65,7 @@ extends SparseMultiLabelPredictor[K] } object VwSparseMultilabelPredictor { - private val DummyClassNS = "y" private val ClassNS = "Y" - private val NegDummyClass = Int.MaxValue.toLong + 1 - private val PosDummyClass = NegDummyClass + 1 - private val NegDummyClassLine = s"$NegDummyClass:1 |$DummyClassNS _C${NegDummyClass}_" - private val PosDummyClassLine = s"$PosDummyClass:0 |$DummyClassNS _C${PosDummyClass}_" - private[multilabel] def multiLabelClassifierInput( features: IndexedSeq[Sparse], @@ -80,12 +74,9 @@ object VwSparseMultilabelPredictor { namespaces: List[(String, List[Int])] ): Array[String] = { val n = indices.size - // The length of the output array is n + 3. The first row is the shared features. - // These are features that are not label dependent. Then we have two dummy classes. - // - // The class at the index 1 (0-based) is one that is always negative. The class at - // index 2 is one that is always positive. These dummy classes are required to make - // the probabilities work out for multi-label classifier. + // The length of the output array is n + 1. The first row is the shared features. + // These are features that are not label dependent. Then come the features for the + // n labels. val x = new Array[String](n + 1) val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) @@ -96,37 +87,23 @@ object VwSparseMultilabelPredictor { // This is mutable because we want speed. var i = 0 - // This is fantastic! - // TODO: It appears that we don't have to add the dummy classes at test time. Double check. while (i < n) { val labelInd = indices(i) x(i + 1) = s"$labelInd:0 |$ClassNS _C${labelInd}_" i += 1 } -// x(n + 1) = NegDummyClassLine -// x(n + 2) = PosDummyClassLine - x } - private[multilabel] def produceOutput[K]( - pred: ActionScores, - labels: sci.IndexedSeq[K], - indices: sci.IndexedSeq[Int] - ): Map[K, Double] = { - - // TODO: Possibly update the interface to pass this in (possibly non-strictly). - val indToLabel: Map[Int, K] = indices.zip(labels)(collection.breakOut) + private[multilabel] def produceOutput[K](pred: ActionScores, labels: sci.IndexedSeq[K]): Map[K, Double] = { // The last two action IDs in the action scores array are the dummy actions. - val y: Map[K, Double] = (for { + (for { as <- pred.getActionScores // if as.getAction < indices.size if we need to deal with dummy classes. - label <- indToLabel.get(as.getAction).toIterable + label = labels(as.getAction) pred = modifiedLogistic(as.getScore) } yield label -> pred)(collection.breakOut) - - y } /** diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index a0dc4ce7..67aaf22c 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -2,20 +2,20 @@ package com.eharmony.aloha.models.vw.jni.multilabel import java.io.File -import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor -import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel -import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} +import com.eharmony.aloha.semantics.func.GenFunc0 import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} -import scala.collection.{immutable => sci} +import scala.annotation.tailrec /** * Created by ryan.deak on 9/11/17. @@ -25,62 +25,103 @@ class VwMultilabelModelTest { import VwMultilabelModelTest._ @Test def test1(): Unit = { - val features = Vector[GenAggFunc[Domain, Sparse]]( - GenFunc0("", (_: Domain) => Iterable(("", 1d))) - ) - - val predProd = VwSparseMultilabelPredictorProducer[Label]( - modelSource = TrainedModel, - params = "", // to see the output: "-p /dev/stdout", - defaultNs = List.empty[Int], - namespaces = List(("X", List(0))) // List.empty[(String, List[Int])] - ) - - val model = - MultilabelModel( - modelId = ModelId(1, "model"), - featureNames = Vector(FeatureName), - featureFunctions = features, - labelsInTrainingSet = AllLabels, - labelsOfInterest = Option.empty[GenAggFunc[Domain, sci.IndexedSeq[Label]]], - predictorProducer = predProd, - numMissingThreshold = Option.empty[Int], - auditor = Auditor) - - val result = model("This value doesn't matter.") + val model = Model try { - result.value match { - case None => fail() - case Some(labelMap) => - assertEquals(AllLabels.toSet, labelMap.keySet) - assertEquals(0.7, labelMap(LabelSeven), 0.01) - assertEquals(0.8, labelMap(LabelEight), 0.01) - assertEquals(0.6, labelMap(LabelSix), 0.01) - } + // Test the no label case. + testEmpty(model("")) + + // Test all elements of the power set of labels (except the empty set, done above). + for { + labels <- powerSet(AllLabels.toSet).filter(ls => ls.nonEmpty) + x = labels.mkString(",") + y = model(x) + } testOutput(y, labels, ExpectedMarginalDist) } finally { model.close() } } + + private[this] def testEmpty(yEmpty: PredictionOutput): Unit = { + assertEquals(None, yEmpty.value) + assertEquals(Vector("No labels provided. Cannot produce a prediction."), yEmpty.errorMsgs) + assertEquals(Set.empty, yEmpty.missingVarNames) + assertEquals(None, yEmpty.prob) + } + + private[this] def testOutput( + y: PredictionOutput, + labels: Set[Label], + expectedMarginalDist: Map[Label, Double] + ): Unit = y.value match { + case None => fail() + case Some(labelMap) => + assertEquals(labels, labelMap.keySet) + val exp = expectedMarginalDist.filterKeys(label => labels contains label) + exp foreach { case (label, expPr) => + assertEquals(expPr, labelMap(label), 0.01) + } + } } object VwMultilabelModelTest { private type Label = String - private type Domain = Any + private type Domain = String + private type PredictionOutput = RootedTree[Any, Map[Label, Double]] + + private[this] val TrainingEpochs = 30 + private val LabelSix = "six" private val LabelSeven = "seven" private val LabelEight = "eight" - private val LabelSix = "six" + + private val ExpectedMarginalDist = Map( + LabelSix -> 0.6, + LabelSeven -> 0.7, + LabelEight -> 0.8 + ) /** * The order here is important because it's the same as the indices in the training data. - LabelSeven is _C0_ (index 0) - LabelEight is _C1_ (index 1) - - LabelSix is _C2_ (index 2) + - LabelSix is _C2_ (index 2) */ private val AllLabels = Vector(LabelSeven, LabelEight, LabelSix) + private lazy val Model: Model[Domain, RootedTree[Any, Map[Label, Double]]] = { + val featureNames = Vector(FeatureName) + val features = Vector(GenFunc0("", (_: Domain) => Iterable(("", 1d)))) + + // Get the list of labels from the comma-separated list passed in the input string. + val labelsOfInterestFn = + GenFunc0("", + (x: Domain) => + x.split("\\s*,\\s*") + .map(label => label.trim) + .filter(label => label.nonEmpty) + .toVector + ) + + val predProd = VwSparseMultilabelPredictorProducer[Label]( + modelSource = TrainedModel, + params = "", // to see the output: "-p /dev/stdout", + defaultNs = List.empty[Int], + namespaces = List(("X", List(0))) + ) + + MultilabelModel( + modelId = ModelId(1, "model"), + featureNames = featureNames, + featureFunctions = features, + labelsInTrainingSet = AllLabels, + labelsOfInterest = Option(labelsOfInterestFn), + predictorProducer = predProd, + numMissingThreshold = Option.empty[Int], + auditor = Auditor) + } + private def tmpFile() = { val f = File.createTempFile(classOf[VwMultilabelModelTest].getSimpleName + "_", ".vw.model") f.deleteOnExit() @@ -89,13 +130,36 @@ object VwMultilabelModelTest { private def vwTrainingParams(modelFile: File = tmpFile()) = { + // NOTES: + // 1. `--csoaa_rank` is needed by VW to make a VWActionScoresLearner. + // 2. `-q YX` forms the cross product of features between the label-based features in Y + // and the side information in X. If the features in namespace Y are unique to each + // class, the cross product effectively makes one model per class by interacting the + // class features to the features in X. Since the class labels vary (possibly) + // independently, this cross product is capable of creating independent models + // (one per class). + // 3. `--ignore_linear Y` is provided because in the data, there is one constant feature + // in the Y namespace per class. This flag acts essentially the same as the + // `--noconstant` flag in the traditional binary classification context. It omits the + // one feature related to the class (the intercept). + // 4. `--noconstant` is specified because there's no need to have a "common intercept" + // whose value is the same across all models. This should be the job of the + // first-order features in Y (the per-class intercepts). + // 5. `--ignore y` is used to ignore the features in the namespace related to the dummy + // classes. We need these two dummy classes in training to make `--csoaa_ldf mc` + // output proper probabilities. + // 6. `--link logistic` doesn't actually seem to do anything. + // 7. `--loss_function logistic` works properly during training; however, when + // interrogating the scores, it is important to remember they appear as + // '''negative logits'''. This is because the CSOAA algorithm uses '''costs''', so + // "smaller is better". So, to get the probability, one must do `1/(1 + exp(-1 * -y))` + // or simply `1/(1 + exp(y))`. val flags = """ | --quiet | --csoaa_ldf mc | --csoaa_rank | --loss_function logistic - | --link logistic | -q YX | --noconstant | --ignore_linear X @@ -106,15 +170,17 @@ object VwMultilabelModelTest { (flags + " " + modelFile.getCanonicalPath).split("\n").map(_.trim).mkString(" ") } - private val ModelFile = tmpFile() - private val FeatureName = "feature" /** * A dataset that creates the following marginal distribution. - Pr[seven] = 0.7 where seven is _C0_ - Pr[eight] = 0.8 where eight is _C1_ - - Pr[six] = 0.6 where six is _C2_ + - Pr[six] = 0.6 where six is _C2_ + * + * The observant reader may notice these are oddly ordered. On each line C1 appears first, + * then C0, then C2. This is done to show ordering doesn't matter. What matters is the + * class '''indices'''. */ private val TrainingData = Vector( @@ -134,12 +200,9 @@ object VwMultilabelModelTest { val learner = VWLearners.create[VWActionScoresLearner](params) for { - _ <- 1 to 50 + _ <- 1 to TrainingEpochs d <- TrainingData - } { - val asdf = d.toVector - learner.learn(asdf.toArray) - } + } learner.learn(d) learner.close() @@ -147,4 +210,21 @@ object VwMultilabelModelTest { } private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + + /** + * Creates the power set of the provided set. + * Answer provided by Chris Marshall (@oxbow_lakes) on + * [[https://stackoverflow.com/a/11581323 Stack Overflow]]. + * @param generatorSet a set for which a power set should be produced. + * @tparam A type of elements in the set. + * @return the power set of `generatorSet`. + */ + private def powerSet[A](generatorSet: Set[A]): Set[Set[A]] = { + @tailrec def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + // powerset of empty set is the set of the empty set. + pwr(generatorSet, Set(Set.empty[A])) + } } \ No newline at end of file From e8159fc243a24eb7d48408c8bb961974817c4ca1 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 12 Sep 2017 10:24:28 -0700 Subject: [PATCH 40/98] Added some test comments. --- .../models/vw/jni/multilabel/VwMultilabelModelTest.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 67aaf22c..4374df30 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -57,8 +57,14 @@ class VwMultilabelModelTest { ): Unit = y.value match { case None => fail() case Some(labelMap) => + // Should that model gives back all the keys that were request. assertEquals(labels, labelMap.keySet) + + // The expected return value. We check keys first to make sure the model + // doesn't return extraneous key-value pairs. val exp = expectedMarginalDist.filterKeys(label => labels contains label) + + // Check that the values returned are correct. exp foreach { case (label, expPr) => assertEquals(expPr, labelMap(label), 0.01) } From 6f98ce225c37c95d2f18637241415b3afafbfbb2 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 12 Sep 2017 10:51:14 -0700 Subject: [PATCH 41/98] Added comments to VwSparseMultilabelPredictor. --- .../VwSparseMultilabelPredictor.scala | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 58241a51..c3432ee3 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -13,9 +13,15 @@ import scala.collection.{immutable => sci} import scala.util.Try /** + * * Created by ryan.deak on 9/8/17. - * @param modelSource a specification of a location for the underlying VW model that - * will be materialized in this class. + * @param modelSource a specification for the underlying ''Cost Sensitive One Against All'' + * VW model with ''label dependent features''. VW flag + * `--csoaa_ldf mc` is expected. For more information, see the + * [[https://github.com/JohnLangford/vowpal_wabbit/wiki/Cost-Sensitive-One-Against-All-(csoaa)-multi-class-example VW CSOAA wiki page]]. + * Also see the ''Cost-Sensitive Multiclass Classification'' section of + * Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html On Multiclass Classification in VW]] + * page. This model specification will be materialized in this class. * @param params VW parameters. * @param defaultNs The list of indices into the `features` sequence that does not have * an exist in any value of the `namespaces` map. @@ -49,6 +55,16 @@ extends SparseMultiLabelPredictor[K] require(vwModel != null) } + /** + * Given the input, form a VW example, and delegate to the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param labels labels for which the VW learner should produce predictions. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param labelDependentFeatures Any label dependent features. This is not yet utilized and + * is currently ignored. + * @return a Map from label to prediction. + */ override def apply( features: IndexedSeq[Sparse], labels: sci.IndexedSeq[K], @@ -67,6 +83,17 @@ extends SparseMultiLabelPredictor[K] object VwSparseMultilabelPredictor { private val ClassNS = "Y" + /** + * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param defaultNs the indices into `features` that should be placed in VW's default + * namespace. + * @param namespaces the indices into `features` that should be associated with each + * namespace. + * @return an array to be passed directly to an underlying `VWActionScoresLearner`. + */ private[multilabel] def multiLabelClassifierInput( features: IndexedSeq[Sparse], indices: sci.IndexedSeq[Int], @@ -96,11 +123,17 @@ object VwSparseMultilabelPredictor { x } + /** + * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. + * @param pred predictions returned by the underlying VW ''CSOAA LDF'' model. + * @param labels the labels provided to the `apply` function. This determines which predictions + * should be produced. + * @tparam K The label or class type. + * @return a map of predictions from label to prediction. + */ private[multilabel] def produceOutput[K](pred: ActionScores, labels: sci.IndexedSeq[K]): Map[K, Double] = { - - // The last two action IDs in the action scores array are the dummy actions. (for { - as <- pred.getActionScores // if as.getAction < indices.size if we need to deal with dummy classes. + as <- pred.getActionScores label = labels(as.getAction) pred = modifiedLogistic(as.getScore) } yield label -> pred)(collection.breakOut) @@ -108,8 +141,9 @@ object VwSparseMultilabelPredictor { /** * A modified logistic function where the sign of the exponent is opposite the usual - * definition. Since CSOAA in VW employs costs, it changes the sign of the normal - * logistic function so the definition becomes `1 / (1 + exp(x))`. + * definition. Since CSOAA in VW employs costs, it returns the negative logits which + * changes the sign of the normal logistic function so the definition becomes + * `1 / (1 + exp(x))`. * * @param x an input produced by a VW CSOAA prediction. * @return a probability. From 47c06581fac8fec3b838cdf616571a5e659c7fd6 Mon Sep 17 00:00:00 2001 From: amirziai Date: Tue, 12 Sep 2017 13:51:55 -0700 Subject: [PATCH 42/98] One more passing test, moving to performance testing now --- .../multilabel/MultilabelModelTest.scala | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index bff54d26..d48df62d 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -246,7 +246,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq("labels"), Seq())) assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) - // TODO: figre this out // error val labelsExtractorError: (Map[String, String]) => Option[sci.IndexedSeq[Label]] = (m: Map[String, String]) => m.get("labels") match { @@ -274,11 +273,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { labelToInt) ).failed.get match { case ex: SemanticsUdfException[_] => - assertEquals(ex.accessorsInErr, labelsAndInfoNoLabelsGen2.problems) + assertEquals(ex.accessorsInErr, problemsNoLabelsExpectedGen2.get.errors) } - - - // assertEquals(problemsNoLabelsExpectedGen2, labelsAndInfoNoLabelsGen2.problems) } @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { @@ -356,22 +352,40 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val example = Map("" -> "") val allLabels = sci.IndexedSeq("a", "b", "c") - val labelToInt = allLabels.zipWithIndex.toMap + val labelToInd = allLabels.zipWithIndex.toMap val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val defaultLabelInfo = LabelsAndInfo(allLabels.indices, allLabels, Seq.empty, None) + val li = labelsAndInfo(example, Option(labelsOfInterestExtractor), labelToInd, defaultLabelInfo) + val report = reportNoPrediction(ModelId(), li, aud) + assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) + } + @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { + // When the amount of missing data exceeds the threshold, reportTooManyMissing should be + // called and its value should be returned. Instantiate a MultilabelModel and + // call apply with some missing data required by the features. + def extractLabelsOutOfExample(example: Map[String, String]) = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq - // labelsAndInfo(example,) + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + + // TODO: continue here + +// val modelWithThreshold = MultilabelModel( +// ModelId(), +// sci.IndexedSeq("a", "b", "c", "d"), +// sci.IndexedSeq(labelsOfInterestExtractor), +// sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), +// None, +// Lazy(ConstantPredictor[Label]()), +// Option(2), +// Auditor +// ) +// +// println(modelWithThreshold(5)) fail() } - -// @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { -// // When the amount of missing data exceeds the threshold, reportTooManyMissing should be -// // called and its value should be returned. Instantiate a MultilabelModel and -// // call apply with some missing data required by the features. -// -// fail() -// } // // @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { // // Create a predictorProducer that throws. Check that the model still returns a value @@ -403,7 +417,7 @@ object MultilabelModelTest { override def apply(v1: SparseFeatures, v2: Labels[K], v3: LabelIndices, - v4: SparseLabelDepFeatures): Map[K, Double] = v2.map(_ -> prediction).toMap + v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(v2.map(_ -> prediction).toMap) } case class Lazy[A](value: A) extends (() => A) { From c54d9154c65533a7b788dc34488d478b41726592 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 13 Sep 2017 14:16:18 -0700 Subject: [PATCH 43/98] updated split --- .../aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 4374df30..83edcc20 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -198,13 +198,12 @@ object VwMultilabelModelTest { s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.224 |y _C2147483649_\n1:-0.224 |Y _C1_\n0:-0.224 |Y _C0_\n2:0.0 |Y _C2_", s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.036 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.036 |Y _C2_", s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.096 |y _C2147483649_\n1:-0.096 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_" - ).map(_.split(raw"\n")) + ).map(_.split("\n")) private lazy val TrainedModel: ModelSource = { val modelFile = tmpFile() val params = vwTrainingParams(modelFile) val learner = VWLearners.create[VWActionScoresLearner](params) - for { _ <- 1 to TrainingEpochs d <- TrainingData From 4c00c6f638c0d6aec649f228026a603920368c56 Mon Sep 17 00:00:00 2001 From: amirziai Date: Mon, 18 Sep 2017 13:34:33 -0700 Subject: [PATCH 44/98] Merging updates --- .../multilabel/VwMultilabelModelTest.scala | 233 ++++++++++++++++-- version.sbt | 2 +- 2 files changed, 211 insertions(+), 24 deletions(-) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 5389f702..30afc1d2 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -1,16 +1,21 @@ package com.eharmony.aloha.models.vw.jni.multilabel -import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor -import com.eharmony.aloha.dataset.density.Sparse +import java.io.File + +import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel -import com.eharmony.aloha.semantics.func.GenAggFunc -import com.eharmony.aloha.util.SerializabilityEvidence +import com.eharmony.aloha.semantics.func.GenFunc0 +import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} -import scala.collection.{immutable => sci} +import scala.annotation.tailrec /** * Created by ryan.deak on 9/11/17. @@ -20,29 +25,211 @@ class VwMultilabelModelTest { import VwMultilabelModelTest._ @Test def test1(): Unit = { - val predProd = VwSparseMultilabelPredictorProducer[Label]( - modelSource = null, // ModelSource, - params = "", - defaultNs = List.empty[Int], - namespaces = List.empty[(String, List[Int])] - ) + val model = Model + + try { + // Test the no label case. + testEmpty(model("")) + + // Test all elements of the power set of labels (except the empty set, done above). + for { + labels <- powerSet(AllLabels.toSet).filter(ls => ls.nonEmpty) + x = labels.mkString(",") + y = model(x) + } testOutput(y, labels, ExpectedMarginalDist) + } + finally { + model.close() + } + } + + private[this] def testEmpty(yEmpty: PredictionOutput): Unit = { + assertEquals(None, yEmpty.value) + assertEquals(Vector("No labels provided. Cannot produce a prediction."), yEmpty.errorMsgs) + assertEquals(Set.empty, yEmpty.missingVarNames) + assertEquals(None, yEmpty.prob) + } + + private[this] def testOutput( + y: PredictionOutput, + labels: Set[Label], + expectedMarginalDist: Map[Label, Double] + ): Unit = y.value match { + case None => fail() + case Some(labelMap) => + // Should that model gives back all the keys that were request. + assertEquals(labels, labelMap.keySet) - val model = - MultilabelModel( - modelId = ModelId(1, "model"), - featureNames = sci.IndexedSeq.empty[String], - featureFunctions = sci.IndexedSeq.empty[GenAggFunc[Domain, Sparse]], - labelsInTrainingSet = sci.IndexedSeq.empty[Label], - labelsOfInterest = Option.empty[GenAggFunc[Domain, sci.IndexedSeq[Label]]], - predictorProducer = predProd, - numMissingThreshold = Option(1000000), - auditor = Auditor) + // The expected return value. We check keys first to make sure the model + // doesn't return extraneous key-value pairs. + val exp = expectedMarginalDist.filterKeys(label => labels contains label) + + // Check that the values returned are correct. + exp foreach { case (label, expPr) => + assertEquals(expPr, labelMap(label), 0.01) + } } } object VwMultilabelModelTest { - private type Label = String - private type Domain = Any + private type Label = String + private type Domain = String + private type PredictionOutput = RootedTree[Any, Map[Label, Double]] + + private[this] val TrainingEpochs = 30 + + private val LabelSix = "six" + private val LabelSeven = "seven" + private val LabelEight = "eight" + + private val ExpectedMarginalDist = Map( + LabelSix -> 0.6, + LabelSeven -> 0.7, + LabelEight -> 0.8 + ) + + /** + * The order here is important because it's the same as the indices in the training data. + - LabelSeven is _C0_ (index 0) + - LabelEight is _C1_ (index 1) + - LabelSix is _C2_ (index 2) + */ + private val AllLabels = Vector(LabelSeven, LabelEight, LabelSix) + + private lazy val Model: Model[Domain, RootedTree[Any, Map[Label, Double]]] = { + val featureNames = Vector(FeatureName) + val features = Vector(GenFunc0("", (_: Domain) => Iterable(("", 1d)))) + + // Get the list of labels from the comma-separated list passed in the input string. + val labelsOfInterestFn = + GenFunc0("", + (x: Domain) => + x.split("\\s*,\\s*") + .map(label => label.trim) + .filter(label => label.nonEmpty) + .toVector + ) + + val predProd = VwSparseMultilabelPredictorProducer[Label]( + modelSource = TrainedModel, + params = "", // to see the output: "-p /dev/stdout", + defaultNs = List.empty[Int], + namespaces = List(("X", List(0))) + ) + + MultilabelModel( + modelId = ModelId(1, "model"), + featureNames = featureNames, + featureFunctions = features, + labelsInTrainingSet = AllLabels, + labelsOfInterest = Option(labelsOfInterestFn), + predictorProducer = predProd, + numMissingThreshold = Option.empty[Int], + auditor = Auditor) + } + + private def tmpFile() = { + val f = File.createTempFile(classOf[VwMultilabelModelTest].getSimpleName + "_", ".vw.model") + f.deleteOnExit() + f + } + + private def vwTrainingParams(modelFile: File = tmpFile()) = { + + // NOTES: + // 1. `--csoaa_rank` is needed by VW to make a VWActionScoresLearner. + // 2. `-q YX` forms the cross product of features between the label-based features in Y + // and the side information in X. If the features in namespace Y are unique to each + // class, the cross product effectively makes one model per class by interacting the + // class features to the features in X. Since the class labels vary (possibly) + // independently, this cross product is capable of creating independent models + // (one per class). + // 3. `--ignore_linear Y` is provided because in the data, there is one constant feature + // in the Y namespace per class. This flag acts essentially the same as the + // `--noconstant` flag in the traditional binary classification context. It omits the + // one feature related to the class (the intercept). + // 4. `--noconstant` is specified because there's no need to have a "common intercept" + // whose value is the same across all models. This should be the job of the + // first-order features in Y (the per-class intercepts). + // 5. `--ignore y` is used to ignore the features in the namespace related to the dummy + // classes. We need these two dummy classes in training to make `--csoaa_ldf mc` + // output proper probabilities. + // 6. `--link logistic` doesn't actually seem to do anything. + // 7. `--loss_function logistic` works properly during training; however, when + // interrogating the scores, it is important to remember they appear as + // '''negative logits'''. This is because the CSOAA algorithm uses '''costs''', so + // "smaller is better". So, to get the probability, one must do `1/(1 + exp(-1 * -y))` + // or simply `1/(1 + exp(y))`. + val flags = + """ + | --quiet + | --csoaa_ldf mc + | --csoaa_rank + | --loss_function logistic + | -q YX + | --noconstant + | --ignore_linear X + | --ignore y + | -f + """.stripMargin.trim + + (flags + " " + modelFile.getCanonicalPath).split("\n").map(_.trim).mkString(" ") + } + + private val FeatureName = "feature" + + /** + * A dataset that creates the following marginal distribution. + - Pr[seven] = 0.7 where seven is _C0_ + - Pr[eight] = 0.8 where eight is _C1_ + - Pr[six] = 0.6 where six is _C2_ + * + * The observant reader may notice these are oddly ordered. On each line C1 appears first, + * then C0, then C2. This is done to show ordering doesn't matter. What matters is the + * class '''indices'''. + */ + private val TrainingData = + Vector( + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.084 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.084 |Y _C0_\n2:-0.084 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.024 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.336 |y _C2147483649_\n1:-0.336 |Y _C1_\n0:-0.336 |Y _C0_\n2:-0.336 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.056 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.056 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.144 |y _C2147483649_\n1:-0.144 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.144 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.224 |y _C2147483649_\n1:-0.224 |Y _C1_\n0:-0.224 |Y _C0_\n2:0.0 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.036 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.036 |Y _C2_", + s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.096 |y _C2147483649_\n1:-0.096 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_" + ).map(_.split("\n")) + + private lazy val TrainedModel: ModelSource = { + val modelFile = tmpFile() + val params = vwTrainingParams(modelFile) + val learner = VWLearners.create[VWActionScoresLearner](params) + for { + _ <- 1 to TrainingEpochs + d <- TrainingData + } learner.learn(d) + + learner.close() + + ExternalSource(Vfs.javaFileToAloha(modelFile)) + } private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + + /** + * Creates the power set of the provided set. + * Answer provided by Chris Marshall (@oxbow_lakes) on + * [[https://stackoverflow.com/a/11581323 Stack Overflow]]. + * @param generatorSet a set for which a power set should be produced. + * @tparam A type of elements in the set. + * @return the power set of `generatorSet`. + */ + private def powerSet[A](generatorSet: Set[A]): Set[Set[A]] = { + @tailrec def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + // powerset of empty set is the set of the empty set. + pwr(generatorSet, Set(Set.empty[A])) + } } \ No newline at end of file diff --git a/version.sbt b/version.sbt index cbf330b0..ebff5912 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "5.0.0-SNAPSHOT" +version in ThisBuild := "5.0.1-SNAPSHOT" From 86890c1a43c12e0c57c490571e28fd21fd84d8da Mon Sep 17 00:00:00 2001 From: amirziai Date: Mon, 18 Sep 2017 16:20:19 -0700 Subject: [PATCH 45/98] Figured out the gist of a few more tests, terrible code though --- .../multilabel/MultilabelModelTest.scala | 147 +++++++++++++----- 1 file changed, 110 insertions(+), 37 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index d48df62d..29232f5f 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -6,6 +6,7 @@ import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.Model import com.eharmony.aloha.semantics.SemanticsUdfException import com.eharmony.aloha.semantics.func._ import org.junit.Test @@ -24,12 +25,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { import MultilabelModel._ import MultilabelModelTest._ - // TODO: Fill in the test implementation and delete comments once done. - @Test def testSerialization(): Unit = { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - val modelRoundTrip = serializeDeserializeRoundTrip(modelNoFeatures) assertEquals(modelNoFeatures, modelRoundTrip) } @@ -364,47 +362,122 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // When the amount of missing data exceeds the threshold, reportTooManyMissing should be // called and its value should be returned. Instantiate a MultilabelModel and // call apply with some missing data required by the features. + + // TODO: this is a terrible setup + val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("""Iterable(("", 1d))""", + (x: Map[String, String]) => if (x.keys.toVector.contains("a")) { + Iterable((x.keys.head, 1d)) + } else { + Iterable() + }) + + val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + + val modelWithThreshold = MultilabelModel( + modelId = ModelId(1, "model1"), + featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureFunctions = featureFunctions, + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsOfInterest = None, + predictorProducer = Lazy(ConstantPredictor[Label]()), + numMissingThreshold = Option(0), + auditor = Auditor + ) + + val result = modelWithThreshold(Map()) + assertEquals(TooManyMissingError, result.errorMsgs.head) + } + + @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { + // Create a predictorProducer that throws. Check that the model still returns a value + // and that the error message is incorporated appropriately. + + val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("""Iterable(("", 1d))""", + (x: Map[String, String]) => if (x.keys.toVector.contains("a")) { + Iterable((x.keys.head, 1d)) + } else { + Iterable() + }) + + val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + + case class PredictorThatThrows[K](prediction: Double = 0d) extends + SparseMultiLabelPredictor[K] { + override def apply(v1: SparseFeatures, + v2: Labels[K], + v3: LabelIndices, + v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(throw new Exception("error")) + } + val modelWithThrowingPredictorProducer = MultilabelModel( + modelId = ModelId(1, "model1"), + featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureFunctions = featureFunctions, + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsOfInterest = None, + predictorProducer = Lazy(PredictorThatThrows[Label]()), + numMissingThreshold = None, + auditor = Auditor + ) + val result = modelWithThrowingPredictorProducer(Map()) + assertEquals(None, result.value) + assert(result.errorMsgs.head.contains("java.lang.Exception: error")) + } + + @Test def testSubvalueSuccess(): Unit = { + // Test the happy path by calling model.apply. Check the value, missing data, and error + // messages. + + val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("""Iterable(("", 1d))""", (_: Any) => Iterable(("", 1d))) + + val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + def extractLabelsOutOfExample(example: Map[String, String]) = example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - // TODO: continue here - -// val modelWithThreshold = MultilabelModel( -// ModelId(), -// sci.IndexedSeq("a", "b", "c", "d"), -// sci.IndexedSeq(labelsOfInterestExtractor), -// sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), -// None, -// Lazy(ConstantPredictor[Label]()), -// Option(2), -// Auditor -// ) -// -// println(modelWithThreshold(5)) + val modelSuccess = MultilabelModel( + modelId = ModelId(1, "model1"), + featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureFunctions = featureFunctions, + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsOfInterest = Option(labelsOfInterestExtractor), + predictorProducer = Lazy(ConstantPredictor[Label]()), + numMissingThreshold = None, + auditor = Auditor + ) + val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) + assertEquals(Vector(), result.errorMsgs) + assertEquals(Set(), result.missingVarNames) + assertEquals(Option(Map("label1" -> 0.0, "label2" -> 0.0)), result.value) + } - fail() + @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { + // NOTE: This is by design. + + // TODO: this actually throws + val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("""Iterable(("", 1d))""", (x: Map[String, String]) => Iterable((x("hello"), 1d))) + + val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + + val modelSuccess = MultilabelModel( + modelId = ModelId(1, "model1"), + featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureFunctions = featureFunctions, + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsOfInterest = None, + predictorProducer = Lazy(ConstantPredictor[Label]()), + numMissingThreshold = None, + auditor = Auditor + ) + val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) + + println(result.value) } -// -// @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { -// // Create a predictorProducer that throws. Check that the model still returns a value -// // and that the error message is incorporated appropriately. -// -// fail() -// } -// -// @Test def testSubvalueSuccess(): Unit = { -// // Test the happy path by calling model.apply. Check the value, missing data, and error messages. -// -// fail() -// } -// -// @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { -// // NOTE: This is by design. -// -// fail() -// } } object MultilabelModelTest { From daffd68b1daa5d9839aa582f0da17a244a65997a Mon Sep 17 00:00:00 2001 From: amirziai Date: Tue, 19 Sep 2017 15:08:50 -0700 Subject: [PATCH 46/98] First pass over all tests --- .../multilabel/MultilabelModelTest.scala | 146 ++++++++++++------ 1 file changed, 101 insertions(+), 45 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 29232f5f..66d50261 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -1,12 +1,11 @@ package com.eharmony.aloha.models.multilabel -import java.io.{PrintWriter, StringWriter} +import java.io.{Closeable, PrintWriter, StringWriter} import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId -import com.eharmony.aloha.models.Model import com.eharmony.aloha.semantics.SemanticsUdfException import com.eharmony.aloha.semantics.func._ import org.junit.Test @@ -32,47 +31,105 @@ class MultilabelModelTest extends ModelSerializationTestHelper { assertEquals(modelNoFeatures, modelRoundTrip) } -// @Test def testModelCloseClosesPredictor(): Unit = { -// // Make the predictorProducer passed to the constructor be a -// // 'SparsePredictorProducer[K] with Closeable'. -// // predictorProducer should track whether it is closed (using an AtomicBoolean or something). -// // Call close on the MultilabelModel instance and ensure that the underlying predictor is -// // also closed. -// -// fail() -// } -// -// @Test def testLabelsOfInterestOmitted(): Unit = { -// // Test labelsAndInfo[A, K] function. -// // -// // When labelsOfInterest = None, labelsAndInfo should return: -// // LabelsAndInfo[K]( -// // indices = labelsInTrainingSet.indices, -// // labels = labelsInTrainingSet, -// // missingLabels = Seq.empty[K], -// // problems = None -// // ) -// -// fail() -// } -// -// @Test def testLabelsOfInterestProvided(): Unit = { -// // Test labelsAndInfo[A, K] function. -// // -// // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == -// // labelsForPrediction(a, labelsOfInterest.get, labelToInd) -// -// fail() -// } -// -// @Test def testReportTooManyMissing(): Unit = { -// // Make sure Subvalue.natural == None -// // Check the values of Subvalue.audited and make sure they are as expected. -// // Subvalue.audited.value should be None. Check the errors and missing values. -// -// fail() -// } + @Test def testModelCloseClosesPredictor(): Unit = { + // Make the predictorProducer passed to the constructor be a + // 'SparsePredictorProducer[K] with Closeable'. + // predictorProducer should track whether it is closed (using an AtomicBoolean or something). + // Call close on the MultilabelModel instance and ensure that the underlying predictor is + // also closed. + + case class ConstantPredictorClosable[K](prediction: Double = 0d) extends + SparseMultiLabelPredictor[K] with Closeable { + override def apply(v1: SparseFeatures, + v2: Labels[K], + v3: LabelIndices, + v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(v2.map(_ -> prediction).toMap) + + def close(): Unit = println("closing") + } + + val model = MultilabelModel( + ModelId(), + sci.IndexedSeq(), + sci.IndexedSeq[GenAggFunc[Int, Sparse]](), + sci.IndexedSeq[Label](), + None, + Lazy(ConstantPredictorClosable[Label]()), + None, + Auditor + ) + + // TODO: How do I test this? + model.close() + model.predictorProducer.apply() + + fail() + } + + @Test def testLabelsOfInterestOmitted(): Unit = { + // Test labelsAndInfo[A, K] function. + // + // When labelsOfInterest = None, labelsAndInfo should return: + // LabelsAndInfo[K]( + // indices = labelsInTrainingSet.indices, + // labels = labelsInTrainingSet, + // missingLabels = Seq.empty[K], + // problems = None + // ) + + val indices = sci.IndexedSeq[Int](1, 2, 3) + val labels = sci.IndexedSeq[Label]("label1", "label2", "label3") + val labelInfo = LabelsAndInfo( + indices = indices, + labels = labels, + missingLabels = Seq[Label](), + problems = None + ) + + val actual: LabelsAndInfo[String] = labelsAndInfo[Map[String, String], String]( + Map("label1" -> "label1"), + None, + Map("label1" -> 1), + labelInfo + ) + + assertEquals(indices, actual.indices) + assertEquals(labels, actual.labels) + assertEquals(Seq.empty[String], actual.missingLabels) + assertEquals(None, actual.problems) + } + + @Test def testLabelsOfInterestProvided(): Unit = { + // Test labelsAndInfo[A, K] function. + // + // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == + // labelsForPrediction(a, labelsOfInterest.get, labelToInd) + + // TODO: has the signature changed??? still appropriate? + } // + @Test def testReportTooManyMissing(): Unit = { + // Make sure Subvalue.natural == None + // Check the values of Subvalue.audited and make sure they are as expected. + // Subvalue.audited.value should be None. Check the errors and missing values. + + val labelInfo = LabelsAndInfo( + indices = sci.IndexedSeq[Int](), + labels = sci.IndexedSeq[Label](), + missingLabels = Seq[Label]("a"), + problems = None + ) + + val report = reportNoPrediction( + ModelId(1, "a"), + labelInfo, + Auditor + ) + + assertEquals(Vector(NoLabelsError), report.audited.errorMsgs.take(1)) + assertEquals(None, report.audited.value) + } + @Test def testReportNoPrediction(): Unit = { // Make sure Subvalue.natural == None // Check the values of Subvalue.audited and make sure they are as expected. @@ -92,7 +149,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { ) // TODO: check labelInfo values - assertEquals(Vector(NoLabelsError), report.audited.errorMsgs.take(1)) + assertEquals(None, report.natural) assertEquals(None, report.audited.value) } @@ -114,9 +171,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { Auditor ) - // TODO: look into problems assertEquals(Vector(NoLabelsError) ++ errorMessages, report.audited.errorMsgs) assertEquals(None, report.audited.value) + assertEquals(Set(), report.audited.missingVarNames) } @@ -163,7 +220,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // 'value' should equal 'value2'. // Check the errors and missing values. - // TODO: refactor this? val labelInfo = LabelsAndInfo( indices = sci.IndexedSeq[Int](), labels = sci.IndexedSeq[Label](), From 8280ef75b47c58ca499542b661d1049f9678d6b9 Mon Sep 17 00:00:00 2001 From: amirziai Date: Thu, 21 Sep 2017 16:59:51 -0700 Subject: [PATCH 47/98] Simplified some of the tests --- .../semantics/SemanticsUdfException.scala | 4 +- .../multilabel/MultilabelModelTest.scala | 217 ++++++++++-------- 2 files changed, 119 insertions(+), 102 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala b/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala index d869b1fa..b7d2a7f4 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala @@ -26,7 +26,7 @@ case class SemanticsUdfException[+A]( input: A) extends AlohaException(SemanticsUdfException.getMessage(specification, accessorOutput, accessorsMissingOutput, accessorsInErr, input), cause) -private object SemanticsUdfException { +object SemanticsUdfException { /** This method guards against throwing exceptions. * @param specification a specification for the feature that produced an error. @@ -38,7 +38,7 @@ private object SemanticsUdfException { * @tparam A type of input * @return */ - def getMessage[A]( + private def getMessage[A]( specification: String, accessorOutput: Map[String, Try[Any]], accessorsMissingOutput: List[String], diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 66d50261..09f6e963 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -1,6 +1,7 @@ package com.eharmony.aloha.models.multilabel import java.io.{Closeable, PrintWriter, StringWriter} +import java.util.concurrent.atomic.AtomicBoolean import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor @@ -14,7 +15,7 @@ import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner import scala.collection.{immutable => sci, mutable => scm} -import scala.util.Try +import scala.util.{Failure, Random, Success, Try} /** * Created by ryan.deak on 9/1/17. @@ -45,25 +46,26 @@ class MultilabelModelTest extends ModelSerializationTestHelper { v3: LabelIndices, v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(v2.map(_ -> prediction).toMap) - def close(): Unit = println("closing") + private[this] val closed = new AtomicBoolean(false) + override def close(): Unit = closed.set(true) + def isClosed: Boolean = closed.get() } + val pred = ConstantPredictorClosable[Label]() val model = MultilabelModel( ModelId(), sci.IndexedSeq(), sci.IndexedSeq[GenAggFunc[Int, Sparse]](), sci.IndexedSeq[Label](), None, - Lazy(ConstantPredictorClosable[Label]()), + Lazy(pred), None, Auditor ) // TODO: How do I test this? model.close() - model.predictorProducer.apply() - - fail() + assertTrue(pred.isClosed) } @Test def testLabelsOfInterestOmitted(): Unit = { @@ -86,17 +88,14 @@ class MultilabelModelTest extends ModelSerializationTestHelper { problems = None ) - val actual: LabelsAndInfo[String] = labelsAndInfo[Map[String, String], String]( - Map("label1" -> "label1"), - None, - Map("label1" -> 1), + val actual: LabelsAndInfo[Label] = labelsAndInfo[Unit, Label]( + a = (), + labelsOfInterest = None, + Map.empty, labelInfo ) - assertEquals(indices, actual.indices) - assertEquals(labels, actual.labels) - assertEquals(Seq.empty[String], actual.missingLabels) - assertEquals(None, actual.problems) + assertEquals(labelInfo, actual) } @Test def testLabelsOfInterestProvided(): Unit = { @@ -105,7 +104,27 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == // labelsForPrediction(a, labelsOfInterest.get, labelToInd) - // TODO: has the signature changed??? still appropriate? + val indices = sci.IndexedSeq[Int](1, 2, 3) + val labels = sci.IndexedSeq[Label]("label1", "label2", "label3") + + val labelInfo = LabelsAndInfo( + indices = indices, + labels = labels, + missingLabels = Seq[Label](), + problems = None + ) + + val labelsOfInterest = Some(GenFunc0[Unit, sci.IndexedSeq[Label]]("", _ => labels)) + val actual: LabelsAndInfo[Label] = labelsAndInfo[Unit, Label]( + a = (), + labelsOfInterest = labelsOfInterest, + Map.empty, + labelInfo + ) + + val expected = labelsForPrediction((), labelsOfInterest.get, Map.empty[Label, Int]) + + assertEquals(expected, actual) } // @Test def testReportTooManyMissing(): Unit = { @@ -120,6 +139,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { problems = None ) + // TODO: reportTooManyMissing val report = reportNoPrediction( ModelId(1, "a"), labelInfo, @@ -128,6 +148,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { assertEquals(Vector(NoLabelsError), report.audited.errorMsgs.take(1)) assertEquals(None, report.audited.value) + assertEquals(None, report.natural) + + fail() } @Test def testReportNoPrediction(): Unit = { @@ -148,7 +171,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { Auditor ) - // TODO: check labelInfo values + assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) assertEquals(None, report.natural) assertEquals(None, report.audited.value) } @@ -158,6 +181,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Check the values of Subvalue.audited and make sure they are as expected. // Subvalue.audited.value should be None. Check the errors and missing values. + // The missing labels are reported in the error message + val labelInfo = LabelsAndInfo( indices = sci.IndexedSeq[Int](), labels = sci.IndexedSeq[Label](), @@ -245,6 +270,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + // TODO: break this up + // Example with no problems val example: Map[String, String] = Map( "feature1" -> "1", @@ -273,6 +300,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { exampleNoLabels, labelsOfInterestExtractor, labelToInt) + // No problems are actually found + // But the important thing is that problemsNoLabelsExpected is a Some and not a None val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq())) assertEquals(problemsNoLabelsExpected, labelsAndInfoNoLabels.problems) @@ -288,8 +317,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { case _ => None } + val featureWithMissingValue = "feature not present" val f1 = - GenFunc.f1(GeneratedAccessor("labels", labelsExtractor, None))( + GenFunc.f1(GeneratedAccessor(featureWithMissingValue, labelsExtractor, None))( "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] ) @@ -297,37 +327,27 @@ class MultilabelModelTest extends ModelSerializationTestHelper { Map[String, String](), f1, labelToInt) - val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq("labels"), Seq())) + val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq(featureWithMissingValue), Seq())) assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) // error - val labelsExtractorError: (Map[String, String]) => Option[sci.IndexedSeq[Label]] = - (m: Map[String, String]) => m.get("labels") match { - case ls: sci.IndexedSeq[_] if ls.forall { x: String => x.isInstanceOf[Label] } => - Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) - case _ => throw new Exception("labels does not exist") - } - + val featureWithError = "feature has error" val f2 = - GenFunc.f1(GeneratedAccessor("labels", - ((m: Map[String, String]) => throw new Exception("errmsg")) : Map[String, String] => - Option[sci.IndexedSeq[String]] - , None))( + GenFunc.f1(GeneratedAccessor(featureWithError, + ( _ => throw new Exception("errmsg")) : Map[String, String] => Option[sci.IndexedSeq[String]], None))( "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] ) val f2Wrapped = EnrichedErrorGenAggFunc(f2) - new SemanticsUdfException[Any](null, null, null, null, null, null) - - val problemsNoLabelsExpectedGen2 = Option(GenAggFuncAccessorProblems(Seq(), Seq("labels"))) + val problemsNoLabelsExpectedGen2 = Option(GenAggFuncAccessorProblems(Seq(), Seq(featureWithError))) Try( labelsForPrediction( Map[String, String](), f2Wrapped, labelToInt) ).failed.get match { - case ex: SemanticsUdfException[_] => - assertEquals(ex.accessorsInErr, problemsNoLabelsExpectedGen2.get.errors) + case SemanticsUdfException(_, _, _, accessorsInErr, _, _) => assertEquals(accessorsInErr, + problemsNoLabelsExpectedGen2.get.errors) } } @@ -365,6 +385,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val labelsAndInfo3 = labelsForPrediction(example3, labelsOfInterestExtractor, labelToInt) val missingLabels3 = labelsAndInfo3.missingLabels assertEquals(Seq(), missingLabels3) + + // TODO: add these as a different test at the model.apply(.) level + // assertEquals(None, modelNew(Vector("1"))) } @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { @@ -372,7 +395,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip def extractLabelsOutOfExample(example: Map[String, String]) = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.toIndexedSeq // sorted.toIndexedSeq val example: Map[String, String] = Map( "feature1" -> "1", @@ -388,30 +411,24 @@ class MultilabelModelTest extends ModelSerializationTestHelper { "label8" -> "l1" ) - val allLabels = sci.IndexedSeq("a", "b", "c", "235", "1", "l1", "l23", "100") - val labelToInt = allLabels.zipWithIndex.toMap - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) + val allLabels = extractLabelsOutOfExample(example).sorted + val labelToInt: Map[String, Int] = allLabels.zipWithIndex.toMap - assertEquals(allLabels, labelsAndInfo.labels) - assertEquals(allLabels.indices, labelsAndInfo.indices) + val rng = new Random(seed=0) + (1 to 10).foreach { _ => + val ex = rng.shuffle(example.toVector).take(rng.nextInt(example.size)).toMap + val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val labelsAndInfo = labelsForPrediction(ex, labelsOfInterestExtractor, labelToInt) + assertEquals(labelsAndInfo.indices.sorted, labelsAndInfo.indices) + } } @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { // Test this: // if (li.labels.isEmpty) // reportNoPrediction(modelId, li, auditor) - def extractLabelsOutOfExample(example: Map[String, String]) = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq - val example = Map("" -> "") - val allLabels = sci.IndexedSeq("a", "b", "c") - val labelToInd = allLabels.zipWithIndex.toMap - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - val defaultLabelInfo = LabelsAndInfo(allLabels.indices, allLabels, Seq.empty, None) - val li = labelsAndInfo(example, Option(labelsOfInterestExtractor), labelToInd, defaultLabelInfo) - val report = reportNoPrediction(ModelId(), li, aud) - assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) + assertEquals(None, modelNew.subvalue(Vector.empty).natural) } @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { @@ -419,22 +436,16 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // called and its value should be returned. Instantiate a MultilabelModel and // call apply with some missing data required by the features. - // TODO: this is a terrible setup val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = - GenFunc0("""Iterable(("", 1d))""", - (x: Map[String, String]) => if (x.keys.toVector.contains("a")) { - Iterable((x.keys.head, 1d)) - } else { - Iterable() - }) + GenFunc0("", _ => Iterable()) - val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + val featureFunctions = Vector(EmptyIndicatorFn) val modelWithThreshold = MultilabelModel( modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureNames = sci.IndexedSeq("a"), featureFunctions = featureFunctions, - labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2"), labelsOfInterest = None, predictorProducer = Lazy(ConstantPredictor[Label]()), numMissingThreshold = Option(0), @@ -449,90 +460,84 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Create a predictorProducer that throws. Check that the model still returns a value // and that the error message is incorporated appropriately. - val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = - GenFunc0("""Iterable(("", 1d))""", - (x: Map[String, String]) => if (x.keys.toVector.contains("a")) { - Iterable((x.keys.head, 1d)) - } else { - Iterable() - }) - - val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) - - case class PredictorThatThrows[K](prediction: Double = 0d) extends - SparseMultiLabelPredictor[K] { + case object PredictorThatThrows extends + SparseMultiLabelPredictor[Label] { override def apply(v1: SparseFeatures, - v2: Labels[K], + v2: Labels[Label], v3: LabelIndices, - v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(throw new Exception("error")) + v4: SparseLabelDepFeatures): Try[Map[Label, Double]] = Try(throw new Exception("error")) } + val modelWithThrowingPredictorProducer = MultilabelModel( modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq("a", "b", "c", "d"), - featureFunctions = featureFunctions, - labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + featureNames = sci.IndexedSeq.empty, + featureFunctions = sci.IndexedSeq.empty, + labelsInTrainingSet = sci.IndexedSeq(""), // we need at least 1 label to get the error labelsOfInterest = None, - predictorProducer = Lazy(PredictorThatThrows[Label]()), + predictorProducer = Lazy(PredictorThatThrows), numMissingThreshold = None, auditor = Auditor ) + val result = modelWithThrowingPredictorProducer(Map()) assertEquals(None, result.value) - assert(result.errorMsgs.head.contains("java.lang.Exception: error")) + assertEquals("java.lang.Exception: error", result.errorMsgs.head.split("\n").head) } @Test def testSubvalueSuccess(): Unit = { // Test the happy path by calling model.apply. Check the value, missing data, and error // messages. - val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = - GenFunc0("""Iterable(("", 1d))""", (_: Any) => Iterable(("", 1d))) - - val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) - def extractLabelsOutOfExample(example: Map[String, String]) = example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val scoreToReturn = 5d + val modelSuccess = MultilabelModel( modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq("a", "b", "c", "d"), - featureFunctions = featureFunctions, + featureNames = sci.IndexedSeq.empty, + featureFunctions = sci.IndexedSeq.empty, labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), labelsOfInterest = Option(labelsOfInterestExtractor), - predictorProducer = Lazy(ConstantPredictor[Label]()), + predictorProducer = Lazy(ConstantPredictor[Label](scoreToReturn)), numMissingThreshold = None, auditor = Auditor ) val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) assertEquals(Vector(), result.errorMsgs) assertEquals(Set(), result.missingVarNames) - assertEquals(Option(Map("label1" -> 0.0, "label2" -> 0.0)), result.value) + assertEquals(Option(Map("label1" -> scoreToReturn, "label2" -> scoreToReturn)), result.value) } @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { // NOTE: This is by design. - // TODO: this actually throws + val exception = new Exception("error") val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = - GenFunc0("""Iterable(("", 1d))""", (x: Map[String, String]) => Iterable((x("hello"), 1d))) + GenFunc0("", _ => throw exception) - val featureFunctions = Vector.fill(4)(EmptyIndicatorFn) + val featureFunctions = Vector(EmptyIndicatorFn) val modelSuccess = MultilabelModel( modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq("a", "b", "c", "d"), + featureNames = sci.IndexedSeq("throwing feature"), featureFunctions = featureFunctions, - labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsInTrainingSet = sci.IndexedSeq[Label](""), labelsOfInterest = None, predictorProducer = Lazy(ConstantPredictor[Label]()), numMissingThreshold = None, auditor = Auditor ) - val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) - println(result.value) + modelSuccess.copy(modelId=ModelId(1, "b")) + + val result = Try(modelSuccess(Map())) + result match { + case Success(_) => fail() + case Failure(ex) => assertEquals(exception, ex) + } } } @@ -543,10 +548,10 @@ object MultilabelModelTest { private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() case class ConstantPredictor[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] { - override def apply(v1: SparseFeatures, - v2: Labels[K], - v3: LabelIndices, - v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(v2.map(_ -> prediction).toMap) + override def apply(featuresUnused: SparseFeatures, + labels: Labels[K], + indicesUnused: LabelIndices, + ldfUnused: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(labels.map(_ -> prediction).toMap) } case class Lazy[A](value: A) extends (() => A) { @@ -566,6 +571,18 @@ object MultilabelModelTest { Auditor ) + val modelNew = + MultilabelModel( + ModelId(), + sci.IndexedSeq(), + sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), + sci.IndexedSeq[Label]("a", "b", "c"), + Some(GenFunc0("", (a: Vector[String]) => a)), + Lazy(ConstantPredictor[Label]()), + None, + Auditor + ) + val aud: RootedTreeAuditor[Any, Map[Label, Double]] = RootedTreeAuditor[Any, Map[Label, Double]]() // private val failure = aud.failure() From 80fc41cbf90672ef35e60d7c084e26243c924ced Mon Sep 17 00:00:00 2001 From: amirziai Date: Thu, 21 Sep 2017 23:55:21 -0700 Subject: [PATCH 48/98] Refactoring --- .../multilabel/MultilabelModelTest.scala | 58 +++++++------------ 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 09f6e963..6a244156 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -28,6 +28,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testSerialization(): Unit = { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. + val modelRoundTrip = serializeDeserializeRoundTrip(modelNoFeatures) assertEquals(modelNoFeatures, modelRoundTrip) } @@ -39,33 +40,23 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // Call close on the MultilabelModel instance and ensure that the underlying predictor is // also closed. - case class ConstantPredictorClosable[K](prediction: Double = 0d) extends - SparseMultiLabelPredictor[K] with Closeable { - override def apply(v1: SparseFeatures, + case class PredictorClosable[K](prediction: Double = 0d) + extends SparseMultiLabelPredictor[K] with Closeable { + def apply( + v1: SparseFeatures, v2: Labels[K], v3: LabelIndices, - v4: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(v2.map(_ -> prediction).toMap) - + v4: SparseLabelDepFeatures) = Try(Map()) private[this] val closed = new AtomicBoolean(false) override def close(): Unit = closed.set(true) def isClosed: Boolean = closed.get() } - val pred = ConstantPredictorClosable[Label]() - val model = MultilabelModel( - ModelId(), - sci.IndexedSeq(), - sci.IndexedSeq[GenAggFunc[Int, Sparse]](), - sci.IndexedSeq[Label](), - None, - Lazy(pred), - None, - Auditor - ) + val predictor = PredictorClosable[Label]() + val model = modelNoFeatures.copy(predictorProducer = Lazy(predictor)) - // TODO: How do I test this? model.close() - assertTrue(pred.isClosed) + assertTrue(predictor.isClosed) } @Test def testLabelsOfInterestOmitted(): Unit = { @@ -531,7 +522,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { auditor = Auditor ) - modelSuccess.copy(modelId=ModelId(1, "b")) val result = Try(modelSuccess(Map())) result match { @@ -561,26 +551,20 @@ object MultilabelModelTest { val missingLabels: Seq[Label] = Seq("a", "b") val modelNoFeatures = MultilabelModel( - ModelId(), - sci.IndexedSeq(), - sci.IndexedSeq[GenAggFunc[Int, Sparse]](), - sci.IndexedSeq[Label](), - None, - Lazy(ConstantPredictor[Label]()), - None, - Auditor + modelId = ModelId(), + featureNames = sci.IndexedSeq(), + featureFunctions = sci.IndexedSeq[GenAggFunc[Int, Sparse]](), + labelsInTrainingSet = sci.IndexedSeq[Label](), + labelsOfInterest = None, + predictorProducer = Lazy(ConstantPredictor[Label]()), + numMissingThreshold = None, + auditor = Auditor ) - val modelNew = - MultilabelModel( - ModelId(), - sci.IndexedSeq(), - sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), - sci.IndexedSeq[Label]("a", "b", "c"), - Some(GenFunc0("", (a: Vector[String]) => a)), - Lazy(ConstantPredictor[Label]()), - None, - Auditor + val modelNew = modelNoFeatures.copy( + featureFunctions = sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), + labelsInTrainingSet = sci.IndexedSeq[Label]("a", "b", "c"), + labelsOfInterest = Some(GenFunc0("", (a: Vector[String]) => a)) ) val aud: RootedTreeAuditor[Any, Map[Label, Double]] = RootedTreeAuditor[Any, Map[Label, Double]]() From a18a827a9faebbd3d82eafec7c22fed6cb1a4c02 Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 22 Sep 2017 11:14:21 -0700 Subject: [PATCH 49/98] Refactoring common patterns into the companion object --- .../multilabel/MultilabelModelTest.scala | 209 ++++++------------ 1 file changed, 63 insertions(+), 146 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 6a244156..69815249 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -7,6 +7,7 @@ import com.eharmony.aloha.ModelSerializationTestHelper import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.multilabel.MultilabelModel.{reportNoPrediction, LabelsAndInfo, NumLinesToKeepInStackTrace} import com.eharmony.aloha.semantics.SemanticsUdfException import com.eharmony.aloha.semantics.func._ import org.junit.Test @@ -18,7 +19,7 @@ import scala.collection.{immutable => sci, mutable => scm} import scala.util.{Failure, Random, Success, Try} /** - * Created by ryan.deak on 9/1/17. + * Created by ryan.deak and amirziai on 9/1/17. */ @RunWith(classOf[BlockJUnit4ClassRunner]) class MultilabelModelTest extends ModelSerializationTestHelper { @@ -28,18 +29,11 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testSerialization(): Unit = { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - val modelRoundTrip = serializeDeserializeRoundTrip(modelNoFeatures) assertEquals(modelNoFeatures, modelRoundTrip) } @Test def testModelCloseClosesPredictor(): Unit = { - // Make the predictorProducer passed to the constructor be a - // 'SparsePredictorProducer[K] with Closeable'. - // predictorProducer should track whether it is closed (using an AtomicBoolean or something). - // Call close on the MultilabelModel instance and ensure that the underlying predictor is - // also closed. - case class PredictorClosable[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] with Closeable { def apply( @@ -60,8 +54,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testLabelsOfInterestOmitted(): Unit = { - // Test labelsAndInfo[A, K] function. - // // When labelsOfInterest = None, labelsAndInfo should return: // LabelsAndInfo[K]( // indices = labelsInTrainingSet.indices, @@ -70,58 +62,31 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // problems = None // ) - val indices = sci.IndexedSeq[Int](1, 2, 3) - val labels = sci.IndexedSeq[Label]("label1", "label2", "label3") - val labelInfo = LabelsAndInfo( - indices = indices, - labels = labels, - missingLabels = Seq[Label](), - problems = None - ) - - val actual: LabelsAndInfo[Label] = labelsAndInfo[Unit, Label]( - a = (), + val actual: LabelsAndInfo[Label] = labelsAndInfo( + a = (), labelsOfInterest = None, - Map.empty, - labelInfo + labelToInd = Map.empty, + defaultLabelInfo = labelsAndInfoEmpty ) - - assertEquals(labelInfo, actual) + assertEquals(labelsAndInfoEmpty, actual) } @Test def testLabelsOfInterestProvided(): Unit = { - // Test labelsAndInfo[A, K] function. - // - // labelsAndInfo(a, labelsInTrainingSet, labelsOfInterest, labelToInd) == - // labelsForPrediction(a, labelsOfInterest.get, labelToInd) - - val indices = sci.IndexedSeq[Int](1, 2, 3) - val labels = sci.IndexedSeq[Label]("label1", "label2", "label3") - - val labelInfo = LabelsAndInfo( - indices = indices, - labels = labels, - missingLabels = Seq[Label](), - problems = None - ) - - val labelsOfInterest = Some(GenFunc0[Unit, sci.IndexedSeq[Label]]("", _ => labels)) - val actual: LabelsAndInfo[Label] = labelsAndInfo[Unit, Label]( - a = (), + val a = () + val labelsOfInterest = + Option(GenFunc0[Unit,sci.IndexedSeq[Label]]("", _ => labelsInTrainingSet)) + val actual: LabelsAndInfo[Label] = labelsAndInfo( + a = a, labelsOfInterest = labelsOfInterest, Map.empty, - labelInfo + labelsAndInfoEmpty ) - - val expected = labelsForPrediction((), labelsOfInterest.get, Map.empty[Label, Int]) - + val expected = labelsForPrediction(a, labelsOfInterest.get, Map.empty[Label, Int]) assertEquals(expected, actual) } -// + @Test def testReportTooManyMissing(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. + // labelsAndInfoEmpty.copy(missingLabels = Seq[Label]("a")), val labelInfo = LabelsAndInfo( indices = sci.IndexedSeq[Int](), @@ -145,84 +110,31 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testReportNoPrediction(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - - val labelInfo = LabelsAndInfo( - indices = sci.IndexedSeq[Int](), - labels = sci.IndexedSeq[Label](), - missingLabels = Seq[Label](), - problems = None - ) - - val report = reportNoPrediction( - ModelId(1, "a"), - labelInfo, - Auditor - ) - + val report = reportNoPredictionEmpty assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) assertEquals(None, report.natural) assertEquals(None, report.audited.value) } @Test def testReportNoPredictionMissingLabelsDoNotExist(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - // The missing labels are reported in the error message - val labelInfo = LabelsAndInfo( - indices = sci.IndexedSeq[Int](), - labels = sci.IndexedSeq[Label](), - missingLabels = missingLabels, - problems = None - ) - - val report = reportNoPrediction( - ModelId(1, "a"), - labelInfo, - Auditor - ) - + val report = reportNoPredictionPartial(labelsAndInfoMissingLabels) assertEquals(Vector(NoLabelsError) ++ errorMessages, report.audited.errorMsgs) - assertEquals(None, report.audited.value) + assertEquals(None, reportNoPredictionEmpty.audited.value) assertEquals(Set(), report.audited.missingVarNames) } - @Test def testReportPredictorError(): Unit = { - // Make sure Subvalue.natural == None - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be None. Check the errors and missing values. - - val labelInfo = LabelsAndInfo( - indices = sci.IndexedSeq[Int](), - labels = sci.IndexedSeq[Label](), - missingLabels = missingLabels, - problems = None - ) - - val throwable = Try(throw new Exception("error")).failed.get - val sw = new StringWriter - val pw = new PrintWriter(sw) - throwable.printStackTrace(pw) - val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") - - // This is missing variables for a features + val (throwable, stackTrace) = getThrowable("error") val missingVariables = Seq("a", "b") - val missingFeatureMap = scm.Map("x" -> missingVariables) - val report = reportPredictorError( - ModelId(-1, "x"), - labelInfo, - missingFeatureMap, + ModelId(), + labelsAndInfoEmpty.copy(missingLabels = missingLabels), + scm.Map("" -> missingVariables), throwable, Auditor ) - assertEquals(Vector(stackTrace) ++ errorMessages, report.audited.errorMsgs) assertEquals(missingVariables.toSet, report.audited.missingVarNames) assertEquals(None, report.natural) @@ -230,28 +142,14 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testReportSuccess(): Unit = { - // Make sure Subvalue.natural == Some(value) - // Check the values of Subvalue.audited and make sure they are as expected. - // Subvalue.audited.value should be Some(value2). - // 'value' should equal 'value2'. - // Check the errors and missing values. - - val labelInfo = LabelsAndInfo( - indices = sci.IndexedSeq[Int](), - labels = sci.IndexedSeq[Label](), - missingLabels = missingLabels, - problems = None - ) - val predictions = Map("label1" -> 1.0) - - val report = reportSuccess( - ModelId(0, "ModelId"), - labelInfo, - scm.Map("x" -> missingLabels), - predictions, - Auditor + val predictions = Map("label" -> 1.0) + val report = reportSuccess( + modelId = ModelId(), + labelInfo = labelsAndInfoEmpty.copy(missingLabels = missingLabels), + missing = scm.Map("" -> missingLabels), + prediction = predictions, + auditor = Auditor ) - assertEquals(Some(predictions), report.natural) assertEquals(Some(predictions), report.audited.value) assertEquals(report.natural, report.audited.value) @@ -532,8 +430,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } object MultilabelModelTest { - // TODO: Use this label type and Auditor. - private type Label = String private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() @@ -541,15 +437,17 @@ object MultilabelModelTest { override def apply(featuresUnused: SparseFeatures, labels: Labels[K], indicesUnused: LabelIndices, - ldfUnused: SparseLabelDepFeatures): Try[Map[K, Double]] = Try(labels.map(_ -> prediction).toMap) + ldfUnused: SparseLabelDepFeatures): Try[Map[K, Double]] = + Try(labels.map(_ -> prediction).toMap) } case class Lazy[A](value: A) extends (() => A) { override def apply(): A = value } - val missingLabels: Seq[Label] = Seq("a", "b") + val labelsInTrainingSet = sci.IndexedSeq[Label]("a", "b", "c") + // Models val modelNoFeatures = MultilabelModel( modelId = ModelId(), featureNames = sci.IndexedSeq(), @@ -563,23 +461,42 @@ object MultilabelModelTest { val modelNew = modelNoFeatures.copy( featureFunctions = sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), - labelsInTrainingSet = sci.IndexedSeq[Label]("a", "b", "c"), + labelsInTrainingSet = labelsInTrainingSet, labelsOfInterest = Some(GenFunc0("", (a: Vector[String]) => a)) ) - val aud: RootedTreeAuditor[Any, Map[Label, Double]] = RootedTreeAuditor[Any, Map[Label, Double]]() - // private val failure = aud.failure() + // LabelsAndInfo + val labelsAndInfoEmpty = LabelsAndInfo( + indices = sci.IndexedSeq.empty, + labels = sci.IndexedSeq.empty, + missingLabels = Seq[Label](), + problems = None + ) + val labelsAndInfoMissingLabels = labelsAndInfoEmpty.copy(missingLabels = missingLabels) + + // Reports + val reportNoPredictionPartial = reportNoPrediction( + modelId = ModelId(), + _: LabelsAndInfo[Label], + auditor = Auditor + ) + val reportNoPredictionEmpty = reportNoPredictionPartial(labelsAndInfoEmpty) + + val auditor: RootedTreeAuditor[Any, Map[Label, Double]] = + RootedTreeAuditor[Any, Map[Label, Double]]() val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") + val missingLabels: Seq[Label] = Seq("a", "b") val errorMessages: Seq[String] = baseErrorMessage.zip(missingLabels).map { case(msg, label) => s"$msg$label" } -// TODO: Access information returned in audited value by using the following functions: - // val aud: RootedTree[Any, Map[Label, Double]] = ??? - // aud.modelId // : ModelIdentity - // aud.value // : Option[Map[Label, Double]] // Should be missing on failure. - // aud.missingVarNames // : Set[String] - // aud.errorMsgs // : Seq[String] - // aud.prob // : Option[Float] (Shouldn't need this) + def getThrowable(errorMessage: String): (Throwable, String) = { + val throwable = Try(throw new Exception(errorMessage)).failed.get + val sw = new StringWriter + val pw = new PrintWriter(sw) + throwable.printStackTrace(pw) + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + (throwable, stackTrace) + } } From 406b1f93350681943c129edf257933d16f9f6d8f Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 22 Sep 2017 14:28:56 -0700 Subject: [PATCH 50/98] committing VwMultilabelRowCreator and updating other stuff to use it. --- .../multilabel/VwMultilabelRowCreator.scala | 340 ++++++++++++++++++ .../multilabel/json/VwMultilabeledJson.scala | 23 ++ .../models/multilabel/MultilabelModel.scala | 8 +- .../VwSparseMultilabelPredictor.scala | 45 +-- .../VwSparseMultilabelPredictorProducer.scala | 2 + 5 files changed, 373 insertions(+), 45 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala new file mode 100644 index 00000000..369b4a73 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -0,0 +1,340 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.AlohaException +import com.eharmony.aloha.dataset._ +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.VwCovariateProducer +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import spray.json.JsValue + +import scala.collection.{breakOut, immutable => sci} +import scala.util.Try + +/** + * Created by ryan.deak on 9/13/17. + */ +final case class VwMultilabelRowCreator[-A, K]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + featuresFunction: FeatureExtractorFunction[A, Sparse], + defaultNamespace: List[Int], + namespaces: List[(String, List[Int])], + normalizer: Option[CharSequence => CharSequence], + positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], + includeZeroValues: Boolean = false +) extends RowCreator[A, Array[String]] { + import VwMultilabelRowCreator._ + + @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + + @transient private[this] lazy val nss = + determineLabelNamespaces(namespaces.map{ case (ns, _) => ns}(breakOut)) getOrElse { + // If there are so many VW namespaces that all available Unicode characters are taken, + // then a memory error will probably already have occurred. + throw new AlohaException( + "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + + namespaces.unzip._1.mkString(", ") + ) + } + + @transient private[this] lazy val classNs = nss._1 + @transient private[this] lazy val dummyClassNs = nss._2 + + override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { + val (missingAndErrs, features) = featuresFunction(a) + + // TODO: Should this be sci.BitSet? + val positiveIndices: Set[Int] = + positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut) + + val x: Array[String] = trainingInput( + features, + allLabelsInTrainingSet.indices, + positiveIndices, + defaultNamespace, + namespaces, + classNs, + dummyClassNs + ) + + (missingAndErrs, x) + } +} + +object VwMultilabelRowCreator { + + /** + * VW allows long-based feature indices, but Aloha only allow's 32-bit indices + * on the features that produce the key-value pairs passed to VW. The negative + * dummy classes uses an ID outside of the allowable range of feature indices: + * 2^32^. + */ + private[this] val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString + + /** + * VW allows long-based feature indices, but Aloha only allow's 32-bit indices + * on the features that produce the key-value pairs passed to VW. The positive + * dummy classes uses an ID outside of the allowable range of feature indices: + * 2^32^ + 1. + */ + private[this] val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString + + /** + * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the + * dependent variable is based on cost (which is the negative of reward). + * As such, the ''reward'' of a positive example is designated to be one, + * so the cost (or negative reward) is -1. + */ + private[this] val Positive = (-1).toString + + /** + * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the + * dependent variable is based on cost (which is the negative of reward). + * As such, the ''reward'' of a negative example is designated to be zero, + * so the cost (or negative reward) is 0. + */ + private[this] val Negative = 0.toString + + + /** + * "shared" is a special keyword in VW multi-class (multi-row) format. + * See Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html page]]. + * + * '''NOTE''': The trailing space should be here. + */ + private[this] val SharedFeatureIndicator = "shared" + " " + + private[this] val FirstValidCharacter = 0 // Could probably be '0'.toInt + + private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ')) + + /** + * Determine the label namespaces for VW. VW insists on the uniqueness of the first character + * of namespaces. The goal is to use try to use the first one of these combinations for + * label and dummy label namespaces where neither of the values are in `usedNss`. + * + - "Y", "y" + - "Z", "z" + - "Λ", "λ" + * + * If one of these combinations cannot be used because at least one of the elements in a given + * row is in `usedNss`, then iterate over the Unicode set and take the first two characters + * found that adhere to are deemed a valid character. These will then become the actual and + * dummy namespace names (respectively). + * + * The goal of this function is to try to use characters in literature used to denote a + * dependent variable. If that isn't possible (because the characters are already used by some + * other namespace, just find the first possible characters. + * @param usedNss names of namespaces used. + * @return the namespace for ''actual'' label information then the namespace for ''dummy'' + * label information. If two valid namespaces couldn't be produced, return None. + */ + private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[(String, String)] = { + val nss = nssToFirstCharBitSet(usedNss) + preferredLabelNamespaces(nss) orElse bruteForceNsSearch(nss) + } + + private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[(String, String)] = { + PreferredLabelNamespaces collectFirst { + case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) => + (actual.toString, dummy.toString) + } + } + + private[multilabel] def nssToFirstCharBitSet(ss: Set[String]): sci.BitSet = + ss.collect { case s if s != "" => + s.charAt(0).toInt + }(breakOut[Set[String], Int, sci.BitSet]) + + private[multilabel] def validCharForNamespace(chr: Char): Boolean = { + // These might be overkill. + Character.isDefined(chr) && + Character.isLetter(chr) && + !Character.isISOControl(chr) && + !Character.isSpaceChar(chr) && + !Character.isWhitespace(chr) + } + + /** + * Find the first two valid characters that can be used as VW namespaces that, when converted + * to integers are not present in usedNss. + * @param usedNss the set of first characters in namespaces. + * @return the namespace to use for the actual classes and dummy classes, respectively. + */ + private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[(String, String)] = { + val found = + Stream + .from(FirstValidCharacter) + .filter(c => !(usedNss contains c) && validCharForNamespace(c.toChar)) + .take(2) + + found match { + case actual #:: dummy #:: Stream.Empty => + Option((actual.toChar.toString, dummy.toChar.toString)) + case _ => None + } + } + + + /** + * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param defaultNs the indices into `features` that should be placed in VW's default + * namespace. + * @param namespaces the indices into `features` that should be associated with each + * namespace. + * @param classNamespace + * @param dummyClassNamespace + * @return an array to be passed directly to an underlying `VWActionScoresLearner`. + */ + private[aloha] def trainingInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + positiveLabelIndices: Int => Boolean, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: String, + dummyClassNs: String + ): Array[String] = { + + val n = indices.size + + // The length of the output array is n + 3. + // + // The first row is the shared features. These are features that are not label dependent. + // Then comes two dummy classes. These are to make the probabilities work out. + // Then come the features for each of the n labels. + val x = new Array[String](n + 3) + + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) + x(0) = SharedFeatureIndicator + shared + + // These string interpolations are computed over and over but will always be the same + // for a given dummyClassNs. + // TODO: Precompute these in a class instance and pass in as parameters. + x(1) = s"$NegDummyClassId:0 |$dummyClassNs neg" + x(2) = s"$PosDummyClassId:-1 |$dummyClassNs pos" + + // This is mutable because we want speed. + var i = 0 + while (i < n) { + val labelInd = indices(i) + + // TODO or positives.contains(labelInd)? + val dv = if (positiveLabelIndices(i)) Positive else Negative + x(i + 3) = s"$labelInd:$dv |$classNs _$labelInd" + i += 1 + } + + x + } + + + /** + * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param defaultNs the indices into `features` that should be placed in VW's default + * namespace. + * @param namespaces the indices into `features` that should be associated with each + * namespace. + * @param classNamespace + * @return an array to be passed directly to an underlying `VWActionScoresLearner`. + */ + private[aloha] def predictionInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNamespace: String + ): Array[String] = { + + val n = indices.size + + // Use a (mutable) array (and iteration) for speed. + // The first row is the shared features. These are features that are not label dependent. + // Then come the features for each of the n labels. + val x = new Array[String](n + 1) + + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) + x(0) = SharedFeatureIndicator + shared + + var i = 0 + while (i < n) { + val labelInd = indices(i) + x(i + 1) = s"$labelInd:0 |$classNamespace _$labelInd" + i += 1 + } + + x + } + + + /** + * A producer that can produce a [[VwMultilabelRowCreator]]. + * The requirement for [[RowCreatorProducer]] to only have zero-argument constructors is + * relaxed for this Producer because we don't have a way of generically constructing a + * list of labels. If the labels were encoded in the JSON, then a JsonReader for the label + * type would have to be passed to the constructor. Since the labels can't be encoded + * generically in the JSON, we accept that this Producer is a special case and allow the labels + * to be passed directly. The consequence is that this producer doesn't just rely on the + * dataset specification and the data itself. It also relying on the labels provided to the + * constructor. + * + * @param allLabelsInTrainingSet All of the labels that will be encountered in the training set. + * @param ev$1 reflection information about `K`. + * @tparam A type of input passed to the [[RowCreator]]. + * @tparam K the label type. + */ + final class Producer[A, K: RefInfo](allLabelsInTrainingSet: Vector[K]) + extends RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] + with RowCreatorProducerName + with VwCovariateProducer[A] + with DvProducer + with SparseCovariateProducer + with CompilerFailureMessages { + + override type JsonType = VwMultilabeledJson + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * + * @param json + * @return + */ + override def parse(json: JsValue): Try[VwMultilabeledJson] = + Try { json.convertTo[VwMultilabeledJson] } + + /** + * Attempt to produce a Spec. + * + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a RowCreator. + * @return + */ + override def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwMultilabeledJson): Try[VwMultilabelRowCreator[A, K]] = { + val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) + + val spec = for { + cov <- covariates + pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) + } yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, normalizer, pos) + + spec + } + + private[multilabel] def positiveLabelsFn( + semantics: CompiledSemantics[A], + positiveLabels: String + ): Try[GenAggFunc[A, sci.IndexedSeq[K]]] = + getDv[A, sci.IndexedSeq[K]]( + semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) + } +} \ No newline at end of file diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala new file mode 100644 index 00000000..e59b8da2 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala @@ -0,0 +1,23 @@ +package com.eharmony.aloha.dataset.vw.multilabel.json + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.json.VwJsonLike +import spray.json.{DefaultJsonProtocol, RootJsonFormat} + +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 9/13/17. + */ +final case class VwMultilabeledJson( + imports: sci.Seq[String], + features: sci.IndexedSeq[SparseSpec], + namespaces: Option[Seq[Namespace]] = Some(Nil), + normalizeFeatures: Option[Boolean] = Some(false), + positiveLabels: String) + extends VwJsonLike + +object VwMultilabeledJson extends DefaultJsonProtocol { + implicit val labeledVwJsonFormat: RootJsonFormat[VwMultilabeledJson] = + jsonFormat5(VwMultilabeledJson.apply) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 67059ff3..688209a7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -140,8 +140,8 @@ object MultilabelModel extends ParserProviderCompanion { * be sorted in ascending order. * @param labels labels for which a prediction should be produced. labels are parallel to * indices so `indices(i)` is the index associated with `labels(i)`. - * @param missingLabels a sequence of labels derived from the input data that could not be - * found in the sequence of all labels seen during training. + * @param labelsNotInTrainingSet a sequence of labels derived from the input data that could + * not be found in the sequence of all labels seen during training. * @param problems any problems encountered when trying to get the labels. This should only * be present when the caller indicates labels should be embedded in the * input data passed to the prediction function in the MultilabelModel. @@ -150,12 +150,12 @@ object MultilabelModel extends ParserProviderCompanion { protected[multilabel] case class LabelsAndInfo[K]( indices: sci.IndexedSeq[Int], labels: sci.IndexedSeq[K], - missingLabels: Seq[K], + labelsNotInTrainingSet: Seq[K], problems: Option[GenAggFuncAccessorProblems] ) { def missingVarNames: Seq[String] = problems.map(p => p.missing).getOrElse(Nil) def errorMsgs: Seq[String] = { - missingLabels.map { lab => s"Label not in training labels: $lab" } ++ + labelsNotInTrainingSet.map { lab => s"Label not in training labels: $lab" } ++ problems.map(p => p.errors).getOrElse(Nil) } } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index c3432ee3..60b1e68c 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -3,7 +3,7 @@ package com.eharmony.aloha.models.vw.jni.multilabel import java.io.{Closeable, File} import com.eharmony.aloha.dataset.density.Sparse -import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import com.eharmony.aloha.io.sources.ModelSource import com.eharmony.aloha.models.multilabel.SparseMultiLabelPredictor import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} @@ -71,7 +71,9 @@ extends SparseMultiLabelPredictor[K] indices: sci.IndexedSeq[Int], labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] ): Try[Map[K, Double]] = { - val x = multiLabelClassifierInput(features, indices, defaultNs, namespaces) + + // TODO: Pass ClassNS in via the constructor + val x = VwMultilabelRowCreator.predictionInput(features, indices, defaultNs, namespaces, ClassNS) val pred = Try { vwModel.predict(x) } val yOut = pred.map { y => produceOutput(y, labels) } yOut @@ -83,45 +85,6 @@ extends SparseMultiLabelPredictor[K] object VwSparseMultilabelPredictor { private val ClassNS = "Y" - /** - * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. - * @param features (non-label dependent) features shared across all labels. - * @param indices the indices `labels` into the sequence of all labels encountered - * during training. - * @param defaultNs the indices into `features` that should be placed in VW's default - * namespace. - * @param namespaces the indices into `features` that should be associated with each - * namespace. - * @return an array to be passed directly to an underlying `VWActionScoresLearner`. - */ - private[multilabel] def multiLabelClassifierInput( - features: IndexedSeq[Sparse], - indices: sci.IndexedSeq[Int], - defaultNs: List[Int], - namespaces: List[(String, List[Int])] - ): Array[String] = { - val n = indices.size - // The length of the output array is n + 1. The first row is the shared features. - // These are features that are not label dependent. Then come the features for the - // n labels. - val x = new Array[String](n + 1) - - val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) - // "shared" is a special keyword in VW multi-class (multi-row) format. - // See: https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html - x(0) = "shared " + shared - - // This is mutable because we want speed. - var i = 0 - - while (i < n) { - val labelInd = indices(i) - x(i + 1) = s"$labelInd:0 |$ClassNS _C${labelInd}_" - i += 1 - } - - x - } /** * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala index 99d73cfa..c117feff 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -27,6 +27,8 @@ import spray.json.{JsonFormat, JsonReader} */ case class VwSparseMultilabelPredictorProducer[K]( modelSource: ModelSource, + + // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])]) From dc363e6b08a50e3452c10b2b2eee4ea3ded36a96 Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 22 Sep 2017 17:04:50 -0700 Subject: [PATCH 51/98] All tests pass, code structured a bit better --- .../multilabel/MultilabelModelTest.scala | 276 +++++++----------- 1 file changed, 102 insertions(+), 174 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 69815249..59950052 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -4,9 +4,10 @@ import java.io.{Closeable, PrintWriter, StringWriter} import java.util.concurrent.atomic.AtomicBoolean import com.eharmony.aloha.ModelSerializationTestHelper -import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor +import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor, Tree} import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.Subvalue import com.eharmony.aloha.models.multilabel.MultilabelModel.{reportNoPrediction, LabelsAndInfo, NumLinesToKeepInStackTrace} import com.eharmony.aloha.semantics.SemanticsUdfException import com.eharmony.aloha.semantics.func._ @@ -86,27 +87,16 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testReportTooManyMissing(): Unit = { - // labelsAndInfoEmpty.copy(missingLabels = Seq[Label]("a")), - - val labelInfo = LabelsAndInfo( - indices = sci.IndexedSeq[Int](), - labels = sci.IndexedSeq[Label](), - missingLabels = Seq[Label]("a"), - problems = None + val report = reportTooManyMissing( + modelId = ModelId(), + labelInfo = labelsAndInfoEmpty, + missing = scm.Map("" -> missingLabels), + auditor = auditor ) - // TODO: reportTooManyMissing - val report = reportNoPrediction( - ModelId(1, "a"), - labelInfo, - Auditor - ) - - assertEquals(Vector(NoLabelsError), report.audited.errorMsgs.take(1)) + assertEquals(Vector(TooManyMissingError), report.audited.errorMsgs.take(1)) assertEquals(None, report.audited.value) assertEquals(None, report.natural) - - fail() } @Test def testReportNoPrediction(): Unit = { @@ -155,88 +145,58 @@ class MultilabelModelTest extends ModelSerializationTestHelper { assertEquals(report.natural, report.audited.value) } - @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsEmpty(): Unit = { - def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq - - // TODO: break this up - - // Example with no problems - val example: Map[String, String] = Map( - "feature1" -> "1", - "feature2" -> "2", - "feature3" -> "2", - "label1" -> "a", - "label2" -> "b" - ) - val allLabels = sci.IndexedSeq("a", "b", "c") - val labelToInt = allLabels.zipWithIndex.toMap - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) - assertEquals(None, labelsAndInfo.problems) - - // Example with 1 missing label - val exampleMissingOneLabel = Map("feature1" -> "1", "label1" -> "a") - val labelsAndInfoMissingOneLabel = labelsForPrediction( - exampleMissingOneLabel, - labelsOfInterestExtractor, - labelToInt) - assertEquals(None, labelsAndInfoMissingOneLabel.problems) - - // Example with no labels - val exampleNoLabels = Map("feature1" -> "1", "feature2" -> "2") + @Test def testLabelsForPredictionContainsProblemsWhenNoLabelProvided(): Unit = { val labelsAndInfoNoLabels = labelsForPrediction( - exampleNoLabels, - labelsOfInterestExtractor, - labelToInt) - // No problems are actually found - // But the important thing is that problemsNoLabelsExpected is a Some and not a None + example = Map[Label, Label](), // no label provided + labelsOfInterest = labelsOfInterestExtractor, + labelToInd = labelsInTrainingSetToIndex) + + // In this scenario no problems are found but an + // Option(GenAggFuncAccessorProblems) is returned instead of None val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq())) assertEquals(problemsNoLabelsExpected, labelsAndInfoNoLabels.problems) + } - // missing labels - def badFunction(example: Map[String, String]) = example.get("feature1") - val gen1 = GenFunc1("concat _1", (m: Option[String]) => sci.IndexedSeq(s"${m}_1"), - GeneratedAccessor("extract feature 1", badFunction, None)) - + @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsNotPresent(): Unit = { val labelsExtractor = - (m: Map[String, String]) => m.get("labels") match { - case ls: sci.IndexedSeq[_] if ls.forall { x: String => x.isInstanceOf[Label] } => - Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) + (label: Map[Label, Label]) => label.get("labels") match { + case ls: sci.IndexedSeq[_] => Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) case _ => None } - val featureWithMissingValue = "feature not present" - val f1 = - GenFunc.f1(GeneratedAccessor(featureWithMissingValue, labelsExtractor, None))( - "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] + val descriptor = "label is missing" + val labelsOfInterest = + GenFunc.f1(GeneratedAccessor(descriptor, labelsExtractor, None))( + "", _ getOrElse sci.IndexedSeq.empty[Label] ) val labelsAndInfoNoLabelsGen1 = labelsForPrediction( - Map[String, String](), - f1, - labelToInt) - val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq(featureWithMissingValue), Seq())) + example = Map[Label, Label](), + labelsOfInterest = labelsOfInterest, + labelToInd = labelsInTrainingSetToIndex) + val problemsNoLabelsExpectedGen1 = + Option(GenAggFuncAccessorProblems(Seq(descriptor), Seq())) assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) + } - // error - val featureWithError = "feature has error" - val f2 = - GenFunc.f1(GeneratedAccessor(featureWithError, - ( _ => throw new Exception("errmsg")) : Map[String, String] => Option[sci.IndexedSeq[String]], None))( - "def omitted", _ getOrElse sci.IndexedSeq.empty[Label] + @Test def testLabelsForPredictionContainsProblemsWhenLabelAccessorThrows(): Unit = { + val descriptor = "label accessor that throws" + val labelsOfInterest = + GenFunc.f1(GeneratedAccessor(descriptor, + ( _ => throw new Exception()) : Map[String, String] => Option[sci.IndexedSeq[String]], None))( + "", _ getOrElse sci.IndexedSeq.empty[Label] ) - val f2Wrapped = EnrichedErrorGenAggFunc(f2) + val labelsOfInterestWrapped = EnrichedErrorGenAggFunc(labelsOfInterest) - val problemsNoLabelsExpectedGen2 = Option(GenAggFuncAccessorProblems(Seq(), Seq(featureWithError))) + val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq(descriptor))) Try( labelsForPrediction( Map[String, String](), - f2Wrapped, - labelToInt) + labelsOfInterestWrapped, + labelsInTrainingSetToIndex) ).failed.get match { - case SemanticsUdfException(_, _, _, accessorsInErr, _, _) => assertEquals(accessorsInErr, - problemsNoLabelsExpectedGen2.get.errors) + case SemanticsUdfException(_, _, _, accessorsInErr, _, _) => + assertEquals(accessorsInErr, problemsNoLabelsExpected.get.errors) } } @@ -280,75 +240,49 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { - // Test this: - // val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip - - def extractLabelsOutOfExample(example: Map[String, String]) = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.toIndexedSeq // sorted.toIndexedSeq - val example: Map[String, String] = Map( "feature1" -> "1", - "feature2" -> "2", - "feature3" -> "2", "label1" -> "a", - "label2" -> "b", "label3" -> "l23", "label4" -> "100", - "label5" -> "235", "label6" -> "c", - "label7" -> "1", "label8" -> "l1" ) val allLabels = extractLabelsOutOfExample(example).sorted - val labelToInt: Map[String, Int] = allLabels.zipWithIndex.toMap + val labelToInt = allLabels.zipWithIndex.toMap - val rng = new Random(seed=0) + val random = new Random(seed=0) (1 to 10).foreach { _ => - val ex = rng.shuffle(example.toVector).take(rng.nextInt(example.size)).toMap - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) + val ex = random.shuffle(example.toVector).take(random.nextInt(example.size)).toMap + val labelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) val labelsAndInfo = labelsForPrediction(ex, labelsOfInterestExtractor, labelToInt) assertEquals(labelsAndInfo.indices.sorted, labelsAndInfo.indices) } } @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { - // Test this: - // if (li.labels.isEmpty) - // reportNoPrediction(modelId, li, auditor) - - assertEquals(None, modelNew.subvalue(Vector.empty).natural) + assertEquals(None, modelWithFeatureFunctions.subvalue(Vector.empty).natural) } @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { // When the amount of missing data exceeds the threshold, reportTooManyMissing should be - // called and its value should be returned. Instantiate a MultilabelModel and - // call apply with some missing data required by the features. - - val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = - GenFunc0("", _ => Iterable()) - - val featureFunctions = Vector(EmptyIndicatorFn) - - val modelWithThreshold = MultilabelModel( - modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq("a"), - featureFunctions = featureFunctions, - labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2"), - labelsOfInterest = None, - predictorProducer = Lazy(ConstantPredictor[Label]()), - numMissingThreshold = Option(0), - auditor = Auditor + // called and its value should be returned. + + val modelWithMissingThreshold = modelWithFeatureFunctions.copy( + featureNames = sci.IndexedSeq("feature1"), + featureFunctions = featureFunctions, + labelsInTrainingSet = labelsInTrainingSet, + labelsOfInterest = None, + numMissingThreshold = Option(0) ) - val result = modelWithThreshold(Map()) + val result = modelWithMissingThreshold(Map.empty) + assertEquals(None, result.value) assertEquals(TooManyMissingError, result.errorMsgs.head) } @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { - // Create a predictorProducer that throws. Check that the model still returns a value - // and that the error message is incorporated appropriately. - case object PredictorThatThrows extends SparseMultiLabelPredictor[Label] { override def apply(v1: SparseFeatures, @@ -357,43 +291,25 @@ class MultilabelModelTest extends ModelSerializationTestHelper { v4: SparseLabelDepFeatures): Try[Map[Label, Double]] = Try(throw new Exception("error")) } - val modelWithThrowingPredictorProducer = MultilabelModel( - modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq.empty, - featureFunctions = sci.IndexedSeq.empty, - labelsInTrainingSet = sci.IndexedSeq(""), // we need at least 1 label to get the error - labelsOfInterest = None, - predictorProducer = Lazy(PredictorThatThrows), - numMissingThreshold = None, - auditor = Auditor + val modelWithThrowingPredictorProducer = modelWithFeatureFunctions.copy( + predictorProducer = Lazy(PredictorThatThrows), + labelsOfInterest = None ) - val result = modelWithThrowingPredictorProducer(Map()) + val result = modelWithThrowingPredictorProducer(Vector.empty) assertEquals(None, result.value) assertEquals("java.lang.Exception: error", result.errorMsgs.head.split("\n").head) } @Test def testSubvalueSuccess(): Unit = { - // Test the happy path by calling model.apply. Check the value, missing data, and error - // messages. - - def extractLabelsOutOfExample(example: Map[String, String]) = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq - - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - val scoreToReturn = 5d - - val modelSuccess = MultilabelModel( - modelId = ModelId(1, "model1"), - featureNames = sci.IndexedSeq.empty, - featureFunctions = sci.IndexedSeq.empty, + val modelSuccess = modelNoFeatures.copy( labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), labelsOfInterest = Option(labelsOfInterestExtractor), predictorProducer = Lazy(ConstantPredictor[Label](scoreToReturn)), - numMissingThreshold = None, - auditor = Auditor + featureFunctions = sci.IndexedSeq.empty ) + val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) assertEquals(Vector(), result.errorMsgs) assertEquals(Set(), result.missingVarNames) @@ -401,27 +317,20 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { - // NOTE: This is by design. + // This is by design. val exception = new Exception("error") - val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + val featureFunctionThatThrows: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = GenFunc0("", _ => throw exception) - val featureFunctions = Vector(EmptyIndicatorFn) - - val modelSuccess = MultilabelModel( - modelId = ModelId(1, "model1"), + val modelWithFeatureFunctionThatThrows = modelNoFeatures.copy( featureNames = sci.IndexedSeq("throwing feature"), - featureFunctions = featureFunctions, + featureFunctions = Vector(featureFunctionThatThrows), labelsInTrainingSet = sci.IndexedSeq[Label](""), - labelsOfInterest = None, - predictorProducer = Lazy(ConstantPredictor[Label]()), - numMissingThreshold = None, - auditor = Auditor + labelsOfInterest = None ) - - val result = Try(modelSuccess(Map())) + val result = Try(modelWithFeatureFunctionThatThrows(Map())) result match { case Success(_) => fail() case Failure(ex) => assertEquals(exception, ex) @@ -430,6 +339,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } object MultilabelModelTest { + // Types private type Label = String private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() @@ -445,7 +355,21 @@ object MultilabelModelTest { override def apply(): A = value } - val labelsInTrainingSet = sci.IndexedSeq[Label]("a", "b", "c") + // Common input + val labelsInTrainingSet: sci.IndexedSeq[Label] = sci.IndexedSeq[Label]("a", "b", "c") + val labelsInTrainingSetToIndex: Map[Label, Int] =labelsInTrainingSet.zipWithIndex.toMap + val missingLabels: Seq[Label] = Seq("a", "b") + val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") + val errorMessages: Seq[String] = baseErrorMessage.zip(missingLabels).map { + case(msg, label) => s"$msg$label" + } + val auditor: RootedTreeAuditor[Any, Map[Label, Double]] = + RootedTreeAuditor[Any, Map[Label, Double]]() + + // Feature functions + val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("", _ => Iterable()) + val featureFunctions = Vector(EmptyIndicatorFn) // Models val modelNoFeatures = MultilabelModel( @@ -459,7 +383,9 @@ object MultilabelModelTest { auditor = Auditor ) - val modelNew = modelNoFeatures.copy( + val modelWithFeatureFunctions: + MultilabelModel[Tree[Any], String, Vector[String], RootedTree[Any, Map[Label, Double]]] = + modelNoFeatures.copy( featureFunctions = sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), labelsInTrainingSet = labelsInTrainingSet, labelsOfInterest = Some(GenFunc0("", (a: Vector[String]) => a)) @@ -472,25 +398,21 @@ object MultilabelModelTest { missingLabels = Seq[Label](), problems = None ) - val labelsAndInfoMissingLabels = labelsAndInfoEmpty.copy(missingLabels = missingLabels) + val labelsAndInfoMissingLabels: LabelsAndInfo[Label] = + labelsAndInfoEmpty.copy(missingLabels = missingLabels) // Reports - val reportNoPredictionPartial = reportNoPrediction( + val reportNoPredictionPartial: + (LabelsAndInfo[Label]) => Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = + reportNoPrediction( modelId = ModelId(), _: LabelsAndInfo[Label], auditor = Auditor ) - val reportNoPredictionEmpty = reportNoPredictionPartial(labelsAndInfoEmpty) - - val auditor: RootedTreeAuditor[Any, Map[Label, Double]] = - RootedTreeAuditor[Any, Map[Label, Double]]() - - val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") - val missingLabels: Seq[Label] = Seq("a", "b") - val errorMessages: Seq[String] = baseErrorMessage.zip(missingLabels).map { - case(msg, label) => s"$msg$label" - } + val reportNoPredictionEmpty: Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = + reportNoPredictionPartial(labelsAndInfoEmpty) + // Throwable and stack trace def getThrowable(errorMessage: String): (Throwable, String) = { val throwable = Try(throw new Exception(errorMessage)).failed.get val sw = new StringWriter @@ -499,4 +421,10 @@ object MultilabelModelTest { val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") (throwable, stackTrace) } + + // Label extractors + def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.toIndexedSeq + + val labelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) } From 0d35d4cb3a93c0cf2ebb7c53d3f95123d95768f7 Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 22 Sep 2017 17:14:59 -0700 Subject: [PATCH 52/98] Wasn't compiling after merge --- .../models/multilabel/MultilabelModelTest.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 59950052..f950be65 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -120,7 +120,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val missingVariables = Seq("a", "b") val report = reportPredictorError( ModelId(), - labelsAndInfoEmpty.copy(missingLabels = missingLabels), + labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels), scm.Map("" -> missingVariables), throwable, Auditor @@ -135,7 +135,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val predictions = Map("label" -> 1.0) val report = reportSuccess( modelId = ModelId(), - labelInfo = labelsAndInfoEmpty.copy(missingLabels = missingLabels), + labelInfo = labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels), missing = scm.Map("" -> missingLabels), prediction = predictions, auditor = Auditor @@ -206,9 +206,6 @@ class MultilabelModelTest extends ModelSerializationTestHelper { // if (unsorted.size == labelsShouldPredict.size) Seq.empty // else labelsShouldPredict.filterNot(labelToInd.contains) - def extractLabelsOutOfExample(example: Map[String, String]) = - example.filterKeys(_.startsWith("label")).toSeq.unzip._2.sorted.toIndexedSeq - val example: Map[String, String] = Map( "feature1" -> "1", "feature2" -> "2", @@ -220,19 +217,19 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val labelToInt = allLabels.zipWithIndex.toMap val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) - val missingLabels = labelsAndInfo.missingLabels + val missingLabels = labelsAndInfo.labelsNotInTrainingSet assertEquals(Seq(), missingLabels) // Extra label not in the list val example2 = Map("label4" -> "d") val labelsAndInfo2 = labelsForPrediction(example2, labelsOfInterestExtractor, labelToInt) - val missingLabels2 = labelsAndInfo2.missingLabels + val missingLabels2 = labelsAndInfo2.labelsNotInTrainingSet assertEquals(Seq("d"), missingLabels2) // No labels val example3 = Map("feature2" -> "5") val labelsAndInfo3 = labelsForPrediction(example3, labelsOfInterestExtractor, labelToInt) - val missingLabels3 = labelsAndInfo3.missingLabels + val missingLabels3 = labelsAndInfo3.labelsNotInTrainingSet assertEquals(Seq(), missingLabels3) // TODO: add these as a different test at the model.apply(.) level @@ -395,11 +392,11 @@ object MultilabelModelTest { val labelsAndInfoEmpty = LabelsAndInfo( indices = sci.IndexedSeq.empty, labels = sci.IndexedSeq.empty, - missingLabels = Seq[Label](), + labelsNotInTrainingSet = Seq[Label](), problems = None ) val labelsAndInfoMissingLabels: LabelsAndInfo[Label] = - labelsAndInfoEmpty.copy(missingLabels = missingLabels) + labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels) // Reports val reportNoPredictionPartial: From 201a8230858be87ae6c0f7d9cf9f85185b6adee2 Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 22 Sep 2017 17:28:26 -0700 Subject: [PATCH 53/98] labels not in training set should be reported --- .../multilabel/MultilabelModelTest.scala | 38 +++---------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index f950be65..7af36314 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -201,39 +201,11 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { - // Test this: - // val noPrediction = - // if (unsorted.size == labelsShouldPredict.size) Seq.empty - // else labelsShouldPredict.filterNot(labelToInd.contains) - - val example: Map[String, String] = Map( - "feature1" -> "1", - "feature2" -> "2", - "feature3" -> "2", - "label1" -> "a", - "label2" -> "b" - ) - val allLabels = sci.IndexedSeq("a", "b", "c") - val labelToInt = allLabels.zipWithIndex.toMap - val labelsOfInterestExtractor = GenFunc0("empty spec", extractLabelsOutOfExample) - val labelsAndInfo = labelsForPrediction(example, labelsOfInterestExtractor, labelToInt) - val missingLabels = labelsAndInfo.labelsNotInTrainingSet - assertEquals(Seq(), missingLabels) - - // Extra label not in the list - val example2 = Map("label4" -> "d") - val labelsAndInfo2 = labelsForPrediction(example2, labelsOfInterestExtractor, labelToInt) - val missingLabels2 = labelsAndInfo2.labelsNotInTrainingSet - assertEquals(Seq("d"), missingLabels2) - - // No labels - val example3 = Map("feature2" -> "5") - val labelsAndInfo3 = labelsForPrediction(example3, labelsOfInterestExtractor, labelToInt) - val missingLabels3 = labelsAndInfo3.labelsNotInTrainingSet - assertEquals(Seq(), missingLabels3) - - // TODO: add these as a different test at the model.apply(.) level - // assertEquals(None, modelNew(Vector("1"))) + val labelNotInTrainingSet = "d" + val labelsAndInfo = labelsForPrediction(Map("label" -> labelNotInTrainingSet), + labelsOfInterestExtractor, labelsInTrainingSetToIndex) + val missingLabels2 = labelsAndInfo.labelsNotInTrainingSet + assertEquals(Seq(labelNotInTrainingSet), missingLabels2) } @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { From a42c9d76685a0db7fdf6e7f27420c2f45668755c Mon Sep 17 00:00:00 2001 From: amirziai Date: Fri, 22 Sep 2017 17:30:17 -0700 Subject: [PATCH 54/98] Renamed missingLabels->labelsNotInTrainingSet to conform to new signature --- .../models/multilabel/MultilabelModelTest.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index 7af36314..dc47cf65 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -90,7 +90,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val report = reportTooManyMissing( modelId = ModelId(), labelInfo = labelsAndInfoEmpty, - missing = scm.Map("" -> missingLabels), + missing = scm.Map("" -> labelsNotInTrainingSet), auditor = auditor ) @@ -120,7 +120,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val missingVariables = Seq("a", "b") val report = reportPredictorError( ModelId(), - labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels), + labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet), scm.Map("" -> missingVariables), throwable, Auditor @@ -135,8 +135,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val predictions = Map("label" -> 1.0) val report = reportSuccess( modelId = ModelId(), - labelInfo = labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels), - missing = scm.Map("" -> missingLabels), + labelInfo = labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet), + missing = scm.Map("" -> labelsNotInTrainingSet), prediction = predictions, auditor = Auditor ) @@ -204,8 +204,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val labelNotInTrainingSet = "d" val labelsAndInfo = labelsForPrediction(Map("label" -> labelNotInTrainingSet), labelsOfInterestExtractor, labelsInTrainingSetToIndex) - val missingLabels2 = labelsAndInfo.labelsNotInTrainingSet - assertEquals(Seq(labelNotInTrainingSet), missingLabels2) + assertEquals(Seq(labelNotInTrainingSet), labelsAndInfo.labelsNotInTrainingSet) } @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { @@ -327,9 +326,9 @@ object MultilabelModelTest { // Common input val labelsInTrainingSet: sci.IndexedSeq[Label] = sci.IndexedSeq[Label]("a", "b", "c") val labelsInTrainingSetToIndex: Map[Label, Int] =labelsInTrainingSet.zipWithIndex.toMap - val missingLabels: Seq[Label] = Seq("a", "b") + val labelsNotInTrainingSet: Seq[Label] = Seq("a", "b") val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") - val errorMessages: Seq[String] = baseErrorMessage.zip(missingLabels).map { + val errorMessages: Seq[String] = baseErrorMessage.zip(labelsNotInTrainingSet).map { case(msg, label) => s"$msg$label" } val auditor: RootedTreeAuditor[Any, Map[Label, Double]] = @@ -368,7 +367,7 @@ object MultilabelModelTest { problems = None ) val labelsAndInfoMissingLabels: LabelsAndInfo[Label] = - labelsAndInfoEmpty.copy(labelsNotInTrainingSet = missingLabels) + labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet) // Reports val reportNoPredictionPartial: From c8902eca7d16335bdfeb90baf22afd8fe0179034 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 22 Sep 2017 17:45:28 -0700 Subject: [PATCH 55/98] Added some unit tests. Still plenty more to do. --- .../multilabel/VwMultilabelRowCreator.scala | 20 +++-- .../VwMultilabelRowCreatorTest.scala | 89 +++++++++++++++++++ .../VwSparseMultilabelPredictorProducer.scala | 5 +- 3 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 369b4a73..1404a73c 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -184,12 +184,14 @@ object VwMultilabelRowCreator { * @param features (non-label dependent) features shared across all labels. * @param indices the indices `labels` into the sequence of all labels encountered * during training. + * @param positiveLabelIndices a predicate telling whether the example should be positively + * associated with a label. * @param defaultNs the indices into `features` that should be placed in VW's default * namespace. * @param namespaces the indices into `features` that should be associated with each * namespace. - * @param classNamespace - * @param dummyClassNamespace + * @param classNs a namespace for features associated with class labels + * @param dummyClassNs a namespace for features associated with dummy class labels * @return an array to be passed directly to an underlying `VWActionScoresLearner`. */ private[aloha] def trainingInput( @@ -244,15 +246,15 @@ object VwMultilabelRowCreator { * namespace. * @param namespaces the indices into `features` that should be associated with each * namespace. - * @param classNamespace + * @param classNs a namespace for features associated with class labels * @return an array to be passed directly to an underlying `VWActionScoresLearner`. */ private[aloha] def predictionInput( - features: IndexedSeq[Sparse], - indices: sci.IndexedSeq[Int], - defaultNs: List[Int], - namespaces: List[(String, List[Int])], - classNamespace: String + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: String ): Array[String] = { val n = indices.size @@ -268,7 +270,7 @@ object VwMultilabelRowCreator { var i = 0 while (i < n) { val labelInd = indices(i) - x(i + 1) = s"$labelInd:0 |$classNamespace _$labelInd" + x(i + 1) = s"$labelInd:0 |$classNs _$labelInd" i += 1 } diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala new file mode 100644 index 00000000..0da4cd0b --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -0,0 +1,89 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.dataset.SparseFeatureExtractorFunction +import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +/** + * Created by ryan.deak on 9/22/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelRowCreatorTest { + import VwMultilabelRowCreatorTest._ + + // TODO: Test that shared doesn't occur when there are no shared features. + @Test def testSharedOmittedWhenNoSharedFeaturesExist(): Unit = { + fail() + } + + @Test def testOneFeatureNoPos(): Unit = { + val nFeatures = 1 + val shared = List("shared | f1") // One feature because nFeatures = 1 + + val dummy = List( + s"$NegDummyClass:$NegVal |y neg", + s"$PosDummyClass:$PosVal |y pos" + ) + + // Only negative labels because no positive labels. + val classLabels = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") + + val rc = rowCreator(nFeatures) + val (missing, arr) = rc(Map.empty) + assertEquals(shared ++ dummy ++ classLabels, arr.toList) + } +} + +object VwMultilabelRowCreatorTest { + private type Domain = Map[String, String] + private type Label = String + private val Omitted = "" + private val LabelsInTrainingSet = Vector("zero", "one", "two") + private val NegDummyClass = Int.MaxValue.toLong + 1 + private val PosDummyClass = NegDummyClass + 1 + private val PosVal = -1 + private val NegVal = 0 + + private[this] val featureFns = SparseFeatureExtractorFunction[Domain](Vector( + "f1" -> GenFunc0(Omitted, _ => Seq(("", 1d))), + "f2" -> GenFunc0(Omitted, _ => Seq(("", 2d))) + )) + + private def featureFns(n: Int) = { + val ff = (1 to n) map { i => + s"f$i" -> GenFunc0(Omitted, (_: Any) => Seq(("", i.toDouble))) + } + + SparseFeatureExtractorFunction[Domain](ff) + } + + private def positiveLabels(ps: Label*): GenAggFunc[Any, Vector[Label]] = { + GenFunc0(Omitted, _ => Vector(ps:_*)) + } + + private def rowCreator(numFeatures: Int, posLabels: Label*): VwMultilabelRowCreator[Domain, Label] = { + val ff = featureFns(numFeatures) + val pos = positiveLabels(posLabels:_*) + StdRowCreator.copy( + featuresFunction = ff, + defaultNamespace = ff.features.indices.toList, + positiveLabelsFunction = pos + ) + } + + private[this] val StdRowCreator: VwMultilabelRowCreator[Domain, Label] = { + val ff = featureFns(0) + + VwMultilabelRowCreator[Domain, Label]( + allLabelsInTrainingSet = LabelsInTrainingSet, + featuresFunction = ff, + defaultNamespace = ff.features.indices.toList, // All features in default NS. + namespaces = List.empty[(String, List[Int])], + normalizer = Option.empty[CharSequence => CharSequence], + positiveLabelsFunction = positiveLabels() + ) + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala index c117feff..493bbc95 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -31,8 +31,9 @@ case class VwSparseMultilabelPredictorProducer[K]( // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. params: String, defaultNs: List[Int], - namespaces: List[(String, List[Int])]) -extends SparsePredictorProducer[K] { + namespaces: List[(String, List[Int])], + labelNamespace: String +) extends SparsePredictorProducer[K] { override def apply(): VwSparseMultilabelPredictor[K] = VwSparseMultilabelPredictor[K](modelSource, params, defaultNs, namespaces) } From a1ce0dd45fda04e4a53f8e407d81c06e795c111f Mon Sep 17 00:00:00 2001 From: amirziai Date: Sun, 24 Sep 2017 11:31:27 -0700 Subject: [PATCH 56/98] Adding PR template --- .github/PULL_REQUEST_TEMPLATE.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..6ac875bb --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,26 @@ +### Summary +The summary should expand on the title of the pull request. What is the expected effect of the +pull request? Specifically, indicate how the behavior will be different than before the pull +request. + +### Bug Fixes/New Features +* Bullet list overview of bug fixes or new features added. + +### How to Verify +How can reviewers verify the request has the intended effect? This should include descriptions of +any new tests that are added or old tests that are updated. It should also indicate how the code +was tested. + +### Side Effects +Does the pull request contain any side effects? This should list non-obvious things that have +changed in the request. + +### Resolves +Fixes RID-1234 +Fixes AMG-1234 + +### Tests +What tests were created or changed for this feature or fix? Do all tests pass, and if not, why? + +### Code Reviewer(s) +@, @ From d1a40f3947da6b56f8d09efff37a3515b9bd7afa Mon Sep 17 00:00:00 2001 From: amirziai Date: Mon, 25 Sep 2017 15:18:07 -0700 Subject: [PATCH 57/98] Addressing comments --- .github/PULL_REQUEST_TEMPLATE.md | 26 --- .../aloha/util/SerializabilityEvidence.scala | 2 + .../multilabel/MultilabelModelTest.scala | 163 ++++++++---------- 3 files changed, 76 insertions(+), 115 deletions(-) delete mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 6ac875bb..00000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,26 +0,0 @@ -### Summary -The summary should expand on the title of the pull request. What is the expected effect of the -pull request? Specifically, indicate how the behavior will be different than before the pull -request. - -### Bug Fixes/New Features -* Bullet list overview of bug fixes or new features added. - -### How to Verify -How can reviewers verify the request has the intended effect? This should include descriptions of -any new tests that are added or old tests that are updated. It should also indicate how the code -was tested. - -### Side Effects -Does the pull request contain any side effects? This should list non-obvious things that have -changed in the request. - -### Resolves -Fixes RID-1234 -Fixes AMG-1234 - -### Tests -What tests were created or changed for this feature or fix? Do all tests pass, and if not, why? - -### Code Reviewer(s) -@, @ diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala index f817c864..653908f6 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala @@ -10,6 +10,8 @@ object SerializabilityEvidence { implicit def anyValEvidence[A <: AnyVal]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} + implicit def serializableEvidence[A <: java.io.Serializable]: SerializabilityEvidence[A] = new SerializabilityEvidence[A]{} + } diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala index dc47cf65..8ec4a3ec 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -30,13 +30,12 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testSerialization(): Unit = { // Assuming all parameters passed to the MultilabelModel constructor are // Serializable, MultilabelModel should also be Serializable. - val modelRoundTrip = serializeDeserializeRoundTrip(modelNoFeatures) - assertEquals(modelNoFeatures, modelRoundTrip) + val modelRoundTrip = serializeDeserializeRoundTrip(ModelNoFeatures) + assertEquals(ModelNoFeatures, modelRoundTrip) } @Test def testModelCloseClosesPredictor(): Unit = { - case class PredictorClosable[K](prediction: Double = 0d) - extends SparseMultiLabelPredictor[K] with Closeable { + class PredictorClosable[K] extends SparseMultiLabelPredictor[K] with Closeable { def apply( v1: SparseFeatures, v2: Labels[K], @@ -47,51 +46,42 @@ class MultilabelModelTest extends ModelSerializationTestHelper { def isClosed: Boolean = closed.get() } - val predictor = PredictorClosable[Label]() - val model = modelNoFeatures.copy(predictorProducer = Lazy(predictor)) + val predictor = new PredictorClosable[Label] + val model = ModelNoFeatures.copy(predictorProducer = Lazy(predictor)) model.close() assertTrue(predictor.isClosed) } @Test def testLabelsOfInterestOmitted(): Unit = { - // When labelsOfInterest = None, labelsAndInfo should return: - // LabelsAndInfo[K]( - // indices = labelsInTrainingSet.indices, - // labels = labelsInTrainingSet, - // missingLabels = Seq.empty[K], - // problems = None - // ) - val actual: LabelsAndInfo[Label] = labelsAndInfo( a = (), labelsOfInterest = None, labelToInd = Map.empty, - defaultLabelInfo = labelsAndInfoEmpty + defaultLabelInfo = LabelsAndInfoEmpty ) - assertEquals(labelsAndInfoEmpty, actual) + assertEquals(LabelsAndInfoEmpty, actual) } @Test def testLabelsOfInterestProvided(): Unit = { val a = () - val labelsOfInterest = - Option(GenFunc0[Unit,sci.IndexedSeq[Label]]("", _ => labelsInTrainingSet)) + val labelsOfInterest = GenFunc0[Unit,sci.IndexedSeq[Label]]("", _ => LabelsInTrainingSet) val actual: LabelsAndInfo[Label] = labelsAndInfo( a = a, - labelsOfInterest = labelsOfInterest, + labelsOfInterest = Option(labelsOfInterest), Map.empty, - labelsAndInfoEmpty + LabelsAndInfoEmpty ) - val expected = labelsForPrediction(a, labelsOfInterest.get, Map.empty[Label, Int]) + val expected = labelsForPrediction(a, labelsOfInterest, Map.empty[Label, Int]) assertEquals(expected, actual) } @Test def testReportTooManyMissing(): Unit = { val report = reportTooManyMissing( modelId = ModelId(), - labelInfo = labelsAndInfoEmpty, - missing = scm.Map("" -> labelsNotInTrainingSet), - auditor = auditor + labelInfo = LabelsAndInfoEmpty, + missing = scm.Map("" -> LabelsNotInTrainingSet), + auditor = Auditor ) assertEquals(Vector(TooManyMissingError), report.audited.errorMsgs.take(1)) @@ -100,7 +90,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testReportNoPrediction(): Unit = { - val report = reportNoPredictionEmpty + val report = ReportNoPredictionEmpty assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) assertEquals(None, report.natural) assertEquals(None, report.audited.value) @@ -109,9 +99,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testReportNoPredictionMissingLabelsDoNotExist(): Unit = { // The missing labels are reported in the error message - val report = reportNoPredictionPartial(labelsAndInfoMissingLabels) - assertEquals(Vector(NoLabelsError) ++ errorMessages, report.audited.errorMsgs) - assertEquals(None, reportNoPredictionEmpty.audited.value) + val report = ReportNoPredictionPartial(LabelsAndInfoMissingLabels) + assertEquals(Vector(NoLabelsError) ++ ErrorMessages, report.audited.errorMsgs) + assertEquals(None, ReportNoPredictionEmpty.audited.value) assertEquals(Set(), report.audited.missingVarNames) } @@ -120,12 +110,12 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val missingVariables = Seq("a", "b") val report = reportPredictorError( ModelId(), - labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet), + LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet), scm.Map("" -> missingVariables), throwable, Auditor ) - assertEquals(Vector(stackTrace) ++ errorMessages, report.audited.errorMsgs) + assertEquals(Vector(stackTrace) ++ ErrorMessages, report.audited.errorMsgs) assertEquals(missingVariables.toSet, report.audited.missingVarNames) assertEquals(None, report.natural) assertEquals(None, report.audited.value) @@ -135,8 +125,8 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val predictions = Map("label" -> 1.0) val report = reportSuccess( modelId = ModelId(), - labelInfo = labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet), - missing = scm.Map("" -> labelsNotInTrainingSet), + labelInfo = LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet), + missing = scm.Map("" -> LabelsNotInTrainingSet), prediction = predictions, auditor = Auditor ) @@ -147,9 +137,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testLabelsForPredictionContainsProblemsWhenNoLabelProvided(): Unit = { val labelsAndInfoNoLabels = labelsForPrediction( - example = Map[Label, Label](), // no label provided - labelsOfInterest = labelsOfInterestExtractor, - labelToInd = labelsInTrainingSetToIndex) + example = Map[String, String](), // no label provided + labelsOfInterest = LabelsOfInterestExtractor, + labelToInd = LabelsInTrainingSetToIndex) // In this scenario no problems are found but an // Option(GenAggFuncAccessorProblems) is returned instead of None @@ -159,21 +149,16 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsNotPresent(): Unit = { val labelsExtractor = - (label: Map[Label, Label]) => label.get("labels") match { - case ls: sci.IndexedSeq[_] => Option(ls.asInstanceOf[sci.IndexedSeq[Label]]) - case _ => None - } - + (label: Map[Label, Label]) => label.get("labels").map(sci.IndexedSeq[Label](_)) val descriptor = "label is missing" val labelsOfInterest = GenFunc.f1(GeneratedAccessor(descriptor, labelsExtractor, None))( "", _ getOrElse sci.IndexedSeq.empty[Label] ) - val labelsAndInfoNoLabelsGen1 = labelsForPrediction( example = Map[Label, Label](), labelsOfInterest = labelsOfInterest, - labelToInd = labelsInTrainingSetToIndex) + labelToInd = LabelsInTrainingSetToIndex) val problemsNoLabelsExpectedGen1 = Option(GenAggFuncAccessorProblems(Seq(descriptor), Seq())) assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) @@ -189,21 +174,22 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val labelsOfInterestWrapped = EnrichedErrorGenAggFunc(labelsOfInterest) val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq(descriptor))) - Try( - labelsForPrediction( - Map[String, String](), - labelsOfInterestWrapped, - labelsInTrainingSetToIndex) - ).failed.get match { - case SemanticsUdfException(_, _, _, accessorsInErr, _, _) => + Try { + labelsForPrediction(Map.empty[String, String], labelsOfInterestWrapped, + LabelsInTrainingSetToIndex) + } match { + // Failure is scala.util.Failure + case Failure(SemanticsUdfException(_, _, _, accessorsInErr, _, _)) => assertEquals(accessorsInErr, problemsNoLabelsExpected.get.errors) + case _ => fail() } } @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { + // This test would be better done with a property-based testing framework such as ScalaCheck val labelNotInTrainingSet = "d" val labelsAndInfo = labelsForPrediction(Map("label" -> labelNotInTrainingSet), - labelsOfInterestExtractor, labelsInTrainingSetToIndex) + LabelsOfInterestExtractor, LabelsInTrainingSetToIndex) assertEquals(Seq(labelNotInTrainingSet), labelsAndInfo.labelsNotInTrainingSet) } @@ -230,17 +216,17 @@ class MultilabelModelTest extends ModelSerializationTestHelper { } @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { - assertEquals(None, modelWithFeatureFunctions.subvalue(Vector.empty).natural) + assertEquals(None, ModelWithFeatureFunctions.subvalue(Vector.empty).natural) } @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { // When the amount of missing data exceeds the threshold, reportTooManyMissing should be // called and its value should be returned. - val modelWithMissingThreshold = modelWithFeatureFunctions.copy( + val modelWithMissingThreshold = ModelWithFeatureFunctions.copy( featureNames = sci.IndexedSeq("feature1"), - featureFunctions = featureFunctions, - labelsInTrainingSet = labelsInTrainingSet, + featureFunctions = FeatureFunctions, + labelsInTrainingSet = LabelsInTrainingSet, labelsOfInterest = None, numMissingThreshold = Option(0) ) @@ -259,9 +245,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { v4: SparseLabelDepFeatures): Try[Map[Label, Double]] = Try(throw new Exception("error")) } - val modelWithThrowingPredictorProducer = modelWithFeatureFunctions.copy( + val modelWithThrowingPredictorProducer = ModelWithFeatureFunctions.copy( predictorProducer = Lazy(PredictorThatThrows), - labelsOfInterest = None + labelsOfInterest = None ) val result = modelWithThrowingPredictorProducer(Vector.empty) @@ -271,9 +257,9 @@ class MultilabelModelTest extends ModelSerializationTestHelper { @Test def testSubvalueSuccess(): Unit = { val scoreToReturn = 5d - val modelSuccess = modelNoFeatures.copy( + val modelSuccess = ModelNoFeatures.copy( labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), - labelsOfInterest = Option(labelsOfInterestExtractor), + labelsOfInterest = Option(LabelsOfInterestExtractor), predictorProducer = Lazy(ConstantPredictor[Label](scoreToReturn)), featureFunctions = sci.IndexedSeq.empty ) @@ -291,7 +277,7 @@ class MultilabelModelTest extends ModelSerializationTestHelper { val featureFunctionThatThrows: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = GenFunc0("", _ => throw exception) - val modelWithFeatureFunctionThatThrows = modelNoFeatures.copy( + val modelWithFeatureFunctionThatThrows = ModelNoFeatures.copy( featureNames = sci.IndexedSeq("throwing feature"), featureFunctions = Vector(featureFunctionThatThrows), labelsInTrainingSet = sci.IndexedSeq[Label](""), @@ -311,7 +297,8 @@ object MultilabelModelTest { private type Label = String private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() - case class ConstantPredictor[K](prediction: Double = 0d) extends SparseMultiLabelPredictor[K] { + private case class ConstantPredictor[K](prediction: Double = 0d) extends + SparseMultiLabelPredictor[K] { override def apply(featuresUnused: SparseFeatures, labels: Labels[K], indicesUnused: LabelIndices, @@ -319,28 +306,26 @@ object MultilabelModelTest { Try(labels.map(_ -> prediction).toMap) } - case class Lazy[A](value: A) extends (() => A) { + private case class Lazy[A](value: A) extends (() => A) { override def apply(): A = value } // Common input - val labelsInTrainingSet: sci.IndexedSeq[Label] = sci.IndexedSeq[Label]("a", "b", "c") - val labelsInTrainingSetToIndex: Map[Label, Int] =labelsInTrainingSet.zipWithIndex.toMap - val labelsNotInTrainingSet: Seq[Label] = Seq("a", "b") - val baseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") - val errorMessages: Seq[String] = baseErrorMessage.zip(labelsNotInTrainingSet).map { + private val LabelsInTrainingSet: sci.IndexedSeq[Label] = sci.IndexedSeq[Label]("a", "b", "c") + private val LabelsInTrainingSetToIndex: Map[Label, Int] = LabelsInTrainingSet.zipWithIndex.toMap + private val LabelsNotInTrainingSet: Seq[Label] = Seq("a", "b") + private val BaseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") + private val ErrorMessages: Seq[String] = BaseErrorMessage.zip(LabelsNotInTrainingSet).map { case(msg, label) => s"$msg$label" } - val auditor: RootedTreeAuditor[Any, Map[Label, Double]] = - RootedTreeAuditor[Any, Map[Label, Double]]() // Feature functions - val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + private val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = GenFunc0("", _ => Iterable()) - val featureFunctions = Vector(EmptyIndicatorFn) + private val FeatureFunctions = Vector(EmptyIndicatorFn) // Models - val modelNoFeatures = MultilabelModel( + private val ModelNoFeatures = MultilabelModel( modelId = ModelId(), featureNames = sci.IndexedSeq(), featureFunctions = sci.IndexedSeq[GenAggFunc[Int, Sparse]](), @@ -351,38 +336,38 @@ object MultilabelModelTest { auditor = Auditor ) - val modelWithFeatureFunctions: + private val ModelWithFeatureFunctions: MultilabelModel[Tree[Any], String, Vector[String], RootedTree[Any, Map[Label, Double]]] = - modelNoFeatures.copy( + ModelNoFeatures.copy( featureFunctions = sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), - labelsInTrainingSet = labelsInTrainingSet, + labelsInTrainingSet = LabelsInTrainingSet, labelsOfInterest = Some(GenFunc0("", (a: Vector[String]) => a)) ) // LabelsAndInfo - val labelsAndInfoEmpty = LabelsAndInfo( + private val LabelsAndInfoEmpty = LabelsAndInfo( indices = sci.IndexedSeq.empty, labels = sci.IndexedSeq.empty, - labelsNotInTrainingSet = Seq[Label](), + labelsNotInTrainingSet = Seq.empty[Label], problems = None ) - val labelsAndInfoMissingLabels: LabelsAndInfo[Label] = - labelsAndInfoEmpty.copy(labelsNotInTrainingSet = labelsNotInTrainingSet) + private val LabelsAndInfoMissingLabels: LabelsAndInfo[Label] = + LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet) // Reports - val reportNoPredictionPartial: + private val ReportNoPredictionPartial: (LabelsAndInfo[Label]) => Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = reportNoPrediction( - modelId = ModelId(), - _: LabelsAndInfo[Label], - auditor = Auditor - ) - val reportNoPredictionEmpty: Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = - reportNoPredictionPartial(labelsAndInfoEmpty) + modelId = ModelId(), + _: LabelsAndInfo[Label], + auditor = Auditor + ) + private val ReportNoPredictionEmpty: Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = + ReportNoPredictionPartial(LabelsAndInfoEmpty) // Throwable and stack trace - def getThrowable(errorMessage: String): (Throwable, String) = { - val throwable = Try(throw new Exception(errorMessage)).failed.get + private def getThrowable(errorMessage: String): (Throwable, String) = { + val throwable = new Exception(errorMessage) val sw = new StringWriter val pw = new PrintWriter(sw) throwable.printStackTrace(pw) @@ -391,8 +376,8 @@ object MultilabelModelTest { } // Label extractors - def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = + private def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = example.filterKeys(_.startsWith("label")).toSeq.unzip._2.toIndexedSeq - val labelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) + private val LabelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) } From e3f9c59d611853147579e8a20e7398a228186050 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 25 Sep 2017 17:25:53 -0700 Subject: [PATCH 58/98] Getting everything to compile. Still some work to be done., --- .../eharmony/aloha/cli/ModelTypesTest.scala | 3 +- .../multilabel/VwMultilabelRowCreator.scala | 14 +++-- .../factory/JavaDefaultModelFactoryTest.java | 4 +- .../dataset/RowCreatorProducerTest.scala | 18 +++++- .../VwMultilabelRowCreatorTest.scala | 42 +++++++------ .../VwMultilabelModelPluginJsonReader.scala | 19 +++++- .../multilabel/VwMultilabelModelTest.scala | 59 +++++++++++-------- 7 files changed, 104 insertions(+), 55 deletions(-) diff --git a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala index e6b8fd45..d6617048 100644 --- a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala +++ b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala @@ -28,7 +28,8 @@ class ModelTypesTest { "ModelDecisionTree", "Regression", "Segmentation", - "VwJNI" + "VwJNI", + "multilabel-sparse" ) val actual = ModelFactory.defaultFactory(null, null).parsers.map(_.modelType).sorted diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 1404a73c..8ccbaaba 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -88,7 +88,7 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a positive example is designated to be one, * so the cost (or negative reward) is -1. */ - private[this] val Positive = (-1).toString + private[this] val PositiveCost = (-1).toString /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the @@ -96,7 +96,11 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a negative example is designated to be zero, * so the cost (or negative reward) is 0. */ - private[this] val Negative = 0.toString + private[this] val NegativeCost = 0.toString + + private[this] val PositiveDummyClassFeature = "P" + + private[this] val NegativeDummyClassFeature = "N" /** @@ -219,8 +223,8 @@ object VwMultilabelRowCreator { // These string interpolations are computed over and over but will always be the same // for a given dummyClassNs. // TODO: Precompute these in a class instance and pass in as parameters. - x(1) = s"$NegDummyClassId:0 |$dummyClassNs neg" - x(2) = s"$PosDummyClassId:-1 |$dummyClassNs pos" + x(1) = s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" + x(2) = s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" // This is mutable because we want speed. var i = 0 @@ -228,7 +232,7 @@ object VwMultilabelRowCreator { val labelInd = indices(i) // TODO or positives.contains(labelInd)? - val dv = if (positiveLabelIndices(i)) Positive else Negative + val dv = if (positiveLabelIndices(i)) PositiveCost else NegativeCost x(i + 3) = s"$labelInd:$dv |$classNs _$labelInd" i += 1 } diff --git a/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java b/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java index e857194e..c9dde596 100644 --- a/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java +++ b/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java @@ -12,6 +12,7 @@ import com.eharmony.aloha.models.conversion.DoubleToLongModel; import com.eharmony.aloha.models.exploration.BootstrapModel; import com.eharmony.aloha.models.exploration.EpsilonGreedyModel; +import com.eharmony.aloha.models.multilabel.MultilabelModel; import com.eharmony.aloha.reflect.RefInfo; import com.eharmony.aloha.semantics.NoSemantics; import com.eharmony.aloha.semantics.Semantics; @@ -78,7 +79,8 @@ public class JavaDefaultModelFactoryTest { ErrorSwallowingModel.parser().modelType(), EpsilonGreedyModel.parser().modelType(), BootstrapModel.parser().modelType(), - CloserTesterModel.parser().modelType() + CloserTesterModel.parser().modelType(), + MultilabelModel.parser().modelType() }; Arrays.sort(names); diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala index 553d30dc..438a10e4 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala @@ -3,15 +3,19 @@ package com.eharmony.aloha.dataset import java.lang.reflect.Modifier import com.eharmony.aloha +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import org.junit.Assert._ import org.junit.{Ignore, Test} import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner + import scala.collection.JavaConversions.asScalaSet import org.reflections.Reflections @RunWith(classOf[BlockJUnit4ClassRunner]) class RowCreatorProducerTest { + import RowCreatorProducerTest._ + private[this] def scanPkg = aloha.pkgName + ".dataset" @Test def testAllRowCreatorProducersHaveOnlyZeroArgConstructors() { @@ -20,9 +24,11 @@ class RowCreatorProducerTest { specProdClasses.foreach { clazz => val cons = clazz.getConstructors assertTrue(s"There should only be one constructor for ${clazz.getCanonicalName}. Found ${cons.length} constructors.", cons.length <= 1) - cons.headOption.foreach{ c => - val nParams = c.getParameterTypes.length - assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments. It takes $nParams.", 0, nParams) + cons.headOption.foreach { c => + if (!(WhitelistedRowCreatorProducers contains clazz)) { + val nParams = c.getParameterTypes.length + assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments. It takes $nParams.", 0, nParams) + } } } } @@ -41,3 +47,9 @@ class RowCreatorProducerTest { } } } + +object RowCreatorProducerTest { + private val WhitelistedRowCreatorProducers = Set[Class[_]]( + classOf[VwMultilabelRowCreator.Producer[_, _]] + ) +} \ No newline at end of file diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala index 0da4cd0b..1839d2d2 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -1,6 +1,6 @@ package com.eharmony.aloha.dataset.vw.multilabel -import com.eharmony.aloha.dataset.SparseFeatureExtractorFunction +import com.eharmony.aloha.dataset.{MissingAndErroneousFeatureInfo, SparseFeatureExtractorFunction} import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} import org.junit.Assert._ import org.junit.Test @@ -14,31 +14,25 @@ import org.junit.runners.BlockJUnit4ClassRunner class VwMultilabelRowCreatorTest { import VwMultilabelRowCreatorTest._ - // TODO: Test that shared doesn't occur when there are no shared features. - @Test def testSharedOmittedWhenNoSharedFeaturesExist(): Unit = { - fail() + @Test def testSharedIsPresentWhenNoFeaturesSpecified(): Unit = { + val rc = rowCreator(0) + val a = output(rc(X)) + val expected = SharedPrefix +: (DummyLabels ++ AllNegative) + assertEquals(expected, a.toList) } @Test def testOneFeatureNoPos(): Unit = { - val nFeatures = 1 - val shared = List("shared | f1") // One feature because nFeatures = 1 + val rc = rowCreator(numFeatures = 1) + val a = output(rc(X)) - val dummy = List( - s"$NegDummyClass:$NegVal |y neg", - s"$PosDummyClass:$PosVal |y pos" - ) - - // Only negative labels because no positive labels. - val classLabels = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") - - val rc = rowCreator(nFeatures) - val (missing, arr) = rc(Map.empty) - assertEquals(shared ++ dummy ++ classLabels, arr.toList) + val shared = s"$SharedPrefix| f1" + val expected = shared +: (DummyLabels ++ AllNegative) + assertEquals(expected, a.toList) } } object VwMultilabelRowCreatorTest { - private type Domain = Map[String, String] + private type Domain = Map[String, Any] private type Label = String private val Omitted = "" private val LabelsInTrainingSet = Vector("zero", "one", "two") @@ -46,6 +40,18 @@ object VwMultilabelRowCreatorTest { private val PosDummyClass = NegDummyClass + 1 private val PosVal = -1 private val NegVal = 0 + private val X = Map.empty[String, Any] + private val SharedPrefix = "shared " + + private val DummyLabels = List( + s"$NegDummyClass:$NegVal |y N", + s"$PosDummyClass:$PosVal |y P" + ) + + private val AllNegative = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") + + + private def output(out: (MissingAndErroneousFeatureInfo, Array[String])) = out._2 private[this] val featureFns = SparseFeatureExtractorFunction[Domain](Vector( "f1" -> GenFunc0(Omitted, _ => Seq(("", 1d))), diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala index 3fb9956f..9e3efbdc 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala @@ -1,11 +1,14 @@ package com.eharmony.aloha.models.vw.jni.multilabel.json +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.determineLabelNamespaces import com.eharmony.aloha.models.multilabel.SparsePredictorProducer import com.eharmony.aloha.models.vw.jni.Namespaces import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictorProducer import com.eharmony.aloha.util.Logging -import spray.json.{JsValue, JsonReader} +import spray.json.{DeserializationException, JsValue, JsonReader} +import scala.collection.breakOut import scala.collection.immutable.ListMap /** @@ -37,7 +40,19 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String]) if (missing.nonEmpty) info(s"features in namespaces not found in featureNames: $missing") - VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces) + val namespaceNames: Set[String] = namespaces.map(_._1)(breakOut) + val labelAndDummyLabelNss = determineLabelNamespaces(namespaceNames) + + labelAndDummyLabelNss match { + case Some((labelNs, _)) => + // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. + VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces, labelNs) + case _ => + throw new DeserializationException( + "Could not determine label namespace. Found namespaces: " + + namespaceNames.mkString(", ") + ) + } } } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 47c2ed71..39062f83 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -3,6 +3,7 @@ package com.eharmony.aloha.models.vw.jni.multilabel import java.io.File import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs @@ -110,11 +111,16 @@ object VwMultilabelModelTest { .toVector ) + val namespaces = List(("X", List(0))) + val labelNs = VwMultilabelRowCreator.determineLabelNamespaces(namespaces.unzip._1.toSet).get._1 + + val predProd = VwSparseMultilabelPredictorProducer[Label]( modelSource = TrainedModel, params = "", // to see the output: "-p /dev/stdout", defaultNs = List.empty[Int], - namespaces = List(("X", List(0))) + namespaces = namespaces, + labelNamespace = labelNs ) MultilabelModel( @@ -134,6 +140,8 @@ object VwMultilabelModelTest { f } + private val (labelNs, dummyLabelNs) = VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get + private def vwTrainingParams(modelFile: File = tmpFile()) = { // NOTES: @@ -161,17 +169,17 @@ object VwMultilabelModelTest { // "smaller is better". So, to get the probability, one must do `1/(1 + exp(-1 * -y))` // or simply `1/(1 + exp(y))`. val flags = - """ - | --quiet - | --csoaa_ldf mc - | --csoaa_rank - | --loss_function logistic - | -q YX - | --noconstant - | --ignore_linear X - | --ignore y - | -f - """.stripMargin.trim + s""" + | --quiet + | --csoaa_ldf mc + | --csoaa_rank + | --loss_function logistic + | -q ${labelNs}X + | --noconstant + | --ignore_linear X + | --ignore $dummyLabelNs + | -f + """.stripMargin.trim (flags + " " + modelFile.getCanonicalPath).split("\n").map(_.trim).mkString(" ") } @@ -180,24 +188,25 @@ object VwMultilabelModelTest { /** * A dataset that creates the following marginal distribution. - - Pr[seven] = 0.7 where seven is _C0_ - - Pr[eight] = 0.8 where eight is _C1_ - - Pr[six] = 0.6 where six is _C2_ + - Pr[seven] = 0.7 where seven is _0 + - Pr[eight] = 0.8 where eight is _1 + - Pr[six] = 0.6 where six is _2 * - * The observant reader may notice these are oddly ordered. On each line C1 appears first, - * then C0, then C2. This is done to show ordering doesn't matter. What matters is the + * The observant reader may notice these are oddly ordered. On each line _1 appears first, + * then _0, then _2. This is done to show ordering doesn't matter. What matters is the * class '''indices'''. */ private val TrainingData = + Vector( - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.084 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.084 |Y _C0_\n2:-0.084 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.024 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.336 |y _C2147483649_\n1:-0.336 |Y _C1_\n0:-0.336 |Y _C0_\n2:-0.336 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.056 |y _C2147483649_\n1:0.0 |Y _C1_\n0:-0.056 |Y _C0_\n2:0.0 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.144 |y _C2147483649_\n1:-0.144 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.144 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.224 |y _C2147483649_\n1:-0.224 |Y _C1_\n0:-0.224 |Y _C0_\n2:0.0 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.036 |y _C2147483649_\n1:0.0 |Y _C1_\n0:0.0 |Y _C0_\n2:-0.036 |Y _C2_", - s"shared |X $FeatureName\n2147483648:0.0 |y _C2147483648_\n2147483649:-0.096 |y _C2147483649_\n1:-0.096 |Y _C1_\n0:0.0 |Y _C0_\n2:0.0 |Y _C2_" + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.084 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:-0.084 |$labelNs _0\n2:-0.084 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.024 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:0.0 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.336 |$dummyLabelNs pos\n1:-0.336 |$labelNs _1\n0:-0.336 |$labelNs _0\n2:-0.336 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.056 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:-0.056 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.144 |$dummyLabelNs pos\n1:-0.144 |$labelNs _1\n0:0.0 |$labelNs _0\n2:-0.144 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.224 |$dummyLabelNs pos\n1:-0.224 |$labelNs _1\n0:-0.224 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.036 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:0.0 |$labelNs _0\n2:-0.036 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.096 |$dummyLabelNs pos\n1:-0.096 |$labelNs _1\n0:0.0 |$labelNs _0\n2:0.0 |$labelNs _2" ).map(_.split("\n")) private lazy val TrainedModel: ModelSource = { From 9dd35eefc291c8caeebadfc5716b481ca31efb29 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 27 Sep 2017 10:37:08 -0700 Subject: [PATCH 59/98] VW multi-label model parsing working correctly. Tests prove it! --- .../multilabel/VwMultilabelRowCreator.scala | 48 ++++--- .../aloha/factory/ri2jf/CollectionTypes.scala | 31 +++- .../MultilabelModelParserPlugin.scala | 2 +- .../aloha/models/multilabel/PluginInfo.scala | 3 +- .../multilabel/json/MultilabelModelJson.scala | 2 +- .../VwMultilabelRowCreatorTest.scala | 135 +++++++++++++++++- .../compiled/CompiledSemanticsInstances.scala | 28 ++++ .../AnyNameIdentitySemanticsPlugin.scala | 25 ++++ .../VwSparseMultilabelPredictor.scala | 25 +++- .../VwSparseMultilabelPredictorProducer.scala | 9 +- .../json/VwMultilabelModelJson.scala | 2 +- .../VwMultilabelModelPluginJsonReader.scala | 13 +- .../multilabel/VwMultilabelModelTest.scala | 81 ++++++++++- 13 files changed, 355 insertions(+), 49 deletions(-) create mode 100644 aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala create mode 100644 aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 8ccbaaba..e9fd353c 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -12,7 +12,7 @@ import com.eharmony.aloha.semantics.func.GenAggFunc import spray.json.JsValue import scala.collection.{breakOut, immutable => sci} -import scala.util.Try +import scala.util.{Failure, Success, Try} /** * Created by ryan.deak on 9/13/17. @@ -24,24 +24,14 @@ final case class VwMultilabelRowCreator[-A, K]( namespaces: List[(String, List[Int])], normalizer: Option[CharSequence => CharSequence], positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], + classNs: String, + dummyClassNs: String, includeZeroValues: Boolean = false ) extends RowCreator[A, Array[String]] { import VwMultilabelRowCreator._ @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap - @transient private[this] lazy val nss = - determineLabelNamespaces(namespaces.map{ case (ns, _) => ns}(breakOut)) getOrElse { - // If there are so many VW namespaces that all available Unicode characters are taken, - // then a memory error will probably already have occurred. - throw new AlohaException( - "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + - namespaces.unzip._1.mkString(", ") - ) - } - - @transient private[this] lazy val classNs = nss._1 - @transient private[this] lazy val dummyClassNs = nss._2 override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { val (missingAndErrs, features) = featuresFunction(a) @@ -136,15 +126,15 @@ object VwMultilabelRowCreator { * @return the namespace for ''actual'' label information then the namespace for ''dummy'' * label information. If two valid namespaces couldn't be produced, return None. */ - private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[(String, String)] = { + private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[LabelNamespaces] = { val nss = nssToFirstCharBitSet(usedNss) preferredLabelNamespaces(nss) orElse bruteForceNsSearch(nss) } - private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[(String, String)] = { + private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[LabelNamespaces] = { PreferredLabelNamespaces collectFirst { case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) => - (actual.toString, dummy.toString) + LabelNamespaces(actual.toString, dummy.toString) } } @@ -168,7 +158,7 @@ object VwMultilabelRowCreator { * @param usedNss the set of first characters in namespaces. * @return the namespace to use for the actual classes and dummy classes, respectively. */ - private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[(String, String)] = { + private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[LabelNamespaces] = { val found = Stream .from(FirstValidCharacter) @@ -177,7 +167,7 @@ object VwMultilabelRowCreator { found match { case actual #:: dummy #:: Stream.Empty => - Option((actual.toChar.toString, dummy.toChar.toString)) + Option(LabelNamespaces(actual.toChar.toString, dummy.toChar.toString)) case _ => None } } @@ -330,12 +320,30 @@ object VwMultilabelRowCreator { val spec = for { cov <- covariates pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + labelNs <- labelNamespaces(nss) + actualLabelNs = labelNs.labelNs + dummyLabelNs = labelNs.dummyLabelNs sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) - } yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, normalizer, pos) + } yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, + normalizer, pos, actualLabelNs, dummyLabelNs) spec } + private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = { + val nsNames: Set[String] = nss.map(_._1)(breakOut) + determineLabelNamespaces(nsNames) match { + case Some(ns) => Success(ns) + + // If there are so many VW namespaces that all available Unicode characters are taken, + // then a memory error will probably already have occurred. + case None => Failure(new AlohaException( + "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + + nsNames.mkString(", ") + )) + } + } + private[multilabel] def positiveLabelsFn( semantics: CompiledSemantics[A], positiveLabels: String @@ -343,4 +351,6 @@ object VwMultilabelRowCreator { getDv[A, sci.IndexedSeq[K]]( semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) } + + private[aloha] final case class LabelNamespaces(labelNs: String, dummyLabelNs: String) } \ No newline at end of file diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala index 72ec2954..8ae4f317 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala @@ -1,16 +1,39 @@ package com.eharmony.aloha.factory.ri2jf import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import spray.json.DefaultJsonProtocol.{immSeqFormat, mapFormat} import spray.json.JsonFormat -import spray.json.DefaultJsonProtocol.immSeqFormat /** * Created by ryan on 1/25/17. */ class CollectionTypes extends RefInfoToJsonFormatConversions { - def apply[A](implicit conv: RefInfoToJsonFormat, r: RefInfo[A], jf: JsonFormatCaster[A]): Option[JsonFormat[A]] = { - if (RefInfoOps.isSubType[A, collection.immutable.Seq[Any]]) - conv(r.typeArguments.head).flatMap(f => jf(immSeqFormat(f))) + def apply[A](implicit conv: RefInfoToJsonFormat, + r: RefInfo[A], + jf: JsonFormatCaster[A]): Option[JsonFormat[A]] = { + + val typeParams = RefInfoOps.typeParams[A] + + if (RefInfoOps.isSubType[A, collection.immutable.Map[Any, Any]]) + typeParams match { + case List(tKey, tVal) => + for { + k <- conv(tKey) + v <- conv(tVal) + f <- jf(mapFormat(k, v)) + } yield f + case _ => + // TODO: perhaps change the API at some point to better report errors here. + None + } + else if (RefInfoOps.isSubType[A, collection.immutable.Seq[Any]]) +// conv(r.typeArguments.head).flatMap(f => jf(immSeqFormat(f))) +// conv(typeParams.head).flatMap(f => jf(immSeqFormat(f))) + for { + tEl <- typeParams.headOption + el <- conv(tEl) + f <- jf(immSeqFormat(el)) + } yield f else None } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala index bb5b586f..5a7d15bf 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala @@ -25,7 +25,7 @@ trait MultilabelModelParserPlugin { * @tparam K the label or class type to be produced by the multi-label model. * @return a JSON reader that can create `SparsePredictorProducer[K]` from JSON ASTs. */ - def parser[K](info: PluginInfo) + def parser[K](info: PluginInfo[K]) (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala index 1fe8d0f3..db38ce6d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala @@ -7,6 +7,7 @@ import scala.collection.immutable.ListMap /** * Created by ryan.deak on 9/7/17. */ -trait PluginInfo { +trait PluginInfo[K] { def features: ListMap[String, Spec] + def labelsInTrainingSet: Vector[K] } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala index c0e4dadf..5d7be2c0 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala @@ -33,7 +33,7 @@ trait MultilabelModelJson extends SpecJson with ScalaJsonFormats { labelsInTrainingSet: Vector[K], labelsOfInterest: Option[String], underlying: JsObject - ) extends PluginInfo + ) extends PluginInfo[K] protected[this] final implicit def multilabelDataJsonFormat[K: JsonFormat]: RootJsonFormat[MultilabelData[K]] = jsonFormat7(MultilabelData[K]) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala index 1839d2d2..8e7d5492 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -1,11 +1,16 @@ package com.eharmony.aloha.dataset.vw.multilabel +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson import com.eharmony.aloha.dataset.{MissingAndErroneousFeatureInfo, SparseFeatureExtractorFunction} import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import spray.json.pimpString +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances + +import scala.collection.breakOut /** * Created by ryan.deak on 9/22/17. @@ -29,6 +34,95 @@ class VwMultilabelRowCreatorTest { val expected = shared +: (DummyLabels ++ AllNegative) assertEquals(expected, a.toList) } + + @Test def testRowCreationFromJson(): Unit = { + + // Some labels. These are intentionally out of order. Don't change. + val allLabelInts = Vector(7, 8, 6) + + // Run tests for the entire power set of positive labels. + val tests = Seq( + // positive labels: + 1, // {} + 6, // {6} + 7, // {7} + 8, // {8} + 6 * 7, // {6, 7} + 6 * 8, // {6, 8} + 7 * 8, // {7, 8} + 6 * 7 * 8 // {6, 7, 8} + ) + + // Could throw, but if so, test should fail. + val rc = getRowCreator(allLabelInts) + + tests foreach { test => + val exp = positiveAndNegativeValues(test, allLabelInts) + val (_, act) = rc(test) + testOutput(test, exp, act, MultilabelOutputPrefix, rc.classNs) + } + } + + private[this] def getRowCreator(allLabelInts: Vector[Int]) = { + val rcProducer = + new VwMultilabelRowCreator.Producer[Int, Label](allLabelInts.map(y => y.toString)) + + val rcTry = rcProducer.getRowCreator( + CompiledSemanticsInstances.anyNameIdentitySemantics[Int], + MultilabelDatasetJson + ) + + // Don't care about calling `.get`. If it fails, the test will appropriately blow up. + rcTry.get + } + + private[this] def positiveAndNegativeValues(n: Int, allLabels: Seq[Int]): Seq[Boolean] = { + val divisorsN = divisors(n) + val labelInd: Map[Int, Int] = allLabels.zipWithIndex.toMap + val allLabelSet = labelInd.keySet + val pos = allLabelSet intersect divisorsN + val neg = allLabelSet diff divisorsN + + // Because pos and neg form a partition of allLabelSet, labelInd.apply is safe. + val posAndNeg = + pos.map(p => labelInd(p) -> true) ++ + neg.map(p => labelInd(p) -> false) + + posAndNeg + .toSeq + .sorted + .map { case (_, isPos) => isPos } + } + + private[this] def testOutput( + n: Int, + expectedResults: Seq[Boolean], + actualResults: Array[String], + prefix: Seq[String], + labelNs: String + ): Unit = { + + val suffix = + expectedResults.zipWithIndex map { case (isPos, i) => + s"$i:${if (isPos) -1 else 0} |$labelNs _$i" + } + + assertEquals(prefix, actualResults.take(prefix.size).toSeq) + assertEquals(suffix, actualResults.drop(prefix.size).toSeq) + } + + + /** + * Get all divisors of `n`. + * This is not super efficient: O(sqrt(N)) + * @param n value for which divisors are desired. + * @return + */ + private[this] def divisors(n: Int): Set[Int] = + (1 to math.sqrt(n).ceil.toInt).flatMap { + case i if n % i == 0 => List(i, n / i) + case _ => Nil + }(breakOut) } object VwMultilabelRowCreatorTest { @@ -50,6 +144,42 @@ object VwMultilabelRowCreatorTest { private val AllNegative = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") + // Notice the positive labels. This says if the function input is 0 (mod v), where v is + // one of {6, 7, 8}, then make the example a positive example for the associated label. + // + // This JSON should be used with CompiledSemanticsInstances.anyNameIdentitySemantics[Int]. + // This variable `fn_input` works because of the way that the semantics works. It just + // returns the `fn_input` is just the input value of the function. + // + // All of the features are invariant to the input. + private[aloha] val MultilabelDatasetJson = + """ + |{ + | "imports": [ + | "com.eharmony.aloha.feature.BasicFunctions._" + | ], + | "features": [ + | { "name": "f1", "spec": "1" }, + | { "name": "f2", "spec": "1" } + | ], + | "namespaces": [ + | { "name": "ns1", "features": [ "f1" ] }, + | { "name": "ns2", "features": [ "f2" ] } + | ], + | "normalizeFeatures": false, + | "positiveLabels": "(6 to 8).flatMap(v => List(v.toString).filter(_ => ${fn_input} % v == 0))" + |} + """.stripMargin.parseJson.convertTo[VwMultilabeledJson] + + // This is what the first lines of the output are expected to be. This doesn't change b/c + // the features (and dummy class definitions) in MultilabelDatasetJson are invariant to + // the input. + private[aloha] val MultilabelOutputPrefix = Seq( + "shared |ns1 f1 |ns2 f2", // due to definition of features and namespaces. + "2147483648:0 |y N", // negative dummy class + "2147483649:-1 |y P" // positive dummy class + ) + private def output(out: (MissingAndErroneousFeatureInfo, Array[String])) = out._2 @@ -82,6 +212,7 @@ object VwMultilabelRowCreatorTest { private[this] val StdRowCreator: VwMultilabelRowCreator[Domain, Label] = { val ff = featureFns(0) + val labelNss = VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get VwMultilabelRowCreator[Domain, Label]( allLabelsInTrainingSet = LabelsInTrainingSet, @@ -89,7 +220,9 @@ object VwMultilabelRowCreatorTest { defaultNamespace = ff.features.indices.toList, // All features in default NS. namespaces = List.empty[(String, List[Int])], normalizer = Option.empty[CharSequence => CharSequence], - positiveLabelsFunction = positiveLabels() + positiveLabelsFunction = positiveLabels(), + classNs = labelNss.labelNs, + dummyClassNs = labelNss.dummyLabelNs ) } } diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala new file mode 100644 index 00000000..3d5ff863 --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala @@ -0,0 +1,28 @@ +package com.eharmony.aloha.semantics.compiled + +import com.eharmony.aloha.FileLocations +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler +import com.eharmony.aloha.semantics.compiled.plugin.AnyNameIdentitySemanticsPlugin + +/** + * Created by ryan.deak on 9/26/17. + */ +object CompiledSemanticsInstances { + + /** + * Creates a [[CompiledSemantics]] with the [[AnyNameIdentitySemanticsPlugin]] and with + * imports `List("com.eharmony.aloha.feature.BasicFunctions._")`. + * A class cache directory will be used. + * @tparam A the domain of the functions created by this semantics. + * @return + */ + private[aloha] def anyNameIdentitySemantics[A: RefInfo]: CompiledSemantics[A] = { + import scala.concurrent.ExecutionContext.Implicits.global + new CompiledSemantics( + new TwitterEvalCompiler(classCacheDir = FileLocations.testGeneratedClasses), + new AnyNameIdentitySemanticsPlugin[A], + List("com.eharmony.aloha.feature.BasicFunctions._") + ) + } +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala new file mode 100644 index 00000000..40324d2c --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala @@ -0,0 +1,25 @@ +package com.eharmony.aloha.semantics.compiled.plugin + +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.semantics.compiled.{CompiledSemanticsPlugin, RequiredAccessorCode} + +/** + * This plugin binds ANY variable name to the input. For example, if the domain of + * the generated function is `Int` and the function specification is any of the following: + * + - `"${asdf} + 1"` + - `"${jkl} + 1"` + - `"${i_dont_care_the_name} + 1"` + - ... + * + * the resulting function adds 1 to the input. + * + * Created by ryan.deak on 9/26/17. + * @param refInfoA reflection information about the function domain. + * @tparam A domain of the functions being generated from this plugin. + */ +private[aloha] class AnyNameIdentitySemanticsPlugin[A](implicit val refInfoA: RefInfo[A]) + extends CompiledSemanticsPlugin[A] { + override def accessorFunctionCode(spec: String): Right[Nothing, RequiredAccessorCode] = + Right(RequiredAccessorCode(Seq(s"identity[${RefInfoOps.toString[A]}]"))) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 60b1e68c..91af39ad 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -36,13 +36,15 @@ case class VwSparseMultilabelPredictor[K]( modelSource: ModelSource, params: String, defaultNs: List[Int], - namespaces: List[(String, List[Int])]) + namespaces: List[(String, List[Int])], + numLabelsInTrainingSet: Int) extends SparseMultiLabelPredictor[K] with Closeable { import VwSparseMultilabelPredictor._ - @transient private[multilabel] lazy val vwModel = createLearner(modelSource, params).get + @transient private[multilabel] lazy val vwModel = + createLearner(modelSource, params, numLabelsInTrainingSet).get { val emptyNss = namespaces collect { case (ns, ind) if ind.isEmpty => ns } @@ -85,6 +87,7 @@ extends SparseMultiLabelPredictor[K] object VwSparseMultilabelPredictor { private val ClassNS = "Y" + private[this] val AddlVwRingSize = 10 /** * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. @@ -144,12 +147,22 @@ object VwSparseMultilabelPredictor { * @return */ // TODO: How much of the parameter setup is up to the caller versus this function? - private[multilabel] def paramsWithSource(modelSource: File, params: String): String = - params + " -i" + modelSource.getCanonicalPath + " -t --quiet" + private[multilabel] def paramsWithSource( + modelSource: File, + params: String, + numLabelsInTrainingSet: Int + ): String = { + val ringSize = (numLabelsInTrainingSet + AddlVwRingSize) + s"$params -i ${modelSource.getCanonicalPath} --ring_size $ringSize --testonly --quiet" + } - private[multilabel] def createLearner(modelSource: ModelSource, params: String): Try[VWActionScoresLearner] = { + private[multilabel] def createLearner( + modelSource: ModelSource, + params: String, + numLabelsInTrainingSet: Int + ): Try[VWActionScoresLearner] = { val modelFile = modelSource.localVfs.replicatedToLocal() - val updatedParams = paramsWithSource(modelFile.fileObj, params) + val updatedParams = paramsWithSource(modelFile.fileObj, params, numLabelsInTrainingSet) Try { VWLearners.create[VWActionScoresLearner](updatedParams) } } } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala index 493bbc95..0ddb3d8b 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -32,10 +32,11 @@ case class VwSparseMultilabelPredictorProducer[K]( params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])], - labelNamespace: String + labelNamespace: String, + numLabelsInTrainingSet: Int ) extends SparsePredictorProducer[K] { override def apply(): VwSparseMultilabelPredictor[K] = - VwSparseMultilabelPredictor[K](modelSource, params, defaultNs, namespaces) + VwSparseMultilabelPredictor[K](modelSource, params, defaultNs, namespaces, numLabelsInTrainingSet) } object VwSparseMultilabelPredictorProducer extends MultilabelPluginProviderCompanion { @@ -44,9 +45,9 @@ object VwSparseMultilabelPredictorProducer extends MultilabelPluginProviderCompa object Plugin extends MultilabelModelParserPlugin { override def name: String = "vw" - override def parser[K](info: PluginInfo) + override def parser[K](info: PluginInfo[K]) (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] = { - VwMultilabelModelPluginJsonReader[K](info.features.keys.toVector) + VwMultilabelModelPluginJsonReader[K](info.features.keys.toVector, info.labelsInTrainingSet.size) } } } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala index 4bfee0d8..63ccbcc8 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala @@ -15,7 +15,7 @@ trait VwMultilabelModelJson extends ScalaJsonFormats { private[multilabel] case class VwMultilabelAst( `type`: String, modelSource: ModelSource, - params: Either[Seq[String], String] = Right(""), + params: Option[Either[Seq[String], String]] = Option(Right("")), namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) ) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala index 9e3efbdc..418cc35d 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala @@ -1,7 +1,6 @@ package com.eharmony.aloha.models.vw.jni.multilabel.json -import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator -import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.determineLabelNamespaces +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.{LabelNamespaces, determineLabelNamespaces} import com.eharmony.aloha.models.multilabel.SparsePredictorProducer import com.eharmony.aloha.models.vw.jni.Namespaces import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictorProducer @@ -23,7 +22,7 @@ import scala.collection.immutable.ListMap * @param featureNames feature names from the multi-label model. * @tparam K label type for the predictions outputted by the */ -case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String]) +case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String], numLabelsInTrainingSet: Int) extends JsonReader[SparsePredictorProducer[K]] with VwMultilabelModelJson with Namespaces @@ -44,9 +43,9 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String]) val labelAndDummyLabelNss = determineLabelNamespaces(namespaceNames) labelAndDummyLabelNss match { - case Some((labelNs, _)) => + case Some(LabelNamespaces(labelNs, _)) => // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. - VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces, labelNs) + VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces, labelNs, numLabelsInTrainingSet) case _ => throw new DeserializationException( "Could not determine label namespace. Found namespaces: " + @@ -59,8 +58,8 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String]) object VwMultilabelModelPluginJsonReader extends Logging { private val JsonErrStrLength = 100 - private[multilabel] def vwParams(params: Either[Seq[String], String]): String = - params.fold(_ mkString " ", identity).trim + private[multilabel] def vwParams(params: Option[Either[Seq[String], String]]): String = + params.fold("")(e => e.fold(ps => ps.mkString(" "), identity[String])).trim private[multilabel] def notObjErr(json: JsValue): String = { val str = json.prettyPrint diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 39062f83..c77893b2 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -2,19 +2,25 @@ package com.eharmony.aloha.models.vw.jni.multilabel import java.io.File +import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces +import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances import com.eharmony.aloha.semantics.func.GenFunc0 import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import spray.json.{DefaultJsonProtocol, JsonWriter} import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} +import spray.json.DefaultJsonProtocol.{StringJsonFormat, vectorFormat} import scala.annotation.tailrec @@ -25,7 +31,7 @@ import scala.annotation.tailrec class VwMultilabelModelTest { import VwMultilabelModelTest._ - @Test def test1(): Unit = { + @Test def testTrainedModelWorks(): Unit = { val model = Model try { @@ -44,6 +50,71 @@ class VwMultilabelModelTest { } } + @Test def testTrainedModelCanBeParsedAndUsed(): Unit = { + val factory = + ModelFactory.defaultFactory( + CompiledSemanticsInstances.anyNameIdentitySemantics[Any], + OptionAuditor[Map[String, Double]]() + ) + + val json = modelJson(TrainedModel, AllLabels) + val modelTry = factory.fromString(json) + val model = modelTry.get // : Model[Any, Option[Map[String, Double]]] + val x = () + val y = model(x) + model.close() + + y match { + case None => fail("Model should produce output. Produced None.") + case Some(m) => + assertEquals(ExpectedMarginalDist.keySet, m.keySet) + + ExpectedMarginalDist foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } + } + + private[this] def modelJson[K: JsonWriter]( + modelSource: ModelSource, + labelsInTrainingSet: Vector[K], + labelsOfInterest: Option[String] = None) = { + + implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) + + val loi = labelsOfInterest.fold(""){ f => + val escaped = f.replaceAll("\"", "\\\"") + s""""labelsOfInterest": "$escaped",\n""" + } + + val json = + s""" + |{ + | "modelType": "multilabel-sparse", + | "modelId": { "id": 1, "name": "NONE" }, + | "features": { + | "feature": "1" + | }, + | "numMissingThreshold": 0, + | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, + |$loi + | "underlying": { + | "type": "vw", + | "modelSource": ${toJsonString(modelSource)}, + | "namespaces": { + | "X": [ + | "feature" + | ] + | } + | } + |} + """.stripMargin.trim + json + } + + private[this] def toJsonString[A: JsonWriter](a: A): String = + implicitly[JsonWriter[A]].write(a).compactPrint + private[this] def testEmpty(yEmpty: PredictionOutput): Unit = { assertEquals(None, yEmpty.value) assertEquals(Vector("No labels provided. Cannot produce a prediction."), yEmpty.errorMsgs) @@ -112,7 +183,7 @@ object VwMultilabelModelTest { ) val namespaces = List(("X", List(0))) - val labelNs = VwMultilabelRowCreator.determineLabelNamespaces(namespaces.unzip._1.toSet).get._1 + val labelNs = VwMultilabelRowCreator.determineLabelNamespaces(namespaces.unzip._1.toSet).get.labelNs val predProd = VwSparseMultilabelPredictorProducer[Label]( @@ -120,7 +191,8 @@ object VwMultilabelModelTest { params = "", // to see the output: "-p /dev/stdout", defaultNs = List.empty[Int], namespaces = namespaces, - labelNamespace = labelNs + labelNamespace = labelNs, + numLabelsInTrainingSet = AllLabels.size ) MultilabelModel( @@ -140,7 +212,8 @@ object VwMultilabelModelTest { f } - private val (labelNs, dummyLabelNs) = VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get + private val LabelNamespaces(labelNs, dummyLabelNs) = + VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get private def vwTrainingParams(modelFile: File = tmpFile()) = { From 82746e75964c71363382eb88f755f31bdd98c19a Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 27 Sep 2017 10:50:01 -0700 Subject: [PATCH 60/98] exposed VW parameters to VwSparseMultilabelPredictor --- .../VwSparseMultilabelPredictor.scala | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 91af39ad..f92ec156 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -34,6 +34,8 @@ import scala.util.Try // TODO: Comment this function. It requires a lot of assumptions. Make those known. case class VwSparseMultilabelPredictor[K]( modelSource: ModelSource, + + // TODO: Should these be removed? I don't think so but could be, w/o harm, in limited cases. params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])], @@ -43,8 +45,12 @@ extends SparseMultiLabelPredictor[K] import VwSparseMultilabelPredictor._ - @transient private[multilabel] lazy val vwModel = - createLearner(modelSource, params, numLabelsInTrainingSet).get + @transient private[this] lazy val paramsAndVwModel = + createLearner(modelSource, params, numLabelsInTrainingSet) + + @transient private[this] lazy val updatedParams = paramsAndVwModel._1 + @transient private[multilabel] lazy val vwModel = paramsAndVwModel._2.get + { val emptyNss = namespaces collect { case (ns, ind) if ind.isEmpty => ns } @@ -57,6 +63,12 @@ extends SparseMultiLabelPredictor[K] require(vwModel != null) } + /** + * Get the VW parameters used to invoke the underlying VW model. + * @return VW parameters. + */ + def vwParams(): String = updatedParams + /** * Given the input, form a VW example, and delegate to the underlying ''CSOAA LDF'' VW model. * @param features (non-label dependent) features shared across all labels. @@ -142,8 +154,8 @@ object VwSparseMultilabelPredictor { * * val ex = str.split("\n") * }}} - * @param modelSource - * @param params + * @param modelSource location of a VW model. + * @param params the parameters passed to the model to which additional parameters will be added. * @return */ // TODO: How much of the parameter setup is up to the caller versus this function? @@ -152,7 +164,7 @@ object VwSparseMultilabelPredictor { params: String, numLabelsInTrainingSet: Int ): String = { - val ringSize = (numLabelsInTrainingSet + AddlVwRingSize) + val ringSize = numLabelsInTrainingSet + AddlVwRingSize s"$params -i ${modelSource.getCanonicalPath} --ring_size $ringSize --testonly --quiet" } @@ -160,9 +172,10 @@ object VwSparseMultilabelPredictor { modelSource: ModelSource, params: String, numLabelsInTrainingSet: Int - ): Try[VWActionScoresLearner] = { + ): (String, Try[VWActionScoresLearner]) = { val modelFile = modelSource.localVfs.replicatedToLocal() val updatedParams = paramsWithSource(modelFile.fileObj, params, numLabelsInTrainingSet) - Try { VWLearners.create[VWActionScoresLearner](updatedParams) } + val learner = Try { VWLearners.create[VWActionScoresLearner](updatedParams) } + (updatedParams, learner) } } From 5ab8e7cfddbc2d046419e1a2eb0e18c087825c49 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 27 Sep 2017 12:05:45 -0700 Subject: [PATCH 61/98] removed TODO. --- .../models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index f92ec156..609b20f8 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -86,7 +86,6 @@ extends SparseMultiLabelPredictor[K] labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] ): Try[Map[K, Double]] = { - // TODO: Pass ClassNS in via the constructor val x = VwMultilabelRowCreator.predictionInput(features, indices, defaultNs, namespaces, ClassNS) val pred = Try { vwModel.predict(x) } val yOut = pred.map { y => produceOutput(y, labels) } From 2023e74cc97b2b2f46f1b641be21e7f1a4aaa1a7 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 27 Sep 2017 12:51:12 -0700 Subject: [PATCH 62/98] updated tests to add coverage. --- .../multilabel/VwMultilabelModelTest.scala | 142 +++++++++++------- 1 file changed, 91 insertions(+), 51 deletions(-) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index c77893b2..ac8b561d 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -18,9 +18,9 @@ import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner +import spray.json.DefaultJsonProtocol.{StringJsonFormat, vectorFormat} import spray.json.{DefaultJsonProtocol, JsonWriter} import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} -import spray.json.DefaultJsonProtocol.{StringJsonFormat, vectorFormat} import scala.annotation.tailrec @@ -50,70 +50,69 @@ class VwMultilabelModelTest { } } - @Test def testTrainedModelCanBeParsedAndUsed(): Unit = { + @Test def testTrainedModelCanBeParsedAndUsed(): Unit = + testModel(None) + + @Test def testTrainedModelCanBeParsedAndUsedWithLabels67(): Unit = + testModel(Option(Set(LabelSix, LabelSeven))) + + @Test def testTrainedModelCanBeParsedAndUsedWithLabels678(): Unit = + testModel(Option(Set(LabelSix, LabelSeven, LabelEight))) + + @Test def testTrainedModelCanBeParsedAndUsedWithNoLabels(): Unit = + testModel(Option(Set.empty), modelShouldProduceOutput = false) + + /** + * + * @param desiredLabels Notice since this is a Set, label order doesn't matter. + * @param modelShouldProduceOutput whether the model is expected to + */ + private[this] def testModel( + desiredLabels: Option[Set[Label]], + modelShouldProduceOutput: Boolean = true + ): Unit = { + val factory = ModelFactory.defaultFactory( - CompiledSemanticsInstances.anyNameIdentitySemantics[Any], - OptionAuditor[Map[String, Double]]() + CompiledSemanticsInstances.anyNameIdentitySemantics[Set[Label]], + OptionAuditor[Map[Label, Double]]() ) - val json = modelJson(TrainedModel, AllLabels) + // "${desired_labels_from_input}.toVector" returns the input passed to the model in + // the form of a Vector. See anyNameIdentitySemantics for more information. + val desiredLabelFn = desiredLabels map (_ => "${desired_labels_from_input}.toVector") + val json = modelJson(TrainedModel, AllLabels, desiredLabelFn) + val modelTry = factory.fromString(json) - val model = modelTry.get // : Model[Any, Option[Map[String, Double]]] - val x = () + val model = modelTry.get // : Model[Set[Label], Option[Map[Label, Double]]] + val x = desiredLabels getOrElse Set.empty val y = model(x) model.close() y match { - case None => fail("Model should produce output. Produced None.") + case None => + if (modelShouldProduceOutput) + fail("Model should produce output. Produced None.") case Some(m) => - assertEquals(ExpectedMarginalDist.keySet, m.keySet) - - ExpectedMarginalDist foreach { case (k, v) => - assertEquals(s"For key '$k':", v, m(k), 0.01) + if (!modelShouldProduceOutput) + fail(s"Model should not produce output. Produced: $m") + else { + val expected = desiredLabels match { + case None => ExpectedMarginalDist + case Some(labels) => ExpectedMarginalDist.filterKeys(labels.contains) + } + + // The keys should be the same set, not a super set. + assertEquals(expected.keySet, m.keySet) + + // Test the value associated with the label. + expected foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } } } } - private[this] def modelJson[K: JsonWriter]( - modelSource: ModelSource, - labelsInTrainingSet: Vector[K], - labelsOfInterest: Option[String] = None) = { - - implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) - - val loi = labelsOfInterest.fold(""){ f => - val escaped = f.replaceAll("\"", "\\\"") - s""""labelsOfInterest": "$escaped",\n""" - } - - val json = - s""" - |{ - | "modelType": "multilabel-sparse", - | "modelId": { "id": 1, "name": "NONE" }, - | "features": { - | "feature": "1" - | }, - | "numMissingThreshold": 0, - | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, - |$loi - | "underlying": { - | "type": "vw", - | "modelSource": ${toJsonString(modelSource)}, - | "namespaces": { - | "X": [ - | "feature" - | ] - | } - | } - |} - """.stripMargin.trim - json - } - - private[this] def toJsonString[A: JsonWriter](a: A): String = - implicitly[JsonWriter[A]].write(a).compactPrint private[this] def testEmpty(yEmpty: PredictionOutput): Unit = { assertEquals(None, yEmpty.value) @@ -298,6 +297,47 @@ object VwMultilabelModelTest { private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + private[multilabel] def modelJson[K: JsonWriter]( + modelSource: ModelSource, + labelsInTrainingSet: Vector[K], + labelsOfInterest: Option[String] = None) = { + + implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) + + val loi = labelsOfInterest.fold(""){ f => + val escaped = f.replaceAll("\"", "\\\"") + s""""labelsOfInterest": "$escaped",\n""" + } + + val json = + s""" + |{ + | "modelType": "multilabel-sparse", + | "modelId": { "id": 1, "name": "NONE" }, + | "features": { + | "feature": "1" + | }, + | "numMissingThreshold": 0, + | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, + |$loi + | "underlying": { + | "type": "vw", + | "modelSource": ${toJsonString(modelSource)}, + | "namespaces": { + | "X": [ + | "feature" + | ] + | } + | } + |} + """.stripMargin.trim + json + } + + private[this] def toJsonString[A: JsonWriter](a: A): String = + implicitly[JsonWriter[A]].write(a).compactPrint + + /** * Creates the power set of the provided set. * Answer provided by Chris Marshall (@oxbow_lakes) on From 86c5c5017905b8c4757e790134099aa8462335fc Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 27 Sep 2017 15:52:12 -0700 Subject: [PATCH 63/98] Added additional tests. --- .../VwSparseMultilabelPredictor.scala | 2 +- .../multilabel/VwMultilabelModelTest.scala | 5 ++ .../VwSparseMultilabelPredictorTest.scala | 73 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 609b20f8..6af487d9 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -98,7 +98,7 @@ extends SparseMultiLabelPredictor[K] object VwSparseMultilabelPredictor { private val ClassNS = "Y" - private[this] val AddlVwRingSize = 10 + private[multilabel] val AddlVwRingSize = 10 /** * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index ac8b561d..d281588c 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -62,6 +62,11 @@ class VwMultilabelModelTest { @Test def testTrainedModelCanBeParsedAndUsedWithNoLabels(): Unit = testModel(Option(Set.empty), modelShouldProduceOutput = false) + // This is really more of an integration test. + @Test def testTrainingAndTesting(): Unit = { + fail() + } + /** * * @param desiredLabels Notice since this is a Set, label order doesn't matter. diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala new file mode 100644 index 00000000..38fc47ba --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala @@ -0,0 +1,73 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.{ByteArrayOutputStream, File, FileInputStream} + +import com.eharmony.aloha.ModelSerializationTestHelper +import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.IOUtils +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +/** + * Created by ryan.deak on 9/27/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwSparseMultilabelPredictorTest extends ModelSerializationTestHelper { + import VwSparseMultilabelPredictorTest._ + + @Test def testSerializability(): Unit = { + val predictor = getPredictor(getModelSource(), 3) + val ds = serializeDeserializeRoundTrip(predictor) + assertEquals(predictor, ds) + assertEquals(predictor.vwParams(), ds.vwParams()) + assertNotNull(ds.vwModel) + } + + @Test def testVwParameters(): Unit = { + val numLabelsInTrainingSet = 3 + val predictor = getPredictor(getModelSource(), numLabelsInTrainingSet) + + predictor.vwParams() match { + case Data(vwBinFilePath, ringSize) => + checkVwBinFile(vwBinFilePath) + checkVwRingSize(numLabelsInTrainingSet, ringSize.toInt) + case ps => fail(s"Unexpected VW parameters format. Found string: $ps") + } + } +} + +object VwSparseMultilabelPredictorTest { + private val Data = """\s*-i\s+(\S+)\s+--ring_size\s+(\d+)\s+--testonly\s+--quiet""".r + + private def getModelSource(): ModelSource = { + val f = File.createTempFile("i_dont", "care") + f.deleteOnExit() + val learner = VWLearners.create[VWActionScoresLearner](s"--quiet --csoaa_ldf mc --csoaa_rank -f ${f.getCanonicalPath}") + learner.close() + val baos = new ByteArrayOutputStream() + IOUtils.copy(new FileInputStream(f), baos) + val src = Base64StringSource(Base64.encodeBase64URLSafeString(baos.toByteArray)) + ExternalSource(src.localVfs) + } + + private def getPredictor(modelSrc: ModelSource, numLabelsInTrainingSet: Int) = + VwSparseMultilabelPredictor[Any](modelSrc, "", Nil, Nil, numLabelsInTrainingSet) + + private def checkVwBinFile(vwBinFilePath: String): Unit = { + val vwBinFile = new File(vwBinFilePath) + assertTrue("VW binary file should have been written to disk", vwBinFile.exists()) + vwBinFile.deleteOnExit() + } + + private def checkVwRingSize(numLabelsInTrainingSet: Int, ringSize: Int): Unit = { + assertEquals( + "vw --ring_size parameter is incorrect:", + numLabelsInTrainingSet + VwSparseMultilabelPredictor.AddlVwRingSize, + ringSize.toInt + ) + } +} \ No newline at end of file From b3fdff496c1e05bc4ef23afac9ebe6bcb8172fe2 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 28 Sep 2017 17:07:13 -0700 Subject: [PATCH 64/98] End to end testing working. Need to clean it up. --- .../multilabel/VwMultilabelRowCreator.scala | 2 +- .../models/multilabel/MultilabelModel.scala | 4 +- .../multilabel/VwMultilabelModelTest.scala | 173 +++++++++++++++++- 3 files changed, 172 insertions(+), 7 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index e9fd353c..77657d8d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -288,7 +288,7 @@ object VwMultilabelRowCreator { * @tparam A type of input passed to the [[RowCreator]]. * @tparam K the label type. */ - final class Producer[A, K: RefInfo](allLabelsInTrainingSet: Vector[K]) + final class Producer[A, K: RefInfo](allLabelsInTrainingSet: sci.IndexedSeq[K]) extends RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] with RowCreatorProducerName with VwCovariateProducer[A] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 688209a7..17022ca8 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -378,7 +378,9 @@ object MultilabelModel extends ParserProviderCompanion { // Then the necessary type classes related to K are *instantiated* and a reader is // created using types N and K. Ultimately, the reader consumes the type `K` but // only the type N is exposed in the returned reader. - if (!RefInfoOps.isSubType[N, Map[_, Double]]) { + // + // NOTE: The following is an not adequate check: !RefInfoOps.isSubType[N, Map[_, Double]] + if (!RefInfoOps.isSubType[N, Map[Any, Double]]) { warn(s"N=${RefInfoOps.toString[N]} is not a Map[K, Double]. Cannot create a JsonReader for MultilabelModel.") None } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index d281588c..c2ec9811 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -6,6 +6,7 @@ import com.eharmony.aloha.audit.impl.OptionAuditor import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.id.ModelId import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} @@ -18,11 +19,12 @@ import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner -import spray.json.DefaultJsonProtocol.{StringJsonFormat, vectorFormat} -import spray.json.{DefaultJsonProtocol, JsonWriter} +import spray.json.DefaultJsonProtocol.{IntJsonFormat, StringJsonFormat} +import spray.json.{DefaultJsonProtocol, JsonWriter, RootJsonFormat, pimpString} import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} import scala.annotation.tailrec +import scala.util.Random /** * Created by ryan.deak on 9/11/17. @@ -64,7 +66,165 @@ class VwMultilabelModelTest { // This is really more of an integration test. @Test def testTrainingAndTesting(): Unit = { - fail() + + // ------------------------------------------------------------------------------------ + // Preliminaries + // ------------------------------------------------------------------------------------ + + type Lab = Int + + // Not abuse of notation. Model domain is indeed a vector of labels. + type Dom = Vector[Lab] + + val semantics = CompiledSemanticsInstances.anyNameIdentitySemantics[Dom] + val optAud = OptionAuditor[Map[Lab, Double]]() + + // ------------------------------------------------------------------------------------ + // Dataset and test example set up. + // ------------------------------------------------------------------------------------ + + // Marginal Prob Dist: Pr[ 8 ] = 0.80 = 20 / 25 = (12 + 8) / 25 + // Pr[ 4 ] = 0.40 = 10 / 25 = (2 + 8) / 25 + val unshuffledTrainingSet = Seq( // JPD: + Vector(None, None) -> 3, // pr = 0.12 = 3 / 25 + Vector(None, Option(8)) -> 12, // pr = 0.48 = 12 / 25 + Vector(Option(4), None) -> 2, // pr = 0.08 = 2 / 25 + Vector(Option(4), Option(8)) -> 8 // pr = 0.32 = 8 / 25 + ) flatMap { + case (k, n) => + val flattened = k.flatten + Vector.fill(n)(flattened) + } + + val trainingSet = new Random(0).shuffle(unshuffledTrainingSet) + + val labelsInTrainingSet = trainingSet.flatten.toSet.toVector.sorted + + val testExample = Vector.empty[Lab] + + val marginalDist = labelsInTrainingSet.map { label => + val z = trainingSet.size.toDouble + label -> trainingSet.count(row => row contains label) / z + }.toMap + + // ------------------------------------------------------------------------------------ + // Prepare dataset specification and read training dataset. + // ------------------------------------------------------------------------------------ + + val datasetJson = + """ + |{ + | "imports": [ + | "com.eharmony.aloha.feature.BasicFunctions._" + | ], + | "features": [ + | { "name": "feature", "spec": "1" } + | ], + | "namespaces": [ + | { "name": "X", "features": [ "feature" ] } + | ], + | "normalizeFeatures": false, + | "positiveLabels": "${labels_from_input}" + |} + """.stripMargin.parseJson.convertTo[VwMultilabeledJson] + + val rc = + new VwMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet). + getRowCreator(semantics, datasetJson).get + + // ------------------------------------------------------------------------------------ + // Prepare parameters for VW model that will be trained. + // ------------------------------------------------------------------------------------ + + + val modelFile = File.createTempFile("vw_", ".bin.model") + modelFile.deleteOnExit() + + val cacheFile = File.createTempFile("vw_", ".cache") + cacheFile.deleteOnExit() + + val vwParams = + s""" + | --quiet + | --csoaa_ldf mc + | --csoaa_rank + | --loss_function logistic + | -q ${labelNs}X + | --noconstant + | --ignore_linear X + | --ignore $dummyLabelNs + | -f ${modelFile.getCanonicalPath} + | --passes 50 + | --cache_file ${cacheFile.getCanonicalPath} + | --holdout_off + | --learning_rate 5 + | --decay_learning_rate 0.9 + """.stripMargin.trim.replaceAll("\n", " ") + + + // ------------------------------------------------------------------------------------ + // Train VW model + // ------------------------------------------------------------------------------------ + + val vwLearner = VWLearners.create[VWActionScoresLearner](vwParams) + trainingSet foreach { row => + val x = rc(row)._2 + vwLearner.learn(x) + } + vwLearner.close() + + // ------------------------------------------------------------------------------------ + // Create Aloha model JSON + // ------------------------------------------------------------------------------------ + + val modelSource: ModelSource = ExternalSource(Vfs.javaFileToAloha(modelFile)) + + val modelJson: String = + s""" + |{ + | "modelType": "multilabel-sparse", + | "modelId": { "id": 1, "name": "NONE" }, + | "features": { + | "feature": "1" + | }, + | "numMissingThreshold": 0, + | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, + | "underlying": { + | "type": "vw", + | "modelSource": ${toJsonString(modelSource)}, + | "namespaces": { + | "X": [ + | "feature" + | ] + | } + | } + |} + """.stripMargin + + + // ------------------------------------------------------------------------------------ + // Instantiate Aloha Model + // ------------------------------------------------------------------------------------ + + val factory = ModelFactory.defaultFactory(semantics, optAud) + val modelTry = factory.fromString(modelJson) + val model = modelTry.get + + // ------------------------------------------------------------------------------------ + // Test Aloha Model + // ------------------------------------------------------------------------------------ + + val output = model(testExample) + model.close() + + output match { + case None => fail() + case Some(m) => + assertEquals(marginalDist.keySet, m.keySet) + marginalDist foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } } /** @@ -302,12 +462,15 @@ object VwMultilabelModelTest { private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + private implicit def vecWriter[K: JsonWriter]: RootJsonFormat[Vector[K]] = + DefaultJsonProtocol.vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) + private[multilabel] def modelJson[K: JsonWriter]( modelSource: ModelSource, labelsInTrainingSet: Vector[K], labelsOfInterest: Option[String] = None) = { - implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) +// implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) val loi = labelsOfInterest.fold(""){ f => val escaped = f.replaceAll("\"", "\\\"") @@ -339,7 +502,7 @@ object VwMultilabelModelTest { json } - private[this] def toJsonString[A: JsonWriter](a: A): String = + private def toJsonString[A: JsonWriter](a: A): String = implicitly[JsonWriter[A]].write(a).compactPrint From 7087d04f4744cc71bc7f36e926d4738b8598997d Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 28 Sep 2017 17:19:23 -0700 Subject: [PATCH 65/98] a little cleanup. --- .../models/vw/jni/multilabel/VwMultilabelModelTest.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index c2ec9811..ec8f5b5a 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -72,9 +72,7 @@ class VwMultilabelModelTest { // ------------------------------------------------------------------------------------ type Lab = Int - - // Not abuse of notation. Model domain is indeed a vector of labels. - type Dom = Vector[Lab] + type Dom = Vector[Lab] // Not an abuse of notation. Domain is indeed a vector of labels. val semantics = CompiledSemanticsInstances.anyNameIdentitySemantics[Dom] val optAud = OptionAuditor[Map[Lab, Double]]() @@ -100,7 +98,7 @@ class VwMultilabelModelTest { val labelsInTrainingSet = trainingSet.flatten.toSet.toVector.sorted - val testExample = Vector.empty[Lab] + val testExample: Dom = Vector.empty val marginalDist = labelsInTrainingSet.map { label => val z = trainingSet.size.toDouble From 2a0b5c070e852c9cc48a7c1035ea5abe04400716 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 29 Sep 2017 15:12:20 -0700 Subject: [PATCH 66/98] Removed implicit fn com.eharmony.aloha.factory.ScalaJsonFormats.lift(JsonReader). It was turning JsonFormats to JsonReaders and back to JsonFormats without writing capabilities. --- .../src/main/scala/com/eharmony/aloha/factory/formats.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala index 026fbd4f..80e00364 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala @@ -60,11 +60,6 @@ trait JavaJsonFormats { object ScalaJsonFormats extends ScalaJsonFormats trait ScalaJsonFormats { - // This is a very slightly modified copy of the lift from Additional formats that removes the type bound. - implicit def lift[A](implicit reader: JsonReader[A]): JsonFormat[A] = new JsonFormat[A] { - def write(a: A): JsValue = throw new UnsupportedOperationException("No JsonWriter[" + a.getClass + "] available") - def read(value: JsValue): A = reader.read(value) - } implicit def listMapFormat[K :JsonFormat, V :JsonFormat] = new RootJsonFormat[ListMap[K, V]] { def write(m: ListMap[K, V]) = JsObject { From e11aef19ab794d2241ff9e0980980e8f098293f4 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 29 Sep 2017 15:12:39 -0700 Subject: [PATCH 67/98] simplifying tests. --- .../vw/jni/multilabel/VwMultilabelModel.scala | 133 ++++++++++++++++++ .../vw/jni/multilabel/dataset_spec.json | 13 ++ .../multilabel/VwMultilabelModelTest.scala | 93 ++++++------ 3 files changed, 186 insertions(+), 53 deletions(-) create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala create mode 100644 aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala new file mode 100644 index 00000000..0455aa6d --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala @@ -0,0 +1,133 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.id.{ModelId, ModelIdentity} +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.models.multilabel.json.MultilabelModelJson +import com.eharmony.aloha.models.reg.json.Spec +import com.eharmony.aloha.models.vw.jni.VwJniModel +import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelJson +import spray.json.{DefaultJsonProtocol, JsValue, JsonWriter, pimpAny, pimpString} + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/29/17. + */ +object VwMultilabelModel +extends MultilabelModelJson + with VwMultilabelModelJson { + + /** + * Create a JSON representation of an Aloha model. + * + * '''NOTE''': Because of the inclusion of the unrestricted `labelsOfInterest` parameter, + * the JSON produced by this function is not guaranteed to result in a valid + * Aloha model. This is because no semantics are required by this function and + * so, the `labelsOfInterest` function specification cannot be validated. + * + * @param datasetSpec a location of a dataset specification. + * @param binaryVwModel a location of a VW binary model file. + * @param id a model ID. + * @param labelsInTrainingSet The sequence of all labels encountered in the training set used + * to produce the `binaryVwModel`. + * '''''It is extremely important that this sequence has the + * same order as the sequence of labels used in the + * dataset creation process.''''' Otherwise, the VW model might + * associate scores with an incorrect label. + * @param labelsOfInterest It is possible that a model is trained on a super set of labels for + * which predictions can be made. If the labels at prediction time + * differs (''or should be extracted from the input to the model''), + * this function can provide that capability. + * @param vwArgs arguments that should be passed to the VW model. This likely isn't strictly + * necessary. + * @param externalModel whether the underlying binary VW model should remain as a separate + * file and be referenced by the Aloha model specification (`true`) + * or the binary model content should be embeeded directly into the model + * (`false`). '''Keep in mind''' Aloha models must be smaller than 2 GB + * because they are decoded to `String`s and `String`s are indexed by + * 32-bit integers (which have a max value of 2^32^ - 1). + * @param numMissingThreshold the number of missing features to tolerate before emitting a + * prediction failure. + * @tparam K the type of label or class. + * @return a JSON object. + */ + def json[K: JsonWriter]( + datasetSpec: Vfs, + binaryVwModel: Vfs, + id: ModelIdentity, + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String] = None, + vwArgs: Option[String] = None, + externalModel: Boolean = false, + numMissingThreshold: Option[Int] = None + ): JsValue = { + + val dsJsAst = StringReadable.fromInputStream(datasetSpec.inputStream).parseJson + val ds = dsJsAst.convertTo[VwMultilabeledJson] + val features = modelFeatures(ds.features) + val namespaces = ds.namespaces.map(modelNamespaces) + val modelSrc = modelSource(binaryVwModel, externalModel) + val vw = vwModelPlugin(modelSrc, vwArgs, namespaces) + val model = modelAst(id, features, numMissingThreshold, labelsInTrainingSet, labelsOfInterest, vw) + + // Even though we are only *writing* JSON, we need a JsonFormat[K], (which is a reader + // and writer) because multilabelDataJsonFormat has a JsonFormat context bound on K. + // So to turn MultilabelData into JSON, we need to lift the JsonWriter for K into a + // JsonFormat and store it as an implicit value (or create an implicit conversion + // function on implicit arguments. + implicit val labelJF = DefaultJsonProtocol.lift(implicitly[JsonWriter[K]]) + model.toJson + } + + private[multilabel] def modelFeatures(featuresSpecs: Seq[SparseSpec]) = + ListMap ( + featuresSpecs.map { case SparseSpec(name, spec, defVal) => + name -> Spec(spec, defVal.map(ts => ts.toSeq)) + }: _* + ) + + private[multilabel] def modelNamespaces(nss: Seq[Namespace]) = + ListMap( + nss.map(ns => ns.name -> ns.features) : _* + ) + + private[multilabel] def modelSource(binaryVwModel: Vfs, externalModel: Boolean) = + if (externalModel) + ExternalSource(binaryVwModel) + else Base64StringSource(VwJniModel.readBinaryVwModelToB64String(binaryVwModel.inputStream)) + + // Private b/c VwMultilabelAst is protected[this]. Don't let it escape. + private def vwModelPlugin( + modelSrc: ModelSource, + vwArgs: Option[String], + namespaces: Option[ListMap[String, Seq[String]]]) = + VwMultilabelAst( + VwSparseMultilabelPredictorProducer.multilabelPlugin.name, + modelSrc, + vwArgs.map(a => Right(a)), + namespaces + ) + + // Private b/c MultilabelData is private[this]. Don't let it escape. + private def modelAst[K]( + id: ModelIdentity, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String], + vwPlugin: VwMultilabelAst) = + MultilabelData( + modelType = MultilabelModel.parser.modelType, + modelId = ModelId(id.getId(), id.getName()), + features = features, + numMissingThreshold = numMissingThreshold, + labelsInTrainingSet = labelsInTrainingSet.toVector, + labelsOfInterest = labelsOfInterest, + underlying = vwPlugin.toJson.asJsObject + ) +} diff --git a/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json new file mode 100644 index 00000000..fdf1bbf2 --- /dev/null +++ b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json @@ -0,0 +1,13 @@ +{ + "imports": [ + "com.eharmony.aloha.feature.BasicFunctions._" + ], + "features": [ + { "name": "feature", "spec": "1" } + ], + "namespaces": [ + { "name": "X", "features": [ "feature" ] } + ], + "normalizeFeatures": false, + "positiveLabels": "${labels_from_input}" +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index ec8f5b5a..68a46fc5 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -9,12 +9,14 @@ import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelName import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson import com.eharmony.aloha.factory.ModelFactory import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.StringReadable import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances import com.eharmony.aloha.semantics.func.GenFunc0 +import org.apache.commons.vfs2.VFS import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith @@ -81,17 +83,16 @@ class VwMultilabelModelTest { // Dataset and test example set up. // ------------------------------------------------------------------------------------ - // Marginal Prob Dist: Pr[ 8 ] = 0.80 = 20 / 25 = (12 + 8) / 25 - // Pr[ 4 ] = 0.40 = 10 / 25 = (2 + 8) / 25 - val unshuffledTrainingSet = Seq( // JPD: - Vector(None, None) -> 3, // pr = 0.12 = 3 / 25 - Vector(None, Option(8)) -> 12, // pr = 0.48 = 12 / 25 - Vector(Option(4), None) -> 2, // pr = 0.08 = 2 / 25 - Vector(Option(4), Option(8)) -> 8 // pr = 0.32 = 8 / 25 + // Marginal Distribution: Pr[8] = 0.80 = 20 / 25 = (12 + 8) / 25 + // Pr[4] = 0.40 = 10 / 25 = ( 2 + 8) / 25 + // + val unshuffledTrainingSet: Seq[Dom] = Seq( + Vector( ) -> 3, // pr = 0.12 = 3 / 25 <-- JPD + Vector( 8) -> 12, // pr = 0.48 = 12 / 25 + Vector(4 ) -> 2, // pr = 0.08 = 2 / 25 + Vector(4, 8) -> 8 // pr = 0.32 = 8 / 25 ) flatMap { - case (k, n) => - val flattened = k.flatten - Vector.fill(n)(flattened) + case (k, n) => Vector.fill(n)(k) } val trainingSet = new Random(0).shuffle(unshuffledTrainingSet) @@ -109,22 +110,8 @@ class VwMultilabelModelTest { // Prepare dataset specification and read training dataset. // ------------------------------------------------------------------------------------ - val datasetJson = - """ - |{ - | "imports": [ - | "com.eharmony.aloha.feature.BasicFunctions._" - | ], - | "features": [ - | { "name": "feature", "spec": "1" } - | ], - | "namespaces": [ - | { "name": "X", "features": [ "feature" ] } - | ], - | "normalizeFeatures": false, - | "positiveLabels": "${labels_from_input}" - |} - """.stripMargin.parseJson.convertTo[VwMultilabeledJson] + val datasetSpec = VFS.getManager.resolveFile(EndToEndDatasetSpec) + val datasetJson = StringReadable.fromVfs2(datasetSpec).parseJson.convertTo[VwMultilabeledJson] val rc = new VwMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet). @@ -135,12 +122,22 @@ class VwMultilabelModelTest { // ------------------------------------------------------------------------------------ - val modelFile = File.createTempFile("vw_", ".bin.model") - modelFile.deleteOnExit() + val binaryVwModel = File.createTempFile("vw_", ".bin.model") + binaryVwModel.deleteOnExit() val cacheFile = File.createTempFile("vw_", ".cache") cacheFile.deleteOnExit() + // probs / costs + // --csoaa_ldf mc vs m + // --csoaa_rank + // --loss_function logistic add this + // --noconstant + // --ignore_linear X + // --ignore $dummyLabelNs + // turn first order features into quadratics + // turn quadratic features into cubics + // val vwParams = s""" | --quiet @@ -151,7 +148,7 @@ class VwMultilabelModelTest { | --noconstant | --ignore_linear X | --ignore $dummyLabelNs - | -f ${modelFile.getCanonicalPath} + | -f ${binaryVwModel.getCanonicalPath} | --passes 50 | --cache_file ${cacheFile.getCanonicalPath} | --holdout_off @@ -175,29 +172,16 @@ class VwMultilabelModelTest { // Create Aloha model JSON // ------------------------------------------------------------------------------------ - val modelSource: ModelSource = ExternalSource(Vfs.javaFileToAloha(modelFile)) - - val modelJson: String = - s""" - |{ - | "modelType": "multilabel-sparse", - | "modelId": { "id": 1, "name": "NONE" }, - | "features": { - | "feature": "1" - | }, - | "numMissingThreshold": 0, - | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, - | "underlying": { - | "type": "vw", - | "modelSource": ${toJsonString(modelSource)}, - | "namespaces": { - | "X": [ - | "feature" - | ] - | } - | } - |} - """.stripMargin + val modelJson = VwMultilabelModel.json( + datasetSpec = Vfs.apacheVfs2ToAloha(datasetSpec), + binaryVwModel = Vfs.javaFileToAloha(binaryVwModel), + id = ModelId(1, "NONE"), + labelsInTrainingSet = labelsInTrainingSet, + labelsOfInterest = Option.empty[String], + vwArgs = Option.empty[String], + externalModel = false, + numMissingThreshold = Option(0) + ) // ------------------------------------------------------------------------------------ @@ -205,7 +189,7 @@ class VwMultilabelModelTest { // ------------------------------------------------------------------------------------ val factory = ModelFactory.defaultFactory(semantics, optAud) - val modelTry = factory.fromString(modelJson) + val modelTry = factory.fromString(modelJson.prettyPrint) // Use `.compactPrint` in prod. val model = modelTry.get // ------------------------------------------------------------------------------------ @@ -330,6 +314,9 @@ object VwMultilabelModelTest { */ private val AllLabels = Vector(LabelSeven, LabelEight, LabelSix) + private val EndToEndDatasetSpec = + "res:com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json" + private lazy val Model: Model[Domain, RootedTree[Any, Map[Label, Double]]] = { val featureNames = Vector(FeatureName) val features = Vector(GenFunc0("", (_: Domain) => Iterable(("", 1d)))) From c2a048eda387ea39db744d1740d96c6ecf2f784e Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 3 Oct 2017 15:43:34 -0700 Subject: [PATCH 68/98] vw param function skeleton. --- .../models/vw/jni/multilabel/VwMultilabelModel.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala index 0455aa6d..00210139 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala @@ -22,6 +22,17 @@ object VwMultilabelModel extends MultilabelModelJson with VwMultilabelModelJson { + /** + * Add VW parameters to make the multilabel model work: + * @param vwParams current VW parameters passed to the VW JNI + * @param namespaceNames + * @return + */ + def updatedVwParams(vwParams: String, namespaceNames: Set[String]): String = { + + ??? + } + /** * Create a JSON representation of an Aloha model. * From 3cb9d2ecd675482f70550ef99493aee94680ce21 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 4 Oct 2017 17:27:15 -0700 Subject: [PATCH 69/98] non working VwMultilabelModel.updatedVwParams. Skeleton laid out. --- .../vw/jni/multilabel/VwMultilabelModel.scala | 182 +++++++++++++++++- .../VwSparseMultilabelPredictor.scala | 6 +- .../multilabel/VwMultilabelModelTest.scala | 16 ++ 3 files changed, 198 insertions(+), 6 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala index 00210139..fba5bd3e 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala @@ -1,6 +1,7 @@ package com.eharmony.aloha.models.vw.jni.multilabel import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson import com.eharmony.aloha.id.{ModelId, ModelIdentity} import com.eharmony.aloha.io.StringReadable @@ -10,17 +11,105 @@ import com.eharmony.aloha.models.multilabel.MultilabelModel import com.eharmony.aloha.models.multilabel.json.MultilabelModelJson import com.eharmony.aloha.models.reg.json.Spec import com.eharmony.aloha.models.vw.jni.VwJniModel +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelJson +import org.apache.commons.io.IOUtils import spray.json.{DefaultJsonProtocol, JsValue, JsonWriter, pimpAny, pimpString} +import vowpalWabbit.learner.VWLearners import scala.collection.immutable.ListMap +import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} + /** * Created by ryan.deak on 9/29/17. */ object VwMultilabelModel -extends MultilabelModelJson - with VwMultilabelModelJson { + extends MultilabelModelJson + with VwMultilabelModelJson { + + sealed trait VwParamError { + def originalParams: String + def modifiedParams: Option[String] + def errorMessage: String + } + + final case class UnrecoverableParams( + originalParams: String, + unrecoverable: Set[String] + ) extends VwParamError { + def modifiedParams: Option[String] = None + def errorMessage: String = + "Encountered parameters that cannot be augmented: " + + unrecoverable.toSeq.sorted.mkString(", ") + + s"\n\toriginal parameters: $originalParams" + } + + final case class IncorrectLearner( + originalParams: String, + modifiedPs: String, + learnerCanonicalName: String + ) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Found: $learnerCanonicalName." + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" + } + + final case class ClassCastErr( + originalParams: String, + modifiedPs: String, + ccException: ClassCastException + ) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Encountered ClassCastException: ${ccException.getMessage}" + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" + } + + final case class VwError( + originalParams: String, + modifiedPs: String, + vwErrMsg: String + ) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"VW could not create a learner of type " + + s"${classOf[ExpectedLearner].getCanonicalName}. Error: $vwErrMsg. " + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" + } + + private[this] def pad(s: String) = "\\s" + s + "\\s" + private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r + private[this] val CsoaaLdf = pad("""--csoaa_ldf\s+(mc?)""").r + private[this] val CsoaaRegression = "m" + private[this] val CsoaaClassification = "mc" + private[this] val Quiet = pad("--quiet").r + private[this] val Keep = pad("""--keep\s+(\S+)""").r + private[this] val Ignore = pad("""--ignore\s+(\S+)""").r + private[this] val IgnoreLinear = pad("""--ignore_linear\s+(\S+)""").r + private[this] val UnrecoverableFlags = pad("""--(cubic|redefine)""").r + private[this] val QuadraticsShort = pad("""-q\s*([\S])([\S])\s""").r + private[this] val QuadraticsLong = pad("""--quadratic\s+([\S])([\S])""").r + + //$VW -qas -qdf --quadratic fg + // --ignore + // --ignore_linear + // -q, --quadratic + // cubic return None + // --redefine return None + // --noconstant + // -C, --constant + // --csoaa_ldf mc + // --wap_ldf /** * Add VW parameters to make the multilabel model work: @@ -28,9 +117,49 @@ extends MultilabelModelJson * @param namespaceNames * @return */ - def updatedVwParams(vwParams: String, namespaceNames: Set[String]): String = { + def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { + val padded = s" ${vwParams.trim} " + + // --ignore + // --ignore_linear + // -q, --quadratic + // cubic return None + // --redefine return None + // --noconstant + // -C, --constant + // --csoaa_ldf mc + // --wap_ldf + + val unrecoverableFlags = UnrecoverableFlags.findAllMatchIn(vwParams).map(m => m.group(1)).toSet + if (unrecoverableFlags.nonEmpty) + Left(UnrecoverableParams(vwParams, unrecoverableFlags)) + else { + val q = quadratics(padded) + val nsq = nssInQuadratics(q) + val k = kept(padded) + val i = ignored(padded) + val il = ignoredLinear(padded) + val quiet = isQuiet(padded) + + // TODO: or just determineLabelNamespaces(namespaceNames) ? + val allNs = k ++ nsq ++ namespaceNames.flatMap(_.take(1).toCharArray) + val labelNss = VwMultilabelRowCreator.determineLabelNamespaces(allNs.map(_.toString)) - ??? + // lift "kept" first order namespaces to quadratics with the output namespace + // lift quadratics to cubics with the output namespace + // introduce + // ignore for the dummy class namespace + // ignore linear for the other namespaces + // no constant + // + + // how to deal with ignore (not ignore linear) + + + val updatedParams: String = vwParams // TODO: Fill this in. + + validateVwParams(vwParams, updatedParams, !quiet) + } } /** @@ -95,6 +224,51 @@ extends MultilabelModelJson model.toJson } + // ============== updatedVwParams support functions =============== + + private[multilabel] def quad(r: Regex, chrSeq: CharSequence): Set[(Char, Char)] = + r.findAllMatchIn(chrSeq).map { m => + val Seq(a, b) = (1 to 2).map(i => m.group(i).charAt(0)) + if (a < b) (a, b) else (b, a) + }.toSet + + private[multilabel] def charsIn(r: Regex, chrSeq: CharSequence): Set[Char] = + r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet + + private[multilabel] def quadratics(padded: String): Set[(Char, Char)] = + quad(QuadraticsShort, padded) ++ quad(QuadraticsLong, padded) + + private[multilabel] def nssInQuadratics(quadratics: Set[(Char, Char)]): Set[Char] = + quadratics flatMap { case (a, b) => Set(a, b) } + + private[multilabel] def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty + + private[multilabel] def kept(padded: String): Set[Char] = charsIn(Keep, padded) + private[multilabel] def ignored(padded: String): Set[Char] = charsIn(Ignore, padded) + private[multilabel] def ignoredLinear(padded: String): Set[Char] = charsIn(IgnoreLinear, padded) + + private[multilabel] def handleClassCastException(orig: String, mod: String, ex: ClassCastException) = + ex.getMessage match { + case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) + case _ => ClassCastErr(orig, mod, ex) + } + + private[multilabel] def validateVwParams(orig: String, mod: String, addQuiet: Boolean) = { + val ps = if (addQuiet) s"--quiet $mod" else mod + + Try { VWLearners.create[ExpectedLearner](ps) } match { + case Success(m) => + IOUtils.closeQuietly(m) + Right(mod) + case Failure(cce: ClassCastException) => + Left(handleClassCastException(orig, mod, cce)) + case Failure(ex) => + Left(VwError(orig, mod, ex.getMessage)) + } + } + + // ==================== json support functions ==================== + private[multilabel] def modelFeatures(featuresSpecs: Seq[SparseSpec]) = ListMap ( featuresSpecs.map { case SparseSpec(name, spec, defVal) => diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 6af487d9..12504d1c 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -100,6 +100,8 @@ object VwSparseMultilabelPredictor { private[multilabel] val AddlVwRingSize = 10 + private[multilabel] type ExpectedLearner = VWActionScoresLearner + /** * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. * @param pred predictions returned by the underlying VW ''CSOAA LDF'' model. @@ -171,10 +173,10 @@ object VwSparseMultilabelPredictor { modelSource: ModelSource, params: String, numLabelsInTrainingSet: Int - ): (String, Try[VWActionScoresLearner]) = { + ): (String, Try[ExpectedLearner]) = { val modelFile = modelSource.localVfs.replicatedToLocal() val updatedParams = paramsWithSource(modelFile.fileObj, params, numLabelsInTrainingSet) - val learner = Try { VWLearners.create[VWActionScoresLearner](updatedParams) } + val learner = Try { VWLearners.create[ExpectedLearner](updatedParams) } (updatedParams, learner) } } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 68a46fc5..eebf3388 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -14,6 +14,7 @@ import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.IncorrectLearner import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances import com.eharmony.aloha.semantics.func.GenFunc0 import org.apache.commons.vfs2.VFS @@ -35,6 +36,21 @@ import scala.util.Random class VwMultilabelModelTest { import VwMultilabelModelTest._ + + @Test def testWrongLearner(): Unit = { + val args = "--csoaa_ldf mc" + VwMultilabelModel.updatedVwParams("--csoaa_ldf mc", Set.empty) match { + case Left(IncorrectLearner(o, _, c)) => + assertEquals(args, o) + assertNotEquals(classOf[VwSparseMultilabelPredictor.ExpectedLearner].getCanonicalName, c) + case _ => fail() + } + } + + // TODO: More VW argument augmentation tests!!! + + + @Test def testTrainedModelWorks(): Unit = { val model = Model From fb7a8f2241ee46c2775e2e308c60eb70e7596367 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 5 Oct 2017 14:09:38 -0700 Subject: [PATCH 70/98] quadratics and cubics seem to be working. --- .../multilabel/VwMultilabelRowCreator.scala | 14 +- .../VwMultilabelRowCreatorTest.scala | 2 +- .../vw/jni/multilabel/VwMultilabelModel.scala | 314 +----------------- .../VwMultilabelParamAugmentation.scala | 215 ++++++++++++ .../multilabel/VwMultlabelJsonCreator.scala | 133 ++++++++ .../vw/jni/multilabel/VwParamError.scala | 97 ++++++ .../VwSparseMultilabelPredictorProducer.scala | 2 +- .../multilabel/VwMultilabelModelTest.scala | 45 ++- 8 files changed, 495 insertions(+), 327 deletions(-) create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala create mode 100644 aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 77657d8d..6c75df51 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -24,8 +24,8 @@ final case class VwMultilabelRowCreator[-A, K]( namespaces: List[(String, List[Int])], normalizer: Option[CharSequence => CharSequence], positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], - classNs: String, - dummyClassNs: String, + classNs: Char, + dummyClassNs: Char, includeZeroValues: Boolean = false ) extends RowCreator[A, Array[String]] { import VwMultilabelRowCreator._ @@ -134,7 +134,7 @@ object VwMultilabelRowCreator { private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[LabelNamespaces] = { PreferredLabelNamespaces collectFirst { case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) => - LabelNamespaces(actual.toString, dummy.toString) + LabelNamespaces(actual, dummy) } } @@ -167,7 +167,7 @@ object VwMultilabelRowCreator { found match { case actual #:: dummy #:: Stream.Empty => - Option(LabelNamespaces(actual.toChar.toString, dummy.toChar.toString)) + Option(LabelNamespaces(actual.toChar, dummy.toChar)) case _ => None } } @@ -194,8 +194,8 @@ object VwMultilabelRowCreator { positiveLabelIndices: Int => Boolean, defaultNs: List[Int], namespaces: List[(String, List[Int])], - classNs: String, - dummyClassNs: String + classNs: Char, + dummyClassNs: Char ): Array[String] = { val n = indices.size @@ -352,5 +352,5 @@ object VwMultilabelRowCreator { semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) } - private[aloha] final case class LabelNamespaces(labelNs: String, dummyLabelNs: String) + private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char) } \ No newline at end of file diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala index 8e7d5492..f4ac646f 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -99,7 +99,7 @@ class VwMultilabelRowCreatorTest { expectedResults: Seq[Boolean], actualResults: Array[String], prefix: Seq[String], - labelNs: String + labelNs: Char ): Unit = { val suffix = diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala index fba5bd3e..e754aae5 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala @@ -1,318 +1,8 @@ package com.eharmony.aloha.models.vw.jni.multilabel -import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} -import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator -import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson -import com.eharmony.aloha.id.{ModelId, ModelIdentity} -import com.eharmony.aloha.io.StringReadable -import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} -import com.eharmony.aloha.io.vfs.Vfs -import com.eharmony.aloha.models.multilabel.MultilabelModel -import com.eharmony.aloha.models.multilabel.json.MultilabelModelJson -import com.eharmony.aloha.models.reg.json.Spec -import com.eharmony.aloha.models.vw.jni.VwJniModel -import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner -import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelJson -import org.apache.commons.io.IOUtils -import spray.json.{DefaultJsonProtocol, JsValue, JsonWriter, pimpAny, pimpString} -import vowpalWabbit.learner.VWLearners - -import scala.collection.immutable.ListMap -import scala.util.matching.Regex -import scala.util.{Failure, Success, Try} - /** * Created by ryan.deak on 9/29/17. */ -object VwMultilabelModel - extends MultilabelModelJson - with VwMultilabelModelJson { - - sealed trait VwParamError { - def originalParams: String - def modifiedParams: Option[String] - def errorMessage: String - } - - final case class UnrecoverableParams( - originalParams: String, - unrecoverable: Set[String] - ) extends VwParamError { - def modifiedParams: Option[String] = None - def errorMessage: String = - "Encountered parameters that cannot be augmented: " + - unrecoverable.toSeq.sorted.mkString(", ") + - s"\n\toriginal parameters: $originalParams" - } - - final case class IncorrectLearner( - originalParams: String, - modifiedPs: String, - learnerCanonicalName: String - ) extends VwParamError { - override def modifiedParams: Option[String] = Option(modifiedPs) - override def errorMessage: String = - s"Params produced an incorrect learner type. " + - s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + - s"Found: $learnerCanonicalName." + - s"\n\toriginal parameters: $originalParams" + - s"\n\tmodified parameters: $modifiedPs" - } - - final case class ClassCastErr( - originalParams: String, - modifiedPs: String, - ccException: ClassCastException - ) extends VwParamError { - override def modifiedParams: Option[String] = Option(modifiedPs) - override def errorMessage: String = - s"Params produced an incorrect learner type. " + - s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + - s"Encountered ClassCastException: ${ccException.getMessage}" + - s"\n\toriginal parameters: $originalParams" + - s"\n\tmodified parameters: $modifiedPs" - } - - final case class VwError( - originalParams: String, - modifiedPs: String, - vwErrMsg: String - ) extends VwParamError { - override def modifiedParams: Option[String] = Option(modifiedPs) - override def errorMessage: String = - s"VW could not create a learner of type " + - s"${classOf[ExpectedLearner].getCanonicalName}. Error: $vwErrMsg. " + - s"\n\toriginal parameters: $originalParams" + - s"\n\tmodified parameters: $modifiedPs" - } - - private[this] def pad(s: String) = "\\s" + s + "\\s" - private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r - private[this] val CsoaaLdf = pad("""--csoaa_ldf\s+(mc?)""").r - private[this] val CsoaaRegression = "m" - private[this] val CsoaaClassification = "mc" - private[this] val Quiet = pad("--quiet").r - private[this] val Keep = pad("""--keep\s+(\S+)""").r - private[this] val Ignore = pad("""--ignore\s+(\S+)""").r - private[this] val IgnoreLinear = pad("""--ignore_linear\s+(\S+)""").r - private[this] val UnrecoverableFlags = pad("""--(cubic|redefine)""").r - private[this] val QuadraticsShort = pad("""-q\s*([\S])([\S])\s""").r - private[this] val QuadraticsLong = pad("""--quadratic\s+([\S])([\S])""").r - - //$VW -qas -qdf --quadratic fg - // --ignore - // --ignore_linear - // -q, --quadratic - // cubic return None - // --redefine return None - // --noconstant - // -C, --constant - // --csoaa_ldf mc - // --wap_ldf - - /** - * Add VW parameters to make the multilabel model work: - * @param vwParams current VW parameters passed to the VW JNI - * @param namespaceNames - * @return - */ - def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { - val padded = s" ${vwParams.trim} " - - // --ignore - // --ignore_linear - // -q, --quadratic - // cubic return None - // --redefine return None - // --noconstant - // -C, --constant - // --csoaa_ldf mc - // --wap_ldf - - val unrecoverableFlags = UnrecoverableFlags.findAllMatchIn(vwParams).map(m => m.group(1)).toSet - if (unrecoverableFlags.nonEmpty) - Left(UnrecoverableParams(vwParams, unrecoverableFlags)) - else { - val q = quadratics(padded) - val nsq = nssInQuadratics(q) - val k = kept(padded) - val i = ignored(padded) - val il = ignoredLinear(padded) - val quiet = isQuiet(padded) - - // TODO: or just determineLabelNamespaces(namespaceNames) ? - val allNs = k ++ nsq ++ namespaceNames.flatMap(_.take(1).toCharArray) - val labelNss = VwMultilabelRowCreator.determineLabelNamespaces(allNs.map(_.toString)) - - // lift "kept" first order namespaces to quadratics with the output namespace - // lift quadratics to cubics with the output namespace - // introduce - // ignore for the dummy class namespace - // ignore linear for the other namespaces - // no constant - // - - // how to deal with ignore (not ignore linear) - - - val updatedParams: String = vwParams // TODO: Fill this in. - - validateVwParams(vwParams, updatedParams, !quiet) - } - } - - /** - * Create a JSON representation of an Aloha model. - * - * '''NOTE''': Because of the inclusion of the unrestricted `labelsOfInterest` parameter, - * the JSON produced by this function is not guaranteed to result in a valid - * Aloha model. This is because no semantics are required by this function and - * so, the `labelsOfInterest` function specification cannot be validated. - * - * @param datasetSpec a location of a dataset specification. - * @param binaryVwModel a location of a VW binary model file. - * @param id a model ID. - * @param labelsInTrainingSet The sequence of all labels encountered in the training set used - * to produce the `binaryVwModel`. - * '''''It is extremely important that this sequence has the - * same order as the sequence of labels used in the - * dataset creation process.''''' Otherwise, the VW model might - * associate scores with an incorrect label. - * @param labelsOfInterest It is possible that a model is trained on a super set of labels for - * which predictions can be made. If the labels at prediction time - * differs (''or should be extracted from the input to the model''), - * this function can provide that capability. - * @param vwArgs arguments that should be passed to the VW model. This likely isn't strictly - * necessary. - * @param externalModel whether the underlying binary VW model should remain as a separate - * file and be referenced by the Aloha model specification (`true`) - * or the binary model content should be embeeded directly into the model - * (`false`). '''Keep in mind''' Aloha models must be smaller than 2 GB - * because they are decoded to `String`s and `String`s are indexed by - * 32-bit integers (which have a max value of 2^32^ - 1). - * @param numMissingThreshold the number of missing features to tolerate before emitting a - * prediction failure. - * @tparam K the type of label or class. - * @return a JSON object. - */ - def json[K: JsonWriter]( - datasetSpec: Vfs, - binaryVwModel: Vfs, - id: ModelIdentity, - labelsInTrainingSet: Seq[K], - labelsOfInterest: Option[String] = None, - vwArgs: Option[String] = None, - externalModel: Boolean = false, - numMissingThreshold: Option[Int] = None - ): JsValue = { - - val dsJsAst = StringReadable.fromInputStream(datasetSpec.inputStream).parseJson - val ds = dsJsAst.convertTo[VwMultilabeledJson] - val features = modelFeatures(ds.features) - val namespaces = ds.namespaces.map(modelNamespaces) - val modelSrc = modelSource(binaryVwModel, externalModel) - val vw = vwModelPlugin(modelSrc, vwArgs, namespaces) - val model = modelAst(id, features, numMissingThreshold, labelsInTrainingSet, labelsOfInterest, vw) - - // Even though we are only *writing* JSON, we need a JsonFormat[K], (which is a reader - // and writer) because multilabelDataJsonFormat has a JsonFormat context bound on K. - // So to turn MultilabelData into JSON, we need to lift the JsonWriter for K into a - // JsonFormat and store it as an implicit value (or create an implicit conversion - // function on implicit arguments. - implicit val labelJF = DefaultJsonProtocol.lift(implicitly[JsonWriter[K]]) - model.toJson - } - - // ============== updatedVwParams support functions =============== - - private[multilabel] def quad(r: Regex, chrSeq: CharSequence): Set[(Char, Char)] = - r.findAllMatchIn(chrSeq).map { m => - val Seq(a, b) = (1 to 2).map(i => m.group(i).charAt(0)) - if (a < b) (a, b) else (b, a) - }.toSet - - private[multilabel] def charsIn(r: Regex, chrSeq: CharSequence): Set[Char] = - r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet - - private[multilabel] def quadratics(padded: String): Set[(Char, Char)] = - quad(QuadraticsShort, padded) ++ quad(QuadraticsLong, padded) - - private[multilabel] def nssInQuadratics(quadratics: Set[(Char, Char)]): Set[Char] = - quadratics flatMap { case (a, b) => Set(a, b) } - - private[multilabel] def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty - - private[multilabel] def kept(padded: String): Set[Char] = charsIn(Keep, padded) - private[multilabel] def ignored(padded: String): Set[Char] = charsIn(Ignore, padded) - private[multilabel] def ignoredLinear(padded: String): Set[Char] = charsIn(IgnoreLinear, padded) - - private[multilabel] def handleClassCastException(orig: String, mod: String, ex: ClassCastException) = - ex.getMessage match { - case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) - case _ => ClassCastErr(orig, mod, ex) - } - - private[multilabel] def validateVwParams(orig: String, mod: String, addQuiet: Boolean) = { - val ps = if (addQuiet) s"--quiet $mod" else mod - - Try { VWLearners.create[ExpectedLearner](ps) } match { - case Success(m) => - IOUtils.closeQuietly(m) - Right(mod) - case Failure(cce: ClassCastException) => - Left(handleClassCastException(orig, mod, cce)) - case Failure(ex) => - Left(VwError(orig, mod, ex.getMessage)) - } - } - - // ==================== json support functions ==================== - - private[multilabel] def modelFeatures(featuresSpecs: Seq[SparseSpec]) = - ListMap ( - featuresSpecs.map { case SparseSpec(name, spec, defVal) => - name -> Spec(spec, defVal.map(ts => ts.toSeq)) - }: _* - ) - - private[multilabel] def modelNamespaces(nss: Seq[Namespace]) = - ListMap( - nss.map(ns => ns.name -> ns.features) : _* - ) - - private[multilabel] def modelSource(binaryVwModel: Vfs, externalModel: Boolean) = - if (externalModel) - ExternalSource(binaryVwModel) - else Base64StringSource(VwJniModel.readBinaryVwModelToB64String(binaryVwModel.inputStream)) - - // Private b/c VwMultilabelAst is protected[this]. Don't let it escape. - private def vwModelPlugin( - modelSrc: ModelSource, - vwArgs: Option[String], - namespaces: Option[ListMap[String, Seq[String]]]) = - VwMultilabelAst( - VwSparseMultilabelPredictorProducer.multilabelPlugin.name, - modelSrc, - vwArgs.map(a => Right(a)), - namespaces - ) - - // Private b/c MultilabelData is private[this]. Don't let it escape. - private def modelAst[K]( - id: ModelIdentity, - features: ListMap[String, Spec], - numMissingThreshold: Option[Int], - labelsInTrainingSet: Seq[K], - labelsOfInterest: Option[String], - vwPlugin: VwMultilabelAst) = - MultilabelData( - modelType = MultilabelModel.parser.modelType, - modelId = ModelId(id.getId(), id.getName()), - features = features, - numMissingThreshold = numMissingThreshold, - labelsInTrainingSet = labelsInTrainingSet.toVector, - labelsOfInterest = labelsOfInterest, - underlying = vwPlugin.toJson.asJsObject - ) -} +object VwMultilabelModel extends VwMultlabelJsonCreator + with VwMultilabelParamAugmentation {} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala new file mode 100644 index 00000000..1b19f621 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -0,0 +1,215 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner +import org.apache.commons.io.IOUtils +import vowpalWabbit.learner.VWLearners + +import scala.util.{Failure, Success, Try} +import scala.util.matching.Regex + + +/** + * Created by ryan.deak on 10/5/17. + */ +private[multilabel] trait VwMultilabelParamAugmentation { + + private[multilabel] type VWNsSet = Set[Char] + private[multilabel] type VWNsCrossProdSet = Set[(Char, Char)] + + + /** + * Add VW parameters to make the multilabel model work: + * + * + * @param vwParams current VW parameters passed to the VW JNI + * @param namespaceNames it is assumed that `namespaceNames` is a superset + * of all of the namespaces referred to by any + * @return + */ + def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { + val padded = s" ${vwParams.trim} " + val unrecovFlags = unrecoverableFlags(padded) + if (unrecovFlags.nonEmpty) + Left(UnrecoverableParams(vwParams, unrecovFlags)) + else { + val q = quadratics(padded) // Turn these into cubic features later. + + // This won't effect anything if the definition of UnrecoverableFlags contains + // all of the flags referenced in the flagsRefMissingNss function. If there + // are flags referenced in flagsRefMissingNss but not in UnrecoverableFlags, + // then this is a valid check. + val flagsRefMissingNss = flagsReferencingMissingNss(padded, namespaceNames, q) + + if (flagsRefMissingNss.nonEmpty) + Left(NamespaceError(vwParams, namespaceNames, flagsRefMissingNss)) + else + VwMultilabelRowCreator.determineLabelNamespaces(namespaceNames).fold( + Left(LabelNamespaceError(vwParams, namespaceNames)): Either[VwParamError, String] + ){ labelNs => + val paramsWithoutRemoved = removeParams(padded) + val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, q, labelNs) + validateVwParams(vwParams, updatedParams, !isQuiet(updatedParams)) + } + } + } + + private[this] def pad(s: String) = "\\s" + s + "\\s" + private[this] val NumRegex = """-?(\d+(\.\d*)?|\d*\.\d+)([eE][+-]?\d+)?""" + private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r + private[this] val CsoaaRank = pad("--csoaa_rank").r + private[this] val CsoaaLdf = pad("""--csoaa_ldf\s+(mc?)""").r + private[this] val CsoaaRegression = "m" + private[this] val CsoaaClassification = "mc" + private[this] val Quiet = pad("--quiet").r + private[this] val Keep = pad("""--keep\s+(\S+)""").r + private[this] val Ignore = pad("""--ignore\s+(\S+)""").r + private[this] val IgnoreLinear = pad("""--ignore_linear\s+(\S+)""").r + private[multilabel] val UnrecoverableFlagSet = + Set("cubic", "redefine", "stage_poly", "ignore", "ignore_linear", "keep") + private[this] val UnrecoverableFlags = pad("--(" + UnrecoverableFlagSet.mkString("|") + ")").r + private[this] val QuadraticsShort = pad("""-q\s*([\S])([\S])""").r + private[this] val QuadraticsLong = pad("""--quadratic\s+([\S])([\S])""").r + private[this] val NoConstant = pad("""--noconstant""").r + private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r + private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r + + private[this] val OptionsRemoved = Seq( + QuadraticsShort, + QuadraticsLong, + NoConstant, + ConstantShort, + ConstantLong, + CsoaaRank + ) + + private[multilabel] def removeParams(padded: String): String = + OptionsRemoved.foldLeft(padded){(s, r) => + val v = r.replaceAllIn(s, " ") + println(s"after $r:".padTo(70, ' ') + v) + v + } + + private[multilabel] def addParams( + paramsAfterRemoved: String, + namespaceNames: Set[String], + oldQuadratics: VWNsCrossProdSet, + labelNs: LabelNamespaces + ): String = { + val il = toVwNsSet(namespaceNames) + val q = il.map(n => (labelNs.labelNs, n)) + val c = oldQuadratics.map { case (a, b) => firstThenOrderedCross(labelNs.labelNs, a, b) } + + val quadratics = q.toSeq.sorted.map{ case (y, x) => s"-q$y$x" }.mkString(" ") + val cubics = c.toSeq.sorted.map{ case (y, x1, x2) => s"--cubic $y$x1$x2" }.mkString(" ") + val igLin = il.toSeq.sorted.map(n => s"--ignore_linear $n").mkString(" ") + val ig = s"--ignore ${labelNs.dummyLabelNs}" + + (paramsAfterRemoved.trim + s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics").trim + } + + private[multilabel] def setOfFirstGroup(padded: String, regex: Regex): Set[String] = + regex.findAllMatchIn(padded).map(m => m.group(1)).toSet + + private[multilabel] def setOfFirstGroups(padded: String, regexs: Regex*): Set[String] = + regexs.aggregate(Set.empty[String])((_, r) => setOfFirstGroup(padded, r), _ ++ _) + + private[multilabel] def noconstant(padded: String): Boolean = + NoConstant.findFirstIn(padded).nonEmpty + + private[multilabel] def constant(padded: String): Set[Double] = + setOfFirstGroups(padded, ConstantShort, ConstantLong).map(s => s.toDouble) + + private[multilabel] def unorderedCross[A](a: A, b: A)(implicit o: Ordering[A]) = + if (o.lt(a, b)) (a, b) else (b, a) + + private[multilabel] def firstThenOrderedCross[A](first: A, b: A, c: A)(implicit o: Ordering[A]) = { + val (x, y) = unorderedCross(b, c) + (first, x, y) + } + + private[multilabel] def quad(r: Regex, chrSeq: CharSequence): VWNsCrossProdSet = + r.findAllMatchIn(chrSeq).map { m => + val Seq(a, b) = (1 to 2).map(i => m.group(i).charAt(0)) + unorderedCross(a, b) + }.toSet + + private[multilabel] def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = + r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet + + private[multilabel] def unrecoverableFlags(padded: String): Set[String] = + UnrecoverableFlags.findAllMatchIn(padded).map(m => m.group(1)).toSet + + private[multilabel] def quadratics(padded: String): VWNsCrossProdSet = + quad(QuadraticsShort, padded) ++ quad(QuadraticsLong, padded) + + private[multilabel] def crossToSet[A](cross: Set[(A, A)]): Set[A] = + cross flatMap { case (a, b) => Set(a, b) } + + private[multilabel] def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty + + private[multilabel] def kept(padded: String): VWNsSet = charsIn(Keep, padded) + private[multilabel] def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) + private[multilabel] def ignoredLinear(padded: String): VWNsSet = charsIn(IgnoreLinear, padded) + + private[multilabel] def handleClassCastException(orig: String, mod: String, ex: ClassCastException) = + ex.getMessage match { + case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) + case _ => ClassCastErr(orig, mod, ex) + } + + private[multilabel] def flagsReferencingMissingNss( + padded: String, + namespaceNames: Set[String], + q: VWNsCrossProdSet + ): Map[String, VWNsSet] = { + val nsq = crossToSet(q) + val k = kept(padded) + val i = ignored(padded) + val il = ignoredLinear(padded) + flagsReferencingMissingNss(namespaceNames, k, i, il, nsq) + } + + private[multilabel] def flagsReferencingMissingNss( + namespaceNames: Set[String], + k: VWNsSet, i: VWNsSet, il: VWNsSet, nsq: VWNsSet + ): Map[String, VWNsSet] = + nssNotInNamespaceNames( + namespaceNames, + "keep" -> k, + "ignore" -> i, + "ignore_linear" -> il, + "quadratic" -> nsq + ) + + private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = + nsNames.flatMap(_.take(1).toCharArray) + + private[multilabel] def nssNotInNamespaceNames( + nsNames: Set[String], + sets: (String, VWNsSet)* + ): Map[String, VWNsSet] = { + val vwNss = toVwNsSet(nsNames) + + sets.foldLeft(Map.empty[String, VWNsSet]){ case (m, (setName, nss)) => + val extra = nss diff vwNss + if (extra.isEmpty) m + else m + (setName -> extra) + } + } + + private[multilabel] def validateVwParams(orig: String, mod: String, addQuiet: Boolean) = { + val ps = if (addQuiet) s"--quiet $mod" else mod + + Try { VWLearners.create[ExpectedLearner](ps) } match { + case Success(m) => + IOUtils.closeQuietly(m) + Right(mod) + case Failure(cce: ClassCastException) => + Left(handleClassCastException(orig, mod, cce)) + case Failure(ex) => + Left(VwError(orig, mod, ex.getMessage)) + } + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala new file mode 100644 index 00000000..47a8ce26 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala @@ -0,0 +1,133 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.id.{ModelId, ModelIdentity} +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.models.multilabel.json.MultilabelModelJson +import com.eharmony.aloha.models.reg.json.Spec +import com.eharmony.aloha.models.vw.jni.VwJniModel +import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelJson +import spray.json.{DefaultJsonProtocol, JsValue, JsonWriter, pimpAny, pimpString} + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 10/5/17. + */ +private[multilabel] trait VwMultlabelJsonCreator +extends MultilabelModelJson + with VwMultilabelModelJson { + + /** + * Create a JSON representation of an Aloha model. + * + * '''NOTE''': Because of the inclusion of the unrestricted `labelsOfInterest` parameter, + * the JSON produced by this function is not guaranteed to result in a valid + * Aloha model. This is because no semantics are required by this function and + * so, the `labelsOfInterest` function specification cannot be validated. + * + * @param datasetSpec a location of a dataset specification. + * @param binaryVwModel a location of a VW binary model file. + * @param id a model ID. + * @param labelsInTrainingSet The sequence of all labels encountered in the training set used + * to produce the `binaryVwModel`. + * '''''It is extremely important that this sequence has the + * same order as the sequence of labels used in the + * dataset creation process.''''' Otherwise, the VW model might + * associate scores with an incorrect label. + * @param labelsOfInterest It is possible that a model is trained on a super set of labels for + * which predictions can be made. If the labels at prediction time + * differs (''or should be extracted from the input to the model''), + * this function can provide that capability. + * @param vwArgs arguments that should be passed to the VW model. This likely isn't strictly + * necessary. + * @param externalModel whether the underlying binary VW model should remain as a separate + * file and be referenced by the Aloha model specification (`true`) + * or the binary model content should be embeeded directly into the model + * (`false`). '''Keep in mind''' Aloha models must be smaller than 2 GB + * because they are decoded to `String`s and `String`s are indexed by + * 32-bit integers (which have a max value of 2^32^ - 1). + * @param numMissingThreshold the number of missing features to tolerate before emitting a + * prediction failure. + * @tparam K the type of label or class. + * @return a JSON object. + */ + def json[K: JsonWriter]( + datasetSpec: Vfs, + binaryVwModel: Vfs, + id: ModelIdentity, + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String] = None, + vwArgs: Option[String] = None, + externalModel: Boolean = false, + numMissingThreshold: Option[Int] = None + ): JsValue = { + + val dsJsAst = StringReadable.fromInputStream(datasetSpec.inputStream).parseJson + val ds = dsJsAst.convertTo[VwMultilabeledJson] + val features = modelFeatures(ds.features) + val namespaces = ds.namespaces.map(modelNamespaces) + val modelSrc = modelSource(binaryVwModel, externalModel) + val vw = vwModelPlugin(modelSrc, vwArgs, namespaces) + val model = modelAst(id, features, numMissingThreshold, labelsInTrainingSet, labelsOfInterest, vw) + + // Even though we are only *writing* JSON, we need a JsonFormat[K], (which is a reader + // and writer) because multilabelDataJsonFormat has a JsonFormat context bound on K. + // So to turn MultilabelData into JSON, we need to lift the JsonWriter for K into a + // JsonFormat and store it as an implicit value (or create an implicit conversion + // function on implicit arguments. + implicit val labelJF = DefaultJsonProtocol.lift(implicitly[JsonWriter[K]]) + model.toJson + } + + private[multilabel] def modelFeatures(featuresSpecs: Seq[SparseSpec]) = + ListMap ( + featuresSpecs.map { case SparseSpec(name, spec, defVal) => + name -> Spec(spec, defVal.map(ts => ts.toSeq)) + }: _* + ) + + private[multilabel] def modelNamespaces(nss: Seq[Namespace]) = + ListMap( + nss.map(ns => ns.name -> ns.features) : _* + ) + + private[multilabel] def modelSource(binaryVwModel: Vfs, externalModel: Boolean) = + if (externalModel) + ExternalSource(binaryVwModel) + else Base64StringSource(VwJniModel.readBinaryVwModelToB64String(binaryVwModel.inputStream)) + + // Private b/c VwMultilabelAst is protected[this]. Don't let it escape. + private def vwModelPlugin( + modelSrc: ModelSource, + vwArgs: Option[String], + namespaces: Option[ListMap[String, Seq[String]]]) = + VwMultilabelAst( + VwSparseMultilabelPredictorProducer.multilabelPlugin.name, + modelSrc, + vwArgs.map(a => Right(a)), + namespaces + ) + + // Private b/c MultilabelData is private[this]. Don't let it escape. + private def modelAst[K]( + id: ModelIdentity, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String], + vwPlugin: VwMultilabelAst) = + MultilabelData( + modelType = MultilabelModel.parser.modelType, + modelId = ModelId(id.getId(), id.getName()), + features = features, + numMissingThreshold = numMissingThreshold, + labelsInTrainingSet = labelsInTrainingSet.toVector, + labelsOfInterest = labelsOfInterest, + underlying = vwPlugin.toJson.asJsObject + ) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala new file mode 100644 index 00000000..52150de2 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala @@ -0,0 +1,97 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner + +/** + * Created by ryan.deak on 10/5/17. + */ +sealed trait VwParamError { + def originalParams: String + def modifiedParams: Option[String] + def errorMessage: String +} + +final case class UnrecoverableParams( + originalParams: String, + unrecoverable: Set[String] +) extends VwParamError { + def modifiedParams: Option[String] = None + def errorMessage: String = + "Encountered parameters that cannot be augmented: " + + unrecoverable.toSeq.sorted.mkString(", ") + + s"\n\toriginal parameters: $originalParams" +} + +final case class IncorrectLearner( + originalParams: String, + modifiedPs: String, + learnerCanonicalName: String +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Found: $learnerCanonicalName." + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class ClassCastErr( + originalParams: String, + modifiedPs: String, + ccException: ClassCastException +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Encountered ClassCastException: ${ccException.getMessage}" + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class VwError( + originalParams: String, + modifiedPs: String, + vwErrMsg: String +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"VW could not create a learner of type " + + s"${classOf[ExpectedLearner].getCanonicalName}. Error: $vwErrMsg. " + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class NamespaceError( + originalParams: String, + namespaceNames: Set[String], + flagsReferencingMissingNss: Map[String, Set[Char]] +) extends VwParamError { + override def modifiedParams: Option[String] = None + override def errorMessage: String = { + val flagErrs = + flagsReferencingMissingNss + .toSeq.sortBy(_._1) + .map { case(f, s) => s"$f: ${s.mkString(",")}" } + .mkString("; ") + val nss = namespaceNames.toVector.sorted.mkString(", ") + val vwNss = VwMultilabelModel.toVwNsSet(namespaceNames).toVector.sorted.mkString(",") + + s"Detected flags referencing namespaces not in the provided set. $flagErrs" + + s". Expected only $vwNss from provided namespaces: $nss." + + s"\n\toriginal parameters: $originalParams" + } +} + +final case class LabelNamespaceError( + originalParams: String, + namespaceNames: Set[String] +) extends VwParamError { + override def modifiedParams: Option[String] = None + override def errorMessage: String = { + val nss = namespaceNames.toSeq.sorted.mkString(", ") + s"Could not determine label namespaces from namespaces: $nss" + + s"\n\toriginal parameters: $originalParams" + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala index 0ddb3d8b..13d0a26b 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -32,7 +32,7 @@ case class VwSparseMultilabelPredictorProducer[K]( params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])], - labelNamespace: String, + labelNamespace: Char, numLabelsInTrainingSet: Int ) extends SparsePredictorProducer[K] { override def apply(): VwSparseMultilabelPredictor[K] = diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index eebf3388..1bb7ccb0 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -14,7 +14,6 @@ import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} import com.eharmony.aloha.io.vfs.Vfs import com.eharmony.aloha.models.Model import com.eharmony.aloha.models.multilabel.MultilabelModel -import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.IncorrectLearner import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances import com.eharmony.aloha.semantics.func.GenFunc0 import org.apache.commons.vfs2.VFS @@ -36,13 +35,47 @@ import scala.util.Random class VwMultilabelModelTest { import VwMultilabelModelTest._ + @Test def testExpectedUnrecoverableFlags(): Unit = { + assertEquals( + "Unrecoverable flags has changed.", + Set("cubic", "redefine", "stage_poly", "ignore", "ignore_linear", "keep"), + VwMultilabelModel.UnrecoverableFlagSet + ) + } + + @Test def testUnrecoverable(): Unit = { + val unrec = VwMultilabelModel.UnrecoverableFlagSet.iterator.map { f => + VwMultilabelModel.updatedVwParams(s"--$f", Set.empty) + } + + unrec foreach { + case Left(UnrecoverableParams(p, us)) => assertEquals(p, us.map(u => s"--$u").mkString(" ")) + case _ => fail() + } + } - @Test def testWrongLearner(): Unit = { + @Test def testQuadraticCreation(): Unit = { val args = "--csoaa_ldf mc" - VwMultilabelModel.updatedVwParams("--csoaa_ldf mc", Set.empty) match { - case Left(IncorrectLearner(o, _, c)) => - assertEquals(args, o) - assertNotEquals(classOf[VwSparseMultilabelPredictor.ExpectedLearner].getCanonicalName, c) + + // Notice: ignore_linear and quadratics are in sorted order. + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear a --ignore_linear b -qYa -qYb" + VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd")) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreation(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb" + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear a --ignore_linear b --ignore_linear c " + + "-qYa -qYb -qYc " + + "--cubic Yab --cubic Ybc" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd", "cde")) match { + case Right(s) => assertEquals(exp, s) case _ => fail() } } From f86b610e97f4843e1aaa7580977a659cb5a55108 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 5 Oct 2017 14:10:42 -0700 Subject: [PATCH 71/98] removed println --- .../vw/jni/multilabel/VwMultilabelParamAugmentation.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 1b19f621..1044ffde 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -85,11 +85,7 @@ private[multilabel] trait VwMultilabelParamAugmentation { ) private[multilabel] def removeParams(padded: String): String = - OptionsRemoved.foldLeft(padded){(s, r) => - val v = r.replaceAllIn(s, " ") - println(s"after $r:".padTo(70, ' ') + v) - v - } + OptionsRemoved.foldLeft(padded)((s, r) => r.replaceAllIn(s, " ")) private[multilabel] def addParams( paramsAfterRemoved: String, From 4c259721e20ba341b560e891020103bd3ebdd0e1 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Thu, 5 Oct 2017 14:18:24 -0700 Subject: [PATCH 72/98] made ignore_linear more concise --- .../vw/jni/multilabel/VwMultilabelParamAugmentation.scala | 2 +- .../models/vw/jni/multilabel/VwMultilabelModelTest.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 1044ffde..2b49005e 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -99,7 +99,7 @@ private[multilabel] trait VwMultilabelParamAugmentation { val quadratics = q.toSeq.sorted.map{ case (y, x) => s"-q$y$x" }.mkString(" ") val cubics = c.toSeq.sorted.map{ case (y, x1, x2) => s"--cubic $y$x1$x2" }.mkString(" ") - val igLin = il.toSeq.sorted.map(n => s"--ignore_linear $n").mkString(" ") + val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" val ig = s"--ignore ${labelNs.dummyLabelNs}" (paramsAfterRemoved.trim + s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics").trim diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 1bb7ccb0..85bda745 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -59,7 +59,7 @@ class VwMultilabelModelTest { // Notice: ignore_linear and quadratics are in sorted order. val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + - "--ignore_linear a --ignore_linear b -qYa -qYb" + "--ignore_linear ab -qYa -qYb" VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd")) match { case Right(s) => assertEquals(exp, s) case _ => fail() @@ -69,12 +69,12 @@ class VwMultilabelModelTest { @Test def testCubicCreation(): Unit = { val args = "--csoaa_ldf mc -qab --quadratic cb" val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + - "--ignore_linear a --ignore_linear b --ignore_linear c " + - "-qYa -qYb -qYc " + + "--ignore_linear abcd " + + "-qYa -qYb -qYc -qYd " + "--cubic Yab --cubic Ybc" // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd", "cde")) match { + VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd", "cde", "def")) match { case Right(s) => assertEquals(exp, s) case _ => fail() } From 79f988beb347edd6373cf371cca2d9837596a2ff Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 6 Oct 2017 18:18:24 -0700 Subject: [PATCH 73/98] lots of stuff working. More tests to write for VwMultilabelParamAugmentation. --- .../VwMultilabelParamAugmentation.scala | 192 ++++++++++++------ .../vw/jni/multilabel/VwParamError.scala | 17 +- .../multilabel/VwMultilabelModelTest.scala | 49 ----- .../VwMultilabelParamAugmentationTest.scala | 159 +++++++++++++++ 4 files changed, 300 insertions(+), 117 deletions(-) create mode 100644 aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 2b49005e..da3e23d2 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -6,6 +6,7 @@ import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.E import org.apache.commons.io.IOUtils import vowpalWabbit.learner.VWLearners +import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.util.matching.Regex @@ -13,11 +14,10 @@ import scala.util.matching.Regex /** * Created by ryan.deak on 10/5/17. */ -private[multilabel] trait VwMultilabelParamAugmentation { - - private[multilabel] type VWNsSet = Set[Char] - private[multilabel] type VWNsCrossProdSet = Set[(Char, Char)] +protected trait VwMultilabelParamAugmentation { + protected type VWNsSet = Set[Char] + protected type VWNsCrossProdSet = Set[(Char, Char)] /** * Add VW parameters to make the multilabel model work: @@ -30,17 +30,22 @@ private[multilabel] trait VwMultilabelParamAugmentation { */ def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { val padded = s" ${vwParams.trim} " - val unrecovFlags = unrecoverableFlags(padded) - if (unrecovFlags.nonEmpty) + lazy val unrecovFlags = unrecoverableFlags(padded) + + if (WapOrCsoaa.findFirstMatchIn(padded).isEmpty) + Left(NotCsoaaOrWap(vwParams)) + else if (unrecovFlags.nonEmpty) Left(UnrecoverableParams(vwParams, unrecovFlags)) else { - val q = quadratics(padded) // Turn these into cubic features later. + val is = interactions(padded) + val i = ignored(padded) + val il = ignoredLinear(padded) // This won't effect anything if the definition of UnrecoverableFlags contains // all of the flags referenced in the flagsRefMissingNss function. If there // are flags referenced in flagsRefMissingNss but not in UnrecoverableFlags, // then this is a valid check. - val flagsRefMissingNss = flagsReferencingMissingNss(padded, namespaceNames, q) + val flagsRefMissingNss = flagsReferencingMissingNss(padded, namespaceNames, i, il, is) if (flagsRefMissingNss.nonEmpty) Left(NamespaceError(vwParams, namespaceNames, flagsRefMissingNss)) @@ -49,28 +54,29 @@ private[multilabel] trait VwMultilabelParamAugmentation { Left(LabelNamespaceError(vwParams, namespaceNames)): Either[VwParamError, String] ){ labelNs => val paramsWithoutRemoved = removeParams(padded) - val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, q, labelNs) + val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs) validateVwParams(vwParams, updatedParams, !isQuiet(updatedParams)) } } } + protected val UnrecoverableFlagSet: Set[String] = + Set("redefine", "stage_poly", "keep", "permutations") + private[this] def pad(s: String) = "\\s" + s + "\\s" private[this] val NumRegex = """-?(\d+(\.\d*)?|\d*\.\d+)([eE][+-]?\d+)?""" private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r private[this] val CsoaaRank = pad("--csoaa_rank").r - private[this] val CsoaaLdf = pad("""--csoaa_ldf\s+(mc?)""").r - private[this] val CsoaaRegression = "m" - private[this] val CsoaaClassification = "mc" + private[this] val WapOrCsoaa = pad("""--(csoaa|wap)_ldf\s+(mc?)""").r private[this] val Quiet = pad("--quiet").r private[this] val Keep = pad("""--keep\s+(\S+)""").r - private[this] val Ignore = pad("""--ignore\s+(\S+)""").r - private[this] val IgnoreLinear = pad("""--ignore_linear\s+(\S+)""").r - private[multilabel] val UnrecoverableFlagSet = - Set("cubic", "redefine", "stage_poly", "ignore", "ignore_linear", "keep") + protected val Ignore : Regex = pad("""--ignore\s+(\S+)""").r + protected val IgnoreLinear: Regex = pad("""--ignore_linear\s+(\S+)""").r private[this] val UnrecoverableFlags = pad("--(" + UnrecoverableFlagSet.mkString("|") + ")").r - private[this] val QuadraticsShort = pad("""-q\s*([\S])([\S])""").r - private[this] val QuadraticsLong = pad("""--quadratic\s+([\S])([\S])""").r + private[this] val QuadraticsShort = pad("""-q\s*([\S]{2})""").r + private[this] val QuadraticsLong = pad("""--quadratic\s+(\S{2})""").r + private[this] val Cubics = pad("""--cubic\s+(\S{3})""").r + private[this] val Interactions = pad("""--interactions\s+(\S\S+)""").r private[this] val NoConstant = pad("""--noconstant""").r private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r @@ -78,111 +84,164 @@ private[multilabel] trait VwMultilabelParamAugmentation { private[this] val OptionsRemoved = Seq( QuadraticsShort, QuadraticsLong, + Cubics, + Interactions, NoConstant, ConstantShort, ConstantLong, - CsoaaRank + CsoaaRank, + IgnoreLinear, + Ignore ) - private[multilabel] def removeParams(padded: String): String = - OptionsRemoved.foldLeft(padded)((s, r) => r.replaceAllIn(s, " ")) + protected def removeParams(padded: String): String = { + @tailrec def replaceAll(s: String, r: Regex): String = { + val str = r.replaceAllIn(s, " ").trim + s" $str " match { + case v if v == s => v + case v => replaceAll(v, r) + } + } + + // TODO: Figure out why Regex.replaceAllIn doesn't replace all. + OptionsRemoved.foldLeft(padded)((s, r) => replaceAll(s, r)) + } - private[multilabel] def addParams( + protected def addParams( paramsAfterRemoved: String, namespaceNames: Set[String], - oldQuadratics: VWNsCrossProdSet, + oldIgnored: VWNsSet, + oldIgnoredLinear: VWNsSet, + oldInteractions: Set[String], labelNs: LabelNamespaces ): String = { - val il = toVwNsSet(namespaceNames) - val q = il.map(n => (labelNs.labelNs, n)) - val c = oldQuadratics.map { case (a, b) => firstThenOrderedCross(labelNs.labelNs, a, b) } + val i = oldIgnored + labelNs.dummyLabelNs + val il = (toVwNsSet(namespaceNames) ++ oldIgnoredLinear) -- i + val qs = il.flatMap(n => + if (oldIgnored.contains(n) || oldIgnoredLinear.contains(n)) Nil + else List(s"${labelNs.labelNs}$n") + ) + + // Turn quadratic into cubic and cubic into higher-order interactions. + val cs = createLabelInteractions(oldInteractions, oldIgnored, labelNs, _ == 2) + val hos = createLabelInteractions(oldInteractions, oldIgnored, labelNs, _ >= 3) - val quadratics = q.toSeq.sorted.map{ case (y, x) => s"-q$y$x" }.mkString(" ") - val cubics = c.toSeq.sorted.map{ case (y, x1, x2) => s"--cubic $y$x1$x2" }.mkString(" ") + val quadratics = qs.toSeq.sorted.map(q => s"-q$q" ).mkString(" ") + val cubics = cs.toSeq.sorted.map(c => s"--cubic $c").mkString(" ") + val ints = hos.toSeq.sorted.map(ho => s"--interactions $ho").mkString(" ") val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" - val ig = s"--ignore ${labelNs.dummyLabelNs}" + val ig = s"--ignore ${i.mkString("")}" + + (paramsAfterRemoved.trim + s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics $ints").trim + } + + protected def createLabelInteractions( + interactions: Set[String], + ignored: VWNsSet, + labelNs: LabelNamespaces, + filter: Int => Boolean + ): Set[String] = + interactions.collect { + case i if filter(i.length) && !i.toCharArray.exists(ignored.contains) => s"${labelNs.labelNs}$i" + } - (paramsAfterRemoved.trim + s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics").trim + protected def interactions(padded: String): Set[String] = { + List( + QuadraticsShort, + QuadraticsLong, + Cubics, + Interactions + ).foldLeft(Set.empty[String]){(is, r) => + is ++ firstGroups(padded, r).map(s => s.sorted) + } } - private[multilabel] def setOfFirstGroup(padded: String, regex: Regex): Set[String] = - regex.findAllMatchIn(padded).map(m => m.group(1)).toSet + protected def firstGroups(padded: String, regex: Regex): Iterator[String] = + regex.findAllMatchIn(padded).map(m => m.group(1)) + +// protected def setOfFirstGroup(padded: String, regex: Regex): Set[String] = +// regex.findAllMatchIn(padded).map(m => m.group(1)).toSet - private[multilabel] def setOfFirstGroups(padded: String, regexs: Regex*): Set[String] = - regexs.aggregate(Set.empty[String])((_, r) => setOfFirstGroup(padded, r), _ ++ _) +// protected def setOfFirstGroups(padded: String, regexs: Regex*): Set[String] = +// regexs.aggregate(Set.empty[String])((_, r) => setOfFirstGroup(padded, r), _ ++ _) - private[multilabel] def noconstant(padded: String): Boolean = + protected def noconstant(padded: String): Boolean = NoConstant.findFirstIn(padded).nonEmpty - private[multilabel] def constant(padded: String): Set[Double] = - setOfFirstGroups(padded, ConstantShort, ConstantLong).map(s => s.toDouble) +// protected def constant(padded: String): Set[Double] = +// setOfFirstGroups(padded, ConstantShort, ConstantLong).map(s => s.toDouble) - private[multilabel] def unorderedCross[A](a: A, b: A)(implicit o: Ordering[A]) = + protected def unorderedCross[A](a: A, b: A)(implicit o: Ordering[A]): (A, A) = if (o.lt(a, b)) (a, b) else (b, a) - private[multilabel] def firstThenOrderedCross[A](first: A, b: A, c: A)(implicit o: Ordering[A]) = { + protected def firstThenOrderedCross[A](first: A, b: A, c: A)(implicit o: Ordering[A]): (A, A, A) = { val (x, y) = unorderedCross(b, c) (first, x, y) } - private[multilabel] def quad(r: Regex, chrSeq: CharSequence): VWNsCrossProdSet = + protected def quad(r: Regex, chrSeq: CharSequence): VWNsCrossProdSet = r.findAllMatchIn(chrSeq).map { m => val Seq(a, b) = (1 to 2).map(i => m.group(i).charAt(0)) unorderedCross(a, b) }.toSet - private[multilabel] def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = + protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet - private[multilabel] def unrecoverableFlags(padded: String): Set[String] = + protected def unrecoverableFlags(padded: String): Set[String] = UnrecoverableFlags.findAllMatchIn(padded).map(m => m.group(1)).toSet - private[multilabel] def quadratics(padded: String): VWNsCrossProdSet = + protected def quadratics(padded: String): VWNsCrossProdSet = quad(QuadraticsShort, padded) ++ quad(QuadraticsLong, padded) - private[multilabel] def crossToSet[A](cross: Set[(A, A)]): Set[A] = + protected def crossToSet[A](cross: Set[(A, A)]): Set[A] = cross flatMap { case (a, b) => Set(a, b) } - private[multilabel] def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty + protected def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty - private[multilabel] def kept(padded: String): VWNsSet = charsIn(Keep, padded) - private[multilabel] def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) - private[multilabel] def ignoredLinear(padded: String): VWNsSet = charsIn(IgnoreLinear, padded) +// protected def kept(padded: String): VWNsSet = charsIn(Keep, padded) + protected def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) + protected def ignoredLinear(padded: String): VWNsSet = charsIn(IgnoreLinear, padded) - private[multilabel] def handleClassCastException(orig: String, mod: String, ex: ClassCastException) = + protected def handleClassCastException(orig: String, mod: String, ex: ClassCastException): VwParamError = ex.getMessage match { case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) case _ => ClassCastErr(orig, mod, ex) } - private[multilabel] def flagsReferencingMissingNss( + protected def flagsReferencingMissingNss( padded: String, namespaceNames: Set[String], - q: VWNsCrossProdSet + i: VWNsSet, + il: VWNsSet, + is: Set[String] ): Map[String, VWNsSet] = { - val nsq = crossToSet(q) - val k = kept(padded) - val i = ignored(padded) - val il = ignoredLinear(padded) - flagsReferencingMissingNss(namespaceNames, k, i, il, nsq) + val q = filterAndFlattenInteractions(is, _ == 2) + val c = filterAndFlattenInteractions(is, _ == 3) + val ho = filterAndFlattenInteractions(is, _ >= 4) + flagsReferencingMissingNss(namespaceNames, i, il, q, c, ho) } - private[multilabel] def flagsReferencingMissingNss( + protected def filterAndFlattenInteractions(is: Set[String], filter: Int => Boolean): VWNsSet = + is.flatMap { + case interaction if filter(interaction.length) => interaction.toCharArray + case _ => Nil + } + + protected def flagsReferencingMissingNss( namespaceNames: Set[String], - k: VWNsSet, i: VWNsSet, il: VWNsSet, nsq: VWNsSet + i: VWNsSet, il: VWNsSet, q: VWNsSet, c: VWNsSet, ho: VWNsSet ): Map[String, VWNsSet] = nssNotInNamespaceNames( namespaceNames, - "keep" -> k, "ignore" -> i, "ignore_linear" -> il, - "quadratic" -> nsq + "quadratic" -> q, + "cubic" -> c, + "interactions" -> ho ) - private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = - nsNames.flatMap(_.take(1).toCharArray) - - private[multilabel] def nssNotInNamespaceNames( + protected def nssNotInNamespaceNames( nsNames: Set[String], sets: (String, VWNsSet)* ): Map[String, VWNsSet] = { @@ -195,7 +254,10 @@ private[multilabel] trait VwMultilabelParamAugmentation { } } - private[multilabel] def validateVwParams(orig: String, mod: String, addQuiet: Boolean) = { + private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = + nsNames.flatMap(_.take(1).toCharArray) + + protected def validateVwParams(orig: String, mod: String, addQuiet: Boolean): Either[VwParamError, String] = { val ps = if (addQuiet) s"--quiet $mod" else mod Try { VWLearners.create[ExpectedLearner](ps) } match { diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala index 52150de2..235b0fb7 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala @@ -22,6 +22,14 @@ final case class UnrecoverableParams( s"\n\toriginal parameters: $originalParams" } +final case class NotCsoaaOrWap(originalParams: String) extends VwParamError { + def modifiedParams: Option[String] = None + def errorMessage: String = + "Model must be a csoaa_ldf or wap_ldf model." + s"\n\toriginal parameters: $originalParams" +} + + final case class IncorrectLearner( originalParams: String, modifiedPs: String, @@ -73,13 +81,16 @@ final case class NamespaceError( val flagErrs = flagsReferencingMissingNss .toSeq.sortBy(_._1) - .map { case(f, s) => s"$f: ${s.mkString(",")}" } + .map { case(f, s) => s"--$f: ${s.mkString(",")}" } .mkString("; ") val nss = namespaceNames.toVector.sorted.mkString(", ") val vwNss = VwMultilabelModel.toVwNsSet(namespaceNames).toVector.sorted.mkString(",") - s"Detected flags referencing namespaces not in the provided set. $flagErrs" + - s". Expected only $vwNss from provided namespaces: $nss." + + s"Detected flags referencing namespaces not in the provided set. $flagErrs. " + + ( + if (vwNss.isEmpty) "No namespaces provided." + else s"Expected only $vwNss from provided namespaces: $nss." + ) + s"\n\toriginal parameters: $originalParams" } } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 85bda745..68a46fc5 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -35,55 +35,6 @@ import scala.util.Random class VwMultilabelModelTest { import VwMultilabelModelTest._ - @Test def testExpectedUnrecoverableFlags(): Unit = { - assertEquals( - "Unrecoverable flags has changed.", - Set("cubic", "redefine", "stage_poly", "ignore", "ignore_linear", "keep"), - VwMultilabelModel.UnrecoverableFlagSet - ) - } - - @Test def testUnrecoverable(): Unit = { - val unrec = VwMultilabelModel.UnrecoverableFlagSet.iterator.map { f => - VwMultilabelModel.updatedVwParams(s"--$f", Set.empty) - } - - unrec foreach { - case Left(UnrecoverableParams(p, us)) => assertEquals(p, us.map(u => s"--$u").mkString(" ")) - case _ => fail() - } - } - - @Test def testQuadraticCreation(): Unit = { - val args = "--csoaa_ldf mc" - - // Notice: ignore_linear and quadratics are in sorted order. - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + - "--ignore_linear ab -qYa -qYb" - VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd")) match { - case Right(s) => assertEquals(exp, s) - case _ => fail() - } - } - - @Test def testCubicCreation(): Unit = { - val args = "--csoaa_ldf mc -qab --quadratic cb" - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + - "--ignore_linear abcd " + - "-qYa -qYb -qYc -qYd " + - "--cubic Yab --cubic Ybc" - - // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, Set("abc", "bcd", "cde", "def")) match { - case Right(s) => assertEquals(exp, s) - case _ => fail() - } - } - - // TODO: More VW argument augmentation tests!!! - - - @Test def testTrainedModelWorks(): Unit = { val model = Model diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala new file mode 100644 index 00000000..8b9fd262 --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -0,0 +1,159 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +/** + * Created by ryan.deak on 10/6/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { + + @Test def testNotCsoaaWap(): Unit = { + val args = "" + VwMultilabelModel.updatedVwParams(args, Set.empty) match { + case Left(NotCsoaaOrWap(ps)) => assertEquals(args, ps) + case _ => fail() + } + } + + @Test def testExpectedUnrecoverableFlags(): Unit = { + assertEquals( + "Unrecoverable flags has changed.", + Set("redefine", "stage_poly", "keep", "permutations"), + UnrecoverableFlagSet + ) + } + + @Test def testUnrecoverable(): Unit = { + val unrec = UnrecoverableFlagSet.iterator.map { f => + VwMultilabelModel.updatedVwParams(s"--csoaa_ldf mc --$f", Set.empty) + } + + unrec foreach { + case Left(UnrecoverableParams(p, us)) => + assertEquals( + p, + us.map(u => s"--$u") + .mkString("--csoaa_ldf mc ", " ", "") + ) + case _ => fail() + } + } + + @Test def testIgnoredNotInNsSet(): Unit = { + val args = "--csoaa_ldf mc --ignore a" + val origNss = Set.empty[String] + VwMultilabelModel.updatedVwParams(args, origNss) match { + case Left(NamespaceError(o, nss, bad)) => + assertEquals(args, o) + assertEquals(origNss, nss) + assertEquals(Map("ignore" -> Set('a')), bad) + case _ => fail() + } + } + + @Test def testIgnoredNotInNsSet2(): Unit = { + val args = "--csoaa_ldf mc --ignore ab" + val origNss = Set("a") + VwMultilabelModel.updatedVwParams(args, origNss) match { + case Left(NamespaceError(o, nss, bad)) => + assertEquals(args, o) + assertEquals(origNss, nss) + assertEquals(Map("ignore" -> Set('b')), bad) + case _ => fail() + } + } + + @Test def testQuadraticCreation(): Unit = { + val args = "--csoaa_ldf mc" + val nss = Set("abc", "bcd") + + // Notice: ignore_linear and quadratics are in sorted order. + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear ab -qYa -qYb" + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testIgnoredNoQuadraticCreation(): Unit = { + val args = "--csoaa_ldf mc --ignore_linear a" + val nss = Set("abc", "bcd") + + // Notice: ignore_linear and quadratics are in sorted order. + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear ab -qYb" + + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreation(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb" + val nss = Set("abc", "bcd", "cde", "def") + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear abcd " + + "-qYa -qYb -qYc -qYd " + + "--cubic Yab --cubic Ybc" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreationIgnoredLinear(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb --ignore_linear d" + val nss = Set("abc", "bcd", "cde", "def") + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear abcd " + + "-qYa -qYb -qYc " + + "--cubic Yab --cubic Ybc" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreationIgnored(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb --ignore c" + val nss = Set("abc", "bcd", "cde", "def") + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cy " + + "--ignore_linear abd " + + "-qYa -qYb -qYd " + + "--cubic Yab" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicWithInteractionsCreationIgnored(): Unit = { + val args = "--csoaa_ldf mc --interactions ab --interactions cb --ignore c --ignore d" + val nss = Set("abc", "bcd", "cde", "def") + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cy " + + "--ignore_linear abd " + + "-qYa -qYb -qYd " + + "--cubic Yab" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + // TODO: More VW argument augmentation tests!!! + +} From 9f615973bbb39111a1955519c2507207737b3d07 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 6 Oct 2017 18:22:51 -0700 Subject: [PATCH 74/98] tested higher order interactions. --- .../VwMultilabelParamAugmentationTest.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 8b9fd262..370d0c57 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -154,6 +154,19 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } } + @Test def testHigherOrderInteractions(): Unit = { + val args = "--csoaa_ldf mc --interactions abcd --ignore_linear abcd" + val nss = Set("abc", "bcd", "cde", "def") + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear abcd " + + "--interactions Yabcd" + + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s.replaceAll(" +", " ")) + case _ => fail() + } + } + // TODO: More VW argument augmentation tests!!! } From ca84c366c944e6af2137c0377d6e7124eb1f3533 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 6 Oct 2017 19:38:01 -0700 Subject: [PATCH 75/98] removed extra whitespace in string output. --- .../vw/jni/multilabel/VwMultilabelParamAugmentation.scala | 5 +++-- .../jni/multilabel/VwMultilabelParamAugmentationTest.scala | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index da3e23d2..f553afcc 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -131,8 +131,9 @@ protected trait VwMultilabelParamAugmentation { val ints = hos.toSeq.sorted.map(ho => s"--interactions $ho").mkString(" ") val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" val ig = s"--ignore ${i.mkString("")}" - - (paramsAfterRemoved.trim + s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics $ints").trim + val additions = s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics $ints" + .replaceAll("\\s+", " ") + (paramsAfterRemoved.trim + additions).trim } protected def createLabelInteractions( diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 370d0c57..93682e00 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -162,7 +162,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { "--interactions Yabcd" VwMultilabelModel.updatedVwParams(args, nss) match { - case Right(s) => assertEquals(exp, s.replaceAll(" +", " ")) + case Right(s) => assertEquals(exp, s) case _ => fail() } } From db8967f285dfbb06dc326fd3eebd6e1c494790f2 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 9 Oct 2017 14:12:10 -0700 Subject: [PATCH 76/98] working but will change regex padding to use zero-width positive lookahead. --- .../VwMultilabelParamAugmentation.scala | 232 +++++++++++++----- .../VwMultilabelParamAugmentationTest.scala | 60 ++++- 2 files changed, 230 insertions(+), 62 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index f553afcc..89d203a7 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -20,8 +20,103 @@ protected trait VwMultilabelParamAugmentation { protected type VWNsCrossProdSet = Set[(Char, Char)] /** - * Add VW parameters to make the multilabel model work: + * Adds VW parameters to make the parameters work as an Aloha multilabel model. * + * The algorithm works as follows: + * + 1. Ensure the VW `csoaa_ldf` or `wap_ldf` reduction is specified in the supplied VW + parameter list (''with the appropriate option for the flag''). + 1. Ensure that no "''unrecoverable''" flags appear in the supplied VW parameter list. + See `UnrecoverableFlagSet` for flags whose appearance is considered + "''unrecoverable''". + 1. Ensure that ''ignore'' and ''interaction'' flags (`--ignore`, `--ignore_linear`, `-q`, + `--quadratic`, `--cubic`) do not refer to namespaces not supplied in + the `namespaceNames` parameter. + 1. Attempt to determine namespace names that can be used for the labels. For more + information on the label namespace resolution algorithm, see: + `com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.determineLabelNamespaces`. + 1. Remove flags and options found in `FlagsToRemove`. + 1. Add `--noconstant` and `--csoaa_rank` flags. `--noconstant` is added because per-label + intercepts will be included and take the place of a single intercept. `--csoaa_rank` + is added to make the `VWLearner` a `VWActionScoresLearner`. + 1. Create interactions between features and the label namespaces created above. + a. If a namespace in `namespaceNames` appears as an option to VW's `ignore_linear` flag, + '''do not''' create a quadratic interaction between that namespace and the label + namespace. + a. For each interaction term (`-q`, `--quadratic`, `--cubic`, `--interactions`), replace it + with an interaction term also interacted with the label namespace. This increases the + arity of the interaction by 1. + * + * ==Success Examples== + * + * {{{ + * import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.updatedVwParams + * + * // This is a basic example. 'y' and 'Y' in the output are label + * // namespaces. Notice all namespaces are quadratically interacted + * // with the label namespace. + * val uvw1 = updatedVwParams( + * "--csoaa_ldf mc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + * // "--ignore_linear abc -qYa -qYb -qYc") + * + * // Here since 'a' is in 'ignore_linear', no '-qYa' term appears + * // in the output. + * val uvw2 = updatedVwParams( + * "--csoaa_ldf mc --ignore_linear a -qbc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + * // "--ignore_linear abc -qYb -qYc --cubic Ybc) + * + * // 'a' is in 'ignore', so no terms with 'a' are emitted. 'b' is + * // in 'ignore_linear' so it does occur in any quadratic + * // interactions in the output, but can appear in interaction + * // terms of higher arity like the cubic interaction. + * val uvw3 = updatedVwParams( + * "--csoaa_ldf mc --ignore a --ignore_linear b -qbc --cubic abc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore ay " + + * // "--ignore_linear bc -qYc --cubic Ybc") + * }}} + * + * ==Errors Examples== + * + * {{{ + * import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.updatedVwParams + * import com.eharmony.aloha.models.vw.jni.multilabel.{ + * NotCsoaaOrWap, + * NamespaceError + * } + * + * assert( updatedVwParams("", Set()) == Left(NotCsoaaOrWap("")) ) + * + * assert( + * updatedVwParams("--wap_ldf m -qaa", Set()) == + * Left(NamespaceError("--wap_ldf m -qaa",Set(),Map("quadratic" -> Set('a')))) + * ) + * + * assert( + * updatedVwParams( + * "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", + * Set() + * ) == + * Left( + * NamespaceError( + * "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", + * Set(), + * Map( + * "ignore" -> Set('a'), + * "quadratic" -> Set('b'), + * "cubic" -> Set('b', 'c', 'd', 'e') + * ) + * ) + * ) + * ) + * }}} * * @param vwParams current VW parameters passed to the VW JNI * @param namespaceNames it is assumed that `namespaceNames` is a superset @@ -60,28 +155,39 @@ protected trait VwMultilabelParamAugmentation { } } + /** + * VW Flags automatically resulting in an error. + */ protected val UnrecoverableFlagSet: Set[String] = Set("redefine", "stage_poly", "keep", "permutations") + /** + * Use this to pad Regex instances ending in `\S+`. + * `findAllMatchesIn` finds non-overlapping matches. Because of greediness + * in regex matching, the next character (if one exists) must be whitespace when + * a regex ends in `\S+`. + * @param s string representation of Regex to be left-padded w/ one whitespace. + * @return a whitespace left-padded version of `s`. + */ + private[this] def leftPad(s: String) = "\\s" + s private[this] def pad(s: String) = "\\s" + s + "\\s" private[this] val NumRegex = """-?(\d+(\.\d*)?|\d*\.\d+)([eE][+-]?\d+)?""" private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r private[this] val CsoaaRank = pad("--csoaa_rank").r private[this] val WapOrCsoaa = pad("""--(csoaa|wap)_ldf\s+(mc?)""").r private[this] val Quiet = pad("--quiet").r - private[this] val Keep = pad("""--keep\s+(\S+)""").r - protected val Ignore : Regex = pad("""--ignore\s+(\S+)""").r - protected val IgnoreLinear: Regex = pad("""--ignore_linear\s+(\S+)""").r + protected val Ignore : Regex = leftPad("""--ignore\s+(\S+)""").r + protected val IgnoreLinear: Regex = leftPad("""--ignore_linear\s+(\S+)""").r private[this] val UnrecoverableFlags = pad("--(" + UnrecoverableFlagSet.mkString("|") + ")").r private[this] val QuadraticsShort = pad("""-q\s*([\S]{2})""").r private[this] val QuadraticsLong = pad("""--quadratic\s+(\S{2})""").r private[this] val Cubics = pad("""--cubic\s+(\S{3})""").r - private[this] val Interactions = pad("""--interactions\s+(\S\S+)""").r + private[this] val Interactions = leftPad("""--interactions\s+(\S\S+)""").r private[this] val NoConstant = pad("""--noconstant""").r private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r - private[this] val OptionsRemoved = Seq( + private[this] val FlagsToRemove = Seq( QuadraticsShort, QuadraticsLong, Cubics, @@ -94,17 +200,27 @@ protected trait VwMultilabelParamAugmentation { Ignore ) + /** + * Remove flags (and options) for the flags listed in `FlagsToRemove`. + * @param padded a padded version of the params passed to `updateVwParams`. + * @return + */ protected def removeParams(padded: String): String = { + // r.replaceAllIn replaces non-overlapping matches. Since multiple regular + // expressions begin or end with whitespace there can be a whitespace character + // that is part of two matches. To accommodate this, replaceAll keeps replacing + // the matches with whitespace until an equilibrium is reached. Once equilibrium + // is reached, move on to the next Regex. @tailrec def replaceAll(s: String, r: Regex): String = { - val str = r.replaceAllIn(s, " ").trim - s" $str " match { + // Replace the matches with whitespace. Trim and pad to avoid stack overflows. + val str = s" ${r.replaceAllIn(s, " ").trim} " + str match { case v if v == s => v case v => replaceAll(v, r) } } - // TODO: Figure out why Regex.replaceAllIn doesn't replace all. - OptionsRemoved.foldLeft(padded)((s, r) => replaceAll(s, r)) + FlagsToRemove.foldLeft(padded)(replaceAll) } protected def addParams( @@ -116,7 +232,12 @@ protected trait VwMultilabelParamAugmentation { labelNs: LabelNamespaces ): String = { val i = oldIgnored + labelNs.dummyLabelNs + + // Don't include namespaces that are ignored in ignore_linear. val il = (toVwNsSet(namespaceNames) ++ oldIgnoredLinear) -- i + + // Don't turn a given namespace into quadratics interacted on label when the + // namespace is listed in the ignore_linear flag. val qs = il.flatMap(n => if (oldIgnored.contains(n) || oldIgnoredLinear.contains(n)) Nil else List(s"${labelNs.labelNs}$n") @@ -130,7 +251,11 @@ protected trait VwMultilabelParamAugmentation { val cubics = cs.toSeq.sorted.map(c => s"--cubic $c").mkString(" ") val ints = hos.toSeq.sorted.map(ho => s"--interactions $ho").mkString(" ") val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" + + // This is non-empty b/c i is non-empty. val ig = s"--ignore ${i.mkString("")}" + + // Consolidate whitespace because there shouldn't be whitespace in these flags' options. val additions = s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics $ints" .replaceAll("\\s+", " ") (paramsAfterRemoved.trim + additions).trim @@ -143,68 +268,39 @@ protected trait VwMultilabelParamAugmentation { filter: Int => Boolean ): Set[String] = interactions.collect { - case i if filter(i.length) && !i.toCharArray.exists(ignored.contains) => s"${labelNs.labelNs}$i" + case i if filter(i.length) && // Filter based on arity. + !i.toCharArray.exists(ignored.contains) => // Filter out ignored. + s"${labelNs.labelNs}$i" } - protected def interactions(padded: String): Set[String] = { + /** + * Get the set of interactions (encoded as Strings). String length represents the + * interaction arity. + * @param padded the padded version of the original parameters based to `updatedVwParams`. + * @return + */ + protected def interactions(padded: String): Set[String] = List( QuadraticsShort, QuadraticsLong, Cubics, Interactions ).foldLeft(Set.empty[String]){(is, r) => - is ++ firstGroups(padded, r).map(s => s.sorted) + is ++ firstCaptureGroups(padded, r).map(s => s.sorted) } - } - - protected def firstGroups(padded: String, regex: Regex): Iterator[String] = - regex.findAllMatchIn(padded).map(m => m.group(1)) - -// protected def setOfFirstGroup(padded: String, regex: Regex): Set[String] = -// regex.findAllMatchIn(padded).map(m => m.group(1)).toSet - -// protected def setOfFirstGroups(padded: String, regexs: Regex*): Set[String] = -// regexs.aggregate(Set.empty[String])((_, r) => setOfFirstGroup(padded, r), _ ++ _) - - protected def noconstant(padded: String): Boolean = - NoConstant.findFirstIn(padded).nonEmpty - -// protected def constant(padded: String): Set[Double] = -// setOfFirstGroups(padded, ConstantShort, ConstantLong).map(s => s.toDouble) - - protected def unorderedCross[A](a: A, b: A)(implicit o: Ordering[A]): (A, A) = - if (o.lt(a, b)) (a, b) else (b, a) - - protected def firstThenOrderedCross[A](first: A, b: A, c: A)(implicit o: Ordering[A]): (A, A, A) = { - val (x, y) = unorderedCross(b, c) - (first, x, y) - } - - protected def quad(r: Regex, chrSeq: CharSequence): VWNsCrossProdSet = - r.findAllMatchIn(chrSeq).map { m => - val Seq(a, b) = (1 to 2).map(i => m.group(i).charAt(0)) - unorderedCross(a, b) - }.toSet - - protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = - r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet protected def unrecoverableFlags(padded: String): Set[String] = UnrecoverableFlags.findAllMatchIn(padded).map(m => m.group(1)).toSet - protected def quadratics(padded: String): VWNsCrossProdSet = - quad(QuadraticsShort, padded) ++ quad(QuadraticsLong, padded) - - protected def crossToSet[A](cross: Set[(A, A)]): Set[A] = - cross flatMap { case (a, b) => Set(a, b) } - protected def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty - -// protected def kept(padded: String): VWNsSet = charsIn(Keep, padded) protected def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) protected def ignoredLinear(padded: String): VWNsSet = charsIn(IgnoreLinear, padded) - protected def handleClassCastException(orig: String, mod: String, ex: ClassCastException): VwParamError = + protected def handleClassCastException( + orig: String, + mod: String, + ex: ClassCastException + ): VwParamError = ex.getMessage match { case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) case _ => ClassCastErr(orig, mod, ex) @@ -255,10 +351,11 @@ protected trait VwMultilabelParamAugmentation { } } - private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = - nsNames.flatMap(_.take(1).toCharArray) - - protected def validateVwParams(orig: String, mod: String, addQuiet: Boolean): Either[VwParamError, String] = { + protected def validateVwParams( + orig: String, + mod: String, + addQuiet: Boolean + ): Either[VwParamError, String] = { val ps = if (addQuiet) s"--quiet $mod" else mod Try { VWLearners.create[ExpectedLearner](ps) } match { @@ -271,4 +368,21 @@ protected trait VwMultilabelParamAugmentation { Left(VwError(orig, mod, ex.getMessage)) } } + + // More general functions. + + /** + * Find all of the regex matches and extract the first capture group from the match. + * @param padded the padded version of the original parameters based to `updatedVwParams`. + * @param regex with at least one capture group (this is unchecked). + * @return Iterator of the matches' first capture group. + */ + protected def firstCaptureGroups(padded: String, regex: Regex): Iterator[String] = + regex.findAllMatchIn(padded).map(m => m.group(1)) + + protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = + r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet + + private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = + nsNames.flatMap(_.take(1).toCharArray) } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 93682e00..6e090cf4 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -142,9 +142,9 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testCubicWithInteractionsCreationIgnored(): Unit = { val args = "--csoaa_ldf mc --interactions ab --interactions cb --ignore c --ignore d" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cy " + - "--ignore_linear abd " + - "-qYa -qYb -qYd " + + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cdy " + + "--ignore_linear ab " + + "-qYa -qYb " + "--cubic Yab" // Notice: ignore_linear and quadratics are in sorted order. @@ -167,6 +167,60 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } } + @Test def testMultipleInteractions(): Unit = { + val nss = ('a' to 'e').map(_.toString).toSet + + val args = s"--csoaa_ldf mc --interactions ab --interactions abc " + + "--interactions abcd --interactions abcde" + + val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + "--ignore_linear abcde " + + "-qYa -qYb -qYc -qYd -qYe " + + "--cubic Yab " + + "--interactions Yabc " + + "--interactions Yabcd " + + "--interactions Yabcde" + + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def interactionsWithSelf(): Unit = { + val nss = Set("a") + val args = "--wap_ldf m -qaa --cubic aaa --interactions aaaa" + val exp = "--wap_ldf m --noconstant --csoaa_rank --ignore y --ignore_linear a " + + "-qYa " + + "--cubic Yaa " + + "--interactions Yaaa " + + "--interactions Yaaaa" + + VwMultilabelModel.updatedVwParams(args, nss) match { + case Right(s) => assertEquals(exp, s) + case x => assertEquals("", x) + } + } + +// @Test def adf(): Unit = { +// val args = "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde" +// val updated = updatedVwParams(args, Set()) +// +// +// val exp = Left( +// NamespaceError( +// "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", +// Set(), +// Map( +// "ignore" -> Set('a'), +// "quadratic" -> Set('b', 'd'), +// "cubic" -> Set('b', 'c', 'd', 'e') +// ) +// ) +// ) +// +// assertEquals(exp, updated) +// } // TODO: More VW argument augmentation tests!!! } From 6f896304f6a6e17c46c04912bc32d7574d34609f Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 9 Oct 2017 14:39:05 -0700 Subject: [PATCH 77/98] added different padding. --- .../VwMultilabelParamAugmentation.scala | 41 ++++++++++------- .../VwMultilabelParamAugmentationTest.scala | 44 ++++++++++--------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 89d203a7..ea82a90d 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -162,27 +162,37 @@ protected trait VwMultilabelParamAugmentation { Set("redefine", "stage_poly", "keep", "permutations") /** - * Use this to pad Regex instances ending in `\S+`. - * `findAllMatchesIn` finds non-overlapping matches. Because of greediness - * in regex matching, the next character (if one exists) must be whitespace when - * a regex ends in `\S+`. - * @param s string representation of Regex to be left-padded w/ one whitespace. - * @return a whitespace left-padded version of `s`. + * This is the capture group containing the content when the regex has been + * padded with the pad function. */ - private[this] def leftPad(s: String) = "\\s" + s - private[this] def pad(s: String) = "\\s" + s + "\\s" + private val CaptureGroupWithContent = 2 + + + /** + * Pad the regular expression with a prefix and suffix that makes matching work. + * The prefix is `(^|\s)` and means the if there's a character preceding the main + * content in `s`, then that character should be whitespace. The suffix is + * `(?=\s|$)` which means that if a character follows the main content matched by + * `s`, then that character should be whitespace '''AND''' ''that character should + * not'' be consumed by the Regex. By allowing that character to be present for the + * next matching of a regex, it is consumable by the prefix of a regex padded with + * the `pad` function. + * @param s a string + * @return + */ + private[this] def pad(s: String) = """(^|\s)""" + s + """(?=\s|$)""" private[this] val NumRegex = """-?(\d+(\.\d*)?|\d*\.\d+)([eE][+-]?\d+)?""" private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r private[this] val CsoaaRank = pad("--csoaa_rank").r private[this] val WapOrCsoaa = pad("""--(csoaa|wap)_ldf\s+(mc?)""").r private[this] val Quiet = pad("--quiet").r - protected val Ignore : Regex = leftPad("""--ignore\s+(\S+)""").r - protected val IgnoreLinear: Regex = leftPad("""--ignore_linear\s+(\S+)""").r + protected val Ignore : Regex = pad("""--ignore\s+(\S+)""").r + protected val IgnoreLinear: Regex = pad("""--ignore_linear\s+(\S+)""").r private[this] val UnrecoverableFlags = pad("--(" + UnrecoverableFlagSet.mkString("|") + ")").r private[this] val QuadraticsShort = pad("""-q\s*([\S]{2})""").r private[this] val QuadraticsLong = pad("""--quadratic\s+(\S{2})""").r private[this] val Cubics = pad("""--cubic\s+(\S{3})""").r - private[this] val Interactions = leftPad("""--interactions\s+(\S\S+)""").r + private[this] val Interactions = pad("""--interactions\s+(\S{2,})""").r private[this] val NoConstant = pad("""--noconstant""").r private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r @@ -220,7 +230,8 @@ protected trait VwMultilabelParamAugmentation { } } - FlagsToRemove.foldLeft(padded)(replaceAll) +// FlagsToRemove.foldLeft(padded)(replaceAll) + FlagsToRemove.foldLeft(padded)((s, r) => r.replaceAllIn(s, " ")) } protected def addParams( @@ -290,7 +301,7 @@ protected trait VwMultilabelParamAugmentation { } protected def unrecoverableFlags(padded: String): Set[String] = - UnrecoverableFlags.findAllMatchIn(padded).map(m => m.group(1)).toSet + firstCaptureGroups(padded, UnrecoverableFlags).toSet protected def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty protected def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) @@ -378,10 +389,10 @@ protected trait VwMultilabelParamAugmentation { * @return Iterator of the matches' first capture group. */ protected def firstCaptureGroups(padded: String, regex: Regex): Iterator[String] = - regex.findAllMatchIn(padded).map(m => m.group(1)) + regex.findAllMatchIn(padded).map(m => m.group(CaptureGroupWithContent)) protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = - r.findAllMatchIn(chrSeq).flatMap(m => m.group(1).toCharArray).toSet + r.findAllMatchIn(chrSeq).flatMap(m => m.group(CaptureGroupWithContent).toCharArray).toSet private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = nsNames.flatMap(_.take(1).toCharArray) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 6e090cf4..ccddca6e 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -30,7 +30,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testUnrecoverable(): Unit = { val unrec = UnrecoverableFlagSet.iterator.map { f => VwMultilabelModel.updatedVwParams(s"--csoaa_ldf mc --$f", Set.empty) - } + }.toList unrec foreach { case Left(UnrecoverableParams(p, us)) => @@ -202,25 +202,29 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } } -// @Test def adf(): Unit = { -// val args = "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde" -// val updated = updatedVwParams(args, Set()) -// -// -// val exp = Left( -// NamespaceError( -// "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", -// Set(), -// Map( -// "ignore" -> Set('a'), -// "quadratic" -> Set('b', 'd'), -// "cubic" -> Set('b', 'c', 'd', 'e') -// ) -// ) -// ) -// -// assertEquals(exp, updated) -// } + @Test def testNamespaceErrors(): Unit = { + val args = "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + + "--cubic bcd --interactions dde --interactions abcde" + val updated = updatedVwParams(args, Set()) + + + val exp = Left( + NamespaceError( + "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + + "--interactions dde --interactions abcde", + Set(), + Map( + "ignore" -> Set('a'), + "ignore_linear" -> Set('b'), + "quadratic" -> Set('b', 'd'), + "cubic" -> Set('b', 'c', 'd', 'e'), + "interactions" -> Set('a', 'b', 'c', 'd', 'e') + ) + ) + ) + + assertEquals(exp, updated) + } // TODO: More VW argument augmentation tests!!! } From 95e0c924f9ab966641fc0822fc7316e85bc0538a Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 9 Oct 2017 14:59:45 -0700 Subject: [PATCH 78/98] Updated documentation and tests. Looks good. --- .../VwMultilabelParamAugmentation.scala | 84 ++++++++----------- .../VwMultilabelParamAugmentationTest.scala | 64 ++++++++------ 2 files changed, 73 insertions(+), 75 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index ea82a90d..21706628 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -6,9 +6,8 @@ import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.E import org.apache.commons.io.IOUtils import vowpalWabbit.learner.VWLearners -import scala.annotation.tailrec -import scala.util.{Failure, Success, Try} import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} /** @@ -92,26 +91,30 @@ protected trait VwMultilabelParamAugmentation { * NamespaceError * } * - * assert( updatedVwParams("", Set()) == Left(NotCsoaaOrWap("")) ) + * assert( updatedVwParams("", Set()) == Left(NotCsoaaOrWap("")) ) * * assert( * updatedVwParams("--wap_ldf m -qaa", Set()) == - * Left(NamespaceError("--wap_ldf m -qaa",Set(),Map("quadratic" -> Set('a')))) + * Left(NamespaceError("--wap_ldf m -qaa", Set(), Map("quadratic" -> Set('a')))) * ) * * assert( * updatedVwParams( - * "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", + * "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + + "--cubic bcd --interactions dde --interactions abcde", * Set() * ) == * Left( * NamespaceError( - * "--wap_ldf m --ignore a -qbb -qbd --cubic bcd --interactions dde", + * "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + + * "--interactions dde --interactions abcde", * Set(), * Map( - * "ignore" -> Set('a'), - * "quadratic" -> Set('b'), - * "cubic" -> Set('b', 'c', 'd', 'e') + * "ignore" -> Set('a'), + * "ignore_linear" -> Set('b'), + * "quadratic" -> Set('b', 'd'), + * "cubic" -> Set('b', 'c', 'd', 'e'), + * "interactions" -> Set('a', 'b', 'c', 'd', 'e') * ) * ) * ) @@ -120,27 +123,27 @@ protected trait VwMultilabelParamAugmentation { * * @param vwParams current VW parameters passed to the VW JNI * @param namespaceNames it is assumed that `namespaceNames` is a superset - * of all of the namespaces referred to by any + * of all of the namespaces referred to by any flags + * found in `vwParams`. * @return */ def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { - val padded = s" ${vwParams.trim} " - lazy val unrecovFlags = unrecoverableFlags(padded) + lazy val unrecovFlags = unrecoverableFlags(vwParams) - if (WapOrCsoaa.findFirstMatchIn(padded).isEmpty) + if (WapOrCsoaa.findFirstMatchIn(vwParams).isEmpty) Left(NotCsoaaOrWap(vwParams)) else if (unrecovFlags.nonEmpty) Left(UnrecoverableParams(vwParams, unrecovFlags)) else { - val is = interactions(padded) - val i = ignored(padded) - val il = ignoredLinear(padded) + val is = interactions(vwParams) + val i = ignored(vwParams) + val il = ignoredLinear(vwParams) // This won't effect anything if the definition of UnrecoverableFlags contains // all of the flags referenced in the flagsRefMissingNss function. If there // are flags referenced in flagsRefMissingNss but not in UnrecoverableFlags, // then this is a valid check. - val flagsRefMissingNss = flagsReferencingMissingNss(padded, namespaceNames, i, il, is) + val flagsRefMissingNss = flagsReferencingMissingNss(namespaceNames, i, il, is) if (flagsRefMissingNss.nonEmpty) Left(NamespaceError(vwParams, namespaceNames, flagsRefMissingNss)) @@ -148,7 +151,7 @@ protected trait VwMultilabelParamAugmentation { VwMultilabelRowCreator.determineLabelNamespaces(namespaceNames).fold( Left(LabelNamespaceError(vwParams, namespaceNames)): Either[VwParamError, String] ){ labelNs => - val paramsWithoutRemoved = removeParams(padded) + val paramsWithoutRemoved = removeParams(vwParams) val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs) validateVwParams(vwParams, updatedParams, !isQuiet(updatedParams)) } @@ -212,27 +215,11 @@ protected trait VwMultilabelParamAugmentation { /** * Remove flags (and options) for the flags listed in `FlagsToRemove`. - * @param padded a padded version of the params passed to `updateVwParams`. + * @param vwParams VW params passed to the `updatedVwParams` function. * @return */ - protected def removeParams(padded: String): String = { - // r.replaceAllIn replaces non-overlapping matches. Since multiple regular - // expressions begin or end with whitespace there can be a whitespace character - // that is part of two matches. To accommodate this, replaceAll keeps replacing - // the matches with whitespace until an equilibrium is reached. Once equilibrium - // is reached, move on to the next Regex. - @tailrec def replaceAll(s: String, r: Regex): String = { - // Replace the matches with whitespace. Trim and pad to avoid stack overflows. - val str = s" ${r.replaceAllIn(s, " ").trim} " - str match { - case v if v == s => v - case v => replaceAll(v, r) - } - } - -// FlagsToRemove.foldLeft(padded)(replaceAll) - FlagsToRemove.foldLeft(padded)((s, r) => r.replaceAllIn(s, " ")) - } + protected def removeParams(vwParams: String): String = + FlagsToRemove.foldLeft(vwParams)((s, r) => r.replaceAllIn(s, "")) protected def addParams( paramsAfterRemoved: String, @@ -287,25 +274,25 @@ protected trait VwMultilabelParamAugmentation { /** * Get the set of interactions (encoded as Strings). String length represents the * interaction arity. - * @param padded the padded version of the original parameters based to `updatedVwParams`. + * @param vwParams VW params passed to the `updatedVwParams` function. * @return */ - protected def interactions(padded: String): Set[String] = + protected def interactions(vwParams: String): Set[String] = List( QuadraticsShort, QuadraticsLong, Cubics, Interactions ).foldLeft(Set.empty[String]){(is, r) => - is ++ firstCaptureGroups(padded, r).map(s => s.sorted) + is ++ firstCaptureGroups(vwParams, r).map(s => s.sorted) } - protected def unrecoverableFlags(padded: String): Set[String] = - firstCaptureGroups(padded, UnrecoverableFlags).toSet + protected def unrecoverableFlags(vwParams: String): Set[String] = + firstCaptureGroups(vwParams, UnrecoverableFlags).toSet - protected def isQuiet(padded: String): Boolean = Quiet.findFirstIn(padded).nonEmpty - protected def ignored(padded: String): VWNsSet = charsIn(Ignore, padded) - protected def ignoredLinear(padded: String): VWNsSet = charsIn(IgnoreLinear, padded) + protected def isQuiet(vwParams: String): Boolean = Quiet.findFirstIn(vwParams).nonEmpty + protected def ignored(vwParams: String): VWNsSet = charsIn(Ignore, vwParams) + protected def ignoredLinear(vwParams: String): VWNsSet = charsIn(IgnoreLinear, vwParams) protected def handleClassCastException( orig: String, @@ -318,7 +305,6 @@ protected trait VwMultilabelParamAugmentation { } protected def flagsReferencingMissingNss( - padded: String, namespaceNames: Set[String], i: VWNsSet, il: VWNsSet, @@ -384,12 +370,12 @@ protected trait VwMultilabelParamAugmentation { /** * Find all of the regex matches and extract the first capture group from the match. - * @param padded the padded version of the original parameters based to `updatedVwParams`. + * @param vwParams VW params passed to the `updatedVwParams` function. * @param regex with at least one capture group (this is unchecked). * @return Iterator of the matches' first capture group. */ - protected def firstCaptureGroups(padded: String, regex: Regex): Iterator[String] = - regex.findAllMatchIn(padded).map(m => m.group(CaptureGroupWithContent)) + protected def firstCaptureGroups(vwParams: String, regex: Regex): Iterator[String] = + regex.findAllMatchIn(vwParams).map(m => m.group(CaptureGroupWithContent)) protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = r.findAllMatchIn(chrSeq).flatMap(m => m.group(CaptureGroupWithContent).toCharArray).toSet diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index ccddca6e..34b20124 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -67,6 +67,44 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } } + @Test def testNamespaceErrors(): Unit = { + val args = "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + + "--cubic bcd --interactions dde --interactions abcde" + val updated = updatedVwParams(args, Set()) + + val exp = Left( + NamespaceError( + "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + + "--interactions dde --interactions abcde", + Set(), + Map( + "ignore" -> Set('a'), + "ignore_linear" -> Set('b'), + "quadratic" -> Set('b', 'd'), + "cubic" -> Set('b', 'c', 'd', 'e'), + "interactions" -> Set('a', 'b', 'c', 'd', 'e') + ) + ) + ) + + assertEquals(exp, updated) + } + + @Test def testBadVwFlag(): Unit = { + val args = "--wap_ldf m --NO_A_VALID_VW_FLAG" + + val exp = VwError( + args, + "--wap_ldf m --NO_A_VALID_VW_FLAG --noconstant --csoaa_rank --ignore y", + "unrecognised option '--NO_A_VALID_VW_FLAG'" + ) + + VwMultilabelModel.updatedVwParams(args, Set.empty) match { + case Left(e) => assertEquals(exp, e) + case _ => fail() + } + } + @Test def testQuadraticCreation(): Unit = { val args = "--csoaa_ldf mc" val nss = Set("abc", "bcd") @@ -201,30 +239,4 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { case x => assertEquals("", x) } } - - @Test def testNamespaceErrors(): Unit = { - val args = "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + - "--cubic bcd --interactions dde --interactions abcde" - val updated = updatedVwParams(args, Set()) - - - val exp = Left( - NamespaceError( - "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + - "--interactions dde --interactions abcde", - Set(), - Map( - "ignore" -> Set('a'), - "ignore_linear" -> Set('b'), - "quadratic" -> Set('b', 'd'), - "cubic" -> Set('b', 'c', 'd', 'e'), - "interactions" -> Set('a', 'b', 'c', 'd', 'e') - ) - ) - ) - - assertEquals(exp, updated) - } - // TODO: More VW argument augmentation tests!!! - } From a7b168ab886a7a736c09f01e6b401d9541b6f738 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 9 Oct 2017 15:56:39 -0700 Subject: [PATCH 79/98] Updated VW label NS algo. Added test for when a NS can't be found. --- .../vw/multilabel/VwMultilabelRowCreator.scala | 7 ++++--- .../VwMultilabelParamAugmentationTest.scala | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 6c75df51..72b1dc7b 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -160,13 +160,14 @@ object VwMultilabelRowCreator { */ private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[LabelNamespaces] = { val found = - Stream - .from(FirstValidCharacter) + Iterator + .range(Char.MinValue, Char.MaxValue) .filter(c => !(usedNss contains c) && validCharForNamespace(c.toChar)) .take(2) + .toList found match { - case actual #:: dummy #:: Stream.Empty => + case actual :: dummy :: Nil => Option(LabelNamespaces(actual.toChar, dummy.toChar)) case _ => None } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 34b20124..5c1bbb3e 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -90,6 +90,19 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { assertEquals(exp, updated) } + @Test def testNoAvailableLabelNss(): Unit = { + // All namespaces taken. + val nss = (Char.MinValue to Char.MaxValue).map(_.toString).toSet + val validArgs = "--csoaa_ldf mc" + + VwMultilabelModel.updatedVwParams(validArgs, nss) match { + case Left(LabelNamespaceError(orig, nssOut)) => + assertEquals(validArgs, orig) + assertEquals(nss, nssOut) + case _ => fail() + } + } + @Test def testBadVwFlag(): Unit = { val args = "--wap_ldf m --NO_A_VALID_VW_FLAG" From 9c4362b7c1df7749e67de368dc661b084f374a22 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 10 Oct 2017 18:05:19 -0700 Subject: [PATCH 80/98] hacky solution to flags with options referencing files. Use tmp files :-( --- .../VwMultilabelParamAugmentation.scala | 80 +++++++++++++++++-- .../multilabel/VwMultilabelModelTest.scala | 22 ++--- .../VwMultilabelParamAugmentationTest.scala | 2 +- 3 files changed, 80 insertions(+), 24 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 21706628..5ed4d1da 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -1,9 +1,11 @@ package com.eharmony.aloha.models.vw.jni.multilabel +import java.io.File + import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner -import org.apache.commons.io.IOUtils +import org.apache.commons.io.{FileUtils, IOUtils} import vowpalWabbit.learner.VWLearners import scala.util.matching.Regex @@ -45,6 +47,10 @@ protected trait VwMultilabelParamAugmentation { a. For each interaction term (`-q`, `--quadratic`, `--cubic`, `--interactions`), replace it with an interaction term also interacted with the label namespace. This increases the arity of the interaction by 1. + 1. Finally, change the flag options that reference files to point to temp files so that + VW doesn't change the files. This may represent a problem if VW needs to read the file + in the option because although it should exist, it will be empty. + 1. Let VW doing any validations it can. * * ==Success Examples== * @@ -153,7 +159,13 @@ protected trait VwMultilabelParamAugmentation { ){ labelNs => val paramsWithoutRemoved = removeParams(vwParams) val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs) - validateVwParams(vwParams, updatedParams, !isQuiet(updatedParams)) + + val (finalParams, flagToFileMap) = replaceFileBasedFlags(updatedParams, FileBasedFlags) + val ps = validateVwParams( + vwParams, updatedParams, finalParams, flagToFileMap, !isQuiet(updatedParams) + ) + flagToFileMap.values foreach FileUtils.deleteQuietly // IO: Delete the files. + ps } } } @@ -162,14 +174,25 @@ protected trait VwMultilabelParamAugmentation { * VW Flags automatically resulting in an error. */ protected val UnrecoverableFlagSet: Set[String] = - Set("redefine", "stage_poly", "keep", "permutations") + Set("redefine", "stage_poly", "keep", "permutations", "autolink") /** * This is the capture group containing the content when the regex has been * padded with the pad function. */ - private val CaptureGroupWithContent = 2 - + private[this] val CaptureGroupWithContent = 2 + + private[this] val FileBasedFlags = Set( + "-f", "--final_regressor", + "--readable_model", + "--invert_hash", + "--output_feature_regularizer_binary", + "--output_feature_regularizer_text", + "-p", "--predictions", + "-r", "--raw_predictions", + "-c", "--cache", + "--cache_file" + ) /** * Pad the regular expression with a prefix and suffix that makes matching work. @@ -200,6 +223,7 @@ protected trait VwMultilabelParamAugmentation { private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r + private[this] val FlagsToRemove = Seq( QuadraticsShort, QuadraticsLong, @@ -259,6 +283,47 @@ protected trait VwMultilabelParamAugmentation { (paramsAfterRemoved.trim + additions).trim } + /** + * VW will actually update / replace files if files appear as options to flags. To overcome + * this, an attempt is made to detect flags referencing files and if found, replace the the + * files with temp files. These files should be deleted before exiting the main program. + * @param updatedParams the parameters after the updates. + * @param flagsWithFiles the flag + * @return a tuple2 of the final string to try with VW for validation along with the mapping + * from flag to file that was used. + */ + protected def replaceFileBasedFlags(updatedParams: String, flagsWithFiles: Set[String]): (String, Map[String, File]) = { + // This is rather hairy function. + + def flagRe(flags: Set[String], groupsForFlag: Int, c1: String, c2: String, c3: String) = + if (flags.nonEmpty) + Option(pad(flags.map(_ drop groupsForFlag).toVector.sorted.mkString(c1, c2, c3)).r) + else None + + // Get short and long flags. + val shrt = flagsWithFiles.filter(s => s.startsWith("-") && 2 == s.length && s.charAt(1).isLetterOrDigit) + val lng = flagsWithFiles.filter(s => s.startsWith("--") && 2 < s.length) + + val regexes = List( + flagRe(shrt, 1, "(-[", "", """])\s*(\S+)"""), + flagRe(lng, 2, "(--(", "|", """))\s+(\S+)""") + ).flatten + + regexes.foldLeft((updatedParams, Map[String, File]())) { case ((ps, ffm), r) => + // Fold right to not affect subsequent replacements. + r.findAllMatchIn(ps).foldRight((ps, ffm)) { case (m, (ps1, ffm1)) => + val f = File.createTempFile("REPLACED_", "_FILE") + f.deleteOnExit() // Attempt to be safe here. + + val flag = m.group(CaptureGroupWithContent) + val rep = s"$flag ${f.getCanonicalPath}" + val ps2 = ps1.take(m.start) + rep + ps1.drop(m.end) + val ffm2 = ffm1 + (flag -> f) + (ps2, ffm2) + } + } + } + protected def createLabelInteractions( interactions: Set[String], ignored: VWNsSet, @@ -348,12 +413,15 @@ protected trait VwMultilabelParamAugmentation { } } + // TODO: Change file protected def validateVwParams( orig: String, mod: String, + finalPs: String, + flagToFileMap: Map[String, File], addQuiet: Boolean ): Either[VwParamError, String] = { - val ps = if (addQuiet) s"--quiet $mod" else mod + val ps = if (addQuiet) s"--quiet $finalPs" else finalPs Try { VWLearners.create[ExpectedLearner](ps) } match { case Success(m) => diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 68a46fc5..77946d93 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -121,33 +121,17 @@ class VwMultilabelModelTest { // Prepare parameters for VW model that will be trained. // ------------------------------------------------------------------------------------ - val binaryVwModel = File.createTempFile("vw_", ".bin.model") binaryVwModel.deleteOnExit() val cacheFile = File.createTempFile("vw_", ".cache") cacheFile.deleteOnExit() - // probs / costs - // --csoaa_ldf mc vs m - // --csoaa_rank - // --loss_function logistic add this - // --noconstant - // --ignore_linear X - // --ignore $dummyLabelNs - // turn first order features into quadratics - // turn quadratic features into cubics - // - val vwParams = + val origParams = s""" | --quiet | --csoaa_ldf mc - | --csoaa_rank | --loss_function logistic - | -q ${labelNs}X - | --noconstant - | --ignore_linear X - | --ignore $dummyLabelNs | -f ${binaryVwModel.getCanonicalPath} | --passes 50 | --cache_file ${cacheFile.getCanonicalPath} @@ -156,6 +140,10 @@ class VwMultilabelModelTest { | --decay_learning_rate 0.9 """.stripMargin.trim.replaceAll("\n", " ") + val vwParams = VwMultilabelModel.updatedVwParams(origParams, Set("X")) fold ( + e => throw new Exception(e.errorMessage), + ps => ps + ) // ------------------------------------------------------------------------------------ // Train VW model diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 5c1bbb3e..0da75dc3 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -22,7 +22,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testExpectedUnrecoverableFlags(): Unit = { assertEquals( "Unrecoverable flags has changed.", - Set("redefine", "stage_poly", "keep", "permutations"), + Set("redefine", "stage_poly", "keep", "permutations", "autolink"), UnrecoverableFlagSet ) } From 6df3f836950dc5a622646c0c78decb67e86beeb0 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 30 Oct 2017 16:28:29 -0700 Subject: [PATCH 81/98] Looks good. --- .../VwMultilabelParamAugmentation.scala | 2 +- .../multilabel/VwMultilabelModelTest.scala | 12 ++++++- .../VwMultilabelParamAugmentationTest.scala | 34 ++++++++++++++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index 5ed4d1da..dfa408be 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -180,7 +180,7 @@ protected trait VwMultilabelParamAugmentation { * This is the capture group containing the content when the regex has been * padded with the pad function. */ - private[this] val CaptureGroupWithContent = 2 + protected val CaptureGroupWithContent = 2 private[this] val FileBasedFlags = Set( "-f", "--final_regressor", diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 77946d93..df2529dd 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -26,6 +26,7 @@ import spray.json.{DefaultJsonProtocol, JsonWriter, RootJsonFormat, pimpString} import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} import scala.annotation.tailrec +import scala.collection.breakOut import scala.util.Random /** @@ -117,6 +118,7 @@ class VwMultilabelModelTest { new VwMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet). getRowCreator(semantics, datasetJson).get + // ------------------------------------------------------------------------------------ // Prepare parameters for VW model that will be trained. // ------------------------------------------------------------------------------------ @@ -140,11 +142,17 @@ class VwMultilabelModelTest { | --decay_learning_rate 0.9 """.stripMargin.trim.replaceAll("\n", " ") - val vwParams = VwMultilabelModel.updatedVwParams(origParams, Set("X")) fold ( + // Get the namespace names from the row creator. + val nsNames = rc.namespaces.map(_._1)(breakOut): Set[String] + + // Take the parameters and augment with additional parameters to make + // multilabel w/ probabilities work correctly. + val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames) fold ( e => throw new Exception(e.errorMessage), ps => ps ) + // ------------------------------------------------------------------------------------ // Train VW model // ------------------------------------------------------------------------------------ @@ -156,6 +164,7 @@ class VwMultilabelModelTest { } vwLearner.close() + // ------------------------------------------------------------------------------------ // Create Aloha model JSON // ------------------------------------------------------------------------------------ @@ -180,6 +189,7 @@ class VwMultilabelModelTest { val modelTry = factory.fromString(modelJson.prettyPrint) // Use `.compactPrint` in prod. val model = modelTry.get + // ------------------------------------------------------------------------------------ // Test Aloha Model // ------------------------------------------------------------------------------------ diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 0da75dc3..02f6995a 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -1,5 +1,6 @@ package com.eharmony.aloha.models.vw.jni.multilabel +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith @@ -238,7 +239,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } } - @Test def interactionsWithSelf(): Unit = { + @Test def testInteractionsWithSelf(): Unit = { val nss = Set("a") val args = "--wap_ldf m -qaa --cubic aaa --interactions aaaa" val exp = "--wap_ldf m --noconstant --csoaa_rank --ignore y --ignore_linear a " + @@ -252,4 +253,35 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { case x => assertEquals("", x) } } + + @Test def testClassLabels(): Unit = { + val args = "--wap_ldf m" + val nss = Set.empty[String] + + VwMultilabelModel.updatedVwParams(args, nss) match { + case Left(_) => fail() + case Right(p) => + val ignored = + VwMultilabelRowCreator.determineLabelNamespaces(nss) match { + case None => fail() + case Some(labelNs) => + val ignoredNs = + Ignore + .findAllMatchIn(p) + .map(m => m.group(CaptureGroupWithContent)) + .reduce(_ + _) + .toCharArray + .toSet + + assertEquals(Set('y'), ignoredNs) + labelNs.labelNs + } + + assertEquals('Y', ignored) + } + } } + +object VwMultilabelParamAugmentationTest { + +} \ No newline at end of file From cd500069d13f7adaa56ca0c1c5a96d347e01c065 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 1 Nov 2017 10:58:07 -0700 Subject: [PATCH 82/98] Precompute positive and negative dummy class strings. --- .../multilabel/VwMultilabelRowCreator.scala | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 72b1dc7b..a862da2d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -32,6 +32,14 @@ final case class VwMultilabelRowCreator[-A, K]( @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + // Precompute these for efficiency rather recompute than inside a hot loop. + // Notice these are not lazy vals. + + private[this] val negativeDummyStr = + s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" + + private[this] val positiveDummyStr = + s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { val (missingAndErrs, features) = featuresFunction(a) @@ -47,7 +55,8 @@ final case class VwMultilabelRowCreator[-A, K]( defaultNamespace, namespaces, classNs, - dummyClassNs + negativeDummyStr, + positiveDummyStr ) (missingAndErrs, x) @@ -62,7 +71,7 @@ object VwMultilabelRowCreator { * dummy classes uses an ID outside of the allowable range of feature indices: * 2^32^. */ - private[this] val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString + private val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString /** * VW allows long-based feature indices, but Aloha only allow's 32-bit indices @@ -70,7 +79,7 @@ object VwMultilabelRowCreator { * dummy classes uses an ID outside of the allowable range of feature indices: * 2^32^ + 1. */ - private[this] val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString + private val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the @@ -78,7 +87,7 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a positive example is designated to be one, * so the cost (or negative reward) is -1. */ - private[this] val PositiveCost = (-1).toString + private val PositiveCost = (-1).toString /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the @@ -86,12 +95,11 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a negative example is designated to be zero, * so the cost (or negative reward) is 0. */ - private[this] val NegativeCost = 0.toString - - private[this] val PositiveDummyClassFeature = "P" + private val NegativeCost = 0.toString - private[this] val NegativeDummyClassFeature = "N" + private val PositiveDummyClassFeature = "P" + private val NegativeDummyClassFeature = "N" /** * "shared" is a special keyword in VW multi-class (multi-row) format. @@ -101,8 +109,6 @@ object VwMultilabelRowCreator { */ private[this] val SharedFeatureIndicator = "shared" + " " - private[this] val FirstValidCharacter = 0 // Could probably be '0'.toInt - private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ')) /** @@ -173,7 +179,6 @@ object VwMultilabelRowCreator { } } - /** * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. * @param features (non-label dependent) features shared across all labels. @@ -196,7 +201,8 @@ object VwMultilabelRowCreator { defaultNs: List[Int], namespaces: List[(String, List[Int])], classNs: Char, - dummyClassNs: Char + negativeDummyStr: String, + positiveDummyStr: String ): Array[String] = { val n = indices.size @@ -213,9 +219,8 @@ object VwMultilabelRowCreator { // These string interpolations are computed over and over but will always be the same // for a given dummyClassNs. - // TODO: Precompute these in a class instance and pass in as parameters. - x(1) = s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" - x(2) = s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" + x(1) = negativeDummyStr + x(2) = positiveDummyStr // This is mutable because we want speed. var i = 0 @@ -231,7 +236,6 @@ object VwMultilabelRowCreator { x } - /** * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. * @param features (non-label dependent) features shared across all labels. From ffd10ee620fc2f9cc86d9066025bc8a71890b39f Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 1 Nov 2017 15:26:34 -0700 Subject: [PATCH 83/98] Adding numUniqueLabels parameter to updatedVwParams to add VW's --ring_size parameter --- .../VwMultilabelParamAugmentation.scala | 21 +++++-- .../multilabel/VwMultilabelModelTest.scala | 4 +- .../VwMultilabelParamAugmentationTest.scala | 59 ++++++++++--------- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala index dfa408be..c841ba5d 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -131,9 +131,16 @@ protected trait VwMultilabelParamAugmentation { * @param namespaceNames it is assumed that `namespaceNames` is a superset * of all of the namespaces referred to by any flags * found in `vwParams`. + * @param numUniqueLabels the number of unique labels in the training set. + * This is used to calculate the appropriate VW + * `ring_size` parameter. * @return */ - def updatedVwParams(vwParams: String, namespaceNames: Set[String]): Either[VwParamError, String] = { + def updatedVwParams( + vwParams: String, + namespaceNames: Set[String], + numUniqueLabels: Int + ): Either[VwParamError, String] = { lazy val unrecovFlags = unrecoverableFlags(vwParams) if (WapOrCsoaa.findFirstMatchIn(vwParams).isEmpty) @@ -158,9 +165,11 @@ protected trait VwMultilabelParamAugmentation { Left(LabelNamespaceError(vwParams, namespaceNames)): Either[VwParamError, String] ){ labelNs => val paramsWithoutRemoved = removeParams(vwParams) - val updatedParams = addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs) + val updatedParams = + addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs, numUniqueLabels) val (finalParams, flagToFileMap) = replaceFileBasedFlags(updatedParams, FileBasedFlags) + val ps = validateVwParams( vwParams, updatedParams, finalParams, flagToFileMap, !isQuiet(updatedParams) ) @@ -223,7 +232,6 @@ protected trait VwMultilabelParamAugmentation { private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r - private[this] val FlagsToRemove = Seq( QuadraticsShort, QuadraticsLong, @@ -251,7 +259,8 @@ protected trait VwMultilabelParamAugmentation { oldIgnored: VWNsSet, oldIgnoredLinear: VWNsSet, oldInteractions: Set[String], - labelNs: LabelNamespaces + labelNs: LabelNamespaces, + numUniqueLabels: Int ): String = { val i = oldIgnored + labelNs.dummyLabelNs @@ -274,11 +283,13 @@ protected trait VwMultilabelParamAugmentation { val ints = hos.toSeq.sorted.map(ho => s"--interactions $ho").mkString(" ") val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" + val rs = s"--ring_size ${numUniqueLabels + VwSparseMultilabelPredictor.AddlVwRingSize}" + // This is non-empty b/c i is non-empty. val ig = s"--ignore ${i.mkString("")}" // Consolidate whitespace because there shouldn't be whitespace in these flags' options. - val additions = s" --noconstant --csoaa_rank $ig $igLin $quadratics $cubics $ints" + val additions = s" --noconstant --csoaa_rank $rs $ig $igLin $quadratics $cubics $ints" .replaceAll("\\s+", " ") (paramsAfterRemoved.trim + additions).trim } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index df2529dd..ade05aa9 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -147,7 +147,7 @@ class VwMultilabelModelTest { // Take the parameters and augment with additional parameters to make // multilabel w/ probabilities work correctly. - val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames) fold ( + val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames, 2) fold ( e => throw new Exception(e.errorMessage), ps => ps ) @@ -362,7 +362,7 @@ object VwMultilabelModelTest { private val LabelNamespaces(labelNs, dummyLabelNs) = VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get - private def vwTrainingParams(modelFile: File = tmpFile()) = { + private def vwTrainingParams(modelFile: File) = { // NOTES: // 1. `--csoaa_rank` is needed by VW to make a VWActionScoresLearner. diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 02f6995a..900881d2 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -11,10 +11,11 @@ import org.junit.runners.BlockJUnit4ClassRunner */ @RunWith(classOf[BlockJUnit4ClassRunner]) class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { + import VwMultilabelParamAugmentationTest._ @Test def testNotCsoaaWap(): Unit = { val args = "" - VwMultilabelModel.updatedVwParams(args, Set.empty) match { + VwMultilabelModel.updatedVwParams(args, Set.empty, DefaultNumLabels) match { case Left(NotCsoaaOrWap(ps)) => assertEquals(args, ps) case _ => fail() } @@ -30,7 +31,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testUnrecoverable(): Unit = { val unrec = UnrecoverableFlagSet.iterator.map { f => - VwMultilabelModel.updatedVwParams(s"--csoaa_ldf mc --$f", Set.empty) + VwMultilabelModel.updatedVwParams(s"--csoaa_ldf mc --$f", Set.empty, DefaultNumLabels) }.toList unrec foreach { @@ -47,7 +48,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testIgnoredNotInNsSet(): Unit = { val args = "--csoaa_ldf mc --ignore a" val origNss = Set.empty[String] - VwMultilabelModel.updatedVwParams(args, origNss) match { + VwMultilabelModel.updatedVwParams(args, origNss, DefaultNumLabels) match { case Left(NamespaceError(o, nss, bad)) => assertEquals(args, o) assertEquals(origNss, nss) @@ -59,7 +60,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testIgnoredNotInNsSet2(): Unit = { val args = "--csoaa_ldf mc --ignore ab" val origNss = Set("a") - VwMultilabelModel.updatedVwParams(args, origNss) match { + VwMultilabelModel.updatedVwParams(args, origNss, DefaultNumLabels) match { case Left(NamespaceError(o, nss, bad)) => assertEquals(args, o) assertEquals(origNss, nss) @@ -71,7 +72,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testNamespaceErrors(): Unit = { val args = "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + "--cubic bcd --interactions dde --interactions abcde" - val updated = updatedVwParams(args, Set()) + val updated = updatedVwParams(args, Set(), DefaultNumLabels) val exp = Left( NamespaceError( @@ -96,7 +97,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val nss = (Char.MinValue to Char.MaxValue).map(_.toString).toSet val validArgs = "--csoaa_ldf mc" - VwMultilabelModel.updatedVwParams(validArgs, nss) match { + VwMultilabelModel.updatedVwParams(validArgs, nss, DefaultNumLabels) match { case Left(LabelNamespaceError(orig, nssOut)) => assertEquals(validArgs, orig) assertEquals(nss, nssOut) @@ -109,11 +110,11 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val exp = VwError( args, - "--wap_ldf m --NO_A_VALID_VW_FLAG --noconstant --csoaa_rank --ignore y", + s"--wap_ldf m --NO_A_VALID_VW_FLAG --noconstant --csoaa_rank $DefaultRingSize --ignore y", "unrecognised option '--NO_A_VALID_VW_FLAG'" ) - VwMultilabelModel.updatedVwParams(args, Set.empty) match { + VwMultilabelModel.updatedVwParams(args, Set.empty, DefaultNumLabels) match { case Left(e) => assertEquals(exp, e) case _ => fail() } @@ -124,9 +125,9 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val nss = Set("abc", "bcd") // Notice: ignore_linear and quadratics are in sorted order. - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear ab -qYa -qYb" - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -137,10 +138,10 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val nss = Set("abc", "bcd") // Notice: ignore_linear and quadratics are in sorted order. - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear ab -qYb" - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -149,13 +150,13 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testCubicCreation(): Unit = { val args = "--csoaa_ldf mc -qab --quadratic cb" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear abcd " + "-qYa -qYb -qYc -qYd " + "--cubic Yab --cubic Ybc" // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -164,13 +165,13 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testCubicCreationIgnoredLinear(): Unit = { val args = "--csoaa_ldf mc -qab --quadratic cb --ignore_linear d" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear abcd " + "-qYa -qYb -qYc " + "--cubic Yab --cubic Ybc" // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -179,13 +180,13 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testCubicCreationIgnored(): Unit = { val args = "--csoaa_ldf mc -qab --quadratic cb --ignore c" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cy " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore cy " + "--ignore_linear abd " + "-qYa -qYb -qYd " + "--cubic Yab" // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -194,13 +195,13 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testCubicWithInteractionsCreationIgnored(): Unit = { val args = "--csoaa_ldf mc --interactions ab --interactions cb --ignore c --ignore d" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore cdy " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore cdy " + "--ignore_linear ab " + "-qYa -qYb " + "--cubic Yab" // Notice: ignore_linear and quadratics are in sorted order. - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -209,11 +210,11 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testHigherOrderInteractions(): Unit = { val args = "--csoaa_ldf mc --interactions abcd --ignore_linear abcd" val nss = Set("abc", "bcd", "cde", "def") - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear abcd " + "--interactions Yabcd" - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -225,7 +226,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val args = s"--csoaa_ldf mc --interactions ab --interactions abc " + "--interactions abcd --interactions abcde" - val exp = "--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + "--ignore_linear abcde " + "-qYa -qYb -qYc -qYd -qYe " + "--cubic Yab " + @@ -233,7 +234,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { "--interactions Yabcd " + "--interactions Yabcde" - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case _ => fail() } @@ -242,13 +243,13 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { @Test def testInteractionsWithSelf(): Unit = { val nss = Set("a") val args = "--wap_ldf m -qaa --cubic aaa --interactions aaaa" - val exp = "--wap_ldf m --noconstant --csoaa_rank --ignore y --ignore_linear a " + + val exp = s"--wap_ldf m --noconstant --csoaa_rank $DefaultRingSize --ignore y --ignore_linear a " + "-qYa " + "--cubic Yaa " + "--interactions Yaaa " + "--interactions Yaaaa" - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Right(s) => assertEquals(exp, s) case x => assertEquals("", x) } @@ -258,7 +259,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { val args = "--wap_ldf m" val nss = Set.empty[String] - VwMultilabelModel.updatedVwParams(args, nss) match { + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { case Left(_) => fail() case Right(p) => val ignored = @@ -283,5 +284,7 @@ class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { } object VwMultilabelParamAugmentationTest { - + val DefaultNumLabels = 0 + val DefaultRingSize = + s"--ring_size ${DefaultNumLabels + VwSparseMultilabelPredictor.AddlVwRingSize}" } \ No newline at end of file From 9592b09ea04f48fef92f5d4e33b3c22d64931951 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 3 Nov 2017 16:03:56 -0700 Subject: [PATCH 84/98] stateful row creator and reservoir sampling. --- .../aloha/dataset/StatefulRowCreator.scala | 14 ++ .../com/eharmony/aloha/util/rand/Rand.scala | 120 ++++++++++++++++++ .../eharmony/aloha/util/rand/RandTest.scala | 97 ++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala create mode 100644 aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala new file mode 100644 index 00000000..57705ac2 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -0,0 +1,14 @@ +package com.eharmony.aloha.dataset + +/** + * Created by ryan.deak on 11/2/17. + */ +trait StatefulRowCreator[-A, +B, S] extends Serializable { + def initialState: S + + def apply(a: A, s: S): ((MissingAndErroneousFeatureInfo, Option[B]), S) + + def apply(as: Iterator[A], s: S): Iterator[(MissingAndErroneousFeatureInfo, Option[B])] + + def apply(as: Vector[A], s: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala new file mode 100644 index 00000000..9ed0780e --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala @@ -0,0 +1,120 @@ +package com.eharmony.aloha.util.rand + +/** + * Some stateless random sampling utilities. + * + * @author deaktator + * @since 11/3/2017 + */ +private[aloha] trait Rand { + + type Seed = Long + type Index = Short + + /** + * Perform the initial scramble. This should be called '''''once''''' on the initial + * seed prior to the first call to `sampleCombination`. + * @param seed an initial seed + * @return a more scrambled seed. + */ + protected def initSeedScramble(seed: Seed): Seed = + (seed ^ 0x5DEECE66DL) & 0xFFFFFFFFFFFFL + + + /** + * Sample a ''k''-combination from a population of ''n''. + * + * This algorithm uses a linear congruential pseudorandom number generator (see Knuth) + * to perform reservoir sampling via "''Algorithm R''". + * + * It is ~ '''''O'''''('''''n'''''). + * + * If `n` ≤ `k`, then return 0, ..., `n` - 1; otherwise, if `k` < `n`, the + * returned array have length `k` with values between 0 and `n - 1` (inclusive) + * but it is '''NOT''' guaranteed to be sorted. + * + * '''NOTE''': This is a pure function. It produces the same results as if + * `java.util.Random` was used to perform reservoir sampling but since it doesn't + * carry state, this can be trivially operated in parallel with no locking or CAS + * loop overhead. The consequence is that the `seed` must be provided on every call + * and a new seed will be returned as part of the output. + * + * To get this function to act like `java.util.Random`, the first time it is called, the + * seed should be produce by running the desired seed through `initSeedScramble`. For + * instance: + * + * {{{ + * val (kComb1, newSeed1) = sampleCombination(4, 2, initSeedScramble(0)) + * val (kComb2, newSeed2) = sampleCombination(4, 2, newSeed1) + * }}} + * + * For more information, see: + * + - [[http://luc.devroye.org/chapter_twelve.pdf Non-Uniform Random Variate Generation, Luc Devroye, Ch. 12]] + - [[https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R Reservoir sampling (Wikipedia)]] + - [[http://grepcode.com/file/repository.grepcode.com/java/root/jdk/openjdk/6-b14/java/util/Random.java grepcode for Random]] + - [[https://en.wikipedia.org/wiki/Linear_congruential_generator Linear congruential generator (Wikipedia)]] + * + * @param n population size (< 2^15^). + * @param k combination size (< 2^15^). + * @param seed the seed to use for random selection + * @return a tuple 2 containing the array of 0-based indices representing + * the ''k''-combination and a new random seed. + */ + protected def sampleCombination(n: Int, k: Int, seed: Seed): (Array[Index], Seed) = { + + // NOTE: This isn't idiomatic Scala code but it will operate in hot loops in minimizing + // object creation is important. + + if (n <= k) { + ((0 until n).map(i => i.toShort).toArray, seed) + } + else { + var i = k + 1 + var nextSeed = seed + var reservoirSwapInd = 0 + var bits = 0 + var value = 0 + + // Fill reservoir with the first k indices. + val reservoir = (0 until k).map(i => i.toShort).toArray + + // Loop over the rest of the indices outside the reservoir and determine if + // swapping should occur. If so, swap the index in the reservoir with the + // current element, i - 1. + while (i <= n) { + reservoirSwapInd = + if ((i & -i) == i) { + // i = 2^j, for some j. + + // To understand these constants, see + // https://en.wikipedia.org/wiki/Linear_congruential_generator + nextSeed = (nextSeed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL + + // 17 = log2(modulus) - shift = 48 - 31 + ((i * (nextSeed >>> 17)) >> 31).toInt + } + else { + // Loop at least once per swap index. + do { + nextSeed = (nextSeed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL + bits = (nextSeed >>> 17).toInt + value = bits % i + } while (bits - value + (i - 1) < 0) + + value + } + + // This is the key to maintaining the proper probabilities in the reservoir sampling. + // Wikipedia does a good job describing this in the proof by induction in the + // explanation for Algorithm R. + if (reservoirSwapInd < k) + reservoir(reservoirSwapInd) = (i - 1).toShort + + i += 1 + } + + (reservoir, nextSeed) + } + } +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala new file mode 100644 index 00000000..02212d0b --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -0,0 +1,97 @@ +package com.eharmony.aloha.util.rand + +import java.util.Random + +import org.junit.Assert.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +import scala.annotation.tailrec + +/** + * Test the randomness of the random! + * + * @author deaktator + * @since 11/3/2017 + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class RandTest extends Rand { + import RandTest._ + + @Test def testSampleCombinationProbabilities(): Unit = { + val trials = 20 + val maxN = 5 + val r = new Random(0x105923abdd8L) + + val results = Iterator.fill(trials){ + val initSeed = r.nextLong() + val samples = 1000 + r.nextInt(9001) + val n = r.nextInt(maxN + 1) + val k = r.nextInt(n + 1) + val sd = samplingDist(initSeed, samples, n, k) + (samples, n, k, sd) + } + + checkResult(results) + } + + /** + * For any `n` and `k`, `choose(n, k)` states should be sampled with uniform probability. + * @param results different `numSamples`, `n`, `k`, and ''sample probabilities'' for + * each state sampled. + * @tparam A type of random variable. + */ + private def checkResult[A](results: Iterator[(NumSamples, N, K, Distribution[A])]): Unit = { + results.foreach { case (samples, n, k, statesAndProbs) => + // This can be driven down by increasing `samples`. + // This is proportional to 1 / sqrt(samples). + // So, if samples is 10000, this is about 10%. + val pctDiffFromExpPr = 100 * (10 / math.sqrt(samples)) + + val expStates = choose(n, k) + val expPr = 1d / expStates + + // Check that all states are sampled. + assertEquals(expStates, statesAndProbs.size) + + statesAndProbs.foreach { case (state, pr) => + // Check probabilities are within reason. + assertEquals(s"for key '$state'", expPr, pr, expPr * (pctDiffFromExpPr / 100)) + } + } + } + + private def samplingDist(initSeed: Long, samples: Int, n: Int, k: Int): Distribution[String] = + drawSamples(initSeed, samples, n, k).mapValues { c => c / samples.toDouble } + + private def drawSamples(initSeed: Long, samples: Int, n: Int, k: Int): Map[String, NumSamples] = { + val seed = initSeedScramble(initSeed) + + val samplesAndSeed = + (1 to samples) + .foldLeft((Map.empty[String, Int], seed)){ case ((m, s), _) => + val (ind, newSeed) = sampleCombination(n, k, s) + val key = ind.sorted.mkString(",") + val updatedMap = m + (key -> (m.getOrElse(key, 0) + 1)) + (updatedMap, newSeed) + } + + // Throw away final seed. + samplesAndSeed._1 + } + + private def choose(n: Int, k: Int): Long = { + @tailrec def fact(n: Int, p: Long = 1): Long = + if (n <= 1) p else fact(n - 1, n * p) + + fact(n) / (fact(k) * fact(n-k)) + } +} + +object RandTest { + private type NumSamples = Int + private type N = Int + private type K = Int + private type Distribution[A] = Map[A, Double] +} From 8e478f3827c80fae20b96df0e0ef1716fe40f1f4 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Fri, 3 Nov 2017 17:21:08 -0700 Subject: [PATCH 85/98] provided concrete implementations of iterator and vector apply method. --- .../aloha/dataset/StatefulRowCreator.scala | 72 +++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index 57705ac2..88067511 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -1,14 +1,78 @@ package com.eharmony.aloha.dataset /** + * A row creator that requires state. This state should be modeled functionally, meaning + * implementations should be referentially transparent. + * * Created by ryan.deak on 11/2/17. */ trait StatefulRowCreator[-A, +B, S] extends Serializable { - def initialState: S - def apply(a: A, s: S): ((MissingAndErroneousFeatureInfo, Option[B]), S) + /** + * Some initial state that can be used on the very first call to `apply(A, S)`. + * @return some state. + */ + val initialState: S - def apply(as: Iterator[A], s: S): Iterator[(MissingAndErroneousFeatureInfo, Option[B])] + /** + * Given an `a` and some `state`, produce output, including a new state. + * + * When using this function, the user is responsible for keeping track of, + * and providing the state. + * + * The implementation of this function should be referentially transparent. + * + * @param a input + * @param state the state + * @return a tuple where the first element is a Tuple2 whose first element is + * missing and error information and second element is an optional result. + * The second element of the outer Tuple2 is the new state. + */ + def apply(a: A, state: S): ((MissingAndErroneousFeatureInfo, Option[B]), S) - def apply(as: Vector[A], s: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) + /** + * Apply the `apply(A, S)` method to the elements of the iterator. In the first + * application of `apply(A, S)`, `state` will be used as the state. In subsequent + * applications, the state will come from the state generated in the output of the + * previous application of `apply(A, S)`. + * + * @param as Note the first element of `as` will be forced in this method in order + * to construct the output. + * @param state the initial state to use at the start of the iterator. + * @return an iterator containing the a mapped to a + * `(MissingAndErroneousFeatureInfo, Option[B])` along with the resulting + * state that is created in the process. + */ + def apply(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { + if (as.isEmpty) + Iterator.empty + else { + // Force the first A. Then apply the `apply` transformation to get + // the initial element of a scanLeft. Inside the scanLeft, use the + // state outputted by previous `apply` calls as input to current + // calls to `apply(A, S)`. + val firstA = as.next() + val initEl = apply(firstA, state) + as.scanLeft(initEl){ case ((_, mostRecentState), a) => apply(a, mostRecentState) } + } + } + + /** + * Apply the `apply(A, S)` method to the elements of the Vector. In the first + * application of `apply(A, S)`, `state` will be used as the state. In subsequent + * applications, the state will come from the state generated in the output of the + * previous application of `apply(A, S)`. + * + * @param as input to map. + * @param state the initial state to use at the start of the Vector. + * @return a Tuple2 where the first element is a vector of results and the second + * element is the resulting state. + */ + def apply(as: Vector[A], state: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) = { + as.foldLeft((Vector.empty[(MissingAndErroneousFeatureInfo, Option[B])], state)){ + case ((bs, s), a) => + val (b, newS) = apply(a, s) + (bs :+ b, newS) + } + } } From b7594fa4efe1046c1a6156448d851aa1749fae71 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Sat, 4 Nov 2017 10:23:21 -0700 Subject: [PATCH 86/98] more purity in test. --- .../eharmony/aloha/util/rand/RandTest.scala | 110 +++++++++++++----- 1 file changed, 80 insertions(+), 30 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala index 02212d0b..aed115fc 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -2,8 +2,8 @@ package com.eharmony.aloha.util.rand import java.util.Random -import org.junit.Assert.assertEquals import org.junit.Test +import org.junit.Assert.fail import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner @@ -20,46 +20,88 @@ class RandTest extends Rand { import RandTest._ @Test def testSampleCombinationProbabilities(): Unit = { - val trials = 20 - val maxN = 5 - val r = new Random(0x105923abdd8L) + val failures = findFailures( + trials = 25, + maxN = 6, + minSamples = 1000, + maxSamples = 10000, + seed = 0 + ) + + reportFailures(failures) + } + + private def findFailures(trials: Int, maxN: Int, minSamples: Int, maxSamples: Int, seed: Long) = { + val r = new Random(seed) - val results = Iterator.fill(trials){ + Iterator.fill(trials){ val initSeed = r.nextLong() - val samples = 1000 + r.nextInt(9001) + val samples = minSamples + r.nextInt(maxSamples - minSamples + 1) val n = r.nextInt(maxN + 1) val k = r.nextInt(n + 1) - val sd = samplingDist(initSeed, samples, n, k) - (samples, n, k, sd) + (initSeed, samples, n, k) + } flatMap { case (initSeed, samples, n, k) => + val dist = samplingDist(initSeed, samples, n, k) + checkDistributionUniformity(samples, n, k, dist).toIterable } + } + + private def reportFailures[A](failures: Iterator[TestFailure[A]]): Unit = { + if (failures.nonEmpty) { + val errMsg = + failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => - checkResult(results) + val thisFail = + s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + + fails.mkString("\n\t", "\n\t", "\n\n") + + msg + thisFail + } + fail(errMsg) + } } /** * For any `n` and `k`, `choose(n, k)` states should be sampled with uniform probability. - * @param results different `numSamples`, `n`, `k`, and ''sample probabilities'' for - * each state sampled. + * @param samples number of samples drawn + * @param n number of objects from which to choose. + * @param k number of elements drawn from the `n` objects. + * @param dist distribution created from `samples` ''samples''. * @tparam A type of random variable. + * @return a potential error. */ - private def checkResult[A](results: Iterator[(NumSamples, N, K, Distribution[A])]): Unit = { - results.foreach { case (samples, n, k, statesAndProbs) => - // This can be driven down by increasing `samples`. - // This is proportional to 1 / sqrt(samples). - // So, if samples is 10000, this is about 10%. - val pctDiffFromExpPr = 100 * (10 / math.sqrt(samples)) - - val expStates = choose(n, k) - val expPr = 1d / expStates - - // Check that all states are sampled. - assertEquals(expStates, statesAndProbs.size) - - statesAndProbs.foreach { case (state, pr) => - // Check probabilities are within reason. - assertEquals(s"for key '$state'", expPr, pr, expPr * (pctDiffFromExpPr / 100)) + private def checkDistributionUniformity[A]( + samples: Int, + n: Int, + k: Int, + dist: Distribution[A] + ): Option[TestFailure[A]] = { + + val expStates = choose(n, k) + val expPr = 1d / expStates + + // This can be driven down by increasing `samples`. + // This is proportional to 1 / sqrt(samples). + // So, if samples is 10000, this is about 10%. + val pctDiffFromExpPr = 100 * (10 / math.sqrt(samples)) + val delta = expPr * (pctDiffFromExpPr / 100) + + // Check that all states are sampled. + val allStates = + if (expStates == dist.size) + List.empty[FailureReason] + else List(MissingStates(expStates.toInt, dist.size)) + + val errors = + dist.foldRight(allStates){ case ((state, pr), errs) => + if (math.abs(expPr - pr) < delta) + errs + else WrongProbability(state, expPr, pr, delta) :: errs } - } + + if (errors.isEmpty) + None + else Option(TestFailure(samples, n, k, errors, dist)) } private def samplingDist(initSeed: Long, samples: Int, n: Int, k: Int): Distribution[String] = @@ -91,7 +133,15 @@ class RandTest extends Rand { object RandTest { private type NumSamples = Int - private type N = Int - private type K = Int private type Distribution[A] = Map[A, Double] + + private sealed trait FailureReason + private final case class MissingStates(expected: Int, actual: Int) extends FailureReason { + override def toString = s"States missing. Expected $expected states. Found $actual." + } + + private final case class WrongProbability[A](state: A, expected: Double, actual: Double, delta: Double) extends FailureReason { + override def toString = s"Incorrect probability for state $state. Expected pr = $expected +- $delta. Found pr = $actual." + } + private final case class TestFailure[A](samples: Int, n: Int, k: Int, failures: Seq[FailureReason], dist: Distribution[A]) } From 819095a5c8e825c79d7ce938fbdb50f748fb1c08 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Sat, 4 Nov 2017 10:45:34 -0700 Subject: [PATCH 87/98] separated pure and impure code. --- .../eharmony/aloha/util/rand/RandTest.scala | 61 +++++++++++++------ 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala index aed115fc..c4da14a2 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -27,11 +27,35 @@ class RandTest extends Rand { maxSamples = 10000, seed = 0 ) - - reportFailures(failures) + + val failureMsg = allFailuresMsg(failures) + reportFailureIfPresent(failureMsg) } - private def findFailures(trials: Int, maxN: Int, minSamples: Int, maxSamples: Int, seed: Long) = { + /** + * Produce junit failure if necessary. ('''IMPURE''') + * @param failureMsg A failure message that when empty, will cause the tests to fail. + */ + private def reportFailureIfPresent(failureMsg: String): Unit = { + if (failureMsg.nonEmpty) + fail(failureMsg) + } + + /** + * Sets up the sampling scenarios. ('''IMPURE''') + * @param trials number of trials to run. + * @param maxN maximum N value + * @param minSamples minimum number of samples to draw per trial. + * @param maxSamples maximum number of samples to draw per trial. + * @param seed a random seed to use for generating the parameters in the scenarios. + * @return sampling scenarios + */ + private def samplingScenarios( + trials: Int, + maxN: Int, + minSamples: Int, + maxSamples: Int, + seed: Long): Iterator[SamplingScenario] = { val r = new Random(seed) Iterator.fill(trials){ @@ -39,25 +63,26 @@ class RandTest extends Rand { val samples = minSamples + r.nextInt(maxSamples - minSamples + 1) val n = r.nextInt(maxN + 1) val k = r.nextInt(n + 1) - (initSeed, samples, n, k) - } flatMap { case (initSeed, samples, n, k) => - val dist = samplingDist(initSeed, samples, n, k) - checkDistributionUniformity(samples, n, k, dist).toIterable + SamplingScenario(initSeed, samples, n, k) } } - private def reportFailures[A](failures: Iterator[TestFailure[A]]): Unit = { - if (failures.nonEmpty) { - val errMsg = - failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => + private def findFailures(trials: Int, maxN: Int, minSamples: Int, maxSamples: Int, seed: Long) = { + samplingScenarios(trials, maxN, minSamples, maxSamples, seed).flatMap { + case SamplingScenario(initSeed, samples, n, k) => + val dist = samplingDist(initSeed, samples, n, k) + checkDistributionUniformity(samples, n, k, dist).toIterable + } + } - val thisFail = - s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + - fails.mkString("\n\t", "\n\t", "\n\n") + private def allFailuresMsg[A](failures: Iterator[TestFailure[A]]): String = { + failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => - msg + thisFail - } - fail(errMsg) + val thisFail = + s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + + fails.mkString("\n\t", "\n\t", "\n\n") + + msg + thisFail } } @@ -135,6 +160,8 @@ object RandTest { private type NumSamples = Int private type Distribution[A] = Map[A, Double] + private case class SamplingScenario(initSeed: Long, samples: Int, n: Int, k: Int) + private sealed trait FailureReason private final case class MissingStates(expected: Int, actual: Int) extends FailureReason { override def toString = s"States missing. Expected $expected states. Found $actual." From b1908216ad2e26f7ade2814283a9fe6c3cae0792 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Sat, 4 Nov 2017 11:41:59 -0700 Subject: [PATCH 88/98] no more. good enough. --- .../eharmony/aloha/util/rand/RandTest.scala | 59 ++++++++++--------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala index c4da14a2..8e89f70b 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -29,20 +29,11 @@ class RandTest extends Rand { ) val failureMsg = allFailuresMsg(failures) - reportFailureIfPresent(failureMsg) + failureMsg foreach fail } /** - * Produce junit failure if necessary. ('''IMPURE''') - * @param failureMsg A failure message that when empty, will cause the tests to fail. - */ - private def reportFailureIfPresent(failureMsg: String): Unit = { - if (failureMsg.nonEmpty) - fail(failureMsg) - } - - /** - * Sets up the sampling scenarios. ('''IMPURE''') + * Sets up the sampling scenarios. * @param trials number of trials to run. * @param maxN maximum N value * @param minSamples minimum number of samples to draw per trial. @@ -55,34 +46,46 @@ class RandTest extends Rand { maxN: Int, minSamples: Int, maxSamples: Int, - seed: Long): Iterator[SamplingScenario] = { - val r = new Random(seed) - - Iterator.fill(trials){ - val initSeed = r.nextLong() - val samples = minSamples + r.nextInt(maxSamples - minSamples + 1) - val n = r.nextInt(maxN + 1) - val k = r.nextInt(n + 1) - SamplingScenario(initSeed, samples, n, k) - } + seed: Long): Seq[SamplingScenario] = { + + // Get the scenarios eagerly to avoid carrying around the PRNG. + // If there was stateless "randomness", this could be non-strict. + val (scenarios, _) = + (1 to trials).foldLeft((List.empty[SamplingScenario], new Random(seed))){ + case ((ss, r), _) => + val initSeed = r.nextLong() + val samples = minSamples + r.nextInt(maxSamples - minSamples + 1) + val n = r.nextInt(maxN + 1) + val k = r.nextInt(n + 1) + (SamplingScenario(initSeed, samples, n, k) :: ss, r) + } + + scenarios } private def findFailures(trials: Int, maxN: Int, minSamples: Int, maxSamples: Int, seed: Long) = { - samplingScenarios(trials, maxN, minSamples, maxSamples, seed).flatMap { + samplingScenarios(trials, maxN, minSamples, maxSamples, seed).iterator.flatMap { case SamplingScenario(initSeed, samples, n, k) => val dist = samplingDist(initSeed, samples, n, k) checkDistributionUniformity(samples, n, k, dist).toIterable } } - private def allFailuresMsg[A](failures: Iterator[TestFailure[A]]): String = { - failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => + private def allFailuresMsg[A](failures: Iterator[TestFailure[A]]): Option[String] = { + if (failures.isEmpty) + None + else { + val errorMsg = + failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => - val thisFail = - s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + - fails.mkString("\n\t", "\n\t", "\n\n") + val thisFail = + s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + + fails.mkString("\n\t", "\n\t", "\n\n") + + msg + thisFail + } - msg + thisFail + Option(errorMsg) } } From 5e1ebd062d16cc262873f1d509f638f6b42612a8 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Sat, 4 Nov 2017 13:41:01 -0700 Subject: [PATCH 89/98] It's never good enough, even on a football Saturday. --- .../scala/com/eharmony/aloha/util/rand/RandTest.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala index 8e89f70b..cd370a49 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -77,7 +77,6 @@ class RandTest extends Rand { else { val errorMsg = failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => - val thisFail = s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + fails.mkString("\n\t", "\n\t", "\n\n") @@ -85,7 +84,7 @@ class RandTest extends Rand { msg + thisFail } - Option(errorMsg) + Option(errorMsg.trim) } } @@ -141,7 +140,14 @@ class RandTest extends Rand { val samplesAndSeed = (1 to samples) .foldLeft((Map.empty[String, Int], seed)){ case ((m, s), _) => + + // This is the sampling method whose uniformity is being tested. val (ind, newSeed) = sampleCombination(n, k, s) + + // It is important to sort here because order shouldn't matter since these + // samples are to be treated as sets, not sequences. By contract, + // `sampleCombinations` doesn't guarantee a particular ordering but the + // results are returned in an array (for computational efficiency). val key = ind.sorted.mkString(",") val updatedMap = m + (key -> (m.getOrElse(key, 0) + 1)) (updatedMap, newSeed) From 0889971bf320d7cfffabf29ae12a0fc9b890ad0a Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Sat, 4 Nov 2017 18:35:58 -0700 Subject: [PATCH 90/98] Seq -> List --- .../src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala index cd370a49..f60b1a5c 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -46,7 +46,7 @@ class RandTest extends Rand { maxN: Int, minSamples: Int, maxSamples: Int, - seed: Long): Seq[SamplingScenario] = { + seed: Long): List[SamplingScenario] = { // Get the scenarios eagerly to avoid carrying around the PRNG. // If there was stateless "randomness", this could be non-strict. From d60f226e53b936959405872b5964c427cf99786a Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Mon, 6 Nov 2017 17:19:08 -0800 Subject: [PATCH 91/98] VwDownsampledMultilabelRowCreator and supporting infrastructure and tests. --- .../aloha/dataset/StatefulRowCreator.scala | 2 +- .../dataset/StatefulRowCreatorProducer.scala | 43 +++ .../multilabel/PositiveLabelsFunction.scala | 42 +++ .../VwDownsampledMultilabelRowCreator.scala | 286 ++++++++++++++++++ .../multilabel/VwMultilabelRowCreator.scala | 85 +++--- .../json/VwDownsampledMultilabeledJson.scala | 31 ++ .../VwMultilabelRowCreatorTest.scala | 13 +- .../downsampled_neg_dataset_spec.json | 14 + .../VwMultilabelDownsampledModelTest.scala | 188 ++++++++++++ .../multilabel/VwMultilabelModelTest.scala | 2 - 10 files changed, 648 insertions(+), 58 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala create mode 100644 aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json create mode 100644 aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index 88067511..8c1d7b70 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -6,7 +6,7 @@ package com.eharmony.aloha.dataset * * Created by ryan.deak on 11/2/17. */ -trait StatefulRowCreator[-A, +B, S] extends Serializable { +trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] extends Serializable { /** * Some initial state that can be used on the very first call to `apply(A, S)`. diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala new file mode 100644 index 00000000..4a5ce8a0 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala @@ -0,0 +1,43 @@ +package com.eharmony.aloha.dataset + +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import spray.json.JsValue + +import scala.util.Try + +/** + * Created by deaktator on 11/6/17. + * + * @tparam A + * @tparam B + * @tparam S + * @tparam Impl + */ +trait StatefulRowCreatorProducer[A, +B, S, +Impl <: StatefulRowCreator[A, B, S]] { + + /** + * Type of parsed JSON object. + */ + type JsonType + + /** + * Name of this producer. + * @return + */ + def name: String + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * @param json + * @return + */ + def parse(json: JsValue): Try[JsonType] + + /** + * Attempt to produce a Spec. + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a StatefulRowCreator. + * @return + */ + def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: JsonType): Try[Impl] +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala new file mode 100644 index 00000000..caf577a4 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala @@ -0,0 +1,42 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.AlohaException +import com.eharmony.aloha.dataset.DvProducer +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.{determineLabelNamespaces, LabelNamespaces} +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc + +import scala.collection.breakOut +import scala.util.{Failure, Success, Try} +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 11/6/17. + * @param ev$1 + * @tparam A + * @tparam K + */ +private[multilabel] abstract class PositiveLabelsFunction[A, K: RefInfo] { self: DvProducer => + + private[multilabel] def positiveLabelsFn( + semantics: CompiledSemantics[A], + positiveLabels: String + ): Try[GenAggFunc[A, sci.IndexedSeq[K]]] = + getDv[A, sci.IndexedSeq[K]]( + semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) + + private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = { + val nsNames: Set[String] = nss.map(_._1)(breakOut) + determineLabelNamespaces(nsNames) match { + case Some(ns) => Success(ns) + + // If there are so many VW namespaces that all available Unicode characters are taken, + // then a memory error will probably already have occurred. + case None => Failure(new AlohaException( + "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + + nsNames.mkString(", ") + )) + } + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala new file mode 100644 index 00000000..3f311d08 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -0,0 +1,286 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.dataset._ +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.VwCovariateProducer +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator._ +import com.eharmony.aloha.dataset.vw.multilabel.json.VwDownsampledMultilabeledJson +import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.Logging +import com.eharmony.aloha.util.rand.Rand +import spray.json.JsValue + +import scala.collection.{breakOut, immutable => sci} +import scala.util.Try + + +/** + * Created by ryan.deak on 11/6/17. + * + * @param allLabelsInTrainingSet + * @param featuresFunction + * @param defaultNamespace + * @param namespaces + * @param normalizer + * @param positiveLabelsFunction + * @param classNs + * @param dummyClassNs + * @param numDownsampledLabels + * @param initialSeed a way to start off randomness + * @param includeZeroValues include zero values in VW input? + * @tparam A + * @tparam K + */ +final case class VwDownsampledMultilabelRowCreator[-A, K]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + featuresFunction: FeatureExtractorFunction[A, Sparse], + defaultNamespace: List[Int], + namespaces: List[(String, List[Int])], + normalizer: Option[CharSequence => CharSequence], + positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], + classNs: Char, + dummyClassNs: Char, + numDownsampledLabels: Int, + initialSeed: () => Long, + includeZeroValues: Boolean = false +) extends StatefulRowCreator[A, Array[String], Long] + with Logging { + + import VwDownsampledMultilabelRowCreator._ + + @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + + private[this] val negativeDummyStr = + s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" + + private[this] val positiveDummyStr = + s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" + + /** + * Some initial state that can be used on the very first call to `apply(A, S)`. + * + * @return some state. + */ + @transient override lazy val initialState: Long = { + + val seed = initialSeed() + + // For logging. Try to get time as close as possible to calling initialSeed. + // Note: There's a good chance this will differ. + val time = System.nanoTime() + val ip = java.net.InetAddress.getLocalHost.getHostAddress + val thread = Thread.currentThread() + val threadId = thread.getId + val threadName = thread.getName + + val scrambled = scramble(seed) + + info( + s"${getClass.getSimpleName} seed: $seed, scrambled: $scrambled, " + + s"nanotime: $time, ip: $ip, threadId: $threadId, threadName: $threadName" + ) + + scrambled + } + + /** + * Given an `a` and some `state`, produce output, including a new state. + * + * When using this function, the user is responsible for keeping track of, + * and providing the state. + * + * The implementation of this function should be referentially transparent. + * + * @param a input + * @param seed the random seed which is updated on each call. + * @return a tuple where the first element is a Tuple2 whose first element is + * missing and error information and second element is an optional result. + * The second element of the outer Tuple2 is the new state. + */ + override def apply(a: A, seed: Long): ((MissingAndErroneousFeatureInfo, Option[Array[String]]), Long) = { + val (missingAndErrs, features) = featuresFunction(a) + + // TODO: Should this be sci.BitSet? + val positiveIndices: Set[Int] = + positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut) + + val (x, newSeed) = sampledTrainingInput( + features, + allLabelsInTrainingSet.indices, + positiveIndices, + defaultNamespace, + namespaces, + classNs, + dummyClassNs, + negativeDummyStr, + positiveDummyStr, + seed, + numDownsampledLabels + ) + + ((missingAndErrs, Option(x)), newSeed) + } +} + + +object VwDownsampledMultilabelRowCreator extends Rand { + + private def scramble(initSeed: Long): Long = initSeedScramble(initSeed) + + private[multilabel] def sampledTrainingInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + positiveLabelIndices: Int => Boolean, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: Char, + dummyClassNs: Char, + negativeDummyStr: String, + positiveDummyStr: String, + seed: Long, + numNegLabelsTarget: Int + ): (Array[String], Long) = { + + // Partition into positive and negative indices. + val (posLabelInd, negLabelInd) = indices.partition(positiveLabelIndices) + val negLabels = negLabelInd.size + val numDownsampledLabels = math.min(numNegLabelsTarget, negLabels) + + // Sample indices in negLabelInd to use in the output. + val (downsampledNegLabelInd, newSeed) = + sampleCombination(negLabels, numDownsampledLabels, seed) + + val p = posLabelInd.size + val n = downsampledNegLabelInd.length + + // This code is written with knowledge of the constant's value. + // TODO: Write calling code to abstract over costs so knowledge of the costs isn't necessary. + val negWt = + if (negLabels <= numNegLabelsTarget) { + // No downsampling occurs. + NegativeCost.toString + } + else { + // Determine the weight for the downsampled negatives. + // If the cost for positive examples is 0, and negative examples have cost 1, + // the weight will be in the interval (NegativeCost, Infinity). + + f"${NegativeCost * negLabels / n.toDouble}%.5g" + } + + // The length of the output array is n + 3. + // + // The first row is the shared features. These are features that are not label dependent. + // Then comes two dummy classes. These are to make the probabilities work out. + // Then come the features for each of the n labels. + val x = new Array[String](p + n + 3) + + val shared = VwRowCreator.unlabeledVwInput( + features, defaultNs, namespaces, includeZeroValues = false + ) + + x(0) = SharedFeatureIndicator + shared + x(1) = negativeDummyStr + x(2) = positiveDummyStr + + // vvvvv This is mutable because we want speed. vvvvv + + // Negative weights + var ni = 0 + while (ni < n) { + val labelInd = negLabelInd(downsampledNegLabelInd(ni)) + x(ni + 3) = s"$labelInd:$negWt |$classNs _$labelInd" + ni += 1 + } + + // Positive weights + var pi = 0 + while (pi < p) { + val labelInd = posLabelInd(pi) + x(pi + n + 3) = s"$labelInd:$PositiveCost |$classNs _$labelInd" + pi += 1 + } + + (x, newSeed) + } + + /** + * A producer that can produce a [[VwDownsampledMultilabelRowCreator]]. + * The requirement for [[StatefulRowCreatorProducer]] to only have zero-argument constructors + * is relaxed for this Producer because we don't have a way of generically constructing a + * list of labels. If the labels were encoded in the JSON, then a JsonReader for the label + * type would have to be passed to the constructor. Since the labels can't be encoded + * generically in the JSON, we accept that this Producer is a special case and allow the labels + * to be passed directly. The consequence is that this producer doesn't just rely on the + * dataset specification and the data itself. It also relying on the labels provided to the + * constructor. + * + * @param allLabelsInTrainingSet All of the labels that will be encountered in the training set. + * @param seedCreator a "''function''" that creates a seed that will be used for randomness. + * The implementation of this function is important. It should create a + * unique value for each unit of parallelism. If for example, row + * creation is parallelized across multiple threads on one machine, the + * unit of parallelism is threads and `seedCreator` should produce unique + * values for each thread. If row creation is parallelized across multiple + * machines, the `seedCreator` should produce a unique value for each + * machine. If row creation is parallelized across machines and threads on + * each machine, the `seedCreator` should create unique values for each + * thread on each machine. Otherwise, randomness will be striped which + * @param ev$1 reflection information about `K`. + * @tparam A type of input passed to the [[StatefulRowCreator]]. + * @tparam K the label type. + */ + final class Producer[A, K: RefInfo]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + seedCreator: () => Long + ) extends PositiveLabelsFunction[A, K] + with StatefulRowCreatorProducer[A, Array[String], Long, VwDownsampledMultilabelRowCreator[A, K]] + with RowCreatorProducerName + with VwCovariateProducer[A] + with DvProducer + with SparseCovariateProducer + with CompilerFailureMessages { + + override type JsonType = VwDownsampledMultilabeledJson + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * to create the row creator. + * @param json JSON AST. + * @return + */ + override def parse(json: JsValue): Try[VwDownsampledMultilabeledJson] = + Try { json.convertTo[VwDownsampledMultilabeledJson] } + + /** + * Attempt to produce a Spec. + * + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a RowCreator. + * @return + */ + override def getRowCreator( + semantics: CompiledSemantics[A], + jsonSpec: VwDownsampledMultilabeledJson + ): Try[VwDownsampledMultilabelRowCreator[A, K]] = { + val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) + + val rc = for { + cov <- covariates + pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + labelNs <- labelNamespaces(nss) + actualLabelNs = labelNs.labelNs + dummyLabelNs = labelNs.dummyLabelNs + sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) + numNeg = jsonSpec.numDownsampledNegLabels + } yield new VwDownsampledMultilabelRowCreator[A, K]( + allLabelsInTrainingSet, cov, default, nss, normalizer, + pos, actualLabelNs, dummyLabelNs, numNeg, seedCreator) + + rc + } + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index a862da2d..d48f527c 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -1,6 +1,5 @@ package com.eharmony.aloha.dataset.vw.multilabel -import com.eharmony.aloha.AlohaException import com.eharmony.aloha.dataset._ import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.dataset.vw.VwCovariateProducer @@ -9,10 +8,11 @@ import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator import com.eharmony.aloha.reflect.RefInfo import com.eharmony.aloha.semantics.compiled.CompiledSemantics import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.rand.Rand import spray.json.JsValue import scala.collection.{breakOut, immutable => sci} -import scala.util.{Failure, Success, Try} +import scala.util.Try /** * Created by ryan.deak on 9/13/17. @@ -26,8 +26,8 @@ final case class VwMultilabelRowCreator[-A, K]( positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], classNs: Char, dummyClassNs: Char, - includeZeroValues: Boolean = false -) extends RowCreator[A, Array[String]] { + includeZeroValues: Boolean = false) +extends RowCreator[A, Array[String]] { import VwMultilabelRowCreator._ @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap @@ -63,7 +63,7 @@ final case class VwMultilabelRowCreator[-A, K]( } } -object VwMultilabelRowCreator { +object VwMultilabelRowCreator extends Rand { /** * VW allows long-based feature indices, but Aloha only allow's 32-bit indices @@ -71,7 +71,7 @@ object VwMultilabelRowCreator { * dummy classes uses an ID outside of the allowable range of feature indices: * 2^32^. */ - private val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString + private[multilabel] val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString /** * VW allows long-based feature indices, but Aloha only allow's 32-bit indices @@ -79,7 +79,11 @@ object VwMultilabelRowCreator { * dummy classes uses an ID outside of the allowable range of feature indices: * 2^32^ + 1. */ - private val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString + private[multilabel] val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString + + // NOTE: If PositiveCost and NegativeCost change, + // VwDownsampledMultilabledRowCreator.sampledTrainingInput + // will also need to change. /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the @@ -87,7 +91,7 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a positive example is designated to be one, * so the cost (or negative reward) is -1. */ - private val PositiveCost = (-1).toString + private[multilabel] val PositiveCost = 0 /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the @@ -95,11 +99,11 @@ object VwMultilabelRowCreator { * As such, the ''reward'' of a negative example is designated to be zero, * so the cost (or negative reward) is 0. */ - private val NegativeCost = 0.toString + private[multilabel] val NegativeCost = 1 - private val PositiveDummyClassFeature = "P" + private[multilabel] val PositiveDummyClassFeature = "P" - private val NegativeDummyClassFeature = "N" + private[multilabel] val NegativeDummyClassFeature = "N" /** * "shared" is a special keyword in VW multi-class (multi-row) format. @@ -107,7 +111,7 @@ object VwMultilabelRowCreator { * * '''NOTE''': The trailing space should be here. */ - private[this] val SharedFeatureIndicator = "shared" + " " + private[multilabel] val SharedFeatureIndicator = "shared" + " " private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ')) @@ -191,10 +195,12 @@ object VwMultilabelRowCreator { * @param namespaces the indices into `features` that should be associated with each * namespace. * @param classNs a namespace for features associated with class labels - * @param dummyClassNs a namespace for features associated with dummy class labels + * // @param dummyClassNs a namespace for features associated with dummy class labels + * @param negativeDummyStr + * @param positiveDummyStr * @return an array to be passed directly to an underlying `VWActionScoresLearner`. */ - private[aloha] def trainingInput( + private[multilabel] def trainingInput( features: IndexedSeq[Sparse], indices: sci.IndexedSeq[Int], positiveLabelIndices: Int => Boolean, @@ -222,7 +228,8 @@ object VwMultilabelRowCreator { x(1) = negativeDummyStr x(2) = positiveDummyStr - // This is mutable because we want speed. + // vvvvv This is mutable because we want speed. vvvvv + var i = 0 while (i < n) { val labelInd = indices(i) @@ -276,7 +283,6 @@ object VwMultilabelRowCreator { x } - /** * A producer that can produce a [[VwMultilabelRowCreator]]. * The requirement for [[RowCreatorProducer]] to only have zero-argument constructors is @@ -294,7 +300,8 @@ object VwMultilabelRowCreator { * @tparam K the label type. */ final class Producer[A, K: RefInfo](allLabelsInTrainingSet: sci.IndexedSeq[K]) - extends RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] + extends PositiveLabelsFunction[A, K] + with RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] with RowCreatorProducerName with VwCovariateProducer[A] with DvProducer @@ -305,8 +312,8 @@ object VwMultilabelRowCreator { /** * Attempt to parse the JSON AST to an intermediate representation that is used - * - * @param json + * to create the row creator. + * @param json JSON AST. * @return */ override def parse(json: JsValue): Try[VwMultilabeledJson] = @@ -319,42 +326,24 @@ object VwMultilabelRowCreator { * @param jsonSpec a JSON specification to transform into a RowCreator. * @return */ - override def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwMultilabeledJson): Try[VwMultilabelRowCreator[A, K]] = { + override def getRowCreator( + semantics: CompiledSemantics[A], + jsonSpec: VwMultilabeledJson + ): Try[VwMultilabelRowCreator[A, K]] = { val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) - val spec = for { - cov <- covariates - pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) - labelNs <- labelNamespaces(nss) + val rc = for { + cov <- covariates + pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + labelNs <- labelNamespaces(nss) actualLabelNs = labelNs.labelNs - dummyLabelNs = labelNs.dummyLabelNs - sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) + dummyLabelNs = labelNs.dummyLabelNs + sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) } yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, normalizer, pos, actualLabelNs, dummyLabelNs) - spec + rc } - - private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = { - val nsNames: Set[String] = nss.map(_._1)(breakOut) - determineLabelNamespaces(nsNames) match { - case Some(ns) => Success(ns) - - // If there are so many VW namespaces that all available Unicode characters are taken, - // then a memory error will probably already have occurred. - case None => Failure(new AlohaException( - "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + - nsNames.mkString(", ") - )) - } - } - - private[multilabel] def positiveLabelsFn( - semantics: CompiledSemantics[A], - positiveLabels: String - ): Try[GenAggFunc[A, sci.IndexedSeq[K]]] = - getDv[A, sci.IndexedSeq[K]]( - semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) } private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala new file mode 100644 index 00000000..db505b94 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala @@ -0,0 +1,31 @@ +package com.eharmony.aloha.dataset.vw.multilabel.json + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.json.VwJsonLike +import spray.json.{DefaultJsonProtocol, RootJsonFormat} + +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 11/6/17. + * + * @param imports + * @param features + * @param namespaces + * @param normalizeFeatures + * @param positiveLabels + * @param numDownsampledNegLabels + */ +final case class VwDownsampledMultilabeledJson( + imports: sci.Seq[String], + features: sci.IndexedSeq[SparseSpec], + namespaces: Option[Seq[Namespace]] = Some(Nil), + normalizeFeatures: Option[Boolean] = Some(false), + positiveLabels: String, + numDownsampledNegLabels: Int +) extends VwJsonLike + +object VwDownsampledMultilabeledJson extends DefaultJsonProtocol { + implicit val vwDownsampledMultilabeledJson: RootJsonFormat[VwDownsampledMultilabeledJson] = + jsonFormat6(VwDownsampledMultilabeledJson.apply) +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala index f4ac646f..7817c15a 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -104,7 +104,7 @@ class VwMultilabelRowCreatorTest { val suffix = expectedResults.zipWithIndex map { case (isPos, i) => - s"$i:${if (isPos) -1 else 0} |$labelNs _$i" + s"$i:${if (isPos) PosVal else NegVal} |$labelNs _$i" } assertEquals(prefix, actualResults.take(prefix.size).toSeq) @@ -132,8 +132,8 @@ object VwMultilabelRowCreatorTest { private val LabelsInTrainingSet = Vector("zero", "one", "two") private val NegDummyClass = Int.MaxValue.toLong + 1 private val PosDummyClass = NegDummyClass + 1 - private val PosVal = -1 - private val NegVal = 0 + private val PosVal = 0 + private val NegVal = 1 private val X = Map.empty[String, Any] private val SharedPrefix = "shared " @@ -175,12 +175,11 @@ object VwMultilabelRowCreatorTest { // the features (and dummy class definitions) in MultilabelDatasetJson are invariant to // the input. private[aloha] val MultilabelOutputPrefix = Seq( - "shared |ns1 f1 |ns2 f2", // due to definition of features and namespaces. - "2147483648:0 |y N", // negative dummy class - "2147483649:-1 |y P" // positive dummy class + "shared |ns1 f1 |ns2 f2", // due to definition of features and namespaces. + s"2147483648:$NegVal |y N", // negative dummy class + s"2147483649:$PosVal |y P" // positive dummy class ) - private def output(out: (MissingAndErroneousFeatureInfo, Array[String])) = out._2 private[this] val featureFns = SparseFeatureExtractorFunction[Domain](Vector( diff --git a/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json new file mode 100644 index 00000000..8910bc6f --- /dev/null +++ b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json @@ -0,0 +1,14 @@ +{ + "imports": [ + "com.eharmony.aloha.feature.BasicFunctions._" + ], + "features": [ + { "name": "feature", "spec": "1" } + ], + "namespaces": [ + { "name": "X", "features": [ "feature" ] } + ], + "normalizeFeatures": false, + "positiveLabels": "${labels_from_input}", + "numDownsampledNegLabels": 1 +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala new file mode 100644 index 00000000..640f7cd2 --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala @@ -0,0 +1,188 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.File + +import com.eharmony.aloha.audit.impl.OptionAuditor +import com.eharmony.aloha.dataset.vw.multilabel.VwDownsampledMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.json.VwDownsampledMultilabeledJson +import com.eharmony.aloha.factory.ModelFactory +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances +import org.apache.commons.vfs2.VFS +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import org.junit.Assert.{assertEquals, fail} +import spray.json.pimpString +import spray.json.DefaultJsonProtocol.IntJsonFormat +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +import scala.collection.breakOut +import scala.util.Random + +/** + * Created by ryan.deak on 11/6/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelDownsampledModelTest { + + import VwMultilabelDownsampledModelTest.EndToEndDatasetSpec + + @Test def test(): Unit = { + + // ------------------------------------------------------------------------------------ + // Preliminaries + // ------------------------------------------------------------------------------------ + + type Lab = Int + type Dom = Vector[Lab] // Not an abuse of notation. Domain is indeed a vector of labels. + + val semantics = CompiledSemanticsInstances.anyNameIdentitySemantics[Dom] + val optAud = OptionAuditor[Map[Lab, Double]]() + + // 10 passes over the data, sampling the negatives. + val repetitions = 10 + + + // ------------------------------------------------------------------------------------ + // Dataset and test example set up. + // ------------------------------------------------------------------------------------ + + // Marginal Distribution: Pr[8] = 0.80 = 20 / 25 = (12 + 8) / 25 + // Pr[4] = 0.40 = 10 / 25 = ( 2 + 8) / 25 + // + val unshuffledTrainingSet: Seq[Dom] = Seq( + Vector( ) -> 3, // pr = 0.12 = 3 / 25 <-- JPD + Vector( 8) -> 12, // pr = 0.48 = 12 / 25 + Vector(4 ) -> 2, // pr = 0.08 = 2 / 25 + Vector(4, 8) -> 8 // pr = 0.32 = 8 / 25 + ) flatMap { + case (k, n) => Vector.fill(n * repetitions)(k) + } + + val trainingSet = new Random(0x0912e40f395bL).shuffle(unshuffledTrainingSet) + + val labelsInTrainingSet = trainingSet.flatten.toSet.toVector.sorted + + val testExample: Dom = Vector.empty + + val marginalDist = labelsInTrainingSet.map { label => + val z = trainingSet.size.toDouble + label -> trainingSet.count(row => row contains label) / z + }.toMap + + // ------------------------------------------------------------------------------------ + // Prepare dataset specification and read training dataset. + // ------------------------------------------------------------------------------------ + + val datasetSpec = VFS.getManager.resolveFile(EndToEndDatasetSpec) + val datasetJson = + StringReadable + .fromVfs2(datasetSpec) + .parseJson + .convertTo[VwDownsampledMultilabeledJson] + + val rc = + new VwDownsampledMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet, () => 0L). + getRowCreator(semantics, datasetJson).get + + + // ------------------------------------------------------------------------------------ + // Prepare parameters for VW model that will be trained. + // ------------------------------------------------------------------------------------ + + val binaryVwModel = File.createTempFile("vw_", ".bin.model") + binaryVwModel.deleteOnExit() + + val cacheFile = File.createTempFile("vw_", ".cache") + cacheFile.deleteOnExit() + + val origParams = + s""" + | -P 2.0 + | --cache_file ${cacheFile.getCanonicalPath} + | --holdout_off + | --passes 50 + | --learning_rate 5 + | --decay_learning_rate 0.9 + | --csoaa_ldf mc + | --loss_function logistic + | -f ${binaryVwModel.getCanonicalPath} + """.stripMargin.trim.replaceAll("\n", " ") + + // Get the namespace names from the row creator. + val nsNames = rc.namespaces.map(_._1)(breakOut): Set[String] + + // Take the parameters and augment with additional parameters to make + // multilabel w/ probabilities work correctly. + val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames, 2) fold ( + e => throw new Exception(e.errorMessage), + ps => ps + ) + + // ------------------------------------------------------------------------------------ + // Train VW model + // ------------------------------------------------------------------------------------ + + // Get the iterator of the examples produced. + val examples = rc(trainingSet.iterator, rc.initialState) collect { + case ((_, Some(x)), _) => x + } + + val vwLearner = VWLearners.create[VWActionScoresLearner](vwParams) + + examples foreach { x => + vwLearner.learn(x) + } + + vwLearner.close() + + + // ------------------------------------------------------------------------------------ + // Create Aloha model JSON + // ------------------------------------------------------------------------------------ + + val modelJson = VwMultilabelModel.json( + datasetSpec = Vfs.apacheVfs2ToAloha(datasetSpec), + binaryVwModel = Vfs.javaFileToAloha(binaryVwModel), + id = ModelId(1, "NONE"), + labelsInTrainingSet = labelsInTrainingSet, + labelsOfInterest = Option.empty[String], + vwArgs = Option.empty[String], + externalModel = false, + numMissingThreshold = Option(0) + ) + + // ------------------------------------------------------------------------------------ + // Instantiate Aloha Model + // ------------------------------------------------------------------------------------ + + val factory = ModelFactory.defaultFactory(semantics, optAud) + val modelTry = factory.fromString(modelJson.prettyPrint) // Use `.compactPrint` in prod. + val model = modelTry.get + + + // ------------------------------------------------------------------------------------ + // Test Aloha Model + // ------------------------------------------------------------------------------------ + + val output = model(testExample) + model.close() + + output match { + case None => fail() + case Some(m) => + assertEquals(marginalDist.keySet, m.keySet) + marginalDist foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } + } +} + +object VwMultilabelDownsampledModelTest { + private val EndToEndDatasetSpec = + "res:com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json" +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index ade05aa9..a8d27fb5 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -453,8 +453,6 @@ object VwMultilabelModelTest { labelsInTrainingSet: Vector[K], labelsOfInterest: Option[String] = None) = { -// implicit val vecWriter = vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) - val loi = labelsOfInterest.fold(""){ f => val escaped = f.replaceAll("\"", "\\\"") s""""labelsOfInterest": "$escaped",\n""" From fe8dca2d0021da91687b69c24689af325333e465 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 13:42:25 -0800 Subject: [PATCH 92/98] Made Rand use Int indices, made k < 2^15 in neg label sampling. Updated StatefulRowCreator API. --- .../aloha/dataset/StatefulRowCreator.scala | 55 +++++++- .../VwDownsampledMultilabelRowCreator.scala | 132 ++++++++++++++---- .../multilabel/VwMultilabelRowCreator.scala | 29 +++- .../json/VwDownsampledMultilabeledJson.scala | 10 +- .../com/eharmony/aloha/util/rand/Rand.scala | 8 +- .../VwMultilabelDownsampledModelTest.scala | 15 +- .../multilabel/VwMultilabelModelTest.scala | 2 +- 7 files changed, 201 insertions(+), 50 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index 8c1d7b70..399715c7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -36,14 +36,53 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * applications, the state will come from the state generated in the output of the * previous application of `apply(A, S)`. * - * @param as Note the first element of `as` will be forced in this method in order + * ''This variant of mapping with state is'' '''non-strict''', so if that's a requirement, + * prefer this function over the `mapSeqWithState` variant. Note that first this method + * to work, the first element is computed eagerly. + * + * To verify non-strictness, this method could be rewritten as: + * + * {{{ + * def statefulMap[A, B, S](as: Iterator[A], s: S) + * (f: (A, S) => (B, S)): Iterator[(B, S)] = { + * if (as.isEmpty) + * Iterator.empty + * else { + * val firstA = as.next() + * val initEl = f(firstA, s) + * as.scanLeft(initEl){ case ((_, newS), a) => f(a, newS) } + * } + * } + * }}} + * + * Then using the method, it's easy to verify non-strictness: + * + * {{{ + * // cycleModulo4 and res are infinite. + * val cycleModulo4 = Iterator.iterate(0)(s => (s + 1) % 4) + * val res = statefulMap(cycleModulo4, 0)((a, s) => ((a + s).toDouble, s + 1)) + * + * // returns: + * res.take(8) + * .map(_._1) + * .toVector + * .groupBy(identity) + * .mapValues(_.size) + * .toVector + * .sorted + * .foreach(println) + * + * res.size // never returns + * }}} + * + * @param as Note the first element of `as` ''will be forced'' in this method in order * to construct the output. * @param state the initial state to use at the start of the iterator. - * @return an iterator containing the a mapped to a + * @return an iterator containing the `a` mapped to a * `(MissingAndErroneousFeatureInfo, Option[B])` along with the resulting * state that is created in the process. */ - def apply(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { + def mapIteratorWithState(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { if (as.isEmpty) Iterator.empty else { @@ -58,17 +97,23 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten } /** - * Apply the `apply(A, S)` method to the elements of the Vector. In the first + * Apply the `apply(A, S)` method to the elements of the sequence. In the first * application of `apply(A, S)`, `state` will be used as the state. In subsequent * applications, the state will come from the state generated in the output of the * previous application of `apply(A, S)`. * + * ''This variant of mapping with state is'' '''strict'''. + * + * '''NOTE''': This method isn't really parallelizable via chunking. The way to + * parallelize this method is to provide a separate starting state for each unit + * of parallelism. + * * @param as input to map. * @param state the initial state to use at the start of the Vector. * @return a Tuple2 where the first element is a vector of results and the second * element is the resulting state. */ - def apply(as: Vector[A], state: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) = { + def mapSeqWithState(as: Seq[A], state: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) = { as.foldLeft((Vector.empty[(MissingAndErroneousFeatureInfo, Option[B])], state)){ case ((bs, s), a) => val (b, newS) = apply(a, s) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala index 3f311d08..10696c0e 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -18,21 +18,48 @@ import scala.util.Try /** - * Created by ryan.deak on 11/6/17. + * Creates training data for multilabel models in Vowpal Wabbit's CSOAA LDF and WAP LDF format + * for the JNI. In this row creator, negative labels are downsampled and costs for the + * downsampled labels are adjusted to produced an unbiased estimator. It is assumed that + * negative labels are in the majority. Downsampling negatives can improve both training + * time and possibly model performance. See the following resources for intuition: * - * @param allLabelsInTrainingSet - * @param featuresFunction - * @param defaultNamespace - * @param namespaces - * @param normalizer - * @param positiveLabelsFunction - * @param classNs - * @param dummyClassNs - * @param numDownsampledLabels + - [[https://www3.nd.edu/~nchawla/papers/SPRINGER05.pdf Chawla, Nitesh V. "Data mining for + imbalanced datasets: An overview." Data mining and knowledge discovery handbook. + Springer US, 2009. 875-886.]] + - [[https://www3.nd.edu/~dial/publications/chawla2004editorial.pdf Chawla, Nitesh V., Nathalie + Japkowicz, and Aleksander Kotcz. "Special issue on learning from imbalanced data sets." + ACM SIGKDD Explorations Newsletter 6.1 (2004): 1-6.]] + - [[http://www.marcoaltini.com/blog/dealing-with-imbalanced-data-undersampling-oversampling-and-proper-cross-validation + Dealing with imbalanced data: undersampling, oversampling, and proper cross validation, + Marco Altini, Aug 17, 2015.]] + * + * This row creator, since it is stateful, requires the caller to maintain state. If however, + * it is only called via an iterator or sequence, then this row creator can maintain the state + * during iteration over the iterator or sequence. In the case of iterators, the mapping is + * '''non-strict''' and in the case of sequences (`Seq`), it is '''strict'''. + * + * @param allLabelsInTrainingSet all labels in the training set. This is a sequence because + * order matters. Order here can be chosen arbitrarily, but it + * must be consistent in the training and test formulation. + * @param featuresFunction features to extract from the data of type `A`. + * @param defaultNamespace list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param normalizer can modify VW output (currently unused) + * @param positiveLabelsFunction A method that can extract positive class labels. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param numDownsampledNegLabels '''a positive value''' representing the number of negative + * labels to include in each row. If this is less than the + * number of negative examples for a given row, then no + * downsampling of negatives will take place. * @param initialSeed a way to start off randomness * @param includeZeroValues include zero values in VW input? - * @tparam A - * @tparam K + * @tparam A the input type + * @tparam K the label or class type + * @author deaktator + * @since 11/6/2017 */ final case class VwDownsampledMultilabelRowCreator[-A, K]( allLabelsInTrainingSet: sci.IndexedSeq[K], @@ -43,16 +70,23 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], classNs: Char, dummyClassNs: Char, - numDownsampledLabels: Int, + numDownsampledNegLabels: Short, initialSeed: () => Long, includeZeroValues: Boolean = false ) extends StatefulRowCreator[A, Array[String], Long] with Logging { + require( + 0 < numDownsampledNegLabels, + s"numDownsampledNegLabels must be positive, found $numDownsampledNegLabels" + ) + import VwDownsampledMultilabelRowCreator._ @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + // Precomputed for efficiency. + private[this] val negativeDummyStr = s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" @@ -61,7 +95,6 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( /** * Some initial state that can be used on the very first call to `apply(A, S)`. - * * @return some state. */ @transient override lazy val initialState: Long = { @@ -69,7 +102,7 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( val seed = initialSeed() // For logging. Try to get time as close as possible to calling initialSeed. - // Note: There's a good chance this will differ. + // Note: There's a very good chance this will differ. val time = System.nanoTime() val ip = java.net.InetAddress.getLocalHost.getHostAddress val thread = Thread.currentThread() @@ -87,10 +120,10 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( } /** - * Given an `a` and some `state`, produce output, including a new state. + * Given an `a` and some `seed`, produce output, including a new seed. * * When using this function, the user is responsible for keeping track of, - * and providing the state. + * and providing the seeds. * * The implementation of this function should be referentially transparent. * @@ -103,11 +136,32 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( override def apply(a: A, seed: Long): ((MissingAndErroneousFeatureInfo, Option[Array[String]]), Long) = { val (missingAndErrs, features) = featuresFunction(a) + // Get the lazy val once. + val labToInd = labelToInd + + // TODO: This seems like it could be optimized. + // The positiveLabelsFunction is invoked once and all labels are produced. + // Then a set is produced that is used to partition all of the labels in + // the training data into positive and negative. It seems like a slightly + // more efficient way to do this would be to create a sorted array of positives + // indices. Then a negative index array could be constructed with the appropriate + // size based on the size of allLabelsInTrainingSet and the size of the positives + // array. Next allLabelsInTrainingSet.indices and the positiveIndices array could + // be iterated over simultaneously and if the current index in + // allLabelsInTrainingSet.indices isn't in the positiveIndices array, then it must + // be in the negativeIndices array. Then then offsets into the two arrays are + // incremented. Both the current and proposed algorithms are O(N), where + // N = allLabelsInTrainingSet.indices.size. But the proposed algorithm would likely + // have much better constant factors. + // + // Note that labels produced by positiveLabelsFunction that are not in the + // allLabelsInTrainingSet are discarded without notice. + // // TODO: Should this be sci.BitSet? val positiveIndices: Set[Int] = - positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut) + positiveLabelsFunction(a).flatMap { y => labToInd.get(y).toSeq }(breakOut) - val (x, newSeed) = sampledTrainingInput( + val (vwInput, newSeed) = sampledTrainingInput( features, allLabelsInTrainingSet.indices, positiveIndices, @@ -118,18 +172,45 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( negativeDummyStr, positiveDummyStr, seed, - numDownsampledLabels + numDownsampledNegLabels ) - ((missingAndErrs, Option(x)), newSeed) + ((missingAndErrs, Option(vwInput)), newSeed) } } object VwDownsampledMultilabelRowCreator extends Rand { + // Expose initSeedScramble to companion class. private def scramble(initSeed: Long): Long = initSeedScramble(initSeed) + /** + * + * '''This should not be used directly.''' ''It is exposed for testing.'' + * + * @param features the common features produced + * @param indices indices of all labels in the training set. + * @param positiveLabelIndices indices of labels in the training set that positive are + * positive for this training example + * @param defaultNs list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param negativeDummyStr line in VW input associated with the negative dummy class. + * @param positiveDummyStr line in VW input associated with the positive dummy class. + * @param seed a seed for randomness. The second part of the output of this function is a + * new seed that should be used on the next call to this function. + * @param numNegLabelsTarget the desired number of negative labels. If it is determined + * that there are less negative labels than desired, no negative + * label downsampling will occur; otherwise, the negative labels + * will be downsampled to this target value. + * @return an array representing the VW input that can either be passed directly to the + * VW JNI library, or `mkString("", "\n", "\n")` can be called to pass to the VW + * CLI. The second part of the return value is a new seed to use as the `seed` + * parameter in the next call to this function. + */ private[multilabel] def sampledTrainingInput( features: IndexedSeq[Sparse], indices: sci.IndexedSeq[Int], @@ -141,12 +222,13 @@ object VwDownsampledMultilabelRowCreator extends Rand { negativeDummyStr: String, positiveDummyStr: String, seed: Long, - numNegLabelsTarget: Int + numNegLabelsTarget: Short ): (Array[String], Long) = { // Partition into positive and negative indices. val (posLabelInd, negLabelInd) = indices.partition(positiveLabelIndices) val negLabels = negLabelInd.size + val numDownsampledLabels = math.min(numNegLabelsTarget, negLabels) // Sample indices in negLabelInd to use in the output. @@ -159,14 +241,14 @@ object VwDownsampledMultilabelRowCreator extends Rand { // This code is written with knowledge of the constant's value. // TODO: Write calling code to abstract over costs so knowledge of the costs isn't necessary. val negWt = - if (negLabels <= numNegLabelsTarget) { + if (numDownsampledLabels == negLabels) { // No downsampling occurs. NegativeCost.toString } else { // Determine the weight for the downsampled negatives. - // If the cost for positive examples is 0, and negative examples have cost 1, - // the weight will be in the interval (NegativeCost, Infinity). + // If the cost of negative examples is positive, then the weight will be + // strictly greater than . f"${NegativeCost * negLabels / n.toDouble}%.5g" } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index d48f527c..587395ab 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -15,7 +15,25 @@ import scala.collection.{breakOut, immutable => sci} import scala.util.Try /** - * Created by ryan.deak on 9/13/17. + * Creates training data for multilabel models in Vowpal Wabbit's CSOAA LDF and WAP LDF format + * for the JNI. + * + * @param allLabelsInTrainingSet all labels in the training set. This is a sequence because + * order matters. Order here can be chosen arbitrarily, but it + * must be consistent in the training and test formulation. + * @param featuresFunction features to extract from the data of type `A`. + * @param defaultNamespace list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param normalizer can modify VW output (currently unused) + * @param positiveLabelsFunction A method that can extract positive class labels. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param includeZeroValues include zero values in VW input? + * @tparam A the input type + * @tparam K the label or class type + * @author deaktator + * @since 9/13/2017 */ final case class VwMultilabelRowCreator[-A, K]( allLabelsInTrainingSet: sci.IndexedSeq[K], @@ -88,16 +106,15 @@ object VwMultilabelRowCreator extends Rand { /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the * dependent variable is based on cost (which is the negative of reward). - * As such, the ''reward'' of a positive example is designated to be one, - * so the cost (or negative reward) is -1. + * As such, the ''reward'' of a positive example is designated to be zero. */ private[multilabel] val PositiveCost = 0 /** * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the * dependent variable is based on cost (which is the negative of reward). - * As such, the ''reward'' of a negative example is designated to be zero, - * so the cost (or negative reward) is 0. + * As such, the ''reward'' of a negative example is designated to be -1, + * so the cost (or negative reward) is 1. */ private[multilabel] val NegativeCost = 1 @@ -347,4 +364,4 @@ object VwMultilabelRowCreator extends Rand { } private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char) -} \ No newline at end of file +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala index db505b94..145ec591 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala @@ -22,8 +22,14 @@ final case class VwDownsampledMultilabeledJson( namespaces: Option[Seq[Namespace]] = Some(Nil), normalizeFeatures: Option[Boolean] = Some(false), positiveLabels: String, - numDownsampledNegLabels: Int -) extends VwJsonLike + numDownsampledNegLabels: Short +) extends VwJsonLike { + + require( + 0 < numDownsampledNegLabels, + s"numDownsampledNegLabels must be positive, found $numDownsampledNegLabels" + ) +} object VwDownsampledMultilabeledJson extends DefaultJsonProtocol { implicit val vwDownsampledMultilabeledJson: RootJsonFormat[VwDownsampledMultilabeledJson] = diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala index 9ed0780e..1a9d194a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala @@ -9,7 +9,7 @@ package com.eharmony.aloha.util.rand private[aloha] trait Rand { type Seed = Long - type Index = Short + type Index = Int /** * Perform the initial scramble. This should be called '''''once''''' on the initial @@ -55,7 +55,7 @@ private[aloha] trait Rand { - [[http://grepcode.com/file/repository.grepcode.com/java/root/jdk/openjdk/6-b14/java/util/Random.java grepcode for Random]] - [[https://en.wikipedia.org/wiki/Linear_congruential_generator Linear congruential generator (Wikipedia)]] * - * @param n population size (< 2^15^). + * @param n population size (< 2^31^). * @param k combination size (< 2^15^). * @param seed the seed to use for random selection * @return a tuple 2 containing the array of 0-based indices representing @@ -67,7 +67,7 @@ private[aloha] trait Rand { // object creation is important. if (n <= k) { - ((0 until n).map(i => i.toShort).toArray, seed) + ((0 until n).toArray, seed) } else { var i = k + 1 @@ -77,7 +77,7 @@ private[aloha] trait Rand { var value = 0 // Fill reservoir with the first k indices. - val reservoir = (0 until k).map(i => i.toShort).toArray + val reservoir = (0 until k).toArray // Loop over the rest of the indices outside the reservoir and determine if // swapping should occur. If so, swap the index in the reservoir with the diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala index 640f7cd2..f7530a97 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala @@ -43,7 +43,7 @@ class VwMultilabelDownsampledModelTest { val optAud = OptionAuditor[Map[Lab, Double]]() // 10 passes over the data, sampling the negatives. - val repetitions = 10 + val repetitions = 8 // ------------------------------------------------------------------------------------ @@ -101,10 +101,10 @@ class VwMultilabelDownsampledModelTest { val origParams = s""" - | -P 2.0 + | --quiet | --cache_file ${cacheFile.getCanonicalPath} | --holdout_off - | --passes 50 + | --passes 10 | --learning_rate 5 | --decay_learning_rate 0.9 | --csoaa_ldf mc @@ -126,15 +126,16 @@ class VwMultilabelDownsampledModelTest { // Train VW model // ------------------------------------------------------------------------------------ - // Get the iterator of the examples produced. - val examples = rc(trainingSet.iterator, rc.initialState) collect { + // Get the iterator of the examples produced. This is similar to what one may do + // within a `mapPartitions` in Spark. + val examples = rc.mapIteratorWithState(trainingSet.iterator, rc.initialState) collect { case ((_, Some(x)), _) => x } val vwLearner = VWLearners.create[VWActionScoresLearner](vwParams) - examples foreach { x => - vwLearner.learn(x) + examples foreach { yx => + vwLearner.learn(yx) } vwLearner.close() diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index a8d27fb5..6566fbeb 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -135,7 +135,7 @@ class VwMultilabelModelTest { | --csoaa_ldf mc | --loss_function logistic | -f ${binaryVwModel.getCanonicalPath} - | --passes 50 + | --passes 40 | --cache_file ${cacheFile.getCanonicalPath} | --holdout_off | --learning_rate 5 From baf7466e5f5d9b8ba432765a4007d1bfa65b9312 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 17:35:41 -0800 Subject: [PATCH 93/98] Addressing PR comments. Removed VW params from VW multi-label model code. --- .../VwDownsampledMultilabelRowCreator.scala | 3 +- .../multilabel/VwMultilabelRowCreator.scala | 12 ++-- .../aloha/factory/ri2jf/CollectionTypes.scala | 2 - .../models/multilabel/MultilabelModel.scala | 16 +++-- .../MultilabelModelParserPlugin.scala | 2 +- .../multilabel/json/MultilabelModelJson.scala | 9 ++- .../aloha/models/multilabel/package.scala | 2 - .../reflect/RuntimeClasspathScanning.scala | 5 +- .../aloha/ModelSerializabilityTestBase.scala | 13 +++- .../dataset/RowCreatorProducerTest.scala | 2 +- .../VwMultilabelRowCreatorTest.scala | 8 ++- .../multilabel/VwMultlabelJsonCreator.scala | 39 ++++++------ .../VwSparseMultilabelPredictor.scala | 60 ++++--------------- .../VwSparseMultilabelPredictorProducer.scala | 14 ++--- .../json/VwMultilabelModelJson.scala | 3 +- .../VwMultilabelModelPluginJsonReader.scala | 9 +-- .../VwMultilabelDownsampledModelTest.scala | 1 - .../multilabel/VwMultilabelModelTest.scala | 2 - .../VwSparseMultilabelPredictorTest.scala | 2 +- 19 files changed, 81 insertions(+), 123 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala index 10696c0e..8ae5b8d7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -158,8 +158,7 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( // allLabelsInTrainingSet are discarded without notice. // // TODO: Should this be sci.BitSet? - val positiveIndices: Set[Int] = - positiveLabelsFunction(a).flatMap { y => labToInd.get(y).toSeq }(breakOut) + val positiveIndices: Set[Int] = positiveLabelsFunction(a).flatMap(labToInd.get)(breakOut) val (vwInput, newSeed) = sampledTrainingInput( features, diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala index 587395ab..784bbee6 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -62,9 +62,11 @@ extends RowCreator[A, Array[String]] { override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { val (missingAndErrs, features) = featuresFunction(a) + // Get the lazy val once. + val labToInd = labelToInd + // TODO: Should this be sci.BitSet? - val positiveIndices: Set[Int] = - positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut) + val positiveIndices: Set[Int] = positiveLabelsFunction(a).flatMap(labToInd.get)(breakOut) val x: Array[String] = trainingInput( features, @@ -148,7 +150,7 @@ object VwMultilabelRowCreator extends Rand { * * The goal of this function is to try to use characters in literature used to denote a * dependent variable. If that isn't possible (because the characters are already used by some - * other namespace, just find the first possible characters. + * other namespace), just find the first possible characters. * @param usedNss names of namespaces used. * @return the namespace for ''actual'' label information then the namespace for ''dummy'' * label information. If two valid namespaces couldn't be produced, return None. @@ -166,7 +168,7 @@ object VwMultilabelRowCreator extends Rand { } private[multilabel] def nssToFirstCharBitSet(ss: Set[String]): sci.BitSet = - ss.collect { case s if s != "" => + ss.collect { case s if s.length != 0 => s.charAt(0).toInt }(breakOut[Set[String], Int, sci.BitSet]) @@ -237,7 +239,7 @@ object VwMultilabelRowCreator extends Rand { // Then come the features for each of the n labels. val x = new Array[String](n + 3) - val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, includeZeroValues = false) x(0) = SharedFeatureIndicator + shared // These string interpolations are computed over and over but will always be the same diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala index 8ae4f317..49dd61c7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala @@ -27,8 +27,6 @@ class CollectionTypes extends RefInfoToJsonFormatConversions { None } else if (RefInfoOps.isSubType[A, collection.immutable.Seq[Any]]) -// conv(r.typeArguments.head).flatMap(f => jf(immSeqFormat(f))) -// conv(typeParams.head).flatMap(f => jf(immSeqFormat(f))) for { tEl <- typeParams.headOption el <- conv(tEl) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 17022ca8..1d3a9527 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -19,14 +19,6 @@ import spray.json.{JsonFormat, JsonReader} import scala.collection.{immutable => sci, mutable => scm} import scala.util.{Failure, Success} - -// TODO: When adding label-dep features, a Seq[GenAggFunc[K, Sparse]] will be needed. -// TODO: To create a Seq[GenAggFunc[K, Sparse]], a Semantics[K] needs to be derived from a Semantics[A]. -// TODO: MorphableSemantics provides this. If K is *embedded inside* A, it should be possible in some cases. -// TODO: An alternative is to pass a Map[K, Sparse], Map[K, Option[Sparse]], Map[K, Seq[Sparse]] or something. -// TODO: Directly passing the map of LDFs avoids the need to derive a Semantics[K]. This is easier to code. -// TODO: Directly passing LDFs would however be more burdensome to the data scientists. - /** * A multi-label predictor. * @@ -312,7 +304,11 @@ object MultilabelModel extends ParserProviderCompanion { } /** - * Get labels from the input for which a prediction should be produced. + * Get labels from the input for which a prediction should be produced. If + * `labelsOfInterest` produces a label not in the training set, it will not + * be present in the prediction output but it will appear in + * `LabelsAndInfo.labelsNotInTrainingSet`. + * * @param example the example provided to the model * @param labelsOfInterest a function used to extract labels for which a * prediction should be produced. @@ -330,6 +326,8 @@ object MultilabelModel extends ParserProviderCompanion { val labelsShouldPredict = labelsOfInterest(example) + // Notice here that if the label isn't present in the labeltoInd map, it is ignored + // and not inserted into the `unsorted` sequence. val unsorted = for { label <- labelsShouldPredict diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala index 5a7d15bf..cd2b4792 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala @@ -4,7 +4,7 @@ import com.eharmony.aloha.reflect.{RefInfo, RuntimeClasspathScanning} import spray.json.{JsonFormat, JsonReader} /** - * A plugin that will produce the + * A plugin that will ultimately produce the [[SparseMultiLabelPredictor]]. * Created by ryan.deak on 9/6/17. */ trait MultilabelModelParserPlugin { diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala index 5d7be2c0..dad7cfb2 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala @@ -14,14 +14,17 @@ trait MultilabelModelJson extends SpecJson with ScalaJsonFormats { protected[this] case class Plugin(`type`: String) /** - * Data for the + * AST for multi-label models. * * @param modelType * @param modelId * @param features * @param numMissingThreshold - * @param labelsInTrainingSet - * @param labelsOfInterest + * @param labelsInTrainingSet The sequence of all labels encountered in training. '''It is + * important''' that this is sequence (''with the same order as the + * labels in the training set''). This is because some algorithms + * may require indices based on the training data. + * @param labelsOfInterest a string representing a function that will be used to extract labels. * @param underlying the underlying model that will be produced by a * @tparam K */ diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala index e1cd9564..5796856b 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -52,8 +52,6 @@ package object multilabel { * * and returns a Map from the labels passed in, to the prediction associated with the label. * - * '''NOTE''': This is exposed as package private for testing. - * * @tparam K the type of labels (or classes in the machine learning literature). */ type SparseMultiLabelPredictor[K] = diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala index 43520318..629cb8ca 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala @@ -3,7 +3,7 @@ package com.eharmony.aloha.reflect import com.eharmony.aloha import org.reflections.Reflections -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import scala.util.Try /** @@ -44,8 +44,7 @@ trait RuntimeClasspathScanning { ): Seq[A] = { val reflections = new Reflections(aloha.pkgName) import scala.collection.JavaConversions.asScalaSet - // val classA = implicitly[ClassTag[A]].runtimeClass - val objects = reflections.getSubTypesOf(implicitly[ClassTag[OBJ]].runtimeClass).toSeq + val objects = reflections.getSubTypesOf(classTag[OBJ].runtimeClass).toSeq val suffixLength = objectSuffix.length diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala b/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala index bb94fc74..6334401a 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala @@ -9,13 +9,16 @@ import org.reflections.Reflections import scala.collection.JavaConversions.asScalaSet import scala.util.Try - import java.lang.reflect.{Method, Modifier} +import com.eharmony.aloha.util.Logging + /** * Created by ryan on 12/7/15. */ -abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String]) { +abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String]) +extends Logging { + def this() = this(pkgs = Seq(aloha.pkgName), Seq.empty) @Test def testSerialization(): Unit = { @@ -27,6 +30,12 @@ abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[S outFilters.exists(name.matches) } + debug { + modelClasses + .map(_.getCanonicalName) + .mkString("Models tested for Serializability:\n\t", "\n\t", "") + } + modelClasses.foreach { c => val m = for { testClass <- getTestClass(c.getCanonicalName) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala index 438a10e4..e7d894bf 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala @@ -52,4 +52,4 @@ object RowCreatorProducerTest { private val WhitelistedRowCreatorProducers = Set[Class[_]]( classOf[VwMultilabelRowCreator.Producer[_, _]] ) -} \ No newline at end of file +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala index 7817c15a..da56b63c 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -134,12 +134,16 @@ object VwMultilabelRowCreatorTest { private val PosDummyClass = NegDummyClass + 1 private val PosVal = 0 private val NegVal = 1 + private val PosFeature = 'P' + private val NegFeature = 'N' + private val DummyNs = 'y' + private val X = Map.empty[String, Any] private val SharedPrefix = "shared " private val DummyLabels = List( - s"$NegDummyClass:$NegVal |y N", - s"$PosDummyClass:$PosVal |y P" + s"$NegDummyClass:$NegVal |$DummyNs $NegFeature", + s"$PosDummyClass:$PosVal |$DummyNs $PosFeature" ) private val AllNegative = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala index 47a8ce26..bc3e452c 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala @@ -43,8 +43,6 @@ extends MultilabelModelJson * which predictions can be made. If the labels at prediction time * differs (''or should be extracted from the input to the model''), * this function can provide that capability. - * @param vwArgs arguments that should be passed to the VW model. This likely isn't strictly - * necessary. * @param externalModel whether the underlying binary VW model should remain as a separate * file and be referenced by the Aloha model specification (`true`) * or the binary model content should be embeeded directly into the model @@ -57,22 +55,21 @@ extends MultilabelModelJson * @return a JSON object. */ def json[K: JsonWriter]( - datasetSpec: Vfs, - binaryVwModel: Vfs, - id: ModelIdentity, - labelsInTrainingSet: Seq[K], - labelsOfInterest: Option[String] = None, - vwArgs: Option[String] = None, - externalModel: Boolean = false, - numMissingThreshold: Option[Int] = None - ): JsValue = { + datasetSpec: Vfs, + binaryVwModel: Vfs, + id: ModelIdentity, + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String] = None, + externalModel: Boolean = false, + numMissingThreshold: Option[Int] = None + ): JsValue = { val dsJsAst = StringReadable.fromInputStream(datasetSpec.inputStream).parseJson val ds = dsJsAst.convertTo[VwMultilabeledJson] val features = modelFeatures(ds.features) val namespaces = ds.namespaces.map(modelNamespaces) val modelSrc = modelSource(binaryVwModel, externalModel) - val vw = vwModelPlugin(modelSrc, vwArgs, namespaces) + val vw = vwModelPlugin(modelSrc, namespaces) val model = modelAst(id, features, numMissingThreshold, labelsInTrainingSet, labelsOfInterest, vw) // Even though we are only *writing* JSON, we need a JsonFormat[K], (which is a reader @@ -103,24 +100,22 @@ extends MultilabelModelJson // Private b/c VwMultilabelAst is protected[this]. Don't let it escape. private def vwModelPlugin( - modelSrc: ModelSource, - vwArgs: Option[String], - namespaces: Option[ListMap[String, Seq[String]]]) = + modelSrc: ModelSource, + namespaces: Option[ListMap[String, Seq[String]]]) = VwMultilabelAst( VwSparseMultilabelPredictorProducer.multilabelPlugin.name, modelSrc, - vwArgs.map(a => Right(a)), namespaces ) // Private b/c MultilabelData is private[this]. Don't let it escape. private def modelAst[K]( - id: ModelIdentity, - features: ListMap[String, Spec], - numMissingThreshold: Option[Int], - labelsInTrainingSet: Seq[K], - labelsOfInterest: Option[String], - vwPlugin: VwMultilabelAst) = + id: ModelIdentity, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String], + vwPlugin: VwMultilabelAst) = MultilabelData( modelType = MultilabelModel.parser.modelType, modelId = ModelId(id.getId(), id.getName()), diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala index 12504d1c..fe453e84 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -13,16 +13,14 @@ import scala.collection.{immutable => sci} import scala.util.Try /** - * - * Created by ryan.deak on 9/8/17. + * Creates a VW multi-label predictor plugin for `MultilabelModel`. * @param modelSource a specification for the underlying ''Cost Sensitive One Against All'' - * VW model with ''label dependent features''. VW flag - * `--csoaa_ldf mc` is expected. For more information, see the + * VW model with ''label dependent features''. VW flag `--csoaa_ldf mc` + * or `--wap_ldf mc` is expected. For more information, see the * [[https://github.com/JohnLangford/vowpal_wabbit/wiki/Cost-Sensitive-One-Against-All-(csoaa)-multi-class-example VW CSOAA wiki page]]. * Also see the ''Cost-Sensitive Multiclass Classification'' section of * Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html On Multiclass Classification in VW]] * page. This model specification will be materialized in this class. - * @param params VW parameters. * @param defaultNs The list of indices into the `features` sequence that does not have * an exist in any value of the `namespaces` map. * @param namespaces Mapping from namespace name to indices in the `features` sequence passed @@ -30,13 +28,11 @@ import scala.util.Try * ''key-value'' pairs appearing in the map should not values that are empty * sequences. '''This is a requirement.''' * @tparam K the label or class type. + * @author deaktator + * @since 9/8/2017 */ -// TODO: Comment this function. It requires a lot of assumptions. Make those known. case class VwSparseMultilabelPredictor[K]( modelSource: ModelSource, - - // TODO: Should these be removed? I don't think so but could be, w/o harm, in limited cases. - params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])], numLabelsInTrainingSet: Int) @@ -46,7 +42,7 @@ extends SparseMultiLabelPredictor[K] import VwSparseMultilabelPredictor._ @transient private[this] lazy val paramsAndVwModel = - createLearner(modelSource, params, numLabelsInTrainingSet) + createLearner(modelSource, numLabelsInTrainingSet) @transient private[this] lazy val updatedParams = paramsAndVwModel._1 @transient private[multilabel] lazy val vwModel = paramsAndVwModel._2.get @@ -130,52 +126,22 @@ object VwSparseMultilabelPredictor { @inline final private def modifiedLogistic(x: Float) = 1 / (1 + math.exp(x)) /** - * Update the parameters with the - * - * VW params of interest when doing multi-class: - * - - `--csoaa_ldf mc` Label-dependent features for multi-class classification - - `--csoaa_rank` (Probably) necessary to get scores for m-c classification. - - `--loss_function logistic` Standard logistic loss for learning. - - `--noconstant` Don't want a constant since it's not interacted with NS Y. - - `-q YX` Cross product of label-dependent and features and features - - `--ignore_linear Y` Don't care about the 1st-order wts of the label-dep features. - - `--ignore_linear X` Don't care about the 1st-order wts of the features. - - `--ignore y` Ignore everything related to the dummy class instances. - * - * {{{ - * val str = - * "shared |X feature" + "\n" + - * - * "0:1 |y _C0_" + "\n" + // These two instances are dummy classes - * "1:0 |y _C1_" + "\n" + - * - * "2:0 |Y _C2_" + "\n" + - * "3:1 |Y _C3_" - * - * val ex = str.split("\n") - * }}} - * @param modelSource location of a VW model. - * @param params the parameters passed to the model to which additional parameters will be added. - * @return + * Update parameters with initial regressior, ring size, testonly, and quiet. + * @param modelSource a trained VW binary model. + * @param numLabelsInTrainingSet number of labels in the training set informs VW's ring size. + * @return updated parameters. */ - // TODO: How much of the parameter setup is up to the caller versus this function? - private[multilabel] def paramsWithSource( - modelSource: File, - params: String, - numLabelsInTrainingSet: Int - ): String = { + private[multilabel] def paramsWithSource(modelSource: File, numLabelsInTrainingSet: Int): String = { val ringSize = numLabelsInTrainingSet + AddlVwRingSize - s"$params -i ${modelSource.getCanonicalPath} --ring_size $ringSize --testonly --quiet" + s"-i ${modelSource.getCanonicalPath} --ring_size $ringSize --testonly --quiet" } private[multilabel] def createLearner( modelSource: ModelSource, - params: String, numLabelsInTrainingSet: Int ): (String, Try[ExpectedLearner]) = { val modelFile = modelSource.localVfs.replicatedToLocal() - val updatedParams = paramsWithSource(modelFile.fileObj, params, numLabelsInTrainingSet) + val updatedParams = paramsWithSource(modelFile.fileObj, numLabelsInTrainingSet) val learner = Try { VWLearners.create[ExpectedLearner](updatedParams) } (updatedParams, learner) } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala index 13d0a26b..0950499b 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -7,36 +7,32 @@ import com.eharmony.aloha.reflect.RefInfo import spray.json.{JsonFormat, JsonReader} /** - * A thing wrapper responsible for creating a [[VwSparseMultilabelPredictor]]. This - * creation is deferred because VW JNI models are not Serializable because they are + * A wrapper responsible for creating a [[VwSparseMultilabelPredictor]]. This defers + * creation since VW JNI models are not Serializable. This is because they are * thin wrappers around native C models in memory. The underlying binary model is * externalizable in a file, byte array, etc. and this information is contained in * `modelSource`. The actual VW JNI model is created only where needed. In short, * this class can be serialized but the JNI model created from `modelSource` cannot * be. * - * Created by ryan.deak on 9/5/17. - * * @param modelSource a source from which the binary VW model information can be * extracted and used to create a VW JNI model. - * @param params VW parameters passed to the JNI constructor * @param defaultNs indices into the features that belong to VW's default namespace. * @param namespaces namespace name to feature indices map. * @tparam K type of the labels returned by the [[VwSparseMultilabelPredictor]] that * will be produced. + * @author deaktator + * @since 9/5/2017 */ case class VwSparseMultilabelPredictorProducer[K]( modelSource: ModelSource, - - // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. - params: String, defaultNs: List[Int], namespaces: List[(String, List[Int])], labelNamespace: Char, numLabelsInTrainingSet: Int ) extends SparsePredictorProducer[K] { override def apply(): VwSparseMultilabelPredictor[K] = - VwSparseMultilabelPredictor[K](modelSource, params, defaultNs, namespaces, numLabelsInTrainingSet) + VwSparseMultilabelPredictor[K](modelSource, defaultNs, namespaces, numLabelsInTrainingSet) } object VwSparseMultilabelPredictorProducer extends MultilabelPluginProviderCompanion { diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala index 63ccbcc8..bc41d6f5 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala @@ -15,9 +15,8 @@ trait VwMultilabelModelJson extends ScalaJsonFormats { private[multilabel] case class VwMultilabelAst( `type`: String, modelSource: ModelSource, - params: Option[Either[Seq[String], String]] = Option(Right("")), namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) ) - protected[this] implicit val vwMultilabelAstFormat = jsonFormat4(VwMultilabelAst) + protected[this] implicit val vwMultilabelAstFormat = jsonFormat3(VwMultilabelAst) } diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala index 418cc35d..ed3bc824 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala @@ -11,7 +11,7 @@ import scala.collection.breakOut import scala.collection.immutable.ListMap /** - * + * A JSON reader for SparsePredictorProducer. * * '''NOTE''': This class extends `JsonReader[SparsePredictorProducer[K]]` rather than the * more specific type `JsonReader[VwSparseMultilabelPredictorProducer[K]]` because `JsonReader` @@ -32,7 +32,6 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String], numLa override def read(json: JsValue): VwSparseMultilabelPredictorProducer[K] = { val ast = json.asJsObject(notObjErr(json)).convertTo[VwMultilabelAst] - val params = vwParams(ast.params) val (namespaces, defaultNs, missing) = allNamespaceIndices(featureNames, ast.namespaces.getOrElse(ListMap.empty)) @@ -44,8 +43,7 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String], numLa labelAndDummyLabelNss match { case Some(LabelNamespaces(labelNs, _)) => - // TODO: Should we remove this. If not, it must contain the --ring_size [training labels + 10]. - VwSparseMultilabelPredictorProducer[K](ast.modelSource, params, defaultNs, namespaces, labelNs, numLabelsInTrainingSet) + VwSparseMultilabelPredictorProducer[K](ast.modelSource, defaultNs, namespaces, labelNs, numLabelsInTrainingSet) case _ => throw new DeserializationException( "Could not determine label namespace. Found namespaces: " + @@ -58,9 +56,6 @@ case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String], numLa object VwMultilabelModelPluginJsonReader extends Logging { private val JsonErrStrLength = 100 - private[multilabel] def vwParams(params: Option[Either[Seq[String], String]]): String = - params.fold("")(e => e.fold(ps => ps.mkString(" "), identity[String])).trim - private[multilabel] def notObjErr(json: JsValue): String = { val str = json.prettyPrint val substr = str.substring(0, JsonErrStrLength) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala index f7530a97..6b4a87ab 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala @@ -151,7 +151,6 @@ class VwMultilabelDownsampledModelTest { id = ModelId(1, "NONE"), labelsInTrainingSet = labelsInTrainingSet, labelsOfInterest = Option.empty[String], - vwArgs = Option.empty[String], externalModel = false, numMissingThreshold = Option(0) ) diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 6566fbeb..43b2ee59 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -175,7 +175,6 @@ class VwMultilabelModelTest { id = ModelId(1, "NONE"), labelsInTrainingSet = labelsInTrainingSet, labelsOfInterest = Option.empty[String], - vwArgs = Option.empty[String], externalModel = false, numMissingThreshold = Option(0) ) @@ -335,7 +334,6 @@ object VwMultilabelModelTest { val predProd = VwSparseMultilabelPredictorProducer[Label]( modelSource = TrainedModel, - params = "", // to see the output: "-p /dev/stdout", defaultNs = List.empty[Int], namespaces = namespaces, labelNamespace = labelNs, diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala index 38fc47ba..2568a47c 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala @@ -55,7 +55,7 @@ object VwSparseMultilabelPredictorTest { } private def getPredictor(modelSrc: ModelSource, numLabelsInTrainingSet: Int) = - VwSparseMultilabelPredictor[Any](modelSrc, "", Nil, Nil, numLabelsInTrainingSet) + VwSparseMultilabelPredictor[Any](modelSrc, Nil, Nil, numLabelsInTrainingSet) private def checkVwBinFile(vwBinFilePath: String): Unit = { val vwBinFile = new File(vwBinFilePath) From 22d95478f83e80c24c39f473649c6b889c0e217e Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 17:43:55 -0800 Subject: [PATCH 94/98] Downsampling can now operate over 2^31 - 1 (2 billion) labels. --- .../VwDownsampledMultilabelRowCreator.scala | 21 ++++++++++++++----- .../json/VwDownsampledMultilabeledJson.scala | 15 ++++++++----- .../com/eharmony/aloha/util/rand/Rand.scala | 4 ++-- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala index 8ae5b8d7..46a2c407 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -54,7 +54,18 @@ import scala.util.Try * labels to include in each row. If this is less than the * number of negative examples for a given row, then no * downsampling of negatives will take place. - * @param initialSeed a way to start off randomness + * @param seedCreator a "''function''" that creates a seed that will be used for randomness. + * The implementation of this function is important. It should create a + * unique value for each unit of parallelism. If for example, row + * creation is parallelized across multiple threads on one machine, the + * unit of parallelism is threads and `seedCreator` should produce unique + * values for each thread. If row creation is parallelized across multiple + * machines, the `seedCreator` should produce a unique value for each + * machine. If row creation is parallelized across machines and threads on + * each machine, the `seedCreator` should create unique values for each + * thread on each machine. Otherwise, randomness will be striped which + + * * @param includeZeroValues include zero values in VW input? * @tparam A the input type * @tparam K the label or class type @@ -70,8 +81,8 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], classNs: Char, dummyClassNs: Char, - numDownsampledNegLabels: Short, - initialSeed: () => Long, + numDownsampledNegLabels: Int, + seedCreator: () => Long, includeZeroValues: Boolean = false ) extends StatefulRowCreator[A, Array[String], Long] with Logging { @@ -99,7 +110,7 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( */ @transient override lazy val initialState: Long = { - val seed = initialSeed() + val seed = seedCreator() // For logging. Try to get time as close as possible to calling initialSeed. // Note: There's a very good chance this will differ. @@ -221,7 +232,7 @@ object VwDownsampledMultilabelRowCreator extends Rand { negativeDummyStr: String, positiveDummyStr: String, seed: Long, - numNegLabelsTarget: Short + numNegLabelsTarget: Int ): (Array[String], Long) = { // Partition into positive and negative indices. diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala index 145ec591..08b812ce 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala @@ -7,14 +7,19 @@ import spray.json.{DefaultJsonProtocol, RootJsonFormat} import scala.collection.{immutable => sci} /** - * Created by ryan.deak on 11/6/17. - * + * JSON AST for `VwDownsampledMultilabelRowCreator`. * @param imports * @param features * @param namespaces * @param normalizeFeatures - * @param positiveLabels - * @param numDownsampledNegLabels + * @param positiveLabels string representing a function that will be used to extract positive + * labels from the input. + * @param numDownsampledNegLabels '''a positive value''' representing the number of negative + * labels to include in each row. If this is less than the + * number of negative examples for a given row, then no + * downsampling of negatives will take place. + * @author deaktator + * @since 11/6/2017 */ final case class VwDownsampledMultilabeledJson( imports: sci.Seq[String], @@ -22,7 +27,7 @@ final case class VwDownsampledMultilabeledJson( namespaces: Option[Seq[Namespace]] = Some(Nil), normalizeFeatures: Option[Boolean] = Some(false), positiveLabels: String, - numDownsampledNegLabels: Short + numDownsampledNegLabels: Int ) extends VwJsonLike { require( diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala index 1a9d194a..39a9116d 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala @@ -55,8 +55,8 @@ private[aloha] trait Rand { - [[http://grepcode.com/file/repository.grepcode.com/java/root/jdk/openjdk/6-b14/java/util/Random.java grepcode for Random]] - [[https://en.wikipedia.org/wiki/Linear_congruential_generator Linear congruential generator (Wikipedia)]] * - * @param n population size (< 2^31^). - * @param k combination size (< 2^15^). + * @param n population size + * @param k combination size * @param seed the seed to use for random selection * @return a tuple 2 containing the array of 0-based indices representing * the ''k''-combination and a new random seed. From c463c11340c2f826d6cf39c0dc938e91c0e15a66 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 18:33:48 -0800 Subject: [PATCH 95/98] changed Iterator.isEmpty to hasNext. Updated docs. --- .../scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index 399715c7..8ccd0202 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -37,7 +37,7 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * previous application of `apply(A, S)`. * * ''This variant of mapping with state is'' '''non-strict''', so if that's a requirement, - * prefer this function over the `mapSeqWithState` variant. Note that first this method + * prefer this function over the `mapSeqWithState` variant. Note that for this method * to work, the first element is computed eagerly. * * To verify non-strictness, this method could be rewritten as: @@ -83,7 +83,7 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * state that is created in the process. */ def mapIteratorWithState(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { - if (as.isEmpty) + if (as.hasNext) Iterator.empty else { // Force the first A. Then apply the `apply` transformation to get From 79fddb69b5506d23603200d94239c827ad3bfb55 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 19:01:41 -0800 Subject: [PATCH 96/98] forgot logical not in if statement. --- .../scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index 8ccd0202..b76d93d3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -83,7 +83,7 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * state that is created in the process. */ def mapIteratorWithState(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { - if (as.hasNext) + if (!as.hasNext) Iterator.empty else { // Force the first A. Then apply the `apply` transformation to get From 1c0102c385efac191fb38bcf62c986efd1360268 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Tue, 7 Nov 2017 23:16:14 -0800 Subject: [PATCH 97/98] removed toShort from Rand. --- .../src/main/scala/com/eharmony/aloha/util/rand/Rand.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala index 39a9116d..d6038922 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala @@ -109,7 +109,7 @@ private[aloha] trait Rand { // Wikipedia does a good job describing this in the proof by induction in the // explanation for Algorithm R. if (reservoirSwapInd < k) - reservoir(reservoirSwapInd) = (i - 1).toShort + reservoir(reservoirSwapInd) = i - 1 i += 1 } From 6e21f8c09969a740398dec93a1c7a88281433cf4 Mon Sep 17 00:00:00 2001 From: Ryan Deak Date: Wed, 8 Nov 2017 13:40:06 -0800 Subject: [PATCH 98/98] addressing PR comments. Changed name of multilabel model to 'SparseMultilabel'. Generalized StatefulRowCreator. --- .../eharmony/aloha/cli/ModelTypesTest.scala | 4 +- .../aloha/dataset/StatefulRowCreator.scala | 81 ++++------------ .../VwDownsampledMultilabelRowCreator.scala | 10 +- .../models/multilabel/MultilabelModel.scala | 2 +- .../eharmony/aloha/util/StatefulMapOps.scala | 93 +++++++++++++++++++ .../VwMultilabelDownsampledModelTest.scala | 2 +- .../multilabel/VwMultilabelModelTest.scala | 2 +- .../VwMultilabelParamAugmentationTest.scala | 2 +- 8 files changed, 121 insertions(+), 75 deletions(-) create mode 100644 aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala diff --git a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala index d6617048..8f8e7811 100644 --- a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala +++ b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala @@ -28,8 +28,8 @@ class ModelTypesTest { "ModelDecisionTree", "Regression", "Segmentation", - "VwJNI", - "multilabel-sparse" + "SparseMultilabel", + "VwJNI" ) val actual = ModelFactory.defaultFactory(null, null).parsers.map(_.modelType).sorted diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala index b76d93d3..3bb3d5c2 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -1,12 +1,17 @@ package com.eharmony.aloha.dataset +import com.eharmony.aloha.util.StatefulMapOps + +import scala.collection.{SeqLike, immutable => sci} +import scala.collection.generic.{CanBuildFrom => CBF} + /** * A row creator that requires state. This state should be modeled functionally, meaning * implementations should be referentially transparent. * * Created by ryan.deak on 11/2/17. */ -trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] extends Serializable { +trait StatefulRowCreator[-A, +B, S] extends Serializable { /** * Some initial state that can be used on the very first call to `apply(A, S)`. @@ -36,44 +41,7 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * applications, the state will come from the state generated in the output of the * previous application of `apply(A, S)`. * - * ''This variant of mapping with state is'' '''non-strict''', so if that's a requirement, - * prefer this function over the `mapSeqWithState` variant. Note that for this method - * to work, the first element is computed eagerly. - * - * To verify non-strictness, this method could be rewritten as: - * - * {{{ - * def statefulMap[A, B, S](as: Iterator[A], s: S) - * (f: (A, S) => (B, S)): Iterator[(B, S)] = { - * if (as.isEmpty) - * Iterator.empty - * else { - * val firstA = as.next() - * val initEl = f(firstA, s) - * as.scanLeft(initEl){ case ((_, newS), a) => f(a, newS) } - * } - * } - * }}} - * - * Then using the method, it's easy to verify non-strictness: - * - * {{{ - * // cycleModulo4 and res are infinite. - * val cycleModulo4 = Iterator.iterate(0)(s => (s + 1) % 4) - * val res = statefulMap(cycleModulo4, 0)((a, s) => ((a + s).toDouble, s + 1)) - * - * // returns: - * res.take(8) - * .map(_._1) - * .toVector - * .groupBy(identity) - * .mapValues(_.size) - * .toVector - * .sorted - * .foreach(println) - * - * res.size // never returns - * }}} + * For more information, see [[com.eharmony.aloha.util.StatefulMapOps]] * * @param as Note the first element of `as` ''will be forced'' in this method in order * to construct the output. @@ -82,19 +50,8 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * `(MissingAndErroneousFeatureInfo, Option[B])` along with the resulting * state that is created in the process. */ - def mapIteratorWithState(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = { - if (!as.hasNext) - Iterator.empty - else { - // Force the first A. Then apply the `apply` transformation to get - // the initial element of a scanLeft. Inside the scanLeft, use the - // state outputted by previous `apply` calls as input to current - // calls to `apply(A, S)`. - val firstA = as.next() - val initEl = apply(firstA, state) - as.scanLeft(initEl){ case ((_, mostRecentState), a) => apply(a, mostRecentState) } - } - } + def statefulMap(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = + StatefulMapOps.statefulMap(as, state)(apply) /** * Apply the `apply(A, S)` method to the elements of the sequence. In the first @@ -102,22 +59,18 @@ trait StatefulRowCreator[-A, +B, @specialized(Int, Float, Long, Double) S] exten * applications, the state will come from the state generated in the output of the * previous application of `apply(A, S)`. * - * ''This variant of mapping with state is'' '''strict'''. - * * '''NOTE''': This method isn't really parallelizable via chunking. The way to * parallelize this method is to provide a separate starting state for each unit * of parallelism. * + * For more information, see [[com.eharmony.aloha.util.StatefulMapOps]] + * * @param as input to map. - * @param state the initial state to use at the start of the Vector. - * @return a Tuple2 where the first element is a vector of results and the second - * element is the resulting state. + * @param state the initial state to use at the start of mapping. + * @param cbf object responsible for building the output collection. + * @return */ - def mapSeqWithState(as: Seq[A], state: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S) = { - as.foldLeft((Vector.empty[(MissingAndErroneousFeatureInfo, Option[B])], state)){ - case ((bs, s), a) => - val (b, newS) = apply(a, s) - (bs :+ b, newS) - } - } + def statefulMap[In <: sci.Seq[A], Out](as: SeqLike[A, In], state: S)(implicit + cbf: CBF[In, ((MissingAndErroneousFeatureInfo, Option[B]), S), Out] + ): Out = StatefulMapOps.statefulMap(as, state)(apply) } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala index 46a2c407..000a909a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -64,8 +64,7 @@ import scala.util.Try * machine. If row creation is parallelized across machines and threads on * each machine, the `seedCreator` should create unique values for each * thread on each machine. Otherwise, randomness will be striped which - - * + * is bad. * @param includeZeroValues include zero values in VW input? * @tparam A the input type * @tparam K the label or class type @@ -160,7 +159,7 @@ final case class VwDownsampledMultilabelRowCreator[-A, K]( // array. Next allLabelsInTrainingSet.indices and the positiveIndices array could // be iterated over simultaneously and if the current index in // allLabelsInTrainingSet.indices isn't in the positiveIndices array, then it must - // be in the negativeIndices array. Then then offsets into the two arrays are + // be in the negativeIndices array. Then the offsets into the two arrays are // incremented. Both the current and proposed algorithms are O(N), where // N = allLabelsInTrainingSet.indices.size. But the proposed algorithm would likely // have much better constant factors. @@ -258,7 +257,7 @@ object VwDownsampledMultilabelRowCreator extends Rand { else { // Determine the weight for the downsampled negatives. // If the cost of negative examples is positive, then the weight will be - // strictly greater than . + // strictly greater than NegativeCost. f"${NegativeCost * negLabels / n.toDouble}%.5g" } @@ -320,7 +319,8 @@ object VwDownsampledMultilabelRowCreator extends Rand { * machines, the `seedCreator` should produce a unique value for each * machine. If row creation is parallelized across machines and threads on * each machine, the `seedCreator` should create unique values for each - * thread on each machine. Otherwise, randomness will be striped which + * thread on each machine. Otherwise, randomness will be striped which is + * bad. * @param ev$1 reflection information about `K`. * @tparam A type of input passed to the [[StatefulRowCreator]]. * @tparam K the label type. diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala index 1d3a9527..2397e9e3 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -353,7 +353,7 @@ object MultilabelModel extends ParserProviderCompanion { override def parser: ModelParser = Parser object Parser extends ModelSubmodelParsingPlugin with Logging { - override val modelType: String = "multilabel-sparse" + override val modelType: String = "SparseMultilabel" // TODO: Figure if a Option[JsonReader[MultilabelModel[U, _, A, B]]] can be returned. // See: parser that returns SegmentationModel[U, _, N, A, B] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala new file mode 100644 index 00000000..b240364f --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala @@ -0,0 +1,93 @@ +package com.eharmony.aloha.util + +import scala.collection.{breakOut, SeqLike, immutable => sci} +import scala.collection.generic.{CanBuildFrom => CBF} + +/** + * Some ways to map with state without needing extra libraries like ''cats'' or ''scalaz''. + * @author deaktator + * @since 11/8/2017 + */ +private[aloha] object StatefulMapOps { + + /** + * Map over `as` while maintaining a running state to produce `(B, S)` pairs contained + * within some container type. + * + * Note that the input container type `In` determines the strictness of this method. + * For instance, this can produce values ''strictly'' with a strict sequence like + * `Vector` or ''non-strictly'' for a non-strict sequence like a `Stream`. + * For instance: + * + * {{{ + * val sCycleModulo4 = Stream.iterate(0)(s => (s + 1) % 4) + * val vCycleModulo4 = Vector.iterate(0, 10)(s => (s + 1) % 4) + * val addAndIncrement = (a: Int, s: Int) => ((a + s).toDouble, s + 1) + * + * + * // Non-strict, so it returns even though it's an infinite stream. O(1). + * // type: Stream[(Double, Int)] + * val stream = statefulMap(sCycleModulo4, 0)(addAndIncrement) + * + * // Strict, so it's runtime is O(N), where N = vCycleModulo4.size. + * import scala.collection.breakOut + * val list: List[(Double, Int)] = + * statefulMap(vCycleModulo4, 0)(addAndIncrement)(breakOut) + * }}} + * + * @param as elements to map (given some state that may change). '''NOTE''': the first + * element of `as` ''will be forced'' in this method in order to construct the + * output. + * @param startState the initial state. + * @param f a function that maps a value and a state and to an output and a new state. + * @param cbf object responsible for building the output. + * @tparam A the input element type + * @tparam B the output element type + * @tparam S the state type + * @tparam In a concrete implementation of the input sequence. + * @tparam Out the output container type (including type parameters). + * @return an output container with the `A`s mapped to `B`s and the state resulting + * from applying `f` and getting the second element (''e.g.'': `f(a, s)._2`). + */ + def statefulMap[A, B, S, In <: sci.Seq[A], Out](as: SeqLike[A, In], startState: S) + (f: (A, S) => (B, S)) + (implicit cbf: CBF[In, (B, S), Out]): Out = { + if (as.isEmpty) + cbf().result() + else { + val initEl = f(as.head, startState) + as.tail.scanLeft(initEl){ case ((_, newState), a) => f(a, newState) }(breakOut) + } + } + + + /** + * Map over `as` while maintaining a running state to produce `(B, S)` pairs contained + * within some container type. + * + * {{{ + * val sCycleModulo4 = Iterator.iterate(0)(s => (s + 1) % 4) + * val addAndIncrement = (a: Int, s: Int) => ((a + s).toDouble, s + 1) + * + * val stream = statefulMap(sCycleModulo4, 0)(addAndIncrement) + * }}} + * @param as elements to map (given some state that may change). '''NOTE''': the first + * element of `as` ''will be forced'' in this method in order to construct the + * output. + * @param startState the initial state. + * @param f a function that maps a value and a state and to an output and a new state. + * @tparam A the input element type + * @tparam B the output element type + * @tparam S the state type + * @return an iterator with the `A`s mapped to `B`s and the state resulting + * from applying `f` and getting the second element (''e.g.'': `f(a, s)._2`). + */ + def statefulMap[A, B, S](as: Iterator[A], startState: S)(f: (A, S) => (B, S)): Iterator[(B, S)] = { + if (!as.hasNext) + Iterator.empty + else { + val initEl = f(as.next(), startState) + as.scanLeft(initEl){ case ((_, newState), a) => f(a, newState) } + } + } +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala index 6b4a87ab..44900c9e 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala @@ -128,7 +128,7 @@ class VwMultilabelDownsampledModelTest { // Get the iterator of the examples produced. This is similar to what one may do // within a `mapPartitions` in Spark. - val examples = rc.mapIteratorWithState(trainingSet.iterator, rc.initialState) collect { + val examples = rc.statefulMap(trainingSet.iterator, rc.initialState) collect { case ((_, Some(x)), _) => x } diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala index 43b2ee59..cf003042 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -459,7 +459,7 @@ object VwMultilabelModelTest { val json = s""" |{ - | "modelType": "multilabel-sparse", + | "modelType": "SparseMultilabel", | "modelId": { "id": 1, "name": "NONE" }, | "features": { | "feature": "1" diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala index 900881d2..0ed7d478 100644 --- a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -287,4 +287,4 @@ object VwMultilabelParamAugmentationTest { val DefaultNumLabels = 0 val DefaultRingSize = s"--ring_size ${DefaultNumLabels + VwSparseMultilabelPredictor.AddlVwRingSize}" -} \ No newline at end of file +}