Complete Databricks Learning Roadmap & Professional Guide
Table of Contents
- Introduction to Databricks
- Prerequisites & Fundamentals
- Databricks Architecture
- Getting Started with Databricks
- Apache Spark Fundamentals
- DataFrames & Datasets
- Databricks SQL
- Delta Lake
- Data Engineering on Databricks
- ETL/ELT Pipelines
- Machine Learning on Databricks
- Databricks Workflows & Jobs
- Security & Governance
- Performance Optimization
- Advanced Topics
- Real-World Projects
- Interview Preparation
1. Introduction to Databricks
What is Databricks?
Databricks is a unified analytics platform built on Apache Spark that provides data engineering, data science, and business analytics capabilities in one collaborative environment.
Key Features
- Unified Platform: Combines data warehousing, data lakes, and machine learning
- Collaborative Notebooks: Interactive workspace for data teams
- Auto-scaling Clusters: Automatically manages compute resources
- Delta Lake: ACID transactions for data lakes
- MLflow Integration: End-to-end machine learning lifecycle management
- Built on Apache Spark: Distributed processing engine
Use Cases
- Big data processing and analytics
- ETL/ELT pipelines
- Real-time streaming analytics
- Machine learning and AI
- Data warehousing
- Business intelligence
2. Prerequisites & Fundamentals
Required Knowledge
- Programming Languages
- Python (Primary)
- SQL (Essential)
- Scala (Optional but recommended)
- R (Optional)
- Database Concepts
- Relational databases
- SQL queries
- Data modeling
- Normalization
- Big Data Concepts
- Distributed computing
- Parallel processing
- CAP theorem
- Data partitioning
Python Essentials for Databricks
# Basic Python constructs you'll use frequently
# 1. List Comprehensions
numbers = [x**2 for x in range(10)]
# 2. Lambda Functions
square = lambda x: x**2
# 3. Map, Filter, Reduce
from functools import reduce
mapped = list(map(lambda x: x*2, [1,2,3]))
filtered = list(filter(lambda x: x > 5, [1,5,10,15]))
sum_all = reduce(lambda x,y: x+y, [1,2,3,4,5])
# 4. Exception Handling
try:
result = 10 / 0
except ZeroDivisionError as e:
print(f"Error: {e}")
finally:
print("Cleanup operations")
SQL Essentials
-- Basic SQL patterns for Databricks
-- SELECT with WHERE clause
SELECT customer_id, order_date, total_amount
FROM orders
WHERE order_date >= '2024-01-01'
AND total_amount > 100;
-- Aggregations
SELECT
customer_id,
COUNT(*) as order_count,
SUM(total_amount) as total_spent,
AVG(total_amount) as avg_order_value
FROM orders
GROUP BY customer_id
HAVING COUNT(*) > 5;
-- Joins
SELECT
o.order_id,
c.customer_name,
o.total_amount
FROM orders o
INNER JOIN customers c ON o.customer_id = c.customer_id;
-- Window Functions
SELECT
order_id,
customer_id,
total_amount,
ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY order_date DESC) as rn
FROM orders;
3. Databricks Architecture
Platform Components
3.1 Control Plane
- Manages cluster lifecycle
- Notebook management
- Job scheduling
- Security and access control
3.2 Data Plane
- Runs Spark clusters
- Executes workloads
- Stores data (DBFS)
3.3 Key Architecture Elements
Workspace
- Collaborative environment
- Contains notebooks, libraries, experiments
- Organized into folders
Clusters
- Computing resources
- Types:
- All-Purpose Clusters (interactive)
- Job Clusters (automated workloads)
- SQL Warehouses (SQL analytics)
DBFS (Databricks File System)
- Distributed file system
- Abstraction over cloud storage
- Default paths:
/FileStore/– User files/databricks-datasets/– Sample datasets/user/hive/warehouse/– Managed tables
Unity Catalog
- Unified governance solution
- Fine-grained access control
- Data lineage tracking
- Centralized metadata management
Architecture Diagram Concept
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Control Plane (Databricks) โ
โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ
โ โNotebooks โ โ Jobs โ โ Security โ โ
โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Data Plane (Your Cloud) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Spark Clusters โ โ
โ โ โโโโโโโโโโ โโโโโโโโโโ โโโโโโโโโโ โ โ
โ โ โDriver โ โWorker โ โWorker โ โ โ
โ โ โโโโโโโโโโ โโโโโโโโโโ โโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ DBFS / Cloud Storage (S3/ADLS/GCS) โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
4. Getting Started with Databricks
4.1 Workspace Setup
Step 1: Create Account
- Sign up for Databricks Community Edition (free tier)
- Or use enterprise cloud provider account
Step 2: Create Cluster
# Cluster configuration (done via UI, but here's what to set)
# - Databricks Runtime: 13.3 LTS or later
# - Worker Type: Standard_DS3_v2 (or equivalent)
# - Min Workers: 1
# - Max Workers: 3 (for auto-scaling)
# - Auto-termination: 120 minutes
Step 3: Create Notebook
- Click “Create” > “Notebook”
- Choose language: Python, Scala, SQL, or R
- Attach to cluster
4.2 Notebook Basics
Magic Commands
# Language switching
%python
print("Python code")
%sql
SELECT * FROM my_table LIMIT 10
%scala
val df = spark.read.parquet("/path/to/data")
%md
# This is markdown
## For documentation
%sh
ls -la /dbfs/
%fs
ls /FileStore/
%run ./other_notebook
# Runs another notebook
# Display output
display(df) # Shows DataFrame in table format
displayHTML("<h1>Custom HTML</h1>")
4.3 First PySpark Program
# Create SparkSession (already available as 'spark' in Databricks)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, avg
# Create sample data
data = [
("Alice", 34, "Engineering"),
("Bob", 45, "Sales"),
("Charlie", 28, "Engineering"),
("David", 39, "HR"),
("Eve", 31, "Sales")
]
columns = ["name", "age", "department"]
# Create DataFrame
df = spark.createDataFrame(data, columns)
# Show data
df.show()
# Basic transformations
result = df.groupBy("department") \
.agg(
count("*").alias("employee_count"),
avg("age").alias("avg_age")
) \
.orderBy("employee_count", ascending=False)
display(result)
# Save as Delta table
result.write.format("delta").mode("overwrite").saveAsTable("department_stats")
5. Apache Spark Fundamentals
5.1 Spark Architecture
Key Components
- Driver: Orchestrates execution, maintains SparkSession
- Executors: Run tasks on worker nodes
- Cluster Manager: Allocates resources
Execution Model
- Driver creates logical plan
- Catalyst optimizer creates physical plan
- Tasks distributed to executors
- Results collected back to driver
5.2 RDD (Resilient Distributed Dataset)
# Creating RDDs (lower-level API, less common now)
rdd = spark.sparkContext.parallelize([1, 2, 3, 4, 5])
# Transformations (lazy)
squared_rdd = rdd.map(lambda x: x**2)
filtered_rdd = squared_rdd.filter(lambda x: x > 10)
# Actions (trigger execution)
result = filtered_rdd.collect()
print(result) # [16, 25]
# Common RDD operations
text_rdd = spark.sparkContext.textFile("/path/to/file.txt")
words = text_rdd.flatMap(lambda line: line.split(" "))
word_pairs = words.map(lambda word: (word, 1))
word_counts = word_pairs.reduceByKey(lambda a, b: a + b)
# Convert RDD to DataFrame (recommended)
df = word_counts.toDF(["word", "count"])
5.3 Transformations vs Actions
Transformations (Lazy)
map(),filter(),flatMap()groupBy(),join(),union()select(),where(),distinct()- Return new RDD/DataFrame
Actions (Eager)
collect(),count(),first()take(),reduce(),foreach()show(),write()- Trigger execution
5.4 Lazy Evaluation
# Nothing executes yet
df1 = spark.read.parquet("/data/input")
df2 = df1.filter(col("age") > 25)
df3 = df2.select("name", "age")
# Execution happens here
df3.show() # Action triggers computation
# View execution plan
df3.explain(True)
6. DataFrames & Datasets
6.1 Creating DataFrames
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
# Method 1: From list
data = [("Alice", 25, 50000), ("Bob", 30, 60000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
# Method 2: With explicit schema
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("salary", DoubleType(), True)
])
df = spark.createDataFrame(data, schema)
# Method 3: From file (most common)
# CSV
df_csv = spark.read.csv("/path/to/file.csv", header=True, inferSchema=True)
# JSON
df_json = spark.read.json("/path/to/file.json")
# Parquet (recommended for performance)
df_parquet = spark.read.parquet("/path/to/file.parquet")
# Delta (Databricks optimized)
df_delta = spark.read.format("delta").load("/path/to/delta/table")
# Method 4: From SQL table
df_table = spark.table("database.table_name")
# Method 5: From SQL query
df_query = spark.sql("SELECT * FROM my_table WHERE age > 25")
6.2 DataFrame Operations
from pyspark.sql.functions import *
# Sample DataFrame
data = [
("Alice", 34, "Engineering", 95000, "2020-01-15"),
("Bob", 45, "Sales", 85000, "2019-03-20"),
("Charlie", 28, "Engineering", 78000, "2021-06-10"),
("David", 39, "HR", 72000, "2018-11-05"),
("Eve", 31, "Sales", 88000, "2020-09-18")
]
df = spark.createDataFrame(data, ["name", "age", "dept", "salary", "hire_date"])
# Selection
df.select("name", "dept").show()
df.select(col("name"), col("salary") * 1.1).show()
# Filtering
df.filter(col("age") > 30).show()
df.where((col("dept") == "Engineering") & (col("salary") > 80000)).show()
# Adding columns
df_with_bonus = df.withColumn("bonus", col("salary") * 0.1)
df_with_grade = df.withColumn("grade",
when(col("salary") > 90000, "A")
.when(col("salary") > 80000, "B")
.otherwise("C")
)
# Renaming columns
df_renamed = df.withColumnRenamed("dept", "department")
# Dropping columns
df_dropped = df.drop("hire_date")
# Sorting
df.orderBy("salary", ascending=False).show()
df.sort(col("age").desc(), col("name")).show()
# Aggregations
df.groupBy("dept").agg(
count("*").alias("employee_count"),
avg("salary").alias("avg_salary"),
max("salary").alias("max_salary"),
min("age").alias("min_age")
).show()
# Window functions
from pyspark.sql.window import Window
window_spec = Window.partitionBy("dept").orderBy(col("salary").desc())
df_ranked = df.withColumn("rank", row_number().over(window_spec))
df_ranked.show()
# Distinct
df.select("dept").distinct().show()
# Joins
employees = spark.createDataFrame([
(1, "Alice", 10),
(2, "Bob", 20),
(3, "Charlie", 10)
], ["emp_id", "name", "dept_id"])
departments = spark.createDataFrame([
(10, "Engineering"),
(20, "Sales"),
(30, "HR")
], ["dept_id", "dept_name"])
# Inner join
joined = employees.join(departments, "dept_id", "inner")
# Left join
left_joined = employees.join(departments, "dept_id", "left")
# Complex join conditions
complex_join = employees.join(
departments,
(employees.dept_id == departments.dept_id) & (employees.name.startswith("A")),
"inner"
)
6.3 Working with Different Data Types
from pyspark.sql.functions import *
from pyspark.sql.types import *
# String operations
df_strings = spark.createDataFrame([
("john doe",), ("JANE SMITH",), (" bob ",)
], ["name"])
df_strings.select(
col("name"),
upper("name").alias("upper"),
lower("name").alias("lower"),
initcap("name").alias("title_case"),
trim("name").alias("trimmed"),
length("name").alias("length"),
substring("name", 1, 4).alias("first_4"),
concat(col("name"), lit(" - Employee")).alias("formatted")
).show(truncate=False)
# Date operations
df_dates = spark.createDataFrame([
("2024-01-15",), ("2024-06-20",), ("2023-12-01",)
], ["date_str"])
df_dates.select(
to_date("date_str").alias("date"),
current_date().alias("today"),
datediff(current_date(), to_date("date_str")).alias("days_ago"),
date_add(to_date("date_str"), 30).alias("plus_30_days"),
year(to_date("date_str")).alias("year"),
month(to_date("date_str")).alias("month"),
dayofweek(to_date("date_str")).alias("day_of_week")
).show()
# Timestamp operations
df_with_ts = spark.sql("SELECT current_timestamp() as ts")
df_with_ts.select(
col("ts"),
date_format("ts", "yyyy-MM-dd HH:mm:ss").alias("formatted"),
hour("ts").alias("hour"),
minute("ts").alias("minute")
).show(truncate=False)
# Array operations
df_arrays = spark.createDataFrame([
(1, ["a", "b", "c"]),
(2, ["x", "y"]),
(3, ["p", "q", "r", "s"])
], ["id", "letters"])
df_arrays.select(
col("id"),
col("letters"),
size("letters").alias("array_size"),
array_contains("letters", "b").alias("contains_b"),
explode("letters").alias("letter") # Flattens array
).show()
# JSON operations
df_json = spark.createDataFrame([
(1, '{"name": "Alice", "age": 30}'),
(2, '{"name": "Bob", "age": 25}')
], ["id", "json_str"])
df_json.select(
col("id"),
get_json_object("json_str", "$.name").alias("name"),
get_json_object("json_str", "$.age").alias("age")
).show()
# Null handling
df_nulls = spark.createDataFrame([
("Alice", 30, 50000),
("Bob", None, 60000),
("Charlie", 28, None)
], ["name", "age", "salary"])
# Fill nulls
df_filled = df_nulls.fillna({"age": 0, "salary": 0})
# Drop nulls
df_dropped_nulls = df_nulls.dropna() # Drop rows with any null
df_dropped_all = df_nulls.dropna(how="all") # Drop only if all values are null
# Replace values
df_replaced = df_nulls.replace({None: 0})
# Coalesce (return first non-null value)
df_coalesced = df_nulls.withColumn("salary_clean", coalesce(col("salary"), lit(0)))
6.4 UDFs (User Defined Functions)
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType
# Simple UDF
def categorize_age(age):
if age < 25:
return "Young"
elif age < 40:
return "Middle"
else:
return "Senior"
categorize_udf = udf(categorize_age, StringType())
df_with_category = df.withColumn("age_category", categorize_udf(col("age")))
# Pandas UDF (vectorized, much faster)
from pyspark.sql.functions import pandas_udf
import pandas as pd
@pandas_udf(StringType())
def categorize_age_pandas(age: pd.Series) -> pd.Series:
return age.apply(lambda x:
"Young" if x < 25 else "Middle" if x < 40 else "Senior"
)
df_with_category_pandas = df.withColumn("age_category", categorize_age_pandas(col("age")))
# Complex UDF with multiple inputs
@pandas_udf(DoubleType())
def calculate_bonus(salary: pd.Series, age: pd.Series) -> pd.Series:
return salary * (0.1 if age.mean() > 35 else 0.05)
df_with_bonus = df.withColumn("bonus", calculate_bonus(col("salary"), col("age")))
7. Databricks SQL
7.1 SQL Analytics
-- Create database
CREATE DATABASE IF NOT EXISTS sales_db;
USE sales_db;
-- Create managed table
CREATE TABLE customers (
customer_id INT,
customer_name STRING,
email STRING,
registration_date DATE,
country STRING
) USING DELTA;
-- Create external table
CREATE TABLE IF NOT EXISTS orders_external (
order_id INT,
customer_id INT,
order_date DATE,
amount DOUBLE
)
USING DELTA
LOCATION '/mnt/data/orders/';
-- Insert data
INSERT INTO customers VALUES
(1, 'Alice Johnson', 'alice@example.com', '2023-01-15', 'USA'),
(2, 'Bob Smith', 'bob@example.com', '2023-02-20', 'UK'),
(3, 'Charlie Brown', 'charlie@example.com', '2023-03-10', 'Canada');
-- Query with CTEs
WITH monthly_sales AS (
SELECT
DATE_TRUNC('month', order_date) as month,
SUM(amount) as total_sales,
COUNT(*) as order_count
FROM orders
GROUP BY DATE_TRUNC('month', order_date)
),
customer_stats AS (
SELECT
c.country,
COUNT(DISTINCT c.customer_id) as customer_count,
AVG(o.amount) as avg_order_value
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.country
)
SELECT
cs.country,
cs.customer_count,
cs.avg_order_value,
ms.total_sales
FROM customer_stats cs
CROSS JOIN monthly_sales ms
WHERE ms.month = '2024-01-01';
-- Window functions
SELECT
customer_id,
order_date,
amount,
SUM(amount) OVER (PARTITION BY customer_id ORDER BY order_date) as running_total,
AVG(amount) OVER (PARTITION BY customer_id ORDER BY order_date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) as moving_avg,
RANK() OVER (PARTITION BY DATE_TRUNC('month', order_date) ORDER BY amount DESC) as monthly_rank
FROM orders;
-- Pivot operations
SELECT * FROM (
SELECT country, YEAR(order_date) as year, amount
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
)
PIVOT (
SUM(amount) as total_sales
FOR year IN (2022, 2023, 2024)
);
7.2 Advanced SQL Features
-- MERGE (Upsert)
MERGE INTO customers AS target
USING updates AS source
ON target.customer_id = source.customer_id
WHEN MATCHED THEN
UPDATE SET target.email = source.email
WHEN NOT MATCHED THEN
INSERT (customer_id, customer_name, email, registration_date, country)
VALUES (source.customer_id, source.customer_name, source.email, source.registration_date, source.country);
-- Time travel (Delta Lake feature)
SELECT * FROM customers VERSION AS OF 5;
SELECT * FROM customers TIMESTAMP AS OF '2024-01-01';
-- Clone tables
CREATE TABLE customers_backup SHALLOW CLONE customers;
CREATE TABLE customers_dev DEEP CLONE customers;
-- Optimize tables
OPTIMIZE customers ZORDER BY (country, registration_date);
-- Vacuum (clean up old files)
VACUUM customers RETAIN 168 HOURS;
-- Describe table
DESCRIBE EXTENDED customers;
DESCRIBE HISTORY customers;
-- Show table properties
SHOW TBLPROPERTIES customers;
-- Constraints
ALTER TABLE customers ADD CONSTRAINT valid_email CHECK (email LIKE '%@%');
ALTER TABLE orders ADD CONSTRAINT positive_amount CHECK (amount > 0);
8. Delta Lake
8.1 Introduction to Delta Lake
Delta Lake is an open-source storage layer that brings ACID transactions to data lakes. It provides:
- ACID transactions
- Scalable metadata handling
- Time travel (data versioning)
- Schema enforcement and evolution
- Audit history
- Unified batch and streaming
8.2 Creating Delta Tables
from delta.tables import DeltaTable
# Create Delta table from DataFrame
df = spark.createDataFrame([
(1, "Alice", 30),
(2, "Bob", 25),
(3, "Charlie", 35)
], ["id", "name", "age"])
# Save as Delta
df.write.format("delta").mode("overwrite").save("/mnt/delta/people")
# Or save as managed table
df.write.format("delta").mode("overwrite").saveAsTable("people")
# Read Delta table
delta_df = spark.read.format("delta").load("/mnt/delta/people")
# Or read managed table
delta_df = spark.table("people")
8.3 ACID Transactions
# Append data
new_data = spark.createDataFrame([
(4, "David", 28),
(5, "Eve", 32)
], ["id", "name", "age"])
new_data.write.format("delta").mode("append").save("/mnt/delta/people")
# Overwrite with condition
(df.write
.format("delta")
.mode("overwrite")
.option("replaceWhere", "age > 30")
.save("/mnt/delta/people"))
# Upsert using MERGE
from delta.tables import DeltaTable
delta_table = DeltaTable.forPath(spark, "/mnt/delta/people")
updates = spark.createDataFrame([
(2, "Bob Smith", 26), # Update
(6, "Frank", 40) # Insert
], ["id", "name", "age"])
(delta_table.alias("target")
.merge(updates.alias("source"), "target.id = source.id")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute())
# Delete with condition
delta_table.delete("age < 25")
# Update
delta_table.update(
condition = "name = 'Alice'",
set = {"age": "31"}
)
8.4 Time Travel
# Read version
df_v0 = spark.read.format("delta").option("versionAsOf", 0).load("/mnt/delta/people")
# Read by timestamp
df_ts = spark.read.format("delta").option("timestampAsOf", "2024-01-01").load("/mnt/delta/people")
# View history
delta_table = DeltaTable.forPath(spark, "/mnt/delta/people")
delta_table.history().show()
# Restore to previous version
delta_table.restoreToVersion(2)
# Or restore to timestamp
delta_table.restoreToTimestamp("2024-01-01")
8.5 Schema Evolution
# Enable schema evolution
df_new_schema = spark.createDataFrame([
(7, "Grace", 29, "Engineering")
], ["id", "name", "age", "department"])
(df_new_schema.write
.format("delta")
.mode("append")
.option("mergeSchema", "true")
.save("/mnt/delta/people"))
# Schema enforcement (will fail if schema doesn't match)
df_wrong_schema = spark.createDataFrame([
(8, "Henry")
], ["id", "name"])
# This will fail without mergeSchema option
df_wrong_schema.write.format("delta").mode("append").save("/mnt/delta/people")
8.6 Optimization
# Optimize (compaction)
delta_table = DeltaTable.forPath(spark, "/mnt/delta/people")
delta_table.optimize().executeCompaction()
# Z-Ordering (multi-dimensional clustering)
delta_table.optimize().executeZOrderBy("age", "department")
# Vacuum (delete old files)
delta_table.vacuum(168) # Retain 7 days (168 hours)
# Auto-optimize (set on table creation)
spark.sql("""
CREATE TABLE optimized_table (id INT, name STRING)
USING DELTA
TBLPROPERTIES (
'delta.autoOptimize.optimizeWrite' = 'true',
'delta.autoOptimize.autoCompact' = 'true'
)
""")
8.7 Change Data Feed
# Enable Change Data Feed
spark.sql("""
ALTER TABLE people
SET TBLPROPERTIES (delta.enableChangeDataFeed = true)
""")
# Read changes
changes = spark.read.format("delta") \
.option("readChangeFeed", "true") \
.option("startingVersion", 0) \
.table("people")
changes.show()
# Changes have additional columns: _change_type, _commit_version, _commit_timestamp
9. Data Engineering on Databricks
9.1 Medallion Architecture
The Medallion Architecture organizes data into three layers:
Bronze Layer (Raw)
- Ingests raw data as-is
- Minimal transformations
- Preserves history
Silver Layer (Cleaned)
- Validated and cleansed data
- Standardized formats
- Business-level transformations
Gold Layer (Aggregated)
- Business-level aggregates
- Optimized for analytics
- Feature stores for ML
9.2 Bronze Layer Implementation
# Ingest raw JSON files into Bronze layer
from pyspark.sql.functions import current_timestamp, input_file_name
# Read raw data
bronze_df = (spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "json")
.option("cloudFiles.schemaLocation", "/mnt/bronze/schema")
.load("/mnt/raw/sales/"))
# Add metadata columns
bronze_enriched = (bronze_df
.withColumn("ingestion_timestamp", current_timestamp())
.withColumn("source_file", input_file_name()))
# Write to Bronze Delta table (streaming)
(bronze_enriched.writeStream
.format("delta")
.option("checkpointLocation", "/mnt/bronze/checkpoint")
.trigger(availableNow=True)
.table("bronze_sales"))
9.3 Silver Layer Implementation
from pyspark.sql.functions import col, to_date, regexp_replace, when
# Read from Bronze
bronze_df = spark.readStream.table("bronze_sales")
# Data quality checks and transformations
silver_df = (bronze_df
# Clean data
.withColumn("order_date", to_date(col("order_date_str"), "yyyy-MM-dd"))
.withColumn("amount", col("amount").cast("double"))
.withColumn("email", regexp_replace(col("email"), r"\s+", ""))
# Data quality flags
.withColumn("is_valid_amount", col("amount") > 0)
.withColumn("is_valid_email", col("email").rlike(r"^[\w\.-]+@[\w\.-]+\.\w+$"))
# Filter out invalid records
.filter(col("is_valid_amount") & col("is_valid_email"))
# Drop quality flag columns
.drop("is_valid_amount", "is_valid_email", "order_date_str")
# Add processing timestamp
.withColumn("processed_timestamp", current_timestamp()))
# Write to Silver table
(silver_df.writeStream
.format("delta")
.option("checkpointLocation", "/mnt/silver/checkpoint")
.outputMode("append")
.table("silver_sales"))
9.4 Gold Layer Implementation
from pyspark.sql.functions import sum, avg, count, window
# Read from Silver
silver_df = spark.readStream.table("silver_sales")
# Create business-level aggregations
gold_daily_summary = (silver_df
.groupBy(
window("order_date", "1 day"),
"product_category"
)
.agg(
sum("amount").alias("total_sales"),
avg("amount").alias("avg_order_value"),
count("*").alias("order_count")
)
.select(
col("window.start").alias("date"),
"product_category",
"total_sales",
"avg_order_value",
"order_count"
))
# Write to Gold table
(gold_daily_summary.writeStream
.format("delta")
.option("checkpointLocation", "/mnt/gold/checkpoint")
.outputMode("complete")
.table("gold_daily_sales_summary"))
9.5 Incremental Processing Patterns
# Pattern 1: Watermarking for late data
from pyspark.sql.functions import col, window
stream_df = (spark.readStream
.format("delta")
.table("bronze_events")
.withWatermark("event_timestamp", "1 hour") # Allow 1 hour late data
.groupBy(
window("event_timestamp", "10 minutes"),
"user_id"
)
.count())
# Pattern 2: Incremental batch processing
def process_incremental_batch(start_version, end_version):
changes = (spark.read
.format("delta")
.option("readChangeFeed", "true")
.option("startingVersion", start_version)
.option("endingVersion", end_version)
.table("source_table"))
# Process only changed records
processed = changes.filter(col("_change_type").isin(["insert", "update"]))
# Write to target
(processed.write
.format("delta")
.mode("append")
.saveAsTable("target_table"))
# Pattern 3: Merge for SCD Type 2
from delta.tables import DeltaTable
target = DeltaTable.forName(spark, "dim_customers")
updates = spark.table("customer_updates")
(target.alias("target").merge(
updates.alias("updates"),
"target.customer_id = updates.customer_id AND target.is_current = true"
)
.whenMatchedUpdate(
condition = "target.email != updates.email OR target.address != updates.address",
set = {
"is_current": "false",
"end_date": "current_date()"
}
)
.whenNotMatchedInsert(
values = {
"customer_id": "updates.customer_id",
"email": "updates.email",
"address": "updates.address",
"is_current": "true",
"start_date": "current_date()",
"end_date": "null"
}
).execute())
10. ETL/ELT Pipelines
10.1 Complete ETL Pipeline Example
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from delta.tables import DeltaTable
from typing import List, Dict
class ETLPipeline:
def __init__(self, spark_session):
self.spark = spark_session
def extract_from_source(self, source_path: str, file_format: str = "csv") -> DataFrame:
"""Extract data from various sources"""
if file_format == "csv":
return (self.spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(source_path))
elif file_format == "json":
return self.spark.read.json(source_path)
elif file_format == "parquet":
return self.spark.read.parquet(source_path)
elif file_format == "delta":
return self.spark.read.format("delta").load(source_path)
else:
raise ValueError(f"Unsupported format: {file_format}")
def transform_data(self, df: DataFrame, transformations: List[Dict]) -> DataFrame:
"""Apply series of transformations"""
result_df = df
for transform in transformations:
transform_type = transform.get("type")
if transform_type == "filter":
result_df = result_df.filter(transform["condition"])
elif transform_type == "select":
result_df = result_df.select(*transform["columns"])
elif transform_type == "add_column":
result_df = result_df.withColumn(
transform["name"],
expr(transform["expression"])
)
elif transform_type == "rename":
for old_name, new_name in transform["mapping"].items():
result_df = result_df.withColumnRenamed(old_name, new_name)
elif transform_type == "drop_duplicates":
subset = transform.get("subset", None)
result_df = result_df.dropDuplicates(subset)
elif transform_type == "join":
other_df = transform["dataframe"]
result_df = result_df.join(
other_df,
transform["on"],
transform.get("how", "inner")
)
elif transform_type == "aggregate":
result_df = result_df.groupBy(*transform["group_by"]).agg(
*[expr(agg) for agg in transform["aggregations"]]
)
return result_df
def apply_data_quality_checks(self, df: DataFrame, rules: List[Dict]) -> DataFrame:
"""Apply data quality rules"""
quality_df = df
for rule in rules:
rule_name = rule["name"]
condition = rule["condition"]
action = rule.get("action", "flag")
if action == "flag":
quality_df = quality_df.withColumn(
f"dq_{rule_name}",
when(expr(condition), lit(True)).otherwise(lit(False))
)
elif action == "filter":
quality_df = quality_df.filter(condition)
elif action == "quarantine":
# Separate bad records
bad_records = quality_df.filter(f"NOT ({condition})")
quality_df = quality_df.filter(condition)
# Save quarantined records
(bad_records.write
.format("delta")
.mode("append")
.saveAsTable(f"quarantine_{rule_name}"))
return quality_df
def load_to_target(
self,
df: DataFrame,
target_path: str,
mode: str = "append",
partition_by: List[str] = None,
optimize: bool = True
):
"""Load data to target location"""
writer = df.write.format("delta").mode(mode)
if partition_by:
writer = writer.partitionBy(*partition_by)
writer.save(target_path)
if optimize:
DeltaTable.forPath(self.spark, target_path).optimize().executeCompaction()
def run_pipeline(self, config: Dict):
"""Execute complete pipeline"""
# Extract
source_df = self.extract_from_source(
config["source"]["path"],
config["source"]["format"]
)
# Transform
transformed_df = self.transform_data(
source_df,
config["transformations"]
)
# Data Quality
quality_df = self.apply_data_quality_checks(
transformed_df,
config.get("quality_rules", [])
)
# Load
self.load_to_target(
quality_df,
config["target"]["path"],
config["target"].get("mode", "append"),
config["target"].get("partition_by"),
config["target"].get("optimize", True)
)
return quality_df
# Example usage
pipeline = ETLPipeline(spark)
config = {
"source": {
"path": "/mnt/raw/sales/",
"format": "csv"
},
"transformations": [
{
"type": "filter",
"condition": "amount > 0"
},
{
"type": "add_column",
"name": "year",
"expression": "year(order_date)"
},
{
"type": "add_column",
"name": "total_with_tax",
"expression": "amount * 1.08"
}
],
"quality_rules": [
{
"name": "valid_email",
"condition": "email RLIKE '^[\\\\w\\\\.-]+@[\\\\w\\\\.-]+\\\\.\\\\w+",
"action": "flag"
},
{
"name": "positive_amount",
"condition": "amount > 0",
"action": "filter"
}
],
"target": {
"path": "/mnt/processed/sales",
"mode": "overwrite",
"partition_by": ["year"],
"optimize": True
}
}
result = pipeline.run_pipeline(config)
10.2 Orchestrating with Databricks Workflows
# Task 1: Data Ingestion (Notebook 1)
# %run ./configs/pipeline_config
from datetime import datetime
# Define parameters
dbutils.widgets.text("start_date", "2024-01-01")
dbutils.widgets.text("end_date", "2024-01-31")
start_date = dbutils.widgets.get("start_date")
end_date = dbutils.widgets.get("end_date")
# Ingest data
raw_df = spark.read.parquet(f"/mnt/raw/data_{start_date}_to_{end_date}")
raw_df.write.format("delta").mode("append").saveAsTable("bronze.raw_sales")
# Return success
dbutils.notebook.exit("Success")
# Task 2: Data Transformation (Notebook 2)
# Depends on Task 1
bronze_df = spark.table("bronze.raw_sales")
# Apply transformations
silver_df = (bronze_df
.filter("amount > 0")
.withColumn("processed_date", current_date())
.dropDuplicates(["order_id"]))
silver_df.write.format("delta").mode("append").saveAsTable("silver.clean_sales")
dbutils.notebook.exit("Success")
# Task 3: Aggregation (Notebook 3)
# Depends on Task 2
silver_df = spark.table("silver.clean_sales")
gold_df = (silver_df
.groupBy("product_category", "region")
.agg(
sum("amount").alias("total_sales"),
count("*").alias("transaction_count")
))
gold_df.write.format("delta").mode("overwrite").saveAsTable("gold.sales_summary")
dbutils.notebook.exit("Success")
11. Machine Learning on Databricks
11.1 MLflow Integration
import mlflow
import mlflow.sklearn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pandas as pd
# Load data
df = spark.table("features.customer_churn").toPandas()
# Prepare features
X = df.drop(["customer_id", "churned"], axis=1)
y = df["churned"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Set experiment
mlflow.set_experiment("/Users/your_email/churn_prediction")
# Start MLflow run
with mlflow.start_run(run_name="random_forest_v1"):
# Log parameters
n_estimators = 100
max_depth = 10
mlflow.log_param("n_estimators", n_estimators)
mlflow.log_param("max_depth", max_depth)
# Train model
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=42
)
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
# Log metrics
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("precision", precision)
mlflow.log_metric("recall", recall)
mlflow.log_metric("f1_score", f1)
# Log model
mlflow.sklearn.log_model(model, "model")
# Log feature importance
feature_importance = pd.DataFrame({
'feature': X.columns,
'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
mlflow.log_table(feature_importance, "feature_importance.json")
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
11.2 Hyperparameter Tuning with Hyperopt
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
from hyperopt.pyll import scope
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
import numpy as np
# Define search space
search_space = {
'n_estimators': scope.int(hp.quniform('n_estimators', 50, 500, 50)),
'max_depth': scope.int(hp.quniform('max_depth', 3, 20, 1)),
'min_samples_split': scope.int(hp.quniform('min_samples_split', 2, 20, 1)),
'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 10, 1)),
'max_features': hp.choice('max_features', ['sqrt', 'log2', None])
}
# Objective function
def objective(params):
with mlflow.start_run(nested=True):
# Log parameters
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(
n_estimators=int(params['n_estimators']),
max_depth=int(params['max_depth']),
min_samples_split=int(params['min_samples_split']),
min_samples_leaf=int(params['min_samples_leaf']),
max_features=params['max_features'],
random_state=42
)
# Cross-validation
cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring='f1')
mean_score = np.mean(cv_scores)
# Log metrics
mlflow.log_metric("cv_f1_mean", mean_score)
mlflow.log_metric("cv_f1_std", np.std(cv_scores))
# Return negative score (hyperopt minimizes)
return {'loss': -mean_score, 'status': STATUS_OK}
# Run optimization
with mlflow.start_run(run_name="hyperopt_tuning"):
trials = Trials()
best_params = fmin(
fn=objective,
space=search_space,
algo=tpe.suggest,
max_evals=50,
trials=trials
)
print("Best parameters:", best_params)
mlflow.log_params(best_params)
11.3 Feature Engineering with PySpark
from pyspark.ml.feature import VectorAssembler, StandardScaler, StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.sql.functions import *
# Load data
df = spark.table("raw.customers")
# Create features
feature_df = (df
# Temporal features
.withColumn("days_since_registration", datediff(current_date(), col("registration_date")))
.withColumn("registration_month", month("registration_date"))
.withColumn("registration_year", year("registration_date"))
# Aggregated features
.join(
spark.table("transactions")
.groupBy("customer_id")
.agg(
sum("amount").alias("total_spent"),
avg("amount").alias("avg_transaction"),
count("*").alias("transaction_count"),
max("transaction_date").alias("last_transaction_date")
),
"customer_id",
"left"
)
.withColumn("days_since_last_transaction",
datediff(current_date(), col("last_transaction_date")))
# Ratio features
.withColumn("avg_monthly_spend",
col("total_spent") / greatest(col("days_since_registration") / 30, lit(1)))
)
# String indexing and one-hot encoding
categorical_cols = ["country", "product_preference", "customer_segment"]
indexed_cols = [f"{col}_index" for col in categorical_cols]
encoded_cols = [f"{col}_encoded" for col in categorical_cols]
# Create indexers
indexers = [
StringIndexer(inputCol=col, outputCol=f"{col}_index")
for col in categorical_cols
]
# Create encoders
encoders = [
OneHotEncoder(inputCol=f"{col}_index", outputCol=f"{col}_encoded")
for col in categorical_cols
]
# Numerical features
numerical_cols = [
"days_since_registration", "total_spent", "avg_transaction",
"transaction_count", "days_since_last_transaction", "avg_monthly_spend"
]
# Assemble features
assembler = VectorAssembler(
inputCols=numerical_cols + encoded_cols,
outputCol="features_unscaled"
)
# Scale features
scaler = StandardScaler(
inputCol="features_unscaled",
outputCol="features",
withStd=True,
withMean=True
)
# Create pipeline
pipeline = Pipeline(stages=indexers + encoders + [assembler, scaler])
# Fit and transform
model = pipeline.fit(feature_df)
transformed_df = model.transform(feature_df)
# Save feature table
(transformed_df
.select("customer_id", "features", "churned")
.write
.format("delta")
.mode("overwrite")
.saveAsTable("features.customer_features"))
11.4 Model Deployment and Serving
import mlflow
from mlflow.tracking import MlflowClient
# Register model
model_name = "churn_prediction_model"
model_uri = "runs:/<run_id>/model"
# Register in MLflow Model Registry
result = mlflow.register_model(model_uri, model_name)
# Transition to production
client = MlflowClient()
client.transition_model_version_stage(
name=model_name,
version=result.version,
stage="Production"
)
# Load model for batch prediction
model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
# Make predictions
new_data = spark.table("features.new_customers")
predictions = model.predict(new_data.toPandas())
# Create predictions DataFrame
predictions_df = new_data.withColumn(
"churn_probability",
lit(predictions).cast("double")
)
# Save predictions
(predictions_df
.write
.format("delta")
.mode("overwrite")
.saveAsTable("predictions.customer_churn"))
# Real-time serving with Model Serving endpoint
# This is done through Databricks UI or API
# After creating endpoint, you can call it:
import requests
import json
endpoint_url = "https://<databricks-instance>/serving-endpoints/<endpoint-name>/invocations"
token = dbutils.secrets.get(scope="<scope>", key="<key>")
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
data = {
"dataframe_records": [
{
"days_since_registration": 365,
"total_spent": 5000,
"avg_transaction": 100,
"transaction_count": 50
}
]
}
response = requests.post(endpoint_url, headers=headers, json=data)
predictions = response.json()
12. Databricks Workflows & Jobs
12.1 Creating Jobs Programmatically
from databricks_cli.sdk import ApiClient, JobsService
from databricks_cli.jobs.api import JobsApi
# Initialize API client
api_client = ApiClient(
host="https://<databricks-instance>",
token=dbutils.secrets.get(scope="<scope>", key="token")
)
jobs_api = JobsApi(api_client)
# Define job configuration
job_config = {
"name": "Daily ETL Pipeline",
"tasks": [
{
"task_key": "ingest_data",
"notebook_task": {
"notebook_path": "/Pipelines/01_ingest",
"base_parameters": {
"start_date": "{{job.start_time.date}}",
"environment": "production"
}
},
"new_cluster": {
"spark_version": "13.3.x-scala2.12",
"node_type_id": "i3.xlarge",
"num_workers": 2
}
},
{
"task_key": "transform_data",
"depends_on": [{"task_key": "ingest_data"}],
"notebook_task": {
"notebook_path": "/Pipelines/02_transform"
},
"existing_cluster_id": "<cluster-id>"
},
{
"task_key": "aggregate_data",
"depends_on": [{"task_key": "transform_data"}],
"spark_python_task": {
"python_file": "dbfs:/scripts/aggregate.py",
"parameters": ["--output", "/mnt/gold/aggregates"]
},
"existing_cluster_id": "<cluster-id>"
},
{
"task_key": "data_quality_check",
"depends_on": [{"task_key": "aggregate_data"}],
"sql_task": {
"query": {
"query_id": "<query-id>"
},
"warehouse_id": "<warehouse-id>"
}
}
],
"schedule": {
"quartz_cron_expression": "0 0 2 * * ?", # Daily at 2 AM
"timezone_id": "America/Los_Angeles",
"pause_status": "UNPAUSED"
},
"email_notifications": {
"on_failure": ["data-team@company.com"],
"on_success": ["data-team@company.com"]
},
"timeout_seconds": 7200,
"max_concurrent_runs": 1
}
# Create job
job_id = jobs_api.create_job(job_config)["job_id"]
print(f"Created job with ID: {job_id}")
# Run job immediately
run_id = jobs_api.run_now(job_id)["run_id"]
print(f"Started run with ID: {run_id}")
12.2 Monitoring and Alerting
# Monitor job runs
from databricks_cli.runs.api import RunsApi
runs_api = RunsApi(api_client)
# Get run status
run = runs_api.get_run(run_id)
print(f"Run state: {run['state']['life_cycle_state']}")
print(f"Result: {run['state'].get('result_state', 'RUNNING')}")
# Custom monitoring function
def monitor_critical_job(job_id, check_interval=60):
import time
while True:
runs = jobs_api.list_runs(job_id, active_only=True)
if runs.get("runs"):
for run in runs["runs"]:
run_id = run["run_id"]
state = run["state"]["life_cycle_state"]
if state == "TERMINATED":
result = run["state"]["result_state"]
if result == "FAILED":
# Send alert
send_slack_alert(f"Job {job_id} run {run_id} failed!")
# Get error details
task_runs = run.get("tasks", [])
for task in task_runs:
if task["state"]["result_state"] == "FAILED":
error = task["state"].get("state_message", "Unknown error")
print(f"Task {task['task_key']} failed: {error}")
break
time.sleep(check_interval)
# Data quality monitoring
def check_data_quality(table_name, checks):
"""
checks = [
{"name": "row_count", "min": 1000, "max": 1000000},
{"name": "null_check", "column": "email", "max_null_pct": 0.01},
{"name": "duplicates", "columns": ["user_id"], "max_duplicates": 0}
]
"""
df = spark.table(table_name)
results = []
for check in checks:
if check["name"] == "row_count":
count = df.count()
passed = check["min"] <= count <= check["max"]
results.append({
"check": "row_count",
"passed": passed,
"value": count,
"expected": f"{check['min']} - {check['max']}"
})
elif check["name"] == "null_check":
null_count = df.filter(col(check["column"]).isNull()).count()
total_count = df.count()
null_pct = null_count / total_count if total_count > 0 else 0
passed = null_pct <= check["max_null_pct"]
results.append({
"check": f"null_check_{check['column']}",
"passed": passed,
"value": null_pct,
"expected": f"<= {check['max_null_pct']}"
})
elif check["name"] == "duplicates":
dup_count = df.groupBy(*check["columns"]).count().filter("count > 1").count()
passed = dup_count <= check["max_duplicates"]
results.append({
"check": "duplicates",
"passed": passed,
"value": dup_count,
"expected": f"<= {check['max_duplicates']}"
})
return results
# Usage
quality_checks = [
{"name": "row_count", "min": 1000, "max": 10000000},
{"name": "null_check", "column": "email", "max_null_pct": 0.01},
{"name": "duplicates", "columns": ["customer_id"], "max_duplicates": 0}
]
results = check_data_quality("silver.customers", quality_checks)
for result in results:
if not result["passed"]:
print(f"โ Quality check failed: {result['check']}")
print(f" Value: {result['value']}, Expected: {result['expected']}")
else:
print(f"โ
Quality check passed: {result['check']}")
13. Security & Governance
13.1 Unity Catalog
-- Create catalog
CREATE CATALOG IF NOT EXISTS production;
-- Create schema
CREATE SCHEMA IF NOT EXISTS production.sales;
-- Grant permissions
GRANT USE CATALOG ON CATALOG production TO `data-analysts`;
GRANT USE SCHEMA ON SCHEMA production.sales TO `data-analysts`;
GRANT SELECT ON TABLE production.sales.orders TO `data-analysts`;
GRANT MODIFY ON TABLE production.sales.orders TO `data-engineers`;
-- Create external location
CREATE EXTERNAL LOCATION sales_data
URL 's3://company-data/sales/'
WITH (STORAGE CREDENTIAL aws_credentials);
-- Create managed table with Unity Catalog
CREATE TABLE production.sales.customers (
customer_id BIGINT,
customer_name STRING,
email STRING,
created_at TIMESTAMP
)
USING DELTA
LOCATION 'sales_data/customers/';
-- Column-level security (Dynamic Views)
CREATE VIEW production.sales.customers_masked AS
SELECT
customer_id,
customer_name,
CASE
WHEN is_member('pii-access') THEN email
ELSE 'REDACTED'
END AS email,
created_at
FROM production.sales.customers;
-- Row-level security
CREATE VIEW production.sales.regional_orders AS
SELECT *
FROM production.sales.orders
WHERE region = current_user_region(); -- Custom function
-- Audit queries
SELECT
event_time,
user_identity,
service_name,
action_name,
request_params
FROM system.access.audit
WHERE action_name = 'getTable'
AND event_date >= current_date() - 7;
13.2 Secrets Management
# Create secret scope (done via CLI or API)
# databricks secrets create-scope --scope my-scope
# Store secrets
# databricks secrets put --scope my-scope --key db-password
# Access secrets in notebooks
db_password = dbutils.secrets.get(scope="my-scope", key="db-password")
aws_access_key = dbutils.secrets.get(scope="my-scope", key="aws-access-key")
# Use in connection strings
jdbc_url = f"jdbc:postgresql://host:5432/db"
connection_properties = {
"user": dbutils.secrets.get(scope="my-scope", key="db-username"),
"password": dbutils.secrets.get(scope="my-scope", key="db-password"),
"driver": "org.postgresql.Driver"
}
df = spark.read.jdbc(url=jdbc_url, table="users", properties=connection_properties)
# Environment-specific configurations
env = dbutils.widgets.get("environment") # dev, staging, prod
config = {
"db_host": dbutils.secrets.get(scope=f"{env}-scope", key="db-host"),
"api_key": dbutils.secrets.get(scope=f"{env}-scope", key="api-key")
}
13.3 Data Lineage and Discovery
# Track lineage with Unity Catalog
from pyspark.sql.functions import current_timestamp
# Add lineage metadata
df_with_lineage = (source_df
.withColumn("source_system", lit("CRM"))
.withColumn("ingestion_timestamp", current_timestamp())
.withColumn("pipeline_version", lit("v2.1.0")))
# Write with lineage tracking
(df_with_lineage
.write
.format("delta")
.mode("append")
.option("mergeSchema", "true")
.saveAsTable("production.sales.customers"))
# Query lineage via SQL
spark.sql("""
SELECT
table_catalog,
table_schema,
table_name,
upstream_tables
FROM system.information_schema.table_lineage
WHERE table_name = 'customers'
""").show()
# Add table comments and tags
spark.sql("""
ALTER TABLE production.sales.customers
SET TBLPROPERTIES (
'description' = 'Customer master data from CRM system',
'data_owner' = 'sales-team@company.com',
'pii' = 'true',
'refresh_frequency' = 'daily'
)
""")
# Add column comments
spark.sql("""
ALTER TABLE production.sales.customers
ALTER COLUMN email
COMMENT 'Customer email address - PII field'
""")
13.4 Access Control Best Practices
# Implement attribute-based access control (ABAC)
from pyspark.sql.functions import current_user
def apply_row_level_security(df, user_groups):
"""Apply row-level security based on user attributes"""
if "admin" in user_groups:
# Admins see everything
return df
elif "regional_manager" in user_groups:
# Regional managers see their region
user_region = get_user_region() # Custom function
return df.filter(col("region") == user_region)
elif "team_lead" in user_groups:
# Team leads see their team
user_team = get_user_team() # Custom function
return df.filter(col("team_id") == user_team)
else:
# Default: see only own records
return df.filter(col("user_id") == current_user())
# Column masking function
def apply_column_masking(df, sensitive_columns, user_groups):
"""Mask sensitive columns based on user permissions"""
result_df = df
if "pii-access" not in user_groups:
for col_name in sensitive_columns:
if col_name == "email":
result_df = result_df.withColumn(
col_name,
regexp_replace(col(col_name), r"(?<=.{2}).*(?=@)", "***")
)
elif col_name == "ssn":
result_df = result_df.withColumn(
col_name,
regexp_replace(col(col_name), r"\d(?=\d{4})", "*")
)
elif col_name == "phone":
result_df = result_df.withColumn(
col_name,
regexp_replace(col(col_name), r"\d(?=\d{4}$)", "*")
)
return result_df
14. Performance Optimization
14.1 Spark Optimization Techniques
# 1. Broadcast Joins for small tables
from pyspark.sql.functions import broadcast
large_df = spark.table("fact_sales") # 100M rows
small_df = spark.table("dim_products") # 10K rows
# Without broadcast (shuffle join)
result = large_df.join(small_df, "product_id")
# With broadcast (faster for small tables < 10MB)
result_optimized = large_df.join(broadcast(small_df), "product_id")
# 2. Partitioning strategies
# Bad: Too many small files
df.write.partitionBy("year", "month", "day", "hour").parquet("/path") # Avoid!
# Good: Reasonable partition size (128MB - 1GB per partition)
df.coalesce(100).write.partitionBy("year", "month").parquet("/path")
# 3. Bucketing for repeated joins
(df.write
.bucketBy(100, "customer_id")
.sortBy("order_date")
.saveAsTable("bucketed_orders"))
# 4. Caching strategically
# Cache only when reusing DataFrames multiple times
df_filtered = df.filter("amount > 1000")
df_filtered.cache()
result1 = df_filtered.groupBy("category").count()
result2 = df_filtered.groupBy("region").sum("amount")
df_filtered.unpersist() # Free memory when done
# 5. Avoid UDFs when possible - use built-in functions
# Bad: UDF
@udf(StringType())
def categorize_slow(amount):
if amount < 100:
return "low"
elif amount < 1000:
return "medium"
else:
return "high"
df.withColumn("category", categorize_slow(col("amount")))
# Good: Built-in functions
df.withColumn("category",
when(col("amount") < 100, "low")
.when(col("amount") < 1000, "medium")
.otherwise("high")
)
# 6. Filter early
# Bad: Filter after expensive operations
result = df.join(other_df, "id").groupBy("category").count().filter("count > 100")
# Good: Filter before expensive operations
df_filtered = df.filter("relevant_condition")
result = df_filtered.join(other_df, "id").groupBy("category").count()
# 7. Use appropriate file formats
# Bad for analytics: CSV, JSON
df.write.csv("/path") # Slow reads, no predicate pushdown
# Good: Parquet or Delta
df.write.format("parquet").save("/path") # Fast, columnar
df.write.format("delta").save("/path") # Fast, ACID, time travel
# 8. Partition pruning
df = spark.read.parquet("/data/partitioned/")
# Ensure filters on partition columns
filtered = df.filter("year = 2024 AND month = 1") # Reads only relevant partitions
# 9. Column pruning
# Bad: Read all columns
df = spark.read.parquet("/data").select("id", "name")
# Good: Read only needed columns
df = spark.read.parquet("/data").select("id", "name") # Automatic column pruning
# 10. Repartition vs Coalesce
# Coalesce: Only reduce partitions (no shuffle)
df.coalesce(10).write.parquet("/path")
# Repartition: Can increase/decrease (full shuffle)
df.repartition(100, "customer_id").write.parquet("/path")
14.2 Delta Lake Optimization
from delta.tables import DeltaTable
# 1. OPTIMIZE with Z-Ordering
delta_table = DeltaTable.forPath(spark, "/mnt/delta/sales")
# Compact small files
delta_table.optimize().executeCompaction()
# Z-Order by frequently filtered columns
delta_table.optimize().executeZOrderBy("customer_id", "order_date")
# 2. Auto-optimize (set at table creation)
spark.sql("""
CREATE TABLE optimized_sales (
order_id BIGINT,
customer_id BIGINT,
amount DOUBLE,
order_date DATE
)
USING DELTA
TBLPROPERTIES (
'delta.autoOptimize.optimizeWrite' = 'true',
'delta.autoOptimize.autoCompact' = 'true'
)
""")
# 3. Data skipping with statistics
# Delta automatically collects min/max statistics
# Query with filters uses statistics to skip files
df = spark.read.format("delta").load("/mnt/delta/sales")
filtered = df.filter((col("order_date") >= "2024-01-01") & (col("amount") > 1000))
# 4. Bloom filter indexes for high-cardinality columns
spark.sql("""
CREATE BLOOMFILTER INDEX ON TABLE sales
FOR COLUMNS(customer_email)
""")
# 5. Liquid clustering (new feature)
spark.sql("""
CREATE TABLE clustered_sales (
order_id BIGINT,
customer_id BIGINT,
product_id BIGINT,
amount DOUBLE
)
USING DELTA
CLUSTER BY (customer_id, product_id)
""")
# 6. Optimize write performance
(df.write
.format("delta")
.option("optimizeWrite", "true") # Automatically optimize during write
.option("dataChange", "false") # Skip change data feed if not needed
.mode("append")
.save("/mnt/delta/sales"))
# 7. Vacuum old files
delta_table.vacuum(168) # Delete files older than 7 days
# Set retention period
spark.sql("""
ALTER TABLE sales
SET TBLPROPERTIES ('delta.deletedFileRetentionDuration' = 'interval 7 days')
""")
# 8. Partition pruning
# Create partitioned table
(df.write
.format("delta")
.partitionBy("year", "month")
.save("/mnt/delta/partitioned_sales"))
# Query with partition filters
result = spark.read.format("delta").load("/mnt/delta/partitioned_sales") \
.filter("year = 2024 AND month = 1")
14.3 Query Performance Tuning
# 1. Analyze query execution plan
df = spark.table("large_table").filter("amount > 1000").groupBy("category").count()
# View physical plan
df.explain()
# View detailed plan with statistics
df.explain(True)
# View formatted plan
df.explain("formatted")
# 2. Use SQL Adaptive Query Execution (AQE)
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# 3. Configure shuffle partitions
# Default 200 often too high for small data, too low for large data
spark.conf.set("spark.sql.shuffle.partitions", "auto") # Let AQE decide
# Or set manually based on data size
spark.conf.set("spark.sql.shuffle.partitions", "1000")
# 4. Monitor and tune memory
spark.conf.set("spark.executor.memory", "16g")
spark.conf.set("spark.executor.memoryOverhead", "2g")
spark.conf.set("spark.memory.fraction", "0.8")
spark.conf.set("spark.memory.storageFraction", "0.3")
# 5. Enable dynamic partition pruning
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
# 6. Use columnar cache for repeated queries
spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "true")
spark.conf.set("spark.sql.inMemoryColumnarStorage.batchSize", "10000")
# 7. Collect statistics for cost-based optimization
spark.sql("ANALYZE TABLE sales COMPUTE STATISTICS")
spark.sql("ANALYZE TABLE sales COMPUTE STATISTICS FOR ALL COLUMNS")
# 8. Monitor with Spark UI metrics
# Check for:
# - Data skew (tasks with widely different durations)
# - Spill to disk (memory pressure)
# - Shuffle read/write sizes
# - GC time (should be < 10% of task time)
14.4 Cluster Configuration Best Practices
# Cluster sizing guidelines
# 1. Driver vs Executor memory
# Driver: 2-8 GB for most workloads
# Executors: Based on data volume
# Small (< 1 TB): 8-16 GB per executor
# Medium (1-10 TB): 16-32 GB per executor
# Large (> 10 TB): 32-64 GB per executor
# 2. Executor cores
# Recommendation: 4-5 cores per executor
# More cores = more tasks per executor
# Fewer cores = more executors = better parallelism
# 3. Instance types
# General purpose: i3.xlarge, i3.2xlarge
# Memory-optimized: r5.xlarge, r5.2xlarge
# Compute-optimized: c5.xlarge, c5.2xlarge
# 4. Auto-scaling configuration
# Min workers: Based on minimum expected load
# Max workers: Based on budget and peak load
# Scale down after: 10-15 minutes of idle time
# 5. Spot instances
# Use for fault-tolerant workloads
# Mix spot and on-demand: 50-70% spot for cost savings
# Example cluster configuration via API
cluster_config = {
"cluster_name": "production-etl-cluster",
"spark_version": "13.3.x-scala2.12",
"node_type_id": "i3.xlarge",
"driver_node_type_id": "i3.2xlarge",
"autoscale": {
"min_workers": 2,
"max_workers": 20
},
"auto_termination_minutes": 120,
"spark_conf": {
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true",
"spark.databricks.delta.optimizeWrite.enabled": "true",
"spark.sql.shuffle.partitions": "auto"
},
"aws_attributes": {
"availability": "SPOT_WITH_FALLBACK",
"spot_bid_price_percent": 100,
"first_on_demand": 1,
"ebs_volume_type": "GENERAL_PURPOSE_SSD",
"ebs_volume_count": 1,
"ebs_volume_size": 100
}
}
15. Advanced Topics
15.1 Structured Streaming
# 1. Reading streaming data
from pyspark.sql.functions import col, window, count, avg
# Read from cloud storage with Auto Loader
streaming_df = (spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "json")
.option("cloudFiles.schemaLocation", "/mnt/schemas/events")
.option("cloudFiles.inferColumnTypes", "true")
.load("/mnt/landing/events/"))
# Read from Kafka
kafka_df = (spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "kafka:9092")
.option("subscribe", "orders")
.option("startingOffsets", "latest")
.load())
# Parse Kafka messages
from pyspark.sql.functions import from_json
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
schema = StructType([
StructField("order_id", StringType()),
StructField("customer_id", StringType()),
StructField("amount", DoubleType()),
StructField("timestamp", StringType())
])
parsed_df = kafka_df.select(
from_json(col("value").cast("string"), schema).alias("data")
).select("data.*")
# 2. Stateful streaming operations
# Tumbling window aggregation
windowed_counts = (streaming_df
.withWatermark("timestamp", "10 minutes")
.groupBy(
window("timestamp", "5 minutes"),
"event_type"
)
.count())
# Sliding window
sliding_avg = (streaming_df
.withWatermark("timestamp", "10 minutes")
.groupBy(
window("timestamp", "10 minutes", "5 minutes"),
"user_id"
)
.agg(avg("amount").alias("avg_amount")))
# 3. Stream-stream joins
orders_stream = spark.readStream.table("streaming_orders")
shipments_stream = spark.readStream.table("streaming_shipments")
joined = (orders_stream
.withWatermark("order_time", "1 hour")
.join(
shipments_stream.withWatermark("ship_time", "2 hours"),
expr("""
order_id = shipment_order_id AND
ship_time >= order_time AND
ship_time <= order_time + interval 1 hour
""")
))
# 4. Stream-batch joins
stream_df = spark.readStream.table("streaming_events")
dimension_df = spark.table("dim_products") # Static table
enriched = stream_df.join(dimension_df, "product_id")
# 5. Writing streaming data
# Append mode (default)
(windowed_counts.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/mnt/checkpoints/window_counts")
.table("streaming_window_counts"))
# Update mode (for aggregations)
(sliding_avg.writeStream
.format("delta")
.outputMode("update")
.option("checkpointLocation", "/mnt/checkpoints/sliding_avg")
.table("streaming_sliding_avg"))
# Complete mode (entire result table)
(streaming_df.groupBy("category").count()
.writeStream
.format("delta")
.outputMode("complete")
.option("checkpointLocation", "/mnt/checkpoints/category_counts")
.table("streaming_category_counts"))
# 6. Trigger modes
# Process all available data then stop
query = (streaming_df.writeStream
.trigger(availableNow=True)
.format("delta")
.option("checkpointLocation", "/mnt/checkpoints")
.table("output_table"))
# Process every 5 minutes
query = (streaming_df.writeStream
.trigger(processingTime="5 minutes")
.format("delta")
.option("checkpointLocation", "/mnt/checkpoints")
.table("output_table"))
# Continuous processing (experimental, low latency)
query = (streaming_df.writeStream
.trigger(continuous="1 second")
.format("delta")
.option("checkpointLocation", "/mnt/checkpoints")
.table("output_table"))
# 7. Monitoring streaming queries
# Get active streams
for stream in spark.streams.active:
print(f"ID: {stream.id}")
print(f"Status: {stream.status}")
print(f"Recent progress: {stream.recentProgress}")
# Await termination
query.awaitTermination()
# Stop query
query.stop()
# 8. Handling late data and watermarking
streaming_with_watermark = (streaming_df
.withWatermark("event_time", "2 hours") # Allow 2 hours late data
.groupBy(
window("event_time", "1 hour"),
"product_id"
)
.agg(count("*").alias("event_count")))
15.2 Working with Semi-Structured Data
# 1. JSON data
from pyspark.sql.functions import get_json_object, from_json, to_json, schema_of_json
# Sample JSON
json_df = spark.read.json("/path/to/json")
# Extract nested fields
df = json_df.select(
col("id"),
get_json_object(col("payload"), "$.user.name").alias("user_name"),
get_json_object(col("payload"), "$.user.age").alias("user_age")
)
# Schema inference
sample_json = '{"user": {"name": "Alice", "age": 30}, "items": [1, 2, 3]}'
schema = schema_of_json(lit(sample_json))
# Parse with schema
parsed = json_df.select(
from_json(col("json_string"), schema).alias("data")
).select("data.*")
# 2. XML data (using databricks xml library)
xml_df = (spark.read
.format("xml")
.option("rowTag", "book")
.load("/path/to/books.xml"))
# 3. Arrays and Maps
from pyspark.sql.functions import explode, explode_outer, posexplode, map_keys, map_values
# Explode arrays
df_with_array = spark.createDataFrame([
(1, ["a", "b", "c"]),
(2, ["x", "y"])
], ["id", "items"])
exploded = df_with_array.select(
col("id"),
explode("items").alias("item")
)
# Explode with position
pos_exploded = df_with_array.select(
col("id"),
posexplode("items").alias("pos", "item")
)
# Maps
df_with_map = spark.createDataFrame([
(1, {"key1": "value1", "key2": "value2"}),
(2, {"keyA": "valueA"})
], ["id", "properties"])
df_with_map.select(
col("id"),
map_keys("properties").alias("keys"),
map_values("properties").alias("values"),
col("properties")["key1"].alias("key1_value")
).show()
# 4. Nested structures
from pyspark.sql.functions import struct
# Create nested structure
nested_df = df.select(
col("id"),
struct(
col("name"),
col("age"),
col("email")
).alias("user_info")
)
# Access nested fields
nested_df.select(
col("id"),
col("user_info.name"),
col("user_info.age")
).show()
# 5. Complex transformations
from pyspark.sql.functions import transform, filter as array_filter, aggregate
df_arrays = spark.createDataFrame([
(1, [1, 2, 3, 4, 5]),
(2, [10, 20, 30])
], ["id", "numbers"])
# Transform array elements
df_arrays.select(
col("id"),
transform("numbers", lambda x: x * 2).alias("doubled")
).show()
# Filter array elements
df_arrays.select(
col("id"),
array_filter("numbers", lambda x: x > 2).alias("filtered")
).show()
# Aggregate array
df_arrays.select(
col("id"),
aggregate("numbers", lit(0), lambda acc, x: acc + x).alias("sum")
).show()
15.3 Graph Processing with GraphFrames
from graphframes import GraphFrame
# Create vertices DataFrame
vertices = spark.createDataFrame([
("1", "Alice", 34),
("2", "Bob", 36),
("3", "Charlie", 30),
("4", "David", 29),
("5", "Eve", 32)
], ["id", "name", "age"])
# Create edges DataFrame
edges = spark.createDataFrame([
("1", "2", "friend"),
("2", "3", "follow"),
("3", "4", "friend"),
("4", "1", "follow"),
("5", "1", "friend"),
("5", "4", "friend")
], ["src", "dst", "relationship"])
# Create GraphFrame
g = GraphFrame(vertices, edges)
# Query the graph
print("Vertices:")
g.vertices.show()
print("Edges:")
g.edges.show()
# In-degree (followers)
in_degrees = g.inDegrees
in_degrees.show()
# Out-degree (following)
out_degrees = g.outDegrees
out_degrees.show()
# PageRank
results = g.pageRank(resetProbability=0.15, maxIter=10)
results.vertices.select("id", "name", "pagerank").show()
# Find motifs (patterns)
# Find chains of length 2
motifs = g.find("(a)-[e1]->(b); (b)-[e2]->(c)")
motifs.show()
# Connected components
cc = g.connectedComponents()
cc.select("id", "name", "component").show()
# Shortest paths
paths = g.shortestPaths(landmarks=["1", "4"])
paths.select("id", "name", "distances").show()
# Breadth-first search
bfs_result = g.bfs("name = 'Alice'", "name = 'David'", maxPathLength=3)
bfs_result.show()
15.4 Delta Live Tables (DLT)
# DLT Pipeline Definition (Python)
import dlt
from pyspark.sql.functions import col, current_timestamp
# Bronze table - raw data ingestion
@dlt.table(
comment="Raw sales data from source system",
table_properties={
"quality": "bronze",
"pipelines.autoOptimize.zOrderCols": "order_date"
}
)
def bronze_sales():
return (
spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "json")
.option("cloudFiles.schemaLocation", "/mnt/schema/sales")
.load("/mnt/landing/sales")
.select("*", current_timestamp().alias("ingestion_time"))
)
# Silver table - cleaned and validated
@dlt.table(
comment="Cleaned sales data with quality checks"
)
@dlt.expect_or_drop("valid_amount", "amount > 0")
@dlt.expect_or_drop("valid_date", "order_date IS NOT NULL")
@dlt.expect("valid_email", "email RLIKE '^[\\\\w\\\\.-]+@[\\\\w\\\\.-]+\\\\.\\\\w+", 0.95)
def silver_sales():
return (
dlt.read_stream("bronze_sales")
.filter("amount > 0")
.filter("order_date IS NOT NULL")
.select(
"order_id",
"customer_id",
"amount",
"order_date",
"email",
"product_id"
)
)
# Gold table - business aggregations
@dlt.table(
comment="Daily sales summary by product"
)
def gold_daily_sales():
return (
dlt.read("silver_sales")
.groupBy("order_date", "product_id")
.agg(
sum("amount").alias("total_sales"),
count("*").alias("order_count"),
avg("amount").alias("avg_order_value")
)
)
# Materialized view
@dlt.view(
comment="Recent high-value orders"
)
def high_value_recent_orders():
return (
dlt.read("silver_sales")
.filter("amount > 1000")
.filter("order_date >= current_date() - 30")
)
# Incremental processing
@dlt.table
def incremental_aggregates():
return (
dlt.read_stream("silver_sales")
.withWatermark("order_date", "1 day")
.groupBy(window("order_date", "1 hour"), "product_id")
.agg(sum("amount").alias("hourly_sales"))
)
# CDC processing
@dlt.table
def customer_dimension():
return (
dlt.read_stream("bronze_customers")
.select("customer_id", "name", "email", "updated_at")
)
# Apply changes (SCD Type 1)
dlt.apply_changes(
target="customer_dimension",
source="customer_updates",
keys=["customer_id"],
sequence_by="updated_at",
stored_as_scd_type=1
)
16. Real-World Projects
Project 1: E-Commerce Data Pipeline
"""
Complete e-commerce analytics pipeline
- Ingest from multiple sources
- Process orders, customers, products
- Create analytics-ready datasets
"""
# 1. Bronze Layer - Ingest raw data
class EcommerceBronzeLayer:
def __init__(self, spark):
self.spark = spark
def ingest_orders(self):
return (self.spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "json")
.option("cloudFiles.schemaLocation", "/mnt/bronze/schema/orders")
.load("/mnt/landing/orders/")
.withColumn("ingestion_timestamp", current_timestamp())
.writeStream
.format("delta")
.option("checkpointLocation", "/mnt/bronze/checkpoints/orders")
.trigger(availableNow=True)
.table("bronze.orders"))
def ingest_customers(self):
# JDBC source
return (self.spark.read
.format("jdbc")
.option("url", "jdbc:mysql://localhost:3306/crm")
.option("dbtable", "customers")
.option("user", dbutils.secrets.get("mysql", "username"))
.option("password", dbutils.secrets.get("mysql", "password"))
.load()
.write
.format("delta")
.mode("overwrite")
.saveAsTable("bronze.customers"))
def ingest_products(self):
# API source
import requests
import pandas as pd
api_key = dbutils.secrets.get("api", "key")
response = requests.get(
"https://api.example.com/products",
headers={"Authorization": f"Bearer {api_key}"}
)
products_pd = pd.DataFrame(response.json())
products_df = self.spark.createDataFrame(products_pd)
(products_df.write
.format("delta")
.mode("overwrite")
.saveAsTable("bronze.products"))
# 2. Silver Layer - Clean and validate
class EcommerceSilverLayer:
def __init__(self, spark):
self.spark = spark
def process_orders(self):
bronze_orders = self.spark.readStream.table("bronze.orders")
clean_orders = (bronze_orders
# Data quality
.filter(col("order_id").isNotNull())
.filter(col("amount") > 0)
.filter(col("order_date").isNotNull())
# Type conversions
.withColumn("amount", col("amount").cast("double"))
.withColumn("order_date", to_date("order_date"))
# Deduplication
.dropDuplicates(["order_id"])
# Enrichment
.withColumn("order_year", year("order_date"))
.withColumn("order_month", month("order_date"))
.withColumn("order_quarter", quarter("order_date"))
.withColumn("processed_timestamp", current_timestamp())
)
return (clean_orders.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/mnt/silver/checkpoints/orders")
.trigger(availableNow=True)
.table("silver.orders"))
def process_customers(self):
bronze_customers = self.spark.table("bronze.customers")
clean_customers = (bronze_customers
# Email validation
.filter(col("email").rlike(r"^[\w\.-]+@[\w\.-]+\.\w+$"))
# Standardization
.withColumn("email", lower(trim(col("email"))))
.withColumn("name", initcap(trim(col("name"))))
# Customer lifetime value calculation
.join(
self.spark.table("silver.orders")
.groupBy("customer_id")
.agg(
sum("amount").alias("lifetime_value"),
count("*").alias("total_orders"),
max("order_date").alias("last_order_date")
),
"customer_id",
"left"
)
.fillna({"lifetime_value": 0, "total_orders": 0})
# Customer segments
.withColumn("customer_segment",
when(col("lifetime_value") > 10000, "VIP")
.when(col("lifetime_value") > 5000, "Premium")
.when(col("lifetime_value") > 1000, "Regular")
.otherwise("New")
)
)
return (clean_customers.write
.format("delta")
.mode("overwrite")
.saveAsTable("silver.customers"))
# 3. Gold Layer - Business aggregations
class EcommerceGoldLayer:
def __init__(self, spark):
self.spark = spark
def create_sales_dashboard(self):
"""Daily sales metrics for dashboards"""
orders = self.spark.table("silver.orders")
customers = self.spark.table("silver.customers")
products = self.spark.table("bronze.products")
sales_metrics = (orders
.join(customers, "customer_id")
.join(products, "product_id")
.groupBy("order_date", "product_category", "customer_segment")
.agg(
sum("amount").alias("total_revenue"),
count("order_id").alias("total_orders"),
countDistinct("customer_id").alias("unique_customers"),
avg("amount").alias("avg_order_value")
)
.withColumn("revenue_per_customer",
col("total_revenue") / col("unique_customers"))
)
return (sales_metrics.write
.format("delta")
.mode("overwrite")
.partitionBy("order_date")
.saveAsTable("gold.daily_sales_dashboard"))
def create_customer_rfm(self):
"""RFM (Recency, Frequency, Monetary) analysis"""
orders = self.spark.table("silver.orders")
rfm = (orders
.groupBy("customer_id")
.agg(
datediff(current_date(), max("order_date")).alias("recency"),
count("*").alias("frequency"),
sum("amount").alias("monetary")
)
.withColumn("r_score",
ntile(5).over(Window.orderBy(col("recency").desc())))
.withColumn("f_score",
ntile(5).over(Window.orderBy(col("frequency"))))
.withColumn("m_score",
ntile(5).over(Window.orderBy(col("monetary"))))
.withColumn("rfm_score",
concat(col("r_score"), col("f_score"), col("m_score")))
.withColumn("customer_value_segment",
when(col("rfm_score").between("444", "555"), "Champions")
.when(col("rfm_score").between("344", "443"), "Loyal")
.when(col("rfm_score").between("244", "343"), "Potential")
.when(col("rfm_score").between("144", "243"), "At Risk")
.otherwise("Lost")
)
)
return (rfm.write
.format("delta")
.mode("overwrite")
.saveAsTable("gold.customer_rfm"))
def create_product_recommendations(self):
"""Product affinity analysis for recommendations"""
orders = self.spark.table("silver.orders")
# Market basket analysis
product_pairs = (orders
.alias("o1")
.join(orders.alias("o2"),
(col("o1.customer_id") == col("o2.customer_id")) &
(col("o1.product_id") < col("o2.product_id")))
.groupBy(col("o1.product_id").alias("product_a"),
col("o2.product_id").alias("product_b"))
.agg(count("*").alias("pair_count"))
.filter("pair_count > 10")
)
return (product_pairs.write
.format("delta")
.mode("overwrite")
.saveAsTable("gold.product_affinity"))
# Execute pipeline
bronze = EcommerceBronzeLayer(spark)
silver = EcommerceSilverLayer(spark)
gold = EcommerceGoldLayer(spark)
# Run all layers
bronze.ingest_orders()
bronze.ingest_customers()
bronze.ingest_products()
silver.process_orders()
silver.process_customers()
gold.create_sales_dashboard()
gold.create_customer_rfm()
gold.create_product_recommendations()
Project 2: Real-Time Fraud Detection
"""
Real-time fraud detection system
- Streaming transaction processing
- ML-based fraud scoring
- Real-time alerting
"""
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import mlflow
class FraudDetectionPipeline:
def __init__(self, spark):
self.spark = spark
self.fraud_model = self.load_model()
def load_model(self):
"""Load pre-trained fraud detection model"""
model_uri = "models:/fraud_detection/Production"
return mlflow.pyfunc.spark_udf(spark, model_uri)
def stream_transactions(self):
"""Read streaming transactions"""
return (self.spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "localhost:9092")
.option("subscribe", "transactions")
.option("startingOffsets", "latest")
.load()
.select(
from_json(col("value").cast("string"), self.get_schema()).alias("data")
)
.select("data.*")
)
def get_schema(self):
return StructType([
StructField("transaction_id", StringType()),
StructField("user_id", StringType()),
StructField("amount", DoubleType()),
StructField("merchant_id", StringType()),
StructField("timestamp", TimestampType()),
StructField("location", StringType()),
StructField("device_id", StringType())
])
def engineer_features(self, df):
"""Create fraud detection features"""
# Time-based features
df = df.withColumn("hour_of_day", hour("timestamp"))
df = df.withColumn("day_of_week", dayofweek("timestamp"))
df = df.withColumn("is_weekend",
when(col("day_of_week").isin([1, 7]), 1).otherwise(0))
# User behavior features (windowed aggregations)
user_window_1h = Window.partitionBy("user_id") \
.orderBy(col("timestamp").cast("long")) \
.rangeBetween(-3600, 0)
df = df.withColumn("txn_count_1h",
count("*").over(user_window_1h))
df = df.withColumn("total_amount_1h",
sum("amount").over(user_window_1h))
df = df.withColumn("avg_amount_1h",
avg("amount").over(user_window_1h))
# Velocity features
user_window_24h = Window.partitionBy("user_id") \
.orderBy(col("timestamp").cast("long")) \
.rangeBetween(-86400, 0)
df = df.withColumn("txn_count_24h",
count("*").over(user_window_24h))
df = df.withColumn("unique_merchants_24h",
countDistinct("merchant_id").over(user_window_24h))
# Amount deviation
df = df.withColumn("amount_vs_avg_ratio",
col("amount") / (col("avg_amount_1h") + 1))
return df
def detect_fraud(self, df):
"""Apply ML model and rule-based detection"""
# Apply ML model
df = df.withColumn("ml_fraud_score",
self.fraud_model(*[col(c) for c in self.get_feature_cols()]))
# Rule-based detection
df = df.withColumn("rule_fraud_flag",
when(
(col("amount") > 10000) |
(col("txn_count_1h") > 10) |
(col("amount_vs_avg_ratio") > 5) |
(col("unique_merchants_24h") > 20),
1
).otherwise(0)
)
# Combined fraud score
df = df.withColumn("fraud_score",
(col("ml_fraud_score") * 0.7 + col("rule_fraud_flag") * 0.3))
# Fraud decision
df = df.withColumn("is_fraud",
when(col("fraud_score") > 0.8, "high_risk")
.when(col("fraud_score") > 0.5, "medium_risk")
.otherwise("low_risk"))
return df
def get_feature_cols(self):
return [
"amount", "hour_of_day", "day_of_week", "is_weekend",
"txn_count_1h", "total_amount_1h", "avg_amount_1h",
"txn_count_24h", "unique_merchants_24h", "amount_vs_avg_ratio"
]
def write_results(self, df):
"""Write results to multiple sinks"""
# High-risk transactions to alerts table
high_risk_query = (df
.filter("is_fraud = 'high_risk'")
.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/mnt/checkpoints/fraud_alerts")
.table("fraud_alerts")
)
# All scored transactions
all_query = (df
.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/mnt/checkpoints/scored_transactions")
.partitionBy("is_fraud")
.table("scored_transactions")
)
# Real-time metrics
metrics_query = (df
.withWatermark("timestamp", "5 minutes")
.groupBy(
window("timestamp", "1 minute"),
"is_fraud"
)
.agg(
count("*").alias("transaction_count"),
sum("amount").alias("total_amount"),
avg("fraud_score").alias("avg_fraud_score")
)
.writeStream
.format("delta")
.outputMode("update")
.option("checkpointLocation", "/mnt/checkpoints/fraud_metrics")
.table("fraud_detection_metrics")
)
return high_risk_query, all_query, metrics_query
def run(self):
"""Execute fraud detection pipeline"""
# Stream transactions
transactions = self.stream_transactions()
# Engineer features
featured = self.engineer_features(transactions)
# Detect fraud
scored = self.detect_fraud(featured)
# Write results
queries = self.write_results(scored)
return queries
# Run pipeline
fraud_pipeline = FraudDetectionPipeline(spark)
queries = fraud_pipeline.run()
# Monitor
for query in queries:
print(f"Query {query.id} is active: {query.isActive}")
17. Interview Preparation
Common Databricks Interview Questions
Conceptual Questions
Q1: What is the difference between RDD, DataFrame, and Dataset?
Answer:
- RDD (Resilient Distributed Dataset): Low-level API, requires manual optimization, no built-in optimization, strongly typed (Scala/Java), uses Java serialization
- DataFrame: High-level API, automatic optimization via Catalyst, schema-based, untyped (rows), uses Tungsten execution engine
- Dataset: Combines benefits of both, strongly typed, compile-time type safety (Scala/Java only), optimized execution
Q2: Explain the Catalyst Optimizer.
Answer:
The Catalyst Optimizer is Spark’s query optimization framework with four phases:
- Analysis: Resolve references, validate types
- Logical Optimization: Predicate pushdown, constant folding, projection pruning
- Physical Planning: Generate multiple physical plans, cost-based optimization
- Code Generation: Generate Java bytecode for execution
Q3: What is Delta Lake and why use it?
Answer:
Delta Lake is an open-source storage layer providing:
- ACID transactions: Multiple writes without conflicts
- Time travel: Query historical versions
- Schema enforcement/evolution: Prevent bad data, evolve schema
- Unified batch and streaming: Single source of truth
- Audit history: Complete lineage tracking
- Scalable metadata: Handles petabyte-scale tables
Coding Questions
Q4: Remove duplicates keeping the latest record
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col
# Sample data with duplicates
data = [
(1, "Alice", "2024-01-01"),
(1, "Alice Updated", "2024-01-15"),
(2, "Bob", "2024-01-10"),
(2, "Bob Updated", "2024-01-20"),
(3, "Charlie", "2024-01-05")
]
df = spark.createDataFrame(data, ["id", "name", "date"])
# Solution
window = Window.partitionBy("id").orderBy(col("date").desc())
result = df.withColumn("rn", row_number().over(window)) \
.filter("rn = 1") \
.drop("rn")
result.show()
Q5: Find top N products by sales in each category
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col
# Sample data
sales_data = [
("Electronics", "Laptop", 50000),
("Electronics", "Phone", 80000),
("Electronics", "Tablet", 30000),
("Clothing", "Shirt", 15000),
("Clothing", "Pants", 25000),
("Clothing", "Shoes", 20000)
]
df = spark.createDataFrame(sales_data, ["category", "product", "sales"])
# Top 2 products per category
window = Window.partitionBy("category").orderBy(col("sales").desc())
top_n = df.withColumn("rank", row_number().over(window)) \
.filter("rank <= 2") \
.drop("rank")
top_n.show()
Q6: Calculate running totals
from pyspark.sql.window import Window
from pyspark.sql.functions import sum, col
data = [
("2024-01-01", 100),
("2024-01-02", 150),
("2024-01-03", 200),
("2024-01-04", 120)
]
df = spark.createDataFrame(data, ["date", "amount"])
# Running total
window = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
result = df.withColumn("running_total", sum("amount").over(window))
result.show()
Q7: Implement SCD Type 2
from delta.tables import DeltaTable
from pyspark.sql.functions import current_date, lit, col
# Current dimension table
current_dim = spark.createDataFrame([
(1, "Alice", "alice@old.com", "2023-01-01", None, True),
(2, "Bob", "bob@example.com", "2023-01-01", None, True)
], ["customer_id", "name", "email", "start_date", "end_date", "is_current"])
current_dim.write.format("delta").mode("overwrite").saveAsTable("dim_customer")
# New updates
updates = spark.createDataFrame([
(1, "Alice", "alice@new.com"), # Email changed
(3, "Charlie", "charlie@example.com") # New customer
], ["customer_id", "name", "email"])
# SCD Type 2 merge
target = DeltaTable.forName(spark, "dim_customer")
# Close expired records
(target.alias("target").merge(
updates.alias("updates"),
"target.customer_id = updates.customer_id AND target.is_current = true"
)
.whenMatchedUpdate(
condition = "target.email != updates.email OR target.name != updates.name",
set = {
"is_current": "false",
"end_date": "current_date()"
}
).execute())
# Insert new records (both updates and new)
new_records = updates.withColumn("start_date", current_date()) \
.withColumn("end_date", lit(None).cast("date")) \
.withColumn("is_current", lit(True))
new_records.write.format("delta").mode("append").saveAsTable("dim_customer")
Q8: Optimize a slow query
# Slow query (multiple scans, no predicate pushdown)
slow_query = """
SELECT *
FROM large_table
WHERE year = 2024 AND amount > 1000
"""
df = spark.sql(slow_query)
result = df.groupBy("category").agg(sum("amount"))
# Optimized version
# 1. Add partition pruning
# 2. Filter early
# 3. Select only needed columns
# 4. Use broadcast for small dimension tables
optimized = spark.sql("""
WITH filtered AS (
SELECT category, amount
FROM large_table
WHERE year = 2024
AND month >= 1
AND amount > 1000
)
SELECT category, SUM(amount) as total
FROM filtered
GROUP BY category
""")
# Further optimization: use Delta Lake optimization
spark.sql("OPTIMIZE large_table ZORDER BY (year, category)")
# Enable adaptive query execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
Performance Questions
Q9: How do you handle data skew?
Answer:
# 1. Identify skew
df.groupBy("partition_key").count().orderBy(col("count").desc()).show()
# 2. Solutions:
# A. Salting technique
from pyspark.sql.functions import rand, concat, lit
salted_df = df.withColumn("salt", (rand() * 10).cast("int"))
salted_df = salted_df.withColumn("salted_key",
concat(col("partition_key"), lit("_"), col("salt")))
# Join with replicated dimension
dim_replicated = dim_df.withColumn("salt", explode(array([lit(i) for i in range(10)])))
dim_replicated = dim_replicated.withColumn("salted_key",
concat(col("dim_key"), lit("_"), col("salt")))
result = salted_df.join(dim_replicated, "salted_key")
# B. Adaptive Query Execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# C. Broadcast small tables
result = large_df.join(broadcast(small_df), "key")
# D. Repartition by multiple keys
df.repartition(200, "key1", "key2")
Q10: Explain checkpoint and its importance in streaming
Answer:
# Checkpoint stores:
# 1. Metadata about streaming query
# 2. Committed offsets
# 3. State information for stateful operations
streaming_df = spark.readStream.table("source")
query = (streaming_df
.groupBy("user_id")
.count()
.writeStream
.format("delta")
.option("checkpointLocation", "/mnt/checkpoints/user_counts") # REQUIRED
.outputMode("complete")
.table("user_counts"))
# Importance:
# - Fault tolerance: Resume from last checkpoint after failure
# - Exactly-once processing: Track processed offsets
# - State management: Store stateful operation state
# - Idempotence: Prevent duplicate processing
# Best practices:
# 1. Never change checkpoint location
# 2. Use separate checkpoint per query
# 3. Store on reliable storage (S3, ADLS)
# 4. Don't delete manually unless starting fresh
Key Topics to Master
- Spark Architecture: Driver, executors, cluster manager, DAG, stages, tasks
- DataFrame API: Transformations, actions, UDFs, window functions
- Delta Lake: ACID, time travel, merge, optimize, vacuum
- Performance: Partitioning, caching, broadcast joins, AQE
- Streaming: Structured streaming, watermarking, checkpointing
- MLflow: Experiment tracking, model registry, deployment
- Security: Unity Catalog, secrets, access control
- Best Practices: Medallion architecture, data quality, monitoring
Summary & Next Steps
Learning Path Summary
Weeks 1-2: Foundations
- Databricks workspace and notebooks
- PySpark basics and DataFrames
- SQL fundamentals
Weeks 3-4: Data Engineering
- Delta Lake deep dive
- ETL pipeline development
- Medallion architecture
Weeks 5-6: Advanced Topics
- Structured streaming
- Performance optimization
- Unity Catalog
Weeks 7-8: Machine Learning
- MLflow integration
- Model training and deployment
- Feature engineering
Weeks 9-10: Production
- Workflows and jobs
- Monitoring and alerting
- Security best practices
Weeks 11-12: Projects
- Build real-world projects
- Practice interview questions
- Contribute to open source
Recommended Resources
Official Documentation
- https://docs.databricks.com
- https://spark.apache.org/docs/latest/
- https://delta.io/
Certifications
- Databricks Certified Data Engineer Associate
- Databricks Certified Data Engineer Professional
- Databricks Certified Machine Learning Professional
Practice Platforms
- Databricks Community Edition (free)
- Databricks Academy
- GitHub: Real-world pipeline examples
Books
- “Learning Spark” by Damji et al.
- “Spark: The Definitive Guide” by Chambers & Zaharia
- “Delta Lake: The Definitive Guide”
Final Tips
- Practice Daily: Code every day, even small examples
- Build Projects: Theory alone isn’t enough
- Optimize Everything: Always think about performance
- Document Well: Clear documentation helps teams
- Stay Updated: Databricks releases new features quarterly
- Join Community: Slack, forums, meetups
- Contribute: Open source contributions build expertise
- Test Thoroughly: Data quality is paramount
Good luck with your Databricks journey!