Home AI/ML Domain Adaptation for Time-Series Anomaly Detection: Complete Implementation Guide with Full Training Scripts

Domain Adaptation for Time-Series Anomaly Detection: Complete Implementation Guide with Full Training Scripts

Last updated: May 27, 2026
k
Published April 6, 2026 · Updated May 27, 2026 · 56 min read

Summary

What this post covers: A complete, runnable implementation guide for domain-adaptive time-series anomaly detection in PyTorch, comprising nine production-ready scripts that implement DANN, MMD, and CORAL on top of a CNN-LSTM encoder for multi-channel sensor data.

Key insights:

  • Domain shift between machines, sensors, factories, or seasons routinely reduces industrial anomaly-detection AUROC from approximately 0.95 on the source to roughly 0.6 on the target, and relabeling each new domain is economically infeasible because anomalies are rare.
  • Three domain-adaptation losses cover the practical design space: DANN (adversarial, most flexible), MMD (kernel-based moment matching, simpler and more stable), and CORAL (second-order statistic alignment, with minimal hyperparameter overhead).
  • A CNN-LSTM hybrid encoder with a shared feature extractor and separate anomaly and domain heads is a strong default architecture for multi-channel time series. The CNN captures local waveform shape and the LSTM captures temporal dependencies.
  • Progressive lambda scheduling, in which the domain-adaptation weight is ramped from 0 toward 1 over training, is the single most important training practice. Without it the adversarial signal destabilizes feature learning.
  • Domain adaptation succeeds only when source and target share the same underlying anomaly mechanisms but differ in superficial signal characteristics. Fundamentally different failure modes still require labeled target data through semi-supervised adaptation.

Main topics: Introduction: The Domain Shift Problem in Anomaly Detection, Project Structure and Setup, Configuration and Hyperparameters, Generating Realistic Synthetic Data, Dataset Classes and Data Loading, The Core Model Architecture, Loss Functions: DANN, MMD, and CORAL, The Main Training Script, Evaluation and Metrics, Utility Functions, Running the Full Pipeline, Understanding the Results, Adapting to Your Own Data, Common Issues and Solutions, Putting It Together, References.

Introduction: The Domain Shift Problem in Anomaly Detection

Consider an engineer who has spent six months collecting labeled anomaly data from a CNC milling machine on the factory floor, painstakingly tagging every spindle vibration spike, every thermal drift event, and every bearing degradation signature. The resulting anomaly detection model attains 0.95 AUROC on that machine. The company subsequently acquires a second milling machine from the same manufacturer and model line, differing only in production year. The model is deployed, and the AUROC falls to 0.62—barely better than a coin flip.

This is the domain shift problem, one of the most costly difficulties in industrial machine learning. The statistical distribution of sensor readings differs between machines, factories, sensor brands, and even seasons. Noise floors vary, baseline amplitudes drift, and the boundary between “normal” and “anomalous” deforms in subtle ways. A carefully trained model becomes essentially unusable the moment it leaves its original domain.

The conventional solution is to label data in each new domain. However, labeling anomaly data is exceptionally expensive: anomalies are rare by definition, and expert annotators are scarce. A more attractive approach is to transfer anomaly-detection knowledge from a labeled source domain (machine A) to an unlabeled target domain (machine B) without re-collecting labels.

This is precisely what domain adaptation provides. By training a model to learn features that are invariant across domains—features capturing the essence of “anomaly” regardless of which machine produced the signal—an analyst can detect anomalies in new domains with little or no labeled target data. The technique originated in computer vision through the DANN paper by Ganin et al. (2016), but its application to time-series anomaly detection remains underexplored in practice, even though it is highly relevant to industrial deployment.

This post is not a theoretical survey. It is a complete, runnable implementation guide. Readers who follow it through will obtain nine production-ready Python scripts that implement three domain adaptation strategies—DANN (Domain-Adversarial Neural Networks), MMD (Maximum Mean Discrepancy), and CORAL (CORrelation ALignment)—on top of a CNN-LSTM hybrid encoder for multi-channel time-series anomaly detection. Every script is complete, with no omissions or pseudocode.

The implementation proceeds below.

Domain Shift: Source vs. Target Distribution Source Domain (Machine A—labeled) anomaly Domain Gap Target Domain (Machine B—unlabeled) ? ? Source normal Target normal Known anomaly Unlabeled (anomaly?)

Project Structure and Setup

Before writing any code, it is useful to establish a clean project layout. Each file has a single responsibility, which makes the codebase easier to understand and adapt to a specific use case.

da-anomaly-detection/
├── config.py                    # Hyperparameters and configuration
├── dataset.py                   # Dataset classes and data loading
├── model.py                     # Model architecture (encoder, classifier, discriminator)
├── losses.py                    # Loss function definitions (DANN, MMD, CORAL)
├── train.py                     # Main training script with domain adaptation
├── evaluate.py                  # Evaluation and metrics
├── utils.py                     # Utility functions (seeding, checkpoints, plotting)
├── generate_synthetic_data.py   # Generate example data for testing
├── requirements.txt             # Dependencies
├── data/                        # Generated or real data goes here
├── checkpoints/                 # Saved model weights
└── results/                     # Evaluation outputs, plots, metrics

The first step is to create the directory and install dependencies.

mkdir -p da-anomaly-detection/{data,checkpoints,results}
cd da-anomaly-detection

requirements.txt

torch>=2.0.0
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
tqdm>=4.65.0
pip install -r requirements.txt
Tip: On systems with a CUDA-capable GPU, install PyTorch with CUDA support for substantially faster training: pip install torch --index-url https://download.pytorch.org/whl/cu121

Configuration and Hyperparameters

Centralizing configuration prevents magic numbers from being scattered across the codebase. A Python dataclass is used here so that the IDE provides autocompletion and type checking without additional effort.

config.py

"""
config.py — Centralized configuration for domain-adaptive anomaly detection.
All hyperparameters live here. Override via CLI arguments in train.py.
"""

from dataclasses import dataclass, field
import torch
import os


@dataclass
class Config:
    """All hyperparameters and paths for the DA anomaly detection pipeline."""

    # --- Data Parameters ---
    num_features: int = 6           # Number of sensor channels
    window_size: int = 64           # Sliding window length (timesteps)
    stride: int = 16                # Stride for sliding window
    train_ratio: float = 0.8        # Train/val split ratio

    # --- Model Architecture ---
    cnn_channels: list = field(default_factory=lambda: [32, 64, 128])
    cnn_kernel_sizes: list = field(default_factory=lambda: [7, 5, 3])
    lstm_hidden_dim: int = 128
    lstm_num_layers: int = 2
    latent_dim: int = 128           # Dimension of the shared feature space
    classifier_hidden_dim: int = 64
    discriminator_hidden_dim: int = 64
    dropout: float = 0.3

    # --- Training Parameters ---
    batch_size: int = 64
    learning_rate: float = 1e-3
    discriminator_lr: float = 1e-3
    weight_decay: float = 1e-4
    epochs: int = 100
    patience: int = 15              # Early stopping patience

    # --- Domain Adaptation Parameters ---
    adaptation_method: str = "dann"  # 'dann', 'mmd', or 'coral'
    lambda_domain: float = 1.0       # Max domain loss weight
    lambda_recon: float = 0.5        # Reconstruction loss weight
    lambda_cls: float = 1.0          # Classification loss weight
    gamma: float = 10.0              # DANN lambda schedule steepness
    mmd_kernel_bandwidth: list = field(
        default_factory=lambda: [0.01, 0.1, 1.0, 10.0, 100.0]
    )

    # --- Anomaly Scoring ---
    alpha: float = 0.7              # Weight for classifier score vs recon error
    anomaly_threshold_percentile: float = 95.0

    # --- Paths ---
    data_dir: str = "data"
    checkpoint_dir: str = "checkpoints"
    results_dir: str = "results"

    # --- Device and Reproducibility ---
    seed: int = 42
    device: str = ""

    def __post_init__(self):
        if not self.device:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        os.makedirs(self.data_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)
Key Takeaway: The most sensitive hyperparameter in domain adaptation is lambda_domain. Too high, and the model loses its ability to classify anomalies. Too low, and domain adaptation has no effect. The progressive scheduling in the training script (the DANN lambda schedule) addresses this by starting low and ramping upward.

Generating Realistic Synthetic Data

Before working with proprietary data, a sandbox dataset is necessary. The script below generates two-domain synthetic time-series data with realistic characteristics: seasonal patterns, trends, multiple anomaly types, and domain-specific differences in noise, amplitude, and baseline offset. The source domain receives full labels, the target training set has no labels (which simulates the realistic scenario), and the target test set retains labels for evaluation purposes.

generate_synthetic_data.py

"""
generate_synthetic_data.py — Generate realistic two-domain time-series data
with injected anomalies for testing domain adaptation.

Simulates 6-channel sensor data (e.g., 3 joints x [torque, position]) from
two different machines with different noise/amplitude characteristics.
"""

import argparse
import os
import numpy as np
import pandas as pd


def generate_base_signal(n_samples: int, num_features: int, seed: int = 42) -> np.ndarray:
    """Generate a base multi-channel time-series with realistic patterns."""
    rng = np.random.RandomState(seed)
    t = np.arange(n_samples)
    signals = np.zeros((n_samples, num_features))

    for ch in range(num_features):
        freq1 = 0.002 + ch * 0.001
        freq2 = 0.01 + ch * 0.003
        phase1 = rng.uniform(0, 2 * np.pi)
        phase2 = rng.uniform(0, 2 * np.pi)

        # Seasonal component
        seasonal = 2.0 * np.sin(2 * np.pi * freq1 * t + phase1)
        # Higher-frequency oscillation
        oscillation = 0.8 * np.sin(2 * np.pi * freq2 * t + phase2)
        # Slow trend
        trend = 0.0005 * t * ((-1) ** ch)
        # Combine
        signals[:, ch] = seasonal + oscillation + trend

    return signals


