From e856a5e526a3c90e06556266158bad5e616bd14d Mon Sep 17 00:00:00 2001 From: MrPowers Date: Thu, 29 Nov 2018 11:53:48 -0800 Subject: [PATCH] Don't import DataFrame extensions in the transformations object --- .../spark/daria/sql/transformations.scala | 17 ++++++++++------- .../spark/daria/sql/TransformationsTest.scala | 5 ++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/transformations.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/transformations.scala index 2c048153..6f711015 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/transformations.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/transformations.scala @@ -1,6 +1,5 @@ package com.github.mrpowers.spark.daria.sql -import com.github.mrpowers.spark.daria.sql.DataFrameExt._ import com.github.mrpowers.spark.daria.sql.functions.truncate import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -48,7 +47,7 @@ object transformations { * }}} */ def sortColumns(order: String = "asc")(df: DataFrame): DataFrame = { - val cols = if (order == "asc") { + val colNames = if (order == "asc") { df.columns.sorted } else if (order == "desc") { df.columns.sorted.reverse @@ -57,7 +56,8 @@ object transformations { s"The sort order must be 'asc' or 'desc'. Your sort order was '$order'." throw new InvalidColumnSortOrderException(message) } - df.reorderColumns(cols) + val cols = colNames.map(col(_)) + df.select(cols: _*) } /** @@ -96,9 +96,12 @@ object transformations { * Example: SomeColumn -> some_column */ def camelCaseToSnakeCaseColumns()(df: DataFrame): DataFrame = - df.renameColumns( - com.github.mrpowers.spark.daria.utils.StringHelpers.camelCaseToSnakeCase - ) + df.columns.foldLeft(df) { (memoDF, colName) => + memoDF.withColumnRenamed( + colName, + com.github.mrpowers.spark.daria.utils.StringHelpers.camelCaseToSnakeCase(colName) + ) + } /** * Title Cases all the columns of a DataFrame @@ -199,7 +202,7 @@ object transformations { def truncateColumns(columnLengths: Map[String, Int])(df: DataFrame): DataFrame = { columnLengths.foldLeft(df) { case (memoDF, (colName, length)) => - if (memoDF.containsColumn(colName)) { + if (memoDF.schema.fieldNames.contains(colName)) { memoDF.withColumn( colName, truncate( diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/TransformationsTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/TransformationsTest.scala index 8c3e804c..d52b1904 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/TransformationsTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/TransformationsTest.scala @@ -1,13 +1,12 @@ package com.github.mrpowers.spark.daria.sql +import utest._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import utest._ +import org.apache.spark.sql.Row import com.github.mrpowers.spark.fast.tests.DataFrameComparer import com.github.mrpowers.spark.fast.tests.ColumnComparer import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ -import com.github.mrpowers.spark.daria.sql.DataFrameExt._ -import org.apache.spark.sql.Row object TransformationsTest extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper {