Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary Classification Confusion Matrix and AUC Aggregators #633

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

/**
* Curve is a list of Confusion Matrices with different
* thresholds
*
* @param matrices List of Matrices
*/
case class ConfusionCurve(matrices: List[ConfusionMatrix])

/**
* Given a List of (x,y) this functions computes the
* Area Under the Curve
*/
object AreaUnderCurve {
private def trapezoid(points: Seq[(Double, Double)]): Double =
points match {
case Seq((x1, x2), (y1, y2)) => (y1 - x1) * (y2 + x2) / 2.0
case _ => sys.error("Points must be of length 2.")
}

def of(curve: List[(Double, Double)]): Double = {
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
combop = _ + _)
}
}

sealed trait AUCMetric
case object ROC extends AUCMetric
case object PR extends AUCMetric

/**
* Sums Curves which are a series of Confusion Matrices
* with different thresholds
*/
case object ConfusionCurveMonoid extends Monoid[ConfusionCurve] {
def zero = ConfusionCurve(Nil)
override def plus(left: ConfusionCurve, right: ConfusionCurve): ConfusionCurve = {
val sg = BinaryClassificationConfusionMatrixMonoid

ConfusionCurve(
left.matrices.zipAll(right.matrices, sg.zero, sg.zero)
.map{ case (cl, cr) => sg.plus(cl, cr) })
}
}

/**
* AUCAggregator computes the Area Under the Curve
* for a given metric by sampling along that curve.
*
* The number of samples is taken and is used to compute
* the thresholds to use. A confusion matrix is then computed
* for each threshold and finally that is used to compute the
* Area Under the Curve.
*
* Note this is for Binary Classifications Tasks
*
* @param metric Which Metric to compute
* @param samples Number of samples, defaults to 100
*/
case class BinaryClassificationAUCAggregator(metric: AUCMetric, samples: Int = 100)
extends Aggregator[BinaryPrediction, ConfusionCurve, Double]
with Serializable {

private def linspace(a: Double, b: Double, length: Int): Array[Double] = {
val increment = (b - a) / (length - 1)
Array.tabulate(length)(i => a + increment * i)
}

private lazy val thresholds = linspace(0.0, 1.0, samples)
private lazy val aggregators = thresholds.map(BinaryClassificationConfusionMatrixAggregator(_)).toList

def prepare(input: BinaryPrediction): ConfusionCurve = ConfusionCurve(aggregators.map(_.prepare(input)))

def semigroup: Semigroup[ConfusionCurve] = ConfusionCurveMonoid

def present(c: ConfusionCurve): Double = {
val total = c.matrices.map { matrix =>
val scores = BinaryClassificationConfusionMatrixAggregator().present(matrix)
metric match {
case ROC => (scores.falsePositiveRate, scores.recall)
case PR => (scores.recall, scores.precision)
}
}.reverse

val combined = metric match {
case ROC => total ++ List((1.0, 1.0))
case PR => List((0.0, 1.0)) ++ total
}

AreaUnderCurve.of(combined)
}
}