def inject_anomalies(
    signals: np.ndarray,
    anomaly_ratio: float = 0.05,
    seed: int = 42
) -> tuple:
    """
    Inject multiple anomaly types into signals.
    Returns (modified_signals, labels) where labels[i]=1 means anomaly.
    """
    rng = np.random.RandomState(seed)
    n_samples, num_features = signals.shape
    labels = np.zeros(n_samples, dtype=int)
    modified = signals.copy()

    n_anomalies = int(n_samples * anomaly_ratio)
    anomaly_types = ["spike", "drift", "level_shift", "frequency_change"]

    # Choose random anomaly locations (non-overlapping segments)
    segment_length = 20
    max_start = n_samples - segment_length
    starts = rng.choice(max_start, size=n_anomalies, replace=False)

    for i, start in enumerate(starts):
        end = start + segment_length
        a_type = anomaly_types[i % len(anomaly_types)]
        channel = rng.randint(0, num_features)

        if a_type == "spike":
            spike_pos = start + rng.randint(0, segment_length)
            magnitude = rng.uniform(5, 10) * (1 if rng.random() > 0.5 else -1)
            modified[spike_pos, channel] += magnitude
            labels[spike_pos] = 1

        elif a_type == "drift":
            drift = np.linspace(0, rng.uniform(3, 6), segment_length)
            modified[start:end, channel] += drift
            labels[start:end] = 1

        elif a_type == "level_shift":
            shift = rng.uniform(3, 7) * (1 if rng.random() > 0.5 else -1)
            modified[start:end, channel] += shift
            labels[start:end] = 1

        elif a_type == "frequency_change":
            t_seg = np.arange(segment_length)
            high_freq = 2.0 * np.sin(2 * np.pi * 0.15 * t_seg)
            modified[start:end, channel] += high_freq
            labels[start:end] = 1

    return modified, labels


def apply_domain_transform(
    signals: np.ndarray,
    noise_scale: float = 0.3,
    amplitude_scale: float = 1.0,
    baseline_offset: float = 0.0,
    seed: int = 42
) -> np.ndarray:
    """Apply domain-specific transformations to simulate a different machine."""
    rng = np.random.RandomState(seed)
    transformed = signals.copy()
    n_samples, num_features = transformed.shape

    # Per-channel amplitude scaling
    for ch in range(num_features):
        ch_amp = amplitude_scale * rng.uniform(0.8, 1.2)
        ch_offset = baseline_offset + rng.uniform(-0.5, 0.5)
        transformed[:, ch] = transformed[:, ch] * ch_amp + ch_offset

    # Add domain-specific noise
    noise = rng.normal(0, noise_scale, transformed.shape)
    transformed += noise

    return transformed


def generate_dataset(
    n_samples: int,
    num_features: int,
    anomaly_ratio: float,
    noise_scale: float,
    amplitude_scale: float,
    baseline_offset: float,
    seed: int
) -> pd.DataFrame:
    """Generate a complete dataset with signals, anomalies, and domain transform."""
    base = generate_base_signal(n_samples, num_features, seed=seed)
    with_anomalies, labels = inject_anomalies(base, anomaly_ratio, seed=seed + 1)
    transformed = apply_domain_transform(
        with_anomalies,
        noise_scale=noise_scale,
        amplitude_scale=amplitude_scale,
        baseline_offset=baseline_offset,
        seed=seed + 2
    )

    columns = [f"sensor_{i}" for i in range(num_features)]
    df = pd.DataFrame(transformed, columns=columns)
    df["label"] = labels
    df["timestamp"] = pd.date_range("2024-01-01", periods=n_samples, freq="s")
    return df


def main():
    parser = argparse.ArgumentParser(
        description="Generate synthetic two-domain time-series data."
    )
    parser.add_argument("--output_dir", type=str, default="data",
                        help="Output directory for CSV files")
    parser.add_argument("--n_samples", type=int, default=20000,
                        help="Number of samples per dataset")
    parser.add_argument("--num_features", type=int, default=6,
                        help="Number of sensor channels")
    parser.add_argument("--anomaly_ratio", type=float, default=0.05,
                        help="Fraction of timesteps with anomalies")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    print("Generating source domain data (Machine A)...")
    source_full = generate_dataset(
        n_samples=args.n_samples,
        num_features=args.num_features,
        anomaly_ratio=args.anomaly_ratio,
        noise_scale=0.2,
        amplitude_scale=1.0,
        baseline_offset=0.0,
        seed=args.seed
    )
    split_idx = int(len(source_full) * 0.7)
    source_train = source_full.iloc[:split_idx].reset_index(drop=True)
    source_test = source_full.iloc[split_idx:].reset_index(drop=True)

    print("Generating target domain data (Machine B)...")
    target_full = generate_dataset(
        n_samples=args.n_samples,
        num_features=args.num_features,
        anomaly_ratio=args.anomaly_ratio,
        noise_scale=0.5,           # Higher noise
        amplitude_scale=1.4,       # Different amplitude
        baseline_offset=2.0,       # Shifted baseline
        seed=args.seed + 100
    )
    split_idx_t = int(len(target_full) * 0.7)
    target_train = target_full.iloc[:split_idx_t].reset_index(drop=True)
    target_test = target_full.iloc[split_idx_t:].reset_index(drop=True)

    # Remove labels from target train (unsupervised in target domain)
    target_train_unlabeled = target_train.drop(columns=["label"])

    # Save all files
    source_train.to_csv(os.path.join(args.output_dir, "source_train.csv"), index=False)
    source_test.to_csv(os.path.join(args.output_dir, "source_test.csv"), index=False)
    target_train_unlabeled.to_csv(os.path.join(args.output_dir, "target_train.csv"), index=False)
    target_test.to_csv(os.path.join(args.output_dir, "target_test.csv"), index=False)

    print(f"\nDatasets saved to {args.output_dir}/")
    print(f"  source_train.csv: {len(source_train)} samples, "
          f"{source_train['label'].sum()} anomalies ({source_train['label'].mean()*100:.1f}%)")
    print(f"  source_test.csv:  {len(source_test)} samples, "
          f"{source_test['label'].sum()} anomalies ({source_test['label'].mean()*100:.1f}%)")
    print(f"  target_train.csv: {len(target_train_unlabeled)} samples (no labels)")
    print(f"  target_test.csv:  {len(target_test)} samples, "
          f"{target_test['label'].sum()} anomalies ({target_test['label'].mean()*100:.1f}%)")


if __name__ == "__main__":
    main()

The script can be executed directly.

python generate_synthetic_data.py --output_dir data/ --n_samples 20000

The script produces four CSV files. The source data is fully labeled. The target training data is unlabeled, which reflects the central premise of domain adaptation. The target test data is labeled so that the effectiveness of adaptation can be measured.

Dataset Classes and Data Loading

Time-series anomaly detection operates on windows, that is, fixed-length slices of the signal. The dataset class below handles windowing, normalization (fit on source data and applied across all data), and optional data augmentation. The DomainAdaptationDataLoader pairs source and target batches for simultaneous training.

dataset.py

"""
dataset.py — PyTorch Dataset classes for time-series domain adaptation.

Handles sliding-window creation, normalization, augmentation, and
paired source-target batch generation.
"""

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader


class TimeSeriesDataset(Dataset):
    """
    Sliding-window dataset for multi-channel time-series.

    Args:
        data: numpy array of shape (n_samples, num_features)
        labels: numpy array of shape (n_samples,) or None for unlabeled data
        window_size: number of timesteps per window
        stride: step between consecutive windows
        transform: optional callable for data augmentation
    """

    def __init__(
        self,
        data: np.ndarray,
        labels: np.ndarray = None,
        window_size: int = 64,
        stride: int = 16,
        transform=None
    ):
        self.data = data.astype(np.float32)
        self.labels = labels
        self.window_size = window_size
        self.stride = stride
        self.transform = transform

        # Precompute valid window start indices
        self.indices = list(range(0, len(data) - window_size + 1, stride))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start = self.indices[idx]
        end = start + self.window_size
        window = self.data[start:end]  # (window_size, num_features)

        if self.transform is not None:
            window = self.transform(window)

        # Transpose to (num_features, window_size) for Conv1d
        window_tensor = torch.tensor(window, dtype=torch.float32).T

        if self.labels is not None:
            # Window label = 1 if any timestep in window is anomalous
            window_label = float(self.labels[start:end].max())
            return window_tensor, torch.tensor(window_label, dtype=torch.float32)
        else:
            return window_tensor, torch.tensor(-1.0, dtype=torch.float32)


class Normalizer:
    """
    Fit on source training data, transform all data.
    Uses per-channel mean and std normalization.
    """

    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, data: np.ndarray):
        """Compute mean and std from training data."""
        self.mean = data.mean(axis=0)
        self.std = data.std(axis=0)
        # Prevent division by zero
        self.std[self.std < 1e-8] = 1.0
        return self

    def transform(self, data: np.ndarray) -> np.ndarray:
        """Apply normalization."""
        return (data - self.mean) / self.std

    def fit_transform(self, data: np.ndarray) -> np.ndarray:
        """Fit and transform in one step."""
        self.fit(data)
        return self.transform(data)


class JitterTransform:
    """Add random Gaussian noise for data augmentation."""

    def __init__(self, sigma: float = 0.03):
        self.sigma = sigma

    def __call__(self, window: np.ndarray) -> np.ndarray:
        noise = np.random.normal(0, self.sigma, window.shape).astype(np.float32)
        return window + noise


