Skip to content

Commit

Permalink
feat: [+] decimal round with mode & (array) mkString functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed May 9, 2023
1 parent 8eef33d commit 9d00961
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 52 deletions.
39 changes: 39 additions & 0 deletions core/src/main/scala/doric/syntax/ArrayColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,45 @@ protected trait ArrayColumns {
def list[T](cols: DoricColumn[T]*): DoricColumn[List[T]] =
cols.toList.traverse(_.elem).map(f.array(_: _*)).toDC

/**
* Extension methods for arrays of strings
*
* @group Array Type
*/
implicit class ArrayStringColumnSyntax[F[_]: CollectionType](
private val col: DoricColumn[F[String]]
) {

/**
* Concatenates each string element with the separator column into a new string column
*
* @note even if `cols` contain null columns, it prints remaining string columns (or empty string).
* @note if separator column is null, result will be null
* @example {{{
* df.withColumn("res", colArrayString("col1").mkString(colString("col2")))
* .show(false)
* +------+----+----+
* |col1 |col2|res |
* +------+----+----+
* |[a, b]|, |a,b |
* |[a, b]||| |a||b|
* |[a] |, |a |
* |[] |, | |
* |[a, b]|null|null|
* |null |, | |
* |null |null|null|
* +------+----+----+
* }}}
* @group Array Type
* @see [[org.apache.spark.sql.functions.concat_ws]]
*/
def mkString(sep: StringColumn): StringColumn =
(col.elem, sep.elem)
.mapN((c, s) => new Column(ConcatWs(Seq(s.expr, c.expr))))
.toDC

}

/**
* Extension methods for arrays
*
Expand Down
32 changes: 30 additions & 2 deletions core/src/main/scala/doric/syntax/NumericColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ package syntax
import cats.implicits._
import doric.DoricColumn.sparkFunction
import doric.types.{CollectionType, NumericType}
import org.apache.spark.sql.catalyst.expressions.{BRound, FormatNumber, FromUnixTime, Rand, Randn, Round, UnaryMinus}
import org.apache.spark.sql.catalyst.expressions.{BRound, Expression, FormatNumber, FromUnixTime, Rand, Randn, Round, RoundBase, UnaryMinus}
import org.apache.spark.sql.{Column, functions => f}

import scala.math.BigDecimal.RoundingMode.RoundingMode

protected trait NumericColumns {

/**
Expand Down Expand Up @@ -586,7 +588,8 @@ protected trait NumericColumns {
def round: DoricColumn[T] = column.elem.map(f.round).toDC

/**
* Returns the value of the column e rounded to 0 decimal places with HALF_UP round mode.
* Round the value to `scale` decimal places with HALF_UP round mode
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @todo decimal type
* @group Numeric Type
Expand All @@ -596,6 +599,31 @@ protected trait NumericColumns {
.mapN((c, s) => new Column(Round(c.expr, s.expr)))
.toDC

/**
* DORIC EXCLUSIVE! Round the value to `scale` decimal places with given round `mode`
* if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @todo decimal type
* @group Numeric Type
*/
def round(scale: IntegerColumn, mode: RoundingMode): DoricColumn[T] = {
case class DoricRound(
child: Expression,
scale: Expression,
mode: RoundingMode
) extends RoundBase(child, scale, mode, s"ROUND_$mode") {
override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression
): DoricRound =
copy(child = newLeft, scale = newRight)
}

(column.elem, scale.elem)
.mapN((c, s) => new Column(DoricRound(c.expr, s.expr, mode)))
.toDC
}

/**
* Returns col1 if it is not NaN, or col2 if col1 is NaN.
*
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ class ArrayColumnsSpec extends DoricTestElements {

import spark.implicits._

describe("mkString doric function") {
it("should concat be equivalent to concat_ws spark function") {
val df = List(
Array("a", "b"),
Array("a"),
Array.empty[String],
null
).toDF("col1")

df.testColumns2("col1", ",")(
(c, sep) => colArrayString(c).mkString(sep.lit),
(c, sep) => f.concat_ws(sep, f.col(c)),
List(Some("a,b"), Some("a"), Some(""), Some(""))
)
}
}

describe("ArrayOps") {
val result = "result"
val testColumn = "col"
Expand Down
Loading

0 comments on commit 9d00961

Please sign in to comment.