object BinaryClassificationAUC {
implicit def monoid: Monoid[ConfusionCurve] = ConfusionCurveMonoid
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

/**
* A BinaryPrediction is a label with a score
*
* @param score Score of the classifier
* @param label Is this in the positive or negative class.
*/
case class BinaryPrediction(score: Double, label: Boolean) extends Serializable {
override def toString: String = s"$label,$score"
}

/**
* Confusion Matrix itself with the statistics to be aggregated
*/
case class ConfusionMatrix(
truePositive: Int = 0,
falsePositive: Int = 0,
falseNegative: Int = 0,
trueNegative: Int = 0)
extends Serializable

/**
* After the aggregation this generates some common statistics
*
* @param fscore F Score based on the alpha given to the Aggregator
* @param precision Precision Score
* @param recall Recall Score
* @param falsePositiveRate False Positive Rate
* @param matrix Confusion Matrix
*/
case class Scores(
fscore: Double,
precision: Double,
recall: Double,
falsePositiveRate: Double,
matrix: ConfusionMatrix)
extends Serializable

case object BinaryClassificationConfusionMatrixMonoid extends Monoid[ConfusionMatrix] {
def zero: ConfusionMatrix = ConfusionMatrix()
override def plus(left: ConfusionMatrix, right: ConfusionMatrix): ConfusionMatrix = {
val tp = left.truePositive + right.truePositive
val fp = left.falsePositive + right.falsePositive
val fn = left.falseNegative + right.falseNegative
val tn = left.trueNegative + right.trueNegative

ConfusionMatrix(tp, fp, fn, tn)
}
}

/**
* A Confusion Matrix Aggregator creates a Confusion Matrix and
* relevant scores for a given threshold given predictions from
* a binary classifier.
*
* @param threshold Threshold to use for the predictions
* @param beta Beta used in the FScore Calculation.
*/
case class BinaryClassificationConfusionMatrixAggregator(threshold: Double = 0.5, beta: Double = 1.0)
extends Aggregator[BinaryPrediction, ConfusionMatrix, Scores]
with Serializable {

def prepare(input: BinaryPrediction): ConfusionMatrix =
(input.label, input.score) match {
case (true, score) if score > threshold =>
ConfusionMatrix(truePositive = 1)
case (true, score) if score < threshold =>
ConfusionMatrix(falseNegative = 1)
case (false, score) if score < threshold =>
ConfusionMatrix(trueNegative = 1)
case (false, score) if score > threshold =>
ConfusionMatrix(falsePositive = 1)
}

def semigroup: Semigroup[ConfusionMatrix] =
BinaryClassificationConfusionMatrixMonoid

def present(m: ConfusionMatrix): Scores = {
val precDenom = m.truePositive.toDouble + m.falsePositive.toDouble
val precision = if (precDenom > 0.0) m.truePositive.toDouble / precDenom else 1.0

val recallDenom = m.truePositive.toDouble + m.falseNegative.toDouble
val recall = if (recallDenom > 0.0) m.truePositive.toDouble / recallDenom else 1.0

val fpDenom = m.falsePositive.toDouble + m.trueNegative.toDouble
val fpr = if (fpDenom > 0.0) m.falsePositive.toDouble / fpDenom else 0.0

val betaSqr = Math.pow(beta, 2.0)

val fScoreDenom = (betaSqr * precision) + recall

val fscore = if (fScoreDenom > 0.0) {
(1 + betaSqr) * ((precision * recall) / fScoreDenom)
} else { 1.0 }

Scores(fscore, precision, recall, fpr, m)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

import org.scalacheck.Arbitrary
import org.scalacheck.Gen.choose
import org.scalactic.TolerantNumerics
import org.scalatest.{ Matchers, _ }

class ConfusionCurveMonoidLaws extends CheckProperties {
import BaseProperties._
import BinaryClassificationAUC._

implicit val gen = Arbitrary {
for (
v <- choose(0, 10000)
) yield ConfusionCurve(List(ConfusionMatrix(truePositive = v)))
}

property("Curve is associative") {
isAssociative[ConfusionCurve]
}

property("Curve is a monoid") {
monoidLaws[ConfusionCurve]
}
}

class BinaryClassificationAUCTest extends WordSpec with Matchers {
lazy val data =
List(
BinaryPrediction(0.1, false),
BinaryPrediction(0.1, true),
BinaryPrediction(0.4, false),
BinaryPrediction(0.6, false),
BinaryPrediction(0.6, true),
BinaryPrediction(0.6, true),
BinaryPrediction(0.8, true))

"BinaryClassificationAUC" should {
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1)

"return roc auc" in {
val aggregator = BinaryClassificationAUCAggregator(ROC, samples = 50)
assert(aggregator(data) === 0.708)
}

"return pr auc" in {
val aggregator = BinaryClassificationAUCAggregator(PR, samples = 50)
assert(aggregator(data) === 0.833)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

import org.scalacheck.Arbitrary
import org.scalacheck.Gen.choose
import org.scalactic.TolerantNumerics
import org.scalatest.{ Matchers, WordSpec }

class BinaryClassificationConfusionMatrixMonoidLaws extends CheckProperties {
import BaseProperties._

implicit val semigroup = BinaryClassificationConfusionMatrixMonoid
implicit val gen = Arbitrary {
for (
v <- choose(0, 10000)
) yield ConfusionMatrix(truePositive = v)
}

property("ConfusionMatrix is associative") {
isAssociative[ConfusionMatrix]
}

property("ConfusionMatrix is a monoid") {
monoidLaws[ConfusionMatrix]
}
}

class BinaryClassificationConfusionMatrixTest extends WordSpec with Matchers {
lazy val data =
List(
BinaryPrediction(0.1, false),
BinaryPrediction(0.1, true),
BinaryPrediction(0.4, false),
BinaryPrediction(0.6, false),
BinaryPrediction(0.6, true),
BinaryPrediction(0.6, true),
BinaryPrediction(0.8, true))

"BinaryClassificationConfusionMatrix" should {
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1)

"return a correct confusion matrix" in {
val aggregator = BinaryClassificationConfusionMatrixAggregator()
val scored = aggregator(data)

assert(scored.recall === 0.75)
assert(scored.precision === 0.75)
assert(scored.fscore === 0.75)
assert(scored.falsePositiveRate === 0.333)
}
}
}
4 changes: 4 additions & 0 deletions docs/src/main/resources/microsite/data/menu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ options:
url: datatypes/bytes.html
menu_type: data

- title: Binary Classifier AUC
url: datatypes/auc.html
menu_type: data

- title: Decayed Value
url: datatypes/decayed_value.html
menu_type: data
Expand Down
Loading