class ScalingTransform:
    """Random per-channel amplitude scaling for data augmentation."""

    def __init__(self, sigma: float = 0.1):
        self.sigma = sigma

    def __call__(self, window: np.ndarray) -> np.ndarray:
        factor = np.random.normal(1.0, self.sigma, (1, window.shape[1])).astype(np.float32)
        return window * factor


class ComposeTransforms:
    """Chain multiple transforms together."""

    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, window: np.ndarray) -> np.ndarray:
        for t in self.transforms:
            window = t(window)
        return window


def load_csv_data(filepath: str, has_labels: bool = True):
    """
    Load a CSV file and separate features from labels.

    Returns:
        data: numpy array (n_samples, num_features)
        labels: numpy array (n_samples,) or None
    """
    df = pd.read_csv(filepath)
    # Drop non-numeric columns like timestamp
    feature_cols = [c for c in df.columns if c not in ("label", "timestamp")]
    data = df[feature_cols].values.astype(np.float32)
    labels = df["label"].values.astype(np.float32) if (has_labels and "label" in df.columns) else None
    return data, labels


def create_data_loaders(config) -> dict:
    """
    Create all data loaders for domain adaptation training.

    Returns a dict with keys:
        'source_train', 'source_val', 'target_train', 'target_test'
    """
    import os

    # Load raw data
    source_train_data, source_train_labels = load_csv_data(
        os.path.join(config.data_dir, "source_train.csv"), has_labels=True
    )
    source_test_data, source_test_labels = load_csv_data(
        os.path.join(config.data_dir, "source_test.csv"), has_labels=True
    )
    target_train_data, _ = load_csv_data(
        os.path.join(config.data_dir, "target_train.csv"), has_labels=False
    )
    target_test_data, target_test_labels = load_csv_data(
        os.path.join(config.data_dir, "target_test.csv"), has_labels=True
    )

    # Normalize: fit on source train only
    normalizer = Normalizer()
    source_train_data = normalizer.fit_transform(source_train_data)
    source_test_data = normalizer.transform(source_test_data)
    target_train_data = normalizer.transform(target_train_data)
    target_test_data = normalizer.transform(target_test_data)

    # Optional augmentation for training
    train_transform = ComposeTransforms([
        JitterTransform(sigma=0.03),
        ScalingTransform(sigma=0.1),
    ])

    # Create datasets
    source_train_ds = TimeSeriesDataset(
        source_train_data, source_train_labels,
        window_size=config.window_size, stride=config.stride,
        transform=train_transform
    )
    source_test_ds = TimeSeriesDataset(
        source_test_data, source_test_labels,
        window_size=config.window_size, stride=config.stride
    )
    target_train_ds = TimeSeriesDataset(
        target_train_data, labels=None,
        window_size=config.window_size, stride=config.stride,
        transform=train_transform
    )
    target_test_ds = TimeSeriesDataset(
        target_test_data, target_test_labels,
        window_size=config.window_size, stride=config.stride
    )

    # Create loaders
    loaders = {
        "source_train": DataLoader(
            source_train_ds, batch_size=config.batch_size,
            shuffle=True, drop_last=True, num_workers=0
        ),
        "source_test": DataLoader(
            source_test_ds, batch_size=config.batch_size,
            shuffle=False, num_workers=0
        ),
        "target_train": DataLoader(
            target_train_ds, batch_size=config.batch_size,
            shuffle=True, drop_last=True, num_workers=0
        ),
        "target_test": DataLoader(
            target_test_ds, batch_size=config.batch_size,
            shuffle=False, num_workers=0
        ),
    }

    return loaders, normalizer
Caution: The normalizer should always be fit on the source training data alone. Fitting on combined source and target data leaks information about the target distribution, defeats the purpose of domain adaptation, and inflates evaluation metrics.

The Core Model Architecture

The model architecture lies at the heart of the system. It comprises four components that operate in concert: a shared encoder that processes time-series windows into a fixed-size feature vector; an anomaly classifier that predicts normal versus anomaly; a reconstruction decoder that reconstructs the original input and provides an auxiliary anomaly signal; and a domain discriminator that attempts to identify which domain produced a given feature vector. The essential ingredient is the Gradient Reversal Layer (GRL), which during backpropagation reverses the sign of gradients flowing from the domain discriminator to the encoder. This compels the encoder to learn features that are maximally uninformative about domain identity, which is precisely the domain-invariant representation required.

Architecture:
                        ┌─── Anomaly Classifier (binary: normal/anomaly)
Input → Shared Encoder ─┤
  (time-series)         ├─── Reconstruction Decoder (autoencoder branch)
                        └─── Domain Discriminator (with gradient reversal)

model.py

"""
model.py — Domain-adaptive anomaly detection model architecture.

Components:
  - GradientReversalLayer: reverses gradients for adversarial domain adaptation
  - SharedEncoder: CNN + BiLSTM feature extractor
  - AnomalyClassifier: binary classification head
  - ReconstructionDecoder: autoencoder branch for reconstruction-based scoring
  - DomainDiscriminator: adversarial domain classification head
  - DomainAdaptiveAnomalyDetector: full model combining all components
"""

import torch
import torch.nn as nn
from torch.autograd import Function


class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer (GRL) — Ganin et al., 2016.
    Forward pass: identity.
    Backward pass: negate gradients and scale by lambda.
    """

    @staticmethod
    def forward(ctx, x, lambda_val):
        ctx.lambda_val = lambda_val
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_val * grad_output, None


class GradientReversalLayer(nn.Module):
    """Module wrapper for the gradient reversal function."""

    def __init__(self, lambda_val: float = 1.0):
        super().__init__()
        self.lambda_val = lambda_val

    def set_lambda(self, lambda_val: float):
        self.lambda_val = lambda_val

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_val)


class SharedEncoder(nn.Module):
    """
    1D-CNN + Bidirectional LSTM encoder for multi-channel time-series.

    Input shape:  (batch, num_features, window_size)
    Output shape: (batch, latent_dim)
    """

    def __init__(
        self,
        num_features: int = 6,
        cnn_channels: list = None,
        cnn_kernel_sizes: list = None,
        lstm_hidden_dim: int = 128,
        lstm_num_layers: int = 2,
        latent_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()
        if cnn_channels is None:
            cnn_channels = [32, 64, 128]
        if cnn_kernel_sizes is None:
            cnn_kernel_sizes = [7, 5, 3]

        # Build CNN layers
        cnn_layers = []
        in_channels = num_features
        for out_ch, ks in zip(cnn_channels, cnn_kernel_sizes):
            cnn_layers.extend([
                nn.Conv1d(in_channels, out_ch, kernel_size=ks, padding=ks // 2),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
            ])
            in_channels = out_ch
        self.cnn = nn.Sequential(*cnn_layers)

        # Bidirectional LSTM on top of CNN features
        self.lstm = nn.LSTM(
            input_size=cnn_channels[-1],
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if lstm_num_layers > 1 else 0.0,
        )

        # Project to latent space
        self.fc = nn.Sequential(
            nn.Linear(lstm_hidden_dim * 2, latent_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
        )
        self.latent_dim = latent_dim

    def forward(self, x):
        """
        Args:
            x: (batch, num_features, window_size)
        Returns:
            latent: (batch, latent_dim)
        """
        # CNN: (batch, cnn_channels[-1], window_size)
        cnn_out = self.cnn(x)
        # Transpose for LSTM: (batch, window_size, cnn_channels[-1])
        lstm_in = cnn_out.permute(0, 2, 1)
        # LSTM: (batch, window_size, lstm_hidden*2)
        lstm_out, _ = self.lstm(lstm_in)
        # Take last timestep output
        last_hidden = lstm_out[:, -1, :]
        # Project to latent space
        latent = self.fc(last_hidden)
        return latent


class AnomalyClassifier(nn.Module):
    """
    Binary classification head: normal (0) vs anomaly (1).

    Input:  (batch, latent_dim)
    Output: (batch, 1) — sigmoid logit
    """

    def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
        )

    def forward(self, latent):
        return self.net(latent)


class ReconstructionDecoder(nn.Module):
    """
    Decoder that reconstructs the original input from latent features.
    Uses LSTM + transposed Conv1d layers.

    Input:  (batch, latent_dim)
    Output: (batch, num_features, window_size)
    """

    def __init__(
        self,
        latent_dim: int = 128,
        num_features: int = 6,
        window_size: int = 64,
        lstm_hidden_dim: int = 128,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.window_size = window_size
        self.num_features = num_features
        self.lstm_hidden_dim = lstm_hidden_dim

        # Expand latent to sequence
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, lstm_hidden_dim),
            nn.ReLU(inplace=True),
        )

        # LSTM decoder
        self.lstm = nn.LSTM(
            input_size=lstm_hidden_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=1,
            batch_first=True,
        )

        # Transposed convolutions to reconstruct
        self.deconv = nn.Sequential(
            nn.ConvTranspose1d(lstm_hidden_dim, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.ConvTranspose1d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose1d(32, num_features, kernel_size=3, padding=1),
        )

    def forward(self, latent):
        """
        Args:
            latent: (batch, latent_dim)
        Returns:
            reconstruction: (batch, num_features, window_size)
        """
        batch_size = latent.size(0)
        # Expand to sequence
        expanded = self.fc(latent).unsqueeze(1).repeat(1, self.window_size, 1)
        # LSTM decode
        lstm_out, _ = self.lstm(expanded)
        # Transpose for Conv1d: (batch, lstm_hidden, window_size)
        conv_in = lstm_out.permute(0, 2, 1)
        # Reconstruct
        reconstruction = self.deconv(conv_in)
        return reconstruction


class DomainDiscriminator(nn.Module):
    """
    Domain classification head with Gradient Reversal Layer.
    Classifies whether features came from source (0) or target (1) domain.

    Input:  (batch, latent_dim)
    Output: (batch, 1) — domain logit
    """

    def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
        super().__init__()
        self.grl = GradientReversalLayer(lambda_val=1.0)
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
        )

    def set_lambda(self, lambda_val: float):
        self.grl.set_lambda(lambda_val)

    def forward(self, latent):
        reversed_features = self.grl(latent)
        return self.net(reversed_features)


class DomainAdaptiveAnomalyDetector(nn.Module):
    """
    Full domain-adaptive anomaly detection model.
    Combines encoder, anomaly classifier, reconstruction decoder,
    and domain discriminator.
    """

    def __init__(self, config):
        super().__init__()
        self.encoder = SharedEncoder(
            num_features=config.num_features,
            cnn_channels=config.cnn_channels,
            cnn_kernel_sizes=config.cnn_kernel_sizes,
            lstm_hidden_dim=config.lstm_hidden_dim,
            lstm_num_layers=config.lstm_num_layers,
            latent_dim=config.latent_dim,
            dropout=config.dropout,
        )
        self.classifier = AnomalyClassifier(
            latent_dim=config.latent_dim,
            hidden_dim=config.classifier_hidden_dim,
            dropout=config.dropout,
        )
        self.decoder = ReconstructionDecoder(
            latent_dim=config.latent_dim,
            num_features=config.num_features,
            window_size=config.window_size,
            lstm_hidden_dim=config.lstm_hidden_dim,
            dropout=config.dropout,
        )
        self.discriminator = DomainDiscriminator(
            latent_dim=config.latent_dim,
            hidden_dim=config.discriminator_hidden_dim,
            dropout=config.dropout,
        )

    def set_domain_lambda(self, lambda_val: float):
        """Update the GRL lambda for progressive scheduling."""
        self.discriminator.set_lambda(lambda_val)

    def forward(self, x):
        """
        Full forward pass.

        Args:
            x: (batch, num_features, window_size)

        Returns:
            anomaly_logits:  (batch, 1) — raw logits for anomaly classification
            reconstruction:  (batch, num_features, window_size) — reconstructed input
            domain_logits:   (batch, 1) — raw logits for domain classification
            latent_features: (batch, latent_dim) — shared latent representation
        """
        latent = self.encoder(x)
        anomaly_logits = self.classifier(latent)
        reconstruction = self.decoder(latent)
        domain_logits = self.discriminator(latent)
        return anomaly_logits, reconstruction, domain_logits, latent
Key Takeaway: The Gradient Reversal Layer consists of only two lines of custom autograd code, yet it constitutes the entire mechanism that makes DANN function. The forward pass is the identity. The backward pass negates the gradient. This simple operation converts a standard domain classifier into an adversarial training signal that compels the encoder to produce domain-invariant features.

Loss Functions: DANN, MMD, and CORAL

Domain adaptation is not a single technique but a family of techniques, each with distinct strengths. The implementation below supports three approaches selectable through a single configuration flag. DANN uses adversarial training based on the discriminator. MMD directly minimizes the statistical distance between source and target feature distributions through a kernel formulation. CORAL aligns the second-order statistics (covariance matrices) of the two domains. Switching between the methods requires a single configuration change.

losses.py

"""
losses.py — Loss functions for domain-adaptive anomaly detection.

