From ddad75bc8173e178187e73fb2e014f871f16ac0a Mon Sep 17 00:00:00 2001 From: Hyerin Park Date: Mon, 25 Mar 2024 06:19:17 +0000 Subject: [PATCH] LibConfigGenerator: add handling when LibConfigDomain is empty --- .../scala/fhetest/Generate/AbsProgram.scala | 17 -- .../Generate/AbsProgramGenerator.scala | 21 +- .../fhetest/Generate/LibConfigGenerator.scala | 208 ++++++++++-------- .../scala/fhetest/Generate/ValidFilter.scala | 19 +- src/main/scala/fhetest/Phase/Generate.scala | 31 --- 5 files changed, 141 insertions(+), 155 deletions(-) diff --git a/src/main/scala/fhetest/Generate/AbsProgram.scala b/src/main/scala/fhetest/Generate/AbsProgram.scala index ccf4395..4e27347 100644 --- a/src/main/scala/fhetest/Generate/AbsProgram.scala +++ b/src/main/scala/fhetest/Generate/AbsProgram.scala @@ -17,23 +17,6 @@ case class AbsProgram( case Mul(_, _) | MulP(_, _) => true; case _ => false } - // TODO: Change these filters to assertions? - // lazy val isValid: Boolean = - // mulDepthIsSmall(mulDepth, encParams.mulDepth) && - // firstModSizeIsLargest(libConfig.firstModSize, libConfig.scalingModSize) && - // modSizeIsUpto60bits(libConfig.firstModSize, libConfig.scalingModSize) && - // openFHEBFVModuli( - // libConfig.scheme, - // libConfig.firstModSize, - // libConfig.scalingModSize, - // ) && - // ringDimIsPowerOfTwo(encParams.ringDim) && - // plainModIsPositive(encParams.plainMod) && - // plainModEnableBatching(encParams.plainMod, encParams.ringDim) && - // lenIsLessThanRingDim(len, encParams.ringDim, libConfig.scheme) && - // boundIsLessThanPowerOfModSize(bound, libConfig.firstModSize) && - // boundIsLessThanPlainMod(bound, encParams.plainMod) - def stringify: String = absStmts.map(_.stringify()).mkString("") def assignRandValues(): AbsProgram = { diff --git a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala index e6763ac..4fc5a5c 100644 --- a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala +++ b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala @@ -36,18 +36,24 @@ case class ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) for { stmt <- allAbsStmts libConfigGen <- libConfigGens + stmts = List(stmt) + libConfigOpt = libConfigGen(stmts) + if libConfigOpt.isDefined } yield { - val stmts = List(stmt) - AbsProgram(stmts, libConfigGen(stmts)) + val libConfig = libConfigOpt.get + AbsProgram(stmts, libConfig) } case _ => for { stmt <- allAbsStmts program <- allAbsProgramsOfSize(n - 1) libConfigGen <- libConfigGens + stmts = stmt :: program.absStmts + libConfigOpt = libConfigGen(stmts) + if libConfigOpt.isDefined } yield { - val stmts = stmt :: program.absStmts - AbsProgram(stmts, libConfigGen(stmts)) + val libConfig = libConfigOpt.get + AbsProgram(stmts, libConfig) } } LazyList.from(1).flatMap(allAbsProgramsOfSize) @@ -72,9 +78,12 @@ case class RandomGenerator(encType: ENC_TYPE, validFilter: Boolean) for { len <- randomLength libConfigGen <- libConfigGens + stmts = randomAbsStmtsOfSize(len) + libConfigOpt = libConfigGen(stmts) + if libConfigOpt.isDefined } yield { - val stmts = randomAbsStmtsOfSize(len) - AbsProgram(stmts, libConfigGen(stmts)) + val libConfig = libConfigOpt.get + AbsProgram(stmts, libConfig) } } } diff --git a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala index 2621fa4..07c2298 100644 --- a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala +++ b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala @@ -1,8 +1,9 @@ package fhetest.Generate import fhetest.Utils.* -import scala.util.Random import fhetest.Generate.Utils.combinations +import scala.util.Random +import scala.util.control.Breaks._ val ringDimCandidates: List[Int] = // also in ValidFilter List(8192, 16384, 32768) @@ -38,7 +39,7 @@ def getLibConfigUniverse(scheme: Scheme) = LibConfigDomain( ) trait LibConfigGenerator(encType: ENC_TYPE) { - def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] + def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]] val validFilters = classOf[ValidFilter].getDeclaredClasses.toList .filter { cls => classOf[ValidFilter] @@ -48,7 +49,7 @@ trait LibConfigGenerator(encType: ENC_TYPE) { case class ValidLibConfigGenerator(encType: ENC_TYPE) extends LibConfigGenerator(encType) { - def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = { + def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]] = { val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { val randomScheme = if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) @@ -83,108 +84,129 @@ case class InvalidLibConfigGenerator(encType: ENC_TYPE) // TODO: currently generate only 1 test case for each class // val numOfTC = 10 val allCombinations_lazy = LazyList.from(allCombinations) - def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = for { - combination <- allCombinations_lazy - } yield { - val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { - val randomScheme = - if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) - else Scheme.CKKS - val libConfigUniverse = getLibConfigUniverse(randomScheme) - val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({ - case (curLibConfigDomain, curValidFilter) => { - val curValidFilterIdx = validFilters.indexOf(curValidFilter) - val inInValid = combination.contains(curValidFilterIdx) - val constructor = curValidFilter.getDeclaredConstructors.head - constructor.setAccessible(true) - val f = constructor - .newInstance(curLibConfigDomain, !inInValid) - .asInstanceOf[ValidFilter] - f.getFilteredLibConfigDomain() + def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]] = + for { + combination <- allCombinations_lazy + } yield { + println(combination) + val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { + val randomScheme = + if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) + else Scheme.CKKS + val libConfigUniverse = getLibConfigUniverse(randomScheme) + val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({ + case (curLibConfigDomain, curValidFilter) => { + val curValidFilterIdx = validFilters.indexOf(curValidFilter) + val inInValid = combination.contains(curValidFilterIdx) + val constructor = curValidFilter.getDeclaredConstructors.head + constructor.setAccessible(true) + val f = constructor + .newInstance(curLibConfigDomain, !inInValid) + .asInstanceOf[ValidFilter] + f.getFilteredLibConfigDomain() + } + }) + val res = randomLibConfigFromDomain( + false, + absStmts, + randomScheme, + filteredLibConfigDomain, + ) + res match { + case None => println("NO DOMAIN") + case Some(_) => () } - }) - randomLibConfigFromDomain( - false, - absStmts, - randomScheme, - filteredLibConfigDomain, - ) + res + } + libConfigGeneratorFromAbsStmts } - libConfigGeneratorFromAbsStmts - } } -// TODO: No handling for empty domain def randomLibConfigFromDomain( validFilter: Boolean, absStmts: List[AbsStmt], randomScheme: Scheme, filteredLibConfigDomain: LibConfigDomain, -): LibConfig = { - val randomRingDim = Random.shuffle(filteredLibConfigDomain.ringDim).head - val randomMulDepth = { - val realMulDepth: Int = absStmts.count { - case Mul(_, _) | MulP(_, _) => true; case _ => false +): Option[LibConfig] = { + var result: Option[LibConfig] = None + breakable { + def getRandomElementOrBreak[T](list: List[T]): T = { + val elem = + if (list.nonEmpty) Some(Random.shuffle(list).head) + else None + elem getOrElse { break } } - Random.shuffle((filteredLibConfigDomain.mulDepth)(realMulDepth)).head - } - val randomPlainMod = - Random.shuffle((filteredLibConfigDomain.plainMod)(randomRingDim)).head - val randomFirstModSize = - Random - .shuffle((filteredLibConfigDomain.firstModSize)(randomScheme)) - .head - val randomScalingModSize = Random - .shuffle( - (filteredLibConfigDomain.scalingModSize)(randomScheme)( + val randomRingDim = getRandomElementOrBreak(filteredLibConfigDomain.ringDim) + val randomMulDepth = { + val realMulDepth: Int = absStmts.count { + case Mul(_, _) | MulP(_, _) => true; case _ => false + } + println(s"realMulDepth: $realMulDepth") + getRandomElementOrBreak( + (filteredLibConfigDomain.mulDepth)(realMulDepth), + ) + } + val randomPlainMod = getRandomElementOrBreak( + (filteredLibConfigDomain.plainMod)(randomRingDim), + ) + val randomFirstModSize = getRandomElementOrBreak( + (filteredLibConfigDomain.firstModSize)(randomScheme), + ) + val randomScalingModSize = getRandomElementOrBreak( + (filteredLibConfigDomain.scalingModSize)(randomScheme)(randomFirstModSize), + ) + val randomSecurityLevel = + getRandomElementOrBreak(filteredLibConfigDomain.securityLevel) + val randomScalingTechnique = getRandomElementOrBreak( + (filteredLibConfigDomain.scalingTechnique)(randomScheme), + ) + val randomLenOpt: Option[Int] = { + val upper = + (filteredLibConfigDomain.lenMax)(randomScheme)(randomRingDim) + val lower = + (filteredLibConfigDomain.lenMin)(randomScheme)(randomRingDim) + if (lower > upper) break + else Some(Random.between(lower, upper + 1)) + } + val randomBoundOpt: Option[Int | Double] = { + val upper = (filteredLibConfigDomain.boundMax)(randomScheme)( + randomPlainMod, + )(randomFirstModSize) + val lower = (filteredLibConfigDomain.boundMin)(randomScheme)( + randomPlainMod, + )(randomFirstModSize) + lower match { + case li: Int => + upper match { + case ui: Int => + if (li > ui) break else Some(Random.between(li, ui + 1)) + case _ => Some(Random.between(1, 100000 + 1)) // unreachable + } + case ld: Double => + upper match { + case ud: Int => + if (ld > ud) break else Some(Random.between(ld, ud)) + case _ => Some(Random.between(1, math.pow(2, 64))) // unreachable + } + } + } + val randomRotateBoundOpt: Option[Int] = + val r = getRandomElementOrBreak(filteredLibConfigDomain.rotateBound) + Some(r) + + result = Some( + LibConfig( + randomScheme, + EncParams(randomRingDim, randomMulDepth, randomPlainMod), randomFirstModSize, + randomScalingModSize, + randomSecurityLevel, + randomScalingTechnique, + randomLenOpt, + randomBoundOpt, + randomRotateBoundOpt, ), ) - .head - val randomSecurityLevel = - Random.shuffle(filteredLibConfigDomain.securityLevel).head - val randomScalingTechnique = Random - .shuffle((filteredLibConfigDomain.scalingTechnique)(randomScheme)) - .head - val randomLenOpt: Option[Int] = { - val upper = - (filteredLibConfigDomain.lenMax)(randomScheme)(randomRingDim) - val lower = - (filteredLibConfigDomain.lenMin)(randomScheme)(randomRingDim) - Some(Random.between(lower, upper + 1)) } - val randomBoundOpt: Option[Int | Double] = { - val upper = (filteredLibConfigDomain.boundMax)(randomScheme)( - randomPlainMod, - )(randomFirstModSize) - val lower = (filteredLibConfigDomain.boundMin)(randomScheme)( - randomPlainMod, - )(randomFirstModSize) - lower match { - case li: Int => - upper match { - case ui: Int => Some(Random.between(li, ui + 1)) - case _ => Some(Random.between(1, 100000 + 1)) // unreachable - } - case ld: Double => - upper match { - case ud: Int => Some(Random.between(ld, ud)) - case _ => Some(Random.between(1, math.pow(2, 64))) // unreachable - } - } - } - val randomRotateBoundOpt: Option[Int] = - Some(Random.shuffle(filteredLibConfigDomain.rotateBound).head) - - LibConfig( - randomScheme, - EncParams(randomRingDim, randomMulDepth, randomPlainMod), - randomFirstModSize, - randomScalingModSize, - randomSecurityLevel, - randomScalingTechnique, - randomLenOpt, - randomBoundOpt, - randomRotateBoundOpt, - ) + result } diff --git a/src/main/scala/fhetest/Generate/ValidFilter.scala b/src/main/scala/fhetest/Generate/ValidFilter.scala index ef54367..24a0a5c 100644 --- a/src/main/scala/fhetest/Generate/ValidFilter.scala +++ b/src/main/scala/fhetest/Generate/ValidFilter.scala @@ -10,17 +10,17 @@ import fhetest.Checker.schemeDecoder // * defined & used in LibConfigGenerator // * automatically arranged in alphabetical order // val validFilters = List( -// FilterBoundIsLessThanPlainMod, -// FilterBoundIsLessThanPowerOfModSize, -// FilterFirstModSizeIsLargest, -// FilterLenIsLessThanRingDim, -// FilterModSizeIsBeteween14And60bits, -// FilterMulDepthIsEnough, -// FilterOpenFHEBFVModuli, +// FilterBoundIsLessThanPlainMod, // 0 +// FilterBoundIsLessThanPowerOfModSize, // 1 +// FilterFirstModSizeIsLargest, // 2 +// FilterLenIsLessThanRingDim, // 3 +// FilterModSizeIsBeteween14And60bits, // 4 +// FilterMulDepthIsEnough, // 5 +// FilterOpenFHEBFVModuli, // 6 // FilterPlainModEnableBatching, /* commented */ // FilterPlainModIsPositive, /* commented */ // FilterRingDimIsPowerOfTwo, /* commented */ -// FilterScalingTechniqueByScheme +// FilterScalingTechniqueByScheme // 7 // ) trait ValidFilter(prev: LibConfigDomain, validFilter: Boolean) { @@ -62,6 +62,9 @@ object ValidFilter { ) } + // TODO: There are 2 options for this implementation + // * Current implementation filters scalingModSize which is not greater than firstModSize + // * Another option is to filter firstModSize to be not smaller than scalingModeSize // def firstModSizeIsLargest(firstModSize: Int, scalingModSize: Int): Boolean = // scalingModSize <= firstModSize case class FilterFirstModSizeIsLargest( diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index 50a8503..ada9e39 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -46,38 +46,7 @@ case class Generate( val adjusted = assigned.adjustScale(encType) adjusted } - val resultAbsPrograms: LazyList[AbsProgram] = adjustedAbsPrograms - // val resultAbsPrograms: LazyList[AbsProgram] = if (validFilter) { - // adjustedAbsPrograms.filter(_.isValid) - // } else { - // // val numOfValidFilter = 10 - // // val programsWithEquivClasses: LazyList[(AbsProgram, List[Boolean])] = - // // adjustedAbsPrograms.map({ pgm => - // // (pgm, pgm.getInvalidEquivClassList()) - // // }) - // // def filterSequencially( - // // absPrograms: LazyList[(AbsProgram, List[Boolean])], - // // idx: Int, - // // ): LazyList[AbsProgram] = - // // if (absPrograms.isEmpty) - // // LazyList.empty // unreachable - // // else if (idx == numOfValidFilter) filterSequencially(absPrograms, 0) - // // else { - // // val (pgm, equivClassList) = absPrograms.head - // // val equivClass = equivClassList.apply(idx) - // // if (equivClass) - // // pgm #:: filterSequencially(absPrograms.tail, idx + 1) - // // else filterSequencially(absPrograms, idx + 1) - // // } - // // filterSequencially(programsWithEquivClasses, 0) - - // val equivClassIdx = LazyList.from(0) - // adjustedAbsPrograms - // .zip(equivClassIdx) - // .filter { case (pgm, idx) => pgm.invalidEquivClass(idx) } - // .map(_._1) - // } val takenResultAbsPrograms = nOpt match { case Some(n) => resultAbsPrograms.take(n) case None => resultAbsPrograms