Wicked Smart Data
LearnArticlesAbout
Sign InSign Up
LearnArticlesAboutContact
Sign InSign Up
Wicked Smart Data

The go-to platform for professionals who want to master data, automation, and AI — from Excel fundamentals to cutting-edge machine learning.

Platform

  • Learning Paths
  • Articles
  • About
  • Contact

Connect

  • Contact Us
  • RSS Feed

© 2026 Wicked Smart Data. All rights reserved.

Privacy PolicyTerms of Service
All Articles
Pipeline Testing: Unit Tests, Integration Tests, and Data Contracts

Pipeline Testing: Unit Tests, Integration Tests, and Data Contracts

Data Engineering⚡ Practitioner26 min readApr 10, 2026Updated Apr 10, 2026
Table of Contents
  • Prerequisites
  • The Three-Layer Testing Strategy
  • Building Effective Unit Tests for Pipeline Components
  • Testing Pure Transformations
  • Testing Components with External Dependencies
  • Integration Testing: Validating End-to-End Pipeline Behavior
  • Setting Up Test Infrastructure
  • Testing Complete Pipeline Workflows
  • Testing with Realistic Data Volumes
  • Implementing Data Contracts for Pipeline Reliability
  • Defining Schema Contracts
  • Quality Contracts and Data Validation

Pipeline Testing: Unit Tests, Integration Tests, and Data Contracts

Picture this: your daily sales pipeline has been running smoothly for months, generating reports that drive million-dollar business decisions. Then one Tuesday morning, you discover that a schema change in your source system three days ago has been silently corrupting your aggregations. The executive dashboard shows revenue up 40% (celebrated in the Monday meeting), but the reality is a data type conversion error has been treating null values as zeros.

This scenario plays out more often than we'd like to admit. Production data pipelines fail in subtle ways, often long before anyone notices. The difference between pipelines that fail fast and fail obviously versus those that fail silently and corrupt downstream systems comes down to comprehensive testing strategies.

By the end of this lesson, you'll implement a three-layered testing approach that catches errors before they reach production and maintains data quality contracts with downstream consumers.

What you'll learn:

  • Build unit tests that validate individual pipeline components in isolation
  • Design integration tests that verify end-to-end pipeline behavior with realistic data volumes
  • Implement data contracts that define and enforce expectations between pipeline stages
  • Structure a testing framework that scales with pipeline complexity
  • Debug common testing scenarios and interpret test failures in production contexts

Prerequisites

You should be comfortable writing data pipelines in Python, familiar with pytest or similar testing frameworks, and understand SQL basics. We'll use Apache Spark with PySpark for examples, but the testing principles apply to any pipeline technology.

The Three-Layer Testing Strategy

Data pipeline testing requires a fundamentally different approach than application testing. While application tests typically verify business logic with predictable inputs, pipeline tests must handle variable data volumes, evolving schemas, and distributed processing concerns.

Our three-layer approach addresses these challenges:

  1. Unit tests verify individual transformations with controlled inputs
  2. Integration tests validate pipeline behavior with realistic data and infrastructure
  3. Data contracts enforce expectations about data structure and quality between pipeline stages

Each layer serves a specific purpose and catches different types of failures. Let's build this framework step by step.

Building Effective Unit Tests for Pipeline Components

Unit tests form your first line of defense, validating individual transformations before they interact with real data sources or complex infrastructure.

Testing Pure Transformations

Start with your simplest components—pure functions that transform DataFrames without external dependencies:

# pipeline/transformations.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, regexp_replace
from pyspark.sql.types import DecimalType

def clean_revenue_data(df: DataFrame) -> DataFrame:
    """Clean and standardize revenue data from multiple sources."""
    return df.select(
        col("customer_id"),
        col("transaction_date"),
        # Handle various currency formats
        regexp_replace(col("amount"), r"[$,]", "").cast(DecimalType(10,2)).alias("amount"),
        # Standardize product categories
        when(col("category").isin(["SaaS", "Software", "Platform"]), "Software")
        .when(col("category").isin(["Consulting", "Services", "Professional"]), "Services")
        .otherwise(col("category")).alias("product_category"),
        # Flag suspicious transactions
        when(col("amount") > 100000, True).otherwise(False).alias("requires_review")
    )

def calculate_monthly_metrics(df: DataFrame) -> DataFrame:
    """Calculate monthly revenue metrics by category."""
    return df.groupBy("product_category", 
                     date_format(col("transaction_date"), "yyyy-MM").alias("month")) \
             .agg(
                 sum("amount").alias("total_revenue"),
                 count("*").alias("transaction_count"),
                 avg("amount").alias("avg_transaction_size"),
                 sum(when(col("requires_review"), 1).otherwise(0)).alias("flagged_transactions")
             )

Now test these transformations with carefully constructed test cases:

# tests/test_transformations.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DateType, DecimalType
from decimal import Decimal
from datetime import date

from pipeline.transformations import clean_revenue_data, calculate_monthly_metrics

@pytest.fixture(scope="session")
def spark():
    return SparkSession.builder \
        .appName("pipeline-tests") \
        .config("spark.sql.shuffle.partitions", "2") \
        .getOrCreate()

class TestRevenueDataCleaning:
    
    def test_currency_format_standardization(self, spark):
        """Test that various currency formats are properly cleaned."""
        # Arrange
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", StringType()),
            StructField("category", StringType())
        ])
        
        input_data = [
            ("C001", date(2024, 1, 15), "$1,234.56", "SaaS"),
            ("C002", date(2024, 1, 16), "2345.67", "Consulting"),
            ("C003", date(2024, 1, 17), "$15,000", "Platform"),
        ]
        
        input_df = spark.createDataFrame(input_data, schema)
        
        # Act
        result = clean_revenue_data(input_df)
        result_data = result.collect()
        
        # Assert
        assert result_data[0]["amount"] == Decimal('1234.56')
        assert result_data[1]["amount"] == Decimal('2345.67')
        assert result_data[2]["amount"] == Decimal('15000.00')
    
    def test_category_standardization(self, spark):
        """Test that product categories are properly standardized."""
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", StringType()),
            StructField("category", StringType())
        ])
        
        input_data = [
            ("C001", date(2024, 1, 15), "100.00", "SaaS"),
            ("C002", date(2024, 1, 16), "200.00", "Software"),
            ("C003", date(2024, 1, 17), "300.00", "Consulting"),
            ("C004", date(2024, 1, 18), "400.00", "Professional"),
            ("C005", date(2024, 1, 19), "500.00", "Hardware"),
        ]
        
        input_df = spark.createDataFrame(input_data, schema)
        result = clean_revenue_data(input_df)
        
        categories = [row["product_category"] for row in result.collect()]
        
        assert categories.count("Software") == 2  # SaaS and Software both map to Software
        assert categories.count("Services") == 2   # Consulting and Professional both map to Services
        assert categories.count("Hardware") == 1   # Hardware stays as-is
    
    def test_large_transaction_flagging(self, spark):
        """Test that large transactions are properly flagged for review."""
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", StringType()),
            StructField("category", StringType())
        ])
        
        input_data = [
            ("C001", date(2024, 1, 15), "50000.00", "Software"),
            ("C002", date(2024, 1, 16), "150000.00", "Software"),
        ]
        
        input_df = spark.createDataFrame(input_data, schema)
        result = clean_revenue_data(input_df)
        result_data = result.collect()
        
        assert result_data[0]["requires_review"] == False
        assert result_data[1]["requires_review"] == True

