Skip to content

Commit

Permalink
refactor: verify liquid legion sketch params when computing variances. (
Browse files Browse the repository at this point in the history
  • Loading branch information
ple13 authored Oct 16, 2024
1 parent ccc4113 commit ae488a6
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ object Covariances {
val overlapReach =
reachMeasurementCovarianceParams.reach + reachMeasurementCovarianceParams.otherReach -
reachMeasurementCovarianceParams.unionReach

val overlapSamplingWidth =
reachMeasurementCovarianceParams.samplingWidth +
reachMeasurementCovarianceParams.otherSamplingWidth -
reachMeasurementCovarianceParams.unionSamplingWidth
require(overlapSamplingWidth >= 0.0 && overlapSamplingWidth <= 1.0) {
"Overlap sampling width must be greater than or equal to 0 and less than or equal to 1, but" +
" got $overlapSamplingWidth."
}

return overlapReach *
(overlapSamplingWidth /
reachMeasurementCovarianceParams.samplingWidth /
Expand All @@ -55,6 +61,7 @@ object Covariances {
sketchParams: LiquidLegionsSketchParams,
reachMeasurementCovarianceParams: ReachMeasurementCovarianceParams,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
return LiquidLegions.inflatedReachCovariance(
sketchParams = sketchParams,
reach = reachMeasurementCovarianceParams.reach,
Expand Down Expand Up @@ -89,6 +96,10 @@ object Covariances {
otherWeightedMeasurementVarianceParams.measurementVarianceParams.measurementParams
.vidSamplingInterval,
)
require(unionSamplingWidth >= 0.0 && unionSamplingWidth <= 1.0) {
"The union sampling width must be greater than or equal to 0 and less than or equal to 1, " +
"but got $unionSamplingWidth."
}

val liquidLegionsSketchParams =
when (val methodology = weightedMeasurementVarianceParams.methodology) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,21 @@ object HonestMajorityShareShuffle {
reachNoiseVariance: Double,
): Double {
val vidUniverseSize: Long = ceil(frequencyVectorSize / vidSamplingIntervalWidth).toLong()
require(vidUniverseSize > 1) { "Vid universe size is too small." }
require(vidUniverseSize > 1) {
"Vid universe size must be greater than 1, but got $vidUniverseSize."
}

require(vidSamplingIntervalWidth > 0.0 && vidSamplingIntervalWidth <= 1.0) {
"Vid sampling width must be greater than 0 and less than or equal to 1."
"Vid sampling width must be greater than 0 and less than or equal to 1, but got " +
"$vidSamplingIntervalWidth."
}
require(reach <= vidUniverseSize) {
"Reach must be less than or equal to the size of the Vid universe."
"Reach ($reach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}
require(reachNoiseVariance >= 0) {
"Reach noise variance must be a non-negative value, but got $reachNoiseVariance."
}
require(reachNoiseVariance >= 0) { "Reach noise variance must be a non-negative value." }

val reachVariance =
(vidSamplingIntervalWidth *
Expand Down Expand Up @@ -66,14 +72,18 @@ object HonestMajorityShareShuffle {
val vidSamplingIntervalWidth = frequencyMeasurementParams.vidSamplingInterval.width

val vidUniverseSize = ceil(frequencyVectorSize / vidSamplingIntervalWidth).toLong()
require(vidUniverseSize > 1) { "Vid universe size is too small." }
require(vidUniverseSize > 1) {
"Vid universe size must be greater than 1, but got $vidUniverseSize."
}
require(totalReach <= vidUniverseSize) {
"Reach must be less than or equal to the size of the Vid universe."
"Total reach ($totalReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

val kReach = (kReachRatio * totalReach).toLong()
require(kReach <= vidUniverseSize) {
"kReach must be less than or equal to the size of the Vid universe."
"kReach ($kReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

var kReachVariance =
Expand Down Expand Up @@ -113,14 +123,18 @@ object HonestMajorityShareShuffle {
val vidSamplingIntervalWidth = frequencyMeasurementParams.vidSamplingInterval.width

val vidUniverseSize = ceil(frequencyVectorSize / vidSamplingIntervalWidth).toLong()
require(vidUniverseSize > 1) { "Vid universe size is too small." }
require(vidUniverseSize > 1) {
"Vid universe size must be greater than 1, but got $vidUniverseSize."
}
require(totalReach <= vidUniverseSize) {
"Reach must be less than or equal to the size of the Vid universe."
"Total reach ($totalReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

val kPlusReach = (kReachRatio * totalReach).toLong()
require(kPlusReach <= vidUniverseSize) {
"kPlusReach must be less than or equal to the size of the Vid universe."
"kPlusReach ($kPlusReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

// Gets the reach noise variance from the reach measurement variance and total reach.
Expand Down Expand Up @@ -167,14 +181,18 @@ object HonestMajorityShareShuffle {
val vidSamplingIntervalWidth = frequencyMeasurementParams.vidSamplingInterval.width

val vidUniverseSize = ceil(frequencyVectorSize / vidSamplingIntervalWidth).toLong()
require(vidUniverseSize > 1) { "Vid universe size is too small." }
require(vidUniverseSize > 1) {
"Vid universe size must be greater than 1, but got $vidUniverseSize."
}
require(totalReach <= vidUniverseSize) {
"Reach must be less than or equal to the size of the Vid universe."
"Total reach ($totalReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

val kReach = (kReachRatio * totalReach).toLong()
require(kReach <= vidUniverseSize) {
"kReach must be less than or equal to the size of the Vid universe."
"kReach ($kReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

var kReachVariance =
Expand Down Expand Up @@ -235,14 +253,18 @@ object HonestMajorityShareShuffle {
val vidSamplingIntervalWidth = frequencyMeasurementParams.vidSamplingInterval.width

val vidUniverseSize = ceil(frequencyVectorSize / vidSamplingIntervalWidth).toLong()
require(vidUniverseSize > 1) { "Vid universe size is too small." }
require(vidUniverseSize > 1) {
"Vid universe size must be greater than 1, but got $vidUniverseSize."
}
require(totalReach <= vidUniverseSize) {
"Reach must be less than or equal to the size of the Vid universe."
"Total reach ($totalReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

val kPlusReach = (kPlusReachRatio * totalReach).toLong()
require(kPlusReach <= vidUniverseSize) {
"kPlusReach must be less than or equal to the size of the Vid universe."
"kPlusReach ($kPlusReach) must be less than or equal to the size of the Vid universe " +
"($vidUniverseSize)."
}

// Gets the reach noise variance from the reach measurement variance and total reach.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ data class LiquidLegionsSketchParams(val decayRate: Double, val sketchSize: Long
}
}

/** Verifies that the liquid legions sketch params are valid. */
fun verifyLiquidLegionsSketchParams(sketchParams: LiquidLegionsSketchParams) {
require(sketchParams.sketchSize > 0) {
"Sketch size must be positive, but got ${sketchParams.sketchSize}."
}
require(sketchParams.decayRate > 0) {
"Decay rate must be positive, but got ${sketchParams.decayRate}."
}
}

/** Functions to compute statistics of Liquid Legions sketch based measurements. */
object LiquidLegions {
/** Exponential integral function for negative inputs. */
Expand Down Expand Up @@ -90,6 +100,7 @@ object LiquidLegions {
y: Double,
liquidLegionsSketchParams: LiquidLegionsSketchParams,
): Double {
verifyLiquidLegionsSketchParams(liquidLegionsSketchParams)
if (k < 0.0) {
throw IllegalArgumentException("Invalid inputs: k=$k < 0.")
}
Expand Down Expand Up @@ -128,6 +139,21 @@ object LiquidLegions {
overlapSamplingWidth: Double,
inflation: Double = 0.0,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
require(reach > 0) { "Reach must be positive, but got $reach." }
require(otherReach > 0) { "Other reach must be positive, but got $otherReach." }
require(samplingWidth > 0.0 && samplingWidth <= 1.0) {
"Sampling width must be greater than 0 and less than or equal to 1, but got $samplingWidth."
}
require(otherSamplingWidth > 0.0 && otherSamplingWidth <= 1.0) {
"Other sampling width must be greater than 0 and less than or equal to 1, but got " +
"$otherSamplingWidth."
}
require(overlapSamplingWidth >= 0.0 && overlapSamplingWidth <= 1.0) {
"Other sampling width must be greater than or equal to 0 and less than or equal to 1, but " +
"got $overlapSamplingWidth."
}

val y1 = max(1.0, reach * samplingWidth)
val y2 = max(1.0, otherReach * otherSamplingWidth)
val y12 = max(1.0, overlapReach * overlapSamplingWidth)
Expand Down Expand Up @@ -160,6 +186,12 @@ object LiquidLegions {
totalReach: Long,
vidSamplingIntervalWidth: Double,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
require(totalReach > 0) { "Total reach must be positive, but got $totalReach." }
require(vidSamplingIntervalWidth > 0.0 && vidSamplingIntervalWidth <= 1.0) {
"Vid sampling width must be greater than 0 and less than or equal to 1, but got " +
"$vidSamplingIntervalWidth."
}
// Expected sampled reach
val expectedReach = totalReach * vidSamplingIntervalWidth
if (expectedReach < 2.0) {
Expand All @@ -179,6 +211,12 @@ object LiquidLegions {
totalReach: Long,
vidSamplingIntervalWidth: Double,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
require(totalReach > 0) { "Total reach must be positive, but got $totalReach" }
require(vidSamplingIntervalWidth > 0.0 && vidSamplingIntervalWidth <= 1.0) {
"Vid sampling width must be greater than 0 and less than or equal to 1, but got " +
"$vidSamplingIntervalWidth."
}
// Expected sampled reach
val expectedReach = totalReach * vidSamplingIntervalWidth
// The mathematical formulas below assume the sampled reach >= 3. If sampled reach < 3, the
Expand Down Expand Up @@ -226,6 +264,16 @@ object LiquidLegions {
reachRatio: Double,
vidSamplingIntervalWidth: Double,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
require(totalReach > 0) { "Total reach must be positive, but got $totalReach." }
require(reachRatio >= 0.0 && reachRatio <= 1.0) {
"Reach ratio must be greater than or equal to 0 and less than or equal to 1, but got " +
"$reachRatio."
}
require(vidSamplingIntervalWidth > 0.0 && vidSamplingIntervalWidth <= 1.0) {
"Vid sampling width must be greater than 0 and less than or equal to 1, but got " +
"$vidSamplingIntervalWidth."
}
val tmp = reachRatio * (1.0 - reachRatio) / totalReach
val expectedRegisterNum =
expectedNumberOfNonDestroyedRegisters(
Expand Down Expand Up @@ -255,6 +303,7 @@ object LiquidLegions {
frequencyNoiseVariance: Double,
relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams,
): Double {
verifyLiquidLegionsSketchParams(sketchParams)
val (
totalReach: Long,
reachMeasurementVariance: Double,
Expand Down
Loading

0 comments on commit ae488a6

Please sign in to comment.