Skip to content

Commit

Permalink
[SEDONA 710] Rename Geostats SQL classes to generic name; merge UdfRe…
Browse files Browse the repository at this point in the history
…gistrator into AbstractCatalog (#1809)

Co-authored-by: jameswillis <[email protected]>
  • Loading branch information
james-willis and jameswillis authored Feb 13, 2025
1 parent ae4b1c2 commit 3a3b8d3
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 283 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ package org.apache.sedona.spark
import org.apache.sedona.common.utils.TelemetryCollector
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDF.Catalog
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.sedona_sql.optimization.{ExtractGeoStatsFunctions, SpatialFilterPushDownForGeoParquet, SpatialTemporalFilterPushDownForStacScan}
import org.apache.spark.sql.sedona_sql.strategy.geostats.EvalGeoStatsFunctionStrategy
import org.apache.spark.sql.sedona_sql.optimization._
import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
import org.apache.spark.sql.sedona_sql.strategy.physical.function.EvalPhysicalFunctionStrategy
import org.apache.spark.sql.{SQLContext, SparkSession}

import scala.annotation.StaticAnnotation
Expand Down Expand Up @@ -73,20 +73,21 @@ object SedonaContext {
}
}

// Support geostats functions
if (!sparkSession.experimental.extraOptimizations.contains(ExtractGeoStatsFunctions)) {
sparkSession.experimental.extraOptimizations ++= Seq(ExtractGeoStatsFunctions)
// Support physical functions
if (!sparkSession.experimental.extraOptimizations.contains(ExtractPhysicalFunctions)) {
sparkSession.experimental.extraOptimizations ++= Seq(ExtractPhysicalFunctions)
}

if (!sparkSession.experimental.extraStrategies.exists(
_.isInstanceOf[EvalGeoStatsFunctionStrategy])) {
_.isInstanceOf[EvalPhysicalFunctionStrategy])) {
sparkSession.experimental.extraStrategies ++= Seq(
new EvalGeoStatsFunctionStrategy(sparkSession))
new EvalPhysicalFunctionStrategy(sparkSession))
}

addGeoParquetToSupportNestedFilterSources(sparkSession)
RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
Catalog.registerAll(sparkSession)
sparkSession
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.{SQLContext, SparkSession, functions}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
Expand Down Expand Up @@ -74,4 +75,28 @@ abstract class AbstractCatalog {

(functionIdentifier, expressionInfo, functionBuilder)
}

def registerAll(sqlContext: SQLContext): Unit = {
registerAll(sqlContext.sparkSession)
}

def registerAll(sparkSession: SparkSession): Unit = {
Catalog.expressions.foreach { case (functionIdentifier, expressionInfo, functionBuilder) =>
sparkSession.sessionState.functionRegistry.registerFunction(
functionIdentifier,
expressionInfo,
functionBuilder)
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
}

def dropAll(sparkSession: SparkSession): Unit = {
Catalog.expressions.foreach { case (functionIdentifier, _, _) =>
sparkSession.sessionState.functionRegistry.dropFunction(functionIdentifier)
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.sessionState.functionRegistry.dropFunction(
FunctionIdentifier(f.getClass.getSimpleName)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _}
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.expressions._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ListBuffer

object Catalog extends AbstractCatalog {

override val expressions: Seq[FunctionDescription] = Seq(
Expand Down Expand Up @@ -344,9 +342,5 @@ object Catalog extends AbstractCatalog {
function[ST_WeightedDistanceBandColumn]())

val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)

// Aggregate functions with List as buffer
val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], Geometry]] =
Seq(new ST_Union_Aggr())
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr())
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.sedona.sql.utils

import org.apache.sedona.spark.SedonaContext
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDF.Catalog
import org.apache.spark.sql.{SQLContext, SparkSession}