class TestMonthlyMetrics:
    
    def test_monthly_aggregation_accuracy(self, spark):
        """Test that monthly metrics are calculated correctly."""
        # Create test data spanning multiple months and categories
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", DecimalType(10,2)),
            StructField("product_category", StringType()),
            StructField("requires_review", StringType())
        ])
        
        input_data = [
            ("C001", date(2024, 1, 15), Decimal('1000.00'), "Software", False),
            ("C002", date(2024, 1, 16), Decimal('2000.00'), "Software", False),
            ("C003", date(2024, 1, 17), Decimal('150000.00'), "Software", True),
            ("C004", date(2024, 2, 1), Decimal('3000.00'), "Services", False),
            ("C005", date(2024, 2, 2), Decimal('4000.00'), "Services", False),
        ]
        
        input_df = spark.createDataFrame(input_data, schema)
        result = calculate_monthly_metrics(input_df)
        
        # Convert to dictionary for easier assertions
        result_dict = {
            (row["product_category"], row["month"]): row 
            for row in result.collect()
        }
        
        # Test January Software metrics
        jan_software = result_dict[("Software", "2024-01")]
        assert jan_software["total_revenue"] == Decimal('153000.00')
        assert jan_software["transaction_count"] == 3
        assert jan_software["avg_transaction_size"] == Decimal('51000.00')
        assert jan_software["flagged_transactions"] == 1
        
        # Test February Services metrics
        feb_services = result_dict[("Services", "2024-02")]
        assert feb_services["total_revenue"] == Decimal('7000.00')
        assert feb_services["transaction_count"] == 2
        assert feb_services["avg_transaction_size"] == Decimal('3500.00')
        assert feb_services["flagged_transactions"] == 0

Testing Components with External Dependencies

Real pipeline components often interact with databases, APIs, or file systems. Use dependency injection and mocking to test these components in isolation:

# pipeline/loaders.py
from abc import ABC, abstractmethod
from pyspark.sql import DataFrame
import requests

class DataSource(ABC):
    @abstractmethod
    def load_customer_data(self, start_date: str, end_date: str) -> DataFrame:
        pass

class DatabaseSource(DataSource):
    def __init__(self, connection_string: str):
        self.connection_string = connection_string
    
    def load_customer_data(self, start_date: str, end_date: str) -> DataFrame:
        # In real implementation, this would connect to database
        pass

class EnrichmentService:
    def __init__(self, api_key: str, base_url: str):
        self.api_key = api_key
        self.base_url = base_url
    
    def get_customer_segment(self, customer_id: str) -> str:
        response = requests.get(
            f"{self.base_url}/customers/{customer_id}/segment",
            headers={"Authorization": f"Bearer {self.api_key}"}
        )
        return response.json()["segment"]

def enrich_customer_data(df: DataFrame, enrichment_service: EnrichmentService) -> DataFrame:
    """Add customer segment information to transaction data."""
    from pyspark.sql.functions import udf
    from pyspark.sql.types import StringType
    
    @udf(returnType=StringType())
    def get_segment(customer_id):
        try:
            return enrichment_service.get_customer_segment(customer_id)
        except Exception:
            return "unknown"
    
    return df.withColumn("customer_segment", get_segment(col("customer_id")))

Test this with mock dependencies:

# tests/test_loaders.py
import pytest
from unittest.mock import Mock, patch
from pyspark.sql.types import StructType, StructField, StringType, DateType
from datetime import date

from pipeline.loaders import enrich_customer_data, EnrichmentService

class TestCustomerEnrichment:
    
    def test_successful_enrichment(self, spark):
        """Test that customer segments are properly added when API calls succeed."""
        # Arrange
        mock_service = Mock(spec=EnrichmentService)
        mock_service.get_customer_segment.side_effect = lambda cid: {
            "C001": "enterprise",
            "C002": "mid-market",
            "C003": "smb"
        }.get(cid, "unknown")
        
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", StringType()),
        ])
        
        input_data = [
            ("C001", date(2024, 1, 15), "1000.00"),
            ("C002", date(2024, 1, 16), "2000.00"),
            ("C003", date(2024, 1, 17), "3000.00"),
        ]
        
        input_df = spark.createDataFrame(input_data, schema)
        
        # Act
        result = enrich_customer_data(input_df, mock_service)
        result_data = {row["customer_id"]: row["customer_segment"] 
                      for row in result.collect()}
        
        # Assert
        assert result_data["C001"] == "enterprise"
        assert result_data["C002"] == "mid-market"
        assert result_data["C003"] == "smb"
        assert mock_service.get_customer_segment.call_count == 3
    
    def test_api_failure_handling(self, spark):
        """Test that API failures are handled gracefully."""
        # Arrange
        mock_service = Mock(spec=EnrichmentService)
        mock_service.get_customer_segment.side_effect = Exception("API timeout")
        
        schema = StructType([
            StructField("customer_id", StringType()),
            StructField("transaction_date", DateType()),
            StructField("amount", StringType()),
        ])
        
        input_data = [("C001", date(2024, 1, 15), "1000.00")]
        input_df = spark.createDataFrame(input_data, schema)
        
        # Act
        result = enrich_customer_data(input_df, mock_service)
        result_data = result.collect()
        
        # Assert
        assert result_data[0]["customer_segment"] == "unknown"

Testing Tip: When testing Spark UDFs that make external calls, be aware that Spark may cache UDF results. Use unique test data or clear the UDF cache between tests if you encounter unexpected behavior.

Integration Testing: Validating End-to-End Pipeline Behavior

While unit tests verify individual components, integration tests validate that your entire pipeline works correctly with realistic data volumes and infrastructure constraints.

Setting Up Test Infrastructure

Integration tests require a test environment that mirrors production without the cost and complexity. Use containerized services and scaled-down data:

# tests/integration/conftest.py
import pytest
import docker
import time
from pyspark.sql import SparkSession
import tempfile
import shutil

@pytest.fixture(scope="session")
def test_postgres():
    """Start a PostgreSQL container for integration tests."""
    client = docker.from_env()
    
    container = client.containers.run(
        "postgres:13",
        environment={
            "POSTGRES_DB": "testdb",
            "POSTGRES_USER": "testuser",
            "POSTGRES_PASSWORD": "testpass"
        },
        ports={"5432/tcp": None},
        detach=True,
        remove=True
    )
    
    # Wait for PostgreSQL to be ready
    time.sleep(10)
    
    # Get the mapped port
    port = container.ports["5432/tcp"][0]["HostPort"]
    connection_string = f"postgresql://testuser:testpass@localhost:{port}/testdb"
    
    yield connection_string
    
    container.stop()