Includes:
  - AnomalyDetectionLoss (BCE for anomaly classification)
  - ReconstructionLoss (MSE for autoencoder)
  - DomainAdversarialLoss (BCE for domain discrimination)
  - MMDLoss (Maximum Mean Discrepancy with Gaussian kernel)
  - CORALLoss (CORrelation ALignment)
  - CombinedLoss (weighted combination of all losses)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class AnomalyDetectionLoss(nn.Module):
    """Binary cross-entropy loss for anomaly classification."""

    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, logits, labels):
        """
        Args:
            logits: (batch, 1) raw anomaly logits
            labels: (batch,) binary labels (0=normal, 1=anomaly)
        """
        return self.bce(logits.squeeze(-1), labels)


class ReconstructionLoss(nn.Module):
    """MSE loss between input and reconstruction."""

    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, reconstruction, original):
        """
        Args:
            reconstruction: (batch, num_features, window_size)
            original: (batch, num_features, window_size)
        """
        return self.mse(reconstruction, original)


class DomainAdversarialLoss(nn.Module):
    """BCE loss for domain classification (used with GRL for DANN)."""

    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, domain_logits, domain_labels):
        """
        Args:
            domain_logits: (batch, 1) raw domain logits
            domain_labels: (batch,) domain labels (0=source, 1=target)
        """
        return self.bce(domain_logits.squeeze(-1), domain_labels)


class MMDLoss(nn.Module):
    """
    Maximum Mean Discrepancy loss with multi-scale Gaussian kernel.

    Measures the distance between source and target feature distributions
    in a reproducing kernel Hilbert space (RKHS).
    """

    def __init__(self, kernel_bandwidths: list = None):
        super().__init__()
        if kernel_bandwidths is None:
            self.kernel_bandwidths = [0.01, 0.1, 1.0, 10.0, 100.0]
        else:
            self.kernel_bandwidths = kernel_bandwidths

    def gaussian_kernel(self, x, y):
        """
        Compute multi-scale Gaussian kernel matrix between x and y.

        Args:
            x: (n, d) tensor
            y: (m, d) tensor
        Returns:
            kernel_val: scalar — sum of Gaussian kernel values across bandwidths
        """
        # Pairwise squared distances
        xx = torch.mm(x, x.t())
        yy = torch.mm(y, y.t())
        xy = torch.mm(x, y.t())

        rx = xx.diag().unsqueeze(0).expand_as(xx)
        ry = yy.diag().unsqueeze(0).expand_as(yy)

        dxx = rx.t() + rx - 2.0 * xx
        dyy = ry.t() + ry - 2.0 * yy
        dxy = rx.t() + ry - 2.0 * xy

        k_xx = torch.zeros_like(xx)
        k_yy = torch.zeros_like(yy)
        k_xy = torch.zeros_like(xy)

        for bw in self.kernel_bandwidths:
            k_xx += torch.exp(-dxx / (2.0 * bw))
            k_yy += torch.exp(-dyy / (2.0 * bw))
            k_xy += torch.exp(-dxy / (2.0 * bw))

        return k_xx, k_yy, k_xy

    def forward(self, source_features, target_features):
        """
        Compute MMD^2 between source and target feature distributions.

        Args:
            source_features: (n, d) latent features from source domain
            target_features:  (m, d) latent features from target domain
        Returns:
            mmd_loss: scalar
        """
        n = source_features.size(0)
        m = target_features.size(0)

        k_xx, k_yy, k_xy = self.gaussian_kernel(source_features, target_features)

        mmd = (k_xx.sum() / (n * n)
               + k_yy.sum() / (m * m)
               - 2.0 * k_xy.sum() / (n * m))

        return mmd


class CORALLoss(nn.Module):
    """
    CORrelation ALignment loss.

    Aligns the second-order statistics (covariance matrices) of
    source and target feature distributions.
    """

    def __init__(self):
        super().__init__()

    def forward(self, source_features, target_features):
        """
        Compute CORAL loss.

        Args:
            source_features: (n, d) latent features from source domain
            target_features:  (m, d) latent features from target domain
        Returns:
            coral_loss: scalar
        """
        d = source_features.size(1)
        n_s = source_features.size(0)
        n_t = target_features.size(0)

        # Compute covariance matrices
        source_centered = source_features - source_features.mean(dim=0, keepdim=True)
        target_centered = target_features - target_features.mean(dim=0, keepdim=True)

        cov_source = (source_centered.t() @ source_centered) / (n_s - 1)
        cov_target = (target_centered.t() @ target_centered) / (n_t - 1)

        # Frobenius norm of covariance difference
        diff = cov_source - cov_target
        coral_loss = (diff * diff).sum() / (4 * d * d)

        return coral_loss


class CombinedLoss(nn.Module):
    """
    Combines anomaly detection, reconstruction, and domain adaptation losses.

    total_loss = lambda_cls * anomaly_loss
               + lambda_recon * recon_loss
               + lambda_domain * domain_loss

    The domain_loss component uses DANN, MMD, or CORAL depending on config.
    """

    def __init__(self, config):
        super().__init__()
        self.anomaly_loss_fn = AnomalyDetectionLoss()
        self.recon_loss_fn = ReconstructionLoss()
        self.dann_loss_fn = DomainAdversarialLoss()
        self.mmd_loss_fn = MMDLoss(kernel_bandwidths=config.mmd_kernel_bandwidth)
        self.coral_loss_fn = CORALLoss()

        self.lambda_cls = config.lambda_cls
        self.lambda_recon = config.lambda_recon
        self.lambda_domain = config.lambda_domain
        self.method = config.adaptation_method

    def forward(
        self,
        anomaly_logits,
        anomaly_labels,
        reconstruction,
        original,
        domain_logits=None,
        domain_labels=None,
        source_features=None,
        target_features=None,
        current_lambda=None,
    ):
        """
        Compute combined loss.

        Args:
            anomaly_logits: (batch, 1) anomaly classification logits (source only)
            anomaly_labels: (batch,) anomaly labels (source only)
            reconstruction: (batch, num_features, window_size) reconstruction
            original: (batch, num_features, window_size) original input
            domain_logits: (batch, 1) domain logits (DANN only)
            domain_labels: (batch,) domain labels (DANN only)
            source_features: (n, d) source latent features (MMD/CORAL)
            target_features: (m, d) target latent features (MMD/CORAL)
            current_lambda: float — current domain adaptation weight

        Returns:
            total_loss, loss_dict (breakdown of individual losses)
        """
        domain_weight = current_lambda if current_lambda is not None else self.lambda_domain

        # Anomaly classification loss (source only)
        cls_loss = self.anomaly_loss_fn(anomaly_logits, anomaly_labels)

        # Reconstruction loss (both domains)
        recon_loss = self.recon_loss_fn(reconstruction, original)

        # Domain adaptation loss
        if self.method == "dann" and domain_logits is not None:
            domain_loss = self.dann_loss_fn(domain_logits, domain_labels)
        elif self.method == "mmd" and source_features is not None:
            domain_loss = self.mmd_loss_fn(source_features, target_features)
        elif self.method == "coral" and source_features is not None:
            domain_loss = self.coral_loss_fn(source_features, target_features)
        else:
            domain_loss = torch.tensor(0.0, device=anomaly_logits.device)

        total_loss = (
            self.lambda_cls * cls_loss
            + self.lambda_recon * recon_loss
            + domain_weight * domain_loss
        )

        loss_dict = {
            "total": total_loss.item(),
            "classification": cls_loss.item(),
            "reconstruction": recon_loss.item(),
            "domain": domain_loss.item(),
        }

        return total_loss, loss_dict

