Spark Transformations
A transformation in Spark produces a new RDD or DataFrame from an existing one without immediately running any computation. Spark records the transformation in a lineage graph (DAG) and defers execution until an action is called. This lazy approach lets Spark optimize the entire chain of operations before a single byte of data moves.
Narrow vs Wide Transformations
This is one of Spark’s most important architectural concepts:
| Property | Narrow | Wide |
|---|---|---|
| Data movement | Each partition maps to one output partition | Multiple input partitions contribute to one output partition |
| Network shuffle | None | Yes — data crosses the network |
| Stage boundary | No | Yes — creates a new stage |
| Examples | map, filter, flatMap, mapPartitions | groupByKey, reduceByKey, join, distinct, repartition |
Minimize wide transformations (shuffles) for better performance.
Element-Level Transformations (Narrow)
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.appName("Transformations").getOrCreate()sc = spark.sparkContext
# map — one element in, one element outnumbers = sc.parallelize([1, 2, 3, 4, 5])squared = numbers.map(lambda x: x ** 2) # [1, 4, 9, 16, 25]
# filter — keep elements matching predicateevens = numbers.filter(lambda x: x % 2 == 0) # [2, 4]
# flatMap — one element in, zero or more outsentences = sc.parallelize(["hello world", "apache spark is fast"])words = sentences.flatMap(lambda s: s.split())# ["hello", "world", "apache", "spark", "is", "fast"]
# mapPartitions — process an entire partition at once (more efficient than map for setup costs)def process_partition(records): # Open DB connection once per partition, not per record for record in records: yield record.upper()
result = sc.parallelize(["a", "b", "c"]).mapPartitions(process_partition)DataFrame Transformations
data = [("Alice", "Eng", 95000), ("Bob", "Mkt", 72000), ("Carol", "Eng", 110000)]df = spark.createDataFrame(data, ["name", "dept", "salary"])
# select — column projectiondf.select("name", "salary")df.select(F.col("name"), (F.col("salary") * 1.1).alias("bumped_salary"))
# filter / wheredf.filter(F.col("salary") > 80000)df.where("dept = 'Eng'")df.filter((F.col("dept") == "Eng") & (F.col("salary") > 90000))
# withColumn — add or replace a columndf.withColumn("is_senior", F.col("salary") > 100000)df.withColumn("dept_upper", F.upper(F.col("dept")))
# drop — remove columnsdf.drop("dept")
# distinct — remove duplicate rowsdf.distinct()df.dropDuplicates(["dept"]) # Deduplicate on specific columns
# orderBy / sortdf.orderBy(F.col("salary").desc())df.sort("dept", F.col("salary").desc())
# limitdf.limit(100)Aggregation Transformations (Wide)
# groupBy + aggregatedf.groupBy("dept").agg( F.count("*").alias("headcount"), F.avg("salary").alias("avg_salary"), F.max("salary").alias("max_salary"),)
# pivotquarterly = spark.createDataFrame([ ("Q1", "Electronics", 50000), ("Q2", "Electronics", 60000), ("Q1", "Clothing", 30000), ("Q2", "Clothing", 35000),], ["quarter", "category", "revenue"])
quarterly.groupBy("category").pivot("quarter", ["Q1", "Q2"]).sum("revenue").show()# +----------+-----+-----+# |category |Q1 |Q2 |# +----------+-----+-----+# |Electronics|50000|60000|# |Clothing |30000|35000|Key-Value RDD Transformations
pairs = sc.parallelize([("apple", 3), ("banana", 5), ("apple", 2), ("cherry", 1)])
# reduceByKey — merges per key locally first (efficient)pairs.reduceByKey(lambda a, b: a + b)# [("apple", 5), ("banana", 5), ("cherry", 1)]
# groupByKey — collects all values per key (can cause OOM for large datasets)pairs.groupByKey().mapValues(list)# [("apple", [3, 2]), ("banana", [5]), ("cherry", [1])]
# aggregateByKey — custom combine logic within and between partitionspairs.aggregateByKey( zeroValue=0, seqFunc=lambda acc, v: acc + v, # Within partition combFunc=lambda a, b: a + b # Between partitions)
# sortByKeypairs.sortByKey(ascending=True)
# keys / valuespairs.keys() # ["apple", "banana", "apple", "cherry"]pairs.values() # [3, 5, 2, 1]Transformation Chaining
Spark is optimized for long chains — Catalyst merges adjacent narrow transformations into a single pass:
# All narrow transformations fuse into one physical stageresult = ( df .filter(F.col("dept") == "Eng") # Narrow .withColumn("tax", F.col("salary") * 0.3) # Narrow .select("name", "salary", "tax") # Narrow .orderBy("salary") # Wide (sort requires shuffle) .limit(5))result.show()