@pytest.fixture(scope="session")
def integration_spark():
    """Spark session configured for integration tests."""
    return SparkSession.builder \
        .appName("integration-tests") \
        .config("spark.sql.shuffle.partitions", "4") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .getOrCreate()

@pytest.fixture
def temp_data_dir():
    """Temporary directory for test data files."""
    temp_dir = tempfile.mkdtemp()
    yield temp_dir
    shutil.rmtree(temp_dir)

Testing Complete Pipeline Workflows

Build integration tests that exercise your entire pipeline with realistic data scenarios:

# tests/integration/test_revenue_pipeline.py
import pytest
import pandas as pd
from datetime import date, timedelta
import random
from decimal import Decimal

from pipeline.main import RevenuePipeline

class TestRevenuePipelineIntegration:
    
    @pytest.fixture
    def sample_data(self, integration_spark, temp_data_dir):
        """Generate realistic test data that covers edge cases."""
        
        # Generate 30 days of transaction data
        start_date = date(2024, 1, 1)
        customers = [f"C{i:03d}" for i in range(1, 51)]  # 50 customers
        categories = ["SaaS", "Software", "Consulting", "Professional", "Hardware"]
        
        transactions = []
        for day in range(30):
            current_date = start_date + timedelta(days=day)
            
            # Generate 20-100 transactions per day
            daily_transactions = random.randint(20, 100)
            
            for _ in range(daily_transactions):
                customer = random.choice(customers)
                category = random.choice(categories)
                
                # Create realistic amount distribution
                if category in ["SaaS", "Software"]:
                    amount = random.uniform(500, 5000)
                elif category in ["Consulting", "Professional"]:
                    amount = random.uniform(2000, 25000)
                else:  # Hardware
                    amount = random.uniform(100, 50000)
                
                # Occasionally generate large transactions
                if random.random() < 0.05:  # 5% chance
                    amount *= random.uniform(10, 50)
                
                # Add some data quality issues
                if random.random() < 0.02:  # 2% malformed amounts
                    amount_str = f"${amount:,.2f}"
                else:
                    amount_str = f"{amount:.2f}"
                
                transactions.append({
                    "customer_id": customer,
                    "transaction_date": current_date,
                    "amount": amount_str,
                    "category": category
                })
        
        # Save to parquet for realistic file-based input
        df = pd.DataFrame(transactions)
        input_path = f"{temp_data_dir}/raw_transactions.parquet"
        df.to_parquet(input_path, index=False)
        
        return input_path, len(transactions)
    
    def test_full_pipeline_execution(self, integration_spark, test_postgres, sample_data, temp_data_dir):
        """Test complete pipeline execution with realistic data volume."""
        input_path, expected_count = sample_data
        output_path = f"{temp_data_dir}/processed"
        
        # Initialize pipeline
        pipeline = RevenuePipeline(
            spark=integration_spark,
            input_path=input_path,
            output_path=output_path,
            database_connection=test_postgres
        )
        
        # Execute pipeline
        pipeline_result = pipeline.run()
        
        # Validate pipeline completed successfully
        assert pipeline_result.success is True
        assert pipeline_result.records_processed > 0
        assert pipeline_result.records_processed <= expected_count  # Some may be filtered
        
        # Validate output data structure
        output_df = integration_spark.read.parquet(output_path)
        
        # Check schema
        expected_columns = ["customer_id", "transaction_date", "amount", 
                          "product_category", "requires_review", "month", 
                          "total_revenue", "transaction_count"]
        assert set(output_df.columns) >= set(expected_columns)
        
        # Check data quality
        assert output_df.count() > 0
        assert output_df.filter("amount IS NULL").count() == 0
        assert output_df.filter("product_category IS NULL").count() == 0
        
        # Validate business logic
        large_transactions = output_df.filter("amount > 100000").count()
        flagged_transactions = output_df.filter("requires_review = true").count()
        assert flagged_transactions >= large_transactions
    
    def test_pipeline_performance_characteristics(self, integration_spark, sample_data, temp_data_dir):
        """Test that pipeline meets performance requirements."""
        input_path, expected_count = sample_data
        output_path = f"{temp_data_dir}/processed"
        
        pipeline = RevenuePipeline(
            spark=integration_spark,
            input_path=input_path,
            output_path=output_path
        )
        
        import time
        start_time = time.time()
        result = pipeline.run()
        execution_time = time.time() - start_time
        
        # Performance assertions
        assert result.success is True
        assert execution_time < 60  # Should complete within 1 minute for test data
        
        # Check resource utilization
        assert result.memory_usage_mb < 2048  # Should use less than 2GB
        assert result.cpu_time_seconds < 120   # Should use less than 2 minutes CPU time
    
    def test_pipeline_error_handling(self, integration_spark, temp_data_dir):
        """Test pipeline behavior when encountering data quality issues."""
        
        # Create data with various quality issues
        bad_data = [
            {"customer_id": None, "transaction_date": "2024-01-01", "amount": "1000.00", "category": "Software"},
            {"customer_id": "C001", "transaction_date": None, "amount": "2000.00", "category": "SaaS"},
            {"customer_id": "C002", "transaction_date": "2024-01-03", "amount": "invalid", "category": "Hardware"},
            {"customer_id": "C003", "transaction_date": "2024-01-04", "amount": "3000.00", "category": None},
        ]
        
        df = pd.DataFrame(bad_data)
        bad_input_path = f"{temp_data_dir}/bad_data.parquet"
        df.to_parquet(bad_input_path, index=False)
        
        pipeline = RevenuePipeline(
            spark=integration_spark,
            input_path=bad_input_path,
            output_path=f"{temp_data_dir}/output",
            error_handling="strict"
        )
        
        result = pipeline.run()
        
        # Pipeline should complete but report data quality issues
        assert result.success is True
        assert result.data_quality_warnings > 0
        assert result.records_processed < 4  # Some records should be filtered
        
        # Error records should be captured
        error_df = integration_spark.read.parquet(f"{temp_data_dir}/output/errors")
        assert error_df.count() > 0

Testing with Realistic Data Volumes

Integration tests should validate behavior at realistic scales without overwhelming your test environment:

