Spark Partitions
A partition is the smallest unit of data distribution in Spark — a slice of your dataset that one task processes on one executor core. The number of partitions directly controls the degree of parallelism: too few and you waste cluster resources; too many and task scheduling overhead dominates.
How Spark Determines Partition Count
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.appName("Partitions").getOrCreate()sc = spark.sparkContext
# Default: sc.defaultParallelism (usually 2 × num executor cores)print(sc.defaultParallelism) # e.g., 8 on 4 cores
# From file: driven by HDFS block size (default 128 MB)df = spark.read.parquet("s3://bucket/large-table/")print(df.rdd.getNumPartitions()) # 1 partition per 128 MB file chunk
# After shuffle (groupBy, join, distinct):# spark.sql.shuffle.partitions (default: 200)spark.conf.get("spark.sql.shuffle.partitions") # "200"
# Adaptive Query Execution (Spark 3.x) auto-coalesces shuffle partitionsspark.conf.set("spark.sql.adaptive.enabled", "true")spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")Checking Partition Count
# RDDrdd = sc.parallelize(range(100), numSlices=10)print(rdd.getNumPartitions()) # 10
# DataFramedf = spark.read.parquet("employees.parquet")print(df.rdd.getNumPartitions()) # Depends on file size
# Partition sizes (approximate)df.rdd.mapPartitionsWithIndex( lambda i, it: [(i, sum(1 for _ in it))]).collect()# [(0, 1234), (1, 1198), (2, 1301), ...] — rows per partitionrepartition — Increase or Change Partitions
repartition causes a full shuffle — all data moves across the network. Use it to increase parallelism or distribute data evenly.
# Repartition to a fixed countdf.repartition(200)
# Repartition by column (data with same key goes to the same partition)df.repartition(200, F.col("department"))
# Multiple partition columnsdf.repartition(400, F.col("country"), F.col("year"))
# RDD repartitionrdd.repartition(50)
# When to use repartition:# - Before a costly operation like a join on a skewed dataset# - When writing output with an optimal number of files# - When increasing parallelism after reading small filescoalesce — Decrease Partitions (No Shuffle)
coalesce reduces the partition count by merging adjacent partitions without moving data across the network. It’s faster than repartition for reducing partitions.
# Reduce to fewer partitions (no shuffle)df.coalesce(4)
# Write a single output filedf.coalesce(1).write.parquet("output/") # Single file — useful for small outputs
# RDD coalescerdd.coalesce(2)
# When to use coalesce:# - After filtering large amounts of data (many empty partitions remain)# - Before writing to reduce the number of output files# - When you have far more partitions than executor coresPartition Sizing Guidelines
# Target: each partition should be ~128 MB when read from storage# Target: 2-4 tasks per executor core for good CPU utilization
# Example: 10 GB dataset, 8 cores, target 128 MB partitions# → 10 GB / 128 MB ≈ 80 partitions# → 8 cores × 3 tasks each = 24 recommended concurrent tasks# → Use 80 partitions (divisible by cores is fine, not required)
# Set for your entire sessionspark.conf.set("spark.sql.shuffle.partitions", "80")
# Or set per-operation by repartitioning before groupBydf.repartition(80).groupBy("category").sum("revenue")Handling Data Skew
Skew means some partitions are much larger than others — a few executors run for hours while the rest finish in minutes.
# Detect skewdf.rdd.mapPartitionsWithIndex( lambda i, it: [(i, sum(1 for _ in it))]).toDF(["partition", "row_count"]).orderBy(F.col("row_count").desc()).show(10)
# Fix with salting — for groupBy/join skewimport random
# Add a random salt to distribute skewed keysdf_salted = df.withColumn( "salted_key", F.concat(F.col("skewed_key"), F.lit("_"), (F.rand() * 10).cast("int").cast("string")))
# Aggregate in two stepsintermediate = df_salted.groupBy("salted_key").sum("value")# Remove salt prefix and re-aggregatefinal = intermediate.withColumn("original_key", F.split(F.col("salted_key"), "_")[0]) \ .groupBy("original_key").sum("sum(value)")
# Fix with broadcast join — for join skew (small table on one side)from pyspark.sql.functions import broadcastdf_large.join(broadcast(df_small), "key")
# AQE automatic skew join fix (Spark 3.x)spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")Custom Partitioner (RDD API)
from pyspark import Partitioner
class DepartmentPartitioner(Partitioner): def __init__(self, partitions): self.partitions_count = partitions
def numPartitions(self): return self.partitions_count
def getPartition(self, key): mapping = {"Engineering": 0, "Marketing": 1, "HR": 2} return mapping.get(key, self.partitions_count - 1)
pairs = sc.parallelize([("Engineering", 1), ("Marketing", 2), ("HR", 3)])partitioned = pairs.partitionBy(3, DepartmentPartitioner(3))