
Picture this: your daily sales pipeline has been running smoothly for months, generating reports that drive million-dollar business decisions. Then one Tuesday morning, you discover that a schema change in your source system three days ago has been silently corrupting your aggregations. The executive dashboard shows revenue up 40% (celebrated in the Monday meeting), but the reality is a data type conversion error has been treating null values as zeros.
This scenario plays out more often than we'd like to admit. Production data pipelines fail in subtle ways, often long before anyone notices. The difference between pipelines that fail fast and fail obviously versus those that fail silently and corrupt downstream systems comes down to comprehensive testing strategies.
By the end of this lesson, you'll implement a three-layered testing approach that catches errors before they reach production and maintains data quality contracts with downstream consumers.
What you'll learn:
You should be comfortable writing data pipelines in Python, familiar with pytest or similar testing frameworks, and understand SQL basics. We'll use Apache Spark with PySpark for examples, but the testing principles apply to any pipeline technology.
Data pipeline testing requires a fundamentally different approach than application testing. While application tests typically verify business logic with predictable inputs, pipeline tests must handle variable data volumes, evolving schemas, and distributed processing concerns.
Our three-layer approach addresses these challenges:
Each layer serves a specific purpose and catches different types of failures. Let's build this framework step by step.
Unit tests form your first line of defense, validating individual transformations before they interact with real data sources or complex infrastructure.
Start with your simplest components—pure functions that transform DataFrames without external dependencies:
# pipeline/transformations.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, regexp_replace
from pyspark.sql.types import DecimalType
def clean_revenue_data(df: DataFrame) -> DataFrame:
"""Clean and standardize revenue data from multiple sources."""
return df.select(
col("customer_id"),
col("transaction_date"),
# Handle various currency formats
regexp_replace(col("amount"), r"[$,]", "").cast(DecimalType(10,2)).alias("amount"),
# Standardize product categories
when(col("category").isin(["SaaS", "Software", "Platform"]), "Software")
.when(col("category").isin(["Consulting", "Services", "Professional"]), "Services")
.otherwise(col("category")).alias("product_category"),
# Flag suspicious transactions
when(col("amount") > 100000, True).otherwise(False).alias("requires_review")
)
def calculate_monthly_metrics(df: DataFrame) -> DataFrame:
"""Calculate monthly revenue metrics by category."""
return df.groupBy("product_category",
date_format(col("transaction_date"), "yyyy-MM").alias("month")) \
.agg(
sum("amount").alias("total_revenue"),
count("*").alias("transaction_count"),
avg("amount").alias("avg_transaction_size"),
sum(when(col("requires_review"), 1).otherwise(0)).alias("flagged_transactions")
)
Now test these transformations with carefully constructed test cases:
# tests/test_transformations.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DateType, DecimalType
from decimal import Decimal
from datetime import date
from pipeline.transformations import clean_revenue_data, calculate_monthly_metrics
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder \
.appName("pipeline-tests") \
.config("spark.sql.shuffle.partitions", "2") \
.getOrCreate()
class TestRevenueDataCleaning:
def test_currency_format_standardization(self, spark):
"""Test that various currency formats are properly cleaned."""
# Arrange
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", StringType()),
StructField("category", StringType())
])
input_data = [
("C001", date(2024, 1, 15), "$1,234.56", "SaaS"),
("C002", date(2024, 1, 16), "2345.67", "Consulting"),
("C003", date(2024, 1, 17), "$15,000", "Platform"),
]
input_df = spark.createDataFrame(input_data, schema)
# Act
result = clean_revenue_data(input_df)
result_data = result.collect()
# Assert
assert result_data[0]["amount"] == Decimal('1234.56')
assert result_data[1]["amount"] == Decimal('2345.67')
assert result_data[2]["amount"] == Decimal('15000.00')
def test_category_standardization(self, spark):
"""Test that product categories are properly standardized."""
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", StringType()),
StructField("category", StringType())
])
input_data = [
("C001", date(2024, 1, 15), "100.00", "SaaS"),
("C002", date(2024, 1, 16), "200.00", "Software"),
("C003", date(2024, 1, 17), "300.00", "Consulting"),
("C004", date(2024, 1, 18), "400.00", "Professional"),
("C005", date(2024, 1, 19), "500.00", "Hardware"),
]
input_df = spark.createDataFrame(input_data, schema)
result = clean_revenue_data(input_df)
categories = [row["product_category"] for row in result.collect()]
assert categories.count("Software") == 2 # SaaS and Software both map to Software
assert categories.count("Services") == 2 # Consulting and Professional both map to Services
assert categories.count("Hardware") == 1 # Hardware stays as-is
def test_large_transaction_flagging(self, spark):
"""Test that large transactions are properly flagged for review."""
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", StringType()),
StructField("category", StringType())
])
input_data = [
("C001", date(2024, 1, 15), "50000.00", "Software"),
("C002", date(2024, 1, 16), "150000.00", "Software"),
]
input_df = spark.createDataFrame(input_data, schema)
result = clean_revenue_data(input_df)
result_data = result.collect()
assert result_data[0]["requires_review"] == False
assert result_data[1]["requires_review"] == True
class TestMonthlyMetrics:
def test_monthly_aggregation_accuracy(self, spark):
"""Test that monthly metrics are calculated correctly."""
# Create test data spanning multiple months and categories
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", DecimalType(10,2)),
StructField("product_category", StringType()),
StructField("requires_review", StringType())
])
input_data = [
("C001", date(2024, 1, 15), Decimal('1000.00'), "Software", False),
("C002", date(2024, 1, 16), Decimal('2000.00'), "Software", False),
("C003", date(2024, 1, 17), Decimal('150000.00'), "Software", True),
("C004", date(2024, 2, 1), Decimal('3000.00'), "Services", False),
("C005", date(2024, 2, 2), Decimal('4000.00'), "Services", False),
]
input_df = spark.createDataFrame(input_data, schema)
result = calculate_monthly_metrics(input_df)
# Convert to dictionary for easier assertions
result_dict = {
(row["product_category"], row["month"]): row
for row in result.collect()
}
# Test January Software metrics
jan_software = result_dict[("Software", "2024-01")]
assert jan_software["total_revenue"] == Decimal('153000.00')
assert jan_software["transaction_count"] == 3
assert jan_software["avg_transaction_size"] == Decimal('51000.00')
assert jan_software["flagged_transactions"] == 1
# Test February Services metrics
feb_services = result_dict[("Services", "2024-02")]
assert feb_services["total_revenue"] == Decimal('7000.00')
assert feb_services["transaction_count"] == 2
assert feb_services["avg_transaction_size"] == Decimal('3500.00')
assert feb_services["flagged_transactions"] == 0
Real pipeline components often interact with databases, APIs, or file systems. Use dependency injection and mocking to test these components in isolation:
# pipeline/loaders.py
from abc import ABC, abstractmethod
from pyspark.sql import DataFrame
import requests
class DataSource(ABC):
@abstractmethod
def load_customer_data(self, start_date: str, end_date: str) -> DataFrame:
pass
class DatabaseSource(DataSource):
def __init__(self, connection_string: str):
self.connection_string = connection_string
def load_customer_data(self, start_date: str, end_date: str) -> DataFrame:
# In real implementation, this would connect to database
pass
class EnrichmentService:
def __init__(self, api_key: str, base_url: str):
self.api_key = api_key
self.base_url = base_url
def get_customer_segment(self, customer_id: str) -> str:
response = requests.get(
f"{self.base_url}/customers/{customer_id}/segment",
headers={"Authorization": f"Bearer {self.api_key}"}
)
return response.json()["segment"]
def enrich_customer_data(df: DataFrame, enrichment_service: EnrichmentService) -> DataFrame:
"""Add customer segment information to transaction data."""
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
@udf(returnType=StringType())
def get_segment(customer_id):
try:
return enrichment_service.get_customer_segment(customer_id)
except Exception:
return "unknown"
return df.withColumn("customer_segment", get_segment(col("customer_id")))
Test this with mock dependencies:
# tests/test_loaders.py
import pytest
from unittest.mock import Mock, patch
from pyspark.sql.types import StructType, StructField, StringType, DateType
from datetime import date
from pipeline.loaders import enrich_customer_data, EnrichmentService
class TestCustomerEnrichment:
def test_successful_enrichment(self, spark):
"""Test that customer segments are properly added when API calls succeed."""
# Arrange
mock_service = Mock(spec=EnrichmentService)
mock_service.get_customer_segment.side_effect = lambda cid: {
"C001": "enterprise",
"C002": "mid-market",
"C003": "smb"
}.get(cid, "unknown")
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", StringType()),
])
input_data = [
("C001", date(2024, 1, 15), "1000.00"),
("C002", date(2024, 1, 16), "2000.00"),
("C003", date(2024, 1, 17), "3000.00"),
]
input_df = spark.createDataFrame(input_data, schema)
# Act
result = enrich_customer_data(input_df, mock_service)
result_data = {row["customer_id"]: row["customer_segment"]
for row in result.collect()}
# Assert
assert result_data["C001"] == "enterprise"
assert result_data["C002"] == "mid-market"
assert result_data["C003"] == "smb"
assert mock_service.get_customer_segment.call_count == 3
def test_api_failure_handling(self, spark):
"""Test that API failures are handled gracefully."""
# Arrange
mock_service = Mock(spec=EnrichmentService)
mock_service.get_customer_segment.side_effect = Exception("API timeout")
schema = StructType([
StructField("customer_id", StringType()),
StructField("transaction_date", DateType()),
StructField("amount", StringType()),
])
input_data = [("C001", date(2024, 1, 15), "1000.00")]
input_df = spark.createDataFrame(input_data, schema)
# Act
result = enrich_customer_data(input_df, mock_service)
result_data = result.collect()
# Assert
assert result_data[0]["customer_segment"] == "unknown"
Testing Tip: When testing Spark UDFs that make external calls, be aware that Spark may cache UDF results. Use unique test data or clear the UDF cache between tests if you encounter unexpected behavior.
While unit tests verify individual components, integration tests validate that your entire pipeline works correctly with realistic data volumes and infrastructure constraints.
Integration tests require a test environment that mirrors production without the cost and complexity. Use containerized services and scaled-down data:
# tests/integration/conftest.py
import pytest
import docker
import time
from pyspark.sql import SparkSession
import tempfile
import shutil
@pytest.fixture(scope="session")
def test_postgres():
"""Start a PostgreSQL container for integration tests."""
client = docker.from_env()
container = client.containers.run(
"postgres:13",
environment={
"POSTGRES_DB": "testdb",
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass"
},
ports={"5432/tcp": None},
detach=True,
remove=True
)
# Wait for PostgreSQL to be ready
time.sleep(10)
# Get the mapped port
port = container.ports["5432/tcp"][0]["HostPort"]
connection_string = f"postgresql://testuser:testpass@localhost:{port}/testdb"
yield connection_string
container.stop()
@pytest.fixture(scope="session")
def integration_spark():
"""Spark session configured for integration tests."""
return SparkSession.builder \
.appName("integration-tests") \
.config("spark.sql.shuffle.partitions", "4") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.getOrCreate()
@pytest.fixture
def temp_data_dir():
"""Temporary directory for test data files."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
Build integration tests that exercise your entire pipeline with realistic data scenarios:
# tests/integration/test_revenue_pipeline.py
import pytest
import pandas as pd
from datetime import date, timedelta
import random
from decimal import Decimal
from pipeline.main import RevenuePipeline
class TestRevenuePipelineIntegration:
@pytest.fixture
def sample_data(self, integration_spark, temp_data_dir):
"""Generate realistic test data that covers edge cases."""
# Generate 30 days of transaction data
start_date = date(2024, 1, 1)
customers = [f"C{i:03d}" for i in range(1, 51)] # 50 customers
categories = ["SaaS", "Software", "Consulting", "Professional", "Hardware"]
transactions = []
for day in range(30):
current_date = start_date + timedelta(days=day)
# Generate 20-100 transactions per day
daily_transactions = random.randint(20, 100)
for _ in range(daily_transactions):
customer = random.choice(customers)
category = random.choice(categories)
# Create realistic amount distribution
if category in ["SaaS", "Software"]:
amount = random.uniform(500, 5000)
elif category in ["Consulting", "Professional"]:
amount = random.uniform(2000, 25000)
else: # Hardware
amount = random.uniform(100, 50000)
# Occasionally generate large transactions
if random.random() < 0.05: # 5% chance
amount *= random.uniform(10, 50)
# Add some data quality issues
if random.random() < 0.02: # 2% malformed amounts
amount_str = f"${amount:,.2f}"
else:
amount_str = f"{amount:.2f}"
transactions.append({
"customer_id": customer,
"transaction_date": current_date,
"amount": amount_str,
"category": category
})
# Save to parquet for realistic file-based input
df = pd.DataFrame(transactions)
input_path = f"{temp_data_dir}/raw_transactions.parquet"
df.to_parquet(input_path, index=False)
return input_path, len(transactions)
def test_full_pipeline_execution(self, integration_spark, test_postgres, sample_data, temp_data_dir):
"""Test complete pipeline execution with realistic data volume."""
input_path, expected_count = sample_data
output_path = f"{temp_data_dir}/processed"
# Initialize pipeline
pipeline = RevenuePipeline(
spark=integration_spark,
input_path=input_path,
output_path=output_path,
database_connection=test_postgres
)
# Execute pipeline
pipeline_result = pipeline.run()
# Validate pipeline completed successfully
assert pipeline_result.success is True
assert pipeline_result.records_processed > 0
assert pipeline_result.records_processed <= expected_count # Some may be filtered
# Validate output data structure
output_df = integration_spark.read.parquet(output_path)
# Check schema
expected_columns = ["customer_id", "transaction_date", "amount",
"product_category", "requires_review", "month",
"total_revenue", "transaction_count"]
assert set(output_df.columns) >= set(expected_columns)
# Check data quality
assert output_df.count() > 0
assert output_df.filter("amount IS NULL").count() == 0
assert output_df.filter("product_category IS NULL").count() == 0
# Validate business logic
large_transactions = output_df.filter("amount > 100000").count()
flagged_transactions = output_df.filter("requires_review = true").count()
assert flagged_transactions >= large_transactions
def test_pipeline_performance_characteristics(self, integration_spark, sample_data, temp_data_dir):
"""Test that pipeline meets performance requirements."""
input_path, expected_count = sample_data
output_path = f"{temp_data_dir}/processed"
pipeline = RevenuePipeline(
spark=integration_spark,
input_path=input_path,
output_path=output_path
)
import time
start_time = time.time()
result = pipeline.run()
execution_time = time.time() - start_time
# Performance assertions
assert result.success is True
assert execution_time < 60 # Should complete within 1 minute for test data
# Check resource utilization
assert result.memory_usage_mb < 2048 # Should use less than 2GB
assert result.cpu_time_seconds < 120 # Should use less than 2 minutes CPU time
def test_pipeline_error_handling(self, integration_spark, temp_data_dir):
"""Test pipeline behavior when encountering data quality issues."""
# Create data with various quality issues
bad_data = [
{"customer_id": None, "transaction_date": "2024-01-01", "amount": "1000.00", "category": "Software"},
{"customer_id": "C001", "transaction_date": None, "amount": "2000.00", "category": "SaaS"},
{"customer_id": "C002", "transaction_date": "2024-01-03", "amount": "invalid", "category": "Hardware"},
{"customer_id": "C003", "transaction_date": "2024-01-04", "amount": "3000.00", "category": None},
]
df = pd.DataFrame(bad_data)
bad_input_path = f"{temp_data_dir}/bad_data.parquet"
df.to_parquet(bad_input_path, index=False)
pipeline = RevenuePipeline(
spark=integration_spark,
input_path=bad_input_path,
output_path=f"{temp_data_dir}/output",
error_handling="strict"
)
result = pipeline.run()
# Pipeline should complete but report data quality issues
assert result.success is True
assert result.data_quality_warnings > 0
assert result.records_processed < 4 # Some records should be filtered
# Error records should be captured
error_df = integration_spark.read.parquet(f"{temp_data_dir}/output/errors")
assert error_df.count() > 0
Integration tests should validate behavior at realistic scales without overwhelming your test environment:
def test_large_dataset_processing(self, integration_spark, temp_data_dir):
"""Test pipeline behavior with larger dataset that exercises partitioning."""
# Generate 1M records across multiple partitions
import numpy as np
records_per_partition = 100000
num_partitions = 10
for partition in range(num_partitions):
partition_data = []
for i in range(records_per_partition):
partition_data.append({
"customer_id": f"C{i % 1000:03d}",
"transaction_date": date(2024, 1, 1) + timedelta(days=i % 90),
"amount": f"{random.uniform(100, 10000):.2f}",
"category": random.choice(["SaaS", "Hardware", "Consulting"])
})
df = pd.DataFrame(partition_data)
df.to_parquet(f"{temp_data_dir}/large_data_part_{partition}.parquet", index=False)
pipeline = RevenuePipeline(
spark=integration_spark,
input_path=f"{temp_data_dir}/large_data_part_*.parquet",
output_path=f"{temp_data_dir}/large_output"
)
result = pipeline.run()
assert result.success is True
assert result.records_processed == records_per_partition * num_partitions
# Validate partitioning strategy worked
output_files = integration_spark.read.parquet(f"{temp_data_dir}/large_output").rdd.getNumPartitions()
assert output_files >= 4 # Should maintain reasonable parallelism
Performance Testing Note: Integration tests run in CI/CD environments with limited resources. Focus on testing partitioning logic and memory management rather than absolute performance numbers.
Data contracts define explicit expectations about data structure, quality, and semantics between pipeline stages. They serve as both documentation and automated validation.
Start with schema contracts that define expected data structure:
# pipeline/contracts.py
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
from pyspark.sql.types import StructType, StructField, StringType, DateType, DecimalType, BooleanType
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, count, sum as spark_sum, avg, max as spark_max, min as spark_min
import logging
@dataclass
class FieldContract:
name: str
data_type: str
nullable: bool = True
constraints: Optional[Dict[str, Any]] = None
description: str = ""
@dataclass
class SchemaContract:
name: str
version: str
fields: List[FieldContract]
description: str = ""
def to_spark_schema(self) -> StructType:
"""Convert contract to Spark StructType."""
type_mapping = {
"string": StringType(),
"date": DateType(),
"decimal": DecimalType(10, 2),
"boolean": BooleanType()
}
return StructType([
StructField(field.name, type_mapping[field.data_type], field.nullable)
for field in self.fields
])
def validate_schema(self, df: DataFrame) -> "ValidationResult":
"""Validate DataFrame schema against contract."""
errors = []
warnings = []
# Check for missing fields
df_columns = set(df.columns)
contract_columns = {field.name for field in self.fields}
missing_fields = contract_columns - df_columns
if missing_fields:
errors.append(f"Missing required fields: {missing_fields}")
extra_fields = df_columns - contract_columns
if extra_fields:
warnings.append(f"Unexpected fields found: {extra_fields}")
# Check field types and constraints
for field in self.fields:
if field.name in df_columns:
df_field = df.schema[field.name]
# Type validation would go here
# (simplified for example)
# Constraint validation
if field.constraints:
constraint_errors = self._validate_field_constraints(df, field)
errors.extend(constraint_errors)
return ValidationResult(
contract_name=self.name,
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings
)
def _validate_field_constraints(self, df: DataFrame, field: FieldContract) -> List[str]:
"""Validate field-level constraints."""
errors = []
if not field.constraints:
return errors
field_col = col(field.name)
# Non-null constraint
if field.constraints.get("non_null", False):
null_count = df.filter(field_col.isNull()).count()
if null_count > 0:
errors.append(f"Field {field.name} has {null_count} null values but is marked non_null")
# Range constraints
if "min_value" in field.constraints or "max_value" in field.constraints:
stats = df.agg(
spark_min(field_col).alias("min_val"),
spark_max(field_col).alias("max_val")
).collect()[0]
if field.constraints.get("min_value") is not None:
if stats["min_val"] < field.constraints["min_value"]:
errors.append(f"Field {field.name} has values below minimum {field.constraints['min_value']}")
if field.constraints.get("max_value") is not None:
if stats["max_val"] > field.constraints["max_value"]:
errors.append(f"Field {field.name} has values above maximum {field.constraints['max_value']}")
return errors
# Define contracts for our revenue pipeline
RAW_TRANSACTIONS_CONTRACT = SchemaContract(
name="raw_transactions",
version="1.0",
description="Raw transaction data from source systems",
fields=[
FieldContract("customer_id", "string", nullable=False,
constraints={"non_null": True},
description="Unique customer identifier"),
FieldContract("transaction_date", "date", nullable=False,
constraints={"non_null": True, "min_value": "2020-01-01"},
description="Date of transaction"),
FieldContract("amount", "string", nullable=False,
description="Transaction amount (may contain currency formatting)"),
FieldContract("category", "string", nullable=False,
constraints={"non_null": True},
description="Product category")
]
)
CLEANED_TRANSACTIONS_CONTRACT = SchemaContract(
name="cleaned_transactions",
version="1.0",
description="Cleaned and standardized transaction data",
fields=[
FieldContract("customer_id", "string", nullable=False,
constraints={"non_null": True}),
FieldContract("transaction_date", "date", nullable=False,
constraints={"non_null": True}),
FieldContract("amount", "decimal", nullable=False,
constraints={"non_null": True, "min_value": 0}),
FieldContract("product_category", "string", nullable=False,
constraints={"non_null": True}),
FieldContract("requires_review", "boolean", nullable=False)
]
)
Beyond schema validation, implement quality contracts that check data distribution and business logic:
@dataclass
class QualityContract:
name: str
description: str
validations: List["QualityValidation"]
def validate(self, df: DataFrame) -> "ValidationResult":
"""Run all quality validations on the DataFrame."""
errors = []
warnings = []
for validation in self.validations:
result = validation.validate(df)
errors.extend(result.errors)
warnings.extend(result.warnings)
return ValidationResult(
contract_name=self.name,
is_valid=len(errors) == 0,
errors=errors,
warnings=warnings
)
class QualityValidation:
"""Base class for data quality validations."""
def __init__(self, name: str, description: str):
self.name = name
self.description = description
def validate(self, df: DataFrame) -> "ValidationResult":
raise NotImplementedError
class UniquenessValidation(QualityValidation):
"""Validate that specified columns contain unique values."""
def __init__(self, columns: List[str], tolerance: float = 0.0):
super().__init__(
f"uniqueness_check_{'+'.join(columns)}",
f"Check uniqueness of {columns}"
)
self.columns = columns
self.tolerance = tolerance # Allow some percentage of duplicates
def validate(self, df: DataFrame) -> "ValidationResult":
total_count = df.count()
distinct_count = df.select(*self.columns).distinct().count()
duplicate_rate = (total_count - distinct_count) / total_count if total_count > 0 else 0
errors = []
if duplicate_rate > self.tolerance:
errors.append(
f"Duplicate rate {duplicate_rate:.2%} exceeds tolerance {self.tolerance:.2%} "
f"for columns {self.columns}"
)
return ValidationResult(
contract_name=self.name,
is_valid=len(errors) == 0,
errors=errors,
warnings=[]
)
class CompletenessValidation(QualityValidation):
"""Validate that fields meet completeness requirements."""
def __init__(self, field_requirements: Dict[str, float]):
super().__init__(
"completeness_check",
"Check field completeness requirements"
)
self.field_requirements = field_requirements # field -> minimum completeness ratio
def validate(self, df: DataFrame) -> "ValidationResult":
total_count = df.count()
errors = []
for field, min_completeness in self.field_requirements.items():
non_null_count = df.filter(col(field).isNotNull()).count()
completeness = non_null_count / total_count if total_count > 0 else 0
if completeness < min_completeness:
errors.append(
f"Field {field} completeness {completeness:.2%} below requirement {min_completeness:.2%}"
)
return ValidationResult(
contract_name=self.name,
is_valid=len(errors) == 0,
errors=errors,
warnings=[]
)
class BusinessRuleValidation(QualityValidation):
"""Validate custom business rules."""
def __init__(self, rule_name: str, condition_expr: str, description: str, max_violations: int = 0):
super().__init__(rule_name, description)
self.condition_expr = condition_expr
self.max_violations = max_violations
def validate(self, df: DataFrame) -> "ValidationResult":
violations = df.filter(f"NOT ({self.condition_expr})").count()
errors = []
if violations > self.max_violations:
errors.append(
f"Business rule '{self.description}' violated {violations} times "
f"(max allowed: {self.max_violations})"
)
return ValidationResult(
contract_name=self.name,
is_valid=len(errors) == 0,
errors=errors,
warnings=[]
)
# Define quality contracts for our pipeline
REVENUE_QUALITY_CONTRACT = QualityContract(
name="revenue_quality_checks",
description="Quality validations for revenue data",
validations=[
UniquenessValidation(["customer_id", "transaction_date"], tolerance=0.01), # Allow 1% duplicates
CompletenessValidation({
"customer_id": 1.0, # 100% required
"transaction_date": 1.0, # 100% required
"amount": 0.95, # 95% required
"product_category": 0.90 # 90% required
}),
BusinessRuleValidation(
"positive_amounts",
"amount > 0",
"All transaction amounts must be positive"
),
BusinessRuleValidation(
"recent_transactions",
"transaction_date >= date_sub(current_date(), 365)",
"Transactions should be within the last year",
max_violations=100 # Allow some historical data
)
]
)
Integrate contract validation into your pipeline execution:
@dataclass
class ValidationResult:
contract_name: str
is_valid: bool
errors: List[str]
warnings: List[str]
def __post_init__(self):
if self.errors:
logging.error(f"Contract {self.contract_name} validation failed: {self.errors}")
if self.warnings:
logging.warning(f"Contract {self.contract_name} validation warnings: {self.warnings}")
class ContractEnforcer:
"""Enforces data contracts at pipeline checkpoints."""
def __init__(self, fail_on_error: bool = True, log_warnings: bool = True):
self.fail_on_error = fail_on_error
self.log_warnings = log_warnings
self.validation_results = []
def validate_schema(self, df: DataFrame, contract: SchemaContract, checkpoint_name: str) -> DataFrame:
"""Validate DataFrame against schema contract."""
result = contract.validate_schema(df)
result.checkpoint_name = checkpoint_name
self.validation_results.append(result)
if not result.is_valid and self.fail_on_error:
raise ValueError(f"Schema validation failed at {checkpoint_name}: {result.errors}")
return df
def validate_quality(self, df: DataFrame, contract: QualityContract, checkpoint_name: str) -> DataFrame:
"""Validate DataFrame against quality contract."""
result = contract.validate(df)
result.checkpoint_name = checkpoint_name
self.validation_results.append(result)
if not result.is_valid and self.fail_on_error:
raise ValueError(f"Quality validation failed at {checkpoint_name}: {result.errors}")
return df
def get_validation_summary(self) -> Dict[str, Any]:
"""Get summary of all validation results."""
total_validations = len(self.validation_results)
failed_validations = sum(1 for r in self.validation_results if not r.is_valid)
return {
"total_validations": total_validations,
"failed_validations": failed_validations,
"success_rate": (total_validations - failed_validations) / total_validations if total_validations > 0 else 0,
"results": self.validation_results
}
# Updated pipeline with contract enforcement
class RevenuePipeline:
def __init__(self, spark: SparkSession, input_path: str, output_path: str,
enforce_contracts: bool = True):
self.spark = spark
self.input_path = input_path
self.output_path = output_path
self.enforcer = ContractEnforcer(fail_on_error=enforce_contracts)
def run(self):
"""Execute pipeline with contract validation."""
# Load raw data
raw_df = self.spark.read.parquet(self.input_path)
# Validate input contract
raw_df = self.enforcer.validate_schema(
raw_df, RAW_TRANSACTIONS_CONTRACT, "raw_input"
)
# Clean data
cleaned_df = clean_revenue_data(raw_df)
# Validate cleaned data contract
cleaned_df = self.enforcer.validate_schema(
cleaned_df, CLEANED_TRANSACTIONS_CONTRACT, "cleaned_data"
)
cleaned_df = self.enforcer.validate_quality(
cleaned_df, REVENUE_QUALITY_CONTRACT, "cleaned_data"
)
# Calculate metrics
metrics_df = calculate_monthly_metrics(cleaned_df)
# Save results
metrics_df.write.mode("overwrite").parquet(self.output_path)
return PipelineResult(
success=True,
validation_summary=self.enforcer.get_validation_summary(),
records_processed=cleaned_df.count()
)
Now let's build a complete testing framework for a realistic scenario: a customer analytics pipeline that processes user behavior data and generates segmentation insights.
We'll create a pipeline that:
# pipeline/customer_analytics.py
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, count, sum as spark_sum, avg, when, datediff, current_date, lag
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DoubleType
class CustomerAnalyticsPipeline:
def __init__(self, spark: SparkSession):
self.spark = spark
def load_raw_events(self, events_path: str) -> DataFrame:
"""Load raw user event data."""
return self.spark.read.parquet(events_path)
def clean_events(self, df: DataFrame) -> DataFrame:
"""Clean and standardize event data."""
return df.filter(
col("user_id").isNotNull() &
col("event_type").isNotNull() &
col("timestamp").isNotNull()
).select(
col("user_id"),
col("event_type"),
col("timestamp"),
when(col("revenue").isNull(), 0.0).otherwise(col("revenue")).alias("revenue"),
col("session_id")
)
def calculate_user_metrics(self, df: DataFrame) -> DataFrame:
"""Calculate engagement metrics per user."""
window_spec = Window.partitionBy("user_id").orderBy("timestamp")
# Add session indicators
df_with_sessions = df.withColumn(
"prev_timestamp",
lag("timestamp").over(window_spec)
).withColumn(
"session_gap_minutes",
(col("timestamp").cast("long") - col("prev_timestamp").cast("long")) / 60
).withColumn(
"new_session",
when((col("session_gap_minutes") > 30) | col("prev_timestamp").isNull(), 1).otherwise(0)
)
return df_with_sessions.groupBy("user_id").agg(
count("*").alias("total_events"),
spark_sum("new_session").alias("session_count"),
spark_sum("revenue").alias("total_revenue"),
avg("revenue").alias("avg_revenue_per_event"),
datediff(current_date(), min("timestamp")).alias("days_since_first_event"),
datediff(current_date(), max("timestamp")).alias("days_since_last_event"),
countDistinct("event_type").alias("unique_event_types")
)
def assign_segments(self, df: DataFrame) -> DataFrame:
"""Assign customer segments based on behavior metrics."""
return df.withColumn(
"customer_segment",
when(
(col("total_revenue") > 1000) & (col("session_count") > 50),
"high_value"
).when(
(col("total_revenue") > 100) & (col("session_count") > 10),
"medium_value"
).when(
col("days_since_last_event") > 30,
"dormant"
).when(
col("days_since_first_event") < 7,
"new_user"
).otherwise("low_engagement")
)
# tests/test_customer_analytics.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from datetime import datetime, timedelta
from pipeline.customer_analytics import CustomerAnalyticsPipeline
class TestCustomerAnalyticsPipeline:
def test_event_cleaning(self, spark):
"""Test that event cleaning properly handles data quality issues."""
schema = StructType([
StructField("user_id", StringType()),
StructField("event_type", StringType()),
StructField("timestamp", TimestampType()),
StructField("revenue", DoubleType()),
StructField("session_id", StringType())
])
# Include various data quality issues
test_data = [
("user1", "login", datetime(2024, 1, 1, 10, 0), 0.0, "s1"),
("user2", "purchase", datetime(2024, 1, 1, 11, 0), 25.99, "s2"),
(None, "click", datetime(2024, 1, 1, 12, 0), 0.0, "s3"), # Null user_id
("user3", None, datetime(2024, 1, 1, 13, 0), 0.0, "s4"), # Null event_type
("user4", "purchase", None, 50.00, "s5"), # Null timestamp
("user5", "purchase", datetime(2024, 1, 1, 14, 0), None, "s6"), # Null revenue
]
input_df = spark.createDataFrame(test_data, schema)
pipeline = CustomerAnalyticsPipeline(spark)
result = pipeline.clean_events(input_df)
result_data = result.collect()
# Should keep valid records and handle null revenue
assert len(result_data) == 3
# Check that null revenue is converted to 0.0
revenue_null_record = [r for r in result_data if r["user_id"] == "user5"][0]
assert revenue_null_record["revenue"] == 0.0
def test_user_metrics_calculation(self, spark):
"""Test user engagement metrics calculation."""
schema = StructType([
StructField("user_id", StringType()),
StructField("event_type", StringType()),
StructField("timestamp", TimestampType()),
StructField("revenue", DoubleType()),
StructField("session_id", StringType())
])
# Create test data for two users with different patterns
base_time = datetime(2024, 1, 1, 10, 0)
test_data = [
# User1: Active user with multiple sessions
("user1", "login", base_time, 0.0, "s1"),
("user1", "click", base_time + timedelta(minutes=5), 0.0, "s1"),
("user1", "purchase", base_time + timedelta(minutes=10), 99.99, "s1"),
("user1", "login", base_time + timedelta(hours=2), 0.0, "s2"), # New session (>30min gap)
("user1", "purchase", base_time + timedelta(hours=2, minutes=5), 149.99, "s2"),
# User2: Less active user
("user2", "login", base_time, 0.0, "s3"),
("user2", "click", base_time + timedelta(minutes=2), 0.0, "s3"),
]
input_df = spark.createDataFrame(test_data, schema)
pipeline = CustomerAnalyticsPipeline(spark)
result = pipeline.calculate_user_metrics(input_df)
result_dict = {row["user_id"]: row for row in result.collect()}
# Validate user1 metrics
user1 = result_dict["user1"]
assert user1["total_events"] == 5
assert user1["session_count"] == 2
assert user1["total_revenue"] == 249.98
assert user1["unique_event_types"] == 3
# Validate user2 metrics
user2 = result_dict["user2"]
assert user2["total_events"] == 2
assert user2["session_count"] == 1
assert user2["total_revenue"] == 0.0
assert user2["unique_event_types"] == 2
def test_customer_segmentation(self, spark):
"""Test customer segment assignment logic."""
schema = StructType([
StructField("user_id", StringType()),
StructField("total_events", IntegerType()),
StructField("session_count", IntegerType()),
StructField("total_revenue", DoubleType()),
StructField("days_since_first_event", IntegerType()),
StructField("days_since_last_event", IntegerType()),
])
test_data = [
("high_value_user", 100, 75, 2500.00, 90, 1), # High value
("medium_value_user", 50, 25, 500.00, 60, 2), # Medium value
("dormant_user", 30, 15, 200.00, 120, 45), # Dormant
("new_user", 5, 2, 50.00, 3, 1), # New user
("low_engagement_user", 10, 3, 25.00, 30, 5), # Low engagement
]
input_df = spark.createDataFrame(test_data, schema)
pipeline = CustomerAnalyticsPipeline(spark)
result = pipeline.assign_segments(input_df)
result_dict = {row["user_id"]: row["customer_segment"] for row in result.collect()}
assert result_dict["high_value_user"] == "high_value"
assert result_dict["medium_value_user"] == "medium_value"
assert result_dict["dormant_user"] == "dormant"
assert result_dict["new_user"] == "new_user"
assert result_dict["low_engagement_user"] == "low_engagement"
# tests/integration/test_customer_analytics_integration.py
import pytest
from datetime import datetime, timedelta
import random
import pandas as pd
from pipeline.customer_analytics import CustomerAnalyticsPipeline
class TestCustomerAnalyticsIntegration:
@pytest.fixture
def realistic_event_data(self, temp_data_dir):
"""Generate realistic user event data."""
users = [f"user_{i:04d}" for i in range(1, 1001)] # 1000 users
event_types = ["login", "click", "view", "purchase", "logout"]
events = []
base_date = datetime(2024, 1, 1)
for user in users:
# Each user has different activity patterns
user_activity_level = random.choice(["high", "medium", "low"])
if user_activity_level == "high":
num_sessions = random.randint(20, 50)
events_per_session = random.randint(10, 30)
elif user_activity_level == "medium":
num_sessions = random.randint(5, 20)
events_per_session = random.randint(5, 15)
else: # low
num_sessions = random.randint(1, 5)
events_per_session = random.randint(1, 5)
current_date = base_date
for session in range(num_sessions):
session_id = f"{user}_s{session}"
session_start = current_date + timedelta(
days=random.randint(0, 90),
hours=random.randint(8, 20),
minutes=random.randint(0, 59)
)
for event_num in range(events_per_session):
event_time = session_start + timedelta(minutes=event_num * random.randint(1, 5))
event_type = random.choice(event_types)
# Purchases have revenue
revenue = 0.0
if event_type == "purchase":
revenue = random.uniform(10, 500)
events.append({
"user_id": user,
"event_type": event_type,
"timestamp": event_time,
"revenue": revenue,
"session_id": session_id
})
# Save to parquet
df = pd.DataFrame(events)
events_path = f"{temp_data_dir}/events.parquet"
df.to_parquet(events_path, index=False)
return events_path, len(events)
def test_full_pipeline_execution(self, integration_spark, realistic_event_data, temp_data_dir):
"""Test complete pipeline with realistic data."""
events_path, expected_events = realistic_event_data
pipeline = CustomerAnalyticsPipeline(integration_spark)
# Execute pipeline steps
raw_events = pipeline.load_raw_events(events_path)
cleaned_events = pipeline.clean_events(raw_events)
user_metrics = pipeline.calculate_user_metrics(cleaned_events)
segmented_users = pipeline.assign_segments(user_metrics)
# Save results
output_path = f"{temp_data_dir}/customer_segments"
segmented_users.write.mode("overwrite").parquet(output_path)
# Validate results
result_df = integration_spark.read.parquet(output_path)
result_data = result_df.collect()
# Should have processed all users
assert len(result_data) == 1000
# Should have all expected segments
segments = {row["customer_segment"] for row in result_data}
expected_segments = {"high_value", "medium_value", "dormant", "new_user", "low_engagement"}
assert segments.issubset(expected_segments)
# Validate segment logic with spot checks
high_value_users = [r for r in result_data if r["customer_segment"] == "high_value"]
for user in high_value_users[:5]: # Check first 5
assert user["total_revenue"] > 1000
assert user["session_count"] > 50
# Define contracts for customer analytics pipeline
from pipeline.contracts import SchemaContract, FieldContract, QualityContract
from pipeline.contracts import UniquenessValidation, CompletenessValidation, BusinessRuleValidation
RAW_EVENTS_CONTRACT = SchemaContract(
name="raw_user_events",
version="1.0",
description="Raw user event data from multiple sources",
fields=[
FieldContract("user_id", "string", nullable=False,
constraints={"non_null": True}),
FieldContract("event_type", "string", nullable=False,
constraints={"non_null": True}),
FieldContract("timestamp", "timestamp", nullable=False,
constraints={"non_null": True}),
FieldContract("revenue", "decimal", nullable=True),
FieldContract("session_id", "string", nullable=True)
]
)
USER_METRICS_CONTRACT = SchemaContract(
name="user_engagement_metrics",
version="1.0",
description="Calculated user engagement metrics",
fields=[
FieldContract("user_id", "string", nullable=False),
FieldContract("total_events", "integer", nullable=False,
constraints={"min_value": 1}),
FieldContract("session_count", "integer", nullable=False,
constraints={"min_value": 1}),
FieldContract("total_revenue", "decimal", nullable=False,
constraints={"min_value": 0}),
FieldContract("customer_segment", "string", nullable=False)
]
)
EVENTS_QUALITY_CONTRACT = QualityContract(
name="events_quality_checks",
description="Quality validations for event data",
validations=[
CompletenessValidation({
"user_id": 1.0,
"event_type": 1.0,
"timestamp": 1.0
}),
BusinessRuleValidation(
"valid_event_types",
"event_type IN ('login', 'click', 'view', 'purchase', 'logout')",
"Event types must be from allowed list"
),
BusinessRuleValidation(
"non_negative_revenue",
"revenue >= 0 OR revenue IS NULL",
"Revenue must be non-negative"
)
]
)
Mistake: Writing tests that only cover "happy path" scenarios with clean, predictable data.
Problem: Real production data contains edge cases, malformed values, and unexpected patterns that can break your pipeline.
Solution: Systematically include problematic data in your tests:
def test_with_edge_cases(self, spark):
"""Test pipeline behavior with various edge cases."""
edge_case_data = [
# Boundary values
("user1", "purchase", datetime(2024, 1, 1), 0.01), # Minimum purchase
("user2", "purchase", datetime(2024, 1, 1), 999999.99), # Maximum purchase
# Special characters and encoding issues
("user_ñoël", "click", datetime(2024, 1, 1), 0.0), # Unicode characters
("user with spaces", "view", datetime(2024, 1, 1), 0.0), # Spaces in ID
# Timing edge cases
("user3", "login", datetime(1970, 1, 1), 0.0), # Unix epoch
("user4", "logout", datetime(2099, 12, 31), 0.0), # Far future
# Duplicate events
("user5", "purchase", datetime(2024, 1, 1, 10, 0), 100.0),
("user5", "purchase", datetime(2024, 1, 1, 10, 0), 100.0), # Exact duplicate
]
# Your pipeline should handle these gracefully without crashing
Mistake: Writing integration tests that don't validate performance characteristics under realistic data volumes.
Problem: Pipelines that work fine with test data may have memory leaks, inefficient joins, or poor partitioning that only shows up at scale.
Solution: Include performance assertions in your integration tests:
def test_memory_efficiency(self, integration_spark, temp_data_dir):
"""Test that pipeline doesn't exceed memory limits with large datasets."""
# Create dataset that's large enough to test memory management
# but not so large it breaks CI/CD
large_dataset = generate_test_data(rows=1000000)
initial_memory = get_spark_memory_usage(integration_spark)
result = pipeline.run(large_dataset)
peak_memory = get_spark_memory_usage(integration_spark)
memory_increase = peak_memory - initial_memory
# Assert memory usage stays within reasonable bounds
assert memory_increase < 1024 * 1024 * 1024 # Less than 1GB increase
assert result.records_processed == 1000000
# Validate that data was properly partitioned
output_partitions = result.output_df.rdd.getNumPartitions()
assert 4 <= output_partitions <= 16 # Reasonable partitioning
Mistake: Writing data contracts that are too rigid and break with minor, acceptable changes.
Problem: Overly strict contracts create false positives and make the pipeline fragile to normal data evolution.
Solution: Design contracts with appropriate tolerances:
# Too rigid
BRITTLE_CONTRACT = QualityContract(
validations=[
CompletenessValidation({"user_id": 1.0}) # Requires 100% completeness
]
)
# More resilient
RESILIENT_CONTRACT = QualityContract(
validations=[
CompletenessValidation({"user_id": 0.99}), # Allow 1% missing values
BusinessRuleValidation(
"user_id_format",
"user_id RLIKE '^user_[0-9]+$'",
"User IDs should follow expected format",
max_violations=100 # Allow some format variations
)
]
)
Mistake: Running tests that interfere with each other due to shared state or resources.
Problem: Tests become flaky and hard to debug when they depend on execution order or previous test results.
Solution: Properly isolate test environments:
@pytest.fixture
def isolated_spark_session():
"""Create isolated Spark session for each test."""
spark = SparkSession.builder \
.appName(f"test-{uuid.uuid4()}") \
.config("spark.sql.warehouse.dir", f"/tmp/spark-warehouse-{uuid.uuid4()}") \
.getOrCreate()
yield spark
# Clean up
spark.stop()
@pytest.fixture(autouse=True)
def clean_temp_tables(spark):
"""Automatically clean temporary tables after each test."""
yield
# Clean up any temporary tables created during the test
for table in spark.catalog.listTables():
if table.isTemporary:
spark.catalog.dropTempView(table.name)
When data contracts fail in production, you need systematic approaches to understand and fix the issues:
def debug_contract_failure(df: DataFrame, contract: QualityContract, output_path: str):
"""Generate detailed diagnostics for contract failures."""
# Sample failing records for inspection
sample_size = 1000
failing_sample = df.sample(fraction=0.1).limit(sample_size)
# Generate data profile
profile = {
"total_records": df.count(),
"schema": df.schema.json(),
"null_counts": {
field.name: df.filter(col(field.name).isNull()).count()
for field in df.schema.fields
},
"sample_values": {
field.name: [row[field.name] for row in failing_sample.select(field.name).distinct().limit(10).collect()]
for field in df.schema.fields if field.dataType.simpleString() == "string"
}
}
# Save diagnostic information
with open(f"{output_path}/contract_failure_diagnostics.json", "w") as f:
json.dump(profile, f, indent=2, default=str)
# Save sample of failing records
failing_sample.write.mode("overwrite").parquet(f"{output_path}/failing_records_sample")
return profile
You've now built a comprehensive three-layer testing strategy that will catch pipeline failures before they impact production systems. Here's what you've accomplished:
Unit Tests: You can isolate and validate individual transformations, handling edge cases and testing business logic with controlled inputs. These tests run fast and provide immediate feedback during development.
Integration Tests: You can validate end-to-end pipeline behavior with realistic data volumes and infrastructure, catching performance issues and integration problems before deployment.
Data Contracts: You can define and enforce explicit expectations about data structure and quality, creating reliable interfaces between pipeline stages and downstream consumers.
The framework you've built scales with pipeline complexity and provides clear diagnostic information when things go wrong. Your pipelines will fail fast and fail obviously rather than silently corrupting data.
Learning Path: Data Pipeline Fundamentals