diff --git a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala index d4322e5..2445065 100644 --- a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala +++ b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala @@ -8,9 +8,11 @@ import org.apache.spark.util.Utils object functions { private def withExpr(expr: Expression): Column = Column(expr) - def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") - def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) - def randGamma(): Column = randGamma(1.0, 1.0) + def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") + def randGamma(seed: Column, shape: Column, scale: Column): Column = withExpr(RandGamma(seed.expr, shape.expr, scale.expr)).alias("gamma_random") + def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) + def randGamma(shape: Column, scale: Column): Column = randGamma(lit(Utils.random.nextLong), shape, scale) + def randGamma(): Column = randGamma(1.0, 1.0) def randLaplace(seed: Long, mu: Double, beta: Double): Column = { val mu_ = lit(mu) diff --git a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala index 5147c13..b529c1e 100644 --- a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala +++ b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala @@ -27,6 +27,28 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar assert(math.abs(gammaMean - 4.0) < 0.5) assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5) } + + "has correct mean and standard deviation from shape/scale column" - { + val sourceDF = spark + .range(100000) + .withColumn("shape", lit(2.0)) + .withColumn("scale", lit(2.0)) + .select(randGamma(col("shape"), col("shape"))) + val stats = sourceDF + .agg( + mean("gamma_random").as("mean"), + stddev("gamma_random").as("stddev") + ) + .collect()(0) + + val gammaMean = stats.getAs[Double]("mean") + val gammaStddev = stats.getAs[Double]("stddev") + + // Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0) + assert(gammaMean > 0) + assert(math.abs(gammaMean - 4.0) < 0.5) + assert(math.abs(gammaStddev - math.sqrt(8.0)) < 0.5) + } } 'rand_laplace - {