Spark Stages
A stage is a group of tasks within a job that can all run in parallel with no data movement between them. Stage boundaries are created wherever a shuffle is required. Spark must complete all tasks in stage N before starting stage N+1.
What Creates a Stage Boundary
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.appName("Stage Demo").getOrCreate()df = spark.read.parquet("sales.parquet") # 100 partitions
# Stage 1: narrow transformations — 100 tasks, all parallelstage1 = df \ .filter(F.col("region") == "APAC") \ .withColumn("vat", F.col("amount") * 0.2)
# ←— SHUFFLE BOUNDARY (groupBy = wide transformation) —→# Stage 2: aggregate — 200 tasks (spark.sql.shuffle.partitions)stage2 = stage1.groupBy("product_category").agg(F.sum("amount").alias("total"))
# ←— SHUFFLE BOUNDARY (sort) —→# Stage 3: sort + writestage3 = stage2.orderBy(F.col("total").desc())
stage3.show()# 3 stages → 100 + 200 + 200 = 500 total tasksTwo Types of Stages
| Type | Purpose | Output |
|---|---|---|
ShuffleMapStage | Intermediate — produces shuffle data | Shuffle write files on local disk |
ResultStage | Final — produces the job’s result | Action output (count, show, write) |
Stage Dependencies
Stages form their own DAG — a stage can depend on multiple parents (e.g., both sides of a join):
df_orders = spark.read.parquet("orders.parquet") # Stage 1a (concurrent)df_customers = spark.read.parquet("customers.parquet") # Stage 1b (concurrent)
# Both 1a and 1b must complete before Stage 2 (join) can startdf_joined = df_orders.join(df_customers, "customer_id") # Stage 2df_joined.groupBy("region").sum("amount").show() # Stage 3Skipped Stages
If data was cached between jobs, dependent stages are skipped — shown in green in the UI and completing instantly:
df.cache()df.count() # Job 1 — Stage 1 runs and caches
df.groupBy("region").sum().show() # Job 2 — Stage 1 SKIPPED (cache hit)Skipped stages confirm your caching strategy is working.
Stage Metrics in Spark UI
Open Jobs → click Job → Stages:
| Metric | What it Tells You |
|---|---|
| Duration | Wall-clock time for the stage |
| Input Size | Bytes read from storage |
| Shuffle Read | Bytes fetched across the network |
| Shuffle Write | Bytes written to shuffle files |
| GC Time | JVM GC — high value = memory pressure |
| Task Time max vs median | Large gap = data skew |
Diagnosing Slow Stages
# High shuffle read/write → reduce shuffles# Use broadcast join for small tables:from pyspark.sql.functions import broadcastlarge_df.join(broadcast(small_df), "key")
# Task time skew → enable AQE skew joinspark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# Too many small tasks → increase partition sizespark.conf.set("spark.sql.shuffle.partitions", "50") # Was 200
# GC time > 20% of task time → tune memory fractionsspark.conf.set("spark.memory.fraction", "0.6")spark.conf.set("spark.memory.storageFraction", "0.3")