@deprecated("Use SedonaContext instead", "1.4.1")
Expand All @@ -44,7 +44,7 @@ object SedonaSQLRegistrator {
SedonaContext.create(sparkSession, language)

def dropAll(sparkSession: SparkSession): Unit = {
UdfRegistrator.dropAll(sparkSession)
Catalog.dropAll(sparkSession)
RasterRegistrator.dropAll(sparkSession)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,80 +23,13 @@ import org.apache.sedona.stats.Weighting.{addBinaryDistanceBandColumn, addWeight
import org.apache.sedona.stats.clustering.DBSCAN.dbscan
import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal
import org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ImplicitCastInputTypes, Literal, ScalarSubquery, Unevaluable}
import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

import scala.reflect.ClassTag

// We mark ST_GeoStatsFunction as non-deterministic to avoid the filter push-down optimization pass
// duplicates the ST_GeoStatsFunction when pushing down aliased ST_GeoStatsFunction through a
// Project operator. This will make ST_GeoStatsFunction being evaluated twice.
trait ST_GeoStatsFunction
extends Expression
with ImplicitCastInputTypes
with Unevaluable
with Serializable {

final override lazy val deterministic: Boolean = false

override def nullable: Boolean = true

private final lazy val sparkSession = SparkSession.getActiveSession.get

protected final lazy val geometryColumnName = getInputName(0, "geometry")

protected def getInputName(i: Int, fieldName: String): String = children(i) match {
case ref: AttributeReference => ref.name
case _ =>
throw new IllegalArgumentException(
f"$fieldName argument must be a named reference to an existing column")
}

protected def getInputNames(i: Int, fieldName: String): Seq[String] = children(
i).dataType match {
case StructType(fields) => fields.map(_.name)
case _ => throw new IllegalArgumentException(f"$fieldName argument must be a struct")
}

protected def getResultName(resultAttrs: Seq[Attribute]): String = resultAttrs match {
case Seq(attr) => attr.name
case _ => throw new IllegalArgumentException("resultAttrs must have exactly one attribute")
}

protected def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame

protected def getScalarValue[T](i: Int, name: String)(implicit ct: ClassTag[T]): T = {
children(i) match {
case Literal(l: T, _) => l
case _: Literal =>
throw new IllegalArgumentException(f"$name must be an instance of ${ct.runtimeClass}")
case s: ScalarSubquery =>
s.eval() match {
case t: T => t
case _ =>
throw new IllegalArgumentException(
f"$name must be an instance of ${ct.runtimeClass}")
}
case _ => throw new IllegalArgumentException(f"$name must be a scalar value")
}
}

def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow] = {
val df = doExecute(
Dataset.ofRows(sparkSession, LogicalRDD(plan.output, plan.execute())(sparkSession)),
resultAttrs)
df.queryExecution.toRdd
}

}

case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
case class ST_DBSCAN(children: Seq[Expression]) extends DataframePhysicalFunction {

override def dataType: DataType = StructType(
Seq(StructField("isCore", BooleanType), StructField("cluster", LongType)))
Expand All @@ -107,7 +40,9 @@ case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame = {
override def transformDataframe(
dataframe: DataFrame,
resultAttrs: Seq[Attribute]): DataFrame = {
require(
!dataframe.columns.contains("__isCore"),
"__isCore is a reserved name by the dbscan algorithm. Please rename the columns before calling the ST_DBSCAN function.")
Expand All @@ -129,7 +64,7 @@ case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
}
}

case class ST_LocalOutlierFactor(children: Seq[Expression]) extends ST_GeoStatsFunction {
case class ST_LocalOutlierFactor(children: Seq[Expression]) extends DataframePhysicalFunction {

override def dataType: DataType = DoubleType

Expand All @@ -139,7 +74,9 @@ case class ST_LocalOutlierFactor(children: Seq[Expression]) extends ST_GeoStatsF
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame = {
override def transformDataframe(
dataframe: DataFrame,
resultAttrs: Seq[Attribute]): DataFrame = {
localOutlierFactor(
dataframe,
getScalarValue[Int](1, "k"),
Expand All @@ -150,7 +87,7 @@ case class ST_LocalOutlierFactor(children: Seq[Expression]) extends ST_GeoStatsF
}
}

case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
case class ST_GLocal(children: Seq[Expression]) extends DataframePhysicalFunction {

override def dataType: DataType = StructType(
Seq(
Expand All @@ -172,7 +109,9 @@ case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame = {
override def transformDataframe(
dataframe: DataFrame,
resultAttrs: Seq[Attribute]): DataFrame = {
gLocal(
dataframe,
getInputName(0, "x"),
Expand All @@ -187,7 +126,8 @@ case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
}
}

case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends ST_GeoStatsFunction {
case class ST_BinaryDistanceBandColumn(children: Seq[Expression])
extends DataframePhysicalFunction {
override def dataType: DataType = ArrayType(
StructType(
Seq(StructField("neighbor", children(5).dataType), StructField("value", DoubleType))))
Expand All @@ -198,7 +138,9 @@ case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends ST_Geo
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame = {
override def transformDataframe(
dataframe: DataFrame,
resultAttrs: Seq[Attribute]): DataFrame = {
val attributeNames = getInputNames(5, "attributes")
require(attributeNames.nonEmpty, "attributes must have at least one column")
require(
Expand All @@ -217,7 +159,8 @@ case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends ST_Geo
}
}

case class ST_WeightedDistanceBandColumn(children: Seq[Expression]) extends ST_GeoStatsFunction {
case class ST_WeightedDistanceBandColumn(children: Seq[Expression])
extends DataframePhysicalFunction {

override def dataType: DataType = ArrayType(
StructType(
Expand All @@ -237,7 +180,9 @@ case class ST_WeightedDistanceBandColumn(children: Seq[Expression]) extends ST_G
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): DataFrame = {
override def transformDataframe(
dataframe: DataFrame,
resultAttrs: Seq[Attribute]): DataFrame = {
val attributeNames = getInputNames(7, "attributes")
require(attributeNames.nonEmpty, "attributes must have at least one column")
require(
Expand Down
Loading

0 comments on commit 3a3b8d3

Please sign in to comment.