PySpark High-Cardinality Aggregation — When count(distinct) Brings Your Cluster to a Halt
Why distinct counts over hundreds of millions of users and group aggregations over billions of rows blow up memory — and what to do about it. The true cost of exact count distinct, plus patterns that keep large-scale aggregation alive: approx_count_distinct (HyperLogLog), two-phase aggregation, and pre-aggregated rollups.
A single line of SELECT count(distinct user_id) eats up tens of minutes and eventually dies with an OOM. High-cardinality distinct aggregations — daily unique visitors (UV) per country, for example — are among the most expensive operations on large datasets. Exact count distinct has to "hold every unique value in memory."
This post explains why distinct aggregation is so expensive, and how to solve it with approximate aggregation (HyperLogLog), two-phase aggregation, and pre-aggregated rollups.
1. Why count(distinct) Is Expensive
count(*) → a single simple counter (cheap)
count(distinct x) → must track every unique value seen (expensive)Exact distinct keeps a set of unique values in memory to answer "have I seen this value before?" With cardinality in the hundreds of millions, that set itself becomes enormous. Worse, when you group by multiple columns, each group needs its own set of unique values, multiplying memory usage.
# Groups × high cardinality → memory explosion
df.groupBy("country", "date").agg(F.countDistinct("user_id"))
# Maintains a set of unique user_id values per country×date group| Operation | Memory | Shuffle |
|---|---|---|
count(*) | O(1) per group | Small |
count(distinct) | O(unique values) per group | Large (global dedup) |
sum/avg | O(1) per group | Small |
2. Solution 1 — approx_count_distinct (HyperLogLog)
Most analytics metrics (UV, reach) tolerate a small margin of error. Whether the number is 1,000,234 or 1,001,050, the business decision is the same. This is where HyperLogLog-based approximate aggregation is the right answer.
from pyspark.sql import functions as F
# Exact (expensive)
df.groupBy("country").agg(F.countDistinct("user_id").alias("uv"))
# Approximate (cheap and fast, default error ~5%)
df.groupBy("country").agg(F.approx_count_distinct("user_id").alias("uv"))
# Specify the error rate (rsd: relative standard deviation; smaller = more accurate, more memory)
df.groupBy("country").agg(F.approx_count_distinct("user_id", rsd=0.02).alias("uv"))Instead of the full set of unique values, HyperLogLog maintains only a small, fixed-size sketch. Memory stays constant even when cardinality reaches hundreds of millions.
| Exact count distinct | approx (HLL) | |
|---|---|---|
| Memory | O(unique values) | Fixed (small) |
| Speed | Slow | Fast |
| Accuracy | 100% | rsd (default ~5%) |
| Best for | Billing, settlement | UV, reach, trends |
The decision criterion: do you truly need the exact value? For billing and settlement, compute exactly; for dashboard metrics and trends, approximate. It's common for this one change to make a job dozens of times faster.
3. Solution 2 — Two-Phase Aggregation (Partial → Final)
Spark already performs aggregation in two phases by default (map-side partial → reduce-side final). But with skewed groups, the final phase piles data onto a single reducer. In that case, split the phases yourself.
# Append a salt to the skewed group key for partial aggregation → drop the salt for the final aggregation
N = 16
salted = df.withColumn("salt", (F.rand() * N).cast("int"))
partial = (salted
.groupBy("country", "salt") # Spread across salts for partial aggregation
.agg(F.sum("amount").alias("partial_sum")))
final = (partial
.groupBy("country") # Drop the salt and sum up
.agg(F.sum("partial_sum").alias("total")))Associative aggregations like sum, count, and max can be safely split into two phases this way. (For a general treatment of skew, see the separate post "Mastering Data Skew in PySpark".)
Caution:
count(distinct)cannot be split with a simple two-phase sum (adding partial distincts gives a wrong answer). For distinct skew, use approx or the pre-aggregation approach below.
4. Solution 3 — Pre-aggregated Rollups (Pre-aggregation)
If the same aggregation is queried repeatedly, don't scan the raw data every time — aggregate ahead of time into a small rollup table.
# Pre-aggregate daily per-country metrics and store them (batch)
daily = (events
.groupBy("date", "country")
.agg(
F.count("*").alias("events"),
F.approx_count_distinct("user_id").alias("uv"),
F.sum("amount").alias("revenue")))
daily.writeTo("analytics.daily_metrics").append()
# Dashboards query only the small rollup (no raw-data scans)Monthly and yearly metrics can be computed by re-aggregating the daily rollup (no need to rescan the raw data).
The Additivity Trap
When re-aggregating rollups, distinct values cannot simply be added. Summing daily UVs does not give you monthly UV (users overlap). The fix is to store HLL sketches and merge them later.
# Store mergeable sketches using Spark's sketch functions (leverage datasketches depending on implementation/version)
# Store daily sketches → for monthly metrics, union the sketches to estimate distinct counts| Metric | Rollup re-aggregation |
|---|---|
| count, sum | Just add (additive) |
| max, min | Combine with max/min |
| avg | Store sum/count separately, then compute |
| distinct | No simple sum → merge sketches |
5. Solution 4 — Avoiding collect_set Blowups
When building a "list of unique values per group" with collect_set, a group with many unique values produces a giant array in a single row — and an OOM.
# Dangerous: millions of unique values per group packed into one array
df.groupBy("country").agg(F.collect_set("user_id"))
# If you only need the count, use distinct count instead of collect_set
df.groupBy("country").agg(F.approx_count_distinct("user_id"))
# If you truly need the list, cap its size or store it decomposed separatelycollect_list/collect_set gather their result into a single row on a single executor, so they're dangerous when group cardinality is high.
6. Diagnosis — Where Things Blow Up
| Symptom | Cause | Remedy |
|---|---|---|
| count distinct is slow / OOM | Exact aggregation at high cardinality | approx_count_distinct |
| Only certain groups are slow | Group skew | Salted two-phase / approx |
| Rollup UV is wrong | Distinct additivity error | Merge sketches |
| collect_set OOM | Giant arrays | Replace with distinct count |
| Same aggregation repeated | Rescanning raw data | Pre-aggregated rollup |
Check the Spark UI for task skew and spill in the final aggregation stage (see the separate post "Debugging Slow PySpark Jobs").
7. Summary
| Solution | When |
|---|---|
approx_count_distinct | Error-tolerant distinct (UV, reach) |
| Two-phase salted aggregation | Group skew in additive aggregations |
| Pre-aggregated rollup | Repeatedly queried metrics |
| Sketch merging | Re-aggregating distinct in rollups |
| Avoiding collect_set | When only the count is needed |
The key insight for high-cardinality aggregation is to "first ask whether you really need exactness." An exact distinct value is expensive, but for most analytics metrics a HyperLogLog approximation is good enough. Add two-phase processing for additive aggregations and pre-aggregated rollups, and even aggregations over hundreds of millions of unique values can run reliably in constant memory. "count distinct brings the cluster down" is no longer something you just have to live with.
This post was written based on Spark 3.5. If you need help designing large-scale metric aggregation and rollup pipelines, feel free to reach out.
— Data Dynamics Engineering Team