The Main Training Script

The training script integrates the entire system. The training loop coordinates the simultaneous optimization of the anomaly classifier on labeled source data, the reconstruction decoder on both domains, and the domain discriminator (adversarially) on both domains. The DANN lambda schedule progressively increases the strength of domain adaptation across training, following the formula from the original paper: λp = 2 / (1 + exp(-γ · p)) - 1, where p denotes training progress from 0 to 1.

train.py

"""
train.py — Main training script for domain-adaptive anomaly detection.

Supports three adaptation methods: DANN, MMD, CORAL.
Uses progressive lambda scheduling for stable training.
"""

import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from losses import CombinedLoss
from utils import (
    set_seed,
    EarlyStopping,
    save_checkpoint,
    MetricLogger,
)


def compute_dann_lambda(epoch: int, total_epochs: int, gamma: float = 10.0) -> float:
    """
    Progressive lambda schedule from the DANN paper (Ganin et al., 2016).
    Ramps from 0 to 1 over training using a sigmoid-like schedule.

    lambda_p = 2 / (1 + exp(-gamma * p)) - 1, where p = epoch / total_epochs
    """
    p = epoch / total_epochs
    return float(2.0 / (1.0 + np.exp(-gamma * p)) - 1.0)


def train_one_epoch(
    model,
    source_loader,
    target_loader,
    criterion,
    optimizer,
    device,
    epoch,
    total_epochs,
    config,
):
    """Train for one epoch with domain adaptation."""
    model.train()
    epoch_losses = {"total": 0, "classification": 0, "reconstruction": 0, "domain": 0}
    n_batches = 0

    # Compute current domain adaptation lambda
    current_lambda = compute_dann_lambda(epoch, total_epochs, config.gamma) * config.lambda_domain

    # Set the GRL lambda in the model
    model.set_domain_lambda(current_lambda)

    # Zip source and target loaders (cycle the shorter one)
    target_iter = iter(target_loader)

    for source_batch, source_labels in source_loader:
        # Get target batch (cycle if exhausted)
        try:
            target_batch, _ = next(target_iter)
        except StopIteration:
            target_iter = iter(target_loader)
            target_batch, _ = next(target_iter)

        source_batch = source_batch.to(device)
        source_labels = source_labels.to(device)
        target_batch = target_batch.to(device)

        # Determine actual batch sizes (may differ)
        bs_s = source_batch.size(0)
        bs_t = target_batch.size(0)

        # Forward pass: source domain
        s_anomaly_logits, s_recon, s_domain_logits, s_latent = model(source_batch)

        # Forward pass: target domain
        t_anomaly_logits, t_recon, t_domain_logits, t_latent = model(target_batch)

        # Combine reconstructions and originals for loss
        all_recon = torch.cat([s_recon, t_recon], dim=0)
        all_original = torch.cat([source_batch, target_batch], dim=0)

        # Domain labels: 0 for source, 1 for target
        domain_labels = torch.cat([
            torch.zeros(bs_s, device=device),
            torch.ones(bs_t, device=device),
        ])
        all_domain_logits = torch.cat([s_domain_logits, t_domain_logits], dim=0)

        # Compute combined loss
        total_loss, loss_dict = criterion(
            anomaly_logits=s_anomaly_logits,
            anomaly_labels=source_labels,
            reconstruction=all_recon,
            original=all_original,
            domain_logits=all_domain_logits,
            domain_labels=domain_labels,
            source_features=s_latent,
            target_features=t_latent,
            current_lambda=current_lambda,
        )

        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Accumulate losses
        for key in epoch_losses:
            epoch_losses[key] += loss_dict[key]
        n_batches += 1

    # Average losses
    for key in epoch_losses:
        epoch_losses[key] /= max(n_batches, 1)

    epoch_losses["lambda"] = current_lambda
    return epoch_losses


@torch.no_grad()
def validate(model, loader, criterion, device, config):
    """Validate on a labeled dataset (source test or target test)."""
    model.eval()
    all_logits = []
    all_labels = []
    total_recon_loss = 0
    n_batches = 0

    for batch, labels in loader:
        batch = batch.to(device)
        labels = labels.to(device)

        anomaly_logits, recon, _, latent = model(batch)
        recon_loss = nn.MSELoss()(recon, batch)

        all_logits.append(anomaly_logits.squeeze(-1).cpu())
        all_labels.append(labels.cpu())
        total_recon_loss += recon_loss.item()
        n_batches += 1

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)

    # Compute metrics
    probs = torch.sigmoid(all_logits)
    preds = (probs > 0.5).float()
    accuracy = (preds == all_labels).float().mean().item()

    from sklearn.metrics import roc_auc_score, f1_score
    try:
        auroc = roc_auc_score(all_labels.numpy(), probs.numpy())
    except ValueError:
        auroc = 0.5  # Only one class present
    f1 = f1_score(all_labels.numpy(), preds.numpy(), zero_division=0)

    return {
        "accuracy": accuracy,
        "auroc": auroc,
        "f1": f1,
        "recon_loss": total_recon_loss / max(n_batches, 1),
    }


def main():
    parser = argparse.ArgumentParser(description="Train domain-adaptive anomaly detector")
    parser.add_argument("--method", type=str, default="dann",
                        choices=["dann", "mmd", "coral"],
                        help="Domain adaptation method")
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--lr", type=float, default=None)
    parser.add_argument("--lambda_domain", type=float, default=None)
    parser.add_argument("--lambda_recon", type=float, default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--data_dir", type=str, default=None)
    parser.add_argument("--device", type=str, default=None)
    args = parser.parse_args()

    # Build config with CLI overrides
    config = Config()
    config.adaptation_method = args.method
    if args.epochs is not None:
        config.epochs = args.epochs
    if args.batch_size is not None:
        config.batch_size = args.batch_size
    if args.lr is not None:
        config.learning_rate = args.lr
    if args.lambda_domain is not None:
        config.lambda_domain = args.lambda_domain
    if args.lambda_recon is not None:
        config.lambda_recon = args.lambda_recon
    if args.seed is not None:
        config.seed = args.seed
    if args.data_dir is not None:
        config.data_dir = args.data_dir
    if args.device is not None:
        config.device = args.device

    # Setup
    set_seed(config.seed)
    device = torch.device(config.device)
    print(f"Using device: {device}")
    print(f"Adaptation method: {config.adaptation_method}")
    print(f"Epochs: {config.epochs}, Batch size: {config.batch_size}, LR: {config.learning_rate}")

    # Data
    print("\nLoading data...")
    loaders, normalizer = create_data_loaders(config)
    print(f"Source train batches: {len(loaders['source_train'])}")
    print(f"Target train batches: {len(loaders['target_train'])}")

    # Model
    model = DomainAdaptiveAnomalyDetector(config).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {total_params:,}")

    # Optimizer (single optimizer for simplicity; separate LRs via param groups)
    optimizer = Adam([
        {"params": model.encoder.parameters(), "lr": config.learning_rate},
        {"params": model.classifier.parameters(), "lr": config.learning_rate},
        {"params": model.decoder.parameters(), "lr": config.learning_rate},
        {"params": model.discriminator.parameters(), "lr": config.discriminator_lr},
    ], weight_decay=config.weight_decay)

    scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)

    # Loss
    criterion = CombinedLoss(config)

    # Early stopping
    early_stopping = EarlyStopping(patience=config.patience, mode="max")

    # Logging
    logger = MetricLogger(config.results_dir)

    # Training loop
    best_target_auroc = 0.0
    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60)

    for epoch in range(config.epochs):
        start_time = time.time()

        # Train
        train_losses = train_one_epoch(
            model, loaders["source_train"], loaders["target_train"],
            criterion, optimizer, device, epoch, config.epochs, config
        )

        # Validate on source test
        source_metrics = validate(model, loaders["source_test"], criterion, device, config)

        # Evaluate on target test (the real metric we care about)
        target_metrics = validate(model, loaders["target_test"], criterion, device, config)

        scheduler.step()

        elapsed = time.time() - start_time

        # Log
        logger.log(epoch, train_losses, source_metrics, target_metrics)

        # Print progress
        if epoch % 5 == 0 or epoch == config.epochs - 1:
            print(
                f"Epoch {epoch:3d}/{config.epochs} ({elapsed:.1f}s) | "
                f"Loss: {train_losses['total']:.4f} "
                f"[cls={train_losses['classification']:.4f}, "
                f"rec={train_losses['reconstruction']:.4f}, "
                f"dom={train_losses['domain']:.4f}] | "
                f"λ={train_losses['lambda']:.3f} | "
                f"Src AUROC: {source_metrics['auroc']:.4f} | "
                f"Tgt AUROC: {target_metrics['auroc']:.4f}"
            )

        # Save best model (based on target AUROC)
        if target_metrics["auroc"] > best_target_auroc:
            best_target_auroc = target_metrics["auroc"]
            save_checkpoint(
                model, optimizer, epoch, target_metrics,
                os.path.join(config.checkpoint_dir, "best_model.pt")
            )

        # Early stopping on target AUROC
        if early_stopping.step(target_metrics["auroc"]):
            print(f"\nEarly stopping triggered at epoch {epoch}")
            break

    print("\n" + "=" * 60)
    print(f"Training complete. Best target AUROC: {best_target_auroc:.4f}")
    print(f"Best model saved to: {config.checkpoint_dir}/best_model.pt")
    print("=" * 60)

    # Save training curves
    logger.save()
    logger.plot_training_curves()


