Skip to content

Commit

Permalink
fixed documentation of AUC calculation, and removed a hack
Browse files Browse the repository at this point in the history
  • Loading branch information
athawk81 committed Jun 13, 2016
1 parent 8793f2e commit 102457c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
0.10.1 -->

<version> 0.10.11</version>
<version> 0.10.12</version>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,16 @@ public ArrayList<AUCPoint> getAUCPointsFromData(List<AUCData> aucDataList) {
double falseNegatives = 0;

ArrayList<AUCPoint> aucPoints = new ArrayList<>();
double threshold = 0.0;
double thresholdForPositiveClassification = 0.0;
//start at upper right of ROC CURVE where everything is a positive
for (AUCData aucData : aucDataList) {
if (aucData.getClassification().equals(positiveClassification)) {
truePositives += aucData.getWeight();
} else {
falsePositives += aucData.getWeight();
}
}
//add 0,0 since we won't get it if we always predict 0.0
//add 1,1 since we won't get it if we always predict 0.0
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
//iterate through each data point updating all points that are changed by the threshold
int startIndex = 0;
Expand All @@ -93,13 +94,24 @@ public ArrayList<AUCPoint> getAUCPointsFromData(List<AUCData> aucDataList) {
probabilityOfNext = aucData.getProbabilityOfPositiveClassification();
}

//now compute the non 0,0 ROC curve points
//now compute the non endpoint ROC curve points
for (int i = startIndex; i< aucDataList.size(); i++) {
//each computed probability of positive classification is used as a threshold (in ascending order
//which maps to the the upper right of the ROC curve.

// At each threshold, we know that at most one data point changed to be classified
// as a negative (and thus know the complete count of TPs FPs, TN, FN at that ROC point

//note, we make the threshold inclusive, in the sense that points are labeled positives if they are
//less than the threshold

AUCData aucData = aucDataList.get(i);
double probability = aucData.getProbabilityOfPositiveClassification();
if (threshold != probability && probability!=0.0) {

//no need to double count
if (thresholdForPositiveClassification != probability && probability!=0.0) {
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
threshold = probability;
thresholdForPositiveClassification = probability;
}
//point is a positive but with the new threshold, we predict it is negative
if (aucData.getClassification().equals(positiveClassification)) {
Expand All @@ -119,7 +131,7 @@ public ArrayList<AUCPoint> getAUCPointsFromData(List<AUCData> aucDataList) {
if (truePositives !=0 && falsePositives !=0) {
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
}
// (0,0)
// (1,1)
aucPoints.add(getAUCPoint(0, 0, trueNegatives, falseNegatives));
return aucPoints;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ protected List<String> alternativeGetCandidateAttributesWithIgnoringApplied(Bran
}
int numTrialAttributes = (int)((1.0-ignoreProb)*candidates.size());

//p_of_a_single_colision at i= frac_in_return_set (fi). The expected num collisions at point i is ei.
// ei = p of just 1 collision = p(1-p), prob of just 2 collisions: = (1-p)p^2, n collisions = (1-p)p^n
//e_i = expected num collisions sum(n*p(1-p)^n, n=0, n=N), N=in collection.=> e_ = p/(1-p) *d/dB( sum B^n, n=1, n=N+1), B=1-p
//p/(1-p)*d/dB[(1-B^(N+1))/(1-B)]=
//sum(ei, i=1, num_elements_to_be_returned),

//O(N) way of shuffling the attributes to make all permutations equally likely.
Collections.shuffle(candidates);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public Optional<? extends Branch<VC>> getBranch(Branch<VC> parent, AttributeStat
return Optional.absent();
}
SplittingUtils.SplitScore splitScore = splitScoreOptional.get();
//TODO: make this 2 a hyper-parameter as it is leads to better performance in cases tested, but may not generalize
splitScore.score=splitScore.score*2; //2.3 1.7
//TODO: make a hyper-parameter for alpha on the following line as it is leads to better performance in cases tested, but may not generalize
//splitScore.score=splitScore.score*alpha; //value around 2 often works well.
double bestThreshold = (Double)attributeStats.getStatsOnEachValue().get(splitScore.indexOfLastValueCounterInTrueSet).getAttrVal();
return createBranch(parent, attributeStats, splitScore, bestThreshold);
}
Expand Down

0 comments on commit 102457c

Please sign in to comment.