diff --git a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala index e6b8fd45..8f8e7811 100644 --- a/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala +++ b/aloha-cli/src/test/scala/com/eharmony/aloha/cli/ModelTypesTest.scala @@ -28,6 +28,7 @@ class ModelTypesTest { "ModelDecisionTree", "Regression", "Segmentation", + "SparseMultilabel", "VwJNI" ) diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala new file mode 100644 index 00000000..3bb3d5c2 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreator.scala @@ -0,0 +1,76 @@ +package com.eharmony.aloha.dataset + +import com.eharmony.aloha.util.StatefulMapOps + +import scala.collection.{SeqLike, immutable => sci} +import scala.collection.generic.{CanBuildFrom => CBF} + +/** + * A row creator that requires state. This state should be modeled functionally, meaning + * implementations should be referentially transparent. + * + * Created by ryan.deak on 11/2/17. + */ +trait StatefulRowCreator[-A, +B, S] extends Serializable { + + /** + * Some initial state that can be used on the very first call to `apply(A, S)`. + * @return some state. + */ + val initialState: S + + /** + * Given an `a` and some `state`, produce output, including a new state. + * + * When using this function, the user is responsible for keeping track of, + * and providing the state. + * + * The implementation of this function should be referentially transparent. + * + * @param a input + * @param state the state + * @return a tuple where the first element is a Tuple2 whose first element is + * missing and error information and second element is an optional result. + * The second element of the outer Tuple2 is the new state. + */ + def apply(a: A, state: S): ((MissingAndErroneousFeatureInfo, Option[B]), S) + + /** + * Apply the `apply(A, S)` method to the elements of the iterator. In the first + * application of `apply(A, S)`, `state` will be used as the state. In subsequent + * applications, the state will come from the state generated in the output of the + * previous application of `apply(A, S)`. + * + * For more information, see [[com.eharmony.aloha.util.StatefulMapOps]] + * + * @param as Note the first element of `as` ''will be forced'' in this method in order + * to construct the output. + * @param state the initial state to use at the start of the iterator. + * @return an iterator containing the `a` mapped to a + * `(MissingAndErroneousFeatureInfo, Option[B])` along with the resulting + * state that is created in the process. + */ + def statefulMap(as: Iterator[A], state: S): Iterator[((MissingAndErroneousFeatureInfo, Option[B]), S)] = + StatefulMapOps.statefulMap(as, state)(apply) + + /** + * Apply the `apply(A, S)` method to the elements of the sequence. In the first + * application of `apply(A, S)`, `state` will be used as the state. In subsequent + * applications, the state will come from the state generated in the output of the + * previous application of `apply(A, S)`. + * + * '''NOTE''': This method isn't really parallelizable via chunking. The way to + * parallelize this method is to provide a separate starting state for each unit + * of parallelism. + * + * For more information, see [[com.eharmony.aloha.util.StatefulMapOps]] + * + * @param as input to map. + * @param state the initial state to use at the start of mapping. + * @param cbf object responsible for building the output collection. + * @return + */ + def statefulMap[In <: sci.Seq[A], Out](as: SeqLike[A, In], state: S)(implicit + cbf: CBF[In, ((MissingAndErroneousFeatureInfo, Option[B]), S), Out] + ): Out = StatefulMapOps.statefulMap(as, state)(apply) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala new file mode 100644 index 00000000..4a5ce8a0 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/StatefulRowCreatorProducer.scala @@ -0,0 +1,43 @@ +package com.eharmony.aloha.dataset + +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import spray.json.JsValue + +import scala.util.Try + +/** + * Created by deaktator on 11/6/17. + * + * @tparam A + * @tparam B + * @tparam S + * @tparam Impl + */ +trait StatefulRowCreatorProducer[A, +B, S, +Impl <: StatefulRowCreator[A, B, S]] { + + /** + * Type of parsed JSON object. + */ + type JsonType + + /** + * Name of this producer. + * @return + */ + def name: String + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * @param json + * @return + */ + def parse(json: JsValue): Try[JsonType] + + /** + * Attempt to produce a Spec. + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a StatefulRowCreator. + * @return + */ + def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: JsonType): Try[Impl] +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala new file mode 100644 index 00000000..caf577a4 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/PositiveLabelsFunction.scala @@ -0,0 +1,42 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.AlohaException +import com.eharmony.aloha.dataset.DvProducer +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.{determineLabelNamespaces, LabelNamespaces} +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc + +import scala.collection.breakOut +import scala.util.{Failure, Success, Try} +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 11/6/17. + * @param ev$1 + * @tparam A + * @tparam K + */ +private[multilabel] abstract class PositiveLabelsFunction[A, K: RefInfo] { self: DvProducer => + + private[multilabel] def positiveLabelsFn( + semantics: CompiledSemantics[A], + positiveLabels: String + ): Try[GenAggFunc[A, sci.IndexedSeq[K]]] = + getDv[A, sci.IndexedSeq[K]]( + semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) + + private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = { + val nsNames: Set[String] = nss.map(_._1)(breakOut) + determineLabelNamespaces(nsNames) match { + case Some(ns) => Success(ns) + + // If there are so many VW namespaces that all available Unicode characters are taken, + // then a memory error will probably already have occurred. + case None => Failure(new AlohaException( + "Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + + nsNames.mkString(", ") + )) + } + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala new file mode 100644 index 00000000..000a909a --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwDownsampledMultilabelRowCreator.scala @@ -0,0 +1,378 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.dataset._ +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.VwCovariateProducer +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator._ +import com.eharmony.aloha.dataset.vw.multilabel.json.VwDownsampledMultilabeledJson +import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.Logging +import com.eharmony.aloha.util.rand.Rand +import spray.json.JsValue + +import scala.collection.{breakOut, immutable => sci} +import scala.util.Try + + +/** + * Creates training data for multilabel models in Vowpal Wabbit's CSOAA LDF and WAP LDF format + * for the JNI. In this row creator, negative labels are downsampled and costs for the + * downsampled labels are adjusted to produced an unbiased estimator. It is assumed that + * negative labels are in the majority. Downsampling negatives can improve both training + * time and possibly model performance. See the following resources for intuition: + * + - [[https://www3.nd.edu/~nchawla/papers/SPRINGER05.pdf Chawla, Nitesh V. "Data mining for + imbalanced datasets: An overview." Data mining and knowledge discovery handbook. + Springer US, 2009. 875-886.]] + - [[https://www3.nd.edu/~dial/publications/chawla2004editorial.pdf Chawla, Nitesh V., Nathalie + Japkowicz, and Aleksander Kotcz. "Special issue on learning from imbalanced data sets." + ACM SIGKDD Explorations Newsletter 6.1 (2004): 1-6.]] + - [[http://www.marcoaltini.com/blog/dealing-with-imbalanced-data-undersampling-oversampling-and-proper-cross-validation + Dealing with imbalanced data: undersampling, oversampling, and proper cross validation, + Marco Altini, Aug 17, 2015.]] + * + * This row creator, since it is stateful, requires the caller to maintain state. If however, + * it is only called via an iterator or sequence, then this row creator can maintain the state + * during iteration over the iterator or sequence. In the case of iterators, the mapping is + * '''non-strict''' and in the case of sequences (`Seq`), it is '''strict'''. + * + * @param allLabelsInTrainingSet all labels in the training set. This is a sequence because + * order matters. Order here can be chosen arbitrarily, but it + * must be consistent in the training and test formulation. + * @param featuresFunction features to extract from the data of type `A`. + * @param defaultNamespace list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param normalizer can modify VW output (currently unused) + * @param positiveLabelsFunction A method that can extract positive class labels. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param numDownsampledNegLabels '''a positive value''' representing the number of negative + * labels to include in each row. If this is less than the + * number of negative examples for a given row, then no + * downsampling of negatives will take place. + * @param seedCreator a "''function''" that creates a seed that will be used for randomness. + * The implementation of this function is important. It should create a + * unique value for each unit of parallelism. If for example, row + * creation is parallelized across multiple threads on one machine, the + * unit of parallelism is threads and `seedCreator` should produce unique + * values for each thread. If row creation is parallelized across multiple + * machines, the `seedCreator` should produce a unique value for each + * machine. If row creation is parallelized across machines and threads on + * each machine, the `seedCreator` should create unique values for each + * thread on each machine. Otherwise, randomness will be striped which + * is bad. + * @param includeZeroValues include zero values in VW input? + * @tparam A the input type + * @tparam K the label or class type + * @author deaktator + * @since 11/6/2017 + */ +final case class VwDownsampledMultilabelRowCreator[-A, K]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + featuresFunction: FeatureExtractorFunction[A, Sparse], + defaultNamespace: List[Int], + namespaces: List[(String, List[Int])], + normalizer: Option[CharSequence => CharSequence], + positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], + classNs: Char, + dummyClassNs: Char, + numDownsampledNegLabels: Int, + seedCreator: () => Long, + includeZeroValues: Boolean = false +) extends StatefulRowCreator[A, Array[String], Long] + with Logging { + + require( + 0 < numDownsampledNegLabels, + s"numDownsampledNegLabels must be positive, found $numDownsampledNegLabels" + ) + + import VwDownsampledMultilabelRowCreator._ + + @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + + // Precomputed for efficiency. + + private[this] val negativeDummyStr = + s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" + + private[this] val positiveDummyStr = + s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" + + /** + * Some initial state that can be used on the very first call to `apply(A, S)`. + * @return some state. + */ + @transient override lazy val initialState: Long = { + + val seed = seedCreator() + + // For logging. Try to get time as close as possible to calling initialSeed. + // Note: There's a very good chance this will differ. + val time = System.nanoTime() + val ip = java.net.InetAddress.getLocalHost.getHostAddress + val thread = Thread.currentThread() + val threadId = thread.getId + val threadName = thread.getName + + val scrambled = scramble(seed) + + info( + s"${getClass.getSimpleName} seed: $seed, scrambled: $scrambled, " + + s"nanotime: $time, ip: $ip, threadId: $threadId, threadName: $threadName" + ) + + scrambled + } + + /** + * Given an `a` and some `seed`, produce output, including a new seed. + * + * When using this function, the user is responsible for keeping track of, + * and providing the seeds. + * + * The implementation of this function should be referentially transparent. + * + * @param a input + * @param seed the random seed which is updated on each call. + * @return a tuple where the first element is a Tuple2 whose first element is + * missing and error information and second element is an optional result. + * The second element of the outer Tuple2 is the new state. + */ + override def apply(a: A, seed: Long): ((MissingAndErroneousFeatureInfo, Option[Array[String]]), Long) = { + val (missingAndErrs, features) = featuresFunction(a) + + // Get the lazy val once. + val labToInd = labelToInd + + // TODO: This seems like it could be optimized. + // The positiveLabelsFunction is invoked once and all labels are produced. + // Then a set is produced that is used to partition all of the labels in + // the training data into positive and negative. It seems like a slightly + // more efficient way to do this would be to create a sorted array of positives + // indices. Then a negative index array could be constructed with the appropriate + // size based on the size of allLabelsInTrainingSet and the size of the positives + // array. Next allLabelsInTrainingSet.indices and the positiveIndices array could + // be iterated over simultaneously and if the current index in + // allLabelsInTrainingSet.indices isn't in the positiveIndices array, then it must + // be in the negativeIndices array. Then the offsets into the two arrays are + // incremented. Both the current and proposed algorithms are O(N), where + // N = allLabelsInTrainingSet.indices.size. But the proposed algorithm would likely + // have much better constant factors. + // + // Note that labels produced by positiveLabelsFunction that are not in the + // allLabelsInTrainingSet are discarded without notice. + // + // TODO: Should this be sci.BitSet? + val positiveIndices: Set[Int] = positiveLabelsFunction(a).flatMap(labToInd.get)(breakOut) + + val (vwInput, newSeed) = sampledTrainingInput( + features, + allLabelsInTrainingSet.indices, + positiveIndices, + defaultNamespace, + namespaces, + classNs, + dummyClassNs, + negativeDummyStr, + positiveDummyStr, + seed, + numDownsampledNegLabels + ) + + ((missingAndErrs, Option(vwInput)), newSeed) + } +} + + +object VwDownsampledMultilabelRowCreator extends Rand { + + // Expose initSeedScramble to companion class. + private def scramble(initSeed: Long): Long = initSeedScramble(initSeed) + + /** + * + * '''This should not be used directly.''' ''It is exposed for testing.'' + * + * @param features the common features produced + * @param indices indices of all labels in the training set. + * @param positiveLabelIndices indices of labels in the training set that positive are + * positive for this training example + * @param defaultNs list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param negativeDummyStr line in VW input associated with the negative dummy class. + * @param positiveDummyStr line in VW input associated with the positive dummy class. + * @param seed a seed for randomness. The second part of the output of this function is a + * new seed that should be used on the next call to this function. + * @param numNegLabelsTarget the desired number of negative labels. If it is determined + * that there are less negative labels than desired, no negative + * label downsampling will occur; otherwise, the negative labels + * will be downsampled to this target value. + * @return an array representing the VW input that can either be passed directly to the + * VW JNI library, or `mkString("", "\n", "\n")` can be called to pass to the VW + * CLI. The second part of the return value is a new seed to use as the `seed` + * parameter in the next call to this function. + */ + private[multilabel] def sampledTrainingInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + positiveLabelIndices: Int => Boolean, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: Char, + dummyClassNs: Char, + negativeDummyStr: String, + positiveDummyStr: String, + seed: Long, + numNegLabelsTarget: Int + ): (Array[String], Long) = { + + // Partition into positive and negative indices. + val (posLabelInd, negLabelInd) = indices.partition(positiveLabelIndices) + val negLabels = negLabelInd.size + + val numDownsampledLabels = math.min(numNegLabelsTarget, negLabels) + + // Sample indices in negLabelInd to use in the output. + val (downsampledNegLabelInd, newSeed) = + sampleCombination(negLabels, numDownsampledLabels, seed) + + val p = posLabelInd.size + val n = downsampledNegLabelInd.length + + // This code is written with knowledge of the constant's value. + // TODO: Write calling code to abstract over costs so knowledge of the costs isn't necessary. + val negWt = + if (numDownsampledLabels == negLabels) { + // No downsampling occurs. + NegativeCost.toString + } + else { + // Determine the weight for the downsampled negatives. + // If the cost of negative examples is positive, then the weight will be + // strictly greater than NegativeCost. + + f"${NegativeCost * negLabels / n.toDouble}%.5g" + } + + // The length of the output array is n + 3. + // + // The first row is the shared features. These are features that are not label dependent. + // Then comes two dummy classes. These are to make the probabilities work out. + // Then come the features for each of the n labels. + val x = new Array[String](p + n + 3) + + val shared = VwRowCreator.unlabeledVwInput( + features, defaultNs, namespaces, includeZeroValues = false + ) + + x(0) = SharedFeatureIndicator + shared + x(1) = negativeDummyStr + x(2) = positiveDummyStr + + // vvvvv This is mutable because we want speed. vvvvv + + // Negative weights + var ni = 0 + while (ni < n) { + val labelInd = negLabelInd(downsampledNegLabelInd(ni)) + x(ni + 3) = s"$labelInd:$negWt |$classNs _$labelInd" + ni += 1 + } + + // Positive weights + var pi = 0 + while (pi < p) { + val labelInd = posLabelInd(pi) + x(pi + n + 3) = s"$labelInd:$PositiveCost |$classNs _$labelInd" + pi += 1 + } + + (x, newSeed) + } + + /** + * A producer that can produce a [[VwDownsampledMultilabelRowCreator]]. + * The requirement for [[StatefulRowCreatorProducer]] to only have zero-argument constructors + * is relaxed for this Producer because we don't have a way of generically constructing a + * list of labels. If the labels were encoded in the JSON, then a JsonReader for the label + * type would have to be passed to the constructor. Since the labels can't be encoded + * generically in the JSON, we accept that this Producer is a special case and allow the labels + * to be passed directly. The consequence is that this producer doesn't just rely on the + * dataset specification and the data itself. It also relying on the labels provided to the + * constructor. + * + * @param allLabelsInTrainingSet All of the labels that will be encountered in the training set. + * @param seedCreator a "''function''" that creates a seed that will be used for randomness. + * The implementation of this function is important. It should create a + * unique value for each unit of parallelism. If for example, row + * creation is parallelized across multiple threads on one machine, the + * unit of parallelism is threads and `seedCreator` should produce unique + * values for each thread. If row creation is parallelized across multiple + * machines, the `seedCreator` should produce a unique value for each + * machine. If row creation is parallelized across machines and threads on + * each machine, the `seedCreator` should create unique values for each + * thread on each machine. Otherwise, randomness will be striped which is + * bad. + * @param ev$1 reflection information about `K`. + * @tparam A type of input passed to the [[StatefulRowCreator]]. + * @tparam K the label type. + */ + final class Producer[A, K: RefInfo]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + seedCreator: () => Long + ) extends PositiveLabelsFunction[A, K] + with StatefulRowCreatorProducer[A, Array[String], Long, VwDownsampledMultilabelRowCreator[A, K]] + with RowCreatorProducerName + with VwCovariateProducer[A] + with DvProducer + with SparseCovariateProducer + with CompilerFailureMessages { + + override type JsonType = VwDownsampledMultilabeledJson + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * to create the row creator. + * @param json JSON AST. + * @return + */ + override def parse(json: JsValue): Try[VwDownsampledMultilabeledJson] = + Try { json.convertTo[VwDownsampledMultilabeledJson] } + + /** + * Attempt to produce a Spec. + * + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a RowCreator. + * @return + */ + override def getRowCreator( + semantics: CompiledSemantics[A], + jsonSpec: VwDownsampledMultilabeledJson + ): Try[VwDownsampledMultilabelRowCreator[A, K]] = { + val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) + + val rc = for { + cov <- covariates + pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + labelNs <- labelNamespaces(nss) + actualLabelNs = labelNs.labelNs + dummyLabelNs = labelNs.dummyLabelNs + sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) + numNeg = jsonSpec.numDownsampledNegLabels + } yield new VwDownsampledMultilabelRowCreator[A, K]( + allLabelsInTrainingSet, cov, default, nss, normalizer, + pos, actualLabelNs, dummyLabelNs, numNeg, seedCreator) + + rc + } + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala new file mode 100644 index 00000000..784bbee6 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreator.scala @@ -0,0 +1,369 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.dataset._ +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.VwCovariateProducer +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.CompiledSemantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.rand.Rand +import spray.json.JsValue + +import scala.collection.{breakOut, immutable => sci} +import scala.util.Try + +/** + * Creates training data for multilabel models in Vowpal Wabbit's CSOAA LDF and WAP LDF format + * for the JNI. + * + * @param allLabelsInTrainingSet all labels in the training set. This is a sequence because + * order matters. Order here can be chosen arbitrarily, but it + * must be consistent in the training and test formulation. + * @param featuresFunction features to extract from the data of type `A`. + * @param defaultNamespace list of feature indices in the default VW namespace. + * @param namespaces a mapping from VW namespace name to feature indices in that namespace. + * @param normalizer can modify VW output (currently unused) + * @param positiveLabelsFunction A method that can extract positive class labels. + * @param classNs the namespace name for class information. + * @param dummyClassNs the namespace name for dummy class information. 2 dummy classes are + * added to make the predicted probabilities work. + * @param includeZeroValues include zero values in VW input? + * @tparam A the input type + * @tparam K the label or class type + * @author deaktator + * @since 9/13/2017 + */ +final case class VwMultilabelRowCreator[-A, K]( + allLabelsInTrainingSet: sci.IndexedSeq[K], + featuresFunction: FeatureExtractorFunction[A, Sparse], + defaultNamespace: List[Int], + namespaces: List[(String, List[Int])], + normalizer: Option[CharSequence => CharSequence], + positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], + classNs: Char, + dummyClassNs: Char, + includeZeroValues: Boolean = false) +extends RowCreator[A, Array[String]] { + import VwMultilabelRowCreator._ + + @transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap + + // Precompute these for efficiency rather recompute than inside a hot loop. + // Notice these are not lazy vals. + + private[this] val negativeDummyStr = + s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" + + private[this] val positiveDummyStr = + s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" + + override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { + val (missingAndErrs, features) = featuresFunction(a) + + // Get the lazy val once. + val labToInd = labelToInd + + // TODO: Should this be sci.BitSet? + val positiveIndices: Set[Int] = positiveLabelsFunction(a).flatMap(labToInd.get)(breakOut) + + val x: Array[String] = trainingInput( + features, + allLabelsInTrainingSet.indices, + positiveIndices, + defaultNamespace, + namespaces, + classNs, + negativeDummyStr, + positiveDummyStr + ) + + (missingAndErrs, x) + } +} + +object VwMultilabelRowCreator extends Rand { + + /** + * VW allows long-based feature indices, but Aloha only allow's 32-bit indices + * on the features that produce the key-value pairs passed to VW. The negative + * dummy classes uses an ID outside of the allowable range of feature indices: + * 2^32^. + */ + private[multilabel] val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString + + /** + * VW allows long-based feature indices, but Aloha only allow's 32-bit indices + * on the features that produce the key-value pairs passed to VW. The positive + * dummy classes uses an ID outside of the allowable range of feature indices: + * 2^32^ + 1. + */ + private[multilabel] val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString + + // NOTE: If PositiveCost and NegativeCost change, + // VwDownsampledMultilabledRowCreator.sampledTrainingInput + // will also need to change. + + /** + * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the + * dependent variable is based on cost (which is the negative of reward). + * As such, the ''reward'' of a positive example is designated to be zero. + */ + private[multilabel] val PositiveCost = 0 + + /** + * Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the + * dependent variable is based on cost (which is the negative of reward). + * As such, the ''reward'' of a negative example is designated to be -1, + * so the cost (or negative reward) is 1. + */ + private[multilabel] val NegativeCost = 1 + + private[multilabel] val PositiveDummyClassFeature = "P" + + private[multilabel] val NegativeDummyClassFeature = "N" + + /** + * "shared" is a special keyword in VW multi-class (multi-row) format. + * See Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html page]]. + * + * '''NOTE''': The trailing space should be here. + */ + private[multilabel] val SharedFeatureIndicator = "shared" + " " + + private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ')) + + /** + * Determine the label namespaces for VW. VW insists on the uniqueness of the first character + * of namespaces. The goal is to use try to use the first one of these combinations for + * label and dummy label namespaces where neither of the values are in `usedNss`. + * + - "Y", "y" + - "Z", "z" + - "Λ", "λ" + * + * If one of these combinations cannot be used because at least one of the elements in a given + * row is in `usedNss`, then iterate over the Unicode set and take the first two characters + * found that adhere to are deemed a valid character. These will then become the actual and + * dummy namespace names (respectively). + * + * The goal of this function is to try to use characters in literature used to denote a + * dependent variable. If that isn't possible (because the characters are already used by some + * other namespace), just find the first possible characters. + * @param usedNss names of namespaces used. + * @return the namespace for ''actual'' label information then the namespace for ''dummy'' + * label information. If two valid namespaces couldn't be produced, return None. + */ + private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[LabelNamespaces] = { + val nss = nssToFirstCharBitSet(usedNss) + preferredLabelNamespaces(nss) orElse bruteForceNsSearch(nss) + } + + private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[LabelNamespaces] = { + PreferredLabelNamespaces collectFirst { + case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) => + LabelNamespaces(actual, dummy) + } + } + + private[multilabel] def nssToFirstCharBitSet(ss: Set[String]): sci.BitSet = + ss.collect { case s if s.length != 0 => + s.charAt(0).toInt + }(breakOut[Set[String], Int, sci.BitSet]) + + private[multilabel] def validCharForNamespace(chr: Char): Boolean = { + // These might be overkill. + Character.isDefined(chr) && + Character.isLetter(chr) && + !Character.isISOControl(chr) && + !Character.isSpaceChar(chr) && + !Character.isWhitespace(chr) + } + + /** + * Find the first two valid characters that can be used as VW namespaces that, when converted + * to integers are not present in usedNss. + * @param usedNss the set of first characters in namespaces. + * @return the namespace to use for the actual classes and dummy classes, respectively. + */ + private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[LabelNamespaces] = { + val found = + Iterator + .range(Char.MinValue, Char.MaxValue) + .filter(c => !(usedNss contains c) && validCharForNamespace(c.toChar)) + .take(2) + .toList + + found match { + case actual :: dummy :: Nil => + Option(LabelNamespaces(actual.toChar, dummy.toChar)) + case _ => None + } + } + + /** + * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param positiveLabelIndices a predicate telling whether the example should be positively + * associated with a label. + * @param defaultNs the indices into `features` that should be placed in VW's default + * namespace. + * @param namespaces the indices into `features` that should be associated with each + * namespace. + * @param classNs a namespace for features associated with class labels + * // @param dummyClassNs a namespace for features associated with dummy class labels + * @param negativeDummyStr + * @param positiveDummyStr + * @return an array to be passed directly to an underlying `VWActionScoresLearner`. + */ + private[multilabel] def trainingInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + positiveLabelIndices: Int => Boolean, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: Char, + negativeDummyStr: String, + positiveDummyStr: String + ): Array[String] = { + + val n = indices.size + + // The length of the output array is n + 3. + // + // The first row is the shared features. These are features that are not label dependent. + // Then comes two dummy classes. These are to make the probabilities work out. + // Then come the features for each of the n labels. + val x = new Array[String](n + 3) + + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, includeZeroValues = false) + x(0) = SharedFeatureIndicator + shared + + // These string interpolations are computed over and over but will always be the same + // for a given dummyClassNs. + x(1) = negativeDummyStr + x(2) = positiveDummyStr + + // vvvvv This is mutable because we want speed. vvvvv + + var i = 0 + while (i < n) { + val labelInd = indices(i) + + // TODO or positives.contains(labelInd)? + val dv = if (positiveLabelIndices(i)) PositiveCost else NegativeCost + x(i + 3) = s"$labelInd:$dv |$classNs _$labelInd" + i += 1 + } + + x + } + + /** + * Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param defaultNs the indices into `features` that should be placed in VW's default + * namespace. + * @param namespaces the indices into `features` that should be associated with each + * namespace. + * @param classNs a namespace for features associated with class labels + * @return an array to be passed directly to an underlying `VWActionScoresLearner`. + */ + private[aloha] def predictionInput( + features: IndexedSeq[Sparse], + indices: sci.IndexedSeq[Int], + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + classNs: String + ): Array[String] = { + + val n = indices.size + + // Use a (mutable) array (and iteration) for speed. + // The first row is the shared features. These are features that are not label dependent. + // Then come the features for each of the n labels. + val x = new Array[String](n + 1) + + val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) + x(0) = SharedFeatureIndicator + shared + + var i = 0 + while (i < n) { + val labelInd = indices(i) + x(i + 1) = s"$labelInd:0 |$classNs _$labelInd" + i += 1 + } + + x + } + + /** + * A producer that can produce a [[VwMultilabelRowCreator]]. + * The requirement for [[RowCreatorProducer]] to only have zero-argument constructors is + * relaxed for this Producer because we don't have a way of generically constructing a + * list of labels. If the labels were encoded in the JSON, then a JsonReader for the label + * type would have to be passed to the constructor. Since the labels can't be encoded + * generically in the JSON, we accept that this Producer is a special case and allow the labels + * to be passed directly. The consequence is that this producer doesn't just rely on the + * dataset specification and the data itself. It also relying on the labels provided to the + * constructor. + * + * @param allLabelsInTrainingSet All of the labels that will be encountered in the training set. + * @param ev$1 reflection information about `K`. + * @tparam A type of input passed to the [[RowCreator]]. + * @tparam K the label type. + */ + final class Producer[A, K: RefInfo](allLabelsInTrainingSet: sci.IndexedSeq[K]) + extends PositiveLabelsFunction[A, K] + with RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] + with RowCreatorProducerName + with VwCovariateProducer[A] + with DvProducer + with SparseCovariateProducer + with CompilerFailureMessages { + + override type JsonType = VwMultilabeledJson + + /** + * Attempt to parse the JSON AST to an intermediate representation that is used + * to create the row creator. + * @param json JSON AST. + * @return + */ + override def parse(json: JsValue): Try[VwMultilabeledJson] = + Try { json.convertTo[VwMultilabeledJson] } + + /** + * Attempt to produce a Spec. + * + * @param semantics semantics used to make sense of the features in the JsonSpec + * @param jsonSpec a JSON specification to transform into a RowCreator. + * @return + */ + override def getRowCreator( + semantics: CompiledSemantics[A], + jsonSpec: VwMultilabeledJson + ): Try[VwMultilabelRowCreator[A, K]] = { + val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) + + val rc = for { + cov <- covariates + pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) + labelNs <- labelNamespaces(nss) + actualLabelNs = labelNs.labelNs + dummyLabelNs = labelNs.dummyLabelNs + sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) + } yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, + normalizer, pos, actualLabelNs, dummyLabelNs) + + rc + } + } + + private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala new file mode 100644 index 00000000..08b812ce --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwDownsampledMultilabeledJson.scala @@ -0,0 +1,42 @@ +package com.eharmony.aloha.dataset.vw.multilabel.json + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.json.VwJsonLike +import spray.json.{DefaultJsonProtocol, RootJsonFormat} + +import scala.collection.{immutable => sci} + +/** + * JSON AST for `VwDownsampledMultilabelRowCreator`. + * @param imports + * @param features + * @param namespaces + * @param normalizeFeatures + * @param positiveLabels string representing a function that will be used to extract positive + * labels from the input. + * @param numDownsampledNegLabels '''a positive value''' representing the number of negative + * labels to include in each row. If this is less than the + * number of negative examples for a given row, then no + * downsampling of negatives will take place. + * @author deaktator + * @since 11/6/2017 + */ +final case class VwDownsampledMultilabeledJson( + imports: sci.Seq[String], + features: sci.IndexedSeq[SparseSpec], + namespaces: Option[Seq[Namespace]] = Some(Nil), + normalizeFeatures: Option[Boolean] = Some(false), + positiveLabels: String, + numDownsampledNegLabels: Int +) extends VwJsonLike { + + require( + 0 < numDownsampledNegLabels, + s"numDownsampledNegLabels must be positive, found $numDownsampledNegLabels" + ) +} + +object VwDownsampledMultilabeledJson extends DefaultJsonProtocol { + implicit val vwDownsampledMultilabeledJson: RootJsonFormat[VwDownsampledMultilabeledJson] = + jsonFormat6(VwDownsampledMultilabeledJson.apply) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala new file mode 100644 index 00000000..e59b8da2 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/dataset/vw/multilabel/json/VwMultilabeledJson.scala @@ -0,0 +1,23 @@ +package com.eharmony.aloha.dataset.vw.multilabel.json + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.json.VwJsonLike +import spray.json.{DefaultJsonProtocol, RootJsonFormat} + +import scala.collection.{immutable => sci} + +/** + * Created by ryan.deak on 9/13/17. + */ +final case class VwMultilabeledJson( + imports: sci.Seq[String], + features: sci.IndexedSeq[SparseSpec], + namespaces: Option[Seq[Namespace]] = Some(Nil), + normalizeFeatures: Option[Boolean] = Some(false), + positiveLabels: String) + extends VwJsonLike + +object VwMultilabeledJson extends DefaultJsonProtocol { + implicit val labeledVwJsonFormat: RootJsonFormat[VwMultilabeledJson] = + jsonFormat5(VwMultilabeledJson.apply) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala index 516f9146..1e608a78 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ModelFactory.scala @@ -1,19 +1,17 @@ package com.eharmony.aloha.factory -import com.eharmony.aloha import com.eharmony.aloha.audit.{Auditor, MorphableAuditor} import com.eharmony.aloha.factory.ModelFactory.{InlineReader, ModelInlineReader, SubmodelInlineReader} import com.eharmony.aloha.factory.ex.{AlohaFactoryException, RecursiveModelDefinitionException} import com.eharmony.aloha.factory.jsext.JsValueExtensions import com.eharmony.aloha.factory.ri2jf.{RefInfoToJsonFormat, StdRefInfoToJsonFormat} -import com.eharmony.aloha.io.{GZippedReadable, LocationLoggingReadable, ReadableByString} import com.eharmony.aloha.io.multiple.{MultipleAlohaReadable, SequenceMultipleReadable} import com.eharmony.aloha.io.sources.ReadableSource +import com.eharmony.aloha.io.{GZippedReadable, LocationLoggingReadable, ReadableByString} import com.eharmony.aloha.models.{Model, Submodel} -import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps, RuntimeClasspathScanning} import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.util.Logging -import org.reflections.Reflections import spray.json.DefaultJsonProtocol.{StringJsonFormat, jsonFormat2, optionFormat} import spray.json.{CompactPrinter, JsObject, JsValue, JsonFormat, JsonReader, RootJsonFormat, pimpString} @@ -268,7 +266,7 @@ case class ModelFactoryImpl[U, N, A, B <: U]( } } -object ModelFactory { +object ModelFactory extends RuntimeClasspathScanning { /** Provides a default factory capable of producing models defined in aloha-core. The list of models come from * the knownModelParsers method. @@ -286,23 +284,7 @@ object ModelFactory { /** Get the list of models on the classpath with parsers that can be used by a model factory. * @return */ - def knownModelParsers(): Seq[ModelParser] = { - val reflections = new Reflections(aloha.pkgName) - import scala.collection.JavaConversions.asScalaSet - val parserProviderCompanions = reflections.getSubTypesOf(classOf[ParserProviderCompanion]).toSeq - - parserProviderCompanions.flatMap { - case ppc if ppc.getCanonicalName.endsWith("$") => - Try { - val c = Class.forName(ppc.getCanonicalName.dropRight(1)) - c.getMethod("parser").invoke(null) match { - case mp: ModelParser => mp - case _ => throw new IllegalStateException() - } - }.toOption - case _ => None - } - } + def knownModelParsers(): Seq[ModelParser] = scanObjects[ParserProviderCompanion, ModelParser]("parser") private[factory] sealed trait InlineReader[U, N, -A, +B <: U, Y] { def jsonReader(parser: ModelParser): Try[JsonReader[_ <: Y]] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala index 026fbd4f..80e00364 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/formats.scala @@ -60,11 +60,6 @@ trait JavaJsonFormats { object ScalaJsonFormats extends ScalaJsonFormats trait ScalaJsonFormats { - // This is a very slightly modified copy of the lift from Additional formats that removes the type bound. - implicit def lift[A](implicit reader: JsonReader[A]): JsonFormat[A] = new JsonFormat[A] { - def write(a: A): JsValue = throw new UnsupportedOperationException("No JsonWriter[" + a.getClass + "] available") - def read(value: JsValue): A = reader.read(value) - } implicit def listMapFormat[K :JsonFormat, V :JsonFormat] = new RootJsonFormat[ListMap[K, V]] { def write(m: ListMap[K, V]) = JsObject { diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala index 72ec2954..49dd61c7 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/factory/ri2jf/CollectionTypes.scala @@ -1,16 +1,37 @@ package com.eharmony.aloha.factory.ri2jf import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import spray.json.DefaultJsonProtocol.{immSeqFormat, mapFormat} import spray.json.JsonFormat -import spray.json.DefaultJsonProtocol.immSeqFormat /** * Created by ryan on 1/25/17. */ class CollectionTypes extends RefInfoToJsonFormatConversions { - def apply[A](implicit conv: RefInfoToJsonFormat, r: RefInfo[A], jf: JsonFormatCaster[A]): Option[JsonFormat[A]] = { - if (RefInfoOps.isSubType[A, collection.immutable.Seq[Any]]) - conv(r.typeArguments.head).flatMap(f => jf(immSeqFormat(f))) + def apply[A](implicit conv: RefInfoToJsonFormat, + r: RefInfo[A], + jf: JsonFormatCaster[A]): Option[JsonFormat[A]] = { + + val typeParams = RefInfoOps.typeParams[A] + + if (RefInfoOps.isSubType[A, collection.immutable.Map[Any, Any]]) + typeParams match { + case List(tKey, tVal) => + for { + k <- conv(tKey) + v <- conv(tVal) + f <- jf(mapFormat(k, v)) + } yield f + case _ => + // TODO: perhaps change the API at some point to better report errors here. + None + } + else if (RefInfoOps.isSubType[A, collection.immutable.Seq[Any]]) + for { + tEl <- typeParams.headOption + el <- conv(tEl) + f <- jf(immSeqFormat(el)) + } yield f else None } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala new file mode 100644 index 00000000..2397e9e3 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModel.scala @@ -0,0 +1,493 @@ +package com.eharmony.aloha.models.multilabel + +import java.io.{Closeable, PrintWriter, StringWriter} + +import com.eharmony.aloha.audit.Auditor +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.factory._ +import com.eharmony.aloha.factory.ri2jf.RefInfoToJsonFormat +import com.eharmony.aloha.id.ModelIdentity +import com.eharmony.aloha.models._ +import com.eharmony.aloha.models.multilabel.json.MultilabelModelReader +import com.eharmony.aloha.models.reg.RegressionFeatures +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.semantics.Semantics +import com.eharmony.aloha.semantics.func.{GenAggFunc, GenAggFuncAccessorProblems} +import com.eharmony.aloha.util.{Logging, SerializabilityEvidence} +import spray.json.{JsonFormat, JsonReader} + +import scala.collection.{immutable => sci, mutable => scm} +import scala.util.{Failure, Success} + +/** + * A multi-label predictor. + * + * Created by ryan.deak on 8/29/17. + * + * @param modelId An identifier for the model. Used in score and error reporting. + * @param featureNames feature names (parallel to featureFunctions) + * @param featureFunctions feature extracting functions. + * @param labelsInTrainingSet a sequence of all labels encountered during training. Note: the + * order of labels may relate to the predictor produced by + * predictorProducer. It is the caller's responsibility to ensure + * the order is correct. To mitigate such problems, both labels + * and indices into labelsInTrainingSet are passed to the predictor + * produced by predictorProducer. + * @param labelsOfInterest if provided, a sequence of labels will be extracted from the example + * for which a prediction is desired. The ''intersection'' of the + * extracted labels and the training labels will be the labels for which + * predictions will be produced. + * @param predictorProducer the function produced when calling this function is responsible for + * getting the data into the correct type and using it within an + * underlying ML library to produce a prediction. The mapping back to + * (K, Double) pairs is also its responsibility. If the predictor + * produced by predictorProducer is Closeable, it will be closed when + * MultilabelModel's close method is called. + * @param numMissingThreshold if provided, we check whether the threshold is exceeded. If so, + * return an error instead of the computed score. This is for missing + * data situations. + * @param auditor transforms a `Map[K, Double]` to a `B`. Reports successes and errors. + * @param ev evidence that `K` is serializable. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam A input type of the model + * @tparam B output type of the model. + */ +case class MultilabelModel[U, K, -A, +B <: U]( + modelId: ModelIdentity, + featureNames: sci.IndexedSeq[String], + featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]], + labelsInTrainingSet: sci.IndexedSeq[K], + labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], + predictorProducer: SparsePredictorProducer[K], + numMissingThreshold: Option[Int], + auditor: Auditor[U, Map[K, Double], B]) +(implicit ev: SerializabilityEvidence[K]) +extends SubmodelBase[U, Map[K, Double], A, B] + with RegressionFeatures[A] { + + import MultilabelModel._ + + /** + * predictory is transient lazy value because we don't need to worry about serialization. + * We don't care about the lazy property. It should be created eagerly. + */ + @transient private[this] lazy val predictor = predictorProducer() + predictor // Force predictor eagerly + + /** + * Cache this in case labelsOfInterest is None. In that case, we don't want to repeatedly + * create this because it could create a GC burden for no real reason. + */ + private[this] val defaultLabelInfo = + // Hopefully, making this non-transient doesn't increase the serialized size much. + // labelsInTrainingSet isn't serialized twice is it? + LabelsAndInfo(labelsInTrainingSet.indices, labelsInTrainingSet, Seq.empty, None) + + /** + * Making from label to index into the sequence of all labels encountered during training. + */ + private[this] val labelToInd: Map[K, Int] = + labelsInTrainingSet.zipWithIndex.map { case (label, i) => label -> i }(collection.breakOut) + + override def subvalue(a: A): Subvalue[B, Map[K, Double]] = { + val li = labelsAndInfo(a, labelsOfInterest, labelToInd, defaultLabelInfo) + + if (li.labels.isEmpty) + reportNoPrediction(modelId, li, auditor) + else { + val Features(x, missing, missingOk) = constructFeatures(a) + + if (!missingOk) + reportTooManyMissing(modelId, li, missing, auditor) + else { + // TODO: To support label-dependent features, fill last parameter with a valid value. + val predictionTry = predictor(x, li.labels, li.indices, sci.IndexedSeq.empty) + + predictionTry match { + case Success(pred) => reportSuccess(modelId, li, missing, pred, auditor) + case Failure(ex) => reportPredictorError(modelId, li, missing, ex, auditor) + } + } + } + } + + /** + * When the `predictor` passed to the constructor is a java.io.Closeable, its `close` + * method is called. + */ + override def close(): Unit = + predictor match { + case closeable: Closeable => closeable.close() + case _ => + } +} + +object MultilabelModel extends ParserProviderCompanion { + + /** + * Contains information about the labels to be used for predictions, and problems encountered + * while trying to get those labels. + * @param indices indices into the sequence of all labels seen during training. These should + * be sorted in ascending order. + * @param labels labels for which a prediction should be produced. labels are parallel to + * indices so `indices(i)` is the index associated with `labels(i)`. + * @param labelsNotInTrainingSet a sequence of labels derived from the input data that could + * not be found in the sequence of all labels seen during training. + * @param problems any problems encountered when trying to get the labels. This should only + * be present when the caller indicates labels should be embedded in the + * input data passed to the prediction function in the MultilabelModel. + * @tparam K type of label or class + */ + protected[multilabel] case class LabelsAndInfo[K]( + indices: sci.IndexedSeq[Int], + labels: sci.IndexedSeq[K], + labelsNotInTrainingSet: Seq[K], + problems: Option[GenAggFuncAccessorProblems] + ) { + def missingVarNames: Seq[String] = problems.map(p => p.missing).getOrElse(Nil) + def errorMsgs: Seq[String] = { + labelsNotInTrainingSet.map { lab => s"Label not in training labels: $lab" } ++ + problems.map(p => p.errors).getOrElse(Nil) + } + } + + private[multilabel] val NumLinesToKeepInStackTrace = 20 + + private[multilabel] val TooManyMissingError = + "Too many missing features encountered to produce prediction." + + private[multilabel] val NoLabelsError = "No labels provided. Cannot produce a prediction." + + /** + * Get the labels and information about the labels. + * @param a an input from which label information should be derived if labelsOfInterest is not empty. + * @param labelsOfInterest an optional function used to extract label information from the input `a`. + * @param labelToInd a mapping from label to index into the sequence of all labels seen during training. + * @param defaultLabelInfo label information related to all labels seen at training time. If + * `labelsOfInterest` is not provided, this information will be used. + * @tparam A input type of the model + * @tparam K type of label or class + * @return labels and information about the labels. + */ + protected[multilabel] def labelsAndInfo[A, K]( + a: A, + labelsOfInterest: Option[GenAggFunc[A, sci.IndexedSeq[K]]], + labelToInd: Map[K, Int], + defaultLabelInfo: LabelsAndInfo[K] + ): LabelsAndInfo[K] = + labelsOfInterest.fold(defaultLabelInfo)(f => labelsForPrediction(a, f, labelToInd)) + + /** + * Combine the missing variables found into a set. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @tparam K type of label or class + * @return a set of missing features + */ + protected[multilabel] def combineMissing[K]( + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]] + ): Set[String] = missing.values.flatten.toSet ++ labelInfo.missingVarNames + + /** + * Report that a prediction could not be made because too many missing features were encountered. + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. + */ + protected[multilabel] def reportTooManyMissing[U, K, B <: U]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + + // TODO: Check that missing.values.flatten.toSet AND labelInfo.missingFeatures have the same format. + val aud = auditor.failure( + modelId, + errorMsgs = TooManyMissingError +: labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missing) + ) + Subvalue(aud, None) + } + + /** + * Report that no prediction attempt was made because of issues with the labels. + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. + */ + protected[multilabel] def reportNoPrediction[U, K, B <: U]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + val aud = auditor.failure( + modelId, + errorMsgs = NoLabelsError +: labelInfo.errorMsgs, + missingVarNames = labelInfo.missingVarNames.toSet + ) + Subvalue(aud, None) + } + + /** + * Report that the model succeeded. + * @param modelId An identifier for the model. Used in score reporting. + * @param labelInfo labels and information about the labels. + * @param missing missing features from + * @param prediction the prediction(s) made by the embedded predictor. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating success. + */ + protected[multilabel] def reportSuccess[U, K, B <: U]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missing: scm.Map[String, Seq[String]], + prediction: Map[K, Double], + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Map[K, Double]] = { + + val aud = auditor.success( + modelId, + prediction, + errorMsgs = labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missing) + ) + + Subvalue(aud, Option(prediction)) + } + + /** + * Report that a `Throwable` was thrown while invoking the predictor + * @param modelId An identifier for the model. Used in error reporting. + * @param labelInfo labels and information about the labels. + * @param missingFeatureMap missing features from RegressionFeatures + * @param throwable the error the occurred in the predictor. + * @param auditor an auditor used to audit the output. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam B output type of the model. + * @return a SubValue indicating failure. + */ + protected[multilabel] def reportPredictorError[U, K, B <: U]( + modelId: ModelIdentity, + labelInfo: LabelsAndInfo[K], + missingFeatureMap: scm.Map[String, Seq[String]], + throwable: Throwable, + auditor: Auditor[U, Map[K, Double], B] + ): Subvalue[B, Nothing] = { + + val sw = new StringWriter + val pw = new PrintWriter(sw) + throwable.printStackTrace(pw) + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + + val aud = auditor.failure( + modelId, + errorMsgs = stackTrace +: labelInfo.errorMsgs, + missingVarNames = combineMissing(labelInfo, missingFeatureMap) + ) + Subvalue(aud, None) + } + + /** + * Get labels from the input for which a prediction should be produced. If + * `labelsOfInterest` produces a label not in the training set, it will not + * be present in the prediction output but it will appear in + * `LabelsAndInfo.labelsNotInTrainingSet`. + * + * @param example the example provided to the model + * @param labelsOfInterest a function used to extract labels for which a + * prediction should be produced. + * @param labelToInd mapping from Label to index into the sequence of all + * labels seen in the training set. + * @tparam A input type of the model + * @tparam K type of label or class + * @return labels and information about the labels. + */ + protected[multilabel] def labelsForPrediction[A, K]( + example: A, + labelsOfInterest: GenAggFunc[A, sci.IndexedSeq[K]], + labelToInd: Map[K, Int] + ): LabelsAndInfo[K] = { + + val labelsShouldPredict = labelsOfInterest(example) + + // Notice here that if the label isn't present in the labeltoInd map, it is ignored + // and not inserted into the `unsorted` sequence. + val unsorted = + for { + label <- labelsShouldPredict + ind <- labelToInd.get(label).toList + } yield (ind, label) + + val problems = + if (labelsShouldPredict.nonEmpty) None + else Option(labelsOfInterest.accessorOutputProblems(example)) + + val noPrediction = + if (unsorted.size == labelsShouldPredict.size) Seq.empty + else labelsShouldPredict.filterNot(labelToInd.contains) + + val (ind, lab) = unsorted.sortBy{ case (i, _) => i }.unzip + + LabelsAndInfo(ind, lab, noPrediction, problems) + } + + private[multilabel] lazy val plugins: Map[String, MultilabelModelParserPlugin] = + MultilabelModelParserPlugin.plugins().map(p => p.name -> p).toMap + + override def parser: ModelParser = Parser + + object Parser extends ModelSubmodelParsingPlugin with Logging { + override val modelType: String = "SparseMultilabel" + + // TODO: Figure if a Option[JsonReader[MultilabelModel[U, _, A, B]]] can be returned. + // See: parser that returns SegmentationModel[U, _, N, A, B] + // See: parser that returns RegressionModel[U, A, B] + // Seems like this should be possible but we get the error: + // + // [error] method commonJsonReader has incompatible type + // [error] override def commonJsonReader[U, N, A, B <: U]( + // [error] ^ + // + override def commonJsonReader[U, N, A, B <: U]( + factory: SubmodelFactory[U, A], + semantics: Semantics[A], + auditor: Auditor[U, N, B])(implicit + r: RefInfo[N], + jf: JsonFormat[N] + ): Option[JsonReader[Model[A, B] with Submodel[N, A, B]]] = { + + // If the type N is a Map with Double values, we can specify a key type (call it K). + // Then the necessary type classes related to K are *instantiated* and a reader is + // created using types N and K. Ultimately, the reader consumes the type `K` but + // only the type N is exposed in the returned reader. + // + // NOTE: The following is an not adequate check: !RefInfoOps.isSubType[N, Map[_, Double]] + if (!RefInfoOps.isSubType[N, Map[Any, Double]]) { + warn(s"N=${RefInfoOps.toString[N]} is not a Map[K, Double]. Cannot create a JsonReader for MultilabelModel.") + None + } + else if (2 != RefInfoOps.typeParams(r).size) { + warn(s"N=${RefInfoOps.toString[N]} does not have 2 type parameters. Cannot infer label type K needed create a JsonReader for MultilabelModel.") + None + } + else { + // This would be a prime candidate for the WriterT monad transformer: + // type Result[A] = WriterT[Option, Vector[String], A] + // https://github.com/typelevel/cats/blob/0.7.x/core/src/main/scala/cats/data/WriterT.scala#L8 + val readerAttempt = + for { + ri <- refInfoOrError[N, Any](r).right // Force type of K = Any + jf <- jsonFormatOrError(factory, ri).right + se <- serializableEvidenceOrError(ri).right + } yield reader(semantics, auditor, ri, jf, se) + + readerAttempt match { + case Left(err) => + warn(err) + None + case Right(reader) => Option(reader) + } + } + } + } + + /** + * Produce the reader. + * + * '''NOTE''': This function should only be applied after we know `N` equals `Map[K, Double]` + * for some `K`. + * + * @param semantics semantics to use for compiling specifications. + * @param auditor an auditor of type `Auditor[U, N, B]`. It's not of type + * `Auditor[U, Map[K, Double], B]`, but because we know the relationship + * between `N` and `Map[K, Double]`, we can cast `N` to `Map[K, Double]`. + * @param ri reflection information about `K`. + * @param jf a JSON format that can translate JSON ASTs to and from `K`s. + * @param se evidence that `K` is `Serializable`. + * @tparam U upper bound on model output type `B` + * @tparam N the expected natural output type the model that will be produced by the + * JSON reader. This should be isomorphic to `Map[K, Double]`. + * @tparam K type of label or class + * @tparam A input type of the model. + * @tparam B output type of the model. + * @return a JSON reader capable of producing a [[MultilabelModel]] from a JSON definition. + */ + private[multilabel] def reader[U, N, K, A, B <: U]( + semantics: Semantics[A], + auditor: Auditor[U, N, B], + ri: RefInfo[K], + jf: JsonFormat[K], + se: SerializabilityEvidence[K]): JsonReader[Model[A, B] with Submodel[N, A, B]] = { + // At this point, N = Map[K, Double], so we are just casting to itself essentially. + val aud = auditor.asInstanceOf[Auditor[U, Map[K, Double], B]] + + // Create a cast from Map[K, Double] to N. Ideally, this would be a Map[K, Double] <:< N + // rather than a Map[K, Double] => N. But it's hard (with good reason) to create a <:< + // and easy to create a function. + implicit val cast = (m: Map[K, Double]) => m.asInstanceOf[N] + + // MultilabelModelReader can produce a MultilabelModel. But there's a problem returning + // the proper type because the compiler doesn't have compile-time evidence that N is + // Map[K, Double] so, a less specific type (Model[A, B] with Submodel[N, A, B]) is + // returned. + MultilabelModelReader(semantics, aud, plugins)(ri, jf, se).untypedReader[N] + } + + private[multilabel] def serializableEvidence[K](refInfoK: RefInfo[K]) = { + val serEv = + if (RefInfoOps.isSubType(refInfoK, RefInfo.JavaSerializable)) + Option(SerializabilityEvidence.serializableEvidence[java.io.Serializable]) + else if (RefInfoOps.isSubType(refInfoK, RefInfo.AnyVal)) + Option(SerializabilityEvidence.anyValEvidence[AnyVal]) + else None + + serEv.asInstanceOf[Option[SerializabilityEvidence[K]]] + } + + private[multilabel] def serializableEvidenceOrError[K](refInfoK: RefInfo[K]) = { + serializableEvidence(refInfoK) + .toRight(s"Couldn't produce evidence that ${RefInfoOps.toString(refInfoK)} is Serializable.") + } + + /** + * Get reflection information about the label type `K` for the [[MultilabelModel]] to be produced. + * + * At the time of application, we should already know N is a subtype of Map with 2 type parameters. + * This will preclude things like + * [[http://scala-lang.org/files/archive/api/2.11.8/#scala.collection.immutable.LongMap scala.collection.immutable.LongMap]]. + * @param rin reflection information about `N`. + * @tparam N natural output type of the model. `N` should equal `Map[K, Double]`. + * @tparam K label type of the [[MultilabelModel]] + * @return reflection information about K. + */ + private[multilabel] def refInfo[N, K](rin: RefInfo[N]) = + RefInfoOps.typeParams(rin).headOption.asInstanceOf[Option[RefInfo[K]]] + + private[multilabel] def refInfoOrError[N, K](rin: RefInfo[N]) = { + refInfo[N, K](rin) + .toRight(s"Couldn't extract key type from natural type: ${RefInfoOps.toString(rin)}") + } + + private[multilabel] def jsonFormatOrError[U, A, K](factory: SubmodelFactory[U, A], refInfoK: RefInfo[K]): Either[String, JsonFormat[K]] = { + // To allow custom class (key) types, we'll need to create a custom ModelFactoryImpl instance + // with a specialized RefInfoToJsonFormat. + factory.jsonFormat(refInfoK) + .toRight(s"Couldn't find a JSON Format for ${RefInfoOps.toString(refInfoK)}. Consider using a different ${classOf[RefInfoToJsonFormat].getCanonicalName}.") + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala new file mode 100644 index 00000000..cd2b4792 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelModelParserPlugin.scala @@ -0,0 +1,40 @@ +package com.eharmony.aloha.models.multilabel + +import com.eharmony.aloha.reflect.{RefInfo, RuntimeClasspathScanning} +import spray.json.{JsonFormat, JsonReader} + +/** + * A plugin that will ultimately produce the [[SparseMultiLabelPredictor]]. + * Created by ryan.deak on 9/6/17. + */ +trait MultilabelModelParserPlugin { + + /** + * A globally unique name for the plugin. If there are name collisions, + * [[MultilabelModelParserPlugin.plugins()]] will error out. + * @return the plugin name. + */ + def name: String + + /** + * Provide a JSON reader that can translate JSON ASTs to a `SparsePredictorProducer`. + * @param info information about the multi-label model passed to the plugin. + * @param ri reflection information about the label type + * @param jf a JSON format representing a bidirectional mapping between instances of + * `K` and JSON ASTs. + * @tparam K the label or class type to be produced by the multi-label model. + * @return a JSON reader that can create `SparsePredictorProducer[K]` from JSON ASTs. + */ + def parser[K](info: PluginInfo[K]) + (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] +} + +object MultilabelModelParserPlugin extends RuntimeClasspathScanning { + + /** + * Finds the plugins in the `com.eharmony.aloha` namespace. + */ + // TODO: Consider making this implicit so that we can parameterize the parser implicitly on a sequence of plugins. + protected[multilabel] def plugins(): Seq[MultilabelModelParserPlugin] = + scanObjects[MultilabelPluginProviderCompanion, MultilabelModelParserPlugin]("multilabelPlugin") +} \ No newline at end of file diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala new file mode 100644 index 00000000..652590ec --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/MultilabelPluginProviderCompanion.scala @@ -0,0 +1,8 @@ +package com.eharmony.aloha.models.multilabel + +/** + * Created by ryan.deak on 9/6/17. + */ +trait MultilabelPluginProviderCompanion { + def multilabelPlugin: MultilabelModelParserPlugin +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala new file mode 100644 index 00000000..db38ce6d --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/PluginInfo.scala @@ -0,0 +1,13 @@ +package com.eharmony.aloha.models.multilabel + +import com.eharmony.aloha.models.reg.json.Spec + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/7/17. + */ +trait PluginInfo[K] { + def features: ListMap[String, Spec] + def labelsInTrainingSet: Vector[K] +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala new file mode 100644 index 00000000..dad7cfb2 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelJson.scala @@ -0,0 +1,46 @@ +package com.eharmony.aloha.models.multilabel.json + +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.multilabel.PluginInfo +import com.eharmony.aloha.models.reg.json.{Spec, SpecJson} +import spray.json.DefaultJsonProtocol._ +import spray.json.{JsObject, JsonFormat, RootJsonFormat} + +import scala.collection.immutable.ListMap +import com.eharmony.aloha.factory.ScalaJsonFormats + +trait MultilabelModelJson extends SpecJson with ScalaJsonFormats { + + protected[this] case class Plugin(`type`: String) + + /** + * AST for multi-label models. + * + * @param modelType + * @param modelId + * @param features + * @param numMissingThreshold + * @param labelsInTrainingSet The sequence of all labels encountered in training. '''It is + * important''' that this is sequence (''with the same order as the + * labels in the training set''). This is because some algorithms + * may require indices based on the training data. + * @param labelsOfInterest a string representing a function that will be used to extract labels. + * @param underlying the underlying model that will be produced by a + * @tparam K + */ + protected[this] case class MultilabelData[K]( + modelType: String, + modelId: ModelId, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Vector[K], + labelsOfInterest: Option[String], + underlying: JsObject + ) extends PluginInfo[K] + + protected[this] final implicit def multilabelDataJsonFormat[K: JsonFormat]: RootJsonFormat[MultilabelData[K]] = + jsonFormat7(MultilabelData[K]) + + protected[this] final implicit val pluginJsonFormat: RootJsonFormat[Plugin] = + jsonFormat1(Plugin) +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala new file mode 100644 index 00000000..82e013bb --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/json/MultilabelModelReader.scala @@ -0,0 +1,111 @@ +package com.eharmony.aloha.models.multilabel.json + +import com.eharmony.aloha.audit.Auditor +import com.eharmony.aloha.models.{Model, Submodel} +import com.eharmony.aloha.models.multilabel.{MultilabelModel, MultilabelModelParserPlugin} +import com.eharmony.aloha.models.reg.RegFeatureCompiler +import com.eharmony.aloha.models.reg.json.Spec +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.Semantics +import com.eharmony.aloha.semantics.func.GenAggFunc +import com.eharmony.aloha.util.{EitherHelpers, SerializabilityEvidence} +import spray.json.{DeserializationException, JsValue, JsonFormat, JsonReader} + +import scala.collection.{immutable => sci} + +/** + * A JSON reader capable of turning JSON to a [[MultilabelModel]]. + * Created by ryan.deak on 9/7/17. + * @param semantics semantics used to generate features and labels of interest. + * @param auditor the auditor used to translate the [[MultilabelModel]] output values to `B` instances. + * @param plugins the possible plugins to which [[MultilabelModel]] can delegate to produce predictions. + * Plugins are responsible for creating the `predictorProducer` passed to the + * [[MultilabelModel]] constructor. + * @param refInfoK reflection information about the label type. + * @param jsonFormatK a JSON format capable of parsing the label type. + * @param serEvK evidence that the label type is `Serializable`. + * @tparam U upper bound on model output type `B` + * @tparam K type of label or class + * @tparam A input type of the model + * @tparam B output type of the model. + */ +final case class MultilabelModelReader[U, K, A, B <: U]( + semantics: Semantics[A], + auditor: Auditor[U, Map[K, Double], B], + plugins: Map[String, MultilabelModelParserPlugin]) +(implicit + refInfoK: RefInfo[K], + jsonFormatK: JsonFormat[K], + serEvK: SerializabilityEvidence[K]) + extends JsonReader[MultilabelModel[U, K, A, B]] + with MultilabelModelJson + with EitherHelpers + with RegFeatureCompiler { self => + + /** + * Creates a reader for [[MultilabelModel]] that has a less specific type assuming that we + * can produce evidence that `N` is a super type of `Map[K, Double]`. + * @param ev Provides a conversion from `Map[K, Double]` to `N`. This indicates a subtype + * relationship and is ''close to the same'' as implicit subtype evidence + * `<:<`, but is not as strong (because it's not statically checked by the compiler). + * @tparam N The natural output of the Model that is a super type of `Map[K, Double]`. Since + * auditors take values of type `N` as input, supplying a subtype should also work. + * @return A JSON reader the creates [[MultilabelModel]], but returns instances as a less + * specific type. + */ + private[multilabel] def untypedReader[N](implicit ev: Map[K, Double] => N): JsonReader[Model[A, B] with Submodel[N, A, B]] = + new JsonReader[Model[A, B] with Submodel[N, A, B]]{ + override def read(json: JsValue): Model[A, B] with Submodel[N, A, B] = + self.read(json).asInstanceOf[Model[A, B] with Submodel[N, A, B]] + } + + override def read(json: JsValue): MultilabelModel[U, K, A, B] = { + val mlData = json.convertTo[MultilabelData[K]] + val pluginType = mlData.underlying.convertTo[Plugin].`type` + + plugins.get(pluginType) + .map { plugin => model(mlData, plugin) } + .getOrElse { throw new DeserializationException(errorMsg(pluginType)) } + } + + private[multilabel] def errorMsg(pluginType: String): String = { + val pluginNames = plugins.mapValues(p => p.getClass.getCanonicalName) + s"Couldn't find plugin for type $pluginType. Plugins available: $pluginNames" + } + + private[multilabel] def compileLabelsOfInterest(labelsOfInterestSpec: String): ENS[GenAggFunc[A, sci.IndexedSeq[K]]] = { + semantics.createFunction[sci.IndexedSeq[K]](labelsOfInterestSpec, Option(Vector.empty[K])). + left.map { Seq(s"Error processing spec '$labelsOfInterestSpec'") ++ _ } + } + + private[multilabel] def model(mlData: MultilabelData[K], plugin: MultilabelModelParserPlugin): MultilabelModel[U, K, A, B] = { + val predProdReader = plugin.parser(mlData)(refInfoK, jsonFormatK) + val predProd = mlData.underlying.convertTo(predProdReader) + + // Compile the labels of interest function. + val labelsOfInterest = mlData.labelsOfInterest.map { spec => + compileLabelsOfInterest(spec) match { + case Left(errs) => throw new DeserializationException(errs.mkString("\n")) + case Right(f) => f + } + } + + // Compile the features. + val featureMap: Seq[(String, Spec)] = mlData.features.toSeq + val (featureNames, featureFns) = + features(featureMap, semantics) + .fold(f => throw new DeserializationException(f.mkString("\n")), identity) + .toIndexedSeq.unzip + + MultilabelModel( + modelId = mlData.modelId, + featureNames = featureNames, + featureFunctions = featureFns, + labelsInTrainingSet = mlData.labelsInTrainingSet, + labelsOfInterest = labelsOfInterest, + predictorProducer = predProd, + numMissingThreshold = mlData.numMissingThreshold, + auditor = auditor + )(serEvK) + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala new file mode 100644 index 00000000..5796856b --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/multilabel/package.scala @@ -0,0 +1,73 @@ +package com.eharmony.aloha.models + +import com.eharmony.aloha.dataset.density.Sparse + +import scala.collection.{immutable => sci} +import scala.util.Try + +/** + * Created by ryan.deak on 8/31/17. + */ +package object multilabel { + + // All but the last two types are package private, for testing. The others are public. + + /** + * Features about the input value (NOT including features based on labels). + * This should probably be a `sci.IndexedSeq[Sparse]` but `RegressionFeatures` + * returns a `collection.IndexedSeq` and using + * [[com.eharmony.aloha.models.reg.RegressionFeatures.constructFeatures]] is + * preferable and will provide consistent results across many model types. + */ + private[multilabel] type SparseFeatures = IndexedSeq[Sparse] + + /** + * Indices of the labels for which predictions should be produced into the + * sequence of all labels. Indices will be sorted in ascending order. + */ + private[multilabel] type LabelIndices = sci.IndexedSeq[Int] + + /** + * Labels for which predictions should be produced. This can be an improper subset of all labels. + * `Labels` align with `LabelIndices` meaning `LabelIndices[i]` will give the index of `Labels[i]` + * into the sequence of all labels a model knows about. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + private[multilabel] type Labels[K] = sci.IndexedSeq[K] + + /** + * Sparse features related to the labels. Other outer sequence aligns with the `Labels` and `LabelIndices` + * sequences meaning `SparseLabelDepFeatures[i]` relates to the features of `Labels[i]`. + */ + private[multilabel] type SparseLabelDepFeatures = Labels[SparseFeatures] + + /** + * A sparse multi-label predictor takes: + * + - features + - labels for which a prediction should be produced + - indices of those labels into sequence of all of the labels the model knows about. + - label dependent-features + * + * and returns a Map from the labels passed in, to the prediction associated with the label. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + type SparseMultiLabelPredictor[K] = + (SparseFeatures, Labels[K], LabelIndices, SparseLabelDepFeatures) => Try[Map[K, Double]] + + /** + * A lazy version of a sparse multi-label predictor. It is a curried zero-arg function that + * produces a sparse multi-label predictor. + * + * This definition is "lazy" because we can't guarantee that the underlying predictor is + * `Serializable` so we pass around a function that can be cached in a ''transient'' + * `lazy val`. This function should however be `Serializable` and testing should be done + * to ensure that each predictor producer is `Serializable`. + * + * @tparam K the type of labels (or classes in the machine learning literature). + */ + type SparsePredictorProducer[K] = () => SparseMultiLabelPredictor[K] +} + diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala index fd95d951..4045041a 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegFeatureCompiler.scala @@ -1,5 +1,6 @@ package com.eharmony.aloha.models.reg +import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.models.reg.json.Spec import com.eharmony.aloha.semantics.Semantics import com.eharmony.aloha.semantics.func.GenAggFunc @@ -23,7 +24,7 @@ trait RegFeatureCompiler { self: EitherHelpers => protected[this] def features[A](featureMap: Seq[(String, Spec)], semantics: Semantics[A]): ENS[Seq[(String, GenAggFunc[A, Iterable[(String, Double)]])]] = mapSeq(featureMap) { case (k, Spec(spec, default)) => - semantics.createFunction[Iterable[(String, Double)]](spec, default). + semantics.createFunction[Sparse](spec, default). left.map { Seq(s"Error processing spec '$spec'") ++ _ }. // Add the spec that errored. right.map { f => (k, f) } } diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala index 21be6fb0..da9bdb48 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/models/reg/RegressionFeatures.scala @@ -1,5 +1,6 @@ package com.eharmony.aloha.models.reg +import com.eharmony.aloha.dataset.density.Sparse import com.eharmony.aloha.semantics.func.GenAggFunc import scala.collection.{immutable => sci, mutable => scm} @@ -20,7 +21,7 @@ trait RegressionFeatures[-A] { /** * Parallel to featureNames. This is the sequence of functions that extract data from the input value. */ - protected[this] val featureFunctions: sci.IndexedSeq[GenAggFunc[A, Iterable[(String, Double)]]] + protected[this] val featureFunctions: sci.IndexedSeq[GenAggFunc[A, Sparse]] /** * A threshold dictating how many missing features to allow before making the prediction fail. None means @@ -66,7 +67,7 @@ trait RegressionFeatures[-A] { * 1 the map of bad features to the missing values in the raw data that were needed to compute the feature * 1 whether the amount of missing data is acceptable to still continue */ - protected[this] final def constructFeatures(a: A): Features[IndexedSeq[Iterable[(String, Double)]]] = { + protected[this] final def constructFeatures(a: A): Features[IndexedSeq[Sparse]] = { // NOTE: Since this function is at the center of the regression process and will be called many times, it // needs to be efficient. Therefore, it uses some things that are not idiomatic scala. For instance, // there are mutable variables, while loops instead of for comprehensions or Range.foreach, etc. @@ -95,7 +96,7 @@ trait RegressionFeatures[-A] { i += 1 } - val numMissingOk = numMissingThreshold map { missing.size <= _ } getOrElse true + val numMissingOk = numMissingThreshold.forall(t => missing.size <= t) // If we are going to err out, allow a linear scan (with repeated work so that we can get richer error // diagnostics. Only include the values where the list of missing accessors variables is not empty. diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala index 7290b1dc..703bead4 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RefInfo.scala @@ -1,34 +1,35 @@ package com.eharmony.aloha.reflect -import java.{lang => jl} +import java.{lang => jl, io => ji} import deaktator.reflect.runtime.manifest.ManifestParser object RefInfo { - val Any: RefInfo[Any] = RefInfoOps.refInfo[Any] - val AnyRef: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] - val AnyVal: RefInfo[AnyVal] = RefInfoOps.refInfo[AnyVal] - val Boolean: RefInfo[Boolean] = RefInfoOps.refInfo[Boolean] - val Byte: RefInfo[Byte] = RefInfoOps.refInfo[Byte] - val Char: RefInfo[Char] = RefInfoOps.refInfo[Char] - val Double: RefInfo[Double] = RefInfoOps.refInfo[Double] - val Float: RefInfo[Float] = RefInfoOps.refInfo[Float] - val Int: RefInfo[Int] = RefInfoOps.refInfo[Int] - val JavaBoolean: RefInfo[jl.Boolean] = RefInfoOps.refInfo[jl.Boolean] - val JavaByte: RefInfo[jl.Byte] = RefInfoOps.refInfo[jl.Byte] - val JavaCharacter: RefInfo[jl.Character] = RefInfoOps.refInfo[jl.Character] - val JavaDouble: RefInfo[jl.Double] = RefInfoOps.refInfo[jl.Double] - val JavaFloat: RefInfo[jl.Float] = RefInfoOps.refInfo[jl.Float] - val JavaInteger: RefInfo[jl.Integer] = RefInfoOps.refInfo[jl.Integer] - val JavaLong: RefInfo[jl.Long] = RefInfoOps.refInfo[jl.Long] - val JavaShort: RefInfo[jl.Short] = RefInfoOps.refInfo[jl.Short] - val Long: RefInfo[Long] = RefInfoOps.refInfo[Long] - val Nothing: RefInfo[Nothing] = RefInfoOps.refInfo[Nothing] - val Null: RefInfo[Null] = RefInfoOps.refInfo[Null] - val Object: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] - val Short: RefInfo[Short] = RefInfoOps.refInfo[Short] - val Unit: RefInfo[Unit] = RefInfoOps.refInfo[Unit] - val String: RefInfo[String] = RefInfoOps.refInfo[String] + val Any: RefInfo[Any] = RefInfoOps.refInfo[Any] + val AnyRef: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] + val AnyVal: RefInfo[AnyVal] = RefInfoOps.refInfo[AnyVal] + val Boolean: RefInfo[Boolean] = RefInfoOps.refInfo[Boolean] + val Byte: RefInfo[Byte] = RefInfoOps.refInfo[Byte] + val Char: RefInfo[Char] = RefInfoOps.refInfo[Char] + val Double: RefInfo[Double] = RefInfoOps.refInfo[Double] + val Float: RefInfo[Float] = RefInfoOps.refInfo[Float] + val Int: RefInfo[Int] = RefInfoOps.refInfo[Int] + val JavaBoolean: RefInfo[jl.Boolean] = RefInfoOps.refInfo[jl.Boolean] + val JavaByte: RefInfo[jl.Byte] = RefInfoOps.refInfo[jl.Byte] + val JavaCharacter: RefInfo[jl.Character] = RefInfoOps.refInfo[jl.Character] + val JavaDouble: RefInfo[jl.Double] = RefInfoOps.refInfo[jl.Double] + val JavaFloat: RefInfo[jl.Float] = RefInfoOps.refInfo[jl.Float] + val JavaInteger: RefInfo[jl.Integer] = RefInfoOps.refInfo[jl.Integer] + val JavaLong: RefInfo[jl.Long] = RefInfoOps.refInfo[jl.Long] + val JavaShort: RefInfo[jl.Short] = RefInfoOps.refInfo[jl.Short] + val JavaSerializable: RefInfo[ji.Serializable] = RefInfoOps.refInfo[ji.Serializable] + val Long: RefInfo[Long] = RefInfoOps.refInfo[Long] + val Nothing: RefInfo[Nothing] = RefInfoOps.refInfo[Nothing] + val Null: RefInfo[Null] = RefInfoOps.refInfo[Null] + val Object: RefInfo[AnyRef] = RefInfoOps.refInfo[AnyRef] + val Short: RefInfo[Short] = RefInfoOps.refInfo[Short] + val Unit: RefInfo[Unit] = RefInfoOps.refInfo[Unit] + val String: RefInfo[String] = RefInfoOps.refInfo[String] def apply[A: RefInfo]: RefInfo[A] = RefInfoOps.refInfo[A] diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala new file mode 100644 index 00000000..629cb8ca --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/reflect/RuntimeClasspathScanning.scala @@ -0,0 +1,65 @@ +package com.eharmony.aloha.reflect + +import com.eharmony.aloha +import org.reflections.Reflections + +import scala.reflect.{classTag, ClassTag} +import scala.util.Try + +/** + * Created by ryan.deak on 9/6/17. + */ +trait RuntimeClasspathScanning { + + private[this] val objectSuffix = "$" + + /** + * Determine if the class is a Scala object by looking at the class name. If it + * ends in the `objectSuffix` + * @param c a `Class` instance. + * @tparam A type of Class. + * @return + */ + private[this] def isObject[A](c: Class[A]) = c.getCanonicalName.endsWith(objectSuffix) + + /** + * Scan the classpath within subpackages of `packageToSearch` to find Scala objects that + * contain extend `OBJ` and contain method called `methodName` that should return an `A`. + * Call the method and get the result. If the result is indeed an `A`, add to it to + * the resulting sequence. + * + * This function makes some assumptions. It assumes that Scala objects have classes + * ending with '$'. It also assumes that methods in objects have static forwarders. + * + * @param methodName the name of a method in an object that should be found and called. + * @param packageToSearch the package to search for candidates + * @tparam OBJ the super type of the objects the search should find + * @tparam A the output type of elements that should be returned by the method named + * `methodName` + * @return a sequence of `A` instances that could be found. + */ + protected[this] def scanObjects[OBJ: ClassTag, A: ClassTag]( + methodName: String, + packageToSearch: String = aloha.pkgName + ): Seq[A] = { + val reflections = new Reflections(aloha.pkgName) + import scala.collection.JavaConversions.asScalaSet + val objects = reflections.getSubTypesOf(classTag[OBJ].runtimeClass).toSeq + + val suffixLength = objectSuffix.length + + objects.flatMap { + case o if isObject(o) => + Try { + // This may have some classloading issues. + val classObj = Class.forName(o.getCanonicalName.dropRight(suffixLength)) + classObj.getMethod(methodName).invoke(null) match { + case a: A => a + case _ => throw new IllegalStateException() + } + }.toOption + case _ => None + } + } +} + diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala b/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala index d869b1fa..b7d2a7f4 100644 --- a/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala +++ b/aloha-core/src/main/scala/com/eharmony/aloha/semantics/SemanticsUdfException.scala @@ -26,7 +26,7 @@ case class SemanticsUdfException[+A]( input: A) extends AlohaException(SemanticsUdfException.getMessage(specification, accessorOutput, accessorsMissingOutput, accessorsInErr, input), cause) -private object SemanticsUdfException { +object SemanticsUdfException { /** This method guards against throwing exceptions. * @param specification a specification for the feature that produced an error. @@ -38,7 +38,7 @@ private object SemanticsUdfException { * @tparam A type of input * @return */ - def getMessage[A]( + private def getMessage[A]( specification: String, accessorOutput: Map[String, Try[Any]], accessorsMissingOutput: List[String], diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala new file mode 100644 index 00000000..653908f6 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/SerializabilityEvidence.scala @@ -0,0 +1,17 @@ +package com.eharmony.aloha.util + +/** + * A type class used to indicate a parameter has a type that can be serialized in + * a larger Serializable object. + */ +sealed trait SerializabilityEvidence[A] + +object SerializabilityEvidence { + + implicit def anyValEvidence[A <: AnyVal]: SerializabilityEvidence[A] = + new SerializabilityEvidence[A]{} + + implicit def serializableEvidence[A <: java.io.Serializable]: SerializabilityEvidence[A] = + new SerializabilityEvidence[A]{} + +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala new file mode 100644 index 00000000..b240364f --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/StatefulMapOps.scala @@ -0,0 +1,93 @@ +package com.eharmony.aloha.util + +import scala.collection.{breakOut, SeqLike, immutable => sci} +import scala.collection.generic.{CanBuildFrom => CBF} + +/** + * Some ways to map with state without needing extra libraries like ''cats'' or ''scalaz''. + * @author deaktator + * @since 11/8/2017 + */ +private[aloha] object StatefulMapOps { + + /** + * Map over `as` while maintaining a running state to produce `(B, S)` pairs contained + * within some container type. + * + * Note that the input container type `In` determines the strictness of this method. + * For instance, this can produce values ''strictly'' with a strict sequence like + * `Vector` or ''non-strictly'' for a non-strict sequence like a `Stream`. + * For instance: + * + * {{{ + * val sCycleModulo4 = Stream.iterate(0)(s => (s + 1) % 4) + * val vCycleModulo4 = Vector.iterate(0, 10)(s => (s + 1) % 4) + * val addAndIncrement = (a: Int, s: Int) => ((a + s).toDouble, s + 1) + * + * + * // Non-strict, so it returns even though it's an infinite stream. O(1). + * // type: Stream[(Double, Int)] + * val stream = statefulMap(sCycleModulo4, 0)(addAndIncrement) + * + * // Strict, so it's runtime is O(N), where N = vCycleModulo4.size. + * import scala.collection.breakOut + * val list: List[(Double, Int)] = + * statefulMap(vCycleModulo4, 0)(addAndIncrement)(breakOut) + * }}} + * + * @param as elements to map (given some state that may change). '''NOTE''': the first + * element of `as` ''will be forced'' in this method in order to construct the + * output. + * @param startState the initial state. + * @param f a function that maps a value and a state and to an output and a new state. + * @param cbf object responsible for building the output. + * @tparam A the input element type + * @tparam B the output element type + * @tparam S the state type + * @tparam In a concrete implementation of the input sequence. + * @tparam Out the output container type (including type parameters). + * @return an output container with the `A`s mapped to `B`s and the state resulting + * from applying `f` and getting the second element (''e.g.'': `f(a, s)._2`). + */ + def statefulMap[A, B, S, In <: sci.Seq[A], Out](as: SeqLike[A, In], startState: S) + (f: (A, S) => (B, S)) + (implicit cbf: CBF[In, (B, S), Out]): Out = { + if (as.isEmpty) + cbf().result() + else { + val initEl = f(as.head, startState) + as.tail.scanLeft(initEl){ case ((_, newState), a) => f(a, newState) }(breakOut) + } + } + + + /** + * Map over `as` while maintaining a running state to produce `(B, S)` pairs contained + * within some container type. + * + * {{{ + * val sCycleModulo4 = Iterator.iterate(0)(s => (s + 1) % 4) + * val addAndIncrement = (a: Int, s: Int) => ((a + s).toDouble, s + 1) + * + * val stream = statefulMap(sCycleModulo4, 0)(addAndIncrement) + * }}} + * @param as elements to map (given some state that may change). '''NOTE''': the first + * element of `as` ''will be forced'' in this method in order to construct the + * output. + * @param startState the initial state. + * @param f a function that maps a value and a state and to an output and a new state. + * @tparam A the input element type + * @tparam B the output element type + * @tparam S the state type + * @return an iterator with the `A`s mapped to `B`s and the state resulting + * from applying `f` and getting the second element (''e.g.'': `f(a, s)._2`). + */ + def statefulMap[A, B, S](as: Iterator[A], startState: S)(f: (A, S) => (B, S)): Iterator[(B, S)] = { + if (!as.hasNext) + Iterator.empty + else { + val initEl = f(as.next(), startState) + as.scanLeft(initEl){ case ((_, newState), a) => f(a, newState) } + } + } +} diff --git a/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala new file mode 100644 index 00000000..d6038922 --- /dev/null +++ b/aloha-core/src/main/scala/com/eharmony/aloha/util/rand/Rand.scala @@ -0,0 +1,120 @@ +package com.eharmony.aloha.util.rand + +/** + * Some stateless random sampling utilities. + * + * @author deaktator + * @since 11/3/2017 + */ +private[aloha] trait Rand { + + type Seed = Long + type Index = Int + + /** + * Perform the initial scramble. This should be called '''''once''''' on the initial + * seed prior to the first call to `sampleCombination`. + * @param seed an initial seed + * @return a more scrambled seed. + */ + protected def initSeedScramble(seed: Seed): Seed = + (seed ^ 0x5DEECE66DL) & 0xFFFFFFFFFFFFL + + + /** + * Sample a ''k''-combination from a population of ''n''. + * + * This algorithm uses a linear congruential pseudorandom number generator (see Knuth) + * to perform reservoir sampling via "''Algorithm R''". + * + * It is ~ '''''O'''''('''''n'''''). + * + * If `n` ≤ `k`, then return 0, ..., `n` - 1; otherwise, if `k` < `n`, the + * returned array have length `k` with values between 0 and `n - 1` (inclusive) + * but it is '''NOT''' guaranteed to be sorted. + * + * '''NOTE''': This is a pure function. It produces the same results as if + * `java.util.Random` was used to perform reservoir sampling but since it doesn't + * carry state, this can be trivially operated in parallel with no locking or CAS + * loop overhead. The consequence is that the `seed` must be provided on every call + * and a new seed will be returned as part of the output. + * + * To get this function to act like `java.util.Random`, the first time it is called, the + * seed should be produce by running the desired seed through `initSeedScramble`. For + * instance: + * + * {{{ + * val (kComb1, newSeed1) = sampleCombination(4, 2, initSeedScramble(0)) + * val (kComb2, newSeed2) = sampleCombination(4, 2, newSeed1) + * }}} + * + * For more information, see: + * + - [[http://luc.devroye.org/chapter_twelve.pdf Non-Uniform Random Variate Generation, Luc Devroye, Ch. 12]] + - [[https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R Reservoir sampling (Wikipedia)]] + - [[http://grepcode.com/file/repository.grepcode.com/java/root/jdk/openjdk/6-b14/java/util/Random.java grepcode for Random]] + - [[https://en.wikipedia.org/wiki/Linear_congruential_generator Linear congruential generator (Wikipedia)]] + * + * @param n population size + * @param k combination size + * @param seed the seed to use for random selection + * @return a tuple 2 containing the array of 0-based indices representing + * the ''k''-combination and a new random seed. + */ + protected def sampleCombination(n: Int, k: Int, seed: Seed): (Array[Index], Seed) = { + + // NOTE: This isn't idiomatic Scala code but it will operate in hot loops in minimizing + // object creation is important. + + if (n <= k) { + ((0 until n).toArray, seed) + } + else { + var i = k + 1 + var nextSeed = seed + var reservoirSwapInd = 0 + var bits = 0 + var value = 0 + + // Fill reservoir with the first k indices. + val reservoir = (0 until k).toArray + + // Loop over the rest of the indices outside the reservoir and determine if + // swapping should occur. If so, swap the index in the reservoir with the + // current element, i - 1. + while (i <= n) { + reservoirSwapInd = + if ((i & -i) == i) { + // i = 2^j, for some j. + + // To understand these constants, see + // https://en.wikipedia.org/wiki/Linear_congruential_generator + nextSeed = (nextSeed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL + + // 17 = log2(modulus) - shift = 48 - 31 + ((i * (nextSeed >>> 17)) >> 31).toInt + } + else { + // Loop at least once per swap index. + do { + nextSeed = (nextSeed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL + bits = (nextSeed >>> 17).toInt + value = bits % i + } while (bits - value + (i - 1) < 0) + + value + } + + // This is the key to maintaining the proper probabilities in the reservoir sampling. + // Wikipedia does a good job describing this in the proof by induction in the + // explanation for Algorithm R. + if (reservoirSwapInd < k) + reservoir(reservoirSwapInd) = i - 1 + + i += 1 + } + + (reservoir, nextSeed) + } + } +} diff --git a/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java b/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java index e857194e..c9dde596 100644 --- a/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java +++ b/aloha-core/src/test/java/com/eharmony/aloha/factory/JavaDefaultModelFactoryTest.java @@ -12,6 +12,7 @@ import com.eharmony.aloha.models.conversion.DoubleToLongModel; import com.eharmony.aloha.models.exploration.BootstrapModel; import com.eharmony.aloha.models.exploration.EpsilonGreedyModel; +import com.eharmony.aloha.models.multilabel.MultilabelModel; import com.eharmony.aloha.reflect.RefInfo; import com.eharmony.aloha.semantics.NoSemantics; import com.eharmony.aloha.semantics.Semantics; @@ -78,7 +79,8 @@ public class JavaDefaultModelFactoryTest { ErrorSwallowingModel.parser().modelType(), EpsilonGreedyModel.parser().modelType(), BootstrapModel.parser().modelType(), - CloserTesterModel.parser().modelType() + CloserTesterModel.parser().modelType(), + MultilabelModel.parser().modelType() }; Arrays.sort(names); diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala b/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala index bb94fc74..6334401a 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/ModelSerializabilityTestBase.scala @@ -9,13 +9,16 @@ import org.reflections.Reflections import scala.collection.JavaConversions.asScalaSet import scala.util.Try - import java.lang.reflect.{Method, Modifier} +import com.eharmony.aloha.util.Logging + /** * Created by ryan on 12/7/15. */ -abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String]) { +abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[String]) +extends Logging { + def this() = this(pkgs = Seq(aloha.pkgName), Seq.empty) @Test def testSerialization(): Unit = { @@ -27,6 +30,12 @@ abstract class ModelSerializabilityTestBase(pkgs: Seq[String], outFilters: Seq[S outFilters.exists(name.matches) } + debug { + modelClasses + .map(_.getCanonicalName) + .mkString("Models tested for Serializability:\n\t", "\n\t", "") + } + modelClasses.foreach { c => val m = for { testClass <- getTestClass(c.getCanonicalName) diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala index 553d30dc..e7d894bf 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/RowCreatorProducerTest.scala @@ -3,15 +3,19 @@ package com.eharmony.aloha.dataset import java.lang.reflect.Modifier import com.eharmony.aloha +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator import org.junit.Assert._ import org.junit.{Ignore, Test} import org.junit.runner.RunWith import org.junit.runners.BlockJUnit4ClassRunner + import scala.collection.JavaConversions.asScalaSet import org.reflections.Reflections @RunWith(classOf[BlockJUnit4ClassRunner]) class RowCreatorProducerTest { + import RowCreatorProducerTest._ + private[this] def scanPkg = aloha.pkgName + ".dataset" @Test def testAllRowCreatorProducersHaveOnlyZeroArgConstructors() { @@ -20,9 +24,11 @@ class RowCreatorProducerTest { specProdClasses.foreach { clazz => val cons = clazz.getConstructors assertTrue(s"There should only be one constructor for ${clazz.getCanonicalName}. Found ${cons.length} constructors.", cons.length <= 1) - cons.headOption.foreach{ c => - val nParams = c.getParameterTypes.length - assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments. It takes $nParams.", 0, nParams) + cons.headOption.foreach { c => + if (!(WhitelistedRowCreatorProducers contains clazz)) { + val nParams = c.getParameterTypes.length + assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments. It takes $nParams.", 0, nParams) + } } } } @@ -41,3 +47,9 @@ class RowCreatorProducerTest { } } } + +object RowCreatorProducerTest { + private val WhitelistedRowCreatorProducers = Set[Class[_]]( + classOf[VwMultilabelRowCreator.Producer[_, _]] + ) +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala new file mode 100644 index 00000000..da56b63c --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/dataset/vw/multilabel/VwMultilabelRowCreatorTest.scala @@ -0,0 +1,231 @@ +package com.eharmony.aloha.dataset.vw.multilabel + +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.dataset.{MissingAndErroneousFeatureInfo, SparseFeatureExtractorFunction} +import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc0} +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import spray.json.pimpString +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances + +import scala.collection.breakOut + +/** + * Created by ryan.deak on 9/22/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelRowCreatorTest { + import VwMultilabelRowCreatorTest._ + + @Test def testSharedIsPresentWhenNoFeaturesSpecified(): Unit = { + val rc = rowCreator(0) + val a = output(rc(X)) + val expected = SharedPrefix +: (DummyLabels ++ AllNegative) + assertEquals(expected, a.toList) + } + + @Test def testOneFeatureNoPos(): Unit = { + val rc = rowCreator(numFeatures = 1) + val a = output(rc(X)) + + val shared = s"$SharedPrefix| f1" + val expected = shared +: (DummyLabels ++ AllNegative) + assertEquals(expected, a.toList) + } + + @Test def testRowCreationFromJson(): Unit = { + + // Some labels. These are intentionally out of order. Don't change. + val allLabelInts = Vector(7, 8, 6) + + // Run tests for the entire power set of positive labels. + val tests = Seq( + // positive labels: + 1, // {} + 6, // {6} + 7, // {7} + 8, // {8} + 6 * 7, // {6, 7} + 6 * 8, // {6, 8} + 7 * 8, // {7, 8} + 6 * 7 * 8 // {6, 7, 8} + ) + + // Could throw, but if so, test should fail. + val rc = getRowCreator(allLabelInts) + + tests foreach { test => + val exp = positiveAndNegativeValues(test, allLabelInts) + val (_, act) = rc(test) + testOutput(test, exp, act, MultilabelOutputPrefix, rc.classNs) + } + } + + private[this] def getRowCreator(allLabelInts: Vector[Int]) = { + val rcProducer = + new VwMultilabelRowCreator.Producer[Int, Label](allLabelInts.map(y => y.toString)) + + val rcTry = rcProducer.getRowCreator( + CompiledSemanticsInstances.anyNameIdentitySemantics[Int], + MultilabelDatasetJson + ) + + // Don't care about calling `.get`. If it fails, the test will appropriately blow up. + rcTry.get + } + + private[this] def positiveAndNegativeValues(n: Int, allLabels: Seq[Int]): Seq[Boolean] = { + val divisorsN = divisors(n) + val labelInd: Map[Int, Int] = allLabels.zipWithIndex.toMap + val allLabelSet = labelInd.keySet + val pos = allLabelSet intersect divisorsN + val neg = allLabelSet diff divisorsN + + // Because pos and neg form a partition of allLabelSet, labelInd.apply is safe. + val posAndNeg = + pos.map(p => labelInd(p) -> true) ++ + neg.map(p => labelInd(p) -> false) + + posAndNeg + .toSeq + .sorted + .map { case (_, isPos) => isPos } + } + + private[this] def testOutput( + n: Int, + expectedResults: Seq[Boolean], + actualResults: Array[String], + prefix: Seq[String], + labelNs: Char + ): Unit = { + + val suffix = + expectedResults.zipWithIndex map { case (isPos, i) => + s"$i:${if (isPos) PosVal else NegVal} |$labelNs _$i" + } + + assertEquals(prefix, actualResults.take(prefix.size).toSeq) + assertEquals(suffix, actualResults.drop(prefix.size).toSeq) + } + + + /** + * Get all divisors of `n`. + * This is not super efficient: O(sqrt(N)) + * @param n value for which divisors are desired. + * @return + */ + private[this] def divisors(n: Int): Set[Int] = + (1 to math.sqrt(n).ceil.toInt).flatMap { + case i if n % i == 0 => List(i, n / i) + case _ => Nil + }(breakOut) +} + +object VwMultilabelRowCreatorTest { + private type Domain = Map[String, Any] + private type Label = String + private val Omitted = "" + private val LabelsInTrainingSet = Vector("zero", "one", "two") + private val NegDummyClass = Int.MaxValue.toLong + 1 + private val PosDummyClass = NegDummyClass + 1 + private val PosVal = 0 + private val NegVal = 1 + private val PosFeature = 'P' + private val NegFeature = 'N' + private val DummyNs = 'y' + + private val X = Map.empty[String, Any] + private val SharedPrefix = "shared " + + private val DummyLabels = List( + s"$NegDummyClass:$NegVal |$DummyNs $NegFeature", + s"$PosDummyClass:$PosVal |$DummyNs $PosFeature" + ) + + private val AllNegative = LabelsInTrainingSet.indices.map(i => s"$i:$NegVal |Y _$i") + + // Notice the positive labels. This says if the function input is 0 (mod v), where v is + // one of {6, 7, 8}, then make the example a positive example for the associated label. + // + // This JSON should be used with CompiledSemanticsInstances.anyNameIdentitySemantics[Int]. + // This variable `fn_input` works because of the way that the semantics works. It just + // returns the `fn_input` is just the input value of the function. + // + // All of the features are invariant to the input. + private[aloha] val MultilabelDatasetJson = + """ + |{ + | "imports": [ + | "com.eharmony.aloha.feature.BasicFunctions._" + | ], + | "features": [ + | { "name": "f1", "spec": "1" }, + | { "name": "f2", "spec": "1" } + | ], + | "namespaces": [ + | { "name": "ns1", "features": [ "f1" ] }, + | { "name": "ns2", "features": [ "f2" ] } + | ], + | "normalizeFeatures": false, + | "positiveLabels": "(6 to 8).flatMap(v => List(v.toString).filter(_ => ${fn_input} % v == 0))" + |} + """.stripMargin.parseJson.convertTo[VwMultilabeledJson] + + // This is what the first lines of the output are expected to be. This doesn't change b/c + // the features (and dummy class definitions) in MultilabelDatasetJson are invariant to + // the input. + private[aloha] val MultilabelOutputPrefix = Seq( + "shared |ns1 f1 |ns2 f2", // due to definition of features and namespaces. + s"2147483648:$NegVal |y N", // negative dummy class + s"2147483649:$PosVal |y P" // positive dummy class + ) + + private def output(out: (MissingAndErroneousFeatureInfo, Array[String])) = out._2 + + private[this] val featureFns = SparseFeatureExtractorFunction[Domain](Vector( + "f1" -> GenFunc0(Omitted, _ => Seq(("", 1d))), + "f2" -> GenFunc0(Omitted, _ => Seq(("", 2d))) + )) + + private def featureFns(n: Int) = { + val ff = (1 to n) map { i => + s"f$i" -> GenFunc0(Omitted, (_: Any) => Seq(("", i.toDouble))) + } + + SparseFeatureExtractorFunction[Domain](ff) + } + + private def positiveLabels(ps: Label*): GenAggFunc[Any, Vector[Label]] = { + GenFunc0(Omitted, _ => Vector(ps:_*)) + } + + private def rowCreator(numFeatures: Int, posLabels: Label*): VwMultilabelRowCreator[Domain, Label] = { + val ff = featureFns(numFeatures) + val pos = positiveLabels(posLabels:_*) + StdRowCreator.copy( + featuresFunction = ff, + defaultNamespace = ff.features.indices.toList, + positiveLabelsFunction = pos + ) + } + + private[this] val StdRowCreator: VwMultilabelRowCreator[Domain, Label] = { + val ff = featureFns(0) + val labelNss = VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get + + VwMultilabelRowCreator[Domain, Label]( + allLabelsInTrainingSet = LabelsInTrainingSet, + featuresFunction = ff, + defaultNamespace = ff.features.indices.toList, // All features in default NS. + namespaces = List.empty[(String, List[Int])], + normalizer = Option.empty[CharSequence => CharSequence], + positiveLabelsFunction = positiveLabels(), + classNs = labelNss.labelNs, + dummyClassNs = labelNss.dummyLabelNs + ) + } +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala new file mode 100644 index 00000000..8ec4a3ec --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/multilabel/MultilabelModelTest.scala @@ -0,0 +1,383 @@ +package com.eharmony.aloha.models.multilabel + +import java.io.{Closeable, PrintWriter, StringWriter} +import java.util.concurrent.atomic.AtomicBoolean + +import com.eharmony.aloha.ModelSerializationTestHelper +import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor, Tree} +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.models.Subvalue +import com.eharmony.aloha.models.multilabel.MultilabelModel.{reportNoPrediction, LabelsAndInfo, NumLinesToKeepInStackTrace} +import com.eharmony.aloha.semantics.SemanticsUdfException +import com.eharmony.aloha.semantics.func._ +import org.junit.Test +import org.junit.Assert._ +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +import scala.collection.{immutable => sci, mutable => scm} +import scala.util.{Failure, Random, Success, Try} + +/** + * Created by ryan.deak and amirziai on 9/1/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class MultilabelModelTest extends ModelSerializationTestHelper { + import MultilabelModel._ + import MultilabelModelTest._ + + @Test def testSerialization(): Unit = { + // Assuming all parameters passed to the MultilabelModel constructor are + // Serializable, MultilabelModel should also be Serializable. + val modelRoundTrip = serializeDeserializeRoundTrip(ModelNoFeatures) + assertEquals(ModelNoFeatures, modelRoundTrip) + } + + @Test def testModelCloseClosesPredictor(): Unit = { + class PredictorClosable[K] extends SparseMultiLabelPredictor[K] with Closeable { + def apply( + v1: SparseFeatures, + v2: Labels[K], + v3: LabelIndices, + v4: SparseLabelDepFeatures) = Try(Map()) + private[this] val closed = new AtomicBoolean(false) + override def close(): Unit = closed.set(true) + def isClosed: Boolean = closed.get() + } + + val predictor = new PredictorClosable[Label] + val model = ModelNoFeatures.copy(predictorProducer = Lazy(predictor)) + + model.close() + assertTrue(predictor.isClosed) + } + + @Test def testLabelsOfInterestOmitted(): Unit = { + val actual: LabelsAndInfo[Label] = labelsAndInfo( + a = (), + labelsOfInterest = None, + labelToInd = Map.empty, + defaultLabelInfo = LabelsAndInfoEmpty + ) + assertEquals(LabelsAndInfoEmpty, actual) + } + + @Test def testLabelsOfInterestProvided(): Unit = { + val a = () + val labelsOfInterest = GenFunc0[Unit,sci.IndexedSeq[Label]]("", _ => LabelsInTrainingSet) + val actual: LabelsAndInfo[Label] = labelsAndInfo( + a = a, + labelsOfInterest = Option(labelsOfInterest), + Map.empty, + LabelsAndInfoEmpty + ) + val expected = labelsForPrediction(a, labelsOfInterest, Map.empty[Label, Int]) + assertEquals(expected, actual) + } + + @Test def testReportTooManyMissing(): Unit = { + val report = reportTooManyMissing( + modelId = ModelId(), + labelInfo = LabelsAndInfoEmpty, + missing = scm.Map("" -> LabelsNotInTrainingSet), + auditor = Auditor + ) + + assertEquals(Vector(TooManyMissingError), report.audited.errorMsgs.take(1)) + assertEquals(None, report.audited.value) + assertEquals(None, report.natural) + } + + @Test def testReportNoPrediction(): Unit = { + val report = ReportNoPredictionEmpty + assertEquals(Vector(NoLabelsError), report.audited.errorMsgs) + assertEquals(None, report.natural) + assertEquals(None, report.audited.value) + } + + @Test def testReportNoPredictionMissingLabelsDoNotExist(): Unit = { + // The missing labels are reported in the error message + + val report = ReportNoPredictionPartial(LabelsAndInfoMissingLabels) + assertEquals(Vector(NoLabelsError) ++ ErrorMessages, report.audited.errorMsgs) + assertEquals(None, ReportNoPredictionEmpty.audited.value) + assertEquals(Set(), report.audited.missingVarNames) + } + + @Test def testReportPredictorError(): Unit = { + val (throwable, stackTrace) = getThrowable("error") + val missingVariables = Seq("a", "b") + val report = reportPredictorError( + ModelId(), + LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet), + scm.Map("" -> missingVariables), + throwable, + Auditor + ) + assertEquals(Vector(stackTrace) ++ ErrorMessages, report.audited.errorMsgs) + assertEquals(missingVariables.toSet, report.audited.missingVarNames) + assertEquals(None, report.natural) + assertEquals(None, report.audited.value) + } + + @Test def testReportSuccess(): Unit = { + val predictions = Map("label" -> 1.0) + val report = reportSuccess( + modelId = ModelId(), + labelInfo = LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet), + missing = scm.Map("" -> LabelsNotInTrainingSet), + prediction = predictions, + auditor = Auditor + ) + assertEquals(Some(predictions), report.natural) + assertEquals(Some(predictions), report.audited.value) + assertEquals(report.natural, report.audited.value) + } + + @Test def testLabelsForPredictionContainsProblemsWhenNoLabelProvided(): Unit = { + val labelsAndInfoNoLabels = labelsForPrediction( + example = Map[String, String](), // no label provided + labelsOfInterest = LabelsOfInterestExtractor, + labelToInd = LabelsInTrainingSetToIndex) + + // In this scenario no problems are found but an + // Option(GenAggFuncAccessorProblems) is returned instead of None + val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq())) + assertEquals(problemsNoLabelsExpected, labelsAndInfoNoLabels.problems) + } + + @Test def testLabelsForPredictionContainsProblemsWhenLabelsIsNotPresent(): Unit = { + val labelsExtractor = + (label: Map[Label, Label]) => label.get("labels").map(sci.IndexedSeq[Label](_)) + val descriptor = "label is missing" + val labelsOfInterest = + GenFunc.f1(GeneratedAccessor(descriptor, labelsExtractor, None))( + "", _ getOrElse sci.IndexedSeq.empty[Label] + ) + val labelsAndInfoNoLabelsGen1 = labelsForPrediction( + example = Map[Label, Label](), + labelsOfInterest = labelsOfInterest, + labelToInd = LabelsInTrainingSetToIndex) + val problemsNoLabelsExpectedGen1 = + Option(GenAggFuncAccessorProblems(Seq(descriptor), Seq())) + assertEquals(problemsNoLabelsExpectedGen1, labelsAndInfoNoLabelsGen1.problems) + } + + @Test def testLabelsForPredictionContainsProblemsWhenLabelAccessorThrows(): Unit = { + val descriptor = "label accessor that throws" + val labelsOfInterest = + GenFunc.f1(GeneratedAccessor(descriptor, + ( _ => throw new Exception()) : Map[String, String] => Option[sci.IndexedSeq[String]], None))( + "", _ getOrElse sci.IndexedSeq.empty[Label] + ) + val labelsOfInterestWrapped = EnrichedErrorGenAggFunc(labelsOfInterest) + + val problemsNoLabelsExpected = Option(GenAggFuncAccessorProblems(Seq(), Seq(descriptor))) + Try { + labelsForPrediction(Map.empty[String, String], labelsOfInterestWrapped, + LabelsInTrainingSetToIndex) + } match { + // Failure is scala.util.Failure + case Failure(SemanticsUdfException(_, _, _, accessorsInErr, _, _)) => + assertEquals(accessorsInErr, problemsNoLabelsExpected.get.errors) + case _ => fail() + } + } + + @Test def testLabelsForPredictionProvidesLabelsThatCantBePredicted(): Unit = { + // This test would be better done with a property-based testing framework such as ScalaCheck + val labelNotInTrainingSet = "d" + val labelsAndInfo = labelsForPrediction(Map("label" -> labelNotInTrainingSet), + LabelsOfInterestExtractor, LabelsInTrainingSetToIndex) + assertEquals(Seq(labelNotInTrainingSet), labelsAndInfo.labelsNotInTrainingSet) + } + + @Test def testLabelsForPredictionReturnsLabelsSortedByIndex(): Unit = { + val example: Map[String, String] = Map( + "feature1" -> "1", + "label1" -> "a", + "label3" -> "l23", + "label4" -> "100", + "label6" -> "c", + "label8" -> "l1" + ) + + val allLabels = extractLabelsOutOfExample(example).sorted + val labelToInt = allLabels.zipWithIndex.toMap + + val random = new Random(seed=0) + (1 to 10).foreach { _ => + val ex = random.shuffle(example.toVector).take(random.nextInt(example.size)).toMap + val labelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) + val labelsAndInfo = labelsForPrediction(ex, labelsOfInterestExtractor, labelToInt) + assertEquals(labelsAndInfo.indices.sorted, labelsAndInfo.indices) + } + } + + @Test def testSubvalueReportsNoPredictionWhenNoLabelsAreProvided(): Unit = { + assertEquals(None, ModelWithFeatureFunctions.subvalue(Vector.empty).natural) + } + + @Test def testSubvalueReportsTooManyMissingWhenThereAreTooManyMissingFeatures(): Unit = { + // When the amount of missing data exceeds the threshold, reportTooManyMissing should be + // called and its value should be returned. + + val modelWithMissingThreshold = ModelWithFeatureFunctions.copy( + featureNames = sci.IndexedSeq("feature1"), + featureFunctions = FeatureFunctions, + labelsInTrainingSet = LabelsInTrainingSet, + labelsOfInterest = None, + numMissingThreshold = Option(0) + ) + + val result = modelWithMissingThreshold(Map.empty) + assertEquals(None, result.value) + assertEquals(TooManyMissingError, result.errorMsgs.head) + } + + @Test def testExceptionsThrownByPredictorAreHandledGracefully(): Unit = { + case object PredictorThatThrows extends + SparseMultiLabelPredictor[Label] { + override def apply(v1: SparseFeatures, + v2: Labels[Label], + v3: LabelIndices, + v4: SparseLabelDepFeatures): Try[Map[Label, Double]] = Try(throw new Exception("error")) + } + + val modelWithThrowingPredictorProducer = ModelWithFeatureFunctions.copy( + predictorProducer = Lazy(PredictorThatThrows), + labelsOfInterest = None + ) + + val result = modelWithThrowingPredictorProducer(Vector.empty) + assertEquals(None, result.value) + assertEquals("java.lang.Exception: error", result.errorMsgs.head.split("\n").head) + } + + @Test def testSubvalueSuccess(): Unit = { + val scoreToReturn = 5d + val modelSuccess = ModelNoFeatures.copy( + labelsInTrainingSet = sci.IndexedSeq[Label]("label1", "label2", "label3", "label4"), + labelsOfInterest = Option(LabelsOfInterestExtractor), + predictorProducer = Lazy(ConstantPredictor[Label](scoreToReturn)), + featureFunctions = sci.IndexedSeq.empty + ) + + val result = modelSuccess(Map("a" -> "b", "label1" -> "label1", "label2" -> "label2")) + assertEquals(Vector(), result.errorMsgs) + assertEquals(Set(), result.missingVarNames) + assertEquals(Option(Map("label1" -> scoreToReturn, "label2" -> scoreToReturn)), result.value) + } + + @Test def testExceptionsThrownInFeatureFunctionsAreNotCaught(): Unit = { + // This is by design. + + val exception = new Exception("error") + val featureFunctionThatThrows: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("", _ => throw exception) + + val modelWithFeatureFunctionThatThrows = ModelNoFeatures.copy( + featureNames = sci.IndexedSeq("throwing feature"), + featureFunctions = Vector(featureFunctionThatThrows), + labelsInTrainingSet = sci.IndexedSeq[Label](""), + labelsOfInterest = None + ) + + val result = Try(modelWithFeatureFunctionThatThrows(Map())) + result match { + case Success(_) => fail() + case Failure(ex) => assertEquals(exception, ex) + } + } +} + +object MultilabelModelTest { + // Types + private type Label = String + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + + private case class ConstantPredictor[K](prediction: Double = 0d) extends + SparseMultiLabelPredictor[K] { + override def apply(featuresUnused: SparseFeatures, + labels: Labels[K], + indicesUnused: LabelIndices, + ldfUnused: SparseLabelDepFeatures): Try[Map[K, Double]] = + Try(labels.map(_ -> prediction).toMap) + } + + private case class Lazy[A](value: A) extends (() => A) { + override def apply(): A = value + } + + // Common input + private val LabelsInTrainingSet: sci.IndexedSeq[Label] = sci.IndexedSeq[Label]("a", "b", "c") + private val LabelsInTrainingSetToIndex: Map[Label, Int] = LabelsInTrainingSet.zipWithIndex.toMap + private val LabelsNotInTrainingSet: Seq[Label] = Seq("a", "b") + private val BaseErrorMessage: Seq[String] = Stream.continually("Label not in training labels: ") + private val ErrorMessages: Seq[String] = BaseErrorMessage.zip(LabelsNotInTrainingSet).map { + case(msg, label) => s"$msg$label" + } + + // Feature functions + private val EmptyIndicatorFn: GenAggFunc[Map[String, String], Iterable[(String, Double)]] = + GenFunc0("", _ => Iterable()) + private val FeatureFunctions = Vector(EmptyIndicatorFn) + + // Models + private val ModelNoFeatures = MultilabelModel( + modelId = ModelId(), + featureNames = sci.IndexedSeq(), + featureFunctions = sci.IndexedSeq[GenAggFunc[Int, Sparse]](), + labelsInTrainingSet = sci.IndexedSeq[Label](), + labelsOfInterest = None, + predictorProducer = Lazy(ConstantPredictor[Label]()), + numMissingThreshold = None, + auditor = Auditor + ) + + private val ModelWithFeatureFunctions: + MultilabelModel[Tree[Any], String, Vector[String], RootedTree[Any, Map[Label, Double]]] = + ModelNoFeatures.copy( + featureFunctions = sci.IndexedSeq[GenAggFunc[Vector[String], Sparse]](), + labelsInTrainingSet = LabelsInTrainingSet, + labelsOfInterest = Some(GenFunc0("", (a: Vector[String]) => a)) + ) + + // LabelsAndInfo + private val LabelsAndInfoEmpty = LabelsAndInfo( + indices = sci.IndexedSeq.empty, + labels = sci.IndexedSeq.empty, + labelsNotInTrainingSet = Seq.empty[Label], + problems = None + ) + private val LabelsAndInfoMissingLabels: LabelsAndInfo[Label] = + LabelsAndInfoEmpty.copy(labelsNotInTrainingSet = LabelsNotInTrainingSet) + + // Reports + private val ReportNoPredictionPartial: + (LabelsAndInfo[Label]) => Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = + reportNoPrediction( + modelId = ModelId(), + _: LabelsAndInfo[Label], + auditor = Auditor + ) + private val ReportNoPredictionEmpty: Subvalue[RootedTree[Any, Map[Label, Double]], Nothing] = + ReportNoPredictionPartial(LabelsAndInfoEmpty) + + // Throwable and stack trace + private def getThrowable(errorMessage: String): (Throwable, String) = { + val throwable = new Exception(errorMessage) + val sw = new StringWriter + val pw = new PrintWriter(sw) + throwable.printStackTrace(pw) + val stackTrace = sw.toString.split("\n").take(NumLinesToKeepInStackTrace).mkString("\n") + (throwable, stackTrace) + } + + // Label extractors + private def extractLabelsOutOfExample(example: Map[String, String]): sci.IndexedSeq[String] = + example.filterKeys(_.startsWith("label")).toSeq.unzip._2.toIndexedSeq + + private val LabelsOfInterestExtractor = GenFunc0("", extractLabelsOutOfExample) +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala index 0acb71ea..41601246 100644 --- a/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala +++ b/aloha-core/src/test/scala/com/eharmony/aloha/models/reg/PolynomialEvaluationAlgoTest.scala @@ -30,7 +30,6 @@ class PolynomialEvaluationAlgoTest { private[this] val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]()) - @Test def testManualPolyEval() { val x = IndexedSeq( Seq(("intercept", 1.0)), diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala new file mode 100644 index 00000000..3d5ff863 --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/CompiledSemanticsInstances.scala @@ -0,0 +1,28 @@ +package com.eharmony.aloha.semantics.compiled + +import com.eharmony.aloha.FileLocations +import com.eharmony.aloha.reflect.RefInfo +import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler +import com.eharmony.aloha.semantics.compiled.plugin.AnyNameIdentitySemanticsPlugin + +/** + * Created by ryan.deak on 9/26/17. + */ +object CompiledSemanticsInstances { + + /** + * Creates a [[CompiledSemantics]] with the [[AnyNameIdentitySemanticsPlugin]] and with + * imports `List("com.eharmony.aloha.feature.BasicFunctions._")`. + * A class cache directory will be used. + * @tparam A the domain of the functions created by this semantics. + * @return + */ + private[aloha] def anyNameIdentitySemantics[A: RefInfo]: CompiledSemantics[A] = { + import scala.concurrent.ExecutionContext.Implicits.global + new CompiledSemantics( + new TwitterEvalCompiler(classCacheDir = FileLocations.testGeneratedClasses), + new AnyNameIdentitySemanticsPlugin[A], + List("com.eharmony.aloha.feature.BasicFunctions._") + ) + } +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala new file mode 100644 index 00000000..40324d2c --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/semantics/compiled/plugin/AnyNameIdentitySemanticsPlugin.scala @@ -0,0 +1,25 @@ +package com.eharmony.aloha.semantics.compiled.plugin + +import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps} +import com.eharmony.aloha.semantics.compiled.{CompiledSemanticsPlugin, RequiredAccessorCode} + +/** + * This plugin binds ANY variable name to the input. For example, if the domain of + * the generated function is `Int` and the function specification is any of the following: + * + - `"${asdf} + 1"` + - `"${jkl} + 1"` + - `"${i_dont_care_the_name} + 1"` + - ... + * + * the resulting function adds 1 to the input. + * + * Created by ryan.deak on 9/26/17. + * @param refInfoA reflection information about the function domain. + * @tparam A domain of the functions being generated from this plugin. + */ +private[aloha] class AnyNameIdentitySemanticsPlugin[A](implicit val refInfoA: RefInfo[A]) + extends CompiledSemanticsPlugin[A] { + override def accessorFunctionCode(spec: String): Right[Nothing, RequiredAccessorCode] = + Right(RequiredAccessorCode(Seq(s"identity[${RefInfoOps.toString[A]}]"))) +} diff --git a/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala new file mode 100644 index 00000000..f60b1a5c --- /dev/null +++ b/aloha-core/src/test/scala/com/eharmony/aloha/util/rand/RandTest.scala @@ -0,0 +1,183 @@ +package com.eharmony.aloha.util.rand + +import java.util.Random + +import org.junit.Test +import org.junit.Assert.fail +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +import scala.annotation.tailrec + +/** + * Test the randomness of the random! + * + * @author deaktator + * @since 11/3/2017 + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class RandTest extends Rand { + import RandTest._ + + @Test def testSampleCombinationProbabilities(): Unit = { + val failures = findFailures( + trials = 25, + maxN = 6, + minSamples = 1000, + maxSamples = 10000, + seed = 0 + ) + + val failureMsg = allFailuresMsg(failures) + failureMsg foreach fail + } + + /** + * Sets up the sampling scenarios. + * @param trials number of trials to run. + * @param maxN maximum N value + * @param minSamples minimum number of samples to draw per trial. + * @param maxSamples maximum number of samples to draw per trial. + * @param seed a random seed to use for generating the parameters in the scenarios. + * @return sampling scenarios + */ + private def samplingScenarios( + trials: Int, + maxN: Int, + minSamples: Int, + maxSamples: Int, + seed: Long): List[SamplingScenario] = { + + // Get the scenarios eagerly to avoid carrying around the PRNG. + // If there was stateless "randomness", this could be non-strict. + val (scenarios, _) = + (1 to trials).foldLeft((List.empty[SamplingScenario], new Random(seed))){ + case ((ss, r), _) => + val initSeed = r.nextLong() + val samples = minSamples + r.nextInt(maxSamples - minSamples + 1) + val n = r.nextInt(maxN + 1) + val k = r.nextInt(n + 1) + (SamplingScenario(initSeed, samples, n, k) :: ss, r) + } + + scenarios + } + + private def findFailures(trials: Int, maxN: Int, minSamples: Int, maxSamples: Int, seed: Long) = { + samplingScenarios(trials, maxN, minSamples, maxSamples, seed).iterator.flatMap { + case SamplingScenario(initSeed, samples, n, k) => + val dist = samplingDist(initSeed, samples, n, k) + checkDistributionUniformity(samples, n, k, dist).toIterable + } + } + + private def allFailuresMsg[A](failures: Iterator[TestFailure[A]]): Option[String] = { + if (failures.isEmpty) + None + else { + val errorMsg = + failures.foldLeft(""){ case (msg, TestFailure(samples, n, k, fails, dist)) => + val thisFail = + s"For (n: $n, k: $k, samples: $samples), produced distribution: $dist. Failures:" + + fails.mkString("\n\t", "\n\t", "\n\n") + + msg + thisFail + } + + Option(errorMsg.trim) + } + } + + /** + * For any `n` and `k`, `choose(n, k)` states should be sampled with uniform probability. + * @param samples number of samples drawn + * @param n number of objects from which to choose. + * @param k number of elements drawn from the `n` objects. + * @param dist distribution created from `samples` ''samples''. + * @tparam A type of random variable. + * @return a potential error. + */ + private def checkDistributionUniformity[A]( + samples: Int, + n: Int, + k: Int, + dist: Distribution[A] + ): Option[TestFailure[A]] = { + + val expStates = choose(n, k) + val expPr = 1d / expStates + + // This can be driven down by increasing `samples`. + // This is proportional to 1 / sqrt(samples). + // So, if samples is 10000, this is about 10%. + val pctDiffFromExpPr = 100 * (10 / math.sqrt(samples)) + val delta = expPr * (pctDiffFromExpPr / 100) + + // Check that all states are sampled. + val allStates = + if (expStates == dist.size) + List.empty[FailureReason] + else List(MissingStates(expStates.toInt, dist.size)) + + val errors = + dist.foldRight(allStates){ case ((state, pr), errs) => + if (math.abs(expPr - pr) < delta) + errs + else WrongProbability(state, expPr, pr, delta) :: errs + } + + if (errors.isEmpty) + None + else Option(TestFailure(samples, n, k, errors, dist)) + } + + private def samplingDist(initSeed: Long, samples: Int, n: Int, k: Int): Distribution[String] = + drawSamples(initSeed, samples, n, k).mapValues { c => c / samples.toDouble } + + private def drawSamples(initSeed: Long, samples: Int, n: Int, k: Int): Map[String, NumSamples] = { + val seed = initSeedScramble(initSeed) + + val samplesAndSeed = + (1 to samples) + .foldLeft((Map.empty[String, Int], seed)){ case ((m, s), _) => + + // This is the sampling method whose uniformity is being tested. + val (ind, newSeed) = sampleCombination(n, k, s) + + // It is important to sort here because order shouldn't matter since these + // samples are to be treated as sets, not sequences. By contract, + // `sampleCombinations` doesn't guarantee a particular ordering but the + // results are returned in an array (for computational efficiency). + val key = ind.sorted.mkString(",") + val updatedMap = m + (key -> (m.getOrElse(key, 0) + 1)) + (updatedMap, newSeed) + } + + // Throw away final seed. + samplesAndSeed._1 + } + + private def choose(n: Int, k: Int): Long = { + @tailrec def fact(n: Int, p: Long = 1): Long = + if (n <= 1) p else fact(n - 1, n * p) + + fact(n) / (fact(k) * fact(n-k)) + } +} + +object RandTest { + private type NumSamples = Int + private type Distribution[A] = Map[A, Double] + + private case class SamplingScenario(initSeed: Long, samples: Int, n: Int, k: Int) + + private sealed trait FailureReason + private final case class MissingStates(expected: Int, actual: Int) extends FailureReason { + override def toString = s"States missing. Expected $expected states. Found $actual." + } + + private final case class WrongProbability[A](state: A, expected: Double, actual: Double, delta: Double) extends FailureReason { + override def toString = s"Incorrect probability for state $state. Expected pr = $expected +- $delta. Found pr = $actual." + } + private final case class TestFailure[A](samples: Int, n: Int, k: Int, failures: Seq[FailureReason], dist: Distribution[A]) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala new file mode 100644 index 00000000..d268ccfc --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/Namespaces.scala @@ -0,0 +1,50 @@ +package com.eharmony.aloha.models.vw.jni + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 9/8/17. + */ +trait Namespaces { + + def namespaceIndicesAndMissing( + indices: Map[String, Int], + namespaces: ListMap[String, Seq[String]] + ): (List[(String, List[Int])], List[(String, List[String])]) = { + val nss = namespaces.map { case (ns, fs) => + val indMissing = fs.map { f => + val oInd = indices.get(f) + if (oInd.isEmpty) + (oInd, Option(f)) + else (oInd, None) + } + + val (oInd, oMissing) = indMissing.unzip + (ns, oInd.flatten.toList, oMissing.flatten.toList) + } + + val nsInd = nss.collect { case (ns, ind, _) if ind.nonEmpty => (ns, ind) }.toList + val nsMissing = nss.collect { case (ns, _, m) if m.nonEmpty => (ns, m) }.toList + (nsInd, nsMissing) + } + + def defaultNamespaceIndices( + indices: Map[String, Int], + namespaces: ListMap[String, Seq[String]] + ): List[Int] = { + val featuresInNss = namespaces.foldLeft(Set.empty[String]){ case (s, (_, fs)) => s ++ fs } + val unusedFeatures = indices.keySet -- featuresInNss + unusedFeatures.flatMap(indices.get).toList + } + + def allNamespaceIndices( + featureNames: Seq[String], + namespaces: ListMap[String, Seq[String]] + ): (List[(String, List[Int])], List[Int], List[(String, List[String])]) = { + val indices = featureNames.zipWithIndex.toMap + val (nsi, missing) = namespaceIndicesAndMissing(indices, namespaces) + val defaultNsInd = defaultNamespaceIndices(indices, namespaces) + + (nsi, defaultNsInd, missing) + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala index 41b30318..8887a9e6 100644 --- a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/VwJniModel.scala @@ -219,7 +219,8 @@ object VwJniModel object Parser extends ModelSubmodelParsingPlugin with EitherHelpers - with RegFeatureCompiler { self => + with RegFeatureCompiler + with Namespaces { self => val modelType = "VwJNI" @@ -239,21 +240,13 @@ object VwJniModel case Right(featureMap) => val (names, functions) = featureMap.toIndexedSeq.unzip - val indices = featureMap.unzip._1.view.zipWithIndex.toMap - val nssRaw = vw.namespaces.getOrElse(sci.ListMap.empty) - val nss = nssRaw.map { case (ns, fs) => - val fi = fs.flatMap { f => - val oInd = indices.get(f) - if (oInd.isEmpty) info(s"Ignoring feature '$f' in namespace '$ns'. Not in the feature list.") - oInd - }.toList - (ns, fi) - }.toList + val (nss, defaultNs, missing) = + allNamespaceIndices(names, vw.namespaces.getOrElse(sci.ListMap.empty)) - val vwParams = vw.vw.params.fold("")(_.fold(_.mkString(" "), x => x)).trim - - val defaultNs = (indices.keySet -- (Set.empty[String] /: nssRaw)(_ ++ _._2)).flatMap(indices.get).toList + if (missing.nonEmpty) + info(s"Ignoring features in namespaces not in feature list: $missing") + val vwParams = vw.vw.params.fold("")(_.fold(_.mkString(" "), x => x)).trim val learnerCreator = LearnerCreator[N](vw.classLabels, vw.spline, vwParams) VwJniModel( diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala new file mode 100644 index 00000000..e754aae5 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModel.scala @@ -0,0 +1,8 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + + +/** + * Created by ryan.deak on 9/29/17. + */ +object VwMultilabelModel extends VwMultlabelJsonCreator + with VwMultilabelParamAugmentation {} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala new file mode 100644 index 00000000..c841ba5d --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentation.scala @@ -0,0 +1,464 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.File + +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner +import org.apache.commons.io.{FileUtils, IOUtils} +import vowpalWabbit.learner.VWLearners + +import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} + + +/** + * Created by ryan.deak on 10/5/17. + */ +protected trait VwMultilabelParamAugmentation { + + protected type VWNsSet = Set[Char] + protected type VWNsCrossProdSet = Set[(Char, Char)] + + /** + * Adds VW parameters to make the parameters work as an Aloha multilabel model. + * + * The algorithm works as follows: + * + 1. Ensure the VW `csoaa_ldf` or `wap_ldf` reduction is specified in the supplied VW + parameter list (''with the appropriate option for the flag''). + 1. Ensure that no "''unrecoverable''" flags appear in the supplied VW parameter list. + See `UnrecoverableFlagSet` for flags whose appearance is considered + "''unrecoverable''". + 1. Ensure that ''ignore'' and ''interaction'' flags (`--ignore`, `--ignore_linear`, `-q`, + `--quadratic`, `--cubic`) do not refer to namespaces not supplied in + the `namespaceNames` parameter. + 1. Attempt to determine namespace names that can be used for the labels. For more + information on the label namespace resolution algorithm, see: + `com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.determineLabelNamespaces`. + 1. Remove flags and options found in `FlagsToRemove`. + 1. Add `--noconstant` and `--csoaa_rank` flags. `--noconstant` is added because per-label + intercepts will be included and take the place of a single intercept. `--csoaa_rank` + is added to make the `VWLearner` a `VWActionScoresLearner`. + 1. Create interactions between features and the label namespaces created above. + a. If a namespace in `namespaceNames` appears as an option to VW's `ignore_linear` flag, + '''do not''' create a quadratic interaction between that namespace and the label + namespace. + a. For each interaction term (`-q`, `--quadratic`, `--cubic`, `--interactions`), replace it + with an interaction term also interacted with the label namespace. This increases the + arity of the interaction by 1. + 1. Finally, change the flag options that reference files to point to temp files so that + VW doesn't change the files. This may represent a problem if VW needs to read the file + in the option because although it should exist, it will be empty. + 1. Let VW doing any validations it can. + * + * ==Success Examples== + * + * {{{ + * import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.updatedVwParams + * + * // This is a basic example. 'y' and 'Y' in the output are label + * // namespaces. Notice all namespaces are quadratically interacted + * // with the label namespace. + * val uvw1 = updatedVwParams( + * "--csoaa_ldf mc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + * // "--ignore_linear abc -qYa -qYb -qYc") + * + * // Here since 'a' is in 'ignore_linear', no '-qYa' term appears + * // in the output. + * val uvw2 = updatedVwParams( + * "--csoaa_ldf mc --ignore_linear a -qbc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore y " + + * // "--ignore_linear abc -qYb -qYc --cubic Ybc) + * + * // 'a' is in 'ignore', so no terms with 'a' are emitted. 'b' is + * // in 'ignore_linear' so it does occur in any quadratic + * // interactions in the output, but can appear in interaction + * // terms of higher arity like the cubic interaction. + * val uvw3 = updatedVwParams( + * "--csoaa_ldf mc --ignore a --ignore_linear b -qbc --cubic abc", + * Set("a", "b", "c") + * ) + * // Right("--csoaa_ldf mc --noconstant --csoaa_rank --ignore ay " + + * // "--ignore_linear bc -qYc --cubic Ybc") + * }}} + * + * ==Errors Examples== + * + * {{{ + * import com.eharmony.aloha.models.vw.jni.multilabel.VwMultilabelModel.updatedVwParams + * import com.eharmony.aloha.models.vw.jni.multilabel.{ + * NotCsoaaOrWap, + * NamespaceError + * } + * + * assert( updatedVwParams("", Set()) == Left(NotCsoaaOrWap("")) ) + * + * assert( + * updatedVwParams("--wap_ldf m -qaa", Set()) == + * Left(NamespaceError("--wap_ldf m -qaa", Set(), Map("quadratic" -> Set('a')))) + * ) + * + * assert( + * updatedVwParams( + * "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + + "--cubic bcd --interactions dde --interactions abcde", + * Set() + * ) == + * Left( + * NamespaceError( + * "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + + * "--interactions dde --interactions abcde", + * Set(), + * Map( + * "ignore" -> Set('a'), + * "ignore_linear" -> Set('b'), + * "quadratic" -> Set('b', 'd'), + * "cubic" -> Set('b', 'c', 'd', 'e'), + * "interactions" -> Set('a', 'b', 'c', 'd', 'e') + * ) + * ) + * ) + * ) + * }}} + * + * @param vwParams current VW parameters passed to the VW JNI + * @param namespaceNames it is assumed that `namespaceNames` is a superset + * of all of the namespaces referred to by any flags + * found in `vwParams`. + * @param numUniqueLabels the number of unique labels in the training set. + * This is used to calculate the appropriate VW + * `ring_size` parameter. + * @return + */ + def updatedVwParams( + vwParams: String, + namespaceNames: Set[String], + numUniqueLabels: Int + ): Either[VwParamError, String] = { + lazy val unrecovFlags = unrecoverableFlags(vwParams) + + if (WapOrCsoaa.findFirstMatchIn(vwParams).isEmpty) + Left(NotCsoaaOrWap(vwParams)) + else if (unrecovFlags.nonEmpty) + Left(UnrecoverableParams(vwParams, unrecovFlags)) + else { + val is = interactions(vwParams) + val i = ignored(vwParams) + val il = ignoredLinear(vwParams) + + // This won't effect anything if the definition of UnrecoverableFlags contains + // all of the flags referenced in the flagsRefMissingNss function. If there + // are flags referenced in flagsRefMissingNss but not in UnrecoverableFlags, + // then this is a valid check. + val flagsRefMissingNss = flagsReferencingMissingNss(namespaceNames, i, il, is) + + if (flagsRefMissingNss.nonEmpty) + Left(NamespaceError(vwParams, namespaceNames, flagsRefMissingNss)) + else + VwMultilabelRowCreator.determineLabelNamespaces(namespaceNames).fold( + Left(LabelNamespaceError(vwParams, namespaceNames)): Either[VwParamError, String] + ){ labelNs => + val paramsWithoutRemoved = removeParams(vwParams) + val updatedParams = + addParams(paramsWithoutRemoved, namespaceNames, i, il, is, labelNs, numUniqueLabels) + + val (finalParams, flagToFileMap) = replaceFileBasedFlags(updatedParams, FileBasedFlags) + + val ps = validateVwParams( + vwParams, updatedParams, finalParams, flagToFileMap, !isQuiet(updatedParams) + ) + flagToFileMap.values foreach FileUtils.deleteQuietly // IO: Delete the files. + ps + } + } + } + + /** + * VW Flags automatically resulting in an error. + */ + protected val UnrecoverableFlagSet: Set[String] = + Set("redefine", "stage_poly", "keep", "permutations", "autolink") + + /** + * This is the capture group containing the content when the regex has been + * padded with the pad function. + */ + protected val CaptureGroupWithContent = 2 + + private[this] val FileBasedFlags = Set( + "-f", "--final_regressor", + "--readable_model", + "--invert_hash", + "--output_feature_regularizer_binary", + "--output_feature_regularizer_text", + "-p", "--predictions", + "-r", "--raw_predictions", + "-c", "--cache", + "--cache_file" + ) + + /** + * Pad the regular expression with a prefix and suffix that makes matching work. + * The prefix is `(^|\s)` and means the if there's a character preceding the main + * content in `s`, then that character should be whitespace. The suffix is + * `(?=\s|$)` which means that if a character follows the main content matched by + * `s`, then that character should be whitespace '''AND''' ''that character should + * not'' be consumed by the Regex. By allowing that character to be present for the + * next matching of a regex, it is consumable by the prefix of a regex padded with + * the `pad` function. + * @param s a string + * @return + */ + private[this] def pad(s: String) = """(^|\s)""" + s + """(?=\s|$)""" + private[this] val NumRegex = """-?(\d+(\.\d*)?|\d*\.\d+)([eE][+-]?\d+)?""" + private[this] val ClassCastMsg = """(\S+) cannot be cast to (\S+)""".r + private[this] val CsoaaRank = pad("--csoaa_rank").r + private[this] val WapOrCsoaa = pad("""--(csoaa|wap)_ldf\s+(mc?)""").r + private[this] val Quiet = pad("--quiet").r + protected val Ignore : Regex = pad("""--ignore\s+(\S+)""").r + protected val IgnoreLinear: Regex = pad("""--ignore_linear\s+(\S+)""").r + private[this] val UnrecoverableFlags = pad("--(" + UnrecoverableFlagSet.mkString("|") + ")").r + private[this] val QuadraticsShort = pad("""-q\s*([\S]{2})""").r + private[this] val QuadraticsLong = pad("""--quadratic\s+(\S{2})""").r + private[this] val Cubics = pad("""--cubic\s+(\S{3})""").r + private[this] val Interactions = pad("""--interactions\s+(\S{2,})""").r + private[this] val NoConstant = pad("""--noconstant""").r + private[this] val ConstantShort = pad("""-C\s*(""" + NumRegex + ")").r + private[this] val ConstantLong = pad("""--constant\s+(""" + NumRegex + ")").r + + private[this] val FlagsToRemove = Seq( + QuadraticsShort, + QuadraticsLong, + Cubics, + Interactions, + NoConstant, + ConstantShort, + ConstantLong, + CsoaaRank, + IgnoreLinear, + Ignore + ) + + /** + * Remove flags (and options) for the flags listed in `FlagsToRemove`. + * @param vwParams VW params passed to the `updatedVwParams` function. + * @return + */ + protected def removeParams(vwParams: String): String = + FlagsToRemove.foldLeft(vwParams)((s, r) => r.replaceAllIn(s, "")) + + protected def addParams( + paramsAfterRemoved: String, + namespaceNames: Set[String], + oldIgnored: VWNsSet, + oldIgnoredLinear: VWNsSet, + oldInteractions: Set[String], + labelNs: LabelNamespaces, + numUniqueLabels: Int + ): String = { + val i = oldIgnored + labelNs.dummyLabelNs + + // Don't include namespaces that are ignored in ignore_linear. + val il = (toVwNsSet(namespaceNames) ++ oldIgnoredLinear) -- i + + // Don't turn a given namespace into quadratics interacted on label when the + // namespace is listed in the ignore_linear flag. + val qs = il.flatMap(n => + if (oldIgnored.contains(n) || oldIgnoredLinear.contains(n)) Nil + else List(s"${labelNs.labelNs}$n") + ) + + // Turn quadratic into cubic and cubic into higher-order interactions. + val cs = createLabelInteractions(oldInteractions, oldIgnored, labelNs, _ == 2) + val hos = createLabelInteractions(oldInteractions, oldIgnored, labelNs, _ >= 3) + + val quadratics = qs.toSeq.sorted.map(q => s"-q$q" ).mkString(" ") + val cubics = cs.toSeq.sorted.map(c => s"--cubic $c").mkString(" ") + val ints = hos.toSeq.sorted.map(ho => s"--interactions $ho").mkString(" ") + val igLin = if (il.nonEmpty) il.toSeq.sorted.mkString("--ignore_linear ", "", "") else "" + + val rs = s"--ring_size ${numUniqueLabels + VwSparseMultilabelPredictor.AddlVwRingSize}" + + // This is non-empty b/c i is non-empty. + val ig = s"--ignore ${i.mkString("")}" + + // Consolidate whitespace because there shouldn't be whitespace in these flags' options. + val additions = s" --noconstant --csoaa_rank $rs $ig $igLin $quadratics $cubics $ints" + .replaceAll("\\s+", " ") + (paramsAfterRemoved.trim + additions).trim + } + + /** + * VW will actually update / replace files if files appear as options to flags. To overcome + * this, an attempt is made to detect flags referencing files and if found, replace the the + * files with temp files. These files should be deleted before exiting the main program. + * @param updatedParams the parameters after the updates. + * @param flagsWithFiles the flag + * @return a tuple2 of the final string to try with VW for validation along with the mapping + * from flag to file that was used. + */ + protected def replaceFileBasedFlags(updatedParams: String, flagsWithFiles: Set[String]): (String, Map[String, File]) = { + // This is rather hairy function. + + def flagRe(flags: Set[String], groupsForFlag: Int, c1: String, c2: String, c3: String) = + if (flags.nonEmpty) + Option(pad(flags.map(_ drop groupsForFlag).toVector.sorted.mkString(c1, c2, c3)).r) + else None + + // Get short and long flags. + val shrt = flagsWithFiles.filter(s => s.startsWith("-") && 2 == s.length && s.charAt(1).isLetterOrDigit) + val lng = flagsWithFiles.filter(s => s.startsWith("--") && 2 < s.length) + + val regexes = List( + flagRe(shrt, 1, "(-[", "", """])\s*(\S+)"""), + flagRe(lng, 2, "(--(", "|", """))\s+(\S+)""") + ).flatten + + regexes.foldLeft((updatedParams, Map[String, File]())) { case ((ps, ffm), r) => + // Fold right to not affect subsequent replacements. + r.findAllMatchIn(ps).foldRight((ps, ffm)) { case (m, (ps1, ffm1)) => + val f = File.createTempFile("REPLACED_", "_FILE") + f.deleteOnExit() // Attempt to be safe here. + + val flag = m.group(CaptureGroupWithContent) + val rep = s"$flag ${f.getCanonicalPath}" + val ps2 = ps1.take(m.start) + rep + ps1.drop(m.end) + val ffm2 = ffm1 + (flag -> f) + (ps2, ffm2) + } + } + } + + protected def createLabelInteractions( + interactions: Set[String], + ignored: VWNsSet, + labelNs: LabelNamespaces, + filter: Int => Boolean + ): Set[String] = + interactions.collect { + case i if filter(i.length) && // Filter based on arity. + !i.toCharArray.exists(ignored.contains) => // Filter out ignored. + s"${labelNs.labelNs}$i" + } + + /** + * Get the set of interactions (encoded as Strings). String length represents the + * interaction arity. + * @param vwParams VW params passed to the `updatedVwParams` function. + * @return + */ + protected def interactions(vwParams: String): Set[String] = + List( + QuadraticsShort, + QuadraticsLong, + Cubics, + Interactions + ).foldLeft(Set.empty[String]){(is, r) => + is ++ firstCaptureGroups(vwParams, r).map(s => s.sorted) + } + + protected def unrecoverableFlags(vwParams: String): Set[String] = + firstCaptureGroups(vwParams, UnrecoverableFlags).toSet + + protected def isQuiet(vwParams: String): Boolean = Quiet.findFirstIn(vwParams).nonEmpty + protected def ignored(vwParams: String): VWNsSet = charsIn(Ignore, vwParams) + protected def ignoredLinear(vwParams: String): VWNsSet = charsIn(IgnoreLinear, vwParams) + + protected def handleClassCastException( + orig: String, + mod: String, + ex: ClassCastException + ): VwParamError = + ex.getMessage match { + case ClassCastMsg(from, _) => IncorrectLearner(orig, mod, from) + case _ => ClassCastErr(orig, mod, ex) + } + + protected def flagsReferencingMissingNss( + namespaceNames: Set[String], + i: VWNsSet, + il: VWNsSet, + is: Set[String] + ): Map[String, VWNsSet] = { + val q = filterAndFlattenInteractions(is, _ == 2) + val c = filterAndFlattenInteractions(is, _ == 3) + val ho = filterAndFlattenInteractions(is, _ >= 4) + flagsReferencingMissingNss(namespaceNames, i, il, q, c, ho) + } + + protected def filterAndFlattenInteractions(is: Set[String], filter: Int => Boolean): VWNsSet = + is.flatMap { + case interaction if filter(interaction.length) => interaction.toCharArray + case _ => Nil + } + + protected def flagsReferencingMissingNss( + namespaceNames: Set[String], + i: VWNsSet, il: VWNsSet, q: VWNsSet, c: VWNsSet, ho: VWNsSet + ): Map[String, VWNsSet] = + nssNotInNamespaceNames( + namespaceNames, + "ignore" -> i, + "ignore_linear" -> il, + "quadratic" -> q, + "cubic" -> c, + "interactions" -> ho + ) + + protected def nssNotInNamespaceNames( + nsNames: Set[String], + sets: (String, VWNsSet)* + ): Map[String, VWNsSet] = { + val vwNss = toVwNsSet(nsNames) + + sets.foldLeft(Map.empty[String, VWNsSet]){ case (m, (setName, nss)) => + val extra = nss diff vwNss + if (extra.isEmpty) m + else m + (setName -> extra) + } + } + + // TODO: Change file + protected def validateVwParams( + orig: String, + mod: String, + finalPs: String, + flagToFileMap: Map[String, File], + addQuiet: Boolean + ): Either[VwParamError, String] = { + val ps = if (addQuiet) s"--quiet $finalPs" else finalPs + + Try { VWLearners.create[ExpectedLearner](ps) } match { + case Success(m) => + IOUtils.closeQuietly(m) + Right(mod) + case Failure(cce: ClassCastException) => + Left(handleClassCastException(orig, mod, cce)) + case Failure(ex) => + Left(VwError(orig, mod, ex.getMessage)) + } + } + + // More general functions. + + /** + * Find all of the regex matches and extract the first capture group from the match. + * @param vwParams VW params passed to the `updatedVwParams` function. + * @param regex with at least one capture group (this is unchecked). + * @return Iterator of the matches' first capture group. + */ + protected def firstCaptureGroups(vwParams: String, regex: Regex): Iterator[String] = + regex.findAllMatchIn(vwParams).map(m => m.group(CaptureGroupWithContent)) + + protected def charsIn(r: Regex, chrSeq: CharSequence): VWNsSet = + r.findAllMatchIn(chrSeq).flatMap(m => m.group(CaptureGroupWithContent).toCharArray).toSet + + private[multilabel] def toVwNsSet(nsNames: Set[String]): VWNsSet = + nsNames.flatMap(_.take(1).toCharArray) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala new file mode 100644 index 00000000..bc3e452c --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultlabelJsonCreator.scala @@ -0,0 +1,128 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.id.{ModelId, ModelIdentity} +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.models.multilabel.json.MultilabelModelJson +import com.eharmony.aloha.models.reg.json.Spec +import com.eharmony.aloha.models.vw.jni.VwJniModel +import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelJson +import spray.json.{DefaultJsonProtocol, JsValue, JsonWriter, pimpAny, pimpString} + +import scala.collection.immutable.ListMap + +/** + * Created by ryan.deak on 10/5/17. + */ +private[multilabel] trait VwMultlabelJsonCreator +extends MultilabelModelJson + with VwMultilabelModelJson { + + /** + * Create a JSON representation of an Aloha model. + * + * '''NOTE''': Because of the inclusion of the unrestricted `labelsOfInterest` parameter, + * the JSON produced by this function is not guaranteed to result in a valid + * Aloha model. This is because no semantics are required by this function and + * so, the `labelsOfInterest` function specification cannot be validated. + * + * @param datasetSpec a location of a dataset specification. + * @param binaryVwModel a location of a VW binary model file. + * @param id a model ID. + * @param labelsInTrainingSet The sequence of all labels encountered in the training set used + * to produce the `binaryVwModel`. + * '''''It is extremely important that this sequence has the + * same order as the sequence of labels used in the + * dataset creation process.''''' Otherwise, the VW model might + * associate scores with an incorrect label. + * @param labelsOfInterest It is possible that a model is trained on a super set of labels for + * which predictions can be made. If the labels at prediction time + * differs (''or should be extracted from the input to the model''), + * this function can provide that capability. + * @param externalModel whether the underlying binary VW model should remain as a separate + * file and be referenced by the Aloha model specification (`true`) + * or the binary model content should be embeeded directly into the model + * (`false`). '''Keep in mind''' Aloha models must be smaller than 2 GB + * because they are decoded to `String`s and `String`s are indexed by + * 32-bit integers (which have a max value of 2^32^ - 1). + * @param numMissingThreshold the number of missing features to tolerate before emitting a + * prediction failure. + * @tparam K the type of label or class. + * @return a JSON object. + */ + def json[K: JsonWriter]( + datasetSpec: Vfs, + binaryVwModel: Vfs, + id: ModelIdentity, + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String] = None, + externalModel: Boolean = false, + numMissingThreshold: Option[Int] = None + ): JsValue = { + + val dsJsAst = StringReadable.fromInputStream(datasetSpec.inputStream).parseJson + val ds = dsJsAst.convertTo[VwMultilabeledJson] + val features = modelFeatures(ds.features) + val namespaces = ds.namespaces.map(modelNamespaces) + val modelSrc = modelSource(binaryVwModel, externalModel) + val vw = vwModelPlugin(modelSrc, namespaces) + val model = modelAst(id, features, numMissingThreshold, labelsInTrainingSet, labelsOfInterest, vw) + + // Even though we are only *writing* JSON, we need a JsonFormat[K], (which is a reader + // and writer) because multilabelDataJsonFormat has a JsonFormat context bound on K. + // So to turn MultilabelData into JSON, we need to lift the JsonWriter for K into a + // JsonFormat and store it as an implicit value (or create an implicit conversion + // function on implicit arguments. + implicit val labelJF = DefaultJsonProtocol.lift(implicitly[JsonWriter[K]]) + model.toJson + } + + private[multilabel] def modelFeatures(featuresSpecs: Seq[SparseSpec]) = + ListMap ( + featuresSpecs.map { case SparseSpec(name, spec, defVal) => + name -> Spec(spec, defVal.map(ts => ts.toSeq)) + }: _* + ) + + private[multilabel] def modelNamespaces(nss: Seq[Namespace]) = + ListMap( + nss.map(ns => ns.name -> ns.features) : _* + ) + + private[multilabel] def modelSource(binaryVwModel: Vfs, externalModel: Boolean) = + if (externalModel) + ExternalSource(binaryVwModel) + else Base64StringSource(VwJniModel.readBinaryVwModelToB64String(binaryVwModel.inputStream)) + + // Private b/c VwMultilabelAst is protected[this]. Don't let it escape. + private def vwModelPlugin( + modelSrc: ModelSource, + namespaces: Option[ListMap[String, Seq[String]]]) = + VwMultilabelAst( + VwSparseMultilabelPredictorProducer.multilabelPlugin.name, + modelSrc, + namespaces + ) + + // Private b/c MultilabelData is private[this]. Don't let it escape. + private def modelAst[K]( + id: ModelIdentity, + features: ListMap[String, Spec], + numMissingThreshold: Option[Int], + labelsInTrainingSet: Seq[K], + labelsOfInterest: Option[String], + vwPlugin: VwMultilabelAst) = + MultilabelData( + modelType = MultilabelModel.parser.modelType, + modelId = ModelId(id.getId(), id.getName()), + features = features, + numMissingThreshold = numMissingThreshold, + labelsInTrainingSet = labelsInTrainingSet.toVector, + labelsOfInterest = labelsOfInterest, + underlying = vwPlugin.toJson.asJsObject + ) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala new file mode 100644 index 00000000..235b0fb7 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwParamError.scala @@ -0,0 +1,108 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictor.ExpectedLearner + +/** + * Created by ryan.deak on 10/5/17. + */ +sealed trait VwParamError { + def originalParams: String + def modifiedParams: Option[String] + def errorMessage: String +} + +final case class UnrecoverableParams( + originalParams: String, + unrecoverable: Set[String] +) extends VwParamError { + def modifiedParams: Option[String] = None + def errorMessage: String = + "Encountered parameters that cannot be augmented: " + + unrecoverable.toSeq.sorted.mkString(", ") + + s"\n\toriginal parameters: $originalParams" +} + +final case class NotCsoaaOrWap(originalParams: String) extends VwParamError { + def modifiedParams: Option[String] = None + def errorMessage: String = + "Model must be a csoaa_ldf or wap_ldf model." + s"\n\toriginal parameters: $originalParams" +} + + +final case class IncorrectLearner( + originalParams: String, + modifiedPs: String, + learnerCanonicalName: String +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Found: $learnerCanonicalName." + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class ClassCastErr( + originalParams: String, + modifiedPs: String, + ccException: ClassCastException +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"Params produced an incorrect learner type. " + + s"Expected: ${classOf[ExpectedLearner].getCanonicalName} " + + s"Encountered ClassCastException: ${ccException.getMessage}" + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class VwError( + originalParams: String, + modifiedPs: String, + vwErrMsg: String +) extends VwParamError { + override def modifiedParams: Option[String] = Option(modifiedPs) + override def errorMessage: String = + s"VW could not create a learner of type " + + s"${classOf[ExpectedLearner].getCanonicalName}. Error: $vwErrMsg. " + + s"\n\toriginal parameters: $originalParams" + + s"\n\tmodified parameters: $modifiedPs" +} + +final case class NamespaceError( + originalParams: String, + namespaceNames: Set[String], + flagsReferencingMissingNss: Map[String, Set[Char]] +) extends VwParamError { + override def modifiedParams: Option[String] = None + override def errorMessage: String = { + val flagErrs = + flagsReferencingMissingNss + .toSeq.sortBy(_._1) + .map { case(f, s) => s"--$f: ${s.mkString(",")}" } + .mkString("; ") + val nss = namespaceNames.toVector.sorted.mkString(", ") + val vwNss = VwMultilabelModel.toVwNsSet(namespaceNames).toVector.sorted.mkString(",") + + s"Detected flags referencing namespaces not in the provided set. $flagErrs. " + + ( + if (vwNss.isEmpty) "No namespaces provided." + else s"Expected only $vwNss from provided namespaces: $nss." + ) + + s"\n\toriginal parameters: $originalParams" + } +} + +final case class LabelNamespaceError( + originalParams: String, + namespaceNames: Set[String] +) extends VwParamError { + override def modifiedParams: Option[String] = None + override def errorMessage: String = { + val nss = namespaceNames.toSeq.sorted.mkString(", ") + s"Could not determine label namespaces from namespaces: $nss" + + s"\n\toriginal parameters: $originalParams" + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala new file mode 100644 index 00000000..fe453e84 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictor.scala @@ -0,0 +1,148 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.{Closeable, File} + +import com.eharmony.aloha.dataset.density.Sparse +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.io.sources.ModelSource +import com.eharmony.aloha.models.multilabel.SparseMultiLabelPredictor +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} +import vowpalWabbit.responses.ActionScores + +import scala.collection.{immutable => sci} +import scala.util.Try + +/** + * Creates a VW multi-label predictor plugin for `MultilabelModel`. + * @param modelSource a specification for the underlying ''Cost Sensitive One Against All'' + * VW model with ''label dependent features''. VW flag `--csoaa_ldf mc` + * or `--wap_ldf mc` is expected. For more information, see the + * [[https://github.com/JohnLangford/vowpal_wabbit/wiki/Cost-Sensitive-One-Against-All-(csoaa)-multi-class-example VW CSOAA wiki page]]. + * Also see the ''Cost-Sensitive Multiclass Classification'' section of + * Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html On Multiclass Classification in VW]] + * page. This model specification will be materialized in this class. + * @param defaultNs The list of indices into the `features` sequence that does not have + * an exist in any value of the `namespaces` map. + * @param namespaces Mapping from namespace name to indices in the `features` sequence passed + * to the `apply` method. There should be no empty namespaces, meaning + * ''key-value'' pairs appearing in the map should not values that are empty + * sequences. '''This is a requirement.''' + * @tparam K the label or class type. + * @author deaktator + * @since 9/8/2017 + */ +case class VwSparseMultilabelPredictor[K]( + modelSource: ModelSource, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + numLabelsInTrainingSet: Int) +extends SparseMultiLabelPredictor[K] + with Closeable { + + import VwSparseMultilabelPredictor._ + + @transient private[this] lazy val paramsAndVwModel = + createLearner(modelSource, numLabelsInTrainingSet) + + @transient private[this] lazy val updatedParams = paramsAndVwModel._1 + @transient private[multilabel] lazy val vwModel = paramsAndVwModel._2.get + + + { + val emptyNss = namespaces collect { case (ns, ind) if ind.isEmpty => ns } + require( + emptyNss.isEmpty, + s"There should be no namespaces that are empty. Found: ${emptyNss mkString ", "}" + ) + + // Force creation. + require(vwModel != null) + } + + /** + * Get the VW parameters used to invoke the underlying VW model. + * @return VW parameters. + */ + def vwParams(): String = updatedParams + + /** + * Given the input, form a VW example, and delegate to the underlying ''CSOAA LDF'' VW model. + * @param features (non-label dependent) features shared across all labels. + * @param labels labels for which the VW learner should produce predictions. + * @param indices the indices `labels` into the sequence of all labels encountered + * during training. + * @param labelDependentFeatures Any label dependent features. This is not yet utilized and + * is currently ignored. + * @return a Map from label to prediction. + */ + override def apply( + features: IndexedSeq[Sparse], + labels: sci.IndexedSeq[K], + indices: sci.IndexedSeq[Int], + labelDependentFeatures: sci.IndexedSeq[IndexedSeq[Sparse]] + ): Try[Map[K, Double]] = { + + val x = VwMultilabelRowCreator.predictionInput(features, indices, defaultNs, namespaces, ClassNS) + val pred = Try { vwModel.predict(x) } + val yOut = pred.map { y => produceOutput(y, labels) } + yOut + } + + override def close(): Unit = vwModel.close() +} + +object VwSparseMultilabelPredictor { + private val ClassNS = "Y" + + private[multilabel] val AddlVwRingSize = 10 + + private[multilabel] type ExpectedLearner = VWActionScoresLearner + + /** + * Produce the output given VW's output, `pred`, and the labels provided to the `apply` function. + * @param pred predictions returned by the underlying VW ''CSOAA LDF'' model. + * @param labels the labels provided to the `apply` function. This determines which predictions + * should be produced. + * @tparam K The label or class type. + * @return a map of predictions from label to prediction. + */ + private[multilabel] def produceOutput[K](pred: ActionScores, labels: sci.IndexedSeq[K]): Map[K, Double] = { + (for { + as <- pred.getActionScores + label = labels(as.getAction) + pred = modifiedLogistic(as.getScore) + } yield label -> pred)(collection.breakOut) + } + + /** + * A modified logistic function where the sign of the exponent is opposite the usual + * definition. Since CSOAA in VW employs costs, it returns the negative logits which + * changes the sign of the normal logistic function so the definition becomes + * `1 / (1 + exp(x))`. + * + * @param x an input produced by a VW CSOAA prediction. + * @return a probability. + */ + @inline final private def modifiedLogistic(x: Float) = 1 / (1 + math.exp(x)) + + /** + * Update parameters with initial regressior, ring size, testonly, and quiet. + * @param modelSource a trained VW binary model. + * @param numLabelsInTrainingSet number of labels in the training set informs VW's ring size. + * @return updated parameters. + */ + private[multilabel] def paramsWithSource(modelSource: File, numLabelsInTrainingSet: Int): String = { + val ringSize = numLabelsInTrainingSet + AddlVwRingSize + s"-i ${modelSource.getCanonicalPath} --ring_size $ringSize --testonly --quiet" + } + + private[multilabel] def createLearner( + modelSource: ModelSource, + numLabelsInTrainingSet: Int + ): (String, Try[ExpectedLearner]) = { + val modelFile = modelSource.localVfs.replicatedToLocal() + val updatedParams = paramsWithSource(modelFile.fileObj, numLabelsInTrainingSet) + val learner = Try { VWLearners.create[ExpectedLearner](updatedParams) } + (updatedParams, learner) + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala new file mode 100644 index 00000000..0950499b --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorProducer.scala @@ -0,0 +1,49 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.io.sources.ModelSource +import com.eharmony.aloha.models.multilabel._ +import com.eharmony.aloha.models.vw.jni.multilabel.json.VwMultilabelModelPluginJsonReader +import com.eharmony.aloha.reflect.RefInfo +import spray.json.{JsonFormat, JsonReader} + +/** + * A wrapper responsible for creating a [[VwSparseMultilabelPredictor]]. This defers + * creation since VW JNI models are not Serializable. This is because they are + * thin wrappers around native C models in memory. The underlying binary model is + * externalizable in a file, byte array, etc. and this information is contained in + * `modelSource`. The actual VW JNI model is created only where needed. In short, + * this class can be serialized but the JNI model created from `modelSource` cannot + * be. + * + * @param modelSource a source from which the binary VW model information can be + * extracted and used to create a VW JNI model. + * @param defaultNs indices into the features that belong to VW's default namespace. + * @param namespaces namespace name to feature indices map. + * @tparam K type of the labels returned by the [[VwSparseMultilabelPredictor]] that + * will be produced. + * @author deaktator + * @since 9/5/2017 + */ +case class VwSparseMultilabelPredictorProducer[K]( + modelSource: ModelSource, + defaultNs: List[Int], + namespaces: List[(String, List[Int])], + labelNamespace: Char, + numLabelsInTrainingSet: Int +) extends SparsePredictorProducer[K] { + override def apply(): VwSparseMultilabelPredictor[K] = + VwSparseMultilabelPredictor[K](modelSource, defaultNs, namespaces, numLabelsInTrainingSet) +} + +object VwSparseMultilabelPredictorProducer extends MultilabelPluginProviderCompanion { + def multilabelPlugin: MultilabelModelParserPlugin = Plugin + + object Plugin extends MultilabelModelParserPlugin { + override def name: String = "vw" + + override def parser[K](info: PluginInfo[K]) + (implicit ri: RefInfo[K], jf: JsonFormat[K]): JsonReader[SparsePredictorProducer[K]] = { + VwMultilabelModelPluginJsonReader[K](info.features.keys.toVector, info.labelsInTrainingSet.size) + } + } +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala new file mode 100644 index 00000000..bc41d6f5 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelJson.scala @@ -0,0 +1,22 @@ +package com.eharmony.aloha.models.vw.jni.multilabel.json + +import com.eharmony.aloha.io.sources.ModelSource + +import spray.json.DefaultJsonProtocol._ + +import scala.collection.immutable.ListMap +import com.eharmony.aloha.factory.ScalaJsonFormats + +/** + * Created by ryan.deak on 9/8/17. + */ +trait VwMultilabelModelJson extends ScalaJsonFormats { + + private[multilabel] case class VwMultilabelAst( + `type`: String, + modelSource: ModelSource, + namespaces: Option[ListMap[String, Seq[String]]] = Some(ListMap.empty) + ) + + protected[this] implicit val vwMultilabelAstFormat = jsonFormat3(VwMultilabelAst) +} diff --git a/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala new file mode 100644 index 00000000..ed3bc824 --- /dev/null +++ b/aloha-vw-jni/src/main/scala/com/eharmony/aloha/models/vw/jni/multilabel/json/VwMultilabelModelPluginJsonReader.scala @@ -0,0 +1,64 @@ +package com.eharmony.aloha.models.vw.jni.multilabel.json + +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.{LabelNamespaces, determineLabelNamespaces} +import com.eharmony.aloha.models.multilabel.SparsePredictorProducer +import com.eharmony.aloha.models.vw.jni.Namespaces +import com.eharmony.aloha.models.vw.jni.multilabel.VwSparseMultilabelPredictorProducer +import com.eharmony.aloha.util.Logging +import spray.json.{DeserializationException, JsValue, JsonReader} + +import scala.collection.breakOut +import scala.collection.immutable.ListMap + +/** + * A JSON reader for SparsePredictorProducer. + * + * '''NOTE''': This class extends `JsonReader[SparsePredictorProducer[K]]` rather than the + * more specific type `JsonReader[VwSparseMultilabelPredictorProducer[K]]` because `JsonReader` + * is not covariant in its type parameter. + * + * Created by ryan.deak on 9/8/17. + * + * @param featureNames feature names from the multi-label model. + * @tparam K label type for the predictions outputted by the + */ +case class VwMultilabelModelPluginJsonReader[K](featureNames: Seq[String], numLabelsInTrainingSet: Int) + extends JsonReader[SparsePredictorProducer[K]] + with VwMultilabelModelJson + with Namespaces + with Logging { + + import VwMultilabelModelPluginJsonReader._ + + override def read(json: JsValue): VwSparseMultilabelPredictorProducer[K] = { + val ast = json.asJsObject(notObjErr(json)).convertTo[VwMultilabelAst] + val (namespaces, defaultNs, missing) = + allNamespaceIndices(featureNames, ast.namespaces.getOrElse(ListMap.empty)) + + if (missing.nonEmpty) + info(s"features in namespaces not found in featureNames: $missing") + + val namespaceNames: Set[String] = namespaces.map(_._1)(breakOut) + val labelAndDummyLabelNss = determineLabelNamespaces(namespaceNames) + + labelAndDummyLabelNss match { + case Some(LabelNamespaces(labelNs, _)) => + VwSparseMultilabelPredictorProducer[K](ast.modelSource, defaultNs, namespaces, labelNs, numLabelsInTrainingSet) + case _ => + throw new DeserializationException( + "Could not determine label namespace. Found namespaces: " + + namespaceNames.mkString(", ") + ) + } + } +} + +object VwMultilabelModelPluginJsonReader extends Logging { + private val JsonErrStrLength = 100 + + private[multilabel] def notObjErr(json: JsValue): String = { + val str = json.prettyPrint + val substr = str.substring(0, JsonErrStrLength) + s"JSON object expected. Found " + substr + (if (str.length != substr.length) " ..." else "") + } +} \ No newline at end of file diff --git a/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json new file mode 100644 index 00000000..fdf1bbf2 --- /dev/null +++ b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json @@ -0,0 +1,13 @@ +{ + "imports": [ + "com.eharmony.aloha.feature.BasicFunctions._" + ], + "features": [ + { "name": "feature", "spec": "1" } + ], + "namespaces": [ + { "name": "X", "features": [ "feature" ] } + ], + "normalizeFeatures": false, + "positiveLabels": "${labels_from_input}" +} diff --git a/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json new file mode 100644 index 00000000..8910bc6f --- /dev/null +++ b/aloha-vw-jni/src/test/resources/com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json @@ -0,0 +1,14 @@ +{ + "imports": [ + "com.eharmony.aloha.feature.BasicFunctions._" + ], + "features": [ + { "name": "feature", "spec": "1" } + ], + "namespaces": [ + { "name": "X", "features": [ "feature" ] } + ], + "normalizeFeatures": false, + "positiveLabels": "${labels_from_input}", + "numDownsampledNegLabels": 1 +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala new file mode 100644 index 00000000..44900c9e --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelDownsampledModelTest.scala @@ -0,0 +1,188 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.File + +import com.eharmony.aloha.audit.impl.OptionAuditor +import com.eharmony.aloha.dataset.vw.multilabel.VwDownsampledMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.json.VwDownsampledMultilabeledJson +import com.eharmony.aloha.factory.ModelFactory +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances +import org.apache.commons.vfs2.VFS +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import org.junit.Assert.{assertEquals, fail} +import spray.json.pimpString +import spray.json.DefaultJsonProtocol.IntJsonFormat +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +import scala.collection.breakOut +import scala.util.Random + +/** + * Created by ryan.deak on 11/6/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelDownsampledModelTest { + + import VwMultilabelDownsampledModelTest.EndToEndDatasetSpec + + @Test def test(): Unit = { + + // ------------------------------------------------------------------------------------ + // Preliminaries + // ------------------------------------------------------------------------------------ + + type Lab = Int + type Dom = Vector[Lab] // Not an abuse of notation. Domain is indeed a vector of labels. + + val semantics = CompiledSemanticsInstances.anyNameIdentitySemantics[Dom] + val optAud = OptionAuditor[Map[Lab, Double]]() + + // 10 passes over the data, sampling the negatives. + val repetitions = 8 + + + // ------------------------------------------------------------------------------------ + // Dataset and test example set up. + // ------------------------------------------------------------------------------------ + + // Marginal Distribution: Pr[8] = 0.80 = 20 / 25 = (12 + 8) / 25 + // Pr[4] = 0.40 = 10 / 25 = ( 2 + 8) / 25 + // + val unshuffledTrainingSet: Seq[Dom] = Seq( + Vector( ) -> 3, // pr = 0.12 = 3 / 25 <-- JPD + Vector( 8) -> 12, // pr = 0.48 = 12 / 25 + Vector(4 ) -> 2, // pr = 0.08 = 2 / 25 + Vector(4, 8) -> 8 // pr = 0.32 = 8 / 25 + ) flatMap { + case (k, n) => Vector.fill(n * repetitions)(k) + } + + val trainingSet = new Random(0x0912e40f395bL).shuffle(unshuffledTrainingSet) + + val labelsInTrainingSet = trainingSet.flatten.toSet.toVector.sorted + + val testExample: Dom = Vector.empty + + val marginalDist = labelsInTrainingSet.map { label => + val z = trainingSet.size.toDouble + label -> trainingSet.count(row => row contains label) / z + }.toMap + + // ------------------------------------------------------------------------------------ + // Prepare dataset specification and read training dataset. + // ------------------------------------------------------------------------------------ + + val datasetSpec = VFS.getManager.resolveFile(EndToEndDatasetSpec) + val datasetJson = + StringReadable + .fromVfs2(datasetSpec) + .parseJson + .convertTo[VwDownsampledMultilabeledJson] + + val rc = + new VwDownsampledMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet, () => 0L). + getRowCreator(semantics, datasetJson).get + + + // ------------------------------------------------------------------------------------ + // Prepare parameters for VW model that will be trained. + // ------------------------------------------------------------------------------------ + + val binaryVwModel = File.createTempFile("vw_", ".bin.model") + binaryVwModel.deleteOnExit() + + val cacheFile = File.createTempFile("vw_", ".cache") + cacheFile.deleteOnExit() + + val origParams = + s""" + | --quiet + | --cache_file ${cacheFile.getCanonicalPath} + | --holdout_off + | --passes 10 + | --learning_rate 5 + | --decay_learning_rate 0.9 + | --csoaa_ldf mc + | --loss_function logistic + | -f ${binaryVwModel.getCanonicalPath} + """.stripMargin.trim.replaceAll("\n", " ") + + // Get the namespace names from the row creator. + val nsNames = rc.namespaces.map(_._1)(breakOut): Set[String] + + // Take the parameters and augment with additional parameters to make + // multilabel w/ probabilities work correctly. + val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames, 2) fold ( + e => throw new Exception(e.errorMessage), + ps => ps + ) + + // ------------------------------------------------------------------------------------ + // Train VW model + // ------------------------------------------------------------------------------------ + + // Get the iterator of the examples produced. This is similar to what one may do + // within a `mapPartitions` in Spark. + val examples = rc.statefulMap(trainingSet.iterator, rc.initialState) collect { + case ((_, Some(x)), _) => x + } + + val vwLearner = VWLearners.create[VWActionScoresLearner](vwParams) + + examples foreach { yx => + vwLearner.learn(yx) + } + + vwLearner.close() + + + // ------------------------------------------------------------------------------------ + // Create Aloha model JSON + // ------------------------------------------------------------------------------------ + + val modelJson = VwMultilabelModel.json( + datasetSpec = Vfs.apacheVfs2ToAloha(datasetSpec), + binaryVwModel = Vfs.javaFileToAloha(binaryVwModel), + id = ModelId(1, "NONE"), + labelsInTrainingSet = labelsInTrainingSet, + labelsOfInterest = Option.empty[String], + externalModel = false, + numMissingThreshold = Option(0) + ) + + // ------------------------------------------------------------------------------------ + // Instantiate Aloha Model + // ------------------------------------------------------------------------------------ + + val factory = ModelFactory.defaultFactory(semantics, optAud) + val modelTry = factory.fromString(modelJson.prettyPrint) // Use `.compactPrint` in prod. + val model = modelTry.get + + + // ------------------------------------------------------------------------------------ + // Test Aloha Model + // ------------------------------------------------------------------------------------ + + val output = model(testExample) + model.close() + + output match { + case None => fail() + case Some(m) => + assertEquals(marginalDist.keySet, m.keySet) + marginalDist foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } + } +} + +object VwMultilabelDownsampledModelTest { + private val EndToEndDatasetSpec = + "res:com/eharmony/aloha/models/vw/jni/multilabel/downsampled_neg_dataset_spec.json" +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala new file mode 100644 index 00000000..cf003042 --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelModelTest.scala @@ -0,0 +1,504 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.File + +import com.eharmony.aloha.audit.impl.OptionAuditor +import com.eharmony.aloha.audit.impl.tree.{RootedTree, RootedTreeAuditor} +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator.LabelNamespaces +import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson +import com.eharmony.aloha.factory.ModelFactory +import com.eharmony.aloha.id.ModelId +import com.eharmony.aloha.io.StringReadable +import com.eharmony.aloha.io.sources.{ExternalSource, ModelSource} +import com.eharmony.aloha.io.vfs.Vfs +import com.eharmony.aloha.models.Model +import com.eharmony.aloha.models.multilabel.MultilabelModel +import com.eharmony.aloha.semantics.compiled.CompiledSemanticsInstances +import com.eharmony.aloha.semantics.func.GenFunc0 +import org.apache.commons.vfs2.VFS +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import spray.json.DefaultJsonProtocol.{IntJsonFormat, StringJsonFormat} +import spray.json.{DefaultJsonProtocol, JsonWriter, RootJsonFormat, pimpString} +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +import scala.annotation.tailrec +import scala.collection.breakOut +import scala.util.Random + +/** + * Created by ryan.deak on 9/11/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelModelTest { + import VwMultilabelModelTest._ + + @Test def testTrainedModelWorks(): Unit = { + val model = Model + + try { + // Test the no label case. + testEmpty(model("")) + + // Test all elements of the power set of labels (except the empty set, done above). + for { + labels <- powerSet(AllLabels.toSet).filter(ls => ls.nonEmpty) + x = labels.mkString(",") + y = model(x) + } testOutput(y, labels, ExpectedMarginalDist) + } + finally { + model.close() + } + } + + @Test def testTrainedModelCanBeParsedAndUsed(): Unit = + testModel(None) + + @Test def testTrainedModelCanBeParsedAndUsedWithLabels67(): Unit = + testModel(Option(Set(LabelSix, LabelSeven))) + + @Test def testTrainedModelCanBeParsedAndUsedWithLabels678(): Unit = + testModel(Option(Set(LabelSix, LabelSeven, LabelEight))) + + @Test def testTrainedModelCanBeParsedAndUsedWithNoLabels(): Unit = + testModel(Option(Set.empty), modelShouldProduceOutput = false) + + // This is really more of an integration test. + @Test def testTrainingAndTesting(): Unit = { + + // ------------------------------------------------------------------------------------ + // Preliminaries + // ------------------------------------------------------------------------------------ + + type Lab = Int + type Dom = Vector[Lab] // Not an abuse of notation. Domain is indeed a vector of labels. + + val semantics = CompiledSemanticsInstances.anyNameIdentitySemantics[Dom] + val optAud = OptionAuditor[Map[Lab, Double]]() + + // ------------------------------------------------------------------------------------ + // Dataset and test example set up. + // ------------------------------------------------------------------------------------ + + // Marginal Distribution: Pr[8] = 0.80 = 20 / 25 = (12 + 8) / 25 + // Pr[4] = 0.40 = 10 / 25 = ( 2 + 8) / 25 + // + val unshuffledTrainingSet: Seq[Dom] = Seq( + Vector( ) -> 3, // pr = 0.12 = 3 / 25 <-- JPD + Vector( 8) -> 12, // pr = 0.48 = 12 / 25 + Vector(4 ) -> 2, // pr = 0.08 = 2 / 25 + Vector(4, 8) -> 8 // pr = 0.32 = 8 / 25 + ) flatMap { + case (k, n) => Vector.fill(n)(k) + } + + val trainingSet = new Random(0).shuffle(unshuffledTrainingSet) + + val labelsInTrainingSet = trainingSet.flatten.toSet.toVector.sorted + + val testExample: Dom = Vector.empty + + val marginalDist = labelsInTrainingSet.map { label => + val z = trainingSet.size.toDouble + label -> trainingSet.count(row => row contains label) / z + }.toMap + + // ------------------------------------------------------------------------------------ + // Prepare dataset specification and read training dataset. + // ------------------------------------------------------------------------------------ + + val datasetSpec = VFS.getManager.resolveFile(EndToEndDatasetSpec) + val datasetJson = StringReadable.fromVfs2(datasetSpec).parseJson.convertTo[VwMultilabeledJson] + + val rc = + new VwMultilabelRowCreator.Producer[Dom, Lab](labelsInTrainingSet). + getRowCreator(semantics, datasetJson).get + + + // ------------------------------------------------------------------------------------ + // Prepare parameters for VW model that will be trained. + // ------------------------------------------------------------------------------------ + + val binaryVwModel = File.createTempFile("vw_", ".bin.model") + binaryVwModel.deleteOnExit() + + val cacheFile = File.createTempFile("vw_", ".cache") + cacheFile.deleteOnExit() + + val origParams = + s""" + | --quiet + | --csoaa_ldf mc + | --loss_function logistic + | -f ${binaryVwModel.getCanonicalPath} + | --passes 40 + | --cache_file ${cacheFile.getCanonicalPath} + | --holdout_off + | --learning_rate 5 + | --decay_learning_rate 0.9 + """.stripMargin.trim.replaceAll("\n", " ") + + // Get the namespace names from the row creator. + val nsNames = rc.namespaces.map(_._1)(breakOut): Set[String] + + // Take the parameters and augment with additional parameters to make + // multilabel w/ probabilities work correctly. + val vwParams = VwMultilabelModel.updatedVwParams(origParams, nsNames, 2) fold ( + e => throw new Exception(e.errorMessage), + ps => ps + ) + + + // ------------------------------------------------------------------------------------ + // Train VW model + // ------------------------------------------------------------------------------------ + + val vwLearner = VWLearners.create[VWActionScoresLearner](vwParams) + trainingSet foreach { row => + val x = rc(row)._2 + vwLearner.learn(x) + } + vwLearner.close() + + + // ------------------------------------------------------------------------------------ + // Create Aloha model JSON + // ------------------------------------------------------------------------------------ + + val modelJson = VwMultilabelModel.json( + datasetSpec = Vfs.apacheVfs2ToAloha(datasetSpec), + binaryVwModel = Vfs.javaFileToAloha(binaryVwModel), + id = ModelId(1, "NONE"), + labelsInTrainingSet = labelsInTrainingSet, + labelsOfInterest = Option.empty[String], + externalModel = false, + numMissingThreshold = Option(0) + ) + + + // ------------------------------------------------------------------------------------ + // Instantiate Aloha Model + // ------------------------------------------------------------------------------------ + + val factory = ModelFactory.defaultFactory(semantics, optAud) + val modelTry = factory.fromString(modelJson.prettyPrint) // Use `.compactPrint` in prod. + val model = modelTry.get + + + // ------------------------------------------------------------------------------------ + // Test Aloha Model + // ------------------------------------------------------------------------------------ + + val output = model(testExample) + model.close() + + output match { + case None => fail() + case Some(m) => + assertEquals(marginalDist.keySet, m.keySet) + marginalDist foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } + } + + /** + * + * @param desiredLabels Notice since this is a Set, label order doesn't matter. + * @param modelShouldProduceOutput whether the model is expected to + */ + private[this] def testModel( + desiredLabels: Option[Set[Label]], + modelShouldProduceOutput: Boolean = true + ): Unit = { + + val factory = + ModelFactory.defaultFactory( + CompiledSemanticsInstances.anyNameIdentitySemantics[Set[Label]], + OptionAuditor[Map[Label, Double]]() + ) + + // "${desired_labels_from_input}.toVector" returns the input passed to the model in + // the form of a Vector. See anyNameIdentitySemantics for more information. + val desiredLabelFn = desiredLabels map (_ => "${desired_labels_from_input}.toVector") + val json = modelJson(TrainedModel, AllLabels, desiredLabelFn) + + val modelTry = factory.fromString(json) + val model = modelTry.get // : Model[Set[Label], Option[Map[Label, Double]]] + val x = desiredLabels getOrElse Set.empty + val y = model(x) + model.close() + + y match { + case None => + if (modelShouldProduceOutput) + fail("Model should produce output. Produced None.") + case Some(m) => + if (!modelShouldProduceOutput) + fail(s"Model should not produce output. Produced: $m") + else { + val expected = desiredLabels match { + case None => ExpectedMarginalDist + case Some(labels) => ExpectedMarginalDist.filterKeys(labels.contains) + } + + // The keys should be the same set, not a super set. + assertEquals(expected.keySet, m.keySet) + + // Test the value associated with the label. + expected foreach { case (k, v) => + assertEquals(s"For key '$k':", v, m(k), 0.01) + } + } + } + } + + + private[this] def testEmpty(yEmpty: PredictionOutput): Unit = { + assertEquals(None, yEmpty.value) + assertEquals(Vector("No labels provided. Cannot produce a prediction."), yEmpty.errorMsgs) + assertEquals(Set.empty, yEmpty.missingVarNames) + assertEquals(None, yEmpty.prob) + } + + private[this] def testOutput( + y: PredictionOutput, + labels: Set[Label], + expectedMarginalDist: Map[Label, Double] + ): Unit = y.value match { + case None => fail() + case Some(labelMap) => + // Should that model gives back all the keys that were request. + assertEquals(labels, labelMap.keySet) + + // The expected return value. We check keys first to make sure the model + // doesn't return extraneous key-value pairs. + val exp = expectedMarginalDist.filterKeys(label => labels contains label) + + // Check that the values returned are correct. + exp foreach { case (label, expPr) => + assertEquals(expPr, labelMap(label), 0.01) + } + } +} + +object VwMultilabelModelTest { + private type Label = String + private type Domain = String + private type PredictionOutput = RootedTree[Any, Map[Label, Double]] + + private[this] val TrainingEpochs = 30 + + private val LabelSix = "six" + private val LabelSeven = "seven" + private val LabelEight = "eight" + + private val ExpectedMarginalDist = Map( + LabelSix -> 0.6, + LabelSeven -> 0.7, + LabelEight -> 0.8 + ) + + /** + * The order here is important because it's the same as the indices in the training data. + - LabelSeven is _C0_ (index 0) + - LabelEight is _C1_ (index 1) + - LabelSix is _C2_ (index 2) + */ + private val AllLabels = Vector(LabelSeven, LabelEight, LabelSix) + + private val EndToEndDatasetSpec = + "res:com/eharmony/aloha/models/vw/jni/multilabel/dataset_spec.json" + + private lazy val Model: Model[Domain, RootedTree[Any, Map[Label, Double]]] = { + val featureNames = Vector(FeatureName) + val features = Vector(GenFunc0("", (_: Domain) => Iterable(("", 1d)))) + + // Get the list of labels from the comma-separated list passed in the input string. + val labelsOfInterestFn = + GenFunc0("", + (x: Domain) => + x.split("\\s*,\\s*") + .map(label => label.trim) + .filter(label => label.nonEmpty) + .toVector + ) + + val namespaces = List(("X", List(0))) + val labelNs = VwMultilabelRowCreator.determineLabelNamespaces(namespaces.unzip._1.toSet).get.labelNs + + + val predProd = VwSparseMultilabelPredictorProducer[Label]( + modelSource = TrainedModel, + defaultNs = List.empty[Int], + namespaces = namespaces, + labelNamespace = labelNs, + numLabelsInTrainingSet = AllLabels.size + ) + + MultilabelModel( + modelId = ModelId(1, "model"), + featureNames = featureNames, + featureFunctions = features, + labelsInTrainingSet = AllLabels, + labelsOfInterest = Option(labelsOfInterestFn), + predictorProducer = predProd, + numMissingThreshold = Option.empty[Int], + auditor = Auditor) + } + + private def tmpFile() = { + val f = File.createTempFile(classOf[VwMultilabelModelTest].getSimpleName + "_", ".vw.model") + f.deleteOnExit() + f + } + + private val LabelNamespaces(labelNs, dummyLabelNs) = + VwMultilabelRowCreator.determineLabelNamespaces(Set.empty).get + + private def vwTrainingParams(modelFile: File) = { + + // NOTES: + // 1. `--csoaa_rank` is needed by VW to make a VWActionScoresLearner. + // 2. `-q YX` forms the cross product of features between the label-based features in Y + // and the side information in X. If the features in namespace Y are unique to each + // class, the cross product effectively makes one model per class by interacting the + // class features to the features in X. Since the class labels vary (possibly) + // independently, this cross product is capable of creating independent models + // (one per class). + // 3. `--ignore_linear Y` is provided because in the data, there is one constant feature + // in the Y namespace per class. This flag acts essentially the same as the + // `--noconstant` flag in the traditional binary classification context. It omits the + // one feature related to the class (the intercept). + // 4. `--noconstant` is specified because there's no need to have a "common intercept" + // whose value is the same across all models. This should be the job of the + // first-order features in Y (the per-class intercepts). + // 5. `--ignore y` is used to ignore the features in the namespace related to the dummy + // classes. We need these two dummy classes in training to make `--csoaa_ldf mc` + // output proper probabilities. + // 6. `--link logistic` doesn't actually seem to do anything. + // 7. `--loss_function logistic` works properly during training; however, when + // interrogating the scores, it is important to remember they appear as + // '''negative logits'''. This is because the CSOAA algorithm uses '''costs''', so + // "smaller is better". So, to get the probability, one must do `1/(1 + exp(-1 * -y))` + // or simply `1/(1 + exp(y))`. + val flags = + s""" + | --quiet + | --csoaa_ldf mc + | --csoaa_rank + | --loss_function logistic + | -q ${labelNs}X + | --noconstant + | --ignore_linear X + | --ignore $dummyLabelNs + | -f + """.stripMargin.trim + + (flags + " " + modelFile.getCanonicalPath).split("\n").map(_.trim).mkString(" ") + } + + private val FeatureName = "feature" + + /** + * A dataset that creates the following marginal distribution. + - Pr[seven] = 0.7 where seven is _0 + - Pr[eight] = 0.8 where eight is _1 + - Pr[six] = 0.6 where six is _2 + * + * The observant reader may notice these are oddly ordered. On each line _1 appears first, + * then _0, then _2. This is done to show ordering doesn't matter. What matters is the + * class '''indices'''. + */ + private val TrainingData = + + Vector( + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.084 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:-0.084 |$labelNs _0\n2:-0.084 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.024 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:0.0 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.336 |$dummyLabelNs pos\n1:-0.336 |$labelNs _1\n0:-0.336 |$labelNs _0\n2:-0.336 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.056 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:-0.056 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.144 |$dummyLabelNs pos\n1:-0.144 |$labelNs _1\n0:0.0 |$labelNs _0\n2:-0.144 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.224 |$dummyLabelNs pos\n1:-0.224 |$labelNs _1\n0:-0.224 |$labelNs _0\n2:0.0 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.036 |$dummyLabelNs pos\n1:0.0 |$labelNs _1\n0:0.0 |$labelNs _0\n2:-0.036 |$labelNs _2", + s"shared |X $FeatureName\n2147483648:0.0 |$dummyLabelNs neg\n2147483649:-0.096 |$dummyLabelNs pos\n1:-0.096 |$labelNs _1\n0:0.0 |$labelNs _0\n2:0.0 |$labelNs _2" + ).map(_.split("\n")) + + private lazy val TrainedModel: ModelSource = { + val modelFile = tmpFile() + val params = vwTrainingParams(modelFile) + val learner = VWLearners.create[VWActionScoresLearner](params) + for { + _ <- 1 to TrainingEpochs + d <- TrainingData + } learner.learn(d) + + learner.close() + + ExternalSource(Vfs.javaFileToAloha(modelFile)) + } + + private val Auditor = RootedTreeAuditor.noUpperBound[Map[Label, Double]]() + + private implicit def vecWriter[K: JsonWriter]: RootJsonFormat[Vector[K]] = + DefaultJsonProtocol.vectorFormat(DefaultJsonProtocol.lift(implicitly[JsonWriter[K]])) + + private[multilabel] def modelJson[K: JsonWriter]( + modelSource: ModelSource, + labelsInTrainingSet: Vector[K], + labelsOfInterest: Option[String] = None) = { + + val loi = labelsOfInterest.fold(""){ f => + val escaped = f.replaceAll("\"", "\\\"") + s""""labelsOfInterest": "$escaped",\n""" + } + + val json = + s""" + |{ + | "modelType": "SparseMultilabel", + | "modelId": { "id": 1, "name": "NONE" }, + | "features": { + | "feature": "1" + | }, + | "numMissingThreshold": 0, + | "labelsInTrainingSet": ${toJsonString(labelsInTrainingSet)}, + |$loi + | "underlying": { + | "type": "vw", + | "modelSource": ${toJsonString(modelSource)}, + | "namespaces": { + | "X": [ + | "feature" + | ] + | } + | } + |} + """.stripMargin.trim + json + } + + private def toJsonString[A: JsonWriter](a: A): String = + implicitly[JsonWriter[A]].write(a).compactPrint + + + /** + * Creates the power set of the provided set. + * Answer provided by Chris Marshall (@oxbow_lakes) on + * [[https://stackoverflow.com/a/11581323 Stack Overflow]]. + * @param generatorSet a set for which a power set should be produced. + * @tparam A type of elements in the set. + * @return the power set of `generatorSet`. + */ + private def powerSet[A](generatorSet: Set[A]): Set[Set[A]] = { + @tailrec def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + // powerset of empty set is the set of the empty set. + pwr(generatorSet, Set(Set.empty[A])) + } +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala new file mode 100644 index 00000000..0ed7d478 --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwMultilabelParamAugmentationTest.scala @@ -0,0 +1,290 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner + +/** + * Created by ryan.deak on 10/6/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwMultilabelParamAugmentationTest extends VwMultilabelParamAugmentation { + import VwMultilabelParamAugmentationTest._ + + @Test def testNotCsoaaWap(): Unit = { + val args = "" + VwMultilabelModel.updatedVwParams(args, Set.empty, DefaultNumLabels) match { + case Left(NotCsoaaOrWap(ps)) => assertEquals(args, ps) + case _ => fail() + } + } + + @Test def testExpectedUnrecoverableFlags(): Unit = { + assertEquals( + "Unrecoverable flags has changed.", + Set("redefine", "stage_poly", "keep", "permutations", "autolink"), + UnrecoverableFlagSet + ) + } + + @Test def testUnrecoverable(): Unit = { + val unrec = UnrecoverableFlagSet.iterator.map { f => + VwMultilabelModel.updatedVwParams(s"--csoaa_ldf mc --$f", Set.empty, DefaultNumLabels) + }.toList + + unrec foreach { + case Left(UnrecoverableParams(p, us)) => + assertEquals( + p, + us.map(u => s"--$u") + .mkString("--csoaa_ldf mc ", " ", "") + ) + case _ => fail() + } + } + + @Test def testIgnoredNotInNsSet(): Unit = { + val args = "--csoaa_ldf mc --ignore a" + val origNss = Set.empty[String] + VwMultilabelModel.updatedVwParams(args, origNss, DefaultNumLabels) match { + case Left(NamespaceError(o, nss, bad)) => + assertEquals(args, o) + assertEquals(origNss, nss) + assertEquals(Map("ignore" -> Set('a')), bad) + case _ => fail() + } + } + + @Test def testIgnoredNotInNsSet2(): Unit = { + val args = "--csoaa_ldf mc --ignore ab" + val origNss = Set("a") + VwMultilabelModel.updatedVwParams(args, origNss, DefaultNumLabels) match { + case Left(NamespaceError(o, nss, bad)) => + assertEquals(args, o) + assertEquals(origNss, nss) + assertEquals(Map("ignore" -> Set('b')), bad) + case _ => fail() + } + } + + @Test def testNamespaceErrors(): Unit = { + val args = "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd " + + "--cubic bcd --interactions dde --interactions abcde" + val updated = updatedVwParams(args, Set(), DefaultNumLabels) + + val exp = Left( + NamespaceError( + "--wap_ldf m --ignore_linear b --ignore a -qbb -qbd --cubic bcd " + + "--interactions dde --interactions abcde", + Set(), + Map( + "ignore" -> Set('a'), + "ignore_linear" -> Set('b'), + "quadratic" -> Set('b', 'd'), + "cubic" -> Set('b', 'c', 'd', 'e'), + "interactions" -> Set('a', 'b', 'c', 'd', 'e') + ) + ) + ) + + assertEquals(exp, updated) + } + + @Test def testNoAvailableLabelNss(): Unit = { + // All namespaces taken. + val nss = (Char.MinValue to Char.MaxValue).map(_.toString).toSet + val validArgs = "--csoaa_ldf mc" + + VwMultilabelModel.updatedVwParams(validArgs, nss, DefaultNumLabels) match { + case Left(LabelNamespaceError(orig, nssOut)) => + assertEquals(validArgs, orig) + assertEquals(nss, nssOut) + case _ => fail() + } + } + + @Test def testBadVwFlag(): Unit = { + val args = "--wap_ldf m --NO_A_VALID_VW_FLAG" + + val exp = VwError( + args, + s"--wap_ldf m --NO_A_VALID_VW_FLAG --noconstant --csoaa_rank $DefaultRingSize --ignore y", + "unrecognised option '--NO_A_VALID_VW_FLAG'" + ) + + VwMultilabelModel.updatedVwParams(args, Set.empty, DefaultNumLabels) match { + case Left(e) => assertEquals(exp, e) + case _ => fail() + } + } + + @Test def testQuadraticCreation(): Unit = { + val args = "--csoaa_ldf mc" + val nss = Set("abc", "bcd") + + // Notice: ignore_linear and quadratics are in sorted order. + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear ab -qYa -qYb" + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testIgnoredNoQuadraticCreation(): Unit = { + val args = "--csoaa_ldf mc --ignore_linear a" + val nss = Set("abc", "bcd") + + // Notice: ignore_linear and quadratics are in sorted order. + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear ab -qYb" + + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreation(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb" + val nss = Set("abc", "bcd", "cde", "def") + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear abcd " + + "-qYa -qYb -qYc -qYd " + + "--cubic Yab --cubic Ybc" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreationIgnoredLinear(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb --ignore_linear d" + val nss = Set("abc", "bcd", "cde", "def") + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear abcd " + + "-qYa -qYb -qYc " + + "--cubic Yab --cubic Ybc" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicCreationIgnored(): Unit = { + val args = "--csoaa_ldf mc -qab --quadratic cb --ignore c" + val nss = Set("abc", "bcd", "cde", "def") + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore cy " + + "--ignore_linear abd " + + "-qYa -qYb -qYd " + + "--cubic Yab" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testCubicWithInteractionsCreationIgnored(): Unit = { + val args = "--csoaa_ldf mc --interactions ab --interactions cb --ignore c --ignore d" + val nss = Set("abc", "bcd", "cde", "def") + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore cdy " + + "--ignore_linear ab " + + "-qYa -qYb " + + "--cubic Yab" + + // Notice: ignore_linear and quadratics are in sorted order. + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testHigherOrderInteractions(): Unit = { + val args = "--csoaa_ldf mc --interactions abcd --ignore_linear abcd" + val nss = Set("abc", "bcd", "cde", "def") + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear abcd " + + "--interactions Yabcd" + + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testMultipleInteractions(): Unit = { + val nss = ('a' to 'e').map(_.toString).toSet + + val args = s"--csoaa_ldf mc --interactions ab --interactions abc " + + "--interactions abcd --interactions abcde" + + val exp = s"--csoaa_ldf mc --noconstant --csoaa_rank $DefaultRingSize --ignore y " + + "--ignore_linear abcde " + + "-qYa -qYb -qYc -qYd -qYe " + + "--cubic Yab " + + "--interactions Yabc " + + "--interactions Yabcd " + + "--interactions Yabcde" + + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case _ => fail() + } + } + + @Test def testInteractionsWithSelf(): Unit = { + val nss = Set("a") + val args = "--wap_ldf m -qaa --cubic aaa --interactions aaaa" + val exp = s"--wap_ldf m --noconstant --csoaa_rank $DefaultRingSize --ignore y --ignore_linear a " + + "-qYa " + + "--cubic Yaa " + + "--interactions Yaaa " + + "--interactions Yaaaa" + + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Right(s) => assertEquals(exp, s) + case x => assertEquals("", x) + } + } + + @Test def testClassLabels(): Unit = { + val args = "--wap_ldf m" + val nss = Set.empty[String] + + VwMultilabelModel.updatedVwParams(args, nss, DefaultNumLabels) match { + case Left(_) => fail() + case Right(p) => + val ignored = + VwMultilabelRowCreator.determineLabelNamespaces(nss) match { + case None => fail() + case Some(labelNs) => + val ignoredNs = + Ignore + .findAllMatchIn(p) + .map(m => m.group(CaptureGroupWithContent)) + .reduce(_ + _) + .toCharArray + .toSet + + assertEquals(Set('y'), ignoredNs) + labelNs.labelNs + } + + assertEquals('Y', ignored) + } + } +} + +object VwMultilabelParamAugmentationTest { + val DefaultNumLabels = 0 + val DefaultRingSize = + s"--ring_size ${DefaultNumLabels + VwSparseMultilabelPredictor.AddlVwRingSize}" +} diff --git a/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala new file mode 100644 index 00000000..2568a47c --- /dev/null +++ b/aloha-vw-jni/src/test/scala/com/eharmony/aloha/models/vw/jni/multilabel/VwSparseMultilabelPredictorTest.scala @@ -0,0 +1,73 @@ +package com.eharmony.aloha.models.vw.jni.multilabel + +import java.io.{ByteArrayOutputStream, File, FileInputStream} + +import com.eharmony.aloha.ModelSerializationTestHelper +import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.IOUtils +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.BlockJUnit4ClassRunner +import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners} + +/** + * Created by ryan.deak on 9/27/17. + */ +@RunWith(classOf[BlockJUnit4ClassRunner]) +class VwSparseMultilabelPredictorTest extends ModelSerializationTestHelper { + import VwSparseMultilabelPredictorTest._ + + @Test def testSerializability(): Unit = { + val predictor = getPredictor(getModelSource(), 3) + val ds = serializeDeserializeRoundTrip(predictor) + assertEquals(predictor, ds) + assertEquals(predictor.vwParams(), ds.vwParams()) + assertNotNull(ds.vwModel) + } + + @Test def testVwParameters(): Unit = { + val numLabelsInTrainingSet = 3 + val predictor = getPredictor(getModelSource(), numLabelsInTrainingSet) + + predictor.vwParams() match { + case Data(vwBinFilePath, ringSize) => + checkVwBinFile(vwBinFilePath) + checkVwRingSize(numLabelsInTrainingSet, ringSize.toInt) + case ps => fail(s"Unexpected VW parameters format. Found string: $ps") + } + } +} + +object VwSparseMultilabelPredictorTest { + private val Data = """\s*-i\s+(\S+)\s+--ring_size\s+(\d+)\s+--testonly\s+--quiet""".r + + private def getModelSource(): ModelSource = { + val f = File.createTempFile("i_dont", "care") + f.deleteOnExit() + val learner = VWLearners.create[VWActionScoresLearner](s"--quiet --csoaa_ldf mc --csoaa_rank -f ${f.getCanonicalPath}") + learner.close() + val baos = new ByteArrayOutputStream() + IOUtils.copy(new FileInputStream(f), baos) + val src = Base64StringSource(Base64.encodeBase64URLSafeString(baos.toByteArray)) + ExternalSource(src.localVfs) + } + + private def getPredictor(modelSrc: ModelSource, numLabelsInTrainingSet: Int) = + VwSparseMultilabelPredictor[Any](modelSrc, Nil, Nil, numLabelsInTrainingSet) + + private def checkVwBinFile(vwBinFilePath: String): Unit = { + val vwBinFile = new File(vwBinFilePath) + assertTrue("VW binary file should have been written to disk", vwBinFile.exists()) + vwBinFile.deleteOnExit() + } + + private def checkVwRingSize(numLabelsInTrainingSet: Int, ringSize: Int): Unit = { + assertEquals( + "vw --ring_size parameter is incorrect:", + numLabelsInTrainingSet + VwSparseMultilabelPredictor.AddlVwRingSize, + ringSize.toInt + ) + } +} \ No newline at end of file