if __name__ == "__main__":
    main()
Tip: The metric of primary interest is the target AUROC, not the source AUROC. Source AUROC indicates only that the model can classify anomalies where labels are available, which is the expected baseline. Target AUROC reveals whether domain adaptation is actually transferring anomaly-detection knowledge to the unlabeled domain.

Evaluation and Metrics

After training, rigorous evaluation on the target domain is required. The evaluation script computes standard anomaly-detection metrics, combines classifier and reconstruction scores, implements multiple threshold strategies, and produces diagnostic plots. This is the stage at which the success of domain adaptation can be assessed.

evaluate.py

"""
evaluate.py — Evaluation script for domain-adaptive anomaly detection.

Loads a trained model and evaluates on target domain test data.
Computes AUROC, AUPRC, F1, precision, recall.
Generates diagnostic plots and saves results to JSON.
"""

import argparse
import json
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
    confusion_matrix,
    roc_curve,
    precision_recall_curve,
)

from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from utils import set_seed, load_checkpoint


def compute_anomaly_scores(model, loader, device, alpha=0.7):
    """
    Compute anomaly scores combining classifier output and reconstruction error.

    anomaly_score = alpha * classifier_prob + (1 - alpha) * normalized_recon_error

    Returns:
        scores: numpy array of anomaly scores
        labels: numpy array of ground truth labels
        recon_errors: numpy array of per-sample reconstruction errors
        classifier_probs: numpy array of classifier probabilities
        latent_features: numpy array of latent features (for t-SNE)
    """
    model.eval()
    all_probs = []
    all_labels = []
    all_recon_errors = []
    all_latent = []

    with torch.no_grad():
        for batch, labels in loader:
            batch = batch.to(device)
            anomaly_logits, recon, _, latent = model(batch)

            # Classifier probability
            probs = torch.sigmoid(anomaly_logits.squeeze(-1))

            # Per-sample reconstruction error (mean across features and time)
            recon_error = ((recon - batch) ** 2).mean(dim=(1, 2))

            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.numpy())
            all_recon_errors.append(recon_error.cpu().numpy())
            all_latent.append(latent.cpu().numpy())

    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    all_recon_errors = np.concatenate(all_recon_errors)
    all_latent = np.concatenate(all_latent)

    # Normalize reconstruction errors to [0, 1]
    re_min, re_max = all_recon_errors.min(), all_recon_errors.max()
    if re_max - re_min > 1e-8:
        norm_recon = (all_recon_errors - re_min) / (re_max - re_min)
    else:
        norm_recon = np.zeros_like(all_recon_errors)

    # Combined anomaly score
    scores = alpha * all_probs + (1 - alpha) * norm_recon

    return scores, all_labels, all_recon_errors, all_probs, all_latent


