Skip to content

Commit

Permalink
WIP: db-based clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
Wiktor Latanowicz committed Mar 14, 2024
1 parent fecb957 commit a1f86f8
Showing 1 changed file with 65 additions and 2 deletions.
67 changes: 65 additions & 2 deletions generic_map_api/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import TYPE_CHECKING, Any

import numpy as np
from django.db.models import QuerySet
from django.contrib.gis.db.models.aggregates import Collect
from django.db import connections
from django.db.models import Count, F, QuerySet
from django.db.models.expressions import Func, Window
from django.db.models.fields import IntegerField
from shapely import wkb
from shapely.geometry import MultiPoint, MultiPolygon, Point
from sklearn.cluster import DBSCAN

Expand All @@ -28,11 +33,69 @@ def find_clusters(self, view: MapFeaturesBaseView, viewport: BaseViewPort, items


class DatabaseClustering(BaseClustering):
@dataclass
class Cluster:
count: int
centroid: Point
shape: MultiPolygon

class StClusterDBSCAN(Func):
function = "ST_ClusterDBSCAN"
output_field = IntegerField()
window_compatible = True

def __init__(self, geometry, eps=1, minpoints=5, **extra):
args = (geometry, eps, minpoints)
super().__init__(*args, **extra)

def find_clusters(self, view: MapFeaturesBaseView, viewport: BaseViewPort, items):
if not isinstance(items, QuerySet):
raise ValueError("Database clustering requires QuerySet on input")

raise NotImplementedError()
sql, sql_params = items.query.sql_with_params()

geometry_field = "position"

items_outside_clusters_raw_sql = f"""
SELECT * FROM (
SELECT *, ST_ClusterDBSCAN({geometry_field}, eps := 3, minpoints := 5) OVER () as cluster_label FROM ({sql}) sq1
) sq2 WHERE cluster_label IS NULL;
"""
items_outside_clusters = items.model.objects.raw(
items_outside_clusters_raw_sql, sql_params
)

for item in items_outside_clusters.iterator():
yield ClusteringOutput(
is_cluster=False,
item=item,
)

clusters_raw_sql = f"""
SELECT COUNT(*) as cluster_item_count, ST_Collect({geometry_field}) as cluster_geometry FROM (
SELECT *, ST_ClusterDBSCAN({geometry_field}, eps := 3, minpoints := 5) OVER () as cluster_label FROM ({sql}) sq1
) sq2 WHERE cluster_label IS NOT NULL GROUP BY cluster_label;
"""

connection = connections[items.db]
with connection.cursor() as c:
c.execute(clusters_raw_sql, sql_params)
while row := c.fetchone():
shape = MultiPolygon([wkb.loads(row[1]).convex_hull])
yield ClusteringOutput(
is_cluster=True,
item=self.Cluster(
count=row[0],
shape=shape,
centroid=shape.centroid,
),
)

for cluster in []:
yield ClusteringOutput(
is_cluster=True,
item=cluster,
)


class BasicClustering:
Expand Down

0 comments on commit a1f86f8

Please sign in to comment.