def test_large_dataset_processing(self, integration_spark, temp_data_dir):
    """Test pipeline behavior with larger dataset that exercises partitioning."""
    
    # Generate 1M records across multiple partitions
    import numpy as np
    
    records_per_partition = 100000
    num_partitions = 10
    
    for partition in range(num_partitions):
        partition_data = []
        for i in range(records_per_partition):
            partition_data.append({
                "customer_id": f"C{i % 1000:03d}",
                "transaction_date": date(2024, 1, 1) + timedelta(days=i % 90),
                "amount": f"{random.uniform(100, 10000):.2f}",
                "category": random.choice(["SaaS", "Hardware", "Consulting"])
            })
        
        df = pd.DataFrame(partition_data)
        df.to_parquet(f"{temp_data_dir}/large_data_part_{partition}.parquet", index=False)
    
    pipeline = RevenuePipeline(
        spark=integration_spark,
        input_path=f"{temp_data_dir}/large_data_part_*.parquet",
        output_path=f"{temp_data_dir}/large_output"
    )
    
    result = pipeline.run()
    
    assert result.success is True
    assert result.records_processed == records_per_partition * num_partitions
    
    # Validate partitioning strategy worked
    output_files = integration_spark.read.parquet(f"{temp_data_dir}/large_output").rdd.getNumPartitions()
    assert output_files >= 4  # Should maintain reasonable parallelism

Performance Testing Note: Integration tests run in CI/CD environments with limited resources. Focus on testing partitioning logic and memory management rather than absolute performance numbers.

Implementing Data Contracts for Pipeline Reliability

Data contracts define explicit expectations about data structure, quality, and semantics between pipeline stages. They serve as both documentation and automated validation.

Defining Schema Contracts

Start with schema contracts that define expected data structure:

# pipeline/contracts.py
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
from pyspark.sql.types import StructType, StructField, StringType, DateType, DecimalType, BooleanType
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, count, sum as spark_sum, avg, max as spark_max, min as spark_min
import logging

@dataclass
class FieldContract:
    name: str
    data_type: str
    nullable: bool = True
    constraints: Optional[Dict[str, Any]] = None
    description: str = ""

@dataclass
class SchemaContract:
    name: str
    version: str
    fields: List[FieldContract]
    description: str = ""
    
    def to_spark_schema(self) -> StructType:
        """Convert contract to Spark StructType."""
        type_mapping = {
            "string": StringType(),
            "date": DateType(),
            "decimal": DecimalType(10, 2),
            "boolean": BooleanType()
        }
        
        return StructType([
            StructField(field.name, type_mapping[field.data_type], field.nullable)
            for field in self.fields
        ])
    
    def validate_schema(self, df: DataFrame) -> "ValidationResult":
        """Validate DataFrame schema against contract."""
        errors = []
        warnings = []
        
        # Check for missing fields
        df_columns = set(df.columns)
        contract_columns = {field.name for field in self.fields}
        
        missing_fields = contract_columns - df_columns
        if missing_fields:
            errors.append(f"Missing required fields: {missing_fields}")
        
        extra_fields = df_columns - contract_columns
        if extra_fields:
            warnings.append(f"Unexpected fields found: {extra_fields}")
        
        # Check field types and constraints
        for field in self.fields:
            if field.name in df_columns:
                df_field = df.schema[field.name]
                
                # Type validation would go here
                # (simplified for example)
                
                # Constraint validation
                if field.constraints:
                    constraint_errors = self._validate_field_constraints(df, field)
                    errors.extend(constraint_errors)
        
        return ValidationResult(
            contract_name=self.name,
            is_valid=len(errors) == 0,
            errors=errors,
            warnings=warnings
        )
    
    def _validate_field_constraints(self, df: DataFrame, field: FieldContract) -> List[str]:
        """Validate field-level constraints."""
        errors = []
        
        if not field.constraints:
            return errors
        
        field_col = col(field.name)
        
        # Non-null constraint
        if field.constraints.get("non_null", False):
            null_count = df.filter(field_col.isNull()).count()
            if null_count > 0:
                errors.append(f"Field {field.name} has {null_count} null values but is marked non_null")
        
        # Range constraints
        if "min_value" in field.constraints or "max_value" in field.constraints:
            stats = df.agg(
                spark_min(field_col).alias("min_val"),
                spark_max(field_col).alias("max_val")
            ).collect()[0]
            
            if field.constraints.get("min_value") is not None:
                if stats["min_val"] < field.constraints["min_value"]:
                    errors.append(f"Field {field.name} has values below minimum {field.constraints['min_value']}")
            
            if field.constraints.get("max_value") is not None:
                if stats["max_val"] > field.constraints["max_value"]:
                    errors.append(f"Field {field.name} has values above maximum {field.constraints['max_value']}")
        
        return errors

# Define contracts for our revenue pipeline
RAW_TRANSACTIONS_CONTRACT = SchemaContract(
    name="raw_transactions",
    version="1.0",
    description="Raw transaction data from source systems",
    fields=[
        FieldContract("customer_id", "string", nullable=False, 
                     constraints={"non_null": True},
                     description="Unique customer identifier"),
        FieldContract("transaction_date", "date", nullable=False,
                     constraints={"non_null": True, "min_value": "2020-01-01"},
                     description="Date of transaction"),
        FieldContract("amount", "string", nullable=False,
                     description="Transaction amount (may contain currency formatting)"),
        FieldContract("category", "string", nullable=False,
                     constraints={"non_null": True},
                     description="Product category")
    ]
)

CLEANED_TRANSACTIONS_CONTRACT = SchemaContract(
    name="cleaned_transactions",
    version="1.0",
    description="Cleaned and standardized transaction data",
    fields=[
        FieldContract("customer_id", "string", nullable=False, 
                     constraints={"non_null": True}),
        FieldContract("transaction_date", "date", nullable=False,
                     constraints={"non_null": True}),
        FieldContract("amount", "decimal", nullable=False,
                     constraints={"non_null": True, "min_value": 0}),
        FieldContract("product_category", "string", nullable=False,
                     constraints={"non_null": True}),
        FieldContract("requires_review", "boolean", nullable=False)
    ]
)

Quality Contracts and Data Validation

Beyond schema validation, implement quality contracts that check data distribution and business logic:

@dataclass
class QualityContract:
    name: str
    description: str
    validations: List["QualityValidation"]
    
    def validate(self, df: DataFrame) -> "ValidationResult":
        """Run all quality validations on the DataFrame."""
        errors = []
        warnings = []
        
        for validation in self.validations:
            result = validation.validate(df)
            errors.extend(result.errors)
            warnings.extend(result.warnings)
        
        return ValidationResult(
            contract_name=self.name,
            is_valid=len(errors) == 0,
            errors=errors,
            warnings=warnings
        )

class QualityValidation:
    """Base class for data quality validations."""
    
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
    
    def validate(self, df: DataFrame) -> "ValidationResult":
        raise NotImplementedError

class UniquenessValidation(QualityValidation):
    """Validate that specified columns contain unique values."""
    
    def __init__(self, columns: List[str], tolerance: float = 0.0):
        super().__init__(
            f"uniqueness_check_{'+'.join(columns)}",
            f"Check uniqueness of {columns}"
        )
        self.columns = columns
        self.tolerance = tolerance  # Allow some percentage of duplicates
    
    def validate(self, df: DataFrame) -> "ValidationResult":
        total_count = df.count()
        distinct_count = df.select(*self.columns).distinct().count()
        
        duplicate_rate = (total_count - distinct_count) / total_count if total_count > 0 else 0
        
        errors = []
        if duplicate_rate > self.tolerance:
            errors.append(
                f"Duplicate rate {duplicate_rate:.2%} exceeds tolerance {self.tolerance:.2%} "
                f"for columns {self.columns}"
            )
        
        return ValidationResult(
            contract_name=self.name,
            is_valid=len(errors) == 0,
            errors=errors,
            warnings=[]
        )

