Spark Persistence and Caching
When you reuse a DataFrame or RDD more than once, Spark recomputes it from scratch each time by default — re-reading from disk and re-running all transformations. Caching stores the result in memory or on disk after the first computation so subsequent accesses skip the recomputation entirely.
cache() vs persist()
cache() is shorthand for persist(StorageLevel.MEMORY_AND_DISK) in Spark 3.x:
from pyspark.sql import SparkSessionfrom pyspark import StorageLevel
spark = SparkSession.builder.appName("Cache Demo").getOrCreate()df = spark.read.parquet("large-table.parquet")
# cache() — defaults to MEMORY_AND_DISKdf.cache()
# persist() — explicit storage level controldf.persist(StorageLevel.MEMORY_ONLY)df.persist(StorageLevel.MEMORY_AND_DISK)df.persist(StorageLevel.DISK_ONLY)df.persist(StorageLevel.MEMORY_ONLY_2) # Replicated to 2 nodes
# Both are LAZY — caching doesn't happen until an action triggers the jobdf.count() # ← This triggers caching (first computation + store)df.show() # ← Reads from cache — no re-computationStorage Levels
| Storage Level | Memory | Disk | Replication | When to Use |
|---|---|---|---|---|
MEMORY_ONLY | ✅ | ❌ | 1× | Data fits in memory, fastest access |
MEMORY_AND_DISK | ✅ | ✅ (spill) | 1× | Default — safe choice for most workloads |
DISK_ONLY | ❌ | ✅ | 1× | Memory is tight, access is infrequent |
MEMORY_ONLY_2 | ✅ | ❌ | 2× | High availability needed |
MEMORY_AND_DISK_2 | ✅ | ✅ | 2× | HA + large data |
OFF_HEAP | Off-heap | ❌ | 1× | Reduce JVM GC pressure |
# Check persisted storage levelprint(df.storageLevel)# StorageLevel(True, True, False, True, 1)When to Cache
Cache when a dataset is:
- Used more than once in the same job
- Expensive to recompute (reads from S3/GCS, complex joins)
- Fits in cluster memory (partial disk spill is acceptable; full spill loses the benefit)
# GOOD: cache before multiple usesdf_base = spark.read.parquet("s3://bucket/events/").filter(F.col("year") == 2025)df_base.cache()df_base.count() # Trigger caching
report_a = df_base.groupBy("region").sum("revenue")report_b = df_base.groupBy("category").count()report_c = df_base.filter(F.col("amount") > 1000)
report_a.show() # From cachereport_b.show() # From cachereport_c.show() # From cache
df_base.unpersist() # Free memory when doneWhen NOT to Cache
# BAD: caching a dataset used only once — wastes memorydf.cache()df.count() # Only use
# BAD: caching before filtering — stores too muchdf_big.cache()df_filtered = df_big.filter(...)
# GOOD: filter first, cache the smaller resultdf_filtered = df_big.filter(...)df_filtered.cache()Unpersisting
Always release cached data when no longer needed:
df.unpersist()df.unpersist(blocking=True) # Wait for completion
print(df.is_cached) # True / False
# Clear all cached data at oncespark.catalog.clearCache()Caching Named Tables
df.createOrReplaceTempView("sales")spark.catalog.cacheTable("sales")spark.catalog.isCached("sales") # Truespark.catalog.uncacheTable("sales")RDD Checkpointing
For iterative algorithms where lineage grows very deep, checkpoint() writes to HDFS/S3 and truncates lineage:
sc.setCheckpointDir("hdfs://namenode/spark-checkpoints/")
for i in range(100): rdd = rdd.flatMap(transform).filter(predicate) if i % 10 == 0: rdd.checkpoint() # Truncate lineage rdd.count() # Materialize