Skip to content

Commit

Permalink
[SEDONA-467] Add Optimized join support for ST_DWithin (#1201)
Browse files Browse the repository at this point in the history
* Add ST_DWithin

* Add documentation for ST_DWithin

* Remove unwanted code

* removed null check test for ST_DWithin

* Fix EOF lint error

* Add explanation for ST_DWithin

* Remove CRS checking logic in ST_DWithin

* Add optimized join support for ST_DWithin

* Remove test change to resourceFolder

* remove unnecessary cast to double

* Add broadcast join test

* Add example of ST_DWithin in Optimizer.md
  • Loading branch information
iGN5117 authored Jan 16, 2024
1 parent 59cba11 commit 19c80a0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/api/sql/Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ FROM pointdf, polygondf
WHERE ST_Within(pointdf.pointshape, polygondf.polygonshape)
```

```sql
SELECT *
FROM pointdf, polygondf
WHERE ST_DWithin(pointdf.pointshape, polygondf.polygonshape, 10.0)
```

Spark SQL Physical plan:

```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(And(extraCondition, predicate: ST_Predicate)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(ST_DWithin(Seq(leftShape, rightShape, distance))) =>
Some(JoinQueryDetection(left, right, leftShape, ST_Buffer(Seq(rightShape, distance)), SpatialPredicate.INTERSECTS, isGeography = false, condition, None))
case Some(And(ST_DWithin(Seq(leftShape, rightShape, distance)), extraCondition)) =>
Some(JoinQueryDetection(left, right, leftShape, ST_Buffer(Seq(rightShape, distance)), SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(extraCondition)))
case Some(And(extraCondition, ST_DWithin(Seq(leftShape, rightShape, distance)))) =>
Some(JoinQueryDetection(left, right, leftShape, ST_Buffer(Seq(rightShape, distance)), SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(extraCondition)))
//For raster-vector joins
case Some(predicate: RS_Predicate) =>
getRasterJoinDetection(left, right, predicate, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,24 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(1 == joinDfLeftBroadcast.count()) // raster within its own convexHull

}

it("Passed ST_DWithin") {
val sampleCount = 200
val distance = 2.0
val polygonDf = buildPolygonDf.limit(sampleCount).repartition(3)
val pointDf = buildPointDf.limit(sampleCount).repartition(5)
val expected = bruteForceDWithin(sampleCount, distance)

var distanceJoinDF = pointDf.alias("pointDf").join(
broadcast(polygonDf).alias("polygonDF"), expr(s"ST_DWithin(pointDf.pointshape, polygonDf.polygonshape, $distance)"))
assert(distanceJoinDF.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size == 1)
assert(distanceJoinDF.count() == expected)

distanceJoinDF = broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), expr(s"ST_DWithin(pointDf.pointshape, polygonDf.polygonshape, $distance)"))

assert(distanceJoinDF.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size == 1)
assert(distanceJoinDF.count() == expected)
}
}

describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.sedona.sql
import com.google.common.math.DoubleMath
import org.apache.log4j.{Level, Logger}
import org.apache.sedona.common.Functions.{frechetDistance, hausdorffDistance}
import org.apache.sedona.common.Predicates.dWithin
import org.apache.sedona.common.sphere.{Haversine, Spheroid}
import org.apache.sedona.spark.SedonaContext
import org.apache.spark.sql.DataFrame
Expand Down Expand Up @@ -172,4 +173,17 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
}).sum
}).sum
}

protected def bruteForceDWithin(sampleCount: Int, distance: Double): Int = {
val inputPolygon = buildPolygonDf.limit(sampleCount).collect()
val inputPoint = buildPointDf.limit(sampleCount).collect()

inputPoint.map(row => {
val point = row.getAs[org.locationtech.jts.geom.Point](0)
inputPolygon.map(row => {
val polygon = row.getAs[org.locationtech.jts.geom.Polygon](0)
if (dWithin(point, polygon, distance)) 1 else 0
}).sum
}).sum
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -502,5 +502,21 @@ class predicateJoinTestScala extends TestBaseScala {
assert(distanceDefaultIntersectsDF.count() == expectedNoIntersects)
})
}

it("Passed ST_DWithin in a spatial join") {
val sampleCount = 200
val distanceCandidates = Seq(1.0, 2.0, 5.0, 10.0)
val inputPoint = buildPointDf.limit(sampleCount).repartition(5)
val inputPolygon = buildPolygonDf.limit(sampleCount).repartition(3)

distanceCandidates.foreach(distance => {
val expected = bruteForceDWithin(sampleCount, distance)
val dWithinDf = inputPoint.alias("pointDf").join(inputPolygon.alias("polygonDf"), expr(s"ST_DWithin(pointDf.pointshape, polygonDf.polygonshape, $distance)"))

assert(dWithinDf.count() == expected)
assert(dWithinDf.queryExecution.sparkPlan.collect { case p: RangeJoinExec => p }.size === 1)
})
}

}
}

0 comments on commit 19c80a0

Please sign in to comment.