class CompletenessValidation(QualityValidation):
    """Validate that fields meet completeness requirements."""
    
    def __init__(self, field_requirements: Dict[str, float]):
        super().__init__(
            "completeness_check",
            "Check field completeness requirements"
        )
        self.field_requirements = field_requirements  # field -> minimum completeness ratio
    
    def validate(self, df: DataFrame) -> "ValidationResult":
        total_count = df.count()
        errors = []
        
        for field, min_completeness in self.field_requirements.items():
            non_null_count = df.filter(col(field).isNotNull()).count()
            completeness = non_null_count / total_count if total_count > 0 else 0
            
            if completeness < min_completeness:
                errors.append(
                    f"Field {field} completeness {completeness:.2%} below requirement {min_completeness:.2%}"
                )
        
        return ValidationResult(
            contract_name=self.name,
            is_valid=len(errors) == 0,
            errors=errors,
            warnings=[]
        )

class BusinessRuleValidation(QualityValidation):
    """Validate custom business rules."""
    
    def __init__(self, rule_name: str, condition_expr: str, description: str, max_violations: int = 0):
        super().__init__(rule_name, description)
        self.condition_expr = condition_expr
        self.max_violations = max_violations
    
    def validate(self, df: DataFrame) -> "ValidationResult":
        violations = df.filter(f"NOT ({self.condition_expr})").count()
        
        errors = []
        if violations > self.max_violations:
            errors.append(
                f"Business rule '{self.description}' violated {violations} times "
                f"(max allowed: {self.max_violations})"
            )
        
        return ValidationResult(
            contract_name=self.name,
            is_valid=len(errors) == 0,
            errors=errors,
            warnings=[]
        )

# Define quality contracts for our pipeline
REVENUE_QUALITY_CONTRACT = QualityContract(
    name="revenue_quality_checks",
    description="Quality validations for revenue data",
    validations=[
        UniquenessValidation(["customer_id", "transaction_date"], tolerance=0.01),  # Allow 1% duplicates
        CompletenessValidation({
            "customer_id": 1.0,      # 100% required
            "transaction_date": 1.0,  # 100% required
            "amount": 0.95,          # 95% required
            "product_category": 0.90  # 90% required
        }),
        BusinessRuleValidation(
            "positive_amounts",
            "amount > 0",
            "All transaction amounts must be positive"
        ),
        BusinessRuleValidation(
            "recent_transactions",
            "transaction_date >= date_sub(current_date(), 365)",
            "Transactions should be within the last year",
            max_violations=100  # Allow some historical data
        )
    ]
)

Contract Enforcement in Pipelines

Integrate contract validation into your pipeline execution:

@dataclass
class ValidationResult:
    contract_name: str
    is_valid: bool
    errors: List[str]
    warnings: List[str]
    
    def __post_init__(self):
        if self.errors:
            logging.error(f"Contract {self.contract_name} validation failed: {self.errors}")
        if self.warnings:
            logging.warning(f"Contract {self.contract_name} validation warnings: {self.warnings}")

class ContractEnforcer:
    """Enforces data contracts at pipeline checkpoints."""
    
    def __init__(self, fail_on_error: bool = True, log_warnings: bool = True):
        self.fail_on_error = fail_on_error
        self.log_warnings = log_warnings
        self.validation_results = []
    
    def validate_schema(self, df: DataFrame, contract: SchemaContract, checkpoint_name: str) -> DataFrame:
        """Validate DataFrame against schema contract."""
        result = contract.validate_schema(df)
        result.checkpoint_name = checkpoint_name
        self.validation_results.append(result)
        
        if not result.is_valid and self.fail_on_error:
            raise ValueError(f"Schema validation failed at {checkpoint_name}: {result.errors}")
        
        return df
    
    def validate_quality(self, df: DataFrame, contract: QualityContract, checkpoint_name: str) -> DataFrame:
        """Validate DataFrame against quality contract."""
        result = contract.validate(df)
        result.checkpoint_name = checkpoint_name
        self.validation_results.append(result)
        
        if not result.is_valid and self.fail_on_error:
            raise ValueError(f"Quality validation failed at {checkpoint_name}: {result.errors}")
        
        return df
    
    def get_validation_summary(self) -> Dict[str, Any]:
        """Get summary of all validation results."""
        total_validations = len(self.validation_results)
        failed_validations = sum(1 for r in self.validation_results if not r.is_valid)
        
        return {
            "total_validations": total_validations,
            "failed_validations": failed_validations,
            "success_rate": (total_validations - failed_validations) / total_validations if total_validations > 0 else 0,
            "results": self.validation_results
        }

# Updated pipeline with contract enforcement
class RevenuePipeline:
    
    def __init__(self, spark: SparkSession, input_path: str, output_path: str, 
                 enforce_contracts: bool = True):
        self.spark = spark
        self.input_path = input_path
        self.output_path = output_path
        self.enforcer = ContractEnforcer(fail_on_error=enforce_contracts)
    
    def run(self):
        """Execute pipeline with contract validation."""
        
        # Load raw data
        raw_df = self.spark.read.parquet(self.input_path)
        
        # Validate input contract
        raw_df = self.enforcer.validate_schema(
            raw_df, RAW_TRANSACTIONS_CONTRACT, "raw_input"
        )
        
        # Clean data
        cleaned_df = clean_revenue_data(raw_df)
        
        # Validate cleaned data contract
        cleaned_df = self.enforcer.validate_schema(
            cleaned_df, CLEANED_TRANSACTIONS_CONTRACT, "cleaned_data"
        )
        cleaned_df = self.enforcer.validate_quality(
            cleaned_df, REVENUE_QUALITY_CONTRACT, "cleaned_data"
        )
        
        # Calculate metrics
        metrics_df = calculate_monthly_metrics(cleaned_df)
        
        # Save results
        metrics_df.write.mode("overwrite").parquet(self.output_path)
        
        return PipelineResult(
            success=True,
            validation_summary=self.enforcer.get_validation_summary(),
            records_processed=cleaned_df.count()
        )

Hands-On Exercise: Building a Complete Testing Framework

Now let's build a complete testing framework for a realistic scenario: a customer analytics pipeline that processes user behavior data and generates segmentation insights.

The Pipeline

We'll create a pipeline that:

  1. Loads raw event data from multiple sources
  2. Cleans and standardizes the data
  3. Calculates user engagement metrics
  4. Assigns customer segments based on behavior
  5. Outputs segmented customer profiles
# pipeline/customer_analytics.py
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, count, sum as spark_sum, avg, when, datediff, current_date, lag
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DoubleType

