diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 495078af..44ad7e0f 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -1,9 +1,12 @@ package com.github.mrpowers.spark.daria.sql.types import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.StructType import org.apache.spark.sql.functions._ +import scala.reflect.runtime.universe._ + object StructTypeHelpers { def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = { @@ -27,4 +30,19 @@ object StructTypeHelpers { }) } + /** + * gets a StructType from a Scala type and + * transforms field names from camel case to snake case + */ + def schemaFor[T: TypeTag]: StructType = { + val struct = ScalaReflection.schemaFor[T] + .dataType.asInstanceOf[StructType] + + struct.copy(fields = + struct.fields.map { field => + field.copy(name = com.github.mrpowers.spark.daria.utils.StringHelpers.camelCaseToSnakeCase(field.name)) + } + ) + } + } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpersTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpersTest.scala index 387c994d..0c3b7973 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpersTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpersTest.scala @@ -71,6 +71,20 @@ object StructTypeHelpersTest extends TestSuite { } + 'schemaFor - { + "gets schema from a scala Type" - { + val actualSchema = StructTypeHelpers.schemaFor[FooBar] + val expectedSchema = StructType(List( + StructField("foo", IntegerType, false), + StructField("bar", StringType), + StructField("foo_bar", ArrayType(IntegerType, false)) + )) + + assert(actualSchema == expectedSchema) + } + } + } + case class FooBar(foo: Int, bar: String, fooBar: Array[Int]) }