Skip to content

Commit

Permalink
Remove some override
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Sep 26, 2024
1 parent 50cbdfe commit 268abb2
Showing 1 changed file with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,21 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

override protected def initializeInternal(partitionIndex: Int): Unit = {
@transient private var distribution: GammaDistribution = _

protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}
@transient private var distribution: GammaDistribution = _

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

override def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)
def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

override protected def evalInternal(input: InternalRow): Double = distribution.sample()
protected def evalInternal(input: InternalRow): Double = distribution.sample()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
Expand All @@ -61,25 +62,25 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

override def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)
def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

override def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)
def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

override def dataType: DataType = DoubleType
def dataType: DataType = DoubleType

override def first: Expression = child
def first: Expression = child

override def second: Expression = shape
def second: Expression = shape

override def third: Expression = scale
def third: Expression = scale

override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

Expand Down

0 comments on commit 268abb2

Please sign in to comment.