class CustomerAnalyticsPipeline:
    
    def __init__(self, spark: SparkSession):
        self.spark = spark
    
    def load_raw_events(self, events_path: str) -> DataFrame:
        """Load raw user event data."""
        return self.spark.read.parquet(events_path)
    
    def clean_events(self, df: DataFrame) -> DataFrame:
        """Clean and standardize event data."""
        return df.filter(
            col("user_id").isNotNull() &
            col("event_type").isNotNull() &
            col("timestamp").isNotNull()
        ).select(
            col("user_id"),
            col("event_type"),
            col("timestamp"),
            when(col("revenue").isNull(), 0.0).otherwise(col("revenue")).alias("revenue"),
            col("session_id")
        )
    
    def calculate_user_metrics(self, df: DataFrame) -> DataFrame:
        """Calculate engagement metrics per user."""
        window_spec = Window.partitionBy("user_id").orderBy("timestamp")
        
        # Add session indicators
        df_with_sessions = df.withColumn(
            "prev_timestamp",
            lag("timestamp").over(window_spec)
        ).withColumn(
            "session_gap_minutes",
            (col("timestamp").cast("long") - col("prev_timestamp").cast("long")) / 60
        ).withColumn(
            "new_session",
            when((col("session_gap_minutes") > 30) | col("prev_timestamp").isNull(), 1).otherwise(0)
        )
        
        return df_with_sessions.groupBy("user_id").agg(
            count("*").alias("total_events"),
            spark_sum("new_session").alias("session_count"),
            spark_sum("revenue").alias("total_revenue"),
            avg("revenue").alias("avg_revenue_per_event"),
            datediff(current_date(), min("timestamp")).alias("days_since_first_event"),
            datediff(current_date(), max("timestamp")).alias("days_since_last_event"),
            countDistinct("event_type").alias("unique_event_types")
        )
    
    def assign_segments(self, df: DataFrame) -> DataFrame:
        """Assign customer segments based on behavior metrics."""
        return df.withColumn(
            "customer_segment",
            when(
                (col("total_revenue") > 1000) & (col("session_count") > 50),
                "high_value"
            ).when(
                (col("total_revenue") > 100) & (col("session_count") > 10),
                "medium_value"
            ).when(
                col("days_since_last_event") > 30,
                "dormant"
            ).when(
                col("days_since_first_event") < 7,
                "new_user"
            ).otherwise("low_engagement")
        )

Unit Tests

# tests/test_customer_analytics.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from datetime import datetime, timedelta
from pipeline.customer_analytics import CustomerAnalyticsPipeline

class TestCustomerAnalyticsPipeline:
    
    def test_event_cleaning(self, spark):
        """Test that event cleaning properly handles data quality issues."""
        
        schema = StructType([
            StructField("user_id", StringType()),
            StructField("event_type", StringType()),
            StructField("timestamp", TimestampType()),
            StructField("revenue", DoubleType()),
            StructField("session_id", StringType())
        ])
        
        # Include various data quality issues
        test_data = [
            ("user1", "login", datetime(2024, 1, 1, 10, 0), 0.0, "s1"),
            ("user2", "purchase", datetime(2024, 1, 1, 11, 0), 25.99, "s2"),
            (None, "click", datetime(2024, 1, 1, 12, 0), 0.0, "s3"),  # Null user_id
            ("user3", None, datetime(2024, 1, 1, 13, 0), 0.0, "s4"),  # Null event_type
            ("user4", "purchase", None, 50.00, "s5"),  # Null timestamp
            ("user5", "purchase", datetime(2024, 1, 1, 14, 0), None, "s6"),  # Null revenue
        ]
        
        input_df = spark.createDataFrame(test_data, schema)
        pipeline = CustomerAnalyticsPipeline(spark)
        
        result = pipeline.clean_events(input_df)
        result_data = result.collect()
        
        # Should keep valid records and handle null revenue
        assert len(result_data) == 3
        
        # Check that null revenue is converted to 0.0
        revenue_null_record = [r for r in result_data if r["user_id"] == "user5"][0]
        assert revenue_null_record["revenue"] == 0.0
    
    def test_user_metrics_calculation(self, spark):
        """Test user engagement metrics calculation."""
        
        schema = StructType([
            StructField("user_id", StringType()),
            StructField("event_type", StringType()),
            StructField("timestamp", TimestampType()),
            StructField("revenue", DoubleType()),
            StructField("session_id", StringType())
        ])
        
        # Create test data for two users with different patterns
        base_time = datetime(2024, 1, 1, 10, 0)
        test_data = [
            # User1: Active user with multiple sessions
            ("user1", "login", base_time, 0.0, "s1"),
            ("user1", "click", base_time + timedelta(minutes=5), 0.0, "s1"),
            ("user1", "purchase", base_time + timedelta(minutes=10), 99.99, "s1"),
            ("user1", "login", base_time + timedelta(hours=2), 0.0, "s2"),  # New session (>30min gap)
            ("user1", "purchase", base_time + timedelta(hours=2, minutes=5), 149.99, "s2"),
            
            # User2: Less active user
            ("user2", "login", base_time, 0.0, "s3"),
            ("user2", "click", base_time + timedelta(minutes=2), 0.0, "s3"),
        ]
        
        input_df = spark.createDataFrame(test_data, schema)
        pipeline = CustomerAnalyticsPipeline(spark)
        
        result = pipeline.calculate_user_metrics(input_df)
        result_dict = {row["user_id"]: row for row in result.collect()}
        
        # Validate user1 metrics
        user1 = result_dict["user1"]
        assert user1["total_events"] == 5
        assert user1["session_count"] == 2
        assert user1["total_revenue"] == 249.98
        assert user1["unique_event_types"] == 3
        
        # Validate user2 metrics
        user2 = result_dict["user2"]
        assert user2["total_events"] == 2
        assert user2["session_count"] == 1
        assert user2["total_revenue"] == 0.0
        assert user2["unique_event_types"] == 2
    
    def test_customer_segmentation(self, spark):
        """Test customer segment assignment logic."""
        
        schema = StructType([
            StructField("user_id", StringType()),
            StructField("total_events", IntegerType()),
            StructField("session_count", IntegerType()),
            StructField("total_revenue", DoubleType()),
            StructField("days_since_first_event", IntegerType()),
            StructField("days_since_last_event", IntegerType()),
        ])
        
        test_data = [
            ("high_value_user", 100, 75, 2500.00, 90, 1),      # High value
            ("medium_value_user", 50, 25, 500.00, 60, 2),      # Medium value  
            ("dormant_user", 30, 15, 200.00, 120, 45),         # Dormant
            ("new_user", 5, 2, 50.00, 3, 1),                   # New user
            ("low_engagement_user", 10, 3, 25.00, 30, 5),      # Low engagement
        ]
        
        input_df = spark.createDataFrame(test_data, schema)
        pipeline = CustomerAnalyticsPipeline(spark)
        
        result = pipeline.assign_segments(input_df)
        result_dict = {row["user_id"]: row["customer_segment"] for row in result.collect()}
        
        assert result_dict["high_value_user"] == "high_value"
        assert result_dict["medium_value_user"] == "medium_value"
        assert result_dict["dormant_user"] == "dormant"
        assert result_dict["new_user"] == "new_user"
        assert result_dict["low_engagement_user"] == "low_engagement"