def find_optimal_threshold(labels, scores):
    """Find the threshold that maximizes F1 score."""
    thresholds = np.linspace(0, 1, 200)
    best_f1 = 0
    best_thresh = 0.5

    for thresh in thresholds:
        preds = (scores >= thresh).astype(int)
        f1 = f1_score(labels, preds, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh

    return best_thresh, best_f1


def compute_all_metrics(labels, scores, threshold):
    """Compute all evaluation metrics at a given threshold."""
    preds = (scores >= threshold).astype(int)
    metrics = {
        "auroc": float(roc_auc_score(labels, scores)),
        "auprc": float(average_precision_score(labels, scores)),
        "f1": float(f1_score(labels, preds, zero_division=0)),
        "precision": float(precision_score(labels, preds, zero_division=0)),
        "recall": float(recall_score(labels, preds, zero_division=0)),
        "accuracy": float(accuracy_score(labels, preds)),
        "threshold": float(threshold),
    }

    cm = confusion_matrix(labels, preds)
    metrics["confusion_matrix"] = cm.tolist()
    metrics["true_negatives"] = int(cm[0, 0])
    metrics["false_positives"] = int(cm[0, 1])
    metrics["false_negatives"] = int(cm[1, 0])
    metrics["true_positives"] = int(cm[1, 1])

    return metrics


def plot_roc_curve(labels, scores, save_path):
    """Plot and save ROC curve."""
    fpr, tpr, _ = roc_curve(labels, scores)
    auroc = roc_auc_score(labels, scores)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(fpr, tpr, "b-", linewidth=2, label=f"AUROC = {auroc:.4f}")
    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Random")
    ax.set_xlabel("False Positive Rate", fontsize=12)
    ax.set_ylabel("True Positive Rate", fontsize=12)
    ax.set_title("ROC Curve — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"ROC curve saved to {save_path}")


def plot_pr_curve(labels, scores, save_path):
    """Plot and save Precision-Recall curve."""
    precision, recall, _ = precision_recall_curve(labels, scores)
    auprc = average_precision_score(labels, scores)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(recall, precision, "r-", linewidth=2, label=f"AUPRC = {auprc:.4f}")
    baseline = labels.sum() / len(labels)
    ax.axhline(y=baseline, color="k", linestyle="--", alpha=0.5, label=f"Baseline = {baseline:.3f}")
    ax.set_xlabel("Recall", fontsize=12)
    ax.set_ylabel("Precision", fontsize=12)
    ax.set_title("Precision-Recall Curve — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"PR curve saved to {save_path}")


def plot_score_distribution(labels, scores, threshold, save_path):
    """Plot anomaly score distribution for normal vs anomaly samples."""
    fig, ax = plt.subplots(figsize=(10, 6))

    normal_scores = scores[labels == 0]
    anomaly_scores = scores[labels == 1]

    ax.hist(normal_scores, bins=50, alpha=0.6, color="steelblue", label="Normal", density=True)
    ax.hist(anomaly_scores, bins=50, alpha=0.6, color="indianred", label="Anomaly", density=True)
    ax.axvline(x=threshold, color="black", linestyle="--", linewidth=2,
               label=f"Threshold = {threshold:.3f}")
    ax.set_xlabel("Anomaly Score", fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.set_title("Anomaly Score Distribution — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"Score distribution saved to {save_path}")


def plot_reconstruction_error(recon_errors, labels, save_path):
    """Plot reconstruction error over sample index, colored by label."""
    fig, ax = plt.subplots(figsize=(14, 5))

    indices = np.arange(len(recon_errors))
    normal_mask = labels == 0
    anomaly_mask = labels == 1

    ax.scatter(indices[normal_mask], recon_errors[normal_mask],
               s=2, alpha=0.4, c="steelblue", label="Normal")
    ax.scatter(indices[anomaly_mask], recon_errors[anomaly_mask],
               s=8, alpha=0.8, c="indianred", label="Anomaly")
    ax.set_xlabel("Sample Index", fontsize=12)
    ax.set_ylabel("Reconstruction Error", fontsize=12)
    ax.set_title("Reconstruction Error Over Time — Target Domain", fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"Reconstruction error plot saved to {save_path}")


def main():
    parser = argparse.ArgumentParser(description="Evaluate domain-adaptive anomaly detector")
    parser.add_argument("--checkpoint", type=str,
                        default="checkpoints/best_model.pt",
                        help="Path to model checkpoint")
    parser.add_argument("--data_dir", type=str, default="data",
                        help="Data directory")
    parser.add_argument("--results_dir", type=str, default="results",
                        help="Output directory for results")
    parser.add_argument("--alpha", type=float, default=0.7,
                        help="Weight for classifier score vs recon error")
    parser.add_argument("--method", type=str, default="dann",
                        choices=["dann", "mmd", "coral"])
    parser.add_argument("--device", type=str, default="")
    args = parser.parse_args()

    config = Config()
    config.data_dir = args.data_dir
    config.results_dir = args.results_dir
    config.adaptation_method = args.method
    if args.device:
        config.device = args.device

    set_seed(config.seed)
    device = torch.device(config.device)
    os.makedirs(config.results_dir, exist_ok=True)

    print(f"Device: {device}")
    print(f"Loading checkpoint: {args.checkpoint}")

    # Load model
    model = DomainAdaptiveAnomalyDetector(config).to(device)
    checkpoint = load_checkpoint(args.checkpoint, model, device=device)
    print(f"Loaded model from epoch {checkpoint.get('epoch', '?')}")

    # Load data
    loaders, normalizer = create_data_loaders(config)

    # --- Evaluate on target test set ---
    print("\n--- Target Domain Evaluation ---")
    scores, labels, recon_errors, probs, latent_features = compute_anomaly_scores(
        model, loaders["target_test"], device, alpha=args.alpha
    )

    # Find optimal threshold
    optimal_thresh, optimal_f1 = find_optimal_threshold(labels, scores)
    print(f"Optimal threshold: {optimal_thresh:.4f} (F1 = {optimal_f1:.4f})")

    # Percentile-based threshold
    percentile_thresh = np.percentile(scores, config.anomaly_threshold_percentile)
    print(f"Percentile ({config.anomaly_threshold_percentile}%) threshold: {percentile_thresh:.4f}")

    # Compute metrics at optimal threshold
    metrics_optimal = compute_all_metrics(labels, scores, optimal_thresh)
    metrics_optimal["threshold_method"] = "f1_optimal"

    # Compute metrics at percentile threshold
    metrics_percentile = compute_all_metrics(labels, scores, percentile_thresh)
    metrics_percentile["threshold_method"] = "percentile"

    # Print results
    print(f"\n{'Metric':<20} {'F1-Optimal':>12} {'Percentile':>12}")
    print("-" * 46)
    for key in ["auroc", "auprc", "f1", "precision", "recall", "accuracy"]:
        print(f"{key:<20} {metrics_optimal[key]:>12.4f} {metrics_percentile[key]:>12.4f}")

    # Also evaluate on source test for comparison
    print("\n--- Source Domain Evaluation (baseline) ---")
    src_scores, src_labels, _, _, src_latent = compute_anomaly_scores(
        model, loaders["source_test"], device, alpha=args.alpha
    )
    src_thresh, _ = find_optimal_threshold(src_labels, src_scores)
    src_metrics = compute_all_metrics(src_labels, src_scores, src_thresh)
    print(f"Source AUROC: {src_metrics['auroc']:.4f}, F1: {src_metrics['f1']:.4f}")

    # --- Generate plots ---
    print("\nGenerating plots...")
    plot_roc_curve(labels, scores, os.path.join(config.results_dir, "roc_curve.png"))
    plot_pr_curve(labels, scores, os.path.join(config.results_dir, "pr_curve.png"))
    plot_score_distribution(labels, scores, optimal_thresh,
                           os.path.join(config.results_dir, "score_distribution.png"))
    plot_reconstruction_error(recon_errors, labels,
                             os.path.join(config.results_dir, "recon_error.png"))

    # --- Save results ---
    results = {
        "method": config.adaptation_method,
        "alpha": args.alpha,
        "target_metrics_optimal": metrics_optimal,
        "target_metrics_percentile": metrics_percentile,
        "source_metrics": src_metrics,
    }
    results_path = os.path.join(config.results_dir, "evaluation_results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {results_path}")


if __name__ == "__main__":
    main()

Utility Functions

The utility module handles reproducibility, early stopping, checkpointing, metric logging, and visualization, including t-SNE plots of feature distributions.

utils.py

"""
utils.py — Utility functions for the DA anomaly detection pipeline.

Includes:
  - Seed setting for reproducibility
  - EarlyStopping class
  - Checkpoint save/load
  - MetricLogger with CSV output and plotting
  - t-SNE visualization of domain features
"""

import os
import random
import json
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class EarlyStopping:
    """
    Early stopping to halt training when a metric stops improving.

    Args:
        patience: number of epochs to wait before stopping
        mode: 'min' or 'max' — whether lower or higher is better
        min_delta: minimum improvement to count as progress
    """

    def __init__(self, patience: int = 15, mode: str = "max", min_delta: float = 1e-4):
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.counter = 0
        self.best_value = None

    def step(self, value: float) -> bool:
        """
        Check if training should stop.

        Args:
            value: current metric value
        Returns:
            True if training should stop
        """
        if self.best_value is None:
            self.best_value = value
            return False

        if self.mode == "max":
            improved = value > self.best_value + self.min_delta
        else:
            improved = value < self.best_value - self.min_delta

        if improved:
            self.best_value = value
            self.counter = 0
        else:
            self.counter += 1

        return self.counter >= self.patience


def save_checkpoint(model, optimizer, epoch, metrics, filepath):
    """Save model checkpoint."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "metrics": metrics,
    }, filepath)


def load_checkpoint(filepath, model, optimizer=None, device="cpu"):
    """Load model checkpoint."""
    checkpoint = torch.load(filepath, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return checkpoint


class MetricLogger:
    """
    Logs training metrics to memory and saves to CSV/JSON.
    Also generates training curve plots.
    """

    def __init__(self, output_dir: str = "results"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.history = {
            "epoch": [],
            "train_total_loss": [],
            "train_cls_loss": [],
            "train_recon_loss": [],
            "train_domain_loss": [],
            "train_lambda": [],
            "source_auroc": [],
            "source_f1": [],
            "target_auroc": [],
            "target_f1": [],
        }

    def log(self, epoch, train_losses, source_metrics, target_metrics):
        """Record one epoch of metrics."""
        self.history["epoch"].append(epoch)
        self.history["train_total_loss"].append(train_losses["total"])
        self.history["train_cls_loss"].append(train_losses["classification"])
        self.history["train_recon_loss"].append(train_losses["reconstruction"])
        self.history["train_domain_loss"].append(train_losses["domain"])
        self.history["train_lambda"].append(train_losses.get("lambda", 0))
        self.history["source_auroc"].append(source_metrics["auroc"])
        self.history["source_f1"].append(source_metrics["f1"])
        self.history["target_auroc"].append(target_metrics["auroc"])
        self.history["target_f1"].append(target_metrics["f1"])

    def save(self):
        """Save metrics history to JSON."""
        path = os.path.join(self.output_dir, "training_history.json")
        with open(path, "w") as f:
            json.dump(self.history, f, indent=2)
        print(f"Training history saved to {path}")

    def plot_training_curves(self):
        """Generate and save training curve plots."""
        epochs = self.history["epoch"]

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # Loss curves
        ax = axes[0, 0]
        ax.plot(epochs, self.history["train_total_loss"], label="Total", linewidth=2)
        ax.plot(epochs, self.history["train_cls_loss"], label="Classification", linewidth=1.5)
        ax.plot(epochs, self.history["train_recon_loss"], label="Reconstruction", linewidth=1.5)
        ax.plot(epochs, self.history["train_domain_loss"], label="Domain", linewidth=1.5)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_title("Training Losses")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # AUROC
        ax = axes[0, 1]
        ax.plot(epochs, self.history["source_auroc"], label="Source AUROC", linewidth=2)
        ax.plot(epochs, self.history["target_auroc"], label="Target AUROC", linewidth=2)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("AUROC")
        ax.set_title("AUROC Over Training")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # F1
        ax = axes[1, 0]
        ax.plot(epochs, self.history["source_f1"], label="Source F1", linewidth=2)
        ax.plot(epochs, self.history["target_f1"], label="Target F1", linewidth=2)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("F1 Score")
        ax.set_title("F1 Score Over Training")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Lambda schedule
        ax = axes[1, 1]
        ax.plot(epochs, self.history["train_lambda"], label="Domain λ", linewidth=2,
                color="purple")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Lambda Value")
        ax.set_title("Domain Adaptation Lambda Schedule")
        ax.legend()
        ax.grid(True, alpha=0.3)

        fig.tight_layout()
        path = os.path.join(self.output_dir, "training_curves.png")
        fig.savefig(path, dpi=150)
        plt.close(fig)
        print(f"Training curves saved to {path}")


def plot_tsne_features(
    source_features: np.ndarray,
    target_features: np.ndarray,
    save_path: str,
    title: str = "t-SNE Feature Visualization",
    max_samples: int = 2000,
):
    """
    Create t-SNE plot showing source vs target feature distributions.

    Args:
        source_features: (n, d) source latent features
        target_features: (m, d) target latent features
        save_path: path to save the plot
        title: plot title
        max_samples: max samples per domain (for speed)
    """
    from sklearn.manifold import TSNE

    # Subsample if needed
    if len(source_features) > max_samples:
        idx = np.random.choice(len(source_features), max_samples, replace=False)
        source_features = source_features[idx]
    if len(target_features) > max_samples:
        idx = np.random.choice(len(target_features), max_samples, replace=False)
        target_features = target_features[idx]

    # Combine and run t-SNE
    combined = np.concatenate([source_features, target_features], axis=0)
    n_source = len(source_features)

    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embedded = tsne.fit_transform(combined)

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.scatter(embedded[:n_source, 0], embedded[:n_source, 1],
               s=10, alpha=0.5, c="steelblue", label="Source")
    ax.scatter(embedded[n_source:, 0], embedded[n_source:, 1],
               s=10, alpha=0.5, c="indianred", label="Target")
    ax.set_title(title, fontsize=14)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"t-SNE plot saved to {save_path}")

Running the Full Pipeline

With all nine scripts in place, the complete workflow from data generation to final evaluation proceeds as follows. The commands below should be executed in order from the da-anomaly-detection/ directory.

Step-by-Step Commands

# Step 1: Install dependencies
pip install -r requirements.txt

# Step 2: Generate synthetic two-domain data
python generate_synthetic_data.py --output_dir data/ --n_samples 20000

# Step 3: Train with DANN (Domain-Adversarial Neural Network)
python train.py --method dann --epochs 100 --batch_size 64 --lr 0.001

# Step 4: Evaluate on target domain
python evaluate.py --checkpoint checkpoints/best_model.pt --data_dir data/ --method dann

# (Optional) Step 5: Train with MMD instead
python train.py --method mmd --epochs 100 --batch_size 64

# (Optional) Step 6: Train with CORAL instead
python train.py --method coral --epochs 100 --batch_size 64

Each training run reports progress every five epochs, saves the best model checkpoint based on target-domain AUROC, and writes training curves to the results/ directory. The evaluation script generates ROC curves, PR curves, score distribution histograms, and reconstruction-error time plots.

Domain Adaptation Implementation Pipeline Source Data Labeled sensor time-series Feature Extraction CNN-LSTM encoder Domain Alignment DANN / MMD / CORAL Target Data Unlabeled sensor stream Anomaly Detector Classifier + recon score Alerts & Results AUROC, F1, plots

Understanding the Results

Once the pipeline has been executed, a results/evaluation_results.json file contains the numerical outputs. Interpreting those numbers and determining whether domain adaptation is actually helping requires familiarity with the relevant metrics.

Interpreting the Evaluation Metrics

AUROC (Area Under the ROC Curve) is the primary metric. It expresses the probability that a randomly chosen anomaly scores higher than a randomly chosen normal sample. An AUROC of 0.5 corresponds to random performance and 1.0 to perfect discrimination. For domain adaptation to be regarded as successful, the target-domain AUROC should be significantly higher than the no-adaptation baseline (training only on source data and evaluating on target data without adaptation).

AUPRC (Area Under the Precision-Recall Curve) is more informative when anomalies are rare. In highly imbalanced datasets with a 1 percent anomaly rate, AUROC can appear favorable even when the model exhibits a high false positive rate. AUPRC penalizes false positives more strongly.

F1 Score is the harmonic mean of precision and recall computed at the optimal threshold. It provides a single value that balances false positives and false negatives. For industrial applications, recall (not missing anomalies) is typically prioritized over precision, since some false alarms are tolerable.

What Good vs. Bad Domain Adaptation Looks Like

Scenario Source AUROC Target AUROC (no adapt) Target AUROC (with DA) Interpretation
Successful adaptation 0.95 0.62 0.87 Domain adaptation recovered most performance
Negative transfer 0.95 0.65 0.58 DA made things worse; domains may be too different
No domain shift 0.93 0.91 0.92 Little domain shift exists; DA not needed
Partial adaptation 0.95 0.55 0.72 DA helps but gap remains; try tuning or more target data

 

Detection Accuracy: Before vs. After Domain Adaptation AUROC Score 1.00 0.90 0.80 0.70 0.60 Successful Adaptation 0.95 Source 0.62 No adapt 0.87 With DA Negative Transfer 0.95 Source 0.65 No adapt 0.58 With DA Partial Adaptation 0.95 Source 0.55 No adapt 0.72 With DA Source AUROC No Adaptation With DA (good) With DA (partial) Negative transfer

Interpreting t-SNE Plots

The t-SNE visualization is the most intuitive diagnostic tool available. It should be applied to the latent features before and after domain adaptation.

  • Before adaptation: Two distinct clusters typically appear, with source samples grouped in one region and target samples in another. This visual separation confirms that domain shift exists in the data.
  • After successful adaptation: The source and target clusters overlap substantially. The encoder has learned features that appear consistent regardless of which domain produced the input. If the anomaly classifier performs well on source features, it should now perform well on the overlapping target features as well.
  • After failed adaptation: Clusters remain separated, or in more severe cases the representation collapses to a single point, indicating mode collapse in the discriminator.

When to Use DANN, MMD, or CORAL

Method Mechanism Strengths Weaknesses Best For
DANN Adversarial training via GRL Powerful; learns complex alignment Unstable training; sensitive to hyperparameters Large domain shifts; enough training data
MMD Kernel-based distribution matching Stable training; mathematically principled Expensive for large batches; kernel selection matters Moderate domain shifts; limited compute
CORAL Covariance matrix alignment Simple; fast; no extra hyperparameters Only matches second-order statistics Small domain shifts; quick baseline

 

Tip: Begin with CORAL, which is the simplest and fastest method, to establish a baseline. If the resulting gap remains too large, proceed to MMD. Where maximum performance is required and some training instability is acceptable, use DANN with careful lambda scheduling.

Adapting to Custom Data

The synthetic data set serves only as a sandbox. The following steps describe how to integrate proprietary time-series data with minimal code changes.

Modifying dataset.py for a Specific Data Format

The CSV files should follow this structure: each row corresponds to a timestep, and each column other than label and timestamp corresponds to a sensor channel. The column names are unimportant as long as label and timestamp are named correctly or absent entirely. For data that uses a different format, the load_csv_data() function can be modified as follows.

# Example: your data has columns named 'temp_1', 'temp_2', 'vibration_x', etc.
# and uses 'anomaly' instead of 'label'
def load_csv_data(filepath, has_labels=True):
    df = pd.read_csv(filepath)
    exclude = ["anomaly", "timestamp", "machine_id", "date"]
    feature_cols = [c for c in df.columns if c not in exclude]
    data = df[feature_cols].values.astype(np.float32)
    labels = df["anomaly"].values.astype(np.float32) if has_labels else None
    return data, labels

Adjusting Model Dimensions

For data with a different number of channels, only num_features in config.py needs to change. The model adjusts automatically. For different sampling rates, the window_size should be adjusted; as a rule of thumb, the window should span roughly one cycle of the normal operating pattern. For a machine cycling every 5 seconds sampled at 100 Hz, window_size=500 is appropriate. For slow processes such as daily patterns at hourly sampling, window_size=24 is appropriate.

Handling Class Imbalance

Real anomaly data is heavily imbalanced, often with anomaly rates of 1 percent or less. Three strategies are effective within this codebase.

  1. Weighted BCE loss: Replace BCEWithLogitsLoss() with BCEWithLogitsLoss(pos_weight=torch.tensor([19.0])), where 19.0 is the ratio of normal to anomaly samples.
  2. Focal loss: Down-weights easy negatives. Replace the BCE in AnomalyDetectionLoss.
  3. Oversampling: Use PyTorch’s WeightedRandomSampler to oversample anomaly windows in the source training loader.

Hyperparameter Tuning Guide

The hyperparameters below are ordered by sensitivity, with the most sensitive listed first.

  1. lambda_domain (0.1–2.0): The most sensitive parameter. Excessively high values cause the encoder to learn domain-invariant features that are uninformative for anomaly detection. Excessively low values prevent any adaptation. A value of 0.5 is a reasonable starting point.
  2. learning_rate (1e-4–1e-2): Standard neural-network tuning. Cosine annealing is recommended.
  3. window_size (32–256): Should capture sufficient context for anomalies to be visible.
  4. latent_dim (64–256): Higher values provide more capacity but increase the risk of overfitting.
  5. alpha (0.5–0.9): Controls the mixture used in anomaly scoring. Higher values place more weight on the classifier output; lower values emphasize reconstruction error.

Common Issues and Solutions

Domain adaptation training is known to be sensitive to configuration choices. The reference table below lists problems that practitioners frequently encounter and the corresponding remedies.

Problem Symptom Cause Solution
Discriminator mode collapse Domain loss stays at ~0.69 (ln 2) Discriminator outputs 0.5 for everything Increase discriminator LR; add more layers; reduce GRL lambda
Training instability Loss oscillates wildly or diverges Lambda too high too early Use progressive lambda schedule; reduce learning rate; increase gradient clipping
Negative transfer Target AUROC decreases with DA Domains are too different or share no useful structure Reduce lambda_domain; try CORAL (less aggressive); verify domains share anomaly types
High false positive rate Good recall but terrible precision Threshold too low; recon error noisy Increase alpha (trust classifier more); use percentile threshold; add recon error smoothing
Source AUROC drops during DA Classification degrades on source Domain-invariant features lose discriminative power Increase lambda_cls; reduce lambda_domain; train classifier longer before starting DA
Out of memory (GPU) CUDA OOM error Batch size or model too large Reduce batch_size; reduce latent_dim; use gradient accumulation
MMD loss is NaN NaN in training Kernel bandwidth mismatch with feature scale Normalize features; adjust kernel_bandwidths in config; add epsilon to kernel computation

 

Caution: Domain adaptation assumes that the source and target domains share the same anomaly types and differ only in feature distributions. When the target domain exhibits fundamentally different anomaly mechanisms (not merely different sensor characteristics), domain adaptation will not help, and at least some labeled target data is required through semi-supervised adaptation.

Putting It Together

The preceding sections constitute a complete, end-to-end implementation of domain-adaptive time-series anomaly detection. A brief recapitulation and discussion of next steps follow.

The nine scripts cover the full pipeline: generating realistic synthetic data with domain shift, constructing a CNN-LSTM encoder with multi-head outputs, implementing three domain-adaptation strategies (DANN, MMD, and CORAL), training with progressive lambda scheduling, and evaluating with comprehensive metrics and diagnostic plots. Every script is complete and runnable as written.

The central insight is straightforward but consequential. Rather than requiring expensive labeled data in each new domain, a model can be trained to learn domain-invariant features: representations that capture the essence of “anomaly” regardless of which machine, factory, or sensor produced the signal. The Gradient Reversal Layer is the elegant mechanism that enables this adversarial training within a single unified model, while MMD and CORAL provide simpler and more stable alternatives.

Three directions are particularly promising for further development. First, semi-supervised adaptation: when even 5 to 10 percent of the target-domain data can be labeled, a supervised loss on those labeled target samples can be added alongside the unsupervised domain alignment, with substantial improvements in performance. Second, multi-source adaptation: when data are available from machines A, B, and C, adaptation to machine D can combine knowledge from all three sources rather than only one. Third, continual adaptation: in production, the target domain drifts over time as machines age and wear; periodic or online re-adaptation keeps the model current.

Domain adaptation is not a universal solution. It performs best when domains share the same underlying anomaly mechanisms but differ in superficial signal characteristics, which is the prevailing scenario in industrial settings. When it succeeds, it can save months of labeling effort and accelerate the deployment of anomaly detection to new equipment. The code provided in this guide contains everything needed to begin experimenting with proprietary data immediately.

References

  1. Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., and Lempitsky, V. (2016). “Domain-Adversarial Training of Neural Networks.” Journal of Machine Learning Research, 17(59), 1-35.
  2. Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., and Smola, A. J. (2012). “A Kernel Two-Sample Test.” Journal of Machine Learning Research, 13, 723-773.
  3. Sun, B. and Saenko, K. (2016). “Deep CORAL: Correlation Alignment for Deep Domain Adaptation.” Proceedings of the European Conference on Computer Vision (ECCV) Workshops.
  4. Ragab, M., Lu, Z., Chen, Z., Wu, M., Kwoh, C. K., and Li, X. (2023). “Time-Series Domain Adaptation: A Survey.” arXiv preprint.
  5. Chalapathy, R. and Chawla, S. (2019). “Deep Learning for Anomaly Detection: A Survey.” arXiv preprint.
  6. PyTorch Documentation. “Extending torch.autograd—Custom Function.”

You Might Also Like

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *