Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #163 #193

Merged
merged 108 commits into from
Nov 8, 2017
Merged
Show file tree
Hide file tree
Changes from 92 commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
773b3b1
initial commit with some ideas on how to do multilabel.
deaktator Aug 30, 2017
66817f7
Additional hacking on multilabel model.
deaktator Aug 30, 2017
c2516f5
Added multilabel type aliases. Updated MultilabelModel but it's broken.
deaktator Aug 31, 2017
73d378b
Added additional commentst.
deaktator Aug 31, 2017
48a6157
Made SparseLabelDepFeatures type alias a little nicer but abused nota…
deaktator Aug 31, 2017
0ebc108
New plan. No label-dependent features for now.
deaktator Sep 1, 2017
8ba3e6e
moved MultilabelModel to multilabel package.
deaktator Sep 1, 2017
5931ef0
updated comments
deaktator Sep 1, 2017
5e1f1dc
Removed the requirements that SparseMultiLabelPredictor is Closeable.
deaktator Sep 1, 2017
0aea10c
Added comment to predictor.
deaktator Sep 1, 2017
85229f9
Updated skeleton of MultilabelModel and small change to RegressionFea…
deaktator Sep 1, 2017
7812019
Added B <: U in helper methods.
deaktator Sep 1, 2017
f78e4bd
Added a few comments to model and added test skeleton.
deaktator Sep 1, 2017
f21fa4f
removed parameters to Auditor. Just use defaults because the values …
deaktator Sep 1, 2017
a9a0f8b
Added import to companion object Label type and Auditor.
deaktator Sep 1, 2017
2d6de6f
Added additional test.
deaktator Sep 1, 2017
6defbe0
SparseMultiLabelPredictor was made package private for testing.
deaktator Sep 1, 2017
7661df2
updated privacy of type aliases in multilabel package object
deaktator Sep 1, 2017
9fd15ad
Serialization test
amirziai Sep 1, 2017
ce66b7c
First test passing
amirziai Sep 1, 2017
75dee22
Hopefully code complete for the case class.
deaktator Sep 1, 2017
599e336
Merge branch '163-multilabel' into 163-multilabel-test
amirziai Sep 1, 2017
2e0f4eb
Added comment.
deaktator Sep 1, 2017
3dc2921
Merge branch '163-multilabel' into 163-multilabel-test
amirziai Sep 1, 2017
3b2652b
lessened privileges.
deaktator Sep 5, 2017
dd7dc71
Merge remote-tracking branch 'upstream/163-multilabel' into 163-multi…
amirziai Sep 5, 2017
999348b
Adding more multi-label tests
amirziai Sep 5, 2017
a38b5ea
Success report test case
amirziai Sep 6, 2017
a5c5259
Addressing JMorra's PR comments.
deaktator Sep 6, 2017
2657672
More tests
amirziai Sep 6, 2017
163609e
Merge remote-tracking branch 'upstream/163-multilabel' into 163-multi…
amirziai Sep 6, 2017
71e21aa
java Serializable needs to be here
amirziai Sep 6, 2017
ed08228
Adding MultilabelModel parsing stuff, plugins, VW version, etc.
deaktator Sep 7, 2017
e450ce1
Empty label problems test
amirziai Sep 7, 2017
77d57e0
More explicit val name
amirziai Sep 7, 2017
6dc5636
MultilabelModel parsing is compiling.
deaktator Sep 8, 2017
c6152ac
Added changed from Iterable[(String, Double)] to Sparse.
deaktator Sep 8, 2017
ff1725a
new line at EOF.
deaktator Sep 8, 2017
79422ab
VW compiling but still a few holes to fill in. Added Namespaces trait.
deaktator Sep 9, 2017
7546262
VW compiling. Added test shell. Fill in shell.
deaktator Sep 11, 2017
6d5671d
Number of new changes
amirziai Sep 12, 2017
6f45145
Test passing. It appears we don't need the dummy classes in test mode.
deaktator Sep 12, 2017
8d1caaa
Merge remote-tracking branch 'upstream/163-multilabel' into 163-multi…
amirziai Sep 12, 2017
f71aba5
Updated VwSparseMultilabelPredictor. It now seems to be fully workin…
deaktator Sep 12, 2017
e8159fc
Added some test comments.
deaktator Sep 12, 2017
6f98ce2
Added comments to VwSparseMultilabelPredictor.
deaktator Sep 12, 2017
47c0658
One more passing test, moving to performance testing now
amirziai Sep 12, 2017
c54d915
updated split
deaktator Sep 13, 2017
1356881
Merged from master.
deaktator Sep 14, 2017
4c00c6f
Merging updates
amirziai Sep 18, 2017
a0a2c0d
Merge remote-tracking branch 'upstream/163-multilabel' into 163-multi…
amirziai Sep 18, 2017
86890c1
Figured out the gist of a few more tests, terrible code though
amirziai Sep 18, 2017
daffd68
First pass over all tests
amirziai Sep 19, 2017
8280ef7
Simplified some of the tests
amirziai Sep 21, 2017
80fc41c
Refactoring
amirziai Sep 22, 2017
a18a827
Refactoring common patterns into the companion object
amirziai Sep 22, 2017
406b1f9
committing VwMultilabelRowCreator and updating other stuff to use it.
deaktator Sep 22, 2017
dc363e6
All tests pass, code structured a bit better
amirziai Sep 23, 2017
b0c1b77
Merge remote-tracking branch 'upstream/163-multilabel' into 163-multi…
amirziai Sep 23, 2017
0d35d4c
Wasn't compiling after merge
amirziai Sep 23, 2017
201a823
labels not in training set should be reported
amirziai Sep 23, 2017
a42c9d7
Renamed missingLabels->labelsNotInTrainingSet to conform to new signa…
amirziai Sep 23, 2017
c8902ec
Added some unit tests. Still plenty more to do.
deaktator Sep 23, 2017
a1ce0dd
Adding PR template
amirziai Sep 24, 2017
d1a40f3
Addressing comments
amirziai Sep 25, 2017
ddf40fd
Merge pull request #182 from amirziai/163-multilabel-test
deaktator Sep 25, 2017
e3f9c59
Getting everything to compile. Still some work to be done.,
deaktator Sep 26, 2017
9dd35ee
VW multi-label model parsing working correctly. Tests prove it!
deaktator Sep 27, 2017
82746e7
exposed VW parameters to VwSparseMultilabelPredictor
deaktator Sep 27, 2017
5ab8e7c
removed TODO.
deaktator Sep 27, 2017
2023e74
updated tests to add coverage.
deaktator Sep 27, 2017
86c5c50
Added additional tests.
deaktator Sep 27, 2017
b3fdff4
End to end testing working. Need to clean it up.
deaktator Sep 29, 2017
7087d04
a little cleanup.
deaktator Sep 29, 2017
2a0b5c0
Removed implicit fn com.eharmony.aloha.factory.ScalaJsonFormats.lift(…
deaktator Sep 29, 2017
e11aef1
simplifying tests.
deaktator Sep 29, 2017
c2a048e
vw param function skeleton.
deaktator Oct 3, 2017
0ce7d68
merge from development branch. Build passing.
deaktator Oct 4, 2017
3cb9d2e
non working VwMultilabelModel.updatedVwParams. Skeleton laid out.
deaktator Oct 5, 2017
fb7a8f2
quadratics and cubics seem to be working.
deaktator Oct 5, 2017
f86b610
removed println
deaktator Oct 5, 2017
4c25972
made ignore_linear more concise
deaktator Oct 5, 2017
79f988b
lots of stuff working. More tests to write for VwMultilabelParamAugm…
deaktator Oct 7, 2017
9f61597
tested higher order interactions.
deaktator Oct 7, 2017
ca84c36
removed extra whitespace in string output.
deaktator Oct 7, 2017
db8967f
working but will change regex padding to use zero-width positive look…
deaktator Oct 9, 2017
6f89630
added different padding.
deaktator Oct 9, 2017
95e0c92
Updated documentation and tests. Looks good.
deaktator Oct 9, 2017
a7b168a
Updated VW label NS algo. Added test for when a NS can't be found.
deaktator Oct 9, 2017
9c4362b
hacky solution to flags with options referencing files. Use tmp file…
deaktator Oct 11, 2017
6df3f83
Looks good.
deaktator Oct 30, 2017
cd50006
Precompute positive and negative dummy class strings.
deaktator Nov 1, 2017
ffd10ee
Adding numUniqueLabels parameter to updatedVwParams to add VW's --rin…
deaktator Nov 1, 2017
9592b09
stateful row creator and reservoir sampling.
deaktator Nov 3, 2017
8e478f3
provided concrete implementations of iterator and vector apply method.
deaktator Nov 4, 2017
b7594fa
more purity in test.
deaktator Nov 4, 2017
819095a
separated pure and impure code.
deaktator Nov 4, 2017
b190821
no more. good enough.
deaktator Nov 4, 2017
5e1ebd0
It's never good enough, even on a football Saturday.
deaktator Nov 4, 2017
0889971
Seq -> List
deaktator Nov 5, 2017
d60f226
VwDownsampledMultilabelRowCreator and supporting infrastructure and t…
deaktator Nov 7, 2017
fe8dca2
Made Rand use Int indices, made k < 2^15 in neg label sampling. Upda…
deaktator Nov 7, 2017
baf7466
Addressing PR comments. Removed VW params from VW multi-label model …
deaktator Nov 8, 2017
22d9547
Downsampling can now operate over 2^31 - 1 (2 billion) labels.
deaktator Nov 8, 2017
c463c11
changed Iterator.isEmpty to hasNext. Updated docs.
deaktator Nov 8, 2017
79fddb6
forgot logical not in if statement.
deaktator Nov 8, 2017
1c0102c
removed toShort from Rand.
deaktator Nov 8, 2017
6e21f8c
addressing PR comments. Changed name of multilabel model to 'SparseM…
deaktator Nov 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class ModelTypesTest {
"ModelDecisionTree",
"Regression",
"Segmentation",
"VwJNI"
"VwJNI",
"multilabel-sparse"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the casing different here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I want opinions on the name here. I am neither set on, nor particularly enamored with the name name "multilabel-sparse". Suggestions?

)

val actual = ModelFactory.defaultFactory(null, null).parsers.map(_.modelType).sorted
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
package com.eharmony.aloha.dataset.vw.multilabel

import com.eharmony.aloha.AlohaException
import com.eharmony.aloha.dataset._
import com.eharmony.aloha.dataset.density.Sparse
import com.eharmony.aloha.dataset.vw.VwCovariateProducer
import com.eharmony.aloha.dataset.vw.multilabel.json.VwMultilabeledJson
import com.eharmony.aloha.dataset.vw.unlabeled.VwRowCreator
import com.eharmony.aloha.reflect.RefInfo
import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.func.GenAggFunc
import spray.json.JsValue

import scala.collection.{breakOut, immutable => sci}
import scala.util.{Failure, Success, Try}

/**
* Created by ryan.deak on 9/13/17.
*/
final case class VwMultilabelRowCreator[-A, K](
allLabelsInTrainingSet: sci.IndexedSeq[K],
featuresFunction: FeatureExtractorFunction[A, Sparse],
defaultNamespace: List[Int],
namespaces: List[(String, List[Int])],
normalizer: Option[CharSequence => CharSequence],
positiveLabelsFunction: GenAggFunc[A, sci.IndexedSeq[K]],
classNs: Char,
dummyClassNs: Char,
includeZeroValues: Boolean = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're missing something important here. We have experimentally shown that when training on a very large multilabel problem the ability to randomly sample from the negative labels at train time both decreases train time and can increase model performance. I think that a parameter to the row creator should be a function that allows me to sample a percentage of negative labels to add to each train line. This needs to be random and the negative labels added needs to change with EVERY line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I like that. Let me think on that for a few and come up with something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmorra I think the requirement that "negative labels added need to change with EVERY line" is important, but I don't think it fits into the current architecture. It seems like you are (at least implicitly) advocating two things:

  1. Intentionally making apply "more" impure, and definitely not idempotent or a function (in the math sense).
  2. Maintaining state inside VwMultilabelRowCreator (by virtue of (1))

What if we have a training set that is distributed across multiple nodes. We apply the VwMultilabelRowCreator to the elements on each node and combine the results? If we want to do what I think you are suggesting, I think we'll need to make some pretty big design changes. I suggest the following:

  1. We'll need a constructor seed parameter. This could be a value, but if so, the calling code would be responsible for varying the seed for each node on which the VwMultilabelRowCreator is run. It could also be an impure and side-effecting function is idempotent (in context), meaning it would produce the same result every time it is run on the same node. This would allow the seed generator parameter to be set once and called on each node to produce the seed for that node.
    seed: () => Long // Not necessarily a long but something to seed pseudo-randomness.
  2. We'll need a new interface like StatefulRowCreator. This interface will use something like the State monad in cats or scalaz. I would recommend stack safety because its implementation stack-safe (remember the tut bug in 0.5.3?).
  3. We'll have to use the StatefulRowCreator differently most likely, we could expose two methods, one that operates on single items and requires the caller the maintain the state and pass it in. The other method could operate on multiple elements (collections, iterators) and hide the state information.

What do you think about this stuff? I think if we try to jam this type of mutability into the current interface, it'll most likely produce disastrous results.

Copy link
Contributor

@ryan-deak-zefr ryan-deak-zefr Nov 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmorra: I am proposing the following trait:

trait StatefulRowCreator[-A, +B, S] extends Serializable {
  // An initial state derived from the constructor
  def initialState: S

  // Caller maintains state.
  def apply(a: A, s: S): ((MissingAndErroneousFeatureInfo, Option[B]), S)

  // Iterator doesn't return state because it's non-strict.  Output Iterator length
  // is same as input iterator length.  May be infinite length.  This should be a constant
  // time operation.
  def apply(as: Iterator[A], s: S): Iterator[(MissingAndErroneousFeatureInfo, Option[B])]

  // Output Vector length is same as input length.  Returns the final state.
  // This is a linear time operation.
  def apply(as: Vector[A], s: S): (Vector[(MissingAndErroneousFeatureInfo, Option[B])], S)
}

I think it was a mistake in retrospect to not have Option[B] in the RowCreator trait. This should be rectified here and we can address the Option in the RowCreator elsewhere.

I'm happy to change the function names but I like the signatures. (map maybe?).

We could consider putting the state as the first parameter at least in the Iterator and Vector methods so that we can curry and supply the initial state to get back a function of one argument.

) extends RowCreator[A, Array[String]] {
import VwMultilabelRowCreator._

@transient private[this] lazy val labelToInd = allLabelsInTrainingSet.zipWithIndex.toMap

// Precompute these for efficiency rather recompute than inside a hot loop.
// Notice these are not lazy vals.

private[this] val negativeDummyStr =
s"$NegDummyClassId:$NegativeCost |$dummyClassNs $NegativeDummyClassFeature"

private[this] val positiveDummyStr =
s"$PosDummyClassId:$PositiveCost |$dummyClassNs $PositiveDummyClassFeature"

override def apply(a: A): (MissingAndErroneousFeatureInfo, Array[String]) = {
val (missingAndErrs, features) = featuresFunction(a)

// TODO: Should this be sci.BitSet?
val positiveIndices: Set[Int] =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to consider BitSet but only if the performance improvement is drastic. If we care about performance shouldn't we consider not changing collection types? Isn't .toSeq an O(N) operation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.toSeq is on an option. It could be removed without effect due to the implicit function that changes options to iterables.

positiveLabelsFunction(a).flatMap { y => labelToInd.get(y).toSeq }(breakOut)

val x: Array[String] = trainingInput(
features,
allLabelsInTrainingSet.indices,
positiveIndices,
defaultNamespace,
namespaces,
classNs,
negativeDummyStr,
positiveDummyStr
)

(missingAndErrs, x)
}
}

object VwMultilabelRowCreator {

/**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love these descriptions.

* VW allows long-based feature indices, but Aloha only allow's 32-bit indices
* on the features that produce the key-value pairs passed to VW. The negative
* dummy classes uses an ID outside of the allowable range of feature indices:
* 2^32^.
*/
private val NegDummyClassId = (Int.MaxValue.toLong + 1L).toString

/**
* VW allows long-based feature indices, but Aloha only allow's 32-bit indices
* on the features that produce the key-value pairs passed to VW. The positive
* dummy classes uses an ID outside of the allowable range of feature indices:
* 2^32^ + 1.
*/
private val PosDummyClassId = (Int.MaxValue.toLong + 2L).toString

/**
* Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the
* dependent variable is based on cost (which is the negative of reward).
* As such, the ''reward'' of a positive example is designated to be one,
* so the cost (or negative reward) is -1.
*/
private val PositiveCost = (-1).toString

/**
* Since VW CSOAA stands for '''COST''' ''Sensitive One Against All'', the
* dependent variable is based on cost (which is the negative of reward).
* As such, the ''reward'' of a negative example is designated to be zero,
* so the cost (or negative reward) is 0.
*/
private val NegativeCost = 0.toString

private val PositiveDummyClassFeature = "P"

private val NegativeDummyClassFeature = "N"

/**
* "shared" is a special keyword in VW multi-class (multi-row) format.
* See Hal Daume's [[https://www.umiacs.umd.edu/%7Ehal/tmp/multiclassVW.html page]].
*
* '''NOTE''': The trailing space should be here.
*/
private[this] val SharedFeatureIndicator = "shared" + " "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just write this as "shared "?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emphasis.


private[this] val PreferredLabelNamespaces = Seq(('Y', 'y'), ('Z', 'z'), ('Λ', 'λ'))

/**
* Determine the label namespaces for VW. VW insists on the uniqueness of the first character
* of namespaces. The goal is to use try to use the first one of these combinations for
* label and dummy label namespaces where neither of the values are in `usedNss`.
*
- "Y", "y"
- "Z", "z"
- "Λ", "λ"
*
* If one of these combinations cannot be used because at least one of the elements in a given
* row is in `usedNss`, then iterate over the Unicode set and take the first two characters
* found that adhere to are deemed a valid character. These will then become the actual and
* dummy namespace names (respectively).
*
* The goal of this function is to try to use characters in literature used to denote a
* dependent variable. If that isn't possible (because the characters are already used by some
* other namespace, just find the first possible characters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close your parenthesis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

* @param usedNss names of namespaces used.
* @return the namespace for ''actual'' label information then the namespace for ''dummy''
* label information. If two valid namespaces couldn't be produced, return None.
*/
private[aloha] def determineLabelNamespaces(usedNss: Set[String]): Option[LabelNamespaces] = {
val nss = nssToFirstCharBitSet(usedNss)
preferredLabelNamespaces(nss) orElse bruteForceNsSearch(nss)
}

private[multilabel] def preferredLabelNamespaces(nss: sci.BitSet): Option[LabelNamespaces] = {
PreferredLabelNamespaces collectFirst {
case (actual, dummy) if !(nss contains actual.toInt) && !(nss contains dummy.toInt) =>
LabelNamespaces(actual, dummy)
}
}

private[multilabel] def nssToFirstCharBitSet(ss: Set[String]): sci.BitSet =
ss.collect { case s if s != "" =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s.nonEmpty is more legible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requires class instantiation. Changing to case s if s.length != 0 => .

s.charAt(0).toInt
}(breakOut[Set[String], Int, sci.BitSet])

private[multilabel] def validCharForNamespace(chr: Char): Boolean = {
// These might be overkill.
Character.isDefined(chr) &&
Character.isLetter(chr) &&
!Character.isISOControl(chr) &&
!Character.isSpaceChar(chr) &&
!Character.isWhitespace(chr)
}

/**
* Find the first two valid characters that can be used as VW namespaces that, when converted
* to integers are not present in usedNss.
* @param usedNss the set of first characters in namespaces.
* @return the namespace to use for the actual classes and dummy classes, respectively.
*/
private[multilabel] def bruteForceNsSearch(usedNss: sci.BitSet): Option[LabelNamespaces] = {
val found =
Iterator
.range(Char.MinValue, Char.MaxValue)
.filter(c => !(usedNss contains c) && validCharForNamespace(c.toChar))
.take(2)
.toList

found match {
case actual :: dummy :: Nil =>
Option(LabelNamespaces(actual.toChar, dummy.toChar))
case _ => None
}
}

/**
* Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model.
* @param features (non-label dependent) features shared across all labels.
* @param indices the indices `labels` into the sequence of all labels encountered
* during training.
* @param positiveLabelIndices a predicate telling whether the example should be positively
* associated with a label.
* @param defaultNs the indices into `features` that should be placed in VW's default
* namespace.
* @param namespaces the indices into `features` that should be associated with each
* namespace.
* @param classNs a namespace for features associated with class labels
* @param dummyClassNs a namespace for features associated with dummy class labels
* @return an array to be passed directly to an underlying `VWActionScoresLearner`.
*/
private[aloha] def trainingInput(
features: IndexedSeq[Sparse],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these types be more general collections?

indices: sci.IndexedSeq[Int],
positiveLabelIndices: Int => Boolean,
defaultNs: List[Int],
namespaces: List[(String, List[Int])],
classNs: Char,
negativeDummyStr: String,
positiveDummyStr: String
): Array[String] = {

val n = indices.size

// The length of the output array is n + 3.
//
// The first row is the shared features. These are features that are not label dependent.
// Then comes two dummy classes. These are to make the probabilities work out.
// Then come the features for each of the n labels.
val x = new Array[String](n + 3)

val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you name the boolean parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

x(0) = SharedFeatureIndicator + shared

// These string interpolations are computed over and over but will always be the same
// for a given dummyClassNs.
x(1) = negativeDummyStr
x(2) = positiveDummyStr

// This is mutable because we want speed.
var i = 0
while (i < n) {
val labelInd = indices(i)

// TODO or positives.contains(labelInd)?
val dv = if (positiveLabelIndices(i)) PositiveCost else NegativeCost
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the only problem with the way this is written now is that the function as provided doesn't guarantee an O(1) runtime which you must have for performance. I would consider changing it so the reader of this function knows that it's an O(1) algorithm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Are you saying that positiveLabelIndices is not guaranteed to be O(1)? That's definitely true, but neither is Set.contains. See Collections Performance Characteristics. I don't think this is an issue as it's a package private function that shouldn't be called outside Aloha.

x(i + 3) = s"$labelInd:$dv |$classNs _$labelInd"
i += 1
}

x
}

/**
* Produce a multi-line input to be consumed by the underlying ''CSOAA LDF'' VW model.
* @param features (non-label dependent) features shared across all labels.
* @param indices the indices `labels` into the sequence of all labels encountered
* during training.
* @param defaultNs the indices into `features` that should be placed in VW's default
* namespace.
* @param namespaces the indices into `features` that should be associated with each
* namespace.
* @param classNs a namespace for features associated with class labels
* @return an array to be passed directly to an underlying `VWActionScoresLearner`.
*/
private[aloha] def predictionInput(
features: IndexedSeq[Sparse],
indices: sci.IndexedSeq[Int],
defaultNs: List[Int],
namespaces: List[(String, List[Int])],
classNs: String
): Array[String] = {

val n = indices.size

// Use a (mutable) array (and iteration) for speed.
// The first row is the shared features. These are features that are not label dependent.
// Then come the features for each of the n labels.
val x = new Array[String](n + 1)

val shared = VwRowCreator.unlabeledVwInput(features, defaultNs, namespaces, false)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this called function is tail recursive with the following comment from @deaktator

// RMD 2015-06-12: GOD I HATE THIS CODE!!!  Maybe functionalize it in the future!

Since we're concerned about performance here, is this worth revisiting now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more performant than an immutable implementation. That comment was in reference to mutability, not performance.

x(0) = SharedFeatureIndicator + shared

var i = 0
while (i < n) {
val labelInd = indices(i)
x(i + 1) = s"$labelInd:0 |$classNs _$labelInd"
i += 1
}

x
}


/**
* A producer that can produce a [[VwMultilabelRowCreator]].
* The requirement for [[RowCreatorProducer]] to only have zero-argument constructors is
* relaxed for this Producer because we don't have a way of generically constructing a
* list of labels. If the labels were encoded in the JSON, then a JsonReader for the label
* type would have to be passed to the constructor. Since the labels can't be encoded
* generically in the JSON, we accept that this Producer is a special case and allow the labels
* to be passed directly. The consequence is that this producer doesn't just rely on the
* dataset specification and the data itself. It also relying on the labels provided to the
* constructor.
*
* @param allLabelsInTrainingSet All of the labels that will be encountered in the training set.
* @param ev$1 reflection information about `K`.
* @tparam A type of input passed to the [[RowCreator]].
* @tparam K the label type.
*/
final class Producer[A, K: RefInfo](allLabelsInTrainingSet: sci.IndexedSeq[K])
extends RowCreatorProducer[A, Array[String], VwMultilabelRowCreator[A, K]]
with RowCreatorProducerName
with VwCovariateProducer[A]
with DvProducer
with SparseCovariateProducer
with CompilerFailureMessages {

override type JsonType = VwMultilabeledJson

/**
* Attempt to parse the JSON AST to an intermediate representation that is used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you finish your thought here.

*
* @param json
* @return
*/
override def parse(json: JsValue): Try[VwMultilabeledJson] =
Try { json.convertTo[VwMultilabeledJson] }

/**
* Attempt to produce a Spec.
*
* @param semantics semantics used to make sense of the features in the JsonSpec
* @param jsonSpec a JSON specification to transform into a RowCreator.
* @return
*/
override def getRowCreator(semantics: CompiledSemantics[A], jsonSpec: VwMultilabeledJson): Try[VwMultilabelRowCreator[A, K]] = {
val (covariates, default, nss, normalizer) = getVwData(semantics, jsonSpec)

val spec = for {
cov <- covariates
pos <- positiveLabelsFn(semantics, jsonSpec.positiveLabels)
labelNs <- labelNamespaces(nss)
actualLabelNs = labelNs.labelNs
dummyLabelNs = labelNs.dummyLabelNs
sem = addStringImplicitsToSemantics(semantics, jsonSpec.imports)
} yield new VwMultilabelRowCreator[A, K](allLabelsInTrainingSet, cov, default, nss,
normalizer, pos, actualLabelNs, dummyLabelNs)

spec
}

private[multilabel] def labelNamespaces(nss: List[(String, List[Int])]): Try[LabelNamespaces] = {
val nsNames: Set[String] = nss.map(_._1)(breakOut)
determineLabelNamespaces(nsNames) match {
case Some(ns) => Success(ns)

// If there are so many VW namespaces that all available Unicode characters are taken,
// then a memory error will probably already have occurred.
case None => Failure(new AlohaException(
"Could not find any Unicode characters to as VW namespaces. Namespaces provided: " +
nsNames.mkString(", ")
))
}
}

private[multilabel] def positiveLabelsFn(
semantics: CompiledSemantics[A],
positiveLabels: String
): Try[GenAggFunc[A, sci.IndexedSeq[K]]] =
getDv[A, sci.IndexedSeq[K]](
semantics, "positiveLabels", Option(positiveLabels), Option(Vector.empty[K]))
}

private[aloha] final case class LabelNamespaces(labelNs: Char, dummyLabelNs: Char)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.eharmony.aloha.dataset.vw.multilabel.json

import com.eharmony.aloha.dataset.json.{Namespace, SparseSpec}
import com.eharmony.aloha.dataset.vw.json.VwJsonLike
import spray.json.{DefaultJsonProtocol, RootJsonFormat}

import scala.collection.{immutable => sci}

/**
* Created by ryan.deak on 9/13/17.
*/
final case class VwMultilabeledJson(
imports: sci.Seq[String],
features: sci.IndexedSeq[SparseSpec],
namespaces: Option[Seq[Namespace]] = Some(Nil),
normalizeFeatures: Option[Boolean] = Some(false),
positiveLabels: String)
extends VwJsonLike

object VwMultilabeledJson extends DefaultJsonProtocol {
implicit val labeledVwJsonFormat: RootJsonFormat[VwMultilabeledJson] =
jsonFormat5(VwMultilabeledJson.apply)
}
Loading