Integration Tests

# tests/integration/test_customer_analytics_integration.py
import pytest
from datetime import datetime, timedelta
import random
import pandas as pd
from pipeline.customer_analytics import CustomerAnalyticsPipeline

class TestCustomerAnalyticsIntegration:
    
    @pytest.fixture
    def realistic_event_data(self, temp_data_dir):
        """Generate realistic user event data."""
        
        users = [f"user_{i:04d}" for i in range(1, 1001)]  # 1000 users
        event_types = ["login", "click", "view", "purchase", "logout"]
        
        events = []
        base_date = datetime(2024, 1, 1)
        
        for user in users:
            # Each user has different activity patterns
            user_activity_level = random.choice(["high", "medium", "low"])
            
            if user_activity_level == "high":
                num_sessions = random.randint(20, 50)
                events_per_session = random.randint(10, 30)
            elif user_activity_level == "medium":
                num_sessions = random.randint(5, 20)
                events_per_session = random.randint(5, 15)
            else:  # low
                num_sessions = random.randint(1, 5)
                events_per_session = random.randint(1, 5)
            
            current_date = base_date
            
            for session in range(num_sessions):
                session_id = f"{user}_s{session}"
                session_start = current_date + timedelta(
                    days=random.randint(0, 90),
                    hours=random.randint(8, 20),
                    minutes=random.randint(0, 59)
                )
                
                for event_num in range(events_per_session):
                    event_time = session_start + timedelta(minutes=event_num * random.randint(1, 5))
                    event_type = random.choice(event_types)
                    
                    # Purchases have revenue
                    revenue = 0.0
                    if event_type == "purchase":
                        revenue = random.uniform(10, 500)
                    
                    events.append({
                        "user_id": user,
                        "event_type": event_type,
                        "timestamp": event_time,
                        "revenue": revenue,
                        "session_id": session_id
                    })
        
        # Save to parquet
        df = pd.DataFrame(events)
        events_path = f"{temp_data_dir}/events.parquet"
        df.to_parquet(events_path, index=False)
        
        return events_path, len(events)
    
    def test_full_pipeline_execution(self, integration_spark, realistic_event_data, temp_data_dir):
        """Test complete pipeline with realistic data."""
        
        events_path, expected_events = realistic_event_data
        
        pipeline = CustomerAnalyticsPipeline(integration_spark)
        
        # Execute pipeline steps
        raw_events = pipeline.load_raw_events(events_path)
        cleaned_events = pipeline.clean_events(raw_events)
        user_metrics = pipeline.calculate_user_metrics(cleaned_events)
        segmented_users = pipeline.assign_segments(user_metrics)
        
        # Save results
        output_path = f"{temp_data_dir}/customer_segments"
        segmented_users.write.mode("overwrite").parquet(output_path)
        
        # Validate results
        result_df = integration_spark.read.parquet(output_path)
        result_data = result_df.collect()
        
        # Should have processed all users
        assert len(result_data) == 1000
        
        # Should have all expected segments
        segments = {row["customer_segment"] for row in result_data}
        expected_segments = {"high_value", "medium_value", "dormant", "new_user", "low_engagement"}
        assert segments.issubset(expected_segments)
        
        # Validate segment logic with spot checks
        high_value_users = [r for r in result_data if r["customer_segment"] == "high_value"]
        for user in high_value_users[:5]:  # Check first 5
            assert user["total_revenue"] > 1000
            assert user["session_count"] > 50

Data Contracts

# Define contracts for customer analytics pipeline
from pipeline.contracts import SchemaContract, FieldContract, QualityContract
from pipeline.contracts import UniquenessValidation, CompletenessValidation, BusinessRuleValidation

RAW_EVENTS_CONTRACT = SchemaContract(
    name="raw_user_events",
    version="1.0",
    description="Raw user event data from multiple sources",
    fields=[
        FieldContract("user_id", "string", nullable=False,
                     constraints={"non_null": True}),
        FieldContract("event_type", "string", nullable=False,
                     constraints={"non_null": True}),
        FieldContract("timestamp", "timestamp", nullable=False,
                     constraints={"non_null": True}),
        FieldContract("revenue", "decimal", nullable=True),
        FieldContract("session_id", "string", nullable=True)
    ]
)

USER_METRICS_CONTRACT = SchemaContract(
    name="user_engagement_metrics",
    version="1.0",
    description="Calculated user engagement metrics",
    fields=[
        FieldContract("user_id", "string", nullable=False),
        FieldContract("total_events", "integer", nullable=False,
                     constraints={"min_value": 1}),
        FieldContract("session_count", "integer", nullable=False,
                     constraints={"min_value": 1}),
        FieldContract("total_revenue", "decimal", nullable=False,
                     constraints={"min_value": 0}),
        FieldContract("customer_segment", "string", nullable=False)
    ]
)

EVENTS_QUALITY_CONTRACT = QualityContract(
    name="events_quality_checks",
    description="Quality validations for event data",
    validations=[
        CompletenessValidation({
            "user_id": 1.0,
            "event_type": 1.0,
            "timestamp": 1.0
        }),
        BusinessRuleValidation(
            "valid_event_types",
            "event_type IN ('login', 'click', 'view', 'purchase', 'logout')",
            "Event types must be from allowed list"
        ),
        BusinessRuleValidation(
            "non_negative_revenue",
            "revenue >= 0 OR revenue IS NULL",
            "Revenue must be non-negative"
        )
    ]
)

Common Mistakes & Troubleshooting

Testing with Insufficient Data Variety

Mistake: Writing tests that only cover "happy path" scenarios with clean, predictable data.

Problem: Real production data contains edge cases, malformed values, and unexpected patterns that can break your pipeline.

Solution: Systematically include problematic data in your tests:

def test_with_edge_cases(self, spark):
    """Test pipeline behavior with various edge cases."""
    
    edge_case_data = [
        # Boundary values
        ("user1", "purchase", datetime(2024, 1, 1), 0.01),     # Minimum purchase
        ("user2", "purchase", datetime(2024, 1, 1), 999999.99), # Maximum purchase
        
        # Special characters and encoding issues
        ("user_ñoël", "click", datetime(2024, 1, 1), 0.0),     # Unicode characters
        ("user with spaces", "view", datetime(2024, 1, 1), 0.0), # Spaces in ID
        
        # Timing edge cases
        ("user3", "login", datetime(1970, 1, 1), 0.0),         # Unix epoch
        ("user4", "logout", datetime(2099, 12, 31), 0.0),      # Far future
        
        # Duplicate events
        ("user5", "purchase", datetime(2024, 1, 1, 10, 0), 100.0),
        ("user5", "purchase", datetime(2024, 1, 1, 10, 0), 100.0), # Exact duplicate
    ]
    
    # Your pipeline should handle these gracefully without crashing

