A distributed version of the DBSCAN algorithm that scales to billions of points can be implemented as a spatial join and a connected components graph traversal. If those words made sense to you, no need to read further. You can nod your head, furrow your brow, and move on. If they did not or you just enjoy pretty gifs of algorithms, keep reading.
What is DBSCAN?
DBSCAN is a clustering algorithm that has some advantages over similar algorithms such as k-means.
- No predefined number of clusters: Unlike k-means, DBSCAN doesn't require specifying the number of clusters beforehand.
- Arbitrary cluster shapes: It can discover clusters of arbitrary shapes, not just spherical ones.
- Outlier detection: DBSCAN naturally identifies outliers as noise points
- Can be implemented in O(n * log(n)) (We will get into this in more detail)
The advantages over k-means can be seen using the moons dataset.


In the first image, we see that DBSCAN separates each half-moon into its own cluster and in the second, k-means assumes spherical clusters and thus, splits the moons down the middle.
The DBSCAN algorithm
Here is an implementation of DBSCAN:
def dbscan(points, eps, min_points):
def get_neighbors(i):
"""
We will talk about about possible implementations of this below.
For now, assume it returns a list of indexes thats are <= eps away
from point[i].
"""
pass
labels = [-1] * len(X) # All points start as unlabeled (-1)
cluster = 0
for i in range(len(X)):
if labels[i] != -1: # Skip previously labeled points
continue
neighbors = get_neighbors(i)
if len(neighbors) >= min_points:
labels[i] = cluster
stack = neighbors[:] # copy
while stack: # Iterate all points in current cluster
n = stack.pop()
if labels[n] == -1: # Update neighbor's label
labels[n] = cluster
# Find neighbors neighbors and so on ...
new_neighbors = get_neighbors(n)
# Recurse if we have enough point density
if len(new_neighbors) >= min_points:
stack.extend(new_neighbors)
cluster += 1 # Update for next possible cluster
Here is a gif of the algorithm adding points to the clusters one-by-one.

