-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closes #163 #193
Closes #163 #193
Changes from 92 commits
773b3b1
66817f7
c2516f5
73d378b
48a6157
0ebc108
8ba3e6e
5931ef0
5e1f1dc
0aea10c
85229f9
7812019
f78e4bd
f21fa4f
a9a0f8b
2d6de6f
6defbe0
7661df2
9fd15ad
ce66b7c
75dee22
599e336
2e0f4eb
3dc2921
3b2652b
dd7dc71
999348b
a38b5ea
a5c5259
2657672
163609e
71e21aa
ed08228
e450ce1
77d57e0
6dc5636
c6152ac
ff1725a
79422ab
7546262
6d5671d
6f45145
8d1caaa
f71aba5
e8159fc
6f98ce2
47c0658
c54d915
1356881
4c00c6f
a0a2c0d
86890c1
daffd68
8280ef7
80fc41c
a18a827
406b1f9
dc363e6
b0c1b77
0d35d4c
201a823
a42c9d7
c8902ec
a1ce0dd
d1a40f3
ddf40fd
e3f9c59
9dd35ee
82746e7
5ab8e7c
2023e74
86c5c50
b3fdff4
7087d04
2a0b5c0
e11aef1
c2a048e
0ce7d68
3cb9d2e
fb7a8f2
f86b610
4c25972
79f988b
9f61597
ca84c36
db8967f
6f89630
95e0c92
a7b168a
9c4362b
6df3f83
cd50006
ffd10ee
9592b09
8e478f3
b7594fa
819095a
b190821
5e1ebd0
0889971
d60f226
fe8dca2
baf7466
22d9547
c463c11
79fddb6
1c0102c
6e21f8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,361 @@ | ||
package com.eharmony.aloha.dataset.vw.multilabel | ||
|
||
import com.eharmony.aloha.AlohaException | ||
import com.eharmony.aloha.dataset._ | ||
import com.eharmony.aloha.dataset.density.Sparse | ||
import com.eharmony.aloha.dataset.vw.VwCovariateProducer | ||
import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson | ||
import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator | ||
import com.eharmony.aloha.reflect.RefInfo | ||
import com.eharmony.aloha.semantics.compiled.CompiledSemantics | ||
import com.eharmony.aloha.semantics.func.GenAggFunc | ||
import spray.json.JsValue | ||
|
||
import scala.collection.{breakOut, immutable => sci} | ||
import scala.util.{Failure, Success, Try} | ||
|
||
/** | ||
* Created by ryan.deak on 9/13/17. | ||
*/ | ||
final case class VwMultilabelRowCreator[-A, K]( | ||
allLabelsInTrainingSet: sci.IndexedSeq[K], | ||
featuresFunction: FeatureExtractorFunction[A, Sparse], | ||
defaultNamespace: List[Int], | ||
namespaces: List[(String, List[Int])], | ||
normalizer: Option[CharSequence => CharSequence], | ||
positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]], | ||
classNs: Char, | ||
dummyClassNs: Char, | ||
includeZeroValues: Boolean = false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're missing something important here. We have experimentally shown that when training on a very large multilabel problem the ability to randomly sample from the negative labels at train time both decreases train time and can increase model performance. I think that a parameter to the row creator should be a function that allows me to sample a percentage of negative labels to add to each train line. This needs to be random and the negative labels added needs to change with EVERY line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I like that. Let me think on that for a few and come up with something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jmorra I think the requirement that "negative labels added need to change with EVERY line" is important, but I don't think it fits into the current architecture. It seems like you are (at least implicitly) advocating two things:
What if we have a training set that is distributed across multiple nodes. We apply the
What do you think about this stuff? I think if we try to jam this type of mutability into the current interface, it'll most likely produce disastrous results. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jmorra: I am proposing the following trait: trait StatefulRowCreator[-A, +B, S] extends Serializable {
// An initial state derived from the constructor
def initialState: S
// Caller maintains state.
def apply(a: A, s: S): ((MissingAndErroneousFeatureInfo, Option[B]), S)
// Iterator doesn't return state because it's non-strict. Output Iterator length
// is same as input iterator length. May be infinite length. This should be a constant
// time operation.
def apply(as: Iterator[A], s: S): Iterator[(MissingAndErroneousFeatureInfo, Option[B])]
// Output Vector length is same as input length. Returns the final state.
// This is a linear time operation.
def apply(as: Vector[A], s: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S)
} I think it was a mistake in retrospect to not have I'm happy to change the function names but I like the signatures. ( We could consider putting the state as the first parameter at least in the Iterator and Vector methods so that we can curry and supply the initial state to get back a function of one argument. |
||
) extends RowCreator[A, Array[String]] { | ||
import VwMultilabelRowCreator._ | ||
|
||
@transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap | ||
|
||
// Precompute these for efficiency rather recompute than inside a hot loop. | ||
// Notice these are not lazy vals. | ||
|
||
private[this] val negativeDummyStr = | ||
s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature" | ||
|
||
private[this] val positiveDummyStr = | ||
s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature" | ||
|
||
override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = { | ||
val (missingAndErrs, features) = featuresFunction(a) | ||
|
||
// TODO: Should this be sci.BitSet? | ||
val positiveIndices: Set[Int] = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to consider There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut) | ||
|
||
val x: Array[String] = trainingInput( | ||
features, | ||
allLabelsInTrainingSet.indices, | ||
positiveIndices, | ||
defaultNamespace, | ||
namespaces, | ||
classNs, | ||
negativeDummyStr, | ||
positiveDummyStr | ||
) | ||
|
||
(missingAndErrs, x) | ||
} | ||
} | ||
|
||
object VwMultilabelRowCreator { | ||
|
||
/** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love these descriptions. |
||
* VW allows long-based feature indices, but Aloha only allow's 32-bit indices | ||
* on the features that produce the key-value pairs passed to VW. The negative | ||
* dummy classes uses an ID outside of the allowable range of feature indices: | ||
* 2^32^. | ||
*/ | ||
private val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString | ||
|
||
/** | ||
* VW allows long-based feature indices, but Aloha only allow's 32-bit indices | ||
* on the features that produce the key-value pairs passed to VW. The positive | ||
* dummy classes uses an ID outside of the allowable range of feature indices: | ||
* 2^32^ + 1. | ||
*/ | ||
private val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString | ||
|
||
/** | ||
* Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the | ||
* dependent variable is based on cost (which is the negative of reward). | ||
* As such, the ''reward'' of a positive example is designated to be one, | ||
* so the cost (or negative reward) is -1. | ||
*/ | ||
private val PositiveCost = (-1).toString | ||
|
||
/** | ||
* Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the | ||
* dependent variable is based on cost (which is the negative of reward). | ||
* As such, the ''reward'' of a negative example is designated to be zero, | ||
* so the cost (or negative reward) is 0. | ||
*/ | ||
private val NegativeCost = 0.toString | ||
|
||
private val PositiveDummyClassFeature = "P" | ||
|
||
private val NegativeDummyClassFeature = "N" | ||
|
||
/** | ||
* "shared" is a special keyword in VW multi-class (multi-row) format. | ||
* See Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html page]]. | ||
* | ||
* '''NOTE''': The trailing space should be here. | ||
*/ | ||
private[this] val SharedFeatureIndicator = "shared" + " " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just write this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Emphasis. |
||
|
||
private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ')) | ||
|
||
/** | ||
* Determine the label namespaces for VW. VW insists on the uniqueness of the first character | ||
* of namespaces. The goal is to use try to use the first one of these combinations for | ||
* label and dummy label namespaces where neither of the values are in `usedNss`. | ||
* | ||
- "Y", "y" | ||
- "Z", "z" | ||
- "Λ", "λ" | ||
* | ||
* If one of these combinations cannot be used because at least one of the elements in a given | ||
* row is in `usedNss`, then iterate over the Unicode set and take the first two characters | ||
* found that adhere to are deemed a valid character. These will then become the actual and | ||
* dummy namespace names (respectively). | ||
* | ||
* The goal of this function is to try to use characters in literature used to denote a | ||
* dependent variable. If that isn't possible (because the characters are already used by some | ||
* other namespace, just find the first possible characters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Close your parenthesis. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
* @param usedNss names of namespaces used. | ||
* @return the namespace for ''actual'' label information then the namespace for ''dummy'' | ||
* label information. If two valid namespaces couldn't be produced, return None. | ||
*/ | ||
private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[LabelNamespaces] = { | ||
val nss = nssToFirstCharBitSet(usedNss) | ||
preferredLabelNamespaces(nss) orElse bruteForceNsSearch(nss) | ||
} | ||
|
||
private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[LabelNamespaces] = { | ||
PreferredLabelNamespaces collectFirst { | ||
case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) => | ||
LabelNamespaces(actual, dummy) | ||
} | ||
} | ||
|
||
private[multilabel] def nssToFirstCharBitSet(ss: Set[String]): sci.BitSet = | ||
ss.collect { case s if s != "" => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Requires class instantiation. Changing to |
||
s.charAt(0).toInt | ||
}(breakOut[Set[String], Int, sci.BitSet]) | ||
|
||
private[multilabel] def validCharForNamespace(chr: Char): Boolean = { | ||
// These might be overkill. | ||
Character.isDefined(chr) && | ||
Character.isLetter(chr) && | ||
!Character.isISOControl(chr) && | ||
!Character.isSpaceChar(chr) && | ||
!Character.isWhitespace(chr) | ||
} | ||
|
||
/** | ||
* Find the first two valid characters that can be used as VW namespaces that, when converted | ||
* to integers are not present in usedNss. | ||
* @param usedNss the set of first characters in namespaces. | ||
* @return the namespace to use for the actual classes and dummy classes, respectively. | ||
*/ | ||
private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[LabelNamespaces] = { | ||
val found = | ||
Iterator | ||
.range(Char.MinValue, Char.MaxValue) | ||
.filter(c => !(usedNss contains c) && validCharForNamespace(c.toChar)) | ||
.take(2) | ||
.toList | ||
|
||
found match { | ||
case actual :: dummy :: Nil => | ||
Option(LabelNamespaces(actual.toChar, dummy.toChar)) | ||
case _ => None | ||
} | ||
} | ||
|
||
/** | ||
* Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. | ||
* @param features (non-label dependent) features shared across all labels. | ||
* @param indices the indices `labels` into the sequence of all labels encountered | ||
* during training. | ||
* @param positiveLabelIndices a predicate telling whether the example should be positively | ||
* associated with a label. | ||
* @param defaultNs the indices into `features` that should be placed in VW's default | ||
* namespace. | ||
* @param namespaces the indices into `features` that should be associated with each | ||
* namespace. | ||
* @param classNs a namespace for features associated with class labels | ||
* @param dummyClassNs a namespace for features associated with dummy class labels | ||
* @return an array to be passed directly to an underlying `VWActionScoresLearner`. | ||
*/ | ||
private[aloha] def trainingInput( | ||
features: IndexedSeq[Sparse], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these types be more general collections? |
||
indices: sci.IndexedSeq[Int], | ||
positiveLabelIndices: Int => Boolean, | ||
defaultNs: List[Int], | ||
namespaces: List[(String, List[Int])], | ||
classNs: Char, | ||
negativeDummyStr: String, | ||
positiveDummyStr: String | ||
): Array[String] = { | ||
|
||
val n = indices.size | ||
|
||
// The length of the output array is n + 3. | ||
// | ||
// The first row is the shared features. These are features that are not label dependent. | ||
// Then comes two dummy classes. These are to make the probabilities work out. | ||
// Then come the features for each of the n labels. | ||
val x = new Array[String](n + 3) | ||
|
||
val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you name the boolean parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure |
||
x(0) = SharedFeatureIndicator + shared | ||
|
||
// These string interpolations are computed over and over but will always be the same | ||
// for a given dummyClassNs. | ||
x(1) = negativeDummyStr | ||
x(2) = positiveDummyStr | ||
|
||
// This is mutable because we want speed. | ||
var i = 0 | ||
while (i < n) { | ||
val labelInd = indices(i) | ||
|
||
// TODO or positives.contains(labelInd)? | ||
val dv = if (positiveLabelIndices(i)) PositiveCost else NegativeCost | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the only problem with the way this is written now is that the function as provided doesn't guarantee an O(1) runtime which you must have for performance. I would consider changing it so the reader of this function knows that it's an O(1) algorithm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand. Are you saying that |
||
x(i + 3) = s"$labelInd:$dv |$classNs _$labelInd" | ||
i += 1 | ||
} | ||
|
||
x | ||
} | ||
|
||
/** | ||
* Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model. | ||
* @param features (non-label dependent) features shared across all labels. | ||
* @param indices the indices `labels` into the sequence of all labels encountered | ||
* during training. | ||
* @param defaultNs the indices into `features` that should be placed in VW's default | ||
* namespace. | ||
* @param namespaces the indices into `features` that should be associated with each | ||
* namespace. | ||
* @param classNs a namespace for features associated with class labels | ||
* @return an array to be passed directly to an underlying `VWActionScoresLearner`. | ||
*/ | ||
private[aloha] def predictionInput( | ||
features: IndexedSeq[Sparse], | ||
indices: sci.IndexedSeq[Int], | ||
defaultNs: List[Int], | ||
namespaces: List[(String, List[Int])], | ||
classNs: String | ||
): Array[String] = { | ||
|
||
val n = indices.size | ||
|
||
// Use a (mutable) array (and iteration) for speed. | ||
// The first row is the shared features. These are features that are not label dependent. | ||
// Then come the features for each of the n labels. | ||
val x = new Array[String](n + 1) | ||
|
||
val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that this called function is tail recursive with the following comment from @deaktator
Since we're concerned about performance here, is this worth revisiting now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more performant than an immutable implementation. That comment was in reference to mutability, not performance. |
||
x(0) = SharedFeatureIndicator + shared | ||
|
||
var i = 0 | ||
while (i < n) { | ||
val labelInd = indices(i) | ||
x(i + 1) = s"$labelInd:0 |$classNs _$labelInd" | ||
i += 1 | ||
} | ||
|
||
x | ||
} | ||
|
||
|
||
/** | ||
* A producer that can produce a [[VwMultilabelRowCreator]]. | ||
* The requirement for [[RowCreatorProducer]] to only have zero-argument constructors is | ||
* relaxed for this Producer because we don't have a way of generically constructing a | ||
* list of labels. If the labels were encoded in the JSON, then a JsonReader for the label | ||
* type would have to be passed to the constructor. Since the labels can't be encoded | ||
* generically in the JSON, we accept that this Producer is a special case and allow the labels | ||
* to be passed directly. The consequence is that this producer doesn't just rely on the | ||
* dataset specification and the data itself. It also relying on the labels provided to the | ||
* constructor. | ||
* | ||
* @param allLabelsInTrainingSet All of the labels that will be encountered in the training set. | ||
* @param ev$1 reflection information about `K`. | ||
* @tparam A type of input passed to the [[RowCreator]]. | ||
* @tparam K the label type. | ||
*/ | ||
final class Producer[A, K: RefInfo](allLabelsInTrainingSet: sci.IndexedSeq[K]) | ||
extends RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]] | ||
with RowCreatorProducerName | ||
with VwCovariateProducer[A] | ||
with DvProducer | ||
with SparseCovariateProducer | ||
with CompilerFailureMessages { | ||
|
||
override type JsonType = VwMultilabeledJson | ||
|
||
/** | ||
* Attempt to parse the JSON AST to an intermediate representation that is used | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you finish your thought here. |
||
* | ||
* @param json | ||
* @return | ||
*/ | ||
override def parse(json: JsValue): Try[VwMultilabeledJson] = | ||
Try { json.convertTo[VwMultilabeledJson] } | ||
|
||
/** | ||
* Attempt to produce a Spec. | ||
* | ||
* @param semantics semantics used to make sense of the features in the JsonSpec | ||
* @param jsonSpec a JSON specification to transform into a RowCreator. | ||
* @return | ||
*/ | ||
override def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwMultilabeledJson): Try[VwMultilabelRowCreator[A, K]] = { | ||
val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec) | ||
|
||
val spec = for { | ||
cov <- covariates | ||
pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels) | ||
labelNs <- labelNamespaces(nss) | ||
actualLabelNs = labelNs.labelNs | ||
dummyLabelNs = labelNs.dummyLabelNs | ||
sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports) | ||
} yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss, | ||
normalizer, pos, actualLabelNs, dummyLabelNs) | ||
|
||
spec | ||
} | ||
|
||
private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = { | ||
val nsNames: Set[String] = nss.map(_._1)(breakOut) | ||
determineLabelNamespaces(nsNames) match { | ||
case Some(ns) => Success(ns) | ||
|
||
// If there are so many VW namespaces that all available Unicode characters are taken, | ||
// then a memory error will probably already have occurred. | ||
case None => Failure(new AlohaException( | ||
"Could not find any Unicode characters to as VW namespaces. Namespaces provided: " + | ||
nsNames.mkString(", ") | ||
)) | ||
} | ||
} | ||
|
||
private[multilabel] def positiveLabelsFn( | ||
semantics: CompiledSemantics[A], | ||
positiveLabels: String | ||
): Try[GenAggFunc[A, sci.IndexedSeq[K]]] = | ||
getDv[A, sci.IndexedSeq[K]]( | ||
semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K])) | ||
} | ||
|
||
private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newline |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package com.eharmony.aloha.dataset.vw.multilabel.json | ||
|
||
import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec} | ||
import com.eharmony.aloha.dataset.vw.json.VwJsonLike | ||
import spray.json.{DefaultJsonProtocol, RootJsonFormat} | ||
|
||
import scala.collection.{immutable => sci} | ||
|
||
/** | ||
* Created by ryan.deak on 9/13/17. | ||
*/ | ||
final case class VwMultilabeledJson( | ||
imports: sci.Seq[String], | ||
features: sci.IndexedSeq[SparseSpec], | ||
namespaces: Option[Seq[Namespace]] = Some(Nil), | ||
normalizeFeatures: Option[Boolean] = Some(false), | ||
positiveLabels: String) | ||
extends VwJsonLike | ||
|
||
object VwMultilabeledJson extends DefaultJsonProtocol { | ||
implicit val labeledVwJsonFormat: RootJsonFormat[VwMultilabeledJson] = | ||
jsonFormat5(VwMultilabeledJson.apply) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the casing different here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I want opinions on the name here. I am neither set on, nor particularly enamored with the name name
"multilabel-sparse"
. Suggestions?