Ignoring Performance in Integration Tests

Mistake: Writing integration tests that don't validate performance characteristics under realistic data volumes.

Problem: Pipelines that work fine with test data may have memory leaks, inefficient joins, or poor partitioning that only shows up at scale.

Solution: Include performance assertions in your integration tests:

def test_memory_efficiency(self, integration_spark, temp_data_dir):
    """Test that pipeline doesn't exceed memory limits with large datasets."""
    
    # Create dataset that's large enough to test memory management
    # but not so large it breaks CI/CD
    large_dataset = generate_test_data(rows=1000000)  
    
    initial_memory = get_spark_memory_usage(integration_spark)
    
    result = pipeline.run(large_dataset)
    
    peak_memory = get_spark_memory_usage(integration_spark)
    memory_increase = peak_memory - initial_memory
    
    # Assert memory usage stays within reasonable bounds
    assert memory_increase < 1024 * 1024 * 1024  # Less than 1GB increase
    assert result.records_processed == 1000000
    
    # Validate that data was properly partitioned
    output_partitions = result.output_df.rdd.getNumPartitions()
    assert 4 <= output_partitions <= 16  # Reasonable partitioning

Brittle Contract Validation

Mistake: Writing data contracts that are too rigid and break with minor, acceptable changes.

Problem: Overly strict contracts create false positives and make the pipeline fragile to normal data evolution.

Solution: Design contracts with appropriate tolerances:

# Too rigid
BRITTLE_CONTRACT = QualityContract(
    validations=[
        CompletenessValidation({"user_id": 1.0})  # Requires 100% completeness
    ]
)

# More resilient
RESILIENT_CONTRACT = QualityContract(
    validations=[
        CompletenessValidation({"user_id": 0.99}),  # Allow 1% missing values
        BusinessRuleValidation(
            "user_id_format",
            "user_id RLIKE '^user_[0-9]+$'",
            "User IDs should follow expected format",
            max_violations=100  # Allow some format variations
        )
    ]
)

Insufficient Test Environment Isolation

Mistake: Running tests that interfere with each other due to shared state or resources.

Problem: Tests become flaky and hard to debug when they depend on execution order or previous test results.

Solution: Properly isolate test environments:

@pytest.fixture
def isolated_spark_session():
    """Create isolated Spark session for each test."""
    spark = SparkSession.builder \
        .appName(f"test-{uuid.uuid4()}") \
        .config("spark.sql.warehouse.dir", f"/tmp/spark-warehouse-{uuid.uuid4()}") \
        .getOrCreate()
    
    yield spark
    
    # Clean up
    spark.stop()
    
@pytest.fixture(autouse=True)
def clean_temp_tables(spark):
    """Automatically clean temporary tables after each test."""
    yield
    
    # Clean up any temporary tables created during the test
    for table in spark.catalog.listTables():
        if table.isTemporary:
            spark.catalog.dropTempView(table.name)

Debugging Contract Failures in Production

When data contracts fail in production, you need systematic approaches to understand and fix the issues:

def debug_contract_failure(df: DataFrame, contract: QualityContract, output_path: str):
    """Generate detailed diagnostics for contract failures."""
    
    # Sample failing records for inspection
    sample_size = 1000
    failing_sample = df.sample(fraction=0.1).limit(sample_size)
    
    # Generate data profile
    profile = {
        "total_records": df.count(),
        "schema": df.schema.json(),
        "null_counts": {
            field.name: df.filter(col(field.name).isNull()).count() 
            for field in df.schema.fields
        },
        "sample_values": {
            field.name: [row[field.name] for row in failing_sample.select(field.name).distinct().limit(10).collect()]
            for field in df.schema.fields if field.dataType.simpleString() == "string"
        }
    }
    
    # Save diagnostic information
    with open(f"{output_path}/contract_failure_diagnostics.json", "w") as f:
        json.dump(profile, f, indent=2, default=str)
    
    # Save sample of failing records
    failing_sample.write.mode("overwrite").parquet(f"{output_path}/failing_records_sample")
    
    return profile

Summary & Next Steps

You've now built a comprehensive three-layer testing strategy that will catch pipeline failures before they impact production systems. Here's what you've accomplished:

Unit Tests: You can isolate and validate individual transformations, handling edge cases and testing business logic with controlled inputs. These tests run fast and provide immediate feedback during development.

Integration Tests: You can validate end-to-end pipeline behavior with realistic data volumes and infrastructure, catching performance issues and integration problems before deployment.

Data Contracts: You can define and enforce explicit expectations about data structure and quality, creating reliable interfaces between pipeline stages and downstream consumers.

The framework you've built scales with pipeline complexity and provides clear diagnostic information when things go wrong. Your pipelines will fail fast and fail obviously rather than silently corrupting data.

Immediate Next Steps

  1. **

Learning Path: Data Pipeline Fundamentals

Previous

Logging, Alerting, and Observability for Data Pipelines

Related Articles

Data Engineering🌱 Foundation

Logging, Alerting, and Observability for Data Pipelines

18 min
Data Engineering🔥 Expert

Data Pipeline Error Handling and Recovery Strategies

27 min
Data Engineering⚡ Practitioner

Incremental Loading Patterns: Timestamps, CDC, and Watermarks

23 min

On this page

  • Prerequisites
  • The Three-Layer Testing Strategy
  • Building Effective Unit Tests for Pipeline Components
  • Testing Pure Transformations
  • Testing Components with External Dependencies
  • Integration Testing: Validating End-to-End Pipeline Behavior
  • Setting Up Test Infrastructure
  • Testing Complete Pipeline Workflows
  • Testing with Realistic Data Volumes
  • Implementing Data Contracts for Pipeline Reliability
  • Contract Enforcement in Pipelines
  • Hands-On Exercise: Building a Complete Testing Framework
  • The Pipeline
  • Unit Tests
  • Integration Tests
  • Data Contracts
  • Common Mistakes & Troubleshooting
  • Testing with Insufficient Data Variety
  • Ignoring Performance in Integration Tests
  • Brittle Contract Validation
  • Insufficient Test Environment Isolation
  • Debugging Contract Failures in Production
  • Summary & Next Steps
  • Immediate Next Steps
  • Defining Schema Contracts
  • Quality Contracts and Data Validation
  • Contract Enforcement in Pipelines
  • Hands-On Exercise: Building a Complete Testing Framework
  • The Pipeline
  • Unit Tests
  • Integration Tests
  • Data Contracts
  • Common Mistakes & Troubleshooting
  • Testing with Insufficient Data Variety
  • Ignoring Performance in Integration Tests
  • Brittle Contract Validation
  • Insufficient Test Environment Isolation
  • Debugging Contract Failures in Production
  • Summary & Next Steps
  • Immediate Next Steps