You can see that DBSCAN finishes recusing through all points in one cluster before moving on to the next one.
In the above implementation, I omitted the get_neighbors
function. The most basic implementation iterates through all points and checks to see, if they are within eps
distance of the given point.
[j for j in range(len(points)) if distance(points[i], points[j]) <= eps]
get_neighbors
is called for every point, thus this implementation is O(n^2). A better implementation builds a spatial search data structure upfront and then uses that to find neighboring points. This can improve the runtime to be O(n log(n)). This is the implementation used when calling sklearn.cluter.DBSCAN with the ball-tree flag.
Why distribute DBSCAN?
Recently, I needed to run DBSCAN on around 2 billions points. If we assume 2 coordinates per point (X, Y) and 8 bytes per coordinates, that is 32GB of memory just for the input points. Once we account for the for the labels, the recursion stack, and the search tree, we are beyond the range of normal desktop memory. Even if we were to turn to a high memory machine in the cloud to overcome our memory woes, the above algorithm is inherently serial and we would be looking at some monster run times. Can we run on multiple machines at once?
Looking at the single node DBSCAN algorithm above, we see it can be broken down into two operations.
1) Find neighbors
2) Grow clusters
I knew that I could find neighboring point pairs by using the ST_Buffer
function and a spatial join but Claude AI helped me realize that those point pairs are edges in a graph. As soon as that lightbulb went off I realized that the grow clusters operation was just a graph traversal.
(Note that, when I was doing this work in June 2024 this PR hadn't yet been proposed or at least I hadn't seen it. It is awesome that Sedona will now support this out of the box!)
My Distributed DBSCAN
Using the same half-moon's dataset above. We first load it into a spark dataframe:
points, _ = make_moons(n_samples=40, noise=0.04, random_state=42)
df = pd.DataFrame(points, columns=('x','y')) # Convert to Pandas
df_in = sedona.createDataFrame(df) # Convert to Spark
We get a dataframe with x
& y
coordinates in columns:
+-------------------+--------------------+
| x| y|
+-------------------+--------------------+
|-1.0117359658543466|-0.00119354277438...|
|-0.8756687198338035| 0.5025335667470677|
|-0.7947492492906738| 0.6128849766870061|
|0.42308878093983887| -0.3683017585461643|
| 0.8270942237051542| 0.6774467361312289|
+-------------------+--------------------+
only showing top 5 rows
Add a point column (pt) using Sedona's ST_Point constructor and give every row an id.
df_pts = df_in \
.withColumn("pt", f.expr("ST_Point(x, y)")) \
.withColumn("id", f.monotically_increasing_id())
+-------------------+--------------------+--------------------+---+
| x| y| pt| id|
+-------------------+--------------------+--------------------+---+
|-1.0117359658543466|-0.00119354277438...|POINT (-1.0117359...| 0|
|-0.9839383532692947| 0.28278464510141105|POINT (-0.9839383...| 1|
|-0.9746389335768995| 0.20045980562917187|POINT (-0.9746389...| 2|
|-0.8756687198338035| 0.5025335667470677|POINT (-0.8756687...| 3|
|-0.7947492492906738| 0.6128849766870061|POINT (-0.7947492...| 4|
+-------------------+--------------------+--------------------+---+
only showing top 5 rows
Next, we want to find all points that are within eps
distance of each other. This can be accomplish by doing a spatial join of the dataframe to itself. Sedona will build an search tree on the fly giving us similar O(n log(n)) performance but distributed across many machines.
eps = 0.3
df_self_intersect = df_pts.alias("l").join(df_pts.alias("r"), f.expr(f"ST_Contains(ST_Buffer(l.pt, {eps}), r.pt) AND l.pt != r.pt"), "left")
df_self_intersect.select("l.pt","l.id","r.pt","r.id").show(5)
Outputs point pairs:
+--------------------+---+--------------------+---+
| pt| id| pt| id|
+--------------------+---+--------------------+---+
|POINT (-0.9839383...| 1|POINT (-1.0117359...| 0|
|POINT (-0.9746389...| 2|POINT (-1.0117359...| 0|
|POINT (-1.0117359...| 0|POINT (-0.9839383...| 1|
|POINT (-0.9746389...| 2|POINT (-0.9839383...| 1|
|POINT (-0.8756687...| 3|POINT (-0.9839383...| 1|
+--------------------+---+--------------------+---+
only showing top 5 rows
Lets breaking down the above join condition a bit more. It is using the ST_Bufferfunction to create a "bubble" around the point on the left side of the join then the ST_Contains function will only return true, if the right point is inside of that "bubble". Finally, there is a second part of the join condition to exclude point pairs with the same id.
Here is a gif plotting all of the point pairs:

If we do a groupBy("l.id")
, we can plot points and their neighbors.
The above gif gives us the intuition that we just need to connect all of the pairs together. The connected components algorithm does exact this. The distributed version of connected components is non-trivial. But the Graphframes library provides a very nice implementation.
Run the connectedComponents
algorithm.
verts = df_pts
edges = df_self_intersect.select(f.col("l.id").alias("src"), f.col("r.id").alias("dst"))
g = GraphFrame(verts, edges)
sedona.sparkContext.setCheckpointDir("file:///tmp/checkpoint")
cc = g.connectedComponents()
Here is a gif of the steps of the connectedComponents
algorithm.

Noise points
We can use the make_blobs function to a generate a test dataset with some noise points. We see that dealing with them is just some pretty straight forward sql/spark. I.e., find all the points that are over the cutoff, join that back to the connected components output dataframe. If you just wanted to drop the noise points and not label them you could also just do a simple inner
join here instead.
min_points = 5
clusters_above_cutoff = cc.groupBy("component").count().where(f.col("count") > min_points)
result = cc_wo_noise = cc.join(clusters_above_cutoff, ["component"], "left") \
.select(
f.when(f.col("count").isNull(), f.lit(-1)).otherwise(f.col("component")).alias("component"),
f.col("x"),
f.col("y"),
f.col("pt"),
f.col("id"))
Final thoughts on AI and coding
I probably would have eventually found a workable solution on my own. But it would have taken me a lot longer. Having an AI assistant tell me, "This is equivalent to the connected components algorithm and here is how you run that using GraphFrames." is an amazing level up. I continue to think that AI will struggle to fully replace developers, at least, for the remainder of my career. But I am becoming more confident that the best developers will be some form of cyborg. What do you think? What problems have you used AI to help you level up? Respond to me on mastodon.
Full implementation
from sklearn.datasets import make_blobs
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import StructType, StructField, DoubleType
from sedona.register import SedonaRegistrator
from sedona.spark import SedonaContext
from pyspark.sql import functions as f
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import imageio
import random
from collections import deque
from graphframes import GraphFrame
random.seed(42)
config = SedonaContext.builder() \
.config('spark.jars.packages',
'org.apache.sedona:sedona-spark-3.5_2.12:1.6.1,org.datasyslab:geotools-wrapper:1.6.1-28.2,graphframes:graphframes:0.8.4-spark3.5-s_2.12') \
.config('spark.jars.repositories', 'https://artifacts.unidata.ucar.edu/repository/unidata-all', ) \
.config('spark.executor.memory', '3g') \
.config('spark.driver.memory', '2g') \
.config("spark.executor.instances", 2) \
.config("spark.executor.cores", 4) \
.getOrCreate()
sedona = SedonaContext.create(config)
def generate_data(n_samples, n_features=2):
return make_blobs(n_samples=n_samples, n_features=n_features, random_state=42)[0]
X = generate_data(100)
df = pd.DataFrame(X, columns=('x','y')) # Convert to Pandas
df_in = sedona.createDataFrame(df) # Convert to Spark
df_pts = df_in \
.withColumn("pt", f.expr("ST_Point(x, y)")) \
.withColumn("id", f.monotonically_increasing_id())
eps = 1
df_self_intersect = df_pts.alias("l").join(df_pts.alias("r"), f.expr(f"ST_Contains(ST_Buffer(l.pt, {eps}), r.pt) AND l.pt != r.pt"), "left")
verts = df_pts
edges = df_self_intersect.select(f.col("l.id").alias("src"), f.col("r.id").alias("dst"))
g = GraphFrame(verts, edges)
sedona.sparkContext.setCheckpointDir("file:///tmp/checkpoint") # Change to hdfs if on cluster
cc = g.connectedComponents()
min_points = 5
clusters_above_cutoff = cc.groupBy("component").count().where(f.col("count") > min_points)
result = cc.join(clusters_above_cutoff, ["component"], "left") \
.select(
f.when(f.col("count").isNull(), f.lit(-1)).otherwise(f.col("component")).alias("component"),
f.col("x"),
f.col("y"),
f.col("pt"),
f.col("id"))
xs, ys, cs = result.select(f.collect_list(f.col('x')), f.collect_list(f.col('y')), f.collect_list(f.col('component'))).first()
plt.figure(figsize=(10, 6))
plt.scatter(xs, ys, c=cs, cmap='viridis')
plt.title('Clustering on Blobs Dataset')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
