Summary
What this post covers: A complete, runnable implementation guide for domain-adaptive time-series anomaly detection in PyTorch, comprising nine production-ready scripts that implement DANN, MMD, and CORAL on top of a CNN-LSTM encoder for multi-channel sensor data.
Key insights:
- Domain shift between machines, sensors, factories, or seasons routinely reduces industrial anomaly-detection AUROC from approximately 0.95 on the source to roughly 0.6 on the target, and relabeling each new domain is economically infeasible because anomalies are rare.
- Three domain-adaptation losses cover the practical design space: DANN (adversarial, most flexible), MMD (kernel-based moment matching, simpler and more stable), and CORAL (second-order statistic alignment, with minimal hyperparameter overhead).
- A CNN-LSTM hybrid encoder with a shared feature extractor and separate anomaly and domain heads is a strong default architecture for multi-channel time series. The CNN captures local waveform shape and the LSTM captures temporal dependencies.
- Progressive lambda scheduling, in which the domain-adaptation weight is ramped from 0 toward 1 over training, is the single most important training practice. Without it the adversarial signal destabilizes feature learning.
- Domain adaptation succeeds only when source and target share the same underlying anomaly mechanisms but differ in superficial signal characteristics. Fundamentally different failure modes still require labeled target data through semi-supervised adaptation.
Main topics: Introduction: The Domain Shift Problem in Anomaly Detection, Project Structure and Setup, Configuration and Hyperparameters, Generating Realistic Synthetic Data, Dataset Classes and Data Loading, The Core Model Architecture, Loss Functions: DANN, MMD, and CORAL, The Main Training Script, Evaluation and Metrics, Utility Functions, Running the Full Pipeline, Understanding the Results, Adapting to Your Own Data, Common Issues and Solutions, Putting It Together, References.
Introduction: The Domain Shift Problem in Anomaly Detection
Consider an engineer who has spent six months collecting labeled anomaly data from a CNC milling machine on the factory floor, painstakingly tagging every spindle vibration spike, every thermal drift event, and every bearing degradation signature. The resulting anomaly detection model attains 0.95 AUROC on that machine. The company subsequently acquires a second milling machine from the same manufacturer and model line, differing only in production year. The model is deployed, and the AUROC falls to 0.62—barely better than a coin flip.
This is the domain shift problem, one of the most costly difficulties in industrial machine learning. The statistical distribution of sensor readings differs between machines, factories, sensor brands, and even seasons. Noise floors vary, baseline amplitudes drift, and the boundary between “normal” and “anomalous” deforms in subtle ways. A carefully trained model becomes essentially unusable the moment it leaves its original domain.
The conventional solution is to label data in each new domain. However, labeling anomaly data is exceptionally expensive: anomalies are rare by definition, and expert annotators are scarce. A more attractive approach is to transfer anomaly-detection knowledge from a labeled source domain (machine A) to an unlabeled target domain (machine B) without re-collecting labels.
This is precisely what domain adaptation provides. By training a model to learn features that are invariant across domains—features capturing the essence of “anomaly” regardless of which machine produced the signal—an analyst can detect anomalies in new domains with little or no labeled target data. The technique originated in computer vision through the DANN paper by Ganin et al. (2016), but its application to time-series anomaly detection remains underexplored in practice, even though it is highly relevant to industrial deployment.
This post is not a theoretical survey. It is a complete, runnable implementation guide. Readers who follow it through will obtain nine production-ready Python scripts that implement three domain adaptation strategies—DANN (Domain-Adversarial Neural Networks), MMD (Maximum Mean Discrepancy), and CORAL (CORrelation ALignment)—on top of a CNN-LSTM hybrid encoder for multi-channel time-series anomaly detection. Every script is complete, with no omissions or pseudocode.
The implementation proceeds below.
Project Structure and Setup
Before writing any code, it is useful to establish a clean project layout. Each file has a single responsibility, which makes the codebase easier to understand and adapt to a specific use case.
da-anomaly-detection/
├── config.py # Hyperparameters and configuration
├── dataset.py # Dataset classes and data loading
├── model.py # Model architecture (encoder, classifier, discriminator)
├── losses.py # Loss function definitions (DANN, MMD, CORAL)
├── train.py # Main training script with domain adaptation
├── evaluate.py # Evaluation and metrics
├── utils.py # Utility functions (seeding, checkpoints, plotting)
├── generate_synthetic_data.py # Generate example data for testing
├── requirements.txt # Dependencies
├── data/ # Generated or real data goes here
├── checkpoints/ # Saved model weights
└── results/ # Evaluation outputs, plots, metrics
The first step is to create the directory and install dependencies.
mkdir -p da-anomaly-detection/{data,checkpoints,results}
cd da-anomaly-detection
requirements.txt
torch>=2.0.0
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
tqdm>=4.65.0
pip install -r requirements.txt
pip install torch --index-url https://download.pytorch.org/whl/cu121
Configuration and Hyperparameters
Centralizing configuration prevents magic numbers from being scattered across the codebase. A Python dataclass is used here so that the IDE provides autocompletion and type checking without additional effort.
config.py
"""
config.py — Centralized configuration for domain-adaptive anomaly detection.
All hyperparameters live here. Override via CLI arguments in train.py.
"""
from dataclasses import dataclass, field
import torch
import os
@dataclass
class Config:
"""All hyperparameters and paths for the DA anomaly detection pipeline."""
# --- Data Parameters ---
num_features: int = 6 # Number of sensor channels
window_size: int = 64 # Sliding window length (timesteps)
stride: int = 16 # Stride for sliding window
train_ratio: float = 0.8 # Train/val split ratio
# --- Model Architecture ---
cnn_channels: list = field(default_factory=lambda: [32, 64, 128])
cnn_kernel_sizes: list = field(default_factory=lambda: [7, 5, 3])
lstm_hidden_dim: int = 128
lstm_num_layers: int = 2
latent_dim: int = 128 # Dimension of the shared feature space
classifier_hidden_dim: int = 64
discriminator_hidden_dim: int = 64
dropout: float = 0.3
# --- Training Parameters ---
batch_size: int = 64
learning_rate: float = 1e-3
discriminator_lr: float = 1e-3
weight_decay: float = 1e-4
epochs: int = 100
patience: int = 15 # Early stopping patience
# --- Domain Adaptation Parameters ---
adaptation_method: str = "dann" # 'dann', 'mmd', or 'coral'
lambda_domain: float = 1.0 # Max domain loss weight
lambda_recon: float = 0.5 # Reconstruction loss weight
lambda_cls: float = 1.0 # Classification loss weight
gamma: float = 10.0 # DANN lambda schedule steepness
mmd_kernel_bandwidth: list = field(
default_factory=lambda: [0.01, 0.1, 1.0, 10.0, 100.0]
)
# --- Anomaly Scoring ---
alpha: float = 0.7 # Weight for classifier score vs recon error
anomaly_threshold_percentile: float = 95.0
# --- Paths ---
data_dir: str = "data"
checkpoint_dir: str = "checkpoints"
results_dir: str = "results"
# --- Device and Reproducibility ---
seed: int = 42
device: str = ""
def __post_init__(self):
if not self.device:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.checkpoint_dir, exist_ok=True)
os.makedirs(self.results_dir, exist_ok=True)
lambda_domain. Too high, and the model loses its ability to classify anomalies. Too low, and domain adaptation has no effect. The progressive scheduling in the training script (the DANN lambda schedule) addresses this by starting low and ramping upward.
Generating Realistic Synthetic Data
Before working with proprietary data, a sandbox dataset is necessary. The script below generates two-domain synthetic time-series data with realistic characteristics: seasonal patterns, trends, multiple anomaly types, and domain-specific differences in noise, amplitude, and baseline offset. The source domain receives full labels, the target training set has no labels (which simulates the realistic scenario), and the target test set retains labels for evaluation purposes.
generate_synthetic_data.py
"""
generate_synthetic_data.py — Generate realistic two-domain time-series data
with injected anomalies for testing domain adaptation.
Simulates 6-channel sensor data (e.g., 3 joints x [torque, position]) from
two different machines with different noise/amplitude characteristics.
"""
import argparse
import os
import numpy as np
import pandas as pd
def generate_base_signal(n_samples: int, num_features: int, seed: int = 42) -> np.ndarray:
"""Generate a base multi-channel time-series with realistic patterns."""
rng = np.random.RandomState(seed)
t = np.arange(n_samples)
signals = np.zeros((n_samples, num_features))
for ch in range(num_features):
freq1 = 0.002 + ch * 0.001
freq2 = 0.01 + ch * 0.003
phase1 = rng.uniform(0, 2 * np.pi)
phase2 = rng.uniform(0, 2 * np.pi)
# Seasonal component
seasonal = 2.0 * np.sin(2 * np.pi * freq1 * t + phase1)
# Higher-frequency oscillation
oscillation = 0.8 * np.sin(2 * np.pi * freq2 * t + phase2)
# Slow trend
trend = 0.0005 * t * ((-1) ** ch)
# Combine
signals[:, ch] = seasonal + oscillation + trend
return signals
def inject_anomalies(
signals: np.ndarray,
anomaly_ratio: float = 0.05,
seed: int = 42
) -> tuple:
"""
Inject multiple anomaly types into signals.
Returns (modified_signals, labels) where labels[i]=1 means anomaly.
"""
rng = np.random.RandomState(seed)
n_samples, num_features = signals.shape
labels = np.zeros(n_samples, dtype=int)
modified = signals.copy()
n_anomalies = int(n_samples * anomaly_ratio)
anomaly_types = ["spike", "drift", "level_shift", "frequency_change"]
# Choose random anomaly locations (non-overlapping segments)
segment_length = 20
max_start = n_samples - segment_length
starts = rng.choice(max_start, size=n_anomalies, replace=False)
for i, start in enumerate(starts):
end = start + segment_length
a_type = anomaly_types[i % len(anomaly_types)]
channel = rng.randint(0, num_features)
if a_type == "spike":
spike_pos = start + rng.randint(0, segment_length)
magnitude = rng.uniform(5, 10) * (1 if rng.random() > 0.5 else -1)
modified[spike_pos, channel] += magnitude
labels[spike_pos] = 1
elif a_type == "drift":
drift = np.linspace(0, rng.uniform(3, 6), segment_length)
modified[start:end, channel] += drift
labels[start:end] = 1
elif a_type == "level_shift":
shift = rng.uniform(3, 7) * (1 if rng.random() > 0.5 else -1)
modified[start:end, channel] += shift
labels[start:end] = 1
elif a_type == "frequency_change":
t_seg = np.arange(segment_length)
high_freq = 2.0 * np.sin(2 * np.pi * 0.15 * t_seg)
modified[start:end, channel] += high_freq
labels[start:end] = 1
return modified, labels
def apply_domain_transform(
signals: np.ndarray,
noise_scale: float = 0.3,
amplitude_scale: float = 1.0,
baseline_offset: float = 0.0,
seed: int = 42
) -> np.ndarray:
"""Apply domain-specific transformations to simulate a different machine."""
rng = np.random.RandomState(seed)
transformed = signals.copy()
n_samples, num_features = transformed.shape
# Per-channel amplitude scaling
for ch in range(num_features):
ch_amp = amplitude_scale * rng.uniform(0.8, 1.2)
ch_offset = baseline_offset + rng.uniform(-0.5, 0.5)
transformed[:, ch] = transformed[:, ch] * ch_amp + ch_offset
# Add domain-specific noise
noise = rng.normal(0, noise_scale, transformed.shape)
transformed += noise
return transformed
def generate_dataset(
n_samples: int,
num_features: int,
anomaly_ratio: float,
noise_scale: float,
amplitude_scale: float,
baseline_offset: float,
seed: int
) -> pd.DataFrame:
"""Generate a complete dataset with signals, anomalies, and domain transform."""
base = generate_base_signal(n_samples, num_features, seed=seed)
with_anomalies, labels = inject_anomalies(base, anomaly_ratio, seed=seed + 1)
transformed = apply_domain_transform(
with_anomalies,
noise_scale=noise_scale,
amplitude_scale=amplitude_scale,
baseline_offset=baseline_offset,
seed=seed + 2
)
columns = [f"sensor_{i}" for i in range(num_features)]
df = pd.DataFrame(transformed, columns=columns)
df["label"] = labels
df["timestamp"] = pd.date_range("2024-01-01", periods=n_samples, freq="s")
return df
def main():
parser = argparse.ArgumentParser(
description="Generate synthetic two-domain time-series data."
)
parser.add_argument("--output_dir", type=str, default="data",
help="Output directory for CSV files")
parser.add_argument("--n_samples", type=int, default=20000,
help="Number of samples per dataset")
parser.add_argument("--num_features", type=int, default=6,
help="Number of sensor channels")
parser.add_argument("--anomaly_ratio", type=float, default=0.05,
help="Fraction of timesteps with anomalies")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
print("Generating source domain data (Machine A)...")
source_full = generate_dataset(
n_samples=args.n_samples,
num_features=args.num_features,
anomaly_ratio=args.anomaly_ratio,
noise_scale=0.2,
amplitude_scale=1.0,
baseline_offset=0.0,
seed=args.seed
)
split_idx = int(len(source_full) * 0.7)
source_train = source_full.iloc[:split_idx].reset_index(drop=True)
source_test = source_full.iloc[split_idx:].reset_index(drop=True)
print("Generating target domain data (Machine B)...")
target_full = generate_dataset(
n_samples=args.n_samples,
num_features=args.num_features,
anomaly_ratio=args.anomaly_ratio,
noise_scale=0.5, # Higher noise
amplitude_scale=1.4, # Different amplitude
baseline_offset=2.0, # Shifted baseline
seed=args.seed + 100
)
split_idx_t = int(len(target_full) * 0.7)
target_train = target_full.iloc[:split_idx_t].reset_index(drop=True)
target_test = target_full.iloc[split_idx_t:].reset_index(drop=True)
# Remove labels from target train (unsupervised in target domain)
target_train_unlabeled = target_train.drop(columns=["label"])
# Save all files
source_train.to_csv(os.path.join(args.output_dir, "source_train.csv"), index=False)
source_test.to_csv(os.path.join(args.output_dir, "source_test.csv"), index=False)
target_train_unlabeled.to_csv(os.path.join(args.output_dir, "target_train.csv"), index=False)
target_test.to_csv(os.path.join(args.output_dir, "target_test.csv"), index=False)
print(f"\nDatasets saved to {args.output_dir}/")
print(f" source_train.csv: {len(source_train)} samples, "
f"{source_train['label'].sum()} anomalies ({source_train['label'].mean()*100:.1f}%)")
print(f" source_test.csv: {len(source_test)} samples, "
f"{source_test['label'].sum()} anomalies ({source_test['label'].mean()*100:.1f}%)")
print(f" target_train.csv: {len(target_train_unlabeled)} samples (no labels)")
print(f" target_test.csv: {len(target_test)} samples, "
f"{target_test['label'].sum()} anomalies ({target_test['label'].mean()*100:.1f}%)")
if __name__ == "__main__":
main()
The script can be executed directly.
python generate_synthetic_data.py --output_dir data/ --n_samples 20000
The script produces four CSV files. The source data is fully labeled. The target training data is unlabeled, which reflects the central premise of domain adaptation. The target test data is labeled so that the effectiveness of adaptation can be measured.
Dataset Classes and Data Loading
Time-series anomaly detection operates on windows, that is, fixed-length slices of the signal. The dataset class below handles windowing, normalization (fit on source data and applied across all data), and optional data augmentation. The DomainAdaptationDataLoader pairs source and target batches for simultaneous training.
dataset.py
"""
dataset.py — PyTorch Dataset classes for time-series domain adaptation.
Handles sliding-window creation, normalization, augmentation, and
paired source-target batch generation.
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
class TimeSeriesDataset(Dataset):
"""
Sliding-window dataset for multi-channel time-series.
Args:
data: numpy array of shape (n_samples, num_features)
labels: numpy array of shape (n_samples,) or None for unlabeled data
window_size: number of timesteps per window
stride: step between consecutive windows
transform: optional callable for data augmentation
"""
def __init__(
self,
data: np.ndarray,
labels: np.ndarray = None,
window_size: int = 64,
stride: int = 16,
transform=None
):
self.data = data.astype(np.float32)
self.labels = labels
self.window_size = window_size
self.stride = stride
self.transform = transform
# Precompute valid window start indices
self.indices = list(range(0, len(data) - window_size + 1, stride))
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
start = self.indices[idx]
end = start + self.window_size
window = self.data[start:end] # (window_size, num_features)
if self.transform is not None:
window = self.transform(window)
# Transpose to (num_features, window_size) for Conv1d
window_tensor = torch.tensor(window, dtype=torch.float32).T
if self.labels is not None:
# Window label = 1 if any timestep in window is anomalous
window_label = float(self.labels[start:end].max())
return window_tensor, torch.tensor(window_label, dtype=torch.float32)
else:
return window_tensor, torch.tensor(-1.0, dtype=torch.float32)
class Normalizer:
"""
Fit on source training data, transform all data.
Uses per-channel mean and std normalization.
"""
def __init__(self):
self.mean = None
self.std = None
def fit(self, data: np.ndarray):
"""Compute mean and std from training data."""
self.mean = data.mean(axis=0)
self.std = data.std(axis=0)
# Prevent division by zero
self.std[self.std < 1e-8] = 1.0
return self
def transform(self, data: np.ndarray) -> np.ndarray:
"""Apply normalization."""
return (data - self.mean) / self.std
def fit_transform(self, data: np.ndarray) -> np.ndarray:
"""Fit and transform in one step."""
self.fit(data)
return self.transform(data)
class JitterTransform:
"""Add random Gaussian noise for data augmentation."""
def __init__(self, sigma: float = 0.03):
self.sigma = sigma
def __call__(self, window: np.ndarray) -> np.ndarray:
noise = np.random.normal(0, self.sigma, window.shape).astype(np.float32)
return window + noise
class ScalingTransform:
"""Random per-channel amplitude scaling for data augmentation."""
def __init__(self, sigma: float = 0.1):
self.sigma = sigma
def __call__(self, window: np.ndarray) -> np.ndarray:
factor = np.random.normal(1.0, self.sigma, (1, window.shape[1])).astype(np.float32)
return window * factor
class ComposeTransforms:
"""Chain multiple transforms together."""
def __init__(self, transforms: list):
self.transforms = transforms
def __call__(self, window: np.ndarray) -> np.ndarray:
for t in self.transforms:
window = t(window)
return window
def load_csv_data(filepath: str, has_labels: bool = True):
"""
Load a CSV file and separate features from labels.
Returns:
data: numpy array (n_samples, num_features)
labels: numpy array (n_samples,) or None
"""
df = pd.read_csv(filepath)
# Drop non-numeric columns like timestamp
feature_cols = [c for c in df.columns if c not in ("label", "timestamp")]
data = df[feature_cols].values.astype(np.float32)
labels = df["label"].values.astype(np.float32) if (has_labels and "label" in df.columns) else None
return data, labels
def create_data_loaders(config) -> dict:
"""
Create all data loaders for domain adaptation training.
Returns a dict with keys:
'source_train', 'source_val', 'target_train', 'target_test'
"""
import os
# Load raw data
source_train_data, source_train_labels = load_csv_data(
os.path.join(config.data_dir, "source_train.csv"), has_labels=True
)
source_test_data, source_test_labels = load_csv_data(
os.path.join(config.data_dir, "source_test.csv"), has_labels=True
)
target_train_data, _ = load_csv_data(
os.path.join(config.data_dir, "target_train.csv"), has_labels=False
)
target_test_data, target_test_labels = load_csv_data(
os.path.join(config.data_dir, "target_test.csv"), has_labels=True
)
# Normalize: fit on source train only
normalizer = Normalizer()
source_train_data = normalizer.fit_transform(source_train_data)
source_test_data = normalizer.transform(source_test_data)
target_train_data = normalizer.transform(target_train_data)
target_test_data = normalizer.transform(target_test_data)
# Optional augmentation for training
train_transform = ComposeTransforms([
JitterTransform(sigma=0.03),
ScalingTransform(sigma=0.1),
])
# Create datasets
source_train_ds = TimeSeriesDataset(
source_train_data, source_train_labels,
window_size=config.window_size, stride=config.stride,
transform=train_transform
)
source_test_ds = TimeSeriesDataset(
source_test_data, source_test_labels,
window_size=config.window_size, stride=config.stride
)
target_train_ds = TimeSeriesDataset(
target_train_data, labels=None,
window_size=config.window_size, stride=config.stride,
transform=train_transform
)
target_test_ds = TimeSeriesDataset(
target_test_data, target_test_labels,
window_size=config.window_size, stride=config.stride
)
# Create loaders
loaders = {
"source_train": DataLoader(
source_train_ds, batch_size=config.batch_size,
shuffle=True, drop_last=True, num_workers=0
),
"source_test": DataLoader(
source_test_ds, batch_size=config.batch_size,
shuffle=False, num_workers=0
),
"target_train": DataLoader(
target_train_ds, batch_size=config.batch_size,
shuffle=True, drop_last=True, num_workers=0
),
"target_test": DataLoader(
target_test_ds, batch_size=config.batch_size,
shuffle=False, num_workers=0
),
}
return loaders, normalizer
The Core Model Architecture
The model architecture lies at the heart of the system. It comprises four components that operate in concert: a shared encoder that processes time-series windows into a fixed-size feature vector; an anomaly classifier that predicts normal versus anomaly; a reconstruction decoder that reconstructs the original input and provides an auxiliary anomaly signal; and a domain discriminator that attempts to identify which domain produced a given feature vector. The essential ingredient is the Gradient Reversal Layer (GRL), which during backpropagation reverses the sign of gradients flowing from the domain discriminator to the encoder. This compels the encoder to learn features that are maximally uninformative about domain identity, which is precisely the domain-invariant representation required.
Architecture:
┌─── Anomaly Classifier (binary: normal/anomaly)
Input → Shared Encoder ─┤
(time-series) ├─── Reconstruction Decoder (autoencoder branch)
└─── Domain Discriminator (with gradient reversal)
model.py
"""
model.py — Domain-adaptive anomaly detection model architecture.
Components:
- GradientReversalLayer: reverses gradients for adversarial domain adaptation
- SharedEncoder: CNN + BiLSTM feature extractor
- AnomalyClassifier: binary classification head
- ReconstructionDecoder: autoencoder branch for reconstruction-based scoring
- DomainDiscriminator: adversarial domain classification head
- DomainAdaptiveAnomalyDetector: full model combining all components
"""
import torch
import torch.nn as nn
from torch.autograd import Function
class GradientReversalFunction(Function):
"""
Gradient Reversal Layer (GRL) — Ganin et al., 2016.
Forward pass: identity.
Backward pass: negate gradients and scale by lambda.
"""
@staticmethod
def forward(ctx, x, lambda_val):
ctx.lambda_val = lambda_val
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambda_val * grad_output, None
class GradientReversalLayer(nn.Module):
"""Module wrapper for the gradient reversal function."""
def __init__(self, lambda_val: float = 1.0):
super().__init__()
self.lambda_val = lambda_val
def set_lambda(self, lambda_val: float):
self.lambda_val = lambda_val
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_val)
class SharedEncoder(nn.Module):
"""
1D-CNN + Bidirectional LSTM encoder for multi-channel time-series.
Input shape: (batch, num_features, window_size)
Output shape: (batch, latent_dim)
"""
def __init__(
self,
num_features: int = 6,
cnn_channels: list = None,
cnn_kernel_sizes: list = None,
lstm_hidden_dim: int = 128,
lstm_num_layers: int = 2,
latent_dim: int = 128,
dropout: float = 0.3,
):
super().__init__()
if cnn_channels is None:
cnn_channels = [32, 64, 128]
if cnn_kernel_sizes is None:
cnn_kernel_sizes = [7, 5, 3]
# Build CNN layers
cnn_layers = []
in_channels = num_features
for out_ch, ks in zip(cnn_channels, cnn_kernel_sizes):
cnn_layers.extend([
nn.Conv1d(in_channels, out_ch, kernel_size=ks, padding=ks // 2),
nn.BatchNorm1d(out_ch),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_channels = out_ch
self.cnn = nn.Sequential(*cnn_layers)
# Bidirectional LSTM on top of CNN features
self.lstm = nn.LSTM(
input_size=cnn_channels[-1],
hidden_size=lstm_hidden_dim,
num_layers=lstm_num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if lstm_num_layers > 1 else 0.0,
)
# Project to latent space
self.fc = nn.Sequential(
nn.Linear(lstm_hidden_dim * 2, latent_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
)
self.latent_dim = latent_dim
def forward(self, x):
"""
Args:
x: (batch, num_features, window_size)
Returns:
latent: (batch, latent_dim)
"""
# CNN: (batch, cnn_channels[-1], window_size)
cnn_out = self.cnn(x)
# Transpose for LSTM: (batch, window_size, cnn_channels[-1])
lstm_in = cnn_out.permute(0, 2, 1)
# LSTM: (batch, window_size, lstm_hidden*2)
lstm_out, _ = self.lstm(lstm_in)
# Take last timestep output
last_hidden = lstm_out[:, -1, :]
# Project to latent space
latent = self.fc(last_hidden)
return latent
class AnomalyClassifier(nn.Module):
"""
Binary classification head: normal (0) vs anomaly (1).
Input: (batch, latent_dim)
Output: (batch, 1) — sigmoid logit
"""
def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, latent):
return self.net(latent)
class ReconstructionDecoder(nn.Module):
"""
Decoder that reconstructs the original input from latent features.
Uses LSTM + transposed Conv1d layers.
Input: (batch, latent_dim)
Output: (batch, num_features, window_size)
"""
def __init__(
self,
latent_dim: int = 128,
num_features: int = 6,
window_size: int = 64,
lstm_hidden_dim: int = 128,
dropout: float = 0.3,
):
super().__init__()
self.window_size = window_size
self.num_features = num_features
self.lstm_hidden_dim = lstm_hidden_dim
# Expand latent to sequence
self.fc = nn.Sequential(
nn.Linear(latent_dim, lstm_hidden_dim),
nn.ReLU(inplace=True),
)
# LSTM decoder
self.lstm = nn.LSTM(
input_size=lstm_hidden_dim,
hidden_size=lstm_hidden_dim,
num_layers=1,
batch_first=True,
)
# Transposed convolutions to reconstruct
self.deconv = nn.Sequential(
nn.ConvTranspose1d(lstm_hidden_dim, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.ConvTranspose1d(64, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(inplace=True),
nn.ConvTranspose1d(32, num_features, kernel_size=3, padding=1),
)
def forward(self, latent):
"""
Args:
latent: (batch, latent_dim)
Returns:
reconstruction: (batch, num_features, window_size)
"""
batch_size = latent.size(0)
# Expand to sequence
expanded = self.fc(latent).unsqueeze(1).repeat(1, self.window_size, 1)
# LSTM decode
lstm_out, _ = self.lstm(expanded)
# Transpose for Conv1d: (batch, lstm_hidden, window_size)
conv_in = lstm_out.permute(0, 2, 1)
# Reconstruct
reconstruction = self.deconv(conv_in)
return reconstruction
class DomainDiscriminator(nn.Module):
"""
Domain classification head with Gradient Reversal Layer.
Classifies whether features came from source (0) or target (1) domain.
Input: (batch, latent_dim)
Output: (batch, 1) — domain logit
"""
def __init__(self, latent_dim: int = 128, hidden_dim: int = 64, dropout: float = 0.3):
super().__init__()
self.grl = GradientReversalLayer(lambda_val=1.0)
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def set_lambda(self, lambda_val: float):
self.grl.set_lambda(lambda_val)
def forward(self, latent):
reversed_features = self.grl(latent)
return self.net(reversed_features)
class DomainAdaptiveAnomalyDetector(nn.Module):
"""
Full domain-adaptive anomaly detection model.
Combines encoder, anomaly classifier, reconstruction decoder,
and domain discriminator.
"""
def __init__(self, config):
super().__init__()
self.encoder = SharedEncoder(
num_features=config.num_features,
cnn_channels=config.cnn_channels,
cnn_kernel_sizes=config.cnn_kernel_sizes,
lstm_hidden_dim=config.lstm_hidden_dim,
lstm_num_layers=config.lstm_num_layers,
latent_dim=config.latent_dim,
dropout=config.dropout,
)
self.classifier = AnomalyClassifier(
latent_dim=config.latent_dim,
hidden_dim=config.classifier_hidden_dim,
dropout=config.dropout,
)
self.decoder = ReconstructionDecoder(
latent_dim=config.latent_dim,
num_features=config.num_features,
window_size=config.window_size,
lstm_hidden_dim=config.lstm_hidden_dim,
dropout=config.dropout,
)
self.discriminator = DomainDiscriminator(
latent_dim=config.latent_dim,
hidden_dim=config.discriminator_hidden_dim,
dropout=config.dropout,
)
def set_domain_lambda(self, lambda_val: float):
"""Update the GRL lambda for progressive scheduling."""
self.discriminator.set_lambda(lambda_val)
def forward(self, x):
"""
Full forward pass.
Args:
x: (batch, num_features, window_size)
Returns:
anomaly_logits: (batch, 1) — raw logits for anomaly classification
reconstruction: (batch, num_features, window_size) — reconstructed input
domain_logits: (batch, 1) — raw logits for domain classification
latent_features: (batch, latent_dim) — shared latent representation
"""
latent = self.encoder(x)
anomaly_logits = self.classifier(latent)
reconstruction = self.decoder(latent)
domain_logits = self.discriminator(latent)
return anomaly_logits, reconstruction, domain_logits, latent
Loss Functions: DANN, MMD, and CORAL
Domain adaptation is not a single technique but a family of techniques, each with distinct strengths. The implementation below supports three approaches selectable through a single configuration flag. DANN uses adversarial training based on the discriminator. MMD directly minimizes the statistical distance between source and target feature distributions through a kernel formulation. CORAL aligns the second-order statistics (covariance matrices) of the two domains. Switching between the methods requires a single configuration change.
losses.py
"""
losses.py — Loss functions for domain-adaptive anomaly detection.
Includes:
- AnomalyDetectionLoss (BCE for anomaly classification)
- ReconstructionLoss (MSE for autoencoder)
- DomainAdversarialLoss (BCE for domain discrimination)
- MMDLoss (Maximum Mean Discrepancy with Gaussian kernel)
- CORALLoss (CORrelation ALignment)
- CombinedLoss (weighted combination of all losses)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class AnomalyDetectionLoss(nn.Module):
"""Binary cross-entropy loss for anomaly classification."""
def __init__(self):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
def forward(self, logits, labels):
"""
Args:
logits: (batch, 1) raw anomaly logits
labels: (batch,) binary labels (0=normal, 1=anomaly)
"""
return self.bce(logits.squeeze(-1), labels)
class ReconstructionLoss(nn.Module):
"""MSE loss between input and reconstruction."""
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, reconstruction, original):
"""
Args:
reconstruction: (batch, num_features, window_size)
original: (batch, num_features, window_size)
"""
return self.mse(reconstruction, original)
class DomainAdversarialLoss(nn.Module):
"""BCE loss for domain classification (used with GRL for DANN)."""
def __init__(self):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
def forward(self, domain_logits, domain_labels):
"""
Args:
domain_logits: (batch, 1) raw domain logits
domain_labels: (batch,) domain labels (0=source, 1=target)
"""
return self.bce(domain_logits.squeeze(-1), domain_labels)
class MMDLoss(nn.Module):
"""
Maximum Mean Discrepancy loss with multi-scale Gaussian kernel.
Measures the distance between source and target feature distributions
in a reproducing kernel Hilbert space (RKHS).
"""
def __init__(self, kernel_bandwidths: list = None):
super().__init__()
if kernel_bandwidths is None:
self.kernel_bandwidths = [0.01, 0.1, 1.0, 10.0, 100.0]
else:
self.kernel_bandwidths = kernel_bandwidths
def gaussian_kernel(self, x, y):
"""
Compute multi-scale Gaussian kernel matrix between x and y.
Args:
x: (n, d) tensor
y: (m, d) tensor
Returns:
kernel_val: scalar — sum of Gaussian kernel values across bandwidths
"""
# Pairwise squared distances
xx = torch.mm(x, x.t())
yy = torch.mm(y, y.t())
xy = torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)
dxx = rx.t() + rx - 2.0 * xx
dyy = ry.t() + ry - 2.0 * yy
dxy = rx.t() + ry - 2.0 * xy
k_xx = torch.zeros_like(xx)
k_yy = torch.zeros_like(yy)
k_xy = torch.zeros_like(xy)
for bw in self.kernel_bandwidths:
k_xx += torch.exp(-dxx / (2.0 * bw))
k_yy += torch.exp(-dyy / (2.0 * bw))
k_xy += torch.exp(-dxy / (2.0 * bw))
return k_xx, k_yy, k_xy
def forward(self, source_features, target_features):
"""
Compute MMD^2 between source and target feature distributions.
Args:
source_features: (n, d) latent features from source domain
target_features: (m, d) latent features from target domain
Returns:
mmd_loss: scalar
"""
n = source_features.size(0)
m = target_features.size(0)
k_xx, k_yy, k_xy = self.gaussian_kernel(source_features, target_features)
mmd = (k_xx.sum() / (n * n)
+ k_yy.sum() / (m * m)
- 2.0 * k_xy.sum() / (n * m))
return mmd
class CORALLoss(nn.Module):
"""
CORrelation ALignment loss.
Aligns the second-order statistics (covariance matrices) of
source and target feature distributions.
"""
def __init__(self):
super().__init__()
def forward(self, source_features, target_features):
"""
Compute CORAL loss.
Args:
source_features: (n, d) latent features from source domain
target_features: (m, d) latent features from target domain
Returns:
coral_loss: scalar
"""
d = source_features.size(1)
n_s = source_features.size(0)
n_t = target_features.size(0)
# Compute covariance matrices
source_centered = source_features - source_features.mean(dim=0, keepdim=True)
target_centered = target_features - target_features.mean(dim=0, keepdim=True)
cov_source = (source_centered.t() @ source_centered) / (n_s - 1)
cov_target = (target_centered.t() @ target_centered) / (n_t - 1)
# Frobenius norm of covariance difference
diff = cov_source - cov_target
coral_loss = (diff * diff).sum() / (4 * d * d)
return coral_loss
class CombinedLoss(nn.Module):
"""
Combines anomaly detection, reconstruction, and domain adaptation losses.
total_loss = lambda_cls * anomaly_loss
+ lambda_recon * recon_loss
+ lambda_domain * domain_loss
The domain_loss component uses DANN, MMD, or CORAL depending on config.
"""
def __init__(self, config):
super().__init__()
self.anomaly_loss_fn = AnomalyDetectionLoss()
self.recon_loss_fn = ReconstructionLoss()
self.dann_loss_fn = DomainAdversarialLoss()
self.mmd_loss_fn = MMDLoss(kernel_bandwidths=config.mmd_kernel_bandwidth)
self.coral_loss_fn = CORALLoss()
self.lambda_cls = config.lambda_cls
self.lambda_recon = config.lambda_recon
self.lambda_domain = config.lambda_domain
self.method = config.adaptation_method
def forward(
self,
anomaly_logits,
anomaly_labels,
reconstruction,
original,
domain_logits=None,
domain_labels=None,
source_features=None,
target_features=None,
current_lambda=None,
):
"""
Compute combined loss.
Args:
anomaly_logits: (batch, 1) anomaly classification logits (source only)
anomaly_labels: (batch,) anomaly labels (source only)
reconstruction: (batch, num_features, window_size) reconstruction
original: (batch, num_features, window_size) original input
domain_logits: (batch, 1) domain logits (DANN only)
domain_labels: (batch,) domain labels (DANN only)
source_features: (n, d) source latent features (MMD/CORAL)
target_features: (m, d) target latent features (MMD/CORAL)
current_lambda: float — current domain adaptation weight
Returns:
total_loss, loss_dict (breakdown of individual losses)
"""
domain_weight = current_lambda if current_lambda is not None else self.lambda_domain
# Anomaly classification loss (source only)
cls_loss = self.anomaly_loss_fn(anomaly_logits, anomaly_labels)
# Reconstruction loss (both domains)
recon_loss = self.recon_loss_fn(reconstruction, original)
# Domain adaptation loss
if self.method == "dann" and domain_logits is not None:
domain_loss = self.dann_loss_fn(domain_logits, domain_labels)
elif self.method == "mmd" and source_features is not None:
domain_loss = self.mmd_loss_fn(source_features, target_features)
elif self.method == "coral" and source_features is not None:
domain_loss = self.coral_loss_fn(source_features, target_features)
else:
domain_loss = torch.tensor(0.0, device=anomaly_logits.device)
total_loss = (
self.lambda_cls * cls_loss
+ self.lambda_recon * recon_loss
+ domain_weight * domain_loss
)
loss_dict = {
"total": total_loss.item(),
"classification": cls_loss.item(),
"reconstruction": recon_loss.item(),
"domain": domain_loss.item(),
}
return total_loss, loss_dict
The Main Training Script
The training script integrates the entire system. The training loop coordinates the simultaneous optimization of the anomaly classifier on labeled source data, the reconstruction decoder on both domains, and the domain discriminator (adversarially) on both domains. The DANN lambda schedule progressively increases the strength of domain adaptation across training, following the formula from the original paper: λp = 2 / (1 + exp(-γ · p)) - 1, where p denotes training progress from 0 to 1.
train.py
"""
train.py — Main training script for domain-adaptive anomaly detection.
Supports three adaptation methods: DANN, MMD, CORAL.
Uses progressive lambda scheduling for stable training.
"""
import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from losses import CombinedLoss
from utils import (
set_seed,
EarlyStopping,
save_checkpoint,
MetricLogger,
)
def compute_dann_lambda(epoch: int, total_epochs: int, gamma: float = 10.0) -> float:
"""
Progressive lambda schedule from the DANN paper (Ganin et al., 2016).
Ramps from 0 to 1 over training using a sigmoid-like schedule.
lambda_p = 2 / (1 + exp(-gamma * p)) - 1, where p = epoch / total_epochs
"""
p = epoch / total_epochs
return float(2.0 / (1.0 + np.exp(-gamma * p)) - 1.0)
def train_one_epoch(
model,
source_loader,
target_loader,
criterion,
optimizer,
device,
epoch,
total_epochs,
config,
):
"""Train for one epoch with domain adaptation."""
model.train()
epoch_losses = {"total": 0, "classification": 0, "reconstruction": 0, "domain": 0}
n_batches = 0
# Compute current domain adaptation lambda
current_lambda = compute_dann_lambda(epoch, total_epochs, config.gamma) * config.lambda_domain
# Set the GRL lambda in the model
model.set_domain_lambda(current_lambda)
# Zip source and target loaders (cycle the shorter one)
target_iter = iter(target_loader)
for source_batch, source_labels in source_loader:
# Get target batch (cycle if exhausted)
try:
target_batch, _ = next(target_iter)
except StopIteration:
target_iter = iter(target_loader)
target_batch, _ = next(target_iter)
source_batch = source_batch.to(device)
source_labels = source_labels.to(device)
target_batch = target_batch.to(device)
# Determine actual batch sizes (may differ)
bs_s = source_batch.size(0)
bs_t = target_batch.size(0)
# Forward pass: source domain
s_anomaly_logits, s_recon, s_domain_logits, s_latent = model(source_batch)
# Forward pass: target domain
t_anomaly_logits, t_recon, t_domain_logits, t_latent = model(target_batch)
# Combine reconstructions and originals for loss
all_recon = torch.cat([s_recon, t_recon], dim=0)
all_original = torch.cat([source_batch, target_batch], dim=0)
# Domain labels: 0 for source, 1 for target
domain_labels = torch.cat([
torch.zeros(bs_s, device=device),
torch.ones(bs_t, device=device),
])
all_domain_logits = torch.cat([s_domain_logits, t_domain_logits], dim=0)
# Compute combined loss
total_loss, loss_dict = criterion(
anomaly_logits=s_anomaly_logits,
anomaly_labels=source_labels,
reconstruction=all_recon,
original=all_original,
domain_logits=all_domain_logits,
domain_labels=domain_labels,
source_features=s_latent,
target_features=t_latent,
current_lambda=current_lambda,
)
# Backprop
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Accumulate losses
for key in epoch_losses:
epoch_losses[key] += loss_dict[key]
n_batches += 1
# Average losses
for key in epoch_losses:
epoch_losses[key] /= max(n_batches, 1)
epoch_losses["lambda"] = current_lambda
return epoch_losses
@torch.no_grad()
def validate(model, loader, criterion, device, config):
"""Validate on a labeled dataset (source test or target test)."""
model.eval()
all_logits = []
all_labels = []
total_recon_loss = 0
n_batches = 0
for batch, labels in loader:
batch = batch.to(device)
labels = labels.to(device)
anomaly_logits, recon, _, latent = model(batch)
recon_loss = nn.MSELoss()(recon, batch)
all_logits.append(anomaly_logits.squeeze(-1).cpu())
all_labels.append(labels.cpu())
total_recon_loss += recon_loss.item()
n_batches += 1
all_logits = torch.cat(all_logits)
all_labels = torch.cat(all_labels)
# Compute metrics
probs = torch.sigmoid(all_logits)
preds = (probs > 0.5).float()
accuracy = (preds == all_labels).float().mean().item()
from sklearn.metrics import roc_auc_score, f1_score
try:
auroc = roc_auc_score(all_labels.numpy(), probs.numpy())
except ValueError:
auroc = 0.5 # Only one class present
f1 = f1_score(all_labels.numpy(), preds.numpy(), zero_division=0)
return {
"accuracy": accuracy,
"auroc": auroc,
"f1": f1,
"recon_loss": total_recon_loss / max(n_batches, 1),
}
def main():
parser = argparse.ArgumentParser(description="Train domain-adaptive anomaly detector")
parser.add_argument("--method", type=str, default="dann",
choices=["dann", "mmd", "coral"],
help="Domain adaptation method")
parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--lambda_domain", type=float, default=None)
parser.add_argument("--lambda_recon", type=float, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--device", type=str, default=None)
args = parser.parse_args()
# Build config with CLI overrides
config = Config()
config.adaptation_method = args.method
if args.epochs is not None:
config.epochs = args.epochs
if args.batch_size is not None:
config.batch_size = args.batch_size
if args.lr is not None:
config.learning_rate = args.lr
if args.lambda_domain is not None:
config.lambda_domain = args.lambda_domain
if args.lambda_recon is not None:
config.lambda_recon = args.lambda_recon
if args.seed is not None:
config.seed = args.seed
if args.data_dir is not None:
config.data_dir = args.data_dir
if args.device is not None:
config.device = args.device
# Setup
set_seed(config.seed)
device = torch.device(config.device)
print(f"Using device: {device}")
print(f"Adaptation method: {config.adaptation_method}")
print(f"Epochs: {config.epochs}, Batch size: {config.batch_size}, LR: {config.learning_rate}")
# Data
print("\nLoading data...")
loaders, normalizer = create_data_loaders(config)
print(f"Source train batches: {len(loaders['source_train'])}")
print(f"Target train batches: {len(loaders['target_train'])}")
# Model
model = DomainAdaptiveAnomalyDetector(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")
# Optimizer (single optimizer for simplicity; separate LRs via param groups)
optimizer = Adam([
{"params": model.encoder.parameters(), "lr": config.learning_rate},
{"params": model.classifier.parameters(), "lr": config.learning_rate},
{"params": model.decoder.parameters(), "lr": config.learning_rate},
{"params": model.discriminator.parameters(), "lr": config.discriminator_lr},
], weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)
# Loss
criterion = CombinedLoss(config)
# Early stopping
early_stopping = EarlyStopping(patience=config.patience, mode="max")
# Logging
logger = MetricLogger(config.results_dir)
# Training loop
best_target_auroc = 0.0
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
for epoch in range(config.epochs):
start_time = time.time()
# Train
train_losses = train_one_epoch(
model, loaders["source_train"], loaders["target_train"],
criterion, optimizer, device, epoch, config.epochs, config
)
# Validate on source test
source_metrics = validate(model, loaders["source_test"], criterion, device, config)
# Evaluate on target test (the real metric we care about)
target_metrics = validate(model, loaders["target_test"], criterion, device, config)
scheduler.step()
elapsed = time.time() - start_time
# Log
logger.log(epoch, train_losses, source_metrics, target_metrics)
# Print progress
if epoch % 5 == 0 or epoch == config.epochs - 1:
print(
f"Epoch {epoch:3d}/{config.epochs} ({elapsed:.1f}s) | "
f"Loss: {train_losses['total']:.4f} "
f"[cls={train_losses['classification']:.4f}, "
f"rec={train_losses['reconstruction']:.4f}, "
f"dom={train_losses['domain']:.4f}] | "
f"λ={train_losses['lambda']:.3f} | "
f"Src AUROC: {source_metrics['auroc']:.4f} | "
f"Tgt AUROC: {target_metrics['auroc']:.4f}"
)
# Save best model (based on target AUROC)
if target_metrics["auroc"] > best_target_auroc:
best_target_auroc = target_metrics["auroc"]
save_checkpoint(
model, optimizer, epoch, target_metrics,
os.path.join(config.checkpoint_dir, "best_model.pt")
)
# Early stopping on target AUROC
if early_stopping.step(target_metrics["auroc"]):
print(f"\nEarly stopping triggered at epoch {epoch}")
break
print("\n" + "=" * 60)
print(f"Training complete. Best target AUROC: {best_target_auroc:.4f}")
print(f"Best model saved to: {config.checkpoint_dir}/best_model.pt")
print("=" * 60)
# Save training curves
logger.save()
logger.plot_training_curves()
if __name__ == "__main__":
main()
Evaluation and Metrics
After training, rigorous evaluation on the target domain is required. The evaluation script computes standard anomaly-detection metrics, combines classifier and reconstruction scores, implements multiple threshold strategies, and produces diagnostic plots. This is the stage at which the success of domain adaptation can be assessed.
evaluate.py
"""
evaluate.py — Evaluation script for domain-adaptive anomaly detection.
Loads a trained model and evaluates on target domain test data.
Computes AUROC, AUPRC, F1, precision, recall.
Generates diagnostic plots and saves results to JSON.
"""
import argparse
import json
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import (
roc_auc_score,
average_precision_score,
f1_score,
precision_score,
recall_score,
accuracy_score,
confusion_matrix,
roc_curve,
precision_recall_curve,
)
from config import Config
from dataset import create_data_loaders
from model import DomainAdaptiveAnomalyDetector
from utils import set_seed, load_checkpoint
def compute_anomaly_scores(model, loader, device, alpha=0.7):
"""
Compute anomaly scores combining classifier output and reconstruction error.
anomaly_score = alpha * classifier_prob + (1 - alpha) * normalized_recon_error
Returns:
scores: numpy array of anomaly scores
labels: numpy array of ground truth labels
recon_errors: numpy array of per-sample reconstruction errors
classifier_probs: numpy array of classifier probabilities
latent_features: numpy array of latent features (for t-SNE)
"""
model.eval()
all_probs = []
all_labels = []
all_recon_errors = []
all_latent = []
with torch.no_grad():
for batch, labels in loader:
batch = batch.to(device)
anomaly_logits, recon, _, latent = model(batch)
# Classifier probability
probs = torch.sigmoid(anomaly_logits.squeeze(-1))
# Per-sample reconstruction error (mean across features and time)
recon_error = ((recon - batch) ** 2).mean(dim=(1, 2))
all_probs.append(probs.cpu().numpy())
all_labels.append(labels.numpy())
all_recon_errors.append(recon_error.cpu().numpy())
all_latent.append(latent.cpu().numpy())
all_probs = np.concatenate(all_probs)
all_labels = np.concatenate(all_labels)
all_recon_errors = np.concatenate(all_recon_errors)
all_latent = np.concatenate(all_latent)
# Normalize reconstruction errors to [0, 1]
re_min, re_max = all_recon_errors.min(), all_recon_errors.max()
if re_max - re_min > 1e-8:
norm_recon = (all_recon_errors - re_min) / (re_max - re_min)
else:
norm_recon = np.zeros_like(all_recon_errors)
# Combined anomaly score
scores = alpha * all_probs + (1 - alpha) * norm_recon
return scores, all_labels, all_recon_errors, all_probs, all_latent
def find_optimal_threshold(labels, scores):
"""Find the threshold that maximizes F1 score."""
thresholds = np.linspace(0, 1, 200)
best_f1 = 0
best_thresh = 0.5
for thresh in thresholds:
preds = (scores >= thresh).astype(int)
f1 = f1_score(labels, preds, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_thresh = thresh
return best_thresh, best_f1
def compute_all_metrics(labels, scores, threshold):
"""Compute all evaluation metrics at a given threshold."""
preds = (scores >= threshold).astype(int)
metrics = {
"auroc": float(roc_auc_score(labels, scores)),
"auprc": float(average_precision_score(labels, scores)),
"f1": float(f1_score(labels, preds, zero_division=0)),
"precision": float(precision_score(labels, preds, zero_division=0)),
"recall": float(recall_score(labels, preds, zero_division=0)),
"accuracy": float(accuracy_score(labels, preds)),
"threshold": float(threshold),
}
cm = confusion_matrix(labels, preds)
metrics["confusion_matrix"] = cm.tolist()
metrics["true_negatives"] = int(cm[0, 0])
metrics["false_positives"] = int(cm[0, 1])
metrics["false_negatives"] = int(cm[1, 0])
metrics["true_positives"] = int(cm[1, 1])
return metrics
def plot_roc_curve(labels, scores, save_path):
"""Plot and save ROC curve."""
fpr, tpr, _ = roc_curve(labels, scores)
auroc = roc_auc_score(labels, scores)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, "b-", linewidth=2, label=f"AUROC = {auroc:.4f}")
ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Random")
ax.set_xlabel("False Positive Rate", fontsize=12)
ax.set_ylabel("True Positive Rate", fontsize=12)
ax.set_title("ROC Curve — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"ROC curve saved to {save_path}")
def plot_pr_curve(labels, scores, save_path):
"""Plot and save Precision-Recall curve."""
precision, recall, _ = precision_recall_curve(labels, scores)
auprc = average_precision_score(labels, scores)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(recall, precision, "r-", linewidth=2, label=f"AUPRC = {auprc:.4f}")
baseline = labels.sum() / len(labels)
ax.axhline(y=baseline, color="k", linestyle="--", alpha=0.5, label=f"Baseline = {baseline:.3f}")
ax.set_xlabel("Recall", fontsize=12)
ax.set_ylabel("Precision", fontsize=12)
ax.set_title("Precision-Recall Curve — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"PR curve saved to {save_path}")
def plot_score_distribution(labels, scores, threshold, save_path):
"""Plot anomaly score distribution for normal vs anomaly samples."""
fig, ax = plt.subplots(figsize=(10, 6))
normal_scores = scores[labels == 0]
anomaly_scores = scores[labels == 1]
ax.hist(normal_scores, bins=50, alpha=0.6, color="steelblue", label="Normal", density=True)
ax.hist(anomaly_scores, bins=50, alpha=0.6, color="indianred", label="Anomaly", density=True)
ax.axvline(x=threshold, color="black", linestyle="--", linewidth=2,
label=f"Threshold = {threshold:.3f}")
ax.set_xlabel("Anomaly Score", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_title("Anomaly Score Distribution — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"Score distribution saved to {save_path}")
def plot_reconstruction_error(recon_errors, labels, save_path):
"""Plot reconstruction error over sample index, colored by label."""
fig, ax = plt.subplots(figsize=(14, 5))
indices = np.arange(len(recon_errors))
normal_mask = labels == 0
anomaly_mask = labels == 1
ax.scatter(indices[normal_mask], recon_errors[normal_mask],
s=2, alpha=0.4, c="steelblue", label="Normal")
ax.scatter(indices[anomaly_mask], recon_errors[anomaly_mask],
s=8, alpha=0.8, c="indianred", label="Anomaly")
ax.set_xlabel("Sample Index", fontsize=12)
ax.set_ylabel("Reconstruction Error", fontsize=12)
ax.set_title("Reconstruction Error Over Time — Target Domain", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"Reconstruction error plot saved to {save_path}")
def main():
parser = argparse.ArgumentParser(description="Evaluate domain-adaptive anomaly detector")
parser.add_argument("--checkpoint", type=str,
default="checkpoints/best_model.pt",
help="Path to model checkpoint")
parser.add_argument("--data_dir", type=str, default="data",
help="Data directory")
parser.add_argument("--results_dir", type=str, default="results",
help="Output directory for results")
parser.add_argument("--alpha", type=float, default=0.7,
help="Weight for classifier score vs recon error")
parser.add_argument("--method", type=str, default="dann",
choices=["dann", "mmd", "coral"])
parser.add_argument("--device", type=str, default="")
args = parser.parse_args()
config = Config()
config.data_dir = args.data_dir
config.results_dir = args.results_dir
config.adaptation_method = args.method
if args.device:
config.device = args.device
set_seed(config.seed)
device = torch.device(config.device)
os.makedirs(config.results_dir, exist_ok=True)
print(f"Device: {device}")
print(f"Loading checkpoint: {args.checkpoint}")
# Load model
model = DomainAdaptiveAnomalyDetector(config).to(device)
checkpoint = load_checkpoint(args.checkpoint, model, device=device)
print(f"Loaded model from epoch {checkpoint.get('epoch', '?')}")
# Load data
loaders, normalizer = create_data_loaders(config)
# --- Evaluate on target test set ---
print("\n--- Target Domain Evaluation ---")
scores, labels, recon_errors, probs, latent_features = compute_anomaly_scores(
model, loaders["target_test"], device, alpha=args.alpha
)
# Find optimal threshold
optimal_thresh, optimal_f1 = find_optimal_threshold(labels, scores)
print(f"Optimal threshold: {optimal_thresh:.4f} (F1 = {optimal_f1:.4f})")
# Percentile-based threshold
percentile_thresh = np.percentile(scores, config.anomaly_threshold_percentile)
print(f"Percentile ({config.anomaly_threshold_percentile}%) threshold: {percentile_thresh:.4f}")
# Compute metrics at optimal threshold
metrics_optimal = compute_all_metrics(labels, scores, optimal_thresh)
metrics_optimal["threshold_method"] = "f1_optimal"
# Compute metrics at percentile threshold
metrics_percentile = compute_all_metrics(labels, scores, percentile_thresh)
metrics_percentile["threshold_method"] = "percentile"
# Print results
print(f"\n{'Metric':<20} {'F1-Optimal':>12} {'Percentile':>12}")
print("-" * 46)
for key in ["auroc", "auprc", "f1", "precision", "recall", "accuracy"]:
print(f"{key:<20} {metrics_optimal[key]:>12.4f} {metrics_percentile[key]:>12.4f}")
# Also evaluate on source test for comparison
print("\n--- Source Domain Evaluation (baseline) ---")
src_scores, src_labels, _, _, src_latent = compute_anomaly_scores(
model, loaders["source_test"], device, alpha=args.alpha
)
src_thresh, _ = find_optimal_threshold(src_labels, src_scores)
src_metrics = compute_all_metrics(src_labels, src_scores, src_thresh)
print(f"Source AUROC: {src_metrics['auroc']:.4f}, F1: {src_metrics['f1']:.4f}")
# --- Generate plots ---
print("\nGenerating plots...")
plot_roc_curve(labels, scores, os.path.join(config.results_dir, "roc_curve.png"))
plot_pr_curve(labels, scores, os.path.join(config.results_dir, "pr_curve.png"))
plot_score_distribution(labels, scores, optimal_thresh,
os.path.join(config.results_dir, "score_distribution.png"))
plot_reconstruction_error(recon_errors, labels,
os.path.join(config.results_dir, "recon_error.png"))
# --- Save results ---
results = {
"method": config.adaptation_method,
"alpha": args.alpha,
"target_metrics_optimal": metrics_optimal,
"target_metrics_percentile": metrics_percentile,
"source_metrics": src_metrics,
}
results_path = os.path.join(config.results_dir, "evaluation_results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {results_path}")
if __name__ == "__main__":
main()
Utility Functions
The utility module handles reproducibility, early stopping, checkpointing, metric logging, and visualization, including t-SNE plots of feature distributions.
utils.py
"""
utils.py — Utility functions for the DA anomaly detection pipeline.
Includes:
- Seed setting for reproducibility
- EarlyStopping class
- Checkpoint save/load
- MetricLogger with CSV output and plotting
- t-SNE visualization of domain features
"""
import os
import random
import json
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def set_seed(seed: int = 42):
"""Set random seeds for reproducibility across all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EarlyStopping:
"""
Early stopping to halt training when a metric stops improving.
Args:
patience: number of epochs to wait before stopping
mode: 'min' or 'max' — whether lower or higher is better
min_delta: minimum improvement to count as progress
"""
def __init__(self, patience: int = 15, mode: str = "max", min_delta: float = 1e-4):
self.patience = patience
self.mode = mode
self.min_delta = min_delta
self.counter = 0
self.best_value = None
def step(self, value: float) -> bool:
"""
Check if training should stop.
Args:
value: current metric value
Returns:
True if training should stop
"""
if self.best_value is None:
self.best_value = value
return False
if self.mode == "max":
improved = value > self.best_value + self.min_delta
else:
improved = value < self.best_value - self.min_delta
if improved:
self.best_value = value
self.counter = 0
else:
self.counter += 1
return self.counter >= self.patience
def save_checkpoint(model, optimizer, epoch, metrics, filepath):
"""Save model checkpoint."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"metrics": metrics,
}, filepath)
def load_checkpoint(filepath, model, optimizer=None, device="cpu"):
"""Load model checkpoint."""
checkpoint = torch.load(filepath, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer is not None and "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint
class MetricLogger:
"""
Logs training metrics to memory and saves to CSV/JSON.
Also generates training curve plots.
"""
def __init__(self, output_dir: str = "results"):
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
self.history = {
"epoch": [],
"train_total_loss": [],
"train_cls_loss": [],
"train_recon_loss": [],
"train_domain_loss": [],
"train_lambda": [],
"source_auroc": [],
"source_f1": [],
"target_auroc": [],
"target_f1": [],
}
def log(self, epoch, train_losses, source_metrics, target_metrics):
"""Record one epoch of metrics."""
self.history["epoch"].append(epoch)
self.history["train_total_loss"].append(train_losses["total"])
self.history["train_cls_loss"].append(train_losses["classification"])
self.history["train_recon_loss"].append(train_losses["reconstruction"])
self.history["train_domain_loss"].append(train_losses["domain"])
self.history["train_lambda"].append(train_losses.get("lambda", 0))
self.history["source_auroc"].append(source_metrics["auroc"])
self.history["source_f1"].append(source_metrics["f1"])
self.history["target_auroc"].append(target_metrics["auroc"])
self.history["target_f1"].append(target_metrics["f1"])
def save(self):
"""Save metrics history to JSON."""
path = os.path.join(self.output_dir, "training_history.json")
with open(path, "w") as f:
json.dump(self.history, f, indent=2)
print(f"Training history saved to {path}")
def plot_training_curves(self):
"""Generate and save training curve plots."""
epochs = self.history["epoch"]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Loss curves
ax = axes[0, 0]
ax.plot(epochs, self.history["train_total_loss"], label="Total", linewidth=2)
ax.plot(epochs, self.history["train_cls_loss"], label="Classification", linewidth=1.5)
ax.plot(epochs, self.history["train_recon_loss"], label="Reconstruction", linewidth=1.5)
ax.plot(epochs, self.history["train_domain_loss"], label="Domain", linewidth=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Losses")
ax.legend()
ax.grid(True, alpha=0.3)
# AUROC
ax = axes[0, 1]
ax.plot(epochs, self.history["source_auroc"], label="Source AUROC", linewidth=2)
ax.plot(epochs, self.history["target_auroc"], label="Target AUROC", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("AUROC")
ax.set_title("AUROC Over Training")
ax.legend()
ax.grid(True, alpha=0.3)
# F1
ax = axes[1, 0]
ax.plot(epochs, self.history["source_f1"], label="Source F1", linewidth=2)
ax.plot(epochs, self.history["target_f1"], label="Target F1", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("F1 Score")
ax.set_title("F1 Score Over Training")
ax.legend()
ax.grid(True, alpha=0.3)
# Lambda schedule
ax = axes[1, 1]
ax.plot(epochs, self.history["train_lambda"], label="Domain λ", linewidth=2,
color="purple")
ax.set_xlabel("Epoch")
ax.set_ylabel("Lambda Value")
ax.set_title("Domain Adaptation Lambda Schedule")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
path = os.path.join(self.output_dir, "training_curves.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"Training curves saved to {path}")
def plot_tsne_features(
source_features: np.ndarray,
target_features: np.ndarray,
save_path: str,
title: str = "t-SNE Feature Visualization",
max_samples: int = 2000,
):
"""
Create t-SNE plot showing source vs target feature distributions.
Args:
source_features: (n, d) source latent features
target_features: (m, d) target latent features
save_path: path to save the plot
title: plot title
max_samples: max samples per domain (for speed)
"""
from sklearn.manifold import TSNE
# Subsample if needed
if len(source_features) > max_samples:
idx = np.random.choice(len(source_features), max_samples, replace=False)
source_features = source_features[idx]
if len(target_features) > max_samples:
idx = np.random.choice(len(target_features), max_samples, replace=False)
target_features = target_features[idx]
# Combine and run t-SNE
combined = np.concatenate([source_features, target_features], axis=0)
n_source = len(source_features)
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embedded = tsne.fit_transform(combined)
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(embedded[:n_source, 0], embedded[:n_source, 1],
s=10, alpha=0.5, c="steelblue", label="Source")
ax.scatter(embedded[n_source:, 0], embedded[n_source:, 1],
s=10, alpha=0.5, c="indianred", label="Target")
ax.set_title(title, fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(save_path, dpi=150)
plt.close(fig)
print(f"t-SNE plot saved to {save_path}")
Running the Full Pipeline
With all nine scripts in place, the complete workflow from data generation to final evaluation proceeds as follows. The commands below should be executed in order from the da-anomaly-detection/ directory.
Step-by-Step Commands
# Step 1: Install dependencies
pip install -r requirements.txt
# Step 2: Generate synthetic two-domain data
python generate_synthetic_data.py --output_dir data/ --n_samples 20000
# Step 3: Train with DANN (Domain-Adversarial Neural Network)
python train.py --method dann --epochs 100 --batch_size 64 --lr 0.001
# Step 4: Evaluate on target domain
python evaluate.py --checkpoint checkpoints/best_model.pt --data_dir data/ --method dann
# (Optional) Step 5: Train with MMD instead
python train.py --method mmd --epochs 100 --batch_size 64
# (Optional) Step 6: Train with CORAL instead
python train.py --method coral --epochs 100 --batch_size 64
Each training run reports progress every five epochs, saves the best model checkpoint based on target-domain AUROC, and writes training curves to the results/ directory. The evaluation script generates ROC curves, PR curves, score distribution histograms, and reconstruction-error time plots.
Understanding the Results
Once the pipeline has been executed, a results/evaluation_results.json file contains the numerical outputs. Interpreting those numbers and determining whether domain adaptation is actually helping requires familiarity with the relevant metrics.
Interpreting the Evaluation Metrics
AUROC (Area Under the ROC Curve) is the primary metric. It expresses the probability that a randomly chosen anomaly scores higher than a randomly chosen normal sample. An AUROC of 0.5 corresponds to random performance and 1.0 to perfect discrimination. For domain adaptation to be regarded as successful, the target-domain AUROC should be significantly higher than the no-adaptation baseline (training only on source data and evaluating on target data without adaptation).
AUPRC (Area Under the Precision-Recall Curve) is more informative when anomalies are rare. In highly imbalanced datasets with a 1 percent anomaly rate, AUROC can appear favorable even when the model exhibits a high false positive rate. AUPRC penalizes false positives more strongly.
F1 Score is the harmonic mean of precision and recall computed at the optimal threshold. It provides a single value that balances false positives and false negatives. For industrial applications, recall (not missing anomalies) is typically prioritized over precision, since some false alarms are tolerable.
What Good vs. Bad Domain Adaptation Looks Like
| Scenario | Source AUROC | Target AUROC (no adapt) | Target AUROC (with DA) | Interpretation |
|---|---|---|---|---|
| Successful adaptation | 0.95 | 0.62 | 0.87 | Domain adaptation recovered most performance |
| Negative transfer | 0.95 | 0.65 | 0.58 | DA made things worse; domains may be too different |
| No domain shift | 0.93 | 0.91 | 0.92 | Little domain shift exists; DA not needed |
| Partial adaptation | 0.95 | 0.55 | 0.72 | DA helps but gap remains; try tuning or more target data |
Interpreting t-SNE Plots
The t-SNE visualization is the most intuitive diagnostic tool available. It should be applied to the latent features before and after domain adaptation.
- Before adaptation: Two distinct clusters typically appear, with source samples grouped in one region and target samples in another. This visual separation confirms that domain shift exists in the data.
- After successful adaptation: The source and target clusters overlap substantially. The encoder has learned features that appear consistent regardless of which domain produced the input. If the anomaly classifier performs well on source features, it should now perform well on the overlapping target features as well.
- After failed adaptation: Clusters remain separated, or in more severe cases the representation collapses to a single point, indicating mode collapse in the discriminator.
When to Use DANN, MMD, or CORAL
| Method | Mechanism | Strengths | Weaknesses | Best For |
|---|---|---|---|---|
| DANN | Adversarial training via GRL | Powerful; learns complex alignment | Unstable training; sensitive to hyperparameters | Large domain shifts; enough training data |
| MMD | Kernel-based distribution matching | Stable training; mathematically principled | Expensive for large batches; kernel selection matters | Moderate domain shifts; limited compute |
| CORAL | Covariance matrix alignment | Simple; fast; no extra hyperparameters | Only matches second-order statistics | Small domain shifts; quick baseline |
Adapting to Custom Data
The synthetic data set serves only as a sandbox. The following steps describe how to integrate proprietary time-series data with minimal code changes.
Modifying dataset.py for a Specific Data Format
The CSV files should follow this structure: each row corresponds to a timestep, and each column other than label and timestamp corresponds to a sensor channel. The column names are unimportant as long as label and timestamp are named correctly or absent entirely. For data that uses a different format, the load_csv_data() function can be modified as follows.
# Example: your data has columns named 'temp_1', 'temp_2', 'vibration_x', etc.
# and uses 'anomaly' instead of 'label'
def load_csv_data(filepath, has_labels=True):
df = pd.read_csv(filepath)
exclude = ["anomaly", "timestamp", "machine_id", "date"]
feature_cols = [c for c in df.columns if c not in exclude]
data = df[feature_cols].values.astype(np.float32)
labels = df["anomaly"].values.astype(np.float32) if has_labels else None
return data, labels
Adjusting Model Dimensions
For data with a different number of channels, only num_features in config.py needs to change. The model adjusts automatically. For different sampling rates, the window_size should be adjusted; as a rule of thumb, the window should span roughly one cycle of the normal operating pattern. For a machine cycling every 5 seconds sampled at 100 Hz, window_size=500 is appropriate. For slow processes such as daily patterns at hourly sampling, window_size=24 is appropriate.
Handling Class Imbalance
Real anomaly data is heavily imbalanced, often with anomaly rates of 1 percent or less. Three strategies are effective within this codebase.
- Weighted BCE loss: Replace
BCEWithLogitsLoss()withBCEWithLogitsLoss(pos_weight=torch.tensor([19.0])), where 19.0 is the ratio of normal to anomaly samples. - Focal loss: Down-weights easy negatives. Replace the BCE in
AnomalyDetectionLoss. - Oversampling: Use PyTorch’s
WeightedRandomSamplerto oversample anomaly windows in the source training loader.
Hyperparameter Tuning Guide
The hyperparameters below are ordered by sensitivity, with the most sensitive listed first.
- lambda_domain (0.1–2.0): The most sensitive parameter. Excessively high values cause the encoder to learn domain-invariant features that are uninformative for anomaly detection. Excessively low values prevent any adaptation. A value of 0.5 is a reasonable starting point.
- learning_rate (1e-4–1e-2): Standard neural-network tuning. Cosine annealing is recommended.
- window_size (32–256): Should capture sufficient context for anomalies to be visible.
- latent_dim (64–256): Higher values provide more capacity but increase the risk of overfitting.
- alpha (0.5–0.9): Controls the mixture used in anomaly scoring. Higher values place more weight on the classifier output; lower values emphasize reconstruction error.
Common Issues and Solutions
Domain adaptation training is known to be sensitive to configuration choices. The reference table below lists problems that practitioners frequently encounter and the corresponding remedies.
| Problem | Symptom | Cause | Solution |
|---|---|---|---|
| Discriminator mode collapse | Domain loss stays at ~0.69 (ln 2) | Discriminator outputs 0.5 for everything | Increase discriminator LR; add more layers; reduce GRL lambda |
| Training instability | Loss oscillates wildly or diverges | Lambda too high too early | Use progressive lambda schedule; reduce learning rate; increase gradient clipping |
| Negative transfer | Target AUROC decreases with DA | Domains are too different or share no useful structure | Reduce lambda_domain; try CORAL (less aggressive); verify domains share anomaly types |
| High false positive rate | Good recall but terrible precision | Threshold too low; recon error noisy | Increase alpha (trust classifier more); use percentile threshold; add recon error smoothing |
| Source AUROC drops during DA | Classification degrades on source | Domain-invariant features lose discriminative power | Increase lambda_cls; reduce lambda_domain; train classifier longer before starting DA |
| Out of memory (GPU) | CUDA OOM error | Batch size or model too large | Reduce batch_size; reduce latent_dim; use gradient accumulation |
| MMD loss is NaN | NaN in training | Kernel bandwidth mismatch with feature scale | Normalize features; adjust kernel_bandwidths in config; add epsilon to kernel computation |
Putting It Together
The preceding sections constitute a complete, end-to-end implementation of domain-adaptive time-series anomaly detection. A brief recapitulation and discussion of next steps follow.
The nine scripts cover the full pipeline: generating realistic synthetic data with domain shift, constructing a CNN-LSTM encoder with multi-head outputs, implementing three domain-adaptation strategies (DANN, MMD, and CORAL), training with progressive lambda scheduling, and evaluating with comprehensive metrics and diagnostic plots. Every script is complete and runnable as written.
The central insight is straightforward but consequential. Rather than requiring expensive labeled data in each new domain, a model can be trained to learn domain-invariant features: representations that capture the essence of “anomaly” regardless of which machine, factory, or sensor produced the signal. The Gradient Reversal Layer is the elegant mechanism that enables this adversarial training within a single unified model, while MMD and CORAL provide simpler and more stable alternatives.
Three directions are particularly promising for further development. First, semi-supervised adaptation: when even 5 to 10 percent of the target-domain data can be labeled, a supervised loss on those labeled target samples can be added alongside the unsupervised domain alignment, with substantial improvements in performance. Second, multi-source adaptation: when data are available from machines A, B, and C, adaptation to machine D can combine knowledge from all three sources rather than only one. Third, continual adaptation: in production, the target domain drifts over time as machines age and wear; periodic or online re-adaptation keeps the model current.
Domain adaptation is not a universal solution. It performs best when domains share the same underlying anomaly mechanisms but differ in superficial signal characteristics, which is the prevailing scenario in industrial settings. When it succeeds, it can save months of labeling effort and accelerate the deployment of anomaly detection to new equipment. The code provided in this guide contains everything needed to begin experimenting with proprietary data immediately.
References
- Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M., and Lempitsky, V. (2016). “Domain-Adversarial Training of Neural Networks.” Journal of Machine Learning Research, 17(59), 1-35.
- Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., and Smola, A. J. (2012). “A Kernel Two-Sample Test.” Journal of Machine Learning Research, 13, 723-773.
- Sun, B. and Saenko, K. (2016). “Deep CORAL: Correlation Alignment for Deep Domain Adaptation.” Proceedings of the European Conference on Computer Vision (ECCV) Workshops.
- Ragab, M., Lu, Z., Chen, Z., Wu, M., Kwoh, C. K., and Li, X. (2023). “Time-Series Domain Adaptation: A Survey.” arXiv preprint.
- Chalapathy, R. and Chawla, S. (2019). “Deep Learning for Anomaly Detection: A Survey.” arXiv preprint.
- PyTorch Documentation. “Extending torch.autograd—Custom Function.”
Leave a Reply