Author: kongastral

  • What Is a Hook in AI? Lifecycle, PyTorch, and Webhook Patterns

    The term “hook” in the context of artificial intelligence will elicit different responses depending on the audience. The agent-framework engineer typically refers to a shell command that fires before Claude Code runs a tool. The deep-learning researcher has in mind a Python callback registered on a neural network layer to capture activations. The MLOps engineer envisions an HTTP POST that lands in Slack the moment a training run finishes. The same term covers three distinct mechanisms, three distinct audiences, and three distinct sets of debugging considerations.

    This overloading is not accidental: all three variants share the same underlying idea, namely a callback that fires at a defined point in another system’s execution. Treating them as interchangeable, however, is a frequent source of confusion. Advice to “use a hook” carries little practical value without specifying which variant is intended. The present guide therefore draws the boundaries explicitly and then accompanies each variant with working code.

    Summary

    What this post covers: The word “hook” in AI refers to at least three distinct mechanisms — agent lifecycle hooks (Claude Code and similar frameworks), model introspection hooks (PyTorch forward and backward callbacks), and MLOps event hooks (webhooks fired by training jobs and model registries). This post defines each, shows working code, and gives you a decision framework for picking the right one.

    Key insights:

    • Claude Code exposes 12 lifecycle events and a small number of handler types, with exit code 2 reserved as the “block this action” signal that feeds stderr back to Claude as an error message.
    • PyTorch hooks come in three core flavors — register_forward_pre_hook, register_forward_hook, and register_full_backward_hook — each with a fixed signature and a RemovableHandle you must call .remove() on to avoid leaks.
    • MLOps webhooks are just HTTP POSTs with HMAC signatures, but they amplify failures: a slow receiver can block a model registry, and a missing signature check turns your training pipeline into an open RCE surface.
    • The three flavors are not interchangeable — picking the wrong one (a PyTorch hook to enforce safety, a webhook for activation extraction) leads to brittle systems that fight their own runtime.
    • Hooks are powerful precisely because they don’t require modifying the host system, but the same property makes them invisible — discoverability and audit logging matter as much as the hook code itself.

    Main topics: Three different things people mean by “hook” in AI, Lifecycle hooks the agent-lifecycle flavor, A working Claude Code hooks example, Model introspection hooks the PyTorch flavor, A working PyTorch hooks example, Event hooks the MLOps webhook flavor, When to use which kind of hook, Common pitfalls.

    Three different things people mean by “hook” in AI

    Vocabulary first, then code. The three variants of “hook” in AI share the same skeletal definition—a user-supplied callback that fires at a defined point in another system’s execution—but they differ in every operationally important respect: where the callback runs, which process owns it, whether it can block the host, and what data it observes.

    A lifecycle hook fires at a specific moment in an agent’s session loop. The canonical example is Claude Code’s PreToolUse event, which fires after the model has decided to invoke a tool but before the tool actually executes. The hook is a separate process—a shell command, an HTTP endpoint, or an MCP server—that the agent invokes with structured JSON describing the intended action. The hook may approve, modify, or block the action through its exit code or response. Lifecycle hooks exist because agent runtimes require extensibility points that do not necessitate forking the agent itself.

    A model introspection hook is an in-process Python callback registered on a neural network module. PyTorch’s register_forward_hook is the canonical case: a function is supplied, and PyTorch calls that function every time the module’s forward() runs, passing the module, its input, and its output. The hook lives in the same process as the model, runs synchronously within the autograd graph (the system that tracks tensor operations for gradient computation), and may read or even modify tensors on the fly. Such hooks exist because researchers need to inspect a model without rewriting its source code.

    An event hook, usually called a webhook in MLOps contexts, is an HTTP POST issued by one service to another when a defined event occurs—a training run completes, a model is promoted to production, or a drift detector exceeds a threshold. The hook receiver lives in an entirely different process (often on a different host or behind a load balancer), authenticates via a shared secret with HMAC (a cryptographic signature method that proves the message was not tampered with), and runs asynchronously with respect to the event source. Webhooks exist because MLOps stacks are heterogeneous and require a low-friction mechanism for distributing events across systems.

    Three observations render this taxonomy useful rather than pedantic. First, the audiences scarcely overlap: the researcher confronting a vanishing gradient and the platform engineer integrating a model registry both rely on “hooks,” but their tooling, vocabulary, and failure modes have little in common. Second, the level of trust required differs sharply: a PyTorch hook runs inside the process and is implicitly trusted; a Claude Code hook executes shell commands and is trusted but auditable; a webhook crosses a network boundary and must therefore authenticate. Third, the cost of misclassification scales accordingly: an errant PyTorch hook leaks memory, an errant Claude Code hook may erase a file, and an errant webhook handler may broadcast secrets. Selecting the right variant is not merely a stylistic choice; it defines the security boundary of the entire feature.

    The figure below summarises the taxonomy:

    Three Meanings of “Hook” in AI Same skeletal idea (callback at a defined point), three operational realities Lifecycle Hook (Claude Code, agent frameworks) Fires at: agent session events Runs in: separate process (shell/HTTP) Can block? yes (exit code 2) Typical user: agent builders, safety teams Example: block rm -rf, auto-format after Edit Introspection Hook (PyTorch, TensorFlow, JAX) Fires at: forward / backward pass Runs in: same process, sync, in graph Can block? no, but can modify tensors Typical user: researchers, model debuggers Example: capture activations, log gradient norms Event Hook (Webhook) (MLflow, W&B, model registry) Fires at: infra/business events Runs in: remote service, async, HTTP Can block? indirectly (timeout, retries) Typical user: MLOps, platform teams Example: Slack alert on training failure All three are “callbacks at a defined point”, but they share nothing else. Pick by problem type, not by name.

    Key Takeaway: Readers interested in only one variant may proceed directly to the relevant section. Agent builders should consult the sections on lifecycle hooks and the Claude Code example. Deep-learning practitioners should refer to the PyTorch sections. MLOps engineers should focus on the webhook section. The decision-framework section at the end is intended for all readers.

    Lifecycle hooks: the agent-lifecycle flavor

    Lifecycle hooks are the most recent of the three variants to enter the AI lexicon, largely because agent frameworks themselves are recent. The mechanism is straightforward: an agent runtime defines a small set of events that mark notable moments in its operation, and handlers are registered to fire when those events occur.

    Claude Code, the CLI agent developed by Anthropic, exposes twelve such events in its current hooks system (per the official documentation at code.claude.com/docs/en/hooks, as of 2026-05-25). The events span the full session arc, from SessionStart when the agent boots, through UserPromptSubmit when the user submits input, to PreToolUse and PostToolUse that wrap every tool call, and finally to Stop and SessionEnd. Each event passes structured JSON to the handler describing the current operation, and the handler may respond with text (returned to Claude as additional context), a block decision, or simply an exit code.

    The significance of this mechanism is as follows: without hooks, customising an agent’s behaviour requires either writing a custom tool (a heavy approach) or relying on a CLAUDE.md instruction (an unreliable one). Hooks provide a third option—deterministic, code-enforced policy that fires regardless of the model’s decisions. If a hook returns exit code 2 on a PreToolUse for any Bash call matching /rm -rf \//, the tool will not run. The model is not merely asked not to run it; the tool will not run. This distinction constitutes the entire value proposition.

    Claude Code Session Lifecycle & Hook Insertion Points Each event lets you register a handler that fires at that exact moment SessionStart agent boots UserPromptSubmit you hit enter PreToolUse CAN BLOCK tool runs PostToolUse log, format, scan Notification tool perms etc. Stop CAN BLOCK SessionEnd cleanup Other events in the 12: – SubagentStop — fires when a spawned sub-agent finishes – PreCompact — fires before context is compacted (your chance to save state) – PreRespond — fires before Claude streams its reply (modify or annotate output) – Plus additional events for slash commands, file edits, and session restoration Red = blocking-capable. Check the official docs for the current authoritative list.

    The twelve events may be categorised by responsibility as follows:

    Event When it fires Can block? Typical use case
    SessionStart Agent boots up No Inject project context, set env vars
    UserPromptSubmit After you hit enter Yes Validate prompt, expand templates
    PreToolUse Before any tool runs Yes Safety check, dry-run preview
    PostToolUse After tool returns No Auto-format, log, scan output
    Notification Permission prompts, etc. No Forward to phone, log audit trail
    Stop Claude finishes its turn Yes Force continuation, run tests
    SubagentStop A sub-agent finishes Yes Collect sub-agent artifacts
    SessionEnd Session terminates No Final cleanup, session summary
    PreCompact Before context compaction No Persist scratchpad to disk
    PreRespond Before reply streams Yes Redact, annotate, classify
    Edit/file events On file modifications No Format, lint, version control
    Slash command events On /command invocation Varies Custom command preprocessing

     

    The names matter because matching is partially name-based. A hook configuration in .claude/settings.json specifies an event name and an optional matcher (a regular expression tested against the tool name for tool-related events), followed by a list of handlers. The handler contains the code that executes.

    Handler Types and Where They Run

    Claude Code’s hooks system currently supports four handler types per the official documentation (as of 2026-05-25; readers should consult the latest reference for the authoritative list, as this area continues to evolve). The three most commonly encountered are described below:

    Claude Code Hook Handler Types Input flow (event JSON) → handler → output flow (exit code + stdout/stderr or HTTP response) Command shell command on local disk Input: JSON on stdin Output: stdout/stderr + exit code Exit semantics: 0 = ok non-zero != 2 = warn 2 = block (stderr → Claude) Pros: simple, no server needed Cons: shell-injection risk if naive cold-start cost per call HTTP POST to a web endpoint Input: JSON in request body Output: JSON in HTTP response Response semantics: 200 + {action:”allow”} 200 + {action:”block”} 5xx / timeout = error Pros: central policy, multi-user Cons: network latency in hot path availability dependency MCP Model Context Protocol server Input: MCP request message Output: MCP response message Semantics: structured tool-like reply streaming supported capability-negotiated Pros: reuses MCP tooling/infra Cons: more setup than Command harder to debug ad-hoc

    The handler type should be chosen on the basis of the desired operational profile rather than syntactic preference:

    Handler type Best for Security posture When to pick
    Command Local, per-developer policies Runs as the local user; care required with untrusted arguments Default for solo or single-machine use
    HTTP Team-wide central policy Use TLS and auth header; isolate the receiver When a single policy must be enforced across many developers
    MCP Integration with existing MCP servers Inherits the MCP server security model When MCP infrastructure is already in operation and consistency is required

     

    Readers new to MCP may find the Model Context Protocol primer a useful companion. Hooks and MCP servers represent two of the principal extensibility surfaces in modern agent runtimes, and they frequently operate in concert.

    Exit Code Semantics for Command Handlers

    The Command handler’s contract is small but precise. According to the official hooks documentation (as of 2026-05-25):

    • Exit 0: success. Stdout is captured but treated as informational, and Claude proceeds normally.
    • Exit 2: blocking error. Stderr is returned to Claude as an error message. For PreToolUse this blocks the tool call entirely; for Stop it forces continuation. This is the appropriate code for deterministic prevention.
    • Other non-zero: warning. The event is logged but not blocked, which is useful for soft policy (“not recommended, but permitted”).

    PreToolUse Hook Flow How Claude Code decides whether your hook lets a tool run Claude decides to call a tool e.g. Bash: “rm -rf /tmp/x” Matcher checked PreToolUse + matcher=”Bash” → handler is selected Handler invoked JSON event payload on stdin: {“tool”:”Bash”,”input”:{“command”:”…”}} Handler runs, reads JSON, decides exit 0 Tool runs as planned stdout → context (optional addendum) exit 1 (or other) Tool runs anyway Warning logged stderr captured exit 2 Tool is BLOCKED stderr → Claude as error message

    Caution: Exit 1 should not be conflated with exit 2. Many shell scripts exit with code 1 on any error condition. When the intent is to block, the script must specifically exit with code 2. A hook that uses set -e and then crashes will exit non-zero but probably not with code 2, so the tool will run anyway and only a warning will be logged. Blocking paths should be tested explicitly.

    A Working Claude Code Hooks Example

    Concrete code follows. The .claude/settings.json file below configures three hooks: a PreToolUse safety check, a PostToolUse auto-formatter, and a SessionStart context injector.

    {
      "hooks": {
        "PreToolUse": [
          {
            "matcher": "Bash",
            "handlers": [
              {
                "type": "command",
                "command": ".claude/hooks/safety-check.sh"
              }
            ]
          }
        ],
        "PostToolUse": [
          {
            "matcher": "Edit|Write",
            "handlers": [
              {
                "type": "command",
                "command": ".claude/hooks/auto-format.sh"
              },
              {
                "type": "http",
                "url": "https://hooks.internal.example.com/claude-edit",
                "headers": {
                  "Authorization": "Bearer ${CLAUDE_HOOK_TOKEN}"
                }
              }
            ]
          }
        ],
        "SessionStart": [
          {
            "handlers": [
              {
                "type": "command",
                "command": ".claude/hooks/inject-context.sh"
              }
            ]
          }
        ]
      }
    }

    Note the two-handler array on PostToolUse: hooks compose. Both execute, and their outputs are aggregated. The matcher is a regular expression matched against the tool name; Edit|Write means the hook fires on either event.

    PreToolUse Safety Hook in Bash

    The shell script below blocks dangerous rm patterns and writes an audit log of every Bash invocation. It reads the event JSON from stdin (using jq for parsing) and exits with code 2 and an explanatory stderr message when a risky pattern is observed.

    #!/usr/bin/env bash
    # .claude/hooks/safety-check.sh
    # Blocks dangerous rm patterns; audits all Bash invocations.
    set -uo pipefail
    
    PAYLOAD=$(cat)
    CMD=$(echo "$PAYLOAD" | jq -r '.tool_input.command // empty')
    
    # Audit log first — we want every attempt recorded.
    mkdir -p .claude/audit
    echo "$(date -u +%FT%TZ)  $CMD" >> .claude/audit/bash.log
    
    # Block obvious destructive patterns.
    DANGEROUS_PATTERNS=(
      'rm[[:space:]]+-rf?[[:space:]]+/($|[[:space:]])'
      'rm[[:space:]]+-rf?[[:space:]]+/\*'
      'rm[[:space:]]+-rf?[[:space:]]+~'
      ':\(\)\{[[:space:]]*:\|:&[[:space:]]*\};:'  # fork bomb
      'mkfs\.'
      'dd[[:space:]]+if=/dev/(zero|random|urandom)[[:space:]]+of=/dev/sd'
    )
    
    for pat in "${DANGEROUS_PATTERNS[@]}"; do
      if [[ "$CMD" =~ $pat ]]; then
        echo "Blocked: command matches dangerous pattern '$pat'" >&2
        echo "If you really need to run this, do it manually outside Claude." >&2
        exit 2
      fi
    done
    
    # Also block writes to anything under /etc or /usr without sudo prompting.
    if [[ "$CMD" =~ (^|[[:space:]])(rm|mv|cp|tee|>)[[:space:]].*(/etc/|/usr/) ]]; then
      echo "Blocked: write to system path detected." >&2
      exit 2
    fi
    
    exit 0
    

    The pattern list is intentionally short, because long pattern lists provide a false sense of security. The real defence is the audit log: even when a command is not blocked, a tamper-evident record of Claude’s attempted actions remains available.

    PostToolUse Auto-Formatter

    #!/usr/bin/env bash
    # .claude/hooks/auto-format.sh
    # Runs Prettier / Black on any file Claude just edited.
    set -euo pipefail
    
    PAYLOAD=$(cat)
    FILE=$(echo "$PAYLOAD" | jq -r '.tool_input.file_path // .tool_input.path // empty')
    
    if [[ -z "$FILE" ]] || [[ ! -f "$FILE" ]]; then
      exit 0
    fi
    
    case "$FILE" in
      *.py)        ruff format "$FILE" 2>/dev/null || true ;;
      *.ts|*.tsx)  npx prettier --write "$FILE" 2>/dev/null || true ;;
      *.js|*.jsx)  npx prettier --write "$FILE" 2>/dev/null || true ;;
      *.json)      npx prettier --write "$FILE" 2>/dev/null || true ;;
      *.go)        gofmt -w "$FILE" 2>/dev/null || true ;;
    esac
    
    # PostToolUse is not blocking — exit 0 even on format failure.
    exit 0
    

    Note the || true: a missing formatter should not cause the hook to fail. Failing a PostToolUse hook with exit code 2 has no effect (the tool has already run), but exit code 1 still produces noise in the agent’s view.

    HTTP PostToolUse Hook (FastAPI Receiver)

    For team-wide policy or central observability, an HTTP hook is preferable to a per-machine command. A minimal FastAPI receiver is shown below:

    """Webhook receiver for Claude Code PostToolUse events.
    
    Run: uvicorn receiver:app --host 0.0.0.0 --port 8080
    """
    import hashlib
    import hmac
    import json
    import logging
    import os
    from datetime import datetime, timezone
    
    from fastapi import FastAPI, Header, HTTPException, Request
    
    app = FastAPI()
    log = logging.getLogger("claude_hook")
    logging.basicConfig(level=logging.INFO)
    
    SECRET = os.environ["CLAUDE_HOOK_SECRET"].encode("utf-8")
    
    
    def verify_signature(body: bytes, signature: str) -> bool:
        """HMAC-SHA256 signature check — prevents spoofed events."""
        expected = hmac.new(SECRET, body, hashlib.sha256).hexdigest()
        return hmac.compare_digest(expected, signature or "")
    
    
    @app.post("/claude-edit")
    async def claude_edit(
        request: Request,
        authorization: str | None = Header(default=None),
        x_signature: str | None = Header(default=None),
    ):
        body = await request.body()
    
        if not verify_signature(body, x_signature or ""):
            raise HTTPException(status_code=401, detail="bad signature")
    
        event = json.loads(body)
        log.info(
            "edit by %s on %s at %s",
            event.get("session_id", "?"),
            event.get("tool_input", {}).get("file_path", "?"),
            datetime.now(timezone.utc).isoformat(),
        )
    
        # Return JSON the agent can use. An empty body is fine for fire-and-forget.
        return {"status": "logged"}
    

    The signature check is important. Without it, any party able to reach the endpoint can fabricate “Claude edited /etc/passwd” events. The shared secret resides in CLAUDE_HOOK_SECRET on both the Claude Code client and the receiver.

    SessionStart Context Injector

    #!/usr/bin/env bash
    # .claude/hooks/inject-context.sh
    # Adds current git status, branch, and any TODO.md to Claude's session context.
    set -euo pipefail
    
    cat <<EOF
    Session starting at $(date -u +%FT%TZ).
    Current branch: $(git branch --show-current 2>/dev/null || echo 'not a git repo')
    Modified files:
    $(git status --short 2>/dev/null || echo 'none')
    
    TODOs in repo:
    $(test -f TODO.md && head -20 TODO.md || echo 'no TODO.md')
    EOF
    
    exit 0
    

    Whatever the hook prints on stdout becomes part of the session’s context: the model receives it before the first user prompt. This is the most underused hook event, since it provides Claude with project-specific situational awareness without enlarging CLAUDE.md.

    For further information on customising Claude Code’s behaviour beyond hooks, refer to the custom commands guide and the skills primer. Hooks fire automatically, whereas commands and skills are user-invoked. Together, these three mechanisms cover most extension scenarios.

    Model Introspection Hooks: the PyTorch Variant

    The context now changes. Setting agents aside, consider a Python process holding a PyTorch nn.Module in which the behaviour of tensors flowing through the module must be observed. Typical use cases include capturing activations for a probing experiment, logging gradient magnitudes to debug a training run, and clipping gradients per layer for an ablation study.

    PyTorch’s nn.Module class exposes a small set of hook registration methods that address these requirements without modifying the module’s forward code. The three most commonly used methods are described below:

    API Signature Fires when Typical use case
    register_forward_pre_hook hook(module, input) Before module.forward() runs Modify or inspect inputs
    register_forward_hook hook(module, input, output) After module.forward() returns Capture activations, inspect outputs
    register_full_backward_hook hook(module, grad_input, grad_output) After gradients computed for module Log/clip gradients, debug training

     

    All three methods return a RemovableHandle. This handle should be retained, and handle.remove() should be called when the hook is no longer required. Failure to remove the handle leaves the hook firing on every forward pass indefinitely, until the module is garbage-collected. In a long-running training job, this constitutes a memory and performance leak.

    PyTorch Forward Hook Synchronous, in-process, sees the (module, input, output) tuple Input tensor x: (B, C, H, W) from upstream Module.forward(x) e.g. nn.Conv2d, nn.Linear, ResNet block, transformer layer computes y = f(x) forward hook fires hook(module, input, output) your code runs here read tensors, save copies, log statistics, modify output Output y continues downstream Key properties: – Hook runs synchronously inside the autograd graph (gradient-tracking system) — overhead is real – Returning a non-None value from the hook replaces the output (advanced use, easy to break things) – Detach tensors before storing (output.detach().clone()) to avoid blowing up memory with the graph

    The backward hook operates similarly but in the reverse direction. After loss.backward() propagates gradients back through the graph, the backward hook fires for each module that has one registered, receiving the gradients flowing into and out of that module:

    PyTorch Backward Hook Fires during loss.backward(), in reverse order through the graph loss.backward() starts at scalar loss, walks graph in reverse grad flow Module (during backprop) computes ∂L/∂x from ∂L/∂y grad_input ← grad_output (via chain rule) backward hook fires hook(module, grad_input, grad_output) read grad norms, detect explode/vanish, clip in place Differences from forward hook: – grad_input and grad_output are TUPLES (one entry per tensor arg) — index carefully – Use register_full_backward_hook, not the deprecated register_backward_hook (broken for in-place ops) – Returning a modified grad_input tuple actually replaces what flows further upstream

    The distinction between register_backward_hook (deprecated) and register_full_backward_hook (current) is a small but consequential point that wastes considerable time when overlooked. The deprecated version exhibited ordering issues with in-place operations and produced incorrect gradients for modules with non-trivial structure. The full_ variant should always be preferred.

    For readers approaching this material from outside deep learning, brief definitions are provided. The forward pass is the computation that transforms inputs into outputs—for example, running an image through ResNet to obtain class scores. The backward pass is the reverse computation that determines the contribution of each parameter to the loss, using the chain rule of calculus. Autograd is PyTorch’s gradient-tracking machinery, which records every operation performed on a tensor so that those operations can be replayed in reverse when loss.backward() is called. A gradient is the vector of partial derivatives of the loss with respect to each parameter; it is the signal that informs the optimiser of the direction in which to adjust each weight. Hooks permit observation and modification of any of these quantities at module boundaries without altering the module’s source code.

    A Working PyTorch Hooks Example

    Three concrete tasks are demonstrated below: capturing activations from a ResNet block, logging gradient norms per layer to detect training instability, and clipping gradients in place to study the effect on a small training run.

    Activation Extraction for Probing or Visualisation

    Consider a scenario in which a pretrained ResNet-50 is available and the feature map following layer4 for an input image is required—perhaps to feed into a linear probe, perhaps to visualise the network’s response. Modifying the ResNet source code is undesirable, and a forward hook is the appropriate tool.

    Activation Extraction with a Forward Hook Capture intermediate features from a frozen pretrained model Step 1: Register the hook captures = {} handle = model.layer4.register_forward_hook(lambda m,i,o: captures.update(layer4=o.detach())) Image input (1, 3, 224, 224) PIL → tensor ResNet-50 forward pass conv1 → bn1 → relu → layer1 → layer2 → layer3 → layer4 ← hook attached here final logits (1, 1000) we discard these Step 2: After forward, captures[“layer4”] holds the activation – Shape: (1, 2048, 7, 7) for a 224×224 input — 2048-channel feature map – Detached from the autograd graph (we used.detach() to avoid keeping forward state alive) – Now usable for: linear probe, CAM visualization, feature similarity search, dataset embedding – Step 3: handle.remove() when done. Forget this and you leak the hook.

    import torch
    import torchvision.models as models
    from torchvision import transforms
    from PIL import Image
    
    # Pretrained ResNet-50, eval mode.
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    model.eval()
    
    # Where we will stash the activation.
    captures: dict[str, torch.Tensor] = {}
    
    def grab_layer4(module: torch.nn.Module,
                    inp: tuple[torch.Tensor, ...],
                    out: torch.Tensor) -> None:
        """Forward hook — copy the output, detach, store."""
        captures["layer4"] = out.detach().clone()
    
    # Register on the layer4 stack (a Sequential of three Bottleneck blocks).
    handle = model.layer4.register_forward_hook(grab_layer4)
    
    try:
        # Standard ImageNet preprocessing.
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        img = Image.open("dog.jpg").convert("RGB")
        x = preprocess(img).unsqueeze(0)
    
        with torch.no_grad():
            _ = model(x)   # we discard logits; we want the captured activation
    
        act = captures["layer4"]
        print(f"layer4 activation shape: {tuple(act.shape)}")
        # → layer4 activation shape: (1, 2048, 7, 7)
    
        # Now use `act` for whatever downstream analysis you want.
    finally:
        # ALWAYS remove the hook when done.
        handle.remove()
    
    Tip: The try/finally pattern is important. If downstream code raises an exception, a dangling hook will quietly increase memory pressure on the next inference. Registrations should be wrapped in a context manager if this pattern is used frequently.

    Logging Gradient Norms with a Backward Hook

    Gradient explosions are easier to diagnose when norms can be observed per layer. A few lines of backward hook code reduce this to a single-line printout per step:

    import torch
    import torch.nn as nn
    
    class SmallNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(128, 256)
            self.fc2 = nn.Linear(256, 256)
            self.fc3 = nn.Linear(256, 10)
    
        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            return self.fc3(x)
    
    model = SmallNet()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
    
    # Track grad norms by layer name.
    grad_norms: dict[str, float] = {}
    handles = []
    
    def make_hook(name: str):
        def hook(module, grad_input, grad_output):
            # grad_output is a tuple of grads w.r.t. each output tensor.
            # We log the L2 norm of the first one as a simple health metric.
            if grad_output[0] is not None:
                grad_norms[name] = grad_output[0].norm().item()
        return hook
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            handles.append(module.register_full_backward_hook(make_hook(name)))
    
    # Fake training step.
    x = torch.randn(32, 128)
    y = torch.randint(0, 10, (32,))
    
    for step in range(3):
        optimizer.zero_grad()
        logits = model(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        print(f"step {step}: " + ", ".join(f"{k}={v:.4f}" for k, v in grad_norms.items()))
    
    # Cleanup.
    for h in handles:
        h.remove()
    

    A typical output line takes the form step 0: fc1=0.0421, fc2=0.0573, fc3=0.1382. If norms expand by orders of magnitude between steps, or fall to zero for a layer that should be learning, the source of the problem is readily identifiable. This pattern is also common in transformer training: instrumenting attention and MLP blocks separately follows the same approach, simply across more modules. For further discussion of training-stack instrumentation, see the LLM training guide.

    Event Hooks: the MLOps Webhook Variant

    The third variant operates at an entirely different level. Webhooks are not located within an agent or a model; they connect services. When a training job finishes, that fact must reach a dashboard, a notification service, a downstream pipeline, and a model registry. Webhooks are the mechanism by which this distribution occurs without each service polling the others.

    The pattern is consistent across MLflow, Weights & Biases, HuggingFace, AWS SageMaker, and most model registries: when a defined event occurs, the source service sends an HTTP POST to a configured URL, with a JSON body describing the event and an HMAC signature in a header. The receiver verifies the signature, processes the event, and returns a 2xx status code (or signals failure and waits for a retry).

    MLOps Webhook Flow Event source → HTTP POST with HMAC signature → receiver fans out Training job epoch 50 complete val_acc=0.912 run_id=abc123 MLflow / W&B fires registered webhook POST + JSON body + X-Signature header (HMAC) Your webhook receiver FastAPI / Cloud Run / Lambda — anything HTTP verifies HMAC, returns 200 Fan out to multiple downstreams Slack “run abc123 hit 0.912” human notification PagerDuty on failure events only paged escalation Internal dashboard append to time series trigger eval pipeline

    Readers familiar with GitHub webhooks will recognise the design of MLOps webhooks: it is essentially the same. The header names vary by service (MLflow uses one name, Weights & Biases another, and so on), but the structure is invariant.

    Common events that vendors expose as webhooks include run.started, run.finished, and run.failed from training trackers; model.version.created, model.version.staged, and model.version.promoted from model registries; dataset.uploaded or dataset.versioned from data platforms; drift.detected or alert.fired from monitoring systems; and increasingly evaluation.completed from automated evaluation services. Each event is accompanied by a stable JSON schema, fixed per major version, and a payload-signing scheme that almost invariably follows the GitHub pattern: a SHA-256 HMAC of the raw body, hex-encoded, in a single header.

    One small but consequential decision concerns the location of the receiver. A long-running FastAPI application on a virtual machine places operational responsibility on the team when it fails outside business hours. A serverless function (Lambda, Cloud Run, Vercel) delegates availability to the platform and is charged per call, which is generally cheaper for low-volume webhook traffic. Most MLOps teams adopt serverless solutions for fan-out webhooks and reserve dedicated services for high-throughput hot paths such as real-time inference logging. The pattern is identical in either case; what differs is the operational profile.

    A webhook receiver for an MLflow “run completed” event, with strict HMAC checking, is shown below:

    """Receiver for MLflow run-completed webhooks.
    
    POST body shape (illustrative — check your MLflow version):
    {
      "event": "run.finished",
      "run": {
        "run_id": "abc123",
        "experiment_id": "42",
        "status": "FINISHED",
        "metrics": {"val_accuracy": 0.912, "val_loss": 0.231}
      },
      "timestamp": "2026-05-25T14:23:11Z"
    }
    """
    import hashlib
    import hmac
    import json
    import os
    from fastapi import FastAPI, Header, HTTPException, Request
    
    app = FastAPI()
    MLFLOW_SECRET = os.environ["MLFLOW_WEBHOOK_SECRET"].encode("utf-8")
    SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK_URL")
    
    
    def verify(body: bytes, signature_header: str) -> bool:
        """MLflow-style: 'sha256=<hex digest>'."""
        if not signature_header.startswith("sha256="):
            return False
        expected = hmac.new(MLFLOW_SECRET, body, hashlib.sha256).hexdigest()
        return hmac.compare_digest(expected, signature_header[len("sha256="):])
    
    
    @app.post("/mlflow/run-finished")
    async def run_finished(
        request: Request,
        x_mlflow_signature: str = Header(default=""),
    ):
        body = await request.body()
        if not verify(body, x_mlflow_signature):
            # Constant-time compare above; reject fast here.
            raise HTTPException(status_code=401, detail="bad signature")
    
        event = json.loads(body)
        run = event["run"]
        metrics = run.get("metrics", {})
        val_acc = metrics.get("val_accuracy")
    
        # Fan out: alert humans on Slack only above a threshold so we don't spam.
        if SLACK_WEBHOOK and val_acc is not None and val_acc > 0.90:
            import httpx
            msg = f"Run {run['run_id']} finished with val_accuracy={val_acc:.3f}"
            async with httpx.AsyncClient(timeout=5.0) as client:
                await client.post(SLACK_WEBHOOK, json={"text": msg})
    
        # Acknowledge — keep response tiny; the sender may impose a timeout.
        return {"ok": True}
    

    Three points warrant attention. The signature check uses hmac.compare_digest rather than ==; the latter leaks timing information that allows an attacker to recover the signature byte by byte. The Slack call uses a short timeout, because a slow Slack response should not hold the MLflow connection open and trigger MLflow’s own timeout-and-retry behaviour. Finally, the receiver returns quickly: for any work heavier than a Slack notification, the operation should be pushed onto a queue and acknowledged immediately.

    Webhook-adjacent patterns also appear in workflow orchestrators. Airflow’s on_success_callback and on_failure_callback are conceptually identical: they are in-process Python callbacks rather than HTTP POSTs, but they serve the same purpose. The Airflow orchestration guide describes how those callbacks compose with cross-system webhooks.

    Selecting the Appropriate Hook: A Decision Framework

    The three variants should by this point be clearly distinct. The remaining question is operational: given a problem, which variant should be selected? The matrix below provides guidance:

    Decision Matrix — Which Hook for Which Problem Green = best fit. Yellow = workable but suboptimal. Red = wrong tool. Problem Lifecycle (Claude Code) PyTorch (introspection) Webhook Safety enforcement (agents) YES wrong layer via HTTP hook Activation extraction wrong scope YES wrong process Training-complete alert wrong scope possible but odd YES Gradient debugging wrong layer YES no visibility Auto-format after edit YES (PostToolUse) wrong layer wrong process Model registry promotion wrong scope wrong scope YES

    A simple rule covers most cases: select the hook variant that operates at the same level as the entity to be observed or modified. Tensor flow occurs inside the model process and calls for PyTorch hooks. Agent decisions occur inside the agent runtime and call for lifecycle hooks. The training-job lifecycle spans services and calls for webhooks. Mixing levels works occasionally but usually creates more integration work than it eliminates.

    A side-by-side reference for rapid selection follows:

    Property Lifecycle (Claude Code) PyTorch introspection Webhook (MLOps)
    Fires when Agent session events Per forward/backward call Infra/business events
    Runs where Shell, HTTP, or MCP Same Python process, sync Remote HTTP service, async
    Blocks execution? Yes (exit 2) No, but can modify tensors Indirectly (timeout/retry)
    Language Any (shell, Python, Go) Python only Any (it’s just HTTP)
    Typical user Agent builders, safety teams Researchers, model debuggers MLOps, platform teams
    Auth model Filesystem perms (Command) or bearer (HTTP) In-process trust HMAC signature

     

    Common Pitfalls

    Each variant has its own failure modes. Awareness of these failure modes in advance saves considerable debugging time.

    Claude Code Hook Pitfalls

    Shell injection through tool arguments. A PreToolUse hook receives JSON containing whatever Claude intends to execute. Naively interpolating fields into a shell command—such as echo "$CMD" | grep ...—exposes a path to remote code execution from prompt-injection-style attacks. JSON should always be parsed with jq or another proper parser, never with string slicing.

    Infinite hook loops. A PostToolUse hook that itself uses Claude to summarise the output, where Claude then invokes tools to summarise, and those tools trigger the PostToolUse hook again, produces a stack that is typically discovered at an inconvenient hour. Hooks should be terminal: they observe but do not re-invoke the agent.

    Exit-code confusion. Bash’s set -e exits non-zero on any failure but not necessarily with code 2. If a hook’s safety-check command crashes for an unrelated reason, the tool will run anyway because the exit code is not the blocking value. When blocking matters, the script should exit with code 2 explicitly.

    The hook is not versioned with the agent. Hook semantics evolve. A handler that worked under one Claude Code version may break under another (renamed fields in the event JSON, new required fields, and so on). Hook scripts should be pinned to the agent version against which they were tested, and re-tested after upgrades.

    PyTorch Hook Pitfalls

    Failing to call handle.remove(). This is the most common bug. A leaked forward hook is difficult to detect: the model continues to function, but more slowly, and memory usage drifts upwards. handle.remove() should be treated like close() and written on the same line as the registration where possible, or wrapped in a context manager.

    Storing tensors with the graph attached. Storing output rather than output.detach() retains the entire computation graph leading to that output. On a fifty-layer model the consequences are severe. Tensors should always be detached, and usually cloned, before storage.

    Hooks added in __init__ versus registered post hoc. Hooks registered on a module from outside do not survive a deep copy of the model (a common pattern in distributed training). Hooks installed in the module’s own __init__ do survive, because they form part of the module’s state. If the training launcher uses copy.deepcopy or torch.nn.parallel.replicate, registration should occur inside the module.

    Overhead in tight loops. Every hook adds Python-level overhead per call. This is acceptable for offline analysis but problematic in a training loop with tens of thousands of iterations per epoch. Hooks should be registered only on the modules of interest, only for the steps of interest, and removed immediately afterwards.

    For training-loop instrumentation that extends beyond gradient logging, the self-supervised learning guide presents similar patterns applied to representation extraction during pretraining.

    Webhook Pitfalls

    Timeout amplification. The sender (MLflow, Weights & Biases, or the model registry in question) typically imposes a short timeout, often five or ten seconds. If the receiver performs any slow operation inline—a database write, a slow Slack call, or ML inference—events will be missed and retries triggered. The recommended pattern is to receive quickly, queue the work, and return a 2xx status code.

    Missing signature verification. An unverified webhook endpoint is a public remote-code-execution risk if the handler performs any privileged operation with the payload. HMAC should be verified on every request, compared with hmac.compare_digest, and the source IP should not be relied upon.

    At-least-once semantics. Almost every webhook sender retries on failure, so the receiver will observe the same event more than once. The handler must be idempotent: the same event delivered twice should not double-count, double-notify, or double-promote.

    Replay attacks. Even with HMAC, a captured request can be replayed. A timestamp should be included in the signature payload (most senders do this already), and events older than a small window should be rejected.

    Caution: Across all three variants, the most common silent failure is the same: the hook is in place but is not actually executing. A misconfigured matcher, a leftover handle, or a webhook endpoint that senders no longer reach can all produce this outcome. Observability should be added through audit logs and gauge metrics on hook invocation counts so that a non-firing hook is detected.

    Frequently Asked Questions

    Are Claude Code hooks the same as MCP servers?

    No. MCP servers extend what an agent can do by exposing new tools, resources, and prompts that the agent can call. Hooks extend the agent’s lifecycle by inserting policy at predefined moments. Both can be used simultaneously; a common pattern is an MCP server that provides project context paired with a PreToolUse hook that enforces safety on the agent’s tool calls. The two systems are complementary rather than redundant.

    Does register_forward_hook affect gradients?

    It can. If the hook returns a tensor in place of None, that tensor replaces the module’s output for the remainder of the forward pass, and gradients flow through the replacement during backpropagation. If the hook only reads tensors and returns None, gradients are unaffected. The same applies to backward hooks: returning a modified grad_input tuple replaces what propagates further back. For read-only inspection, the hook should return nothing.

    Can webhooks block a training job?

    Indirectly. If a model-registry promotion event has a configured webhook receiver that times out, some registries pause the promotion pending retries while others fail the promotion entirely. In either case, the system being hooked into determines whether a slow receiver can stall the workflow. The documentation for the specific service should be consulted. As a general rule, webhooks should be treated as fire-and-forget signals rather than synchronous gates.

    What is the relationship between hooks and callbacks?

    The terms are largely synonymous, with a difference of connotation. “Callback” implies a function registered by the user, often for a single defined moment. “Hook” implies a registered extension point exposed by the host system, often one of many. PyTorch documentation uses “hook”; asyncio documentation uses “callback”; the underlying concept is the same. In MLOps, Airflow uses “callback” (on_success_callback) while GitHub uses “webhook”—the same pattern, expressed in different vocabulary.

    Are there security risks specific to lifecycle hooks?

    Yes, three principal risks. First, hooks run with the agent’s privileges, which usually corresponds to the user’s shell, so a bug in a hook script can cause real damage on the machine. Second, hook payloads contain whatever the model intends to do, including potentially adversarial content arising from prompt injection; naive shell interpolation is dangerous. Third, hooks are invisible: a colleague inspecting an agent session will not see the hook fire unless it is logged. Audit logging and code review for hook scripts are as important as for production code. The harness engineering guide covers the broader threat model.

    References

    Conclusion

    “Hook” in AI is a small term performing three distinct functions. Lifecycle hooks allow deterministic policy to be inserted into an agent’s session without forking the agent. Model introspection hooks allow tensor flow to be read or modified without forking the model. Event hooks allow services to communicate about significant moments without polling. The mechanisms share a name and a skeletal definition—a callback at a defined point—but they differ in process, language, blocking semantics, security model, and audience.

    The practical guidance reduces to three rules. First, select the variant that matches the layer at which the problem resides; agent safety should not be enforced with a PyTorch hook, nor should activations be extracted via a webhook. Second, treat hook code as production code: review it, audit it, log it, and version it alongside the system it extends. Third, recall that hooks are powerful precisely because they are invisible to the host; that invisibility is also their principal failure mode, so observability should be built in to detect when a hook ceases to fire.

    One habit worth taking from this guide is the following: whenever the advice to “use a hook” appears in documentation or in a blog post, the appropriate first question is which variant. The answer almost always determines the correct design.

  • How to Train Open-Source LLMs in 2026: Qwen3.6, Qwen3.5, GPT-OSS

    Two years ago, training a large language model required either renting time at a research lab or accepting that fine-tuning was the preserve of billion-dollar companies. By May 2026, Qwen3.6-27B can be taken from a Hugging Face download to a domain-specialised model on a single rented H100 for less than fifteen dollars. The tools have changed. The underlying mathematics has not, but the population of those who use it has expanded. This article describes how to train an open-source LLM in practice today: what hardware is required, which model to choose, how to format the data so that the trainer does not silently discard it, and how to place the result behind a serving endpoint that responds in milliseconds.

    Summary

    What this post covers: A working 2026 playbook for fine-tuning open-source LLMs using three concrete anchors — the dense Qwen3.6-27B, the MoE Qwen3.5-122B-A10B, and OpenAI’s GPT-OSS-120B — from environment setup through deployment.

    Key insights:

    • QLoRA on a single H100 (80GB) now fine-tunes a 27B dense model in 8 to 12 hours for $10 to $16 of cloud rental, retaining 80 to 90 percent of full fine-tuning quality.
    • MoE models like Qwen3.5-122B-A10B (10B active) and GPT-OSS-120B (5.1B active) need VRAM to hold all 122B or 117B weights, even though per-token compute is small — the “active parameter” headline number is a runtime FLOPs claim, not a memory one.
    • Chat-template mismatch between training and inference is the single most common cause of a “trained but acts untrained” model — Qwen’s <|im_start|> markers and GPT-OSS’s harmony format are not interchangeable.
    • GPT-OSS-120B ships post-trained with MXFP4 quantization on the MoE weights, which is why a 117B-total-parameter model fits in a single 80GB H100 at inference time.
    • For anything past 70B at full precision, FSDP2 or DeepSpeed ZeRO-3 sharding is no longer optional — single-node training caps out around 32B dense in FP16 even on H200 (141GB) hardware.

    Main topics: The State of Open-Source LLM Training in 2026, Meet the Three Anchor Models, Choosing Full Fine-Tune LoRA or QLoRA, Setting Up the Training Environment, Preparing the Dataset, The Actual Training Run, Evaluation That Isn’t Theatre, Deployment, Common Pitfalls and Debugging.

    The State of Open-Source LLM Training in 2026

    The open-source LLM landscape in May 2026 bears little resemblance to that of early 2024. Two structural shifts have transformed what a single engineer can accomplish alone.

    The first shift is architectural. Mixture-of-Experts (MoE) models, in which each token activates only a small subset of total parameters, have become the dominant configuration for any model larger than 30B. A dense model uses every weight on every token; an MoE model uses a router to direct each token to a small fraction of “expert” sub-networks. Qwen3.5-122B-A10B has 122B total parameters but only approximately 10B active per forward pass. GPT-OSS-120B contains 117B total parameters with 5.1B active. The runtime FLOPs resemble those of a small model; the VRAM footprint does not.

    The second shift concerns post-training tooling. QLoRA, in which the base weights are frozen at 4-bit NF4 (NormalFloat-4, a quantisation format optimised for the distribution of neural network weights) and only a small low-rank adapter is trained, has moved from a research curiosity in 2023 to the default starting point in 2026. LoRA (Low-Rank Adaptation) retains 90 to 95 per cent of full fine-tuning performance. QLoRA retains 80 to 90 per cent while reducing VRAM by approximately 75 per cent compared with FP16.

    The practical implication is as follows: a 7B model that required approximately 14GB of VRAM to fine-tune in FP16 now fits in 5 to 6GB under QLoRA. A 70B model that required approximately 140GB now fits in 46GB. The hardware threshold has dropped sufficiently that the question has shifted from whether training is affordable to what should be trained.

    Three Open-Source LLMs at a Glance (May 2026) Qwen3.6-27B Dense, multimodal Total params: 27B Active per token: 27B Architecture: Dense Attention: Gated DeltaNet (linear + self-attn hybrid) Context: 262K native (extensible to 1M) Modalities: Vision + text Released: 2026-04-22 License: Apache 2.0 Best for: Single-GPU fine-tuning, multimodal agents, long context tasks Qwen3.5-122B-A10B MoE, sparse Total params: 122B Active per token: ~10B Architecture: MoE Attention: Gated DeltaNet (linear + self-attn hybrid) Context: 262K native (extensible to 1M+) Modalities: Text Released: 2026-02-24 License: Apache 2.0 Best for: Cheap inference, scale via tensor parallelism, reasoning workloads GPT-OSS-120B MoE, MXFP4 native Total params: 117B Active per token: 5.1B Architecture: MoE Attention: Standard (grouped-query) Context: 128K Modalities: Text Released: Aug 2025 License: Apache 2.0 Best for: Single 80GB GPU serving, reasoning near o4-mini, drop-in OpenAI replacement

    The implications for a practitioner intending to train a model today are as follows: prosumer hardware—a single H100 or H200, or even a 48GB consumer card such as the RTX 6000 Ada—can handle QLoRA on models up to 70B. Beyond that point, multi-GPU LoRA or sharded full fine-tuning is required. Specific recipes for each scenario are presented below.

    Pretraining from scratch—the 2.1 million H100-hour run that produced GPT-OSS-120B—remains out of reach for almost all practitioners. Within reach, however, is taking one of these three checkpoints and adapting it to a particular dataset, domain, or task. This is what “training an open-source LLM” means in practice in 2026.

    Key Takeaway: Training in 2026 almost always means fine-tuning a released checkpoint. The interesting choice is not pretraining versus fine-tuning but rather which fine-tuning method and which base model to use.

    The Three Anchor Models

    Three models cover the practical range of what is fine-tuned today: a dense 27B model that fits comfortably on prosumer hardware, a sparse 122B model that requires cluster-class memory but inexpensive compute, and a 117B MoE model that ships pre-quantised to fit on a single 80GB card.

    Qwen3.6-27B

    Released on 22 April 2026 by Alibaba’s Qwen team. Dense: every one of the 27 billion parameters participates in every forward pass. It uses Gated DeltaNet, a hybrid attention scheme that combines a linear-attention path (constant memory cost per token) with traditional softmax self-attention. The linear path handles long-range context, while the softmax path preserves short-range precision.

    Native context is 262,144 tokens, extensible to one million via position-encoding extrapolation. The model is natively multimodal: the same checkpoint accepts images and text. A “Thinking Preservation” mechanism maintains a chain-of-thought reasoning mode and a fast non-thinking mode within a single set of weights.

    Benchmark figures from the Qwen team include SWE-bench Verified 77.2 (compared with Qwen3.5-397B-A17B at 76.2), SWE-bench Pro 53.5 (compared with 50.9), Terminal-Bench 2.0 59.3 (compared with 52.5), and SkillsBench 48.2 (compared with 30.0). A 27B dense model surpassing its 397B MoE predecessor on code-related work is the kind of result that re-establishes the importance of architecture choice.

    The model can be downloaded from the QwenLM/Qwen3.6 official repository or the Hugging Face Qwen/Qwen3.6-27B mirror. The licence is Apache 2.0: commercial use is permitted with attribution.

    Qwen3.5-122B-A10B

    Released on 24 February 2026. A sparse MoE: 122 billion total parameters, approximately 10 billion active per forward pass. The “A10B” suffix denotes the active-parameter count. Each token is routed through a small subset of experts, while the remainder of the network remains idle for that token.

    The model shares the Gated DeltaNet hybrid attention of Qwen3.6-27B and the same 262K native context, extensible to 1M+. It is text-only at this size. The MoE structure means inference compute resembles that of a 10B model, but VRAM must still hold all 122B weights, because the router cannot determine in advance which expert any given token will require.

    This is the appropriate model when strong quality is required alongside inexpensive per-token serving. The active-parameter count determines latency and energy cost; the total parameter count determines hardware purchasing decisions. The trade-off is frequently misunderstood on first encounter.

    GPT-OSS-120B

    OpenAI’s first open-weight LLMs since GPT-2 (2019), released in August 2025. The model contains 117 billion total parameters with 5.1 billion active, under an Apache 2.0 licence. It was trained on NVIDIA H100 GPUs using PyTorch with custom Triton kernels. The training run consumed 2.1 million H100-hours, which at $2 per hour in cloud pricing represents approximately $4.2 million in compute alone.

    What makes GPT-OSS-120B unusual is that it ships post-trained with MXFP4 quantisation on the MoE weights. MXFP4 is a 4-bit floating-point format with a shared scale per micro-block. Because the bulk of the parameter count resides in the MoE expert layers, quantising those layers to 4-bit reduces the on-disk and in-VRAM footprint sufficiently to fit on a single 80GB GPU (H100 or AMD MI300X). The non-expert layers remain at higher precision.

    The benchmark posture indicates near-parity with OpenAI’s o4-mini on core reasoning. For a model that can run on a single rented GPU, this is a notable result. The model card and weights are available at huggingface.co/openai/gpt-oss-120b; the official repository is at github.com/openai/gpt-oss; the launch announcement is at openai.com/index/introducing-gpt-oss.

    Attribute Qwen3.6-27B Qwen3.5-122B-A10B GPT-OSS-120B
    Total params 27B 122B 117B
    Active params 27B (dense) ~10B 5.1B
    Architecture Dense, Gated DeltaNet MoE, Gated DeltaNet MoE, grouped-query attn
    License Apache 2.0 Apache 2.0 Apache 2.0
    Release date 2026-04-22 2026-02-24 August 2025
    Native context 262K (extensible to 1M) 262K (extensible to 1M+) 128K
    Multimodal Yes (vision + text) Text only Text only
    Download HF: Qwen/Qwen3.6-27B HF: Qwen/Qwen3.5-122B-A10B HF: openai/gpt-oss-120b

     

    Choosing Full Fine-Tune, LoRA, or QLoRA

    Three fine-tuning methods cover essentially the entire field. They occupy positions along a cost-versus-quality spectrum, and the appropriate choice depends on the volume of available data and the degree to which the target domain differs from the base model’s training distribution.

    Full fine-tuning updates every parameter. It requires approximately four times the model’s memory footprint during training: model weights, gradients, optimizer states (two for AdamW: first and second moment), and activations. A 7B model requires approximately 14GB in FP16 for weights alone; with optimizer states and gradients, peak usage approaches 60GB.

    LoRA (Low-Rank Adaptation) freezes the base weights and inserts trainable low-rank matrices into the attention projection layers. Instead of updating the full weight matrix W (for example, 4096×4096 = approximately 16.7M parameters), two small matrices B (4096×r) and A (r×4096) are trained, where r is typically 8, 16, or 32. The model effectively learns ΔW = B·A, which is added to the frozen W at inference. For r = 16, this amounts to approximately 131K trainable parameters per layer rather than 16.7M, roughly 128 times fewer.

    QLoRA extends LoRA further. The frozen base weights are quantised to 4-bit NF4 (NormalFloat-4, designed to match the typical Gaussian distribution of neural network weights), and LoRA adapters sit on top in FP16 or BF16. The weights are de-quantised on the fly only during forward and backward passes. Memory consumption decreases by approximately 75 per cent compared with FP16 training.

    Cost vs Quality Spectrum: Fine-Tuning Methods Lower VRAM & cost Higher VRAM & cost 100% 50% 0% Quality retention (% of full FT) Prompting / RAG ~0 VRAM Quality: ~60-70% QLoRA 80-90% 7B: ~6GB | 70B: ~46GB $10-16 single H100 LoRA 90-95% 7B: ~16GB | 70B: ~160GB 2-4× H100 for 70B Full FT 100% (baseline) 7B: ~60GB | 70B: ~560GB 8× H100, $250-510

    Method VRAM (7B) VRAM (70B) Wall time (1 H100) Cost (cloud) Quality retention
    Full FT ~60 GB ~560 GB (needs 8×H100) 24-48h on 8×H100 $250-510 100% (baseline)
    LoRA ~16 GB ~160 GB (2-4 GPUs) 10-15h $20-40 90-95%
    QLoRA ~6 GB ~46 GB (1 H100/H200) 8-12h $10-16 80-90%

     

    How LoRA and QLoRA Work W₀ (frozen) Base weights d × d e.g. 4096 × 4096 = 16.7M params LoRA: FP16 QLoRA: 4-bit NF4 No gradient. No optimizer state. + B d × r init to 0 · A r × d Gaussian init = W = W₀ + B·A Effective weight For r = 16: B = 4096×16 = 65K A = 16×4096 = 65K 131K trainable vs 16.7M dense ~128× fewer Per attention projection Quantize to NF4 (QLoRA only) ~75% VRAM saved

    The practical selection heuristic is to begin with QLoRA. If quality is insufficient after a sweep over rank, learning rate, and data size, the next step is LoRA. Full fine-tuning should be reserved for cases in which the domain shift is so substantial that the base model’s representation is genuinely wrong—for example, a model trained predominantly on English required to operate in a low-resource language. The 80 to 90 per cent quality retention of QLoRA is sufficient for the majority of production tasks.

    Tip: A LoRA rank (r) of 16 serves as a sensible default. It should be increased to 32 or 64 only if the task differs substantially from the base model’s training distribution. Higher rank consumes more VRAM and rarely provides benefits beyond r ≥ 16 for most domains.

    VRAM Budget by Model and Mode 600 480 360 240 120 60 0 VRAM (GB) H100 = 80GB H200 = 141GB Qwen3.6-27B 54 14 22 ~270 Qwen3.5-122B 244 62 ~80 ~600+ GPT-OSS-120B 234 35* ~75 ~560 Inference FP16 Inference 4-bit QLoRA training Full FT training (peak) * MXFP4 native

    It is worth noting that GPT-OSS-120B’s 4-bit inference figure (approximately 35 GB) is substantially lower than Qwen3.5-122B’s 62 GB despite similar total parameter counts. This is the advantage of MXFP4-native quantisation. Qwen3.5 must be quantised after training (AWQ or GPTQ), incurring some additional accuracy loss; GPT-OSS-120B was post-trained with the 4-bit format already in mind.

    Setting Up the Training Environment

    Three years ago, this section would have been considerably more complex: CUDA versions, PyTorch builds, mismatched Triton, and broken bitsandbytes. In May 2026 the process remains finicky, but the recipe is more stable.

    The requirements are CUDA 12.6 or newer (CUDA 12.8 ships well with the H100/H200 SXM5 drivers), cuDNN 9.5 or newer, PyTorch 2.7 stable or 2.8 nightly, and recent versions of transformers, peft, accelerate, trl, bitsandbytes, and vllm. Flash Attention 3 requires Hopper (H100/H200) or newer; on Ampere (A100), Flash Attention 2 is the fallback.

    The cleanest approach uses a Docker container that pins all of these versions. Building locally is the second-cleanest option. Operating in a bare Python environment invites an evening of debugging mismatched CUDA symbols. Containerising the training environment with a known-good base image, typically nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04, is the standard approach.

    A working pyproject.toml for a fine-tuning project as of May 2026 is shown below:

    [project]
    name = "llm-finetune"
    version = "0.1.0"
    requires-python = ">=3.11"
    dependencies = [
        "torch==2.7.0",
        "transformers==4.50.2",
        "peft==0.14.1",
        "bitsandbytes==0.46.0",
        "accelerate==1.4.0",
        "trl==0.16.0",
        "datasets==3.5.0",
        "unsloth==2026.5.3",
        "flash-attn==3.0.1",
        "vllm==0.9.2",
        "wandb==0.19.5",
        "sentencepiece==0.2.0",
        "tiktoken==0.7.0",
        "lm-eval==0.4.7",
    ]
    
    [tool.uv]
    index-strategy = "unsafe-best-match"
    
    [[tool.uv.index]]
    name = "pytorch-cuda128"
    url = "https://download.pytorch.org/whl/cu128"
    

    A Dockerfile producing a known-good training image is shown below:

    FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04
    
    ENV DEBIAN_FRONTEND=noninteractive \
        PYTHONUNBUFFERED=1 \
        PIP_NO_CACHE_DIR=1 \
        HF_HOME=/workspace/.cache/huggingface \
        TORCH_CUDA_ARCH_LIST="9.0;10.0"
    
    RUN apt-get update && apt-get install -y --no-install-recommends \
            python3.11 python3.11-venv python3-pip git curl ca-certificates \
            build-essential ninja-build cmake \
        && rm -rf /var/lib/apt/lists/*
    
    RUN curl -LsSf https://astral.sh/uv/install.sh | sh
    ENV PATH="/root/.local/bin:${PATH}"
    
    WORKDIR /workspace
    COPY pyproject.toml uv.lock ./
    RUN uv sync --frozen --no-dev
    
    # Flash Attention 3 needs to compile against the installed torch
    RUN uv pip install --no-build-isolation flash-attn==3.0.1
    
    COPY . .
    
    CMD ["uv", "run", "python", "-m", "train"]
    

    The framework landscape in 2026 is as follows: TRL is HuggingFace’s official trainer for SFT (supervised fine-tuning) and reinforcement learning post-training. Axolotl is a YAML-config layer on top of TRL that handles much of the data-preparation boilerplate. Unsloth is a Triton-optimised custom kernel package that claims up to twice the training speed and 60 per cent lower VRAM consumption through hand-tuned kernels, and is now stable enough for production use. torchtitan is Meta’s reference scaffolding for large-scale pretraining and full fine-tuning with FSDP2.

    Framework Primary use case Scaling target Ergonomics Recent activity
    TRL SFT, DPO, GRPO, PPO 1-8 GPUs, single node Python API, flexible Very active
    Axolotl SFT, DPO with YAML config 1-8 GPUs YAML, low boilerplate Active
    Unsloth Single-GPU QLoRA/LoRA, speed 1 GPU (multi-GPU in 2025) Drop-in trainer wrapper Very active
    torchtitan Full FT, pretraining at scale Multi-node FSDP2 Lower-level, reference impl Active

     

    For 99 per cent of fine-tuning work, the appropriate choice is TRL (or Unsloth-wrapped TRL for single-GPU work). torchtitan should be reserved for full fine-tuning at 70B and above, where FSDP2 sharding becomes mandatory.

    Preparing the Dataset

    The training loop is the straightforward part. Dataset preparation is where most fine-tuning runs silently fail. The model trains for hours and the loss curve appears acceptable, but at evaluation the outputs are wrong, off-topic, or in the wrong format. The problem is almost always located upstream of the trainer.

    A modern instruction-tuning dataset consists of a list of conversations. Each conversation is a list of turns, and each turn has a role (system, user, assistant) and content. JSONL is the standard on-disk format, with one JSON object per line. Parquet is the high-performance alternative for large datasets requiring columnar compression.

    A single training example for Qwen3.6-27B appears as follows:

    {"messages": [
      {"role": "system", "content": "You are an industrial controls expert."},
      {"role": "user", "content": "Explain why a PID loop on a cobot joint might oscillate after a payload change."},
      {"role": "assistant", "content": "When the end-effector payload changes, the effective inertia at the joint changes. The PID gains were tuned for the original inertia. With higher effective mass, the proportional term applies the same torque correction to a slower-responding system, which can overshoot the setpoint and drive sustained oscillation. Two practical fixes: gain scheduling (load the right PID set when the gripper picks up the part) or an adaptive controller (online inertia estimation). The first is easier to validate in safety review."}
    ]}
    

    The tokenizer then applies the model’s chat template—a Jinja-style template defined inside tokenizer_config.json—to convert that list of turns into a single tokenised sequence with the model’s special tokens. For Qwen3.6, the chat template wraps each turn in <|im_start|>role\ncontent<|im_end|>. For GPT-OSS-120B, the harmony format with <|start_of_turn|> and channel markers is used. These are not interchangeable. A model trained with the wrong template and inferred with the correct one will behave as though it had never been trained.

    Chat Template: From Conversation to Training Sequence Input: Structured messages role: system “You are a Python expert.” role: user “Why does my asyncio.gather() block?” role: assistant “asyncio.gather() awaits the collected futures. If you wrap a blocking call without to_thread() the whole loop stalls…” apply_chat_template() + tokenizer.encode() Qwen chat template output <|im_start|>system You are a Python expert. <|im_end|> <|im_start|>user Why does my asyncio.gather() block? <|im_end|> <|im_start|>assistant asyncio.gather() awaits the collected futures. If you wrap a blocking call without to_thread() the whole loop stalls… <|im_end|> Loss mask: System + user tokens: ignore_index = -100 Assistant tokens: train normally CRITICAL: GPT-OSS uses harmony format, NOT <|im_start|>. Templates are not portable.

    The standard loss-masking pattern is as follows: the model is trained to predict assistant tokens, but the loss is masked (set to -100, the standard ignore_index for PyTorch’s CrossEntropyLoss) on system and user tokens. It is undesirable to teach the model to generate user messages.

    A representative data-loading pipeline for Qwen3.6-27B, using the HuggingFace datasets library, is shown below:

    from datasets import load_dataset
    from transformers import AutoTokenizer
    
    MODEL_ID = "Qwen/Qwen3.6-27B"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    
    def format_example(example):
        """Apply Qwen's chat template and tokenize."""
        text = tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False,
        )
        return {"text": text}
    
    ds = load_dataset("json", data_files="data/train.jsonl", split="train")
    ds = ds.map(format_example, remove_columns=ds.column_names)
    
    # Train/eval split with a fixed seed for reproducibility
    split = ds.train_test_split(test_size=0.05, seed=42)
    train_ds, eval_ds = split["train"], split["test"]
    
    print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
    print("Sample formatted text:")
    print(train_ds[0]["text"][:500])
    

    Before training, two additional passes should be performed on the dataset. First, deduplication: exact-match dedup is inexpensive (a hash per example), while MinHash or SimHash near-dedup catches paraphrases. Duplicates inflate the loss curve and bias the model toward memorising common patterns.

    Second, a contamination check: it must be ensured that none of the training data overlaps with the evaluation benchmarks. If the evaluation is MMLU and the training data was scraped from Common Crawl, there is a real probability that MMLU questions are present. A substring search of evaluation questions against the training set should be conducted, with any matches removed.

    When data preparation is sufficiently complex to warrant orchestration, Airflow data pipelines are a suitable fit, as the dedup, contamination check, and tokenisation steps map well to a directed acyclic graph.

    Caution: The most common training failure is also the most silent: chat template mismatch. The output fed to the trainer should always be verified with tokenizer.apply_chat_template to confirm that it matches the format expected by the model. The first 1000 characters of a tokenised example should be printed before any long run.

    The Actual Training Run

    Three concrete recipes are presented below, covering the three anchor models across three hardware budgets. Each provides a known-working starting point from which learning rate, rank, and data mixture may be tuned.

    End-to-End Training Pipeline 1. Data prep dedup, filter hours-days (offline) 2. Tokenize chat template minutes (cached) 3. Forward compute logits ~50-200ms/step 4. Loss backward + grads ~70-300ms/step 5. Optimizer AdamW step ~10-30ms/step Repeat for N steps per epoch 6. Eval held-out set loss every N steps 7. Checkpoint save adapter / weights every K steps or best eval 8. Benchmark lm-eval-harness end of training Total wall time (QLoRA 27B, single H100, 50K examples, 3 epochs): ~8-12 hours end-to-end | per-step: ~150-400ms | eval: every 500 steps | checkpoint: every 1000 steps

    Recipe 1: QLoRA on Qwen3.6-27B, Single H100 (80GB)

    This is the most accessible setup. One rented H100 from Lambda Labs, RunPod, or a comparable cloud provider costs approximately $1.80 to $2.50 per hour as of May 2026. With 50,000 training examples and three epochs, the target wall time is eight to twelve hours, for a total bill of $10 to $16. This is the recipe most teams actually use.

    # train_qlora_qwen36.py
    import torch
    from transformers import (
        AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    )
    from peft import LoraConfig, prepare_model_for_kbit_training
    from trl import SFTConfig, SFTTrainer
    from datasets import load_dataset
    
    MODEL_ID = "Qwen/Qwen3.6-27B"
    OUTPUT_DIR = "out/qwen36-27b-qlora"
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",        # NormalFloat-4
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,   # nested quantization of the quant constants
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.padding_side = "right"  # important: right-pad for SFT
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_3",
        device_map="auto",
        trust_remote_code=True,
    )
    model = prepare_model_for_kbit_training(model)
    model.config.use_cache = False  # cache is not used during training; saves VRAM
    
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,             # alpha/r = 2 is a common starting ratio
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
    )
    
    train_ds = load_dataset("json", data_files="data/train.jsonl", split="train")
    eval_ds  = load_dataset("json", data_files="data/eval.jsonl",  split="train")
    
    sft_config = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=3,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,   # effective batch = 16
        gradient_checkpointing=True,     # trade compute for VRAM
        learning_rate=2e-4,              # LoRA-typical; full FT would use ~1e-5
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        optim="paged_adamw_8bit",        # 8-bit optimizer to save more VRAM
        bf16=True,
        max_seq_length=4096,
        packing=True,                    # pack short examples to maximize GPU use
        eval_strategy="steps",
        eval_steps=500,
        save_steps=1000,
        save_total_limit=3,
        logging_steps=20,
        report_to="wandb",
        seed=42,
    )
    
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=sft_config,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        peft_config=peft_config,
    )
    
    trainer.train()
    trainer.save_model(OUTPUT_DIR)
    

    The principal design choices in the script merit explanation:

    • NF4 with double quantisation: NF4 quantises the weights themselves; double quantisation additionally quantises the per-block scaling constants, saving a further approximately 0.4 bits per parameter on average.
    • Gradient checkpointing: activations are recomputed during the backward pass rather than stored. This reduces activation memory by approximately the square root of the sequence length at a cost of roughly 30 per cent additional compute. The trade is almost always worthwhile for LoRA and QLoRA.
    • Gradient accumulation: with a per-device batch size of 2 and accumulation steps of 8, the effective batch is 16. This is useful when VRAM constrains the per-step batch but the optimisation signal of a larger batch is desired.
    • Paged AdamW 8-bit: optimiser states (first and second moments) at 8-bit precision, with paging to CPU when not in use. Reduces optimiser-state memory by a factor of four compared with FP32 AdamW.
    • Packing: concatenates multiple short examples into one sequence up to max_seq_length. Without packing, padding to 4096 tokens wastes most of the compute on short examples.

    Recipe 2: Multi-GPU LoRA on Qwen3.5-122B-A10B

    122B total parameters corresponds to approximately 244GB in FP16 for the weights alone. Two H200s (141GB each, 282GB combined) or four H100s (320GB combined) handle this comfortably with tensor parallelism. The accelerate configuration below specifies FSDP2 with the model sharded across eight GPUs.

    # accelerate_config_fsdp.yaml
    compute_environment: LOCAL_MACHINE
    distributed_type: FSDP
    mixed_precision: bf16
    num_processes: 8
    num_machines: 1
    machine_rank: 0
    gpu_ids: all
    
    fsdp_config:
      fsdp_version: 2
      fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
      fsdp_transformer_layer_cls_to_wrap: Qwen3MoeDecoderLayer
      fsdp_sharding_strategy: FULL_SHARD
      fsdp_state_dict_type: SHARDED_STATE_DICT
      fsdp_offload_params: false
      fsdp_use_orig_params: true
      fsdp_sync_module_states: true
      fsdp_cpu_ram_efficient_loading: true
      fsdp_activation_checkpointing: true
    

    Launch the run with: accelerate launch --config_file accelerate_config_fsdp.yaml train_lora_qwen35.py

    The training script is structurally similar to Recipe 1, with three changes: no BitsAndBytesConfig (LoRA rather than QLoRA), device_map=None (FSDP manages placement), and per-device batch size reduced to 1 with accumulation steps increased to maintain an effective batch of approximately 32. Wall time for 50K examples over three epochs on 8× H100 is approximately 18 to 24 hours.

    FSDP2 / ZeRO-3: Sharding Across GPUs Naive Data Parallel (DDP) Each GPU holds full model + grads + optim state GPU 0 Params Grads Optim GPU 1 Params Grads Optim GPU 2 Params Grads Optim GPU 3 Params Grads Optim FSDP2 / ZeRO-3 Sharded Each GPU holds 1/N of each state GPU 0 P/4 G/4 O/4 GPU 1 P/4 G/4 O/4 GPU 2 P/4 G/4 O/4 GPU 3 P/4 G/4 O/4 Per-GPU memory: 70B model in BF16, 4 GPUs DDP (no sharding): ~560 GB/GPU (overflows 80GB H100 by 7×) ZeRO-2 (grads+optim): ~280 GB/GPU (still overflows) FSDP2 / ZeRO-3: ~140 GB/GPU (fits on H200, tight on H100) FSDP2 + 8× GPUs: ~70 GB/GPU (fits comfortably on H100)

    Recipe 3: Multi-Node Full Fine-Tune on GPT-OSS-120B

    Full fine-tuning a 117B MoE is genuinely expensive. The model weights in BF16 alone occupy approximately 234GB. With the addition of gradients, optimiser states (AdamW = twice the parameter count, in FP32 at 8 bytes each, approximately 940GB), and activations, cluster-class storage is required. The lower bound is 32 H100 GPUs across four nodes, using torchtitan with FSDP2 sharding across all 32 GPUs and tensor parallelism within each node.

    For most use cases this is not the appropriate path. Even with full fine-tuning, there is a risk of losing the post-training calibration and safety tuning baked into the released checkpoint. The pragmatic path for GPT-OSS-120B is LoRA with rank 32, with the adapter applied to attention and MoE expert gate projections only.

    Setup Combined VRAM What it can train
    Single H100 QLoRA 80 GB Up to ~70B with QLoRA; Qwen3.6-27B comfortably
    Single H200 QLoRA 141 GB Up to ~120B with QLoRA; comfortable 70B LoRA
    2× H200 LoRA 282 GB Full LoRA on Qwen3.5-122B-A10B with FSDP2
    8× H100 LoRA 640 GB LoRA on any model up to ~200B with sharding
    8× H100 full FT 640 GB Full FT up to ~70B with FSDP2 + activation checkpointing
    32× H100 multi-node 2,560 GB Full FT on 120B+ MoE; small pretraining runs

     

    Across all three recipes, the choice of optimiser matters more than is commonly appreciated. AdamW with a cosine learning rate schedule and 3 per cent warm-up is the strong default. For LoRA, the learning rate is typically 1e-4 to 2e-4—substantially higher than the 1e-5 to 5e-5 used for full fine-tuning—because LoRA’s adapter layers begin near zero and require larger steps to learn meaningful deltas. Checkpoints should be saved every 1000 steps. Adapter-only (PEFT) checkpoints are preferable to full-model checkpoints; they are approximately one hundred times smaller.

    For systematic optimisation of learning rate and rank, Bayesian hyperparameter optimisation with Gaussian processes is efficient. Random search is acceptable when the additional complexity is not warranted; grid search is almost never worthwhile for LoRA.

    Substantive Evaluation

    Most fine-tuning evaluation amounts to theatre. The model is trained, training loss decreases, an “evaluation” runs on a sliver of the training set (or the same data slightly shuffled), and the team declares success. The model is then deployed to production, where it underperforms.

    Substantive evaluation requires three properties: the evaluation data must not have been observed during training; the evaluation metric must measure the actual task rather than a proxy; and the metric must be reproducible across runs.

    For general language understanding and reasoning, the standard benchmarks are MMLU (multi-task language understanding across 57 subjects), HumanEval (function-completion code), GSM8K (grade-school mathematics word problems), and MT-Bench (multi-turn instruction following, judged by a strong LLM). For code-heavy use cases, SWE-bench Verified and Terminal-Bench 2.0 are the current standards.

    The community-standard tool is lm-evaluation-harness from EleutherAI, which runs the model against a registered benchmark suite in a reproducible manner:

    lm_eval \
      --model hf \
      --model_args pretrained=out/qwen36-27b-qlora,trust_remote_code=True \
      --tasks mmlu,gsm8k,humaneval \
      --batch_size auto \
      --output_path eval_results/qwen36-qlora.json
    

    The contamination problem is real and frequently neglected. If the training data was scraped from the public web, there is a non-trivial probability that benchmark questions are present. The decontamination check consists of an n-gram (typically 8-gram) overlap test between the training set and each benchmark’s question text, with any matches removed from training. Without this check, evaluation scores represent an upper bound that obscures the effect of contamination.

    Reading the Training Loss Curve 0 Training steps N Loss Healthy: monotonic decline eval loss tracks train loss Overfitting: eval rises while train keeps falling Loss spike → likely bad data batch Grad explosion → NaN lr too high, no clipping Diagnostic checklist — Healthy: smooth curve, eval ~= train — Overfit: stop early, more data, regularize — Spike: inspect batch at step N, dedup — Explosion: lower lr, add grad clipping Set grad_clip=1.0 as default; rerun from last good ckpt.

    Beyond standard benchmarks, a domain-specific evaluation set should be held out, constructed from realistic prompts drawn from the actual use case. Benchmark suites measure general capability; a custom evaluation set measures whether the model performs better at the relevant task. The two metrics frequently disagree, and the custom set is the one that ultimately matters.

    Tip: Construct the held-out evaluation set before fine-tuning begins, and store it at a separate file path that the training code cannot access. The temptation to inspect and “improve” the evaluation set after a poor run is a silent destroyer of meaningful evaluation.

    Deployment

    When training is complete, the adapter or full checkpoint resides in a directory and must be served.

    The two standard serving stacks in 2026 are vLLM and SGLang. vLLM has the broadest support and is the production default for most teams. SGLang is faster for structured-output workloads (JSON, regex-constrained generation) and provides superior RadixAttention KV-cache reuse for repeated-prefix workloads such as RAG and multi-turn chat.

    Both implement continuous batching, a serving technique that keeps the GPU saturated by dynamically inserting new requests into the batch as existing requests complete, rather than waiting for the whole batch to finish. The throughput multiplier of continuous batching over static batching is typically a factor of three to five, sometimes more.

    Deployment Serving Stack Checkpoint PEFT adapter + base model Quantize AWQ / GPTQ / MXFP4 / FP8 vLLM / SGLang continuous batching KV cache PagedAttention block-based Clients OpenAI- compatible Throughput multiplier from continuous batching (vs static batching, same GPU) Static: 1× Continuous batching: ~3× + PagedAttention + prefix cache: ~5× or more on RAG workloads Measured tokens/second for concurrent 256-request streams Quantization trade-offs at inference time FP16 / BF16: baseline quality, 2× VRAM of int4 AWQ (Activation-aware Weight Quant): 4-bit, ~0.5pp quality loss, fast kernels in vLLM GPTQ: 4-bit, post-training, slightly lower quality than AWQ but broader compatibility MXFP4: 4-bit FP w/ block scale; GPT-OSS-120B trained with it; cleanest precision/cost trade

    For a fine-tuned Qwen3.6-27B served on a single H100, the launch command is as follows:

    vllm serve out/qwen36-27b-qlora \
      --host 0.0.0.0 \
      --port 8000 \
      --max-model-len 32768 \
      --dtype bfloat16 \
      --enable-lora \
      --lora-modules my-adapter=out/qwen36-27b-qlora \
      --gpu-memory-utilization 0.92 \
      --enable-prefix-caching \
      --tensor-parallel-size 1
    

    The serving endpoint exposes an OpenAI-compatible API at http://localhost:8000/v1. On the client side, it functions as a direct substitute for the OpenAI SDK:

    from openai import OpenAI
    
    client = OpenAI(
        base_url="http://localhost:8000/v1",
        api_key="EMPTY",  # vLLM ignores the key by default
    )
    
    response = client.chat.completions.create(
        model="my-adapter",
        messages=[
            {"role": "system", "content": "You are an industrial controls expert."},
            {"role": "user",   "content": "What causes oscillation after a payload change on a cobot joint?"},
        ],
        temperature=0.2,
        max_tokens=512,
    )
    
    print(response.choices[0].message.content)
    

    If the deployment forms part of a larger application, the serving pods may be run on Kubernetes with a GPU-aware scheduler. For tool-augmented workflows, tool calling support in vLLM via Hermes-style JSON output operates by default for Qwen3.6 and GPT-OSS. For broader integrations, the Model Context Protocol (MCP) is emerging as the de facto integration standard for tool-using LLM applications.

    Common Pitfalls and Debugging

    Most training failures derive from a small set of recurring mistakes. Awareness of these in advance saves substantial debugging time.

    Chat template mismatch. Previously noted, but worth repeating because it is the most common silent failure. The training-time template and the inference-time template must be identical. A fully tokenised example with special tokens visible (tokenizer.decode(input_ids, skip_special_tokens=False)) should be printed before beginning any long run.

    Out-of-memory mid-training. The loss curve appears acceptable for 5,000 steps, after which a single long sequence in a batch exceeds the activation memory budget. The remedy is to lower max_seq_length, enable packing=True with a sequence cap, or reduce per-device batch size and increase gradient accumulation to compensate.

    Tokenizer drift. The base model has been loaded with one tokenizer revision and inference performed with another, causing the vocabulary or special-token IDs to shift. The tokenizer commit hash should be locked explicitly: AutoTokenizer.from_pretrained(MODEL_ID, revision="abc123def...").

    Loss spikes. A large upward jump in loss at a specific step almost always indicates a bad batch—corrupted data, a tokenisation error on a single example, or an unusually long sequence. The data at that step should be inspected. If recurrence is rare, gradient clipping (max_grad_norm=1.0) should be added and training resumed from the last good checkpoint.

    Evaluation/training distribution mismatch. Training loss is low, while evaluation loss is high and fails to improve. The evaluation set is drawn from a different distribution from the training set. Either the evaluation set should be drawn from the same source as the training data (with a fresh seed split), or the gap should be accepted as a measure of generalisation rather than a training failure.

    Gradient explosion. Loss diverges to NaN within a few steps. The learning rate is too high for the task, gradient clipping has been omitted, or the data contain an extreme outlier in numerical features. Training should restart with learning_rate halved and max_grad_norm=1.0.

    MoE-specific: expert collapse. Specific to MoE training (Qwen3.5-122B, GPT-OSS-120B). The router learns to route everything to one or two experts, and the remainder of the model atrophies. The mitigation is an auxiliary load-balancing loss, which TRL and torchtitan include by default; this should nonetheless be verified as enabled rather than silently overridden by a configuration setting.

    Caution: Training should always be launched with W&B (or an equivalent) logging enabled, and the loss curve should be reviewed every few hundred steps. Detecting a failure in the first hour costs an hour; detecting it at the twelve-hour evaluation costs a day and the cloud bill.

    FAQ

    Can these models be fine-tuned on a consumer GPU such as an RTX 4090?

    Qwen3.6-27B can be fine-tuned on a 4090 with QLoRA. The 24GB of VRAM on a 4090 is tight but workable with gradient checkpointing, a paged 8-bit optimiser, and a short sequence length (approximately 2048 tokens). Qwen3.5-122B-A10B and GPT-OSS-120B require at least 80GB of VRAM, which corresponds to H100/H200/MI300X-class hardware. The released GPT-OSS-120B can be served (though not trained) on a single 80GB card due to MXFP4 quantisation.

    How much data is actually required?

    Less than is commonly expected. For domain adaptation with LoRA or QLoRA, 5,000 to 20,000 high-quality examples are sufficient for most domains. Quality matters considerably more than quantity: a tightly curated 10,000-example set consistently outperforms a noisy 100,000-example set. For format adaptation (teaching the model a new structured output schema), 1,000 to 2,000 examples often suffice.

    How does this compare with using a managed API?

    The two represent different problem spaces. Managed APIs (OpenAI, Anthropic) excel in convenience and access to the latest models. Self-hosted fine-tuned models excel in cost per million tokens at scale, data sovereignty, custom domain adaptation, and predictable cost (no per-call billing). The crossover point is typically around 100M tokens per month; below this, managed services are usually preferable, and above it, self-hosted is usually cheaper.

    What is the quality difference between LoRA and full fine-tuning?

    LoRA retains 90 to 95 per cent of full fine-tuning quality across most tasks. QLoRA retains 80 to 90 per cent. The remaining gap is largest on tasks requiring substantial representational shift from the base model—for example, fine-tuning an English-pretrained model to operate fluently in a low-resource language. For typical instruction tuning, code adaptation, or structured-output tasks, the gap is sufficiently small that the cost savings of LoRA dominate.

    Should continued pretraining precede instruction tuning?

    Only when the domain is genuinely far from the base model’s training distribution—medical literature, legal contracts in a non-English language, or highly specialised scientific notation. For most domains, the base model has sufficient coverage that instruction tuning alone closes the gap. Continued pretraining is expensive and easily mishandled, with the principal risk being catastrophic forgetting of the base model’s general competence.

    References

    Conclusion

    Training open-source LLMs in 2026 is no longer the closed activity it was two years ago. The combination of Apache 2.0 base models with frontier-class reasoning (GPT-OSS-120B approaching o4-mini), QLoRA on a single rented GPU, and serving infrastructure capable of handling thousands of concurrent users on commodity hardware has placed production-grade LLM customisation within reach of any team with a modest budget and a clear use case.

    The three anchor models cover the practical range: Qwen3.6-27B for the single-GPU dense workflow, Qwen3.5-122B-A10B for inexpensive MoE serving when multi-GPU capacity is available, and GPT-OSS-120B for single-GPU serving of a frontier-class reasoner enabled by MXFP4. None of these is universally “best”; each addresses different questions about hardware, latency, and quality.

    The principal challenge is no longer the technology; it is the data—assembling, deduplicating, formatting, and contamination-checking a dataset that actually teaches the model the intended behaviour. The trainer runs in eight hours. The dataset takes eight weeks. Planning should be adjusted accordingly.

  • Kubernetes Pods Explained: Why Connecting to a Database Pod Is Hard

    This article examines the architecture of Kubernetes pods and explains why directly connecting an external client to a database running inside a pod is more involved than the equivalent task with a standalone Docker container. The discussion is grounded in the networking model that Kubernetes uses and in the Service abstraction that the model requires. A common experience for engineers new to Kubernetes is that kubectl exec into a Postgres pod followed by psql -h localhost works as expected, while a parallel attempt from a developer laptop, using the pod IP reported by kubectl get pods -o wide, times out with no error. The credentials, the database, and the apparent network are the same, yet the second connection never completes. This outcome is not the result of a defect; it is a direct consequence of how the platform is designed, and understanding the design is the first step toward connecting to in-cluster databases in a reliable manner.

    Summary

    What this post covers: A practical, code-first explanation of Kubernetes pods, the flat-IP networking model that makes the cluster tick, and the specific reasons that connecting a database container to clients outside its pod is harder than running docker run -p 5432:5432 postgres.

    Key insights:

    • Pod IPs are ephemeral; the moment a pod restarts, the address you memorized is gone, which is why hard-coded connection strings break in ways that look like network failures.
    • ClusterIP — the default Service type — only exists inside the cluster, so the IP that kubectl get svc shows you is unreachable from a laptop without explicit forwarding.
    • Stateful workloads like Postgres need StatefulSets and PersistentVolumeClaims, not plain Deployments, or you will lose data the first time a pod reschedules to another node.
    • kubectl port-forward is wonderful for local development and dangerous in production — it tunnels through the API server and bypasses normal auth and network policies.
    • Kubernetes 1.36, released in April 2026, promoted User Namespaces, Mutating Admission Policies, and Fine-Grained Kubelet API Authorization to GA, all of which tighten the security defaults that govern who can talk to what inside a cluster.

    Main topics: Why Kubernetes Exists in the First Place, The Pod: Smaller Than a VM, Bigger Than a Container, The Flat Networking Model That Nobody Warns You About, Services: How Pods Actually Find Each Other, Why Connecting Directly to a Database Pod Falls Apart, Three Connection Patterns That Actually Work, Kubernetes 1.36 and What Changed in 2026

    Why Kubernetes Exists in the First Place

    Docker addressed a genuine packaging problem by allowing an application and its dependencies to be shipped as a single reproducible image that behaves consistently on a developer laptop, a continuous integration runner, and a production virtual machine. Readers who have not yet worked through the container model will find the Docker containers, from dev to production guide a useful prerequisite for the material that follows. Once an organization operates several dozen containers distributed across several dozen servers, however, Docker alone becomes insufficient. Several questions arise that the single-host model does not answer: which host should run a given container, what should happen if that host fails outside business hours, how can a new version be rolled out without dropping in-flight requests, how do containers on one host discover containers on another, and how should compute and memory budgets be enforced.

    Kubernetes is the answer that became the industry consensus. It originated inside Google as a re-implementation of the Borg system and was open-sourced in 2014. Kubernetes is best understood as a cluster operating system: the operator declares the desired state of the workload, and a chain of controllers continuously reconciles the actual state of the cluster with that declaration. The unit that Kubernetes manages is not a container directly but a pod, a wrapper around one or more tightly coupled containers that share a network identity and storage. All higher-level objects, including Deployments, Services, StatefulSets, Jobs, and CronJobs, are abstractions that ultimately specify which pods should exist, where they should run, and how they should be exposed.

    Kubernetes Cluster Architecture Control Plane (master) kube-apiserver REST front door auth + admission talks to everything etcd key/value store all cluster state single source of truth scheduler picks node for pod resources, taints, affinity rules controller-manager replicaset, deployment, node, endpoint, job, reconciliation loops Worker node 1 kubelet | runtime | kube-proxy pod: api 10.244.1.5 pod: worker 10.244.1.6 pod: cache 10.244.1.7 Worker node 2 kubelet | runtime | kube-proxy pod: api 10.244.2.5 pod: ingest 10.244.2.6 pod: postgres-0 10.244.2.7 (StatefulSet) Worker node 3 kubelet | runtime | kube-proxy pod: web 10.244.3.5 pod: cron 10.244.3.6 pod: postgres-1 10.244.3.7 (replica) All node-to-control-plane traffic is mediated by the API server. Pods talk to each other directly through the CNI overlay.

    The control plane performs the coordination function, while worker nodes execute the workload. Each worker runs three components in its base layer. The kubelet is the agent that receives instructions from the API server and applies them to the local node. The container runtime, which is now almost always containerd or CRI-O, executes the containers themselves; Docker as a runtime was deprecated in version 1.20 and removed in version 1.24. The kube-proxy process programs the kernel iptables or IPVS rules that route Service IPs to actual pod endpoints. Pods are scheduled on top of these three components.

    The Pod: Smaller Than a VM, Bigger Than a Container

    A pod is the smallest deployable unit in Kubernetes. The simplest pod runs a single container, but the abstraction exists precisely because the smallest unit that an engineer sometimes needs to ship is more than one container. A common configuration places a primary application container alongside a sidecar container that handles logging, TLS termination, metric scraping, or database proxying. All containers in the same pod share a single network namespace, which means they can communicate over localhost, and they can share filesystem volumes. They are always scheduled together onto the same node, and their lifecycles are linked: they are created together, restarted together, and terminated together.

    Anatomy of a Pod Pod: api-789f-bc4 — one IP, one DNS name, one lifecycle Container: app FastAPI on:8000 image: app:1.4.2 talks to sidecar over localhost:6432 writes logs to /var/log/app Sidecar: pgbouncer connection pool listens on:6432 forwards to postgres.db.svc:5432 shares network namespace with app Sidecar: log-tail vector / fluent-bit image: log-agent:3.0 reads /var/log/app via shared volume ships to Loki over Service DNS Shared network namespace — same 10.244.2.5, same loopback Containers reach each other on localhost. Outside the pod, they all appear as one IP. Shared volumes emptyDir /var/log/app (in-memory) and configMap /etc/app/config mounted into all three

    The manifest below shows the minimum specification for a pod. Three details are worth noting. The apiVersion: v1 field indicates that pods are part of the core Kubernetes API. The specification contains a single container running an Nginx image. There is no top-level restart policy, which reflects the fact that bare pods are not self-healing: when a node fails, the pod fails with it. Bare pods are therefore rarely used in production. They are useful primarily for one-off tests and as a teaching device.

    # pod.yaml — the simplest possible pod
    apiVersion: v1
    kind: Pod
    metadata:
      name: hello-pod
      labels:
        app: hello
    spec:
      containers:
      - name: web
        image: nginx:1.27
        ports:
        - containerPort: 80
        resources:
          requests:
            cpu: "50m"
            memory: "64Mi"
          limits:
            cpu: "200m"
            memory: "128Mi"
    

    The object that an operator actually deploys is a Deployment, a controller that maintains a desired number of identical pods, handles rolling updates, and recreates pods when nodes fail. The Deployment owns a ReplicaSet, which in turn owns the pods. Operators rarely reference pods directly. Instead, they reference the Deployment, and Kubernetes manages the underlying pods.

    # deployment.yaml — a real workload
    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: api
      labels:
        app: api
    spec:
      replicas: 3
      selector:
        matchLabels:
          app: api
      strategy:
        type: RollingUpdate
        rollingUpdate:
          maxSurge: 1
          maxUnavailable: 0
      template:
        metadata:
          labels:
            app: api
        spec:
          containers:
          - name: api
            image: ghcr.io/acme/api:1.4.2
            ports:
            - containerPort: 8000
            readinessProbe:
              httpGet: { path: /healthz, port: 8000 }
              initialDelaySeconds: 5
            livenessProbe:
              httpGet: { path: /livez, port: 8000 }
              initialDelaySeconds: 15
            env:
            - name: DB_HOST
              value: "postgres.db.svc.cluster.local"
            - name: DB_PORT
              value: "5432"
    
    Tip: Always set both requests, which are the resources the scheduler reserves for the pod, and limits, which are the ceilings the kernel enforces. Without requests, the scheduler treats the pod as requiring no resources and may place it on an already saturated node. Without limits, a runaway process can starve every other workload on the host.

    The Flat Networking Model and Its Implications

    The Kubernetes networking model rests on four rules that appear simple on the surface but carry substantial implications for how traffic flows inside a cluster:

    1. Every pod gets its own IP address, drawn from a cluster-wide CIDR range that does not overlap with the node IPs.
    2. Pods on the same node can communicate without NAT.
    3. Pods on different nodes can communicate without NAT.
    4. The IP a pod sees as its own is the same IP that other pods see when they talk to it.

    The fourth point carries the most weight. In a typical single-host Docker configuration, a container holds a private IP on a bridge network, and outbound traffic is translated through the host using NAT. Kubernetes deliberately avoids this arrangement. Every pod is a first-class participant in a single flat network, regardless of the physical machine that hosts it. The component that implements this property is a CNI plugin (Container Network Interface), of which several mature implementations exist, including Calico, Cilium, Flannel, Weave, the AWS VPC CNI, and the native plugin used by GKE. These plugins implement the same contract but differ in their mechanisms, which range from overlays based on VXLAN tunnels, to BGP route advertisement, to native cloud routing, to eBPF-based data planes.

    Flat Pod-to-Pod Networking (no NAT, one CIDR) Node A — 192.168.10.11 pod-a1 10.244.1.5 pod-a2 10.244.1.6 CNI plugin (Calico/Cilium) veth + bridge + routing programs kernel routes Node B — 192.168.10.12 pod-b1 10.244.2.5 pod-b2 10.244.2.6 CNI plugin VXLAN/IPIP/BGP/native tunnel or route exchange Node C — 192.168.10.13 postgres-0 10.244.3.7 pod-c2 10.244.3.5 CNI plugin programs route to other nodes 10.244.0.0/16 known Underlay network — physical/virtual switching between nodes 192.168.10.0/24 (node CIDR). Pod traffic encapsulated or routed natively across this. pod-a1 talking to postgres-0 sees: src=10.244.1.5 dst=10.244.3.7 No SNAT. No DNAT. postgres-0 sees the real source IP of pod-a1. This is the property that makes mTLS, audit logs, and network policies meaningful.

    This flat addressing scheme is convenient for application code, which simply connects to an IP address and proceeds, but it is demanding for operators who must reason about traffic flows. Every pod is mutually addressable inside the cluster, which means that in the absence of explicit policies, every pod is able to reach every database, every cache, and every internal API. This property becomes important in a later section, which explains why directly addressing a database pod is more fragile than it appears.

    Services: How Pods Actually Find Each Other

    Because pod IPs change on every restart, they cannot appear in a connection string. The Kubernetes solution to this problem is a Service, a stable virtual IP and DNS name that load-balances traffic to a set of pods identified by labels. Individual pods come and go, but the Service remains. Inside the cluster, every Service automatically receives a DNS name of the form service-name.namespace.svc.cluster.local, which is resolved by CoreDNS, the built-in DNS resolver of the cluster.

    Service Types — Where Each One Is Reachable From ClusterIP (default) Scope cluster-internal only virtual IP from Service CIDR Reachable from other pods (yes) node terminal (yes) laptop (no) internet (no) Use for internal APIs, databases, caches, message brokers NodePort (simple external) Scope every node IP + high port 30000-32767 Reachable from other pods (yes) laptop on VPC (yes) internet (if firewall) but ugly ports Use for on-prem clusters without an LB, debugging, bare-metal demos LoadBalancer (cloud LB) Scope public IP from cloud provider (NLB/ALB/CLB) Reachable from internet (yes) any TCP/UDP port L4 load balancing ~$15-25/mo per LB Use for non-HTTP services (gRPC, raw TCP) one LB per Service externalTrafficPolicy Ingress (L7 HTTP) Scope path/host-based routing on:80/:443 single LB for many Reachable from internet (yes) HTTP/HTTPS only TLS terminated not for Postgres Use for web APIs, microservices, SaaS multi-tenant cert-manager + TLS

    Service type Scope Typical use Port range Downside
    ClusterIP Inside cluster only Databases, caches, internal APIs Any (virtual) Unreachable from outside without help
    NodePort Every node’s IP + high port On-prem clusters, debugging 30000–32767 Ugly URLs, every node exposes it
    LoadBalancer Public IP from cloud LB Non-HTTP services to internet Any TCP/UDP Costs money, one LB per Service
    Ingress L7 HTTP/HTTPS routing Web apps, REST/gRPC over HTTP 80, 443 HTTP only — will not route Postgres

     

    How a Pod Finds a Service — Step by Step Client pod api-789-bc4 10.244.1.5 DB_HOST = postgres.db.svc 1. DNS query CoreDNS cluster DNS resolver runs as pod in kube-system 10.96.0.10 (kube-dns) 2. returns ClusterIP Service: postgres type: ClusterIP 10.96.42.7:5432 virtual IP — no host 3. TCP to 10.96.42.7:5432 kube-proxy on each node — iptables / IPVS / nftables rules Watches Services + EndpointSlices. Rewrites destination IP from the virtual ClusterIP to an actual pod IP (round-robin or session-affinity) — entirely in the kernel. No userspace hop. The packet never visits a proxy process. 4. DNAT to a pod postgres-0 10.244.3.7:5432 label: app=postgres postgres-1 10.244.4.7:5432 label: app=postgres postgres-2 10.244.5.7:5432 label: app=postgres EndpointSlice objects keep this set up to date as pods come and go.

    The manifest below defines a ClusterIP Service for a Postgres pod. The selector field is worth attention, because the Service matches by labels rather than by name. Any pod that carries the label app: postgres in the same namespace automatically becomes a backend.

    # service.yaml — ClusterIP for Postgres
    apiVersion: v1
    kind: Service
    metadata:
      name: postgres
      namespace: db
    spec:
      type: ClusterIP            # default, can omit
      selector:
        app: postgres
      ports:
      - name: pg
        port: 5432              # the Service port
        targetPort: 5432        # the container port
        protocol: TCP
    

    From any other pod inside the cluster, the command psql -h postgres.db.svc.cluster.local -p 5432 will succeed. The same command issued from a developer laptop will hang indefinitely. The next section examines the reasons for this gap in detail.

    Why Direct Connections to a Database Pod Fail

    The assumption that breaks down for engineers coming from a plain Docker workflow is the idea that a container can be reached as long as one knows its IP address and port. In Kubernetes, almost every element of that assumption is incorrect. The pod IP is real but private, ephemeral, and exists only on a network shared between the nodes of a single cluster. The container port is open inside the container but is not automatically exposed at any higher layer. There is no host-level port-publishing equivalent to docker run -p 5432:5432; the field hostPort exists but is discouraged for production use. The following paragraphs examine each failure mode in turn.

    Why “psql -h 10.244.3.7” From Your Laptop Hangs Your laptop 192.168.0.42 “psql -h 10.244.3.7” no route, no return SYN sent into the void timeout Internet / Cloud VPC routes only to node IPs 10.244.0.0/16 is NOT advertised externally drops the packet Cluster boundary Even if you reached a node, kube-proxy would not DNAT for an arbitrary pod IP. There is no Service entry for the raw IP. no rule matches → packet rejected Inside the cluster (other pods) 10.244.3.7 IS reachable — until postgres-0 restarts and becomes 10.244.3.18. A connection string pinned to 10.244.3.7 fails the next deploy. Hence: never use pod IPs. Five hidden failure modes most people hit before giving up 1. Pod IP changed because the pod restarted → old IP belongs to nothing now. 2. ClusterIP Service exists but you are connecting from outside the cluster → no external route. 3. NetworkPolicy denies all ingress to db namespace by default → even valid traffic dropped. 4. Postgres bound to 127.0.0.1 inside container → listening but not on the pod IP. 5. pg_hba.conf rejects the source CIDR → TCP handshake succeeds, auth fails silently. 6. Cloud security group blocks the node port even when NodePort is configured correctly.

    Pod IPs are ephemeral. The moment a pod restarts, for any reason ranging from a node reboot, to a failed liveness probe, to a manual kubectl rollout restart, to an eviction by the scheduler, the new pod receives a new IP address from the address pool that the CNI manages. Any client that retains a reference to the previous IP is now communicating with nothing, or in the worst case, with whatever pod has been allocated the recycled address. This is the reason pod IPs should never be written into a configuration file. The correct address to record is a Service DNS name, which CoreDNS resolves at lookup time.

    The ClusterIP is not visible from outside the cluster. The Service IP that kubectl get svc reports, such as 10.96.42.7 in the earlier example, is a virtual IP. It does not belong to any physical or virtual network interface and exists only as an entry in the iptables tables that kube-proxy maintains on each node. A laptop outside the cluster has no route to the Service CIDR 10.96.0.0/12, and even a statically added route would not help, because no kernel outside the cluster contains the rules required to translate that virtual address.

    Pods do not use the host network by default. Setting hostNetwork: true on a pod causes the container to share the network namespace of the node, with the consequence that the container port maps directly to a port on the node. This configuration is used by CNI agents, node-exporter, and similar infrastructure components. Applying it to a database, however, is poor practice: IP isolation is lost, port collisions become possible, and any node failure takes the database with it, since the address is tied to a specific host and cannot be moved.

    NetworkPolicies can explicitly deny traffic. When the cluster runs a CNI that supports NetworkPolicy, which most modern plugins do, operators can write rules such as “only pods labeled role: api in the app namespace may connect to pods labeled app: postgres in the db namespace on port 5432.” When a default-deny baseline is in place and no allow rule has been written, all traffic is dropped. When no policies are present at all, all traffic is permitted, which presents its own security concerns.

    # networkpolicy.yaml — only the api can talk to postgres
    apiVersion: networking.k8s.io/v1
    kind: NetworkPolicy
    metadata:
      name: postgres-allow-api
      namespace: db
    spec:
      podSelector:
        matchLabels:
          app: postgres
      policyTypes:
      - Ingress
      ingress:
      - from:
        - namespaceSelector:
            matchLabels:
              name: app
          podSelector:
            matchLabels:
              role: api
        ports:
        - protocol: TCP
          port: 5432
    

    The container port is not automatically exposed at the node level. Docker users are accustomed to -p 5432:5432, which binds a host port to a container port. Kubernetes provides no equivalent automatic mapping. The containerPort field in a pod specification is documentation: it informs operators and tooling that the container intends to listen on the indicated port, but it does not open a path through any higher layer. External reachability requires a Service of the appropriate type and, in cloud environments, a security group rule that permits traffic to whichever node port the cloud load balancer or NodePort uses.

    Databases are stateful, and stateful pods require stateful controllers. A plain Deployment treats its pods as interchangeable replicas. The Deployment will reschedule postgres-0 from node 2 to node 5 when node 2 becomes unhealthy, mounting whichever PersistentVolume is available, or no volume at all if the PersistentVolumeClaim has been deleted. A database instead requires a StatefulSet, which assigns each pod a stable identity such as postgres-0 or postgres-1, a stable per-pod DNS name served by a headless Service, and a stable PersistentVolumeClaim that remains attached to the same ordinal across reschedules. A misconfiguration in this area is a common cause of data loss for teams new to running databases on Kubernetes.

    The request path is long, and any single weak link breaks it. When an external client reaches a pod, the request typically traverses the following sequence: client, public DNS, cloud load balancer, node IP, iptables DNAT, pod IP, container port, Postgres listener, pg_hba.conf check, and finally authentication. A misconfiguration at any stage, such as an incorrect TLS certificate, a security group blocking the load balancer health check, a pg_hba.conf rule that denies the source CIDR, or a Postgres listener bound to 127.0.0.1 inside the container rather than 0.0.0.0, produces a connection failure that appears identical to a network problem from the perspective of the client.

    Failure mode Symptom Root cause Proper workaround
    Pod IP in connection string Works for hours, then suddenly times out after a restart CNI re-allocated IP to a different pod Use Service DNS name (postgres.db.svc.cluster.local)
    Laptop connecting to ClusterIP TCP timeout, no error No route from laptop to Service CIDR Use kubectl port-forward or a bastion
    Default-deny NetworkPolicy Within-cluster traffic also dropped No explicit allow rule for the source Write a targeted ingress NetworkPolicy
    Postgres bound to 127.0.0.1 Connection refused even inside cluster listen_addresses not set to * Fix postgresql.conf in the image/ConfigMap
    Pod rescheduled, lost data Tables empty after a node failure Deployment used instead of StatefulSet, no PVC StatefulSet + PVC + headless Service
    pg_hba.conf rejects source “no pg_hba.conf entry for host” error Pod CIDR not allowed Add cluster pod CIDR to pg_hba.conf
    LoadBalancer reachable but SG blocks Timeout from internet Cloud security group does not allow 5432 Open SG to client IPs, lock to known sources

     

    Caution: Operators tempted to expose a production database to the public internet through a LoadBalancer should reconsider whether such exposure is necessary. The preferred design is to keep the database internal to the cluster and to route application traffic through a hardened API tier. An internet-facing Postgres listener on port 5432 is among the most heavily attacked surfaces on the public internet.

    Three Reliable Connection Patterns

    Three legitimate patterns exist for connecting a client to a database that runs in a pod, and the appropriate choice depends primarily on the location of the client. Selecting among them is largely a question of which client requires the connection and for how long.

    Three Patterns — Pick the One That Matches Your Client A — In-cluster app ClusterIP + DNS app pod DB_HOST=postgres.db.svc CoreDNS → ClusterIP 10.96.42.7:5432 postgres-0 pod 10.244.3.7:5432 Best for production app traffic, CronJobs, Airflow DAGs, message workers B — Local developer kubectl port-forward laptop psql connects to localhost:5432 kubectl port-forward SPDY tunnel via API server postgres-0 pod (direct) no Service involved Best for debugging, migrations, one-off admin queries. NEVER for prod traffic. C — External app LoadBalancer + TLS + auth external app postgres.example.com:5432 cloud LB (NLB) SG: allow client CIDR postgres pod (via Service) TLS + strong auth required Best for analytics replica only, otherwise route through an API tier instead.

    Pattern A: In-cluster application to in-cluster database

    This pattern is the default and the most reliable choice. The application pod sets DB_HOST=postgres.db.svc.cluster.local as an environment variable and opens a connection. CoreDNS resolves the name, kube-proxy translates the virtual IP into the address of a real pod through DNAT, and the connection succeeds. Pod restarts on either side remain transparent because every endpoint is named rather than pinned to a specific IP. This is also the pattern that Airflow workloads adopt when they run with the KubernetesExecutor described in the Apache Airflow data pipeline orchestration guide, in which each task is launched as a pod that reaches the database through a Service. The same pattern applies to dbt jobs running on Kubernetes and to Kafka consumer workloads running in pods.

    Pattern B: Local developer to in-cluster database

    The command kubectl port-forward opens a tunnel from a local port on a developer machine, through the Kubernetes API server, to a port on a pod. It is intended for development and one-off administrative tasks. The example below uses it against the headless Service that the next subsection defines:

    # forward localhost:5432 to the postgres-0 pod's port 5432
    kubectl port-forward -n db pod/postgres-0 5432:5432
    
    # Or forward through the headless Service to whichever endpoint is selected
    kubectl port-forward -n db svc/postgres 5432:5432
    
    # Now from another terminal, on your laptop:
    psql -h localhost -p 5432 -U app -d production
    

    The Python client below connects through the forwarded port. The connection string specifies localhost, which is correct on the developer laptop. Inside the cluster, the same code would instead specify postgres.db.svc.cluster.local.

    # dev_query.py — assumes "kubectl port-forward" is running
    import os
    import psycopg2
    from psycopg2.extras import RealDictCursor
    
    # Local dev: connect through kubectl port-forward
    # In production (in-cluster), DB_HOST would be postgres.db.svc.cluster.local
    DB_HOST = os.environ.get("DB_HOST", "localhost")
    DB_PORT = int(os.environ.get("DB_PORT", "5432"))
    DB_NAME = os.environ.get("DB_NAME", "production")
    DB_USER = os.environ.get("DB_USER", "app")
    DB_PASS = os.environ["DB_PASS"]  # required, no default
    
    def fetch_recent_orders(limit: int = 50):
        """Read the most recent orders — example dev-time query."""
        with psycopg2.connect(
            host=DB_HOST,
            port=DB_PORT,
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASS,
            connect_timeout=5,
            sslmode="require",   # still enforce TLS even on port-forward
        ) as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    "SELECT id, customer_id, total_cents, created_at "
                    "FROM orders ORDER BY created_at DESC LIMIT %s",
                    (limit,),
                )
                return cur.fetchall()
    
    if __name__ == "__main__":
        rows = fetch_recent_orders()
        for row in rows:
            print(row)
    
    Caution: kubectl port-forward bypasses NetworkPolicies because the tunnel travels through the kubelet rather than as pod-to-pod traffic. Any user who holds pods/portforward RBAC permission on the namespace can reach the database, regardless of the NetworkPolicy configuration. The verb should therefore be treated as a form of production database access and subjected to audit logging.

    Pattern C: External application to in-cluster database

    This is the pattern about which most teams should hesitate. When an application outside the cluster needs to read from or write to the database, the preferred architecture is almost always to expose an API over HTTP or gRPC through an Ingress with TLS and authentication, and to let the API mediate access to the database. Legitimate cases for direct external access nevertheless exist, including analytics tools, business intelligence dashboards, and replication to external systems. In those cases the pattern takes the following shape: a Service of type LoadBalancer backed by the database pods, fronted by a cloud network load balancer, with the security group restricted to specific client CIDRs, mandatory TLS, and a credential rotation policy. When a managed database such as Amazon RDS, Google Cloud SQL, or Aurora can be substituted, that option is usually preferable. Operating Postgres inside Kubernetes is technically feasible, but it represents a significant operational commitment.

    The StatefulSet plus headless Service pattern

    StatefulSet + Headless Service for a Database Headless Service — ClusterIP: None postgres.db.svc.cluster.local resolves to ALL pod IPs (DNS A records, one per pod) Plus per-pod names: postgres-0.postgres.db.svc, postgres-1.postgres.db.svc, postgres-2.postgres.db.svc postgres-0 (primary) postgres-0.postgres.db.svc postgres container image: postgres:16.3 role: primary accepts writes PVC: data-postgres-0 storageClass: gp3-ssd size: 200 GiB accessMode: RWO stays with postgres-0 postgres-1 (replica) postgres-1.postgres.db.svc postgres container image: postgres:16.3 role: replica (streaming) read-only PVC: data-postgres-1 independent volume full replica copy stays with postgres-1 survives reschedule postgres-2 (replica) postgres-2.postgres.db.svc postgres container image: postgres:16.3 role: replica (streaming) read-only PVC: data-postgres-2 independent volume full replica copy stays with postgres-2 stable identity Writes go to postgres-0.postgres.db.svc. Reads can fan out to all three. Identity survives reschedule.

    A headless Service is the object produced when clusterIP: None is set in the specification. Rather than allocating a virtual IP, this configuration produces DNS A records, with one record per pod backend. When combined with a StatefulSet, the result is a set of stable per-pod hostnames, such as postgres-0.postgres.db.svc.cluster.local and postgres-1.postgres.db.svc.cluster.local. This naming arrangement is precisely what a primary-replica database deployment requires. The application directs writes to the hostname of the primary and reads to the hostname of any replica.

    # headless service + statefulset for postgres
    apiVersion: v1
    kind: Service
    metadata:
      name: postgres
      namespace: db
      labels:
        app: postgres
    spec:
      clusterIP: None          # headless — no virtual IP
      selector:
        app: postgres
      ports:
      - name: pg
        port: 5432
        targetPort: 5432
    ---
    apiVersion: apps/v1
    kind: StatefulSet
    metadata:
      name: postgres
      namespace: db
    spec:
      serviceName: postgres    # MUST match the headless Service name
      replicas: 3
      selector:
        matchLabels:
          app: postgres
      template:
        metadata:
          labels:
            app: postgres
        spec:
          terminationGracePeriodSeconds: 30
          containers:
          - name: postgres
            image: postgres:16.3
            ports:
            - containerPort: 5432
              name: pg
            env:
            - name: POSTGRES_PASSWORD
              valueFrom:
                secretKeyRef:
                  name: postgres-secret
                  key: password
            - name: PGDATA
              value: /var/lib/postgresql/data/pgdata
            volumeMounts:
            - name: data
              mountPath: /var/lib/postgresql/data
            readinessProbe:
              exec:
                command: ["pg_isready", "-U", "postgres"]
              initialDelaySeconds: 10
              periodSeconds: 5
            resources:
              requests:
                cpu: "500m"
                memory: "1Gi"
              limits:
                cpu: "2"
                memory: "4Gi"
      volumeClaimTemplates:
      - metadata:
          name: data
        spec:
          accessModes: ["ReadWriteOnce"]
          storageClassName: gp3-ssd
          resources:
            requests:
              storage: 200Gi
    

    Production databases almost always benefit from a purpose-built operator layered on top of this scaffolding, such as CloudNativePG, the postgres-operator developed by Zalando, or Crunchy PGO. These operators handle primary election, streaming replication, backups, point-in-time recovery, and rolling minor-version upgrades. Selecting an appropriate database backend is a separate concern; the database comparison for preprocessed time-series data serves as a useful companion reference for that decision.

    Key Takeaway: Pod IPs are an internal implementation detail of the cluster and should never serve as the target of a client connection. Inside the cluster, use Service DNS names. From a developer laptop, use kubectl port-forward. For external clients, use a managed load balancer, or preferably an API tier placed in front of the database. Stateful workloads should always combine a StatefulSet, a PersistentVolumeClaim, and a headless Service.

    Kubernetes 1.36 and What Changed in 2026

    Kubernetes 1.36 is the most recent minor release as of this writing in May 2026, and it continues the project’s emphasis on stronger security defaults and on first-class support for AI workloads. According to the official release page (Source: kubernetes.io/releases, as of 2026-05-20), the project actively maintains release branches for the three most recent minor versions, currently 1.34, 1.35, and 1.36. Version 1.33 entered maintenance on 2026-04-28 and reaches end of life on 2026-06-28. The release cadence is rapid enough that operators running anything older than 1.33 are already outside the supported window.

    Source: kubernetes.io/releases, as of 2026-05-20
    Version Released Status Key features
    1.36 April 2026 Latest, fully supported User Namespaces GA, Mutating Admission Policies GA, Fine-Grained Kubelet API Authorization GA; 70 enhancements total (18 GA / 25 Beta / 25 Alpha)
    1.35 December 2025 Supported DRA improvements for GPU scheduling, Topology-aware routing refinements
    1.34 August 2025 Supported VolumeAttributesClass GA, Direct Service Return + overlay networking in Windows kube-proxy
    1.33 April 2025 Maintenance only (EOL 2026-06-28) Sidecar containers GA, in-place pod resize beta

     

    The promotion of User Namespaces to general availability is the most prominent security change in 1.36. When user namespaces are enabled, the root user inside a container is mapped to an unprivileged user on the host. This arrangement substantially reduces the impact of a container escape: even when an attacker compromises a container running as UID 0, they emerge on the host as a high-numbered unprivileged user, such as UID 100000, with no special privileges. For database pods specifically, a compromised Postgres container no longer translates directly into root access on the node. In combination with seccomp and AppArmor profiles, this change closes one of the long-standing gaps between Kubernetes security and traditional virtual machine isolation.

    Mutating Admission Policies, also promoted to general availability, bring declarative mutations expressed in the Common Expression Language (CEL) to the admission chain, replacing many uses of webhook-based mutating admission controllers. Operators can now write policies that, for example, automatically inject sidecar containers, attach labels, set default resource requests, or enforce image-registry rules, without operating a separate webhook server. The result is less infrastructure to maintain and fewer failure modes when a webhook becomes unavailable.

    Fine-Grained Kubelet API Authorization, now generally available, allows the kubelet to enforce per-verb RBAC on its own API rather than treating all operations uniformly. This change matters for hardening: tools that require nodes/proxy can be restricted to read-only operations, and the kubelet can refuse risky combinations that previously required cluster-admin privileges in order to be fully restricted.

    Beyond security, version 1.36 continues to invest in AI workload support. It introduces refinements to Dynamic Resource Allocation (DRA) for GPU scheduling, adds support for accelerator partitioning, and improves the ability of the scheduler to handle long-running training jobs alongside short-lived inference pods. The trajectory is clear: the pattern of Kubernetes as an AI platform, which grew rapidly in 2024 and 2025 as model-serving workloads migrated off bespoke infrastructure, has been a first-class concern for two consecutive release cycles. For language and runtime choices when developing operators or controllers around these new APIs, the Python and Rust comparison provides a useful framing. The controller-runtime ecosystem in Go remains dominant, but Rust-based operators are gaining ground for performance-sensitive components.

    Frequently Asked Questions

    Can a pod have more than one container?

    Yes, and it is a common design pattern. The most frequent reason is the sidecar — a helper container that does logging, TLS termination, service-mesh proxying (Envoy in Istio or Linkerd), or connection pooling. All containers in a pod share a network namespace and can share volumes, but they remain separate processes with separate filesystems. Use multiple containers when their lifecycles are genuinely coupled. If the answer to “can these scale independently?” is yes, they belong in separate pods.

    Why not just expose every database pod with a NodePort and connect directly?

    NodePort opens the same port on every node in the cluster, in the 30000–32767 range, and routes it to whichever pod backs the Service. Three problems: the port numbers are non-standard so client tooling fights you, every node becomes an attack surface for the database, and you still need a cloud security group or firewall rule to control who can hit those ports. NodePort is fine for on-prem clusters without a cloud LB or for very specific debug scenarios. It is not a substitute for proper Service architecture.

    Is kubectl port-forward safe to use in production?

    It is safe to use, but it should not be how production traffic flows. The tunnel runs through the API server and consumes API-server resources. It bypasses NetworkPolicy — if you can port-forward, you can connect, regardless of how strict your in-cluster policies are. RBAC controls who can use it, and you should treat pods/portforward on a database namespace as a sensitive verb subject to audit. For production traffic, use a real Service.

    What is the difference between a StatefulSet and a Deployment?

    A Deployment treats pods as interchangeable. It will scale up by spinning up new pods with random suffix names, scale down by killing any of them, and roll updates in parallel. A StatefulSet maintains ordered, named pods (name-0, name-1, name-2) that always come up in order, always shut down in reverse order, and each get their own stable PersistentVolumeClaim. Use Deployment for stateless apps. Use StatefulSet for anything that has identity — databases, message brokers, ZooKeeper, distributed coordination services. Kafka brokers running in Kubernetes are a textbook StatefulSet workload.

    Should I actually run my database in Kubernetes, or use a managed service?

    For most teams below the scale of needing a database engineer on the org chart, managed (RDS, Cloud SQL, Aurora, AlloyDB, Spanner) is the right answer. Operating a stateful workload well — backups, point-in-time recovery, minor-version upgrades, failover, performance tuning, observability — is a continuous engineering investment that managed services amortize across thousands of customers. Run databases in your cluster when you have a real reason: cost at scale, regulatory data residency, latency requirements that make a separate database tier unworkable, or a database that managed offerings do not provide. The operator ecosystem (CloudNativePG and friends) makes this much more tractable than it was five years ago, but it is still real work.

    The following companion guides examine the surrounding stack in greater depth:

    References

    Conclusion

    Connecting to a database that runs in a Kubernetes pod feels harder than it should because Kubernetes is solving a different problem than many engineers initially assume. It is not an elaborate replacement for docker run. It is a cluster operating system whose entire networking model is designed around the principle that pods communicate with other pods through stable abstractions, and external clients reach applications through carefully chosen entry points. The pod IP revealed by kubectl get pods -o wide is a debugging convenience rather than an address suitable for client traffic. The ClusterIP shown by kubectl get svc is a virtual construct held together by iptables rules. The correct address for production traffic originating inside the cluster is a DNS name served by CoreDNS and backed by a Service whose membership the controllers maintain. The correct address from outside the cluster is whatever the LoadBalancer, Ingress, or bastion-host configuration specifies, and it is never a pod IP.

    Three points are worth retaining from this discussion. First, kubectl port-forward is well suited to development workflows and unsuited to production traffic. Second, stateful workloads require a StatefulSet, a PersistentVolumeClaim, and a headless Service in combination, or data loss is likely. Third, in Kubernetes 1.36 and beyond, security defaults are tightening, with User Namespaces reaching general availability as the most consequential change, which benefits anyone running databases in pods. Even with these improvements, however, the number of ways in which a connection between an external client and an in-cluster database can fail remains large enough that exposing Postgres directly to the public internet is almost always inferior to placing an API tier in front of the database. The recommended approach is to build the conservative, layered version first, and to reserve more aggressive shortcuts for cases that genuinely warrant them.

  • Who Owns Anthropic? Public Company Stakes and Investor Map in 2026

    Published: 2026-05-17

    This analysis examines the publicly-traded routes to Anthropic exposure as of May 2026. Anthropic is reportedly in talks for a new funding round at a valuation of $900 billion or more, with some reports placing the upper bound near $950 billion (Bloomberg, as of 2026-05-12). That figure, set against a $380 billion post-money valuation only three months earlier (Anthropic press release, as of 2026-02-12), revives a recurrent question for public-market investors: Anthropic is private and not directly investable as a stock, and the relevant question is therefore which publicly-traded companies offer the most accessible exposure to its growth.

    The analysis is US-focused, anchored on Amazon (AMZN, NASDAQ), Alphabet (GOOGL, NASDAQ), NVIDIA (NVDA, NASDAQ), and Microsoft (MSFT, NASDAQ) as the four public-stock holders of meaningful Anthropic stakes. The horizon considered is short-term (one to three months) and mid-term (six to twelve months). The institutional and venture capital base, which includes GIC, Coatue, Sequoia, ICONIQ, Fidelity, BlackRock-affiliated funds, Goldman Sachs Alternatives, JPMorganChase, and others, receives a brief context section. The analytical weight nevertheless remains on the four public anchors because they are the only Anthropic investors that retail participants can trade directly.

    Summary

    What this post covers: A May 2026 mapping of the publicly-traded routes to Anthropic exposure, namely Amazon, Alphabet, NVIDIA, and Microsoft, sized against the reported $900 billion-plus funding talks, with materiality estimates, scenario conditions, and the institutional capital stack behind them. The discussion is provided for informational purposes only and does not constitute investment advice.

    Key insights:

    • Anthropic re-rated from a $380 billion post-Series G valuation in February 2026 to talks at $900 billion or more by May 2026, while annualised revenue scaled to approximately $30 billion in April 2026 from approximately $1 billion at the start of 2025. The re-rating has, to date, been matched by ARR growth.
    • Amazon (with up to approximately $33 billion committed) and Alphabet (with up to approximately $40 billion committed and a reported stake of approximately 14 percent) are the two material public-stock proxies. NVIDIA and Microsoft hold substantially smaller positions, at up to $10 billion and $5 billion respectively.
    • According to Fortune (April 2026), approximately half of the year-over-year increase in Google’s and Amazon’s AI-related profits in Q1 2026 came from mark-to-market accounting gains on the Anthropic stake, not from operating revenue. This makes the stake itself a matter of investor scrutiny rather than a footnote.
    • Implied marks, for example Alphabet’s notional figure of approximately $126 billion at a $900 billion valuation, are unrealised, lumpy, and contingent on a future liquidity event. Balance-sheet carrying values under US GAAP will typically be below the implied mark.
    • Across the upside, downside, and neutral scenarios, the data tend toward the neutral case: Anthropic remains private, mark-to-market gains oscillate with each round, and the translation into operating income depends primarily on cloud-segment growth at Amazon, Alphabet, and Microsoft rather than on the stake itself.

    Main topics: why Anthropic’s investor base matters to public-market investors, valuation and revenue context, the strategic public-company investors, the institutional and VC capital stack, materiality of Anthropic exposure inside each public stock, three conditional scenarios, limitations, and FAQ.

    Key Takeaways:
    • Anthropic is reportedly in talks for a new round at a $900 billion-plus valuation, up from $380 billion post-Series G in February 2026 (Bloomberg, as of 2026-05-12; Anthropic press release, as of 2026-02-12).
    • Annualized revenue reached approximately $30 billion in April 2026, up from roughly $1 billion at the start of 2025 (VentureBeat / SaaStr coverage citing Anthropic disclosures, as of 2026-04).
    • Amazon and Alphabet are the two largest strategic shareholders; combined potential commitments exceed $70 billion (Fortune, as of 2026-04-30; Data Center Dynamics, as of 2026; TechFundingNews, as of 2026).
    • NVIDIA and Microsoft joined in November 2025 with commitments of up to $10 billion and $5 billion respectively, alongside a $30 billion Anthropic Azure compute purchase (Microsoft blog, as of 2025-11-18).
    • For retail investors, the public-stock holders of Anthropic stakes are the only accessible exposure, but the valuation marks discussed below are unrealized, lumpy, and contingent on a future liquidity event.

    Why Anthropic’s Investor Base Matters to Public-Market Investors

    Anthropic is a private company. Its equity does not trade on a public exchange, and secondary-market access to private AI lab shares is restricted to qualified institutional buyers. For the typical retail investor, the only means of gaining exposure to Anthropic’s revenue and valuation trajectory is to hold a publicly-listed company that owns a stake. That set has narrowed to a small group of US large caps, and each of them now derives a measurable portion of recent reported earnings, AI optionality, or both, from its Anthropic position.

    The question is more than academic. A Fortune article published on 2026-04-30 reported that approximately half of the year-over-year increase in Google’s and Amazon’s AI-related profits in Q1 2026 came from accounting gains tied to their Anthropic stakes rather than from operating revenue. This disclosure has elevated the stake itself to a matter of investor scrutiny rather than a footnote (Fortune, as of 2026-04-30). When mark-to-market accounting, namely the practice of revaluing an asset on the balance sheet to its current implied price, is the dominant contributor to a reported earnings surprise, the durability of those earnings depends on whether the underlying private valuation holds.

    The post that follows examines the four anchor names in order of committed capital, sets out the institutional capital stack for context, and then quantifies how much of each anchor’s market capitalisation is attributable to Anthropic exposure under reasonable assumptions. Readers tracking adjacent AI compute exposure may find the AMD prospects versus NVIDIA 2026 analysis and the broader NVIDIA, AMD, and Intel semiconductor stock comparison useful as parallel framings on the accelerator side of the same AI capital cycle.

    The Anthropic Valuation and Revenue Context

    Anthropic’s valuation trajectory over the past nine months provides the denominator for every stake calculation that follows. Four data points are material, and the gaps between them indicate the pace of the re-rating.

    Data: Anthropic press releases / Bloomberg, as of 2026-05-17.

    Round Date Amount Raised Post-Money Valuation Lead(s)
    Series F Sep 2025 $13B $183B ICONIQ (lead); Fidelity, Lightspeed (co-lead)
    Strategic round Nov 2025 Up to $15B (MSFT + NVDA) ~$350B (reported) Microsoft, NVIDIA (strategic partners)
    Series G Feb 12, 2026 $30B $380B GIC, Coatue (lead); D. E. Shaw, Dragoneer, Founders Fund, ICONIQ, MGX (co-lead)
    Current talks May 2026 At least $30B (up to $50B reported) $900B-$950B (reported) Dragoneer, Greenoaks, Sequoia, Altimeter (reported co-lead)

     

    The May 2026 round is still at the talks stage rather than closed, so the $900 billion figure should be interpreted as a market-clearing indication rather than a confirmed mark (Bloomberg, as of 2026-05-12). Several definitional notes are useful before proceeding. Post-money valuation is the implied total equity value of the company immediately after a financing closes, including the new capital raised. Annualised run-rate revenue (ARR) is the most recent monthly or quarterly revenue extrapolated to a full year, and it is the metric that Anthropic and its strategic partners have used in public commentary.

    The ARR trajectory explains why the valuation has compounded so rapidly. Anthropic’s CEO Dario Amodei has publicly described 80-fold annualised growth in the first quarter of 2026, and the underlying ARR figures support that ratio when measured against the start of 2025.

    Data: VentureBeat, SaaStr, MindStudio coverage citing Anthropic disclosures, as of 2026-04.

    Date Annualized Revenue (ARR) Multiplier from baseline
    Start of 2025 ~$1B 1x
    August 2025 $5B 5x
    End of 2025 $9B 9x
    April 2026 ~$30B ~30x in 16 months

     

    Approximately 70 to 75 percent of revenue is reported to come from API consumption rather than from consumer subscriptions, with Claude Code reaching a $2.5 billion run-rate by February 2026, from a $1 billion run-rate within six months of its mid-2025 launch (VentureBeat and SaaStr coverage citing Anthropic disclosures, as of 2026-04). Anthropic disclosed more than 300,000 business customers in October 2025, and Amazon reported that more than 100,000 customers were running Claude on Amazon Bedrock as of April 2026 (Fortune, as of 2026-04-30).

    Key Takeaway: A $900 billion private valuation on approximately $30 billion of ARR implies a multiple in the high 20s. The multiple is high in absolute terms but consistent with the pricing of the most recent funding rounds, provided that growth is sustained. The multiple compresses rapidly if the next ARR print disappoints.

    The Strategic Public-Company Investors

    Four publicly-listed companies hold the largest disclosed positions in Anthropic. Each entered for different strategic reasons, including cloud distribution, model availability on a platform, and technology co-design, and each has structured its commitment so that the headline number includes both funded and milestone-tied components. The figures below distinguish the two where the disclosure allows.

    Amazon (AMZN, NASDAQ)

    Amazon is the largest single investor in Anthropic by committed capital. The original commitment was up to $8 billion, and the more recent expansion added a further $5 billion investment together with an option for up to $20 billion more tied to commercial milestones, which brings the total potential commitment to approximately $33 billion (Fortune, as of 2026-04-30; TechFundingNews, as of 2026). The strategic anchor is Amazon Web Services: Anthropic models are first-class citizens on Amazon Bedrock, the managed foundation-model service, and more than 100,000 customers now run Claude on that platform (Fortune, as of 2026-04-30).

    The most frequently cited mark on the stake is drawn from the same Fortune article, which reported that Amazon’s original $8 billion investment was worth more than $70 billion based on the implied valuation following the Series G close and the subsequent talks (Fortune, as of 2026-04-30). The figure represents paper rather than realised value. The same article observed that “half of Google’s and Amazon’s blowout AI profits came from a stake in Anthropic — not from their actual business” in the Q1 2026 earnings disclosure cycle (Fortune, as of 2026-04-30). The phrasing is the source’s, not a forecast.

    Alphabet (GOOGL, NASDAQ)

    Alphabet is the second-largest disclosed shareholder. Data Center Dynamics reported an estimated 14 percent stake based on the company’s filings and round disclosures (Data Center Dynamics, as of 2026). The most recently committed tranche was $10 billion at the $350 billion valuation reported in late 2025, with up to a further $30 billion to follow if Anthropic meets performance milestones, for a total potential commitment of $40 billion (Data Center Dynamics, as of 2026; TheStreet, as of 2026; Silicon Republic, as of 2026; TechFundingNews, as of 2026).

    Claude is also available on Google Cloud’s Vertex AI platform. The strategic logic mirrors that of Amazon: a frontier model is placed within the company’s own cloud distribution layer to capture both inference compute revenue and incremental enterprise account stickiness. Alphabet’s position is unusual in that Google operates its own competing internal model family (Gemini). The Anthropic stake therefore functions as both a hedge and a distribution position rather than as a substitute for first-party model development.

    NVIDIA (NVDA, NASDAQ)

    NVIDIA committed up to $10 billion as part of the November 2025 strategic round (Microsoft blog, as of 2025-11-18; Bloomberg, as of 2025-11; CNBC, as of 2025-11). The investment includes a technology partnership component: co-design and engineering work intended to optimise Anthropic’s models for NVIDIA’s architectures, and, conversely, to inform NVIDIA’s roadmap with frontier-model workload patterns.

    From a public-market exposure standpoint, NVIDIA is the most operationally connected of the four anchors. Anthropic’s training and inference compute already relies heavily on NVIDIA hardware purchased through hyperscaler partners, so the equity stake compounds an existing demand relationship. The exposure is therefore less concerned with a future liquidity event for Anthropic and more concerned with a self-reinforcing dynamic in which Anthropic growth drives additional NVIDIA accelerator demand.

    Microsoft (MSFT, NASDAQ)

    Microsoft committed up to $5 billion in November 2025, the smallest of the four anchors by direct investment size but paired with the largest commercial commitment in the opposite direction (Microsoft blog, as of 2025-11-18). Anthropic agreed to purchase $30 billion of Azure compute capacity, with additional compute available up to 1 gigawatt, a unit of electrical power consumption used to size data-centre deployments. Claude (Sonnet 4.5, Opus 4.1, Haiku 4.5) was added to Microsoft Foundry, the company’s enterprise model catalogue.

    The structural result is that Claude is now the only frontier model available on all three major cloud platforms, namely AWS, Azure, and Google Cloud, while Microsoft remains a major holder of competing OpenAI economics. For Microsoft shareholders, the Anthropic stake is small relative to the Azure compute commitment in the opposite direction, and the value to the equity story rests more on Azure revenue capture than on the $5 billion investment itself.

    Caution: Commitment numbers in the public disclosures combine funded investment with milestone-tied options. The headline “$33 billion” for Amazon or “$40 billion” for Alphabet represents a maximum potential commitment, not cash deployed. Investors evaluating accounting marks should separate the funded base from the option overhang when constructing their models.

    Behind the Public Names: The Institutional and VC Capital Stack

    Outside the four public-company anchors, Anthropic’s cap table reads as a comprehensive list of sovereign wealth, growth equity, and crossover funds. The Series G announcement on 2026-02-12 named GIC and Coatue as leads, with co-leads D. E. Shaw Ventures, Dragoneer, Founders Fund, ICONIQ, and MGX (Anthropic press release, as of 2026-02-12). Significant investors disclosed in the same release included Accel, Addition, Alpha Wave Global, Altimeter, AMP PBC, Appaloosa LP, Baillie Gifford, Bessemer Venture Partners, BlackRock affiliates, Blackstone, D1 Capital, Fidelity, General Catalyst, Greenoaks, Goldman Sachs Alternatives, Insight Partners, Jane Street, JPMorganChase (Security and Resiliency Initiative and Growth Equity Partners), Lightspeed, Menlo Ventures, Morgan Stanley Investment Management, NX1 Capital, Qatar Investment Authority, Sands Capital, Sequoia, Temasek, TowerBrook, TPG, Whale Rock Capital, and XN.

    The Series F announcement in September 2025 added the prior layer: ICONIQ (lead), Fidelity Management and Research Co., and Lightspeed Venture Partners (co-leads), with significant investors including Altimeter, Baillie Gifford, BlackRock-affiliated funds, Blackstone, Coatue, D1 Capital Partners, General Atlantic, General Catalyst, GIC, Goldman Sachs Alternatives, Insight Partners, Jane Street, Ontario Teachers’ Pension Plan, Qatar Investment Authority, TPG, T. Rowe Price, WCM Investment Management, and XN (Anthropic press release, as of 2025-09).

    Salesforce Ventures, the corporate venture arm of Salesforce (CRM, NYSE), is also a publicly acknowledged backer of Anthropic. No specific dollar amount for Salesforce’s stake has been confirmed in the primary sources reviewed for this analysis. Readers should not interpret the absence of a number as evidence that the stake is either small or large.

    The practical implication for public-market investors is that the most direct exposures within the institutional stack, namely GIC, Qatar Investment Authority, Temasek, and Ontario Teachers’ Pension Plan, are sovereign or pension vehicles that are inaccessible to retail buyers. Several other names (BlackRock affiliates, Fidelity, T. Rowe Price, JPMorganChase, and Goldman Sachs Alternatives) belong to publicly-listed financial groups, but the Anthropic positions sit within private alternatives sleeves rather than on the listed parent’s balance sheet in a manner that would materially move the stock. The exposure is real but small relative to those groups’ total assets.

    The Materiality of Anthropic Exposure Within Each Public Stock

    The four anchor stakes can be approximately sized against each company’s market capitalisation to gauge how much of the equity story the Anthropic position represents. The implied marks below use the funded-and-committed totals together with reported stake percentages where disclosed, and the $900 billion mid-point of the current talks for forward implied marks (Bloomberg, as of 2026-05-12).

    Data: Fortune, Microsoft Blog, Bloomberg, Data Center Dynamics; commitments include both funded and milestone-tied amounts; figures rounded.

    Investor (Ticker) Total Committed (incl. milestone-tied) Latest Reported Mark / Stake Strategic Notes
    Amazon (AMZN, NASDAQ) Up to ~$33B $8B funded base reported worth over $70B (Fortune, 2026-04-30) Largest single investor; Claude on Amazon Bedrock; 100,000+ customers on Bedrock
    Alphabet (GOOGL, NASDAQ) Up to ~$40B ~14% stake reported; at $900B valuation implies ~$126B mark Second largest; Claude on Google Cloud Vertex AI; competes with own Gemini family
    NVIDIA (NVDA, NASDAQ) Up to $10B No public mark separately disclosed Nov 2025 round; technology co-design partnership
    Microsoft (MSFT, NASDAQ) Up to $5B No public mark separately disclosed Anthropic committed $30B Azure compute purchase + up to 1 GW; Claude on Microsoft Foundry

     

    The implied marks require substantial caveats. First, the statement that the stake is worth $X billion assumes that the headline private valuation is realisable, which requires a future liquidity event such as an initial public offering, a tender, or a secondary sale at a comparable price. Mark-to-market accounting on private-company stakes is sensitive to comparable transactions, and a single down-round at a smaller AI lab can compress the entire cohort. Second, milestone-tied commitments convert into incremental equity only if Anthropic meets the underlying triggers, which are not disclosed in detail. Third, in the case of Alphabet, the 14 percent reported stake (Data Center Dynamics, as of 2026) at a $900 billion valuation produces a notional $126 billion figure, but the actual carrying value on Alphabet’s balance sheet may use a different methodology under US GAAP accounting standards for equity-method or fair-value investments. Investors examining Alphabet’s 10-Q filings will likely find the disclosed carrying value below the implied valuation mark.

    The materiality question also depends on the denominator. Against Amazon’s, Alphabet’s, NVIDIA’s, and Microsoft’s market capitalisations, each measured in trillions of US dollars in 2026, the Anthropic positions are large in absolute terms but represent a single-digit percentage of equity value for each. The Fortune disclosure regarding Q1 2026 earnings is the clearest signal that these positions are beginning to constitute more than rounding errors, particularly for Amazon and Alphabet (Fortune, as of 2026-04-30).

    Readers considering how concentrated AI exposure should be within a portfolio may find the discussion in concentration versus diversification for serious investors useful, since the four anchor names share more risk factors than a casual basket implies. For investors considering how to enter at current levels, the framing in dollar-cost averaging versus lump-sum investing addresses the timing-risk side of the same question.

    Three Conditional Scenarios for Whether Stake Value Translates to Shareholder Returns

    The directional question of whether these Anthropic stakes translate into shareholder returns for Amazon, Alphabet, NVIDIA, and Microsoft does not admit a binary answer. The conditions for each direction can nevertheless be specified concretely.

    Upside Conditions

    Anthropic completes a future liquidity event (an initial public offering, a secondary tender, or a strategic transaction) at or above the current $900 billion implied range, which allows the four public-company holders either to retain a publicly-marked position or to realise partial gains. ARR continues to compound from the approximately $30 billion April 2026 base toward the run-rates that the current valuation implies, which validates the high-20s revenue multiple (VentureBeat and SaaStr coverage, as of 2026-04). API margins expand as inference compute costs fall, which allows the cloud platform holders (Amazon, Alphabet, and Microsoft) to convert their commercial relationships into operating income rather than into balance-sheet marks alone. NVIDIA’s technology partnership generates measurable architectural design wins that drive additional accelerator revenue. The $30 billion Azure compute commitment from Anthropic delivers material Azure segment growth for Microsoft (Microsoft blog, as of 2025-11-18).

    Downside Conditions

    AI commoditisation compresses model margins and reduces the revenue trajectory below the slope implied by current valuation multiples. A down-round at Anthropic, or at a comparable frontier model lab, forces a mark-down on the public-company holders’ carrying values, with the Fortune-described “half of AI profits” disclosure pattern reversing in subsequent quarters (Fortune, as of 2026-04-30). Antitrust scrutiny, in either the United States or the European Union, requires partial divestiture of one or more strategic stakes, particularly in view of the simultaneous public-cloud, model-availability, and equity-holding combinations at Amazon and Alphabet. Geopolitical disruption of AI compute supply chains, which is covered in adjacent terms in the US-China trade war investment strategy 2026 piece and the framework in how geopolitical events affect US stocks, slows the underlying compute build-out that makes the ARR trajectory possible.

    Neutral Conditions

    Anthropic remains private indefinitely and continues to raise periodic primary capital that re-anchors the valuation mark, but without a realising event for existing holders. The accounting gains from mark-to-market continue to appear in Amazon’s and Alphabet’s quarterly disclosures but oscillate with each round’s pricing, which produces reported earnings volatility without a directional change in operating fundamentals. NVIDIA’s accelerator demand and Microsoft’s Azure capture remain healthy but are not individually attributable to the Anthropic stake rather than to broader AI infrastructure spending.

    On the basis of the available data, conditions appear to favour the neutral scenario, with the upside scenario incrementally favoured by the still-expanding ARR trajectory and the downside scenario primarily a function of valuation multiple compression risk rather than near-term operating disappointment. The observation is conditional on the $30 billion April 2026 ARR figure proving durable and on the May 2026 round closing at or near the reported $900 billion (Bloomberg, as of 2026-05-12).

    Tip: Investors interested in the asymmetry profile of holding the four anchor names should consider that the Anthropic-attributable upside is concentrated in liquidity events and earnings disclosures, while the downside is concentrated in valuation re-rating events. The framework discussed in options trading basics for US stocks covers how that asymmetry can be expressed with defined risk, although the options market does not price Anthropic-stake exposure separately from each anchor stock’s overall beta.

    Limitations of This Analysis

    The valuation marks and stake percentages cited above are drawn from press releases and reported figures that may differ from the carrying values disclosed in each public company’s regulatory filings under US GAAP. The current talks for a $900 billion round have not closed as of the publication date, so the implied marks in the materiality section should be treated as indicative rather than as realised.

    Frequently Asked Questions

    Can retail investors buy Anthropic stock directly?

    No. Anthropic is a privately-held company, and its equity does not trade on a public exchange. Secondary-market access to private AI lab shares is generally restricted to qualified institutional buyers. The four public-company anchors — Amazon (AMZN, NASDAQ), Alphabet (GOOGL, NASDAQ), NVIDIA (NVDA, NASDAQ), and Microsoft (MSFT, NASDAQ) — are the most accessible way to gain proxy exposure.

    Which public company has the largest Anthropic stake?

    Amazon (AMZN, NASDAQ) is the largest single investor by committed capital, with up to roughly $33 billion in total committed (including milestone-tied amounts). Fortune reported the original $8 billion funded base was worth more than $70 billion as of 2026-04-30 (Fortune, as of 2026-04-30). Alphabet (GOOGL, NASDAQ) is second, with an estimated 14% stake and up to $40 billion in total potential commitment (Data Center Dynamics, as of 2026).

    Has Anthropic disclosed its profitability?

    Anthropic has disclosed annualized revenue figures — approximately $30 billion as of April 2026 (VentureBeat / SaaStr coverage citing Anthropic disclosures, as of 2026-04) — but no specific profit or loss figure or cash-burn figure has been publicly confirmed in the primary sources reviewed for this analysis. Investors evaluating margin structure should treat this absence as a known information gap.

    What is the difference between Anthropic’s “committed” and “funded” investor amounts?

    Committed amounts include both capital that has already been transferred to Anthropic (funded) and amounts that will be transferred only if Anthropic meets specific commercial or performance milestones (milestone-tied options). Headline figures such as “Amazon’s $33 billion” or “Alphabet’s $40 billion” are total commitments including the milestone-tied portion. The funded base is smaller.

    How does Microsoft’s Anthropic investment relate to its OpenAI relationship?

    Microsoft committed up to $5 billion to Anthropic in November 2025 while continuing to hold significant economics in OpenAI under a separate arrangement (Microsoft blog, as of 2025-11-18). The Anthropic investment is paired with a $30 billion Anthropic Azure compute purchase commitment, indicating the relationship is structured as much around cloud capture as around exclusive model alignment. Claude is now available on Microsoft Foundry alongside other frontier models.

    References

    Investment Disclaimer: This post is provided for informational purposes only and does not constitute a recommendation to buy or sell any specific security. All investment decisions and their outcomes are the sole responsibility of the individual investor.
  • AMD vs NVIDIA in 2026: Prospects, Risks, and Conditional Scenarios

    Published: 2026-05-17

    This analysis examines the relative prospects of Advanced Micro Devices (AMD, NASDAQ) and NVIDIA (NVDA, NASDAQ) in the AI accelerator market as of May 2026. The starting point is the disparity in scale: AMD reported Data Center revenue of $5.8 billion in Q1 2026, a 57 percent increase year-over-year (AMD IR press release, as of 2026-05-05), while NVIDIA reported Data Center revenue of $62 billion in Q4 FY2026 alone, a 75 percent increase year-over-year (NVIDIA newsroom, as of 2026-02-25). The latter figure is approximately 10.7 times the former. This ratio frames the question that follows.

    The analysis is US-focused, anchored on AMD and NVIDIA as the listed comparables, and considers a short horizon of one to three months and a mid horizon of six to twelve months. Custom silicon programmes at Google (TPU), Amazon (Trainium), and Apple, and Korean high-bandwidth memory (HBM) suppliers as a connected supply-chain layer, are addressed where they materially affect the comparison. Intel (INTC, NASDAQ) is referenced only to the extent required to bound the competitive set.

    Summary

    What this post covers: A May 2026 head-to-head assessment of AMD’s prospects relative to those of NVIDIA over a one-to-three-month and a six-to-twelve-month horizon. The assessment is anchored on the Q1 2026 financial results, on the MI450 and Blackwell Ultra roadmaps, and on the Meta and OpenAI six-gigawatt deployment commitments. It is provided for informational purposes only and does not constitute investment advice.

    Key insights:

    • The revenue scale gap is approximately 10.7 times at the Data Center segment level ($62 billion at NVIDIA in Q4 FY26 versus $5.8 billion at AMD in Q1 2026), and percentage growth (73 percent versus 38 percent year-over-year for total revenue) is higher at NVIDIA in both absolute dollar terms and rate of change.
    • The 80 percent versus 5 to 7 percent AI accelerator market-share split is structurally explained by CUDA and ROCm software lock-in together with the timing gap: Blackwell Ultra is shipping at scale, while MI450 first deployments are scheduled for the second half of 2026.
    • The Meta and OpenAI six-gigawatt commitments are non-overlapping according to Lisa Su, but they become material to reported revenue only from late 2026 onward, and they do not meaningfully alter the six-to-twelve-month comparison.
    • The 53 percent versus 75.0 percent GAAP gross margin gap is the under-examined structural issue: even if AMD prevails on inference total-cost-of-ownership comparisons, NVIDIA’s margin profile affords considerably more pricing flexibility for defending share.
    • Across three conditional scenarios, namely upside (the gap narrows), downside (the gap widens), and neutral (mixed signals), the data tend toward the neutral case for the six-to-twelve-month window, and the upside case remains feasible only if the MI450 ramp executes cleanly and Data Center growth accelerates rather than decelerates.

    Main topics: why the comparison matters in May 2026, the Q1 2026 numbers side by side, product roadmaps and the AI accelerator race, market share and hyperscaler CapEx, the Meta and OpenAI 6-GW commitments, valuation and analyst positioning, three conditional scenarios for AMD relative to NVIDIA, limitations, and FAQ.

    Key Takeaways:
    • AMD Q1 2026 revenue reached $10.3 billion (+38% year-over-year), with Data Center revenue of $5.8 billion (+57% year-over-year); Q2 2026 guidance is $11.2 billion (AMD IR, as of 2026-05-05).
    • NVIDIA FY2026 revenue reached $215.9 billion (+65% year-over-year); Q4 Data Center revenue was $62 billion (+75% year-over-year); Q1 FY2027 guidance is $78 billion (NVIDIA newsroom, as of 2026-02-25).
    • NVIDIA still holds roughly 80% of the AI accelerator market, with AMD at roughly 5-7% (Silicon Analysts coverage, as of first half 2026).
    • AMD has secured separate 6-gigawatt deployment commitments from Meta and OpenAI for MI450-based systems beginning H2 2026; Lisa Su has stated the two commitments do not overlap (AMD press releases, as of 2026-02-24).
    • Whether AMD continues to close the gap depends on three concrete conditions, namely ROCm software adoption, MI450 ramp execution, and hyperscaler diversification appetite, rather than on a single binary answer.

    The Relevance of This Comparison in May 2026

    The AI accelerator market, defined as the supply of specialised graphics processing units (GPUs) and related chips used to train and run large neural networks, has expanded from approximately $55 billion in 2023 to an estimated $200 billion or more in 2026 (Silicon Analysts, as of the first half of 2026). Within this market, inference workloads, namely the running of models in production rather than the training of them, are on track to represent approximately two-thirds of total AI compute spending (Silicon Analysts, as of the first half of 2026). Inference is the segment in which AMD has consistently positioned its Instinct GPUs as the most competitive option on price and total cost of ownership.

    Three developments during the first five months of 2026 justify revisiting the relative-prospects question at this point rather than later. First, AMD reported a quarter on 2026-05-05 in which Data Center revenue grew 57 percent year-over-year to $5.8 billion (AMD IR, as of 2026-05-05). Second, NVIDIA closed fiscal 2026 with $215.9 billion in total revenue and guided Q1 FY2027 to $78 billion, a figure larger than AMD’s expected full-year 2026 Data Center revenue under most analyst models (NVIDIA newsroom, as of 2026-02-25). Third, AMD announced two separate six-gigawatt customer commitments in late February 2026, one with Meta and one with OpenAI, and AMD CEO Lisa Su confirmed that they are non-overlapping (AMD press releases, as of 2026-02-24).

    Korean memory suppliers sit one layer behind both companies. High-bandwidth memory (HBM) is the stacked DRAM used on every modern AI accelerator package, and SK Hynix (000660, KOSPI) and Samsung supply the bulk of it. The relevance to this analysis is bounded: HBM availability and pricing influence the gross margins that both AMD and NVIDIA achieve on each accelerator sold, but they do not differentiate the two companies on their own. For investors considering the broader semiconductor stack, the international stock investing piece covering markets beyond the US discusses how Korean memory equities interact with US AI compute demand.

    Readers tracking the broader US large-cap technology setup may also find the NVIDIA, AMD, and Intel semiconductor stock comparison useful as a predecessor framing, since it considered the three-company landscape before the Q1 2026 results were available.

    The Q1 2026 Figures Side by Side

    AMD and NVIDIA report on different fiscal calendars. AMD’s Q1 2026 ended on 2026-03-29 and was reported on 2026-05-05. NVIDIA’s Q4 FY2026, the most recently reported quarter, ended on 2026-01-25 and was reported on 2026-02-25 (NVIDIA newsroom, as of 2026-02-25). The table below compares each company’s most recent reported quarter on a like-for-like basis where possible. Readers should note that the periods do not align perfectly in calendar time.

    Data as of 2026-05-05 (AMD) and 2026-02-25 (NVIDIA). Sources: AMD IR press release, NVIDIA newsroom.

    Metric AMD (Q1 2026) NVIDIA (Q4 FY26)
    Total revenue $10.3B $68.1B
    YoY growth +38% +73%
    Data Center revenue $5.8B $62B
    Data Center YoY growth +57% +75%
    GAAP gross margin 53% 75.0%
    Diluted EPS (GAAP) $0.84 No GAAP EPS figure cited in this brief
    Forward-quarter guidance $11.2B (Q2 2026) $78B (Q1 FY2027)

     

    Several observations follow from this table that do not require additional data to support. AMD’s growth rate is high but trails NVIDIA’s on every comparable line: 38 percent versus 73 percent total revenue growth, and 57 percent versus 75 percent Data Center growth. AMD’s GAAP gross margin of 53 percent (AMD IR, as of 2026-05-05) versus NVIDIA’s 75.0 percent (NVIDIA newsroom, as of 2026-02-25) reflects a meaningful structural gap; NVIDIA captures approximately 22 percentage points more of each dollar of revenue as gross profit. AMD’s non-GAAP gross margin of 55 percent (AMD IR, as of 2026-05-05) and non-GAAP diluted EPS of $1.37 (AMD IR, as of 2026-05-05) reduce part of the gap on adjusted measures but do not eliminate it.

    AMD also disclosed that it has raised its long-term Data Center CPU market growth forecast to more than 35 percent (AMD IR, as of 2026-05-05). This is a market-size statement rather than a market-share claim and applies to the EPYC server CPU business rather than to Instinct GPUs.

    Tip: When comparing semiconductor businesses with different fiscal calendars, Data Center segment revenue is a more reliable anchor than total revenue. AMD continues to derive approximately 44 percent of its total Q1 2026 revenue from outside the Data Center segment ($10.3 billion total minus $5.8 billion in Data Center revenue), including the Client (PC CPU), Gaming, and Embedded segments, in which NVIDIA is either absent or substantially smaller.

    Product Roadmaps and the AI Accelerator Race

    The AI accelerator competition divides into two interrelated contests: hardware generations and the software stack that runs on them. On hardware, both vendors have moved to approximately annual cadences. On software, NVIDIA’s CUDA platform, the parallel computing API and runtime layer in which the company has invested since 2007, remains the dominant developer environment, while AMD’s ROCm (Radeon Open Compute) is the competing open-source stack.

    The product generation map below summarises the announced flagship hardware on each side. CUDA denotes Compute Unified Device Architecture, and ROCm denotes the Radeon Open Compute platform. Hopper, Blackwell, Blackwell Ultra, and MI450 are GPU architecture or product family names rather than acronyms.

    Data as of 2026-05-17. Sources: NVIDIA newsroom, AMD press releases.

    Year shipping NVIDIA flagship AMD flagship
    2023 Hopper (H100) MI300X
    2024 Hopper continued / Blackwell ramp MI325X
    2025 Blackwell MI350X (MI355X variant in MLPerf)
    2026 Blackwell Ultra MI450 (first deployments H2 2026)
    2027 Next-generation platform (no publicly disclosed name confirmed in this brief) MI450 ramp continues; subsequent generation not confirmed in this brief

     

    With respect to benchmarks, NVIDIA has marketed Blackwell Ultra with claimed performance 50 times better and cost 35 times lower than Hopper for agentic AI, namely software systems in which multiple AI models coordinate to complete multi-step tasks, based on SemiAnalysis InferenceX benchmarks (Silicon Analysts coverage, as of the first half of 2026). AMD’s MI355X delivered competitive MLPerf results across the full suite (Silicon Analysts coverage, as of the first half of 2026); MLPerf is an industry-standard benchmark consortium for AI training and inference performance.

    With respect to price-performance, AMD’s MI300X and MI325X have been characterised by independent coverage as offering prices approximately 30 to 40 percent lower than the NVIDIA equivalent on inference workloads (Silicon Analysts coverage, as of the first half of 2026). This price advantage is the strongest single argument for hyperscaler adoption, and it is the lever that AMD is most likely to use on MI450.

    The software question is more difficult to quantify. CUDA benefits from approximately two decades of developer mindshare, a fully developed ecosystem of libraries (cuDNN, cuBLAS, TensorRT, and NCCL), and deep integration with every mainstream machine learning framework. ROCm has narrowed the functional gap on major frameworks (PyTorch, TensorFlow, and JAX), but the porting effort and the long tail of niche libraries remain genuine friction. A hyperscaler that deploys tens of thousands of GPUs is concerned with both raw cost-per-token and the engineering hours required to port and maintain its inference stack. A lower hardware price does not automatically prevail if porting costs are sufficiently high.

    Caution: Vendor-published benchmarks, including SemiAnalysis-cited internal figures and MLPerf submissions, are useful as floors but not as workload-realistic ceilings. Production inference performance depends on model architecture, batch size, sequence length, quantisation, and the specific frameworks in use. The 30 to 40 percent MI3xx price advantage cited above is an industry-coverage figure rather than an audited TCO calculation.

    Market Share, Hyperscaler CapEx, and the 80 to 5-7 Percent Gap

    NVIDIA holds approximately 80 percent of the AI accelerator market on 2026 estimates, while AMD holds approximately 5 to 7 percent, with Instinct GPU revenue of approximately $7 to $8 billion in 2025 (Silicon Analysts coverage, as of the first half of 2026). The remaining 13 to 15 percent is divided among internal accelerators (Google TPU and Amazon Trainium), Intel’s Gaudi line, and smaller participants. For AMD to gain share, the share must be taken from one of three sources: NVIDIA, the custom silicon programmes, or some combination of the two.

    The potential prize is large. The five largest US hyperscalers (Microsoft, Amazon, Google, Meta, and Oracle) are guiding 2026 capital expenditures of approximately $600 to $690 billion, of which approximately 75 percent, or roughly $450 billion, is AI-related (Silicon Analysts coverage, as of the first half of 2026). Industry-wide hyperscaler AI capital expenditure for 2026 was revised upward to approximately $725 billion in Q1 2026 reporting, from a prior range of $660 to $690 billion (Silicon Analysts coverage, as of the first half of 2026). Even if accelerator silicon represents only a fraction of this capital expenditure, with the remainder allocated to power, real estate, networking, and storage, the addressable revenue pool is on the order of $200 billion or more in 2026 (Silicon Analysts coverage, as of the first half of 2026).

    Within this pool, a one-percentage-point gain in share for AMD from a base of 6 percent, to 7 percent, would correspond to approximately $2 billion of additional revenue at 2026 total-addressable-market levels, all else being equal. A five-percentage-point gain (to 11 percent) would correspond to approximately $10 billion. The shape of the share-gain trajectory is important because AMD’s reported Data Center revenue of $5.8 billion in Q1 2026 (AMD IR, as of 2026-05-05) implies an annualised run-rate of approximately $23 billion for Data Center alone, of which Instinct GPUs are only one component, the other being EPYC server CPUs. Increasing Instinct revenue alone from the 2025 level of $7 to $8 billion toward the $20 billion-plus range over 2026-2027 would require, at minimum, that the announced Meta and OpenAI MI450 deployment milestones be met on schedule.

    Custom silicon is the competitor on the other flank. Google TPU v6 is expanding beyond Google’s internal workloads to external customers, AWS Trainium 2 is being actively positioned for inference, and Apple Silicon dominates on-device inference (Silicon Analysts coverage, as of the first half of 2026). Independent industry analysis has characterised the collective custom-silicon threat as a more rapidly growing share threat to NVIDIA than AMD currently represents (Silicon Analysts coverage, as of the first half of 2026). The implication for AMD is sobering: even if NVIDIA’s share erodes meaningfully over 2026 to 2028, AMD is not the only, or even the most likely, beneficiary.

    Concentration risk in either single stock should be considered carefully, and the piece on whether concentration is preferable to diversification for serious investors sets out that framework. For volatile semiconductor names specifically, the margin and leverage guide covers the additional risk overlay involved in leveraged exposure.

    The Meta and OpenAI Six-Gigawatt Commitments: Material or Marginal

    On 2026-02-24, AMD announced two strategic partnerships within approximately the same news cycle. Meta committed to a six-gigawatt deployment across multiple Instinct generations, with the first deployment using a custom MI450-based GPU on AMD’s Helios rack-scale architecture and running ROCm alongside the sixth-generation EPYC server CPU codenamed Venice; first shipments are scheduled for the second half of 2026 (AMD press release, as of 2026-02-24). OpenAI committed separately to a six-gigawatt MI450 deployment, with the first one gigawatt scheduled to come online in the second half of 2026 (AMD press release, as of 2026-02-24). AMD CEO Lisa Su has stated publicly that the two commitments do not overlap (AMD press releases, as of 2026-02-24).

    Quantifying what 12 gigawatts of combined committed AI compute capacity means requires care. A gigawatt of AI data-centre capacity is a power-delivery figure, not a revenue figure or a unit-volume figure. The translation depends on rack density (kilowatts per rack), GPU power draw, and price per accelerator, all of which vary across MI450 system configurations and have not been publicly disclosed in dollar terms for these specific deals at the time of writing.

    The following observations can be made without extrapolating beyond the available disclosure. First, 12 gigawatts represents a structural commitment from two of the most capital-intensive AI buyers in the world, rather than a pilot deployment. Second, the deals fix MI450, rather than MI355X or earlier products, as the principal hardware, which makes execution on the MI450 ramp from the second half of 2026 onward the gating factor for both customers. Third, Meta’s choice to run ROCm in production at this scale is the clearest signal to date that ROCm is now considered hyperscaler-grade by at least one major buyer. This choice is more meaningful than any benchmark publication because Meta is dedicating its own engineering hours to the commitment.

    The bearish interpretation is also defensible. Twelve gigawatts spread over multiple years and multiple Instinct generations does not, by itself, imply that AMD overtakes NVIDIA at either customer; both Meta and OpenAI continue to be very large NVIDIA buyers. No specific FY2026 NVIDIA purchase figures for these two customers were cited in this brief, so the analysis does not assign a number. Hyperscalers routinely diversify suppliers to preserve negotiating leverage, and a diversification award, even a large one, does not necessarily indicate technical preference.

    Key Takeaway: The Meta and OpenAI commitments are large enough to be material to AMD’s revenue trajectory over 2026-2028, and Meta’s adoption of ROCm in production is qualitatively significant. They are not large enough, even in combination, to imply that AMD displaces NVIDIA as the volume leader in AI accelerators on any specific timeline disclosed publicly to date.

    Valuation and Analyst Positioning

    Valuation comparisons between AMD and NVIDIA are sensitive to the forward earnings figure used and to the analyst’s price target referenced. The table below summarises published consensus and individual analyst positioning as of mid-May 2026.

    Data as of 2026-05-16 unless otherwise noted. Sources: Public.com, MarketBeat, Yahoo Finance, TradingKey post-earnings analysis (AMD price, as of 2026-05-06).

    Metric AMD (AMD, NASDAQ) NVIDIA (NVDA, NASDAQ)
    Recent price (approximate) ~$415 (as of 2026-05-06) No specific recent price cited in this brief
    1-year return +253% No publicly disclosed figure confirmed in this brief
    Consensus rating Buy (41% Strong Buy, 41% Buy, 18% Hold) Strong Buy (37 analysts)
    Avg analyst price target ~$390-$397 consensus $273.62
    Implied upside Negative on consensus vs ~$415 print ~21%
    Highest / lowest analyst PT Bernstein $525 (Outperform); Barclays $500; Cantor Fitzgerald $500; BofA $450 $360 high / $195 low

     

    Two features of this table warrant commentary. First, AMD’s approximate price of $415 (TradingKey, as of 2026-05-06) is above the consensus analyst average of $390 to $397 (MarketBeat and Public.com, as of 2026-05-16). This is unusual and reflects the speed at which the stock has moved: the one-year return is 253 percent, the one-month return is 63 percent, and the one-week return is 10 percent (Public.com and MarketBeat, as of 2026-05-16). The post-earnings move on 2026-05-05 alone was +17.46 percent (TradingKey, as of 2026-05-06). Consensus targets often lag price action by several weeks; the negative implied upside on consensus should be interpreted as indicating that the stock has outrun the median analyst model, rather than as a statement that analysts expect the stock to decline.

    Second, the spread of individual targets is wide on AMD. Bernstein at $525 implies meaningful further upside from the recent print, while BofA at $450 implies modest upside; the consensus average sits below the spot price because not every analyst has updated forecasts following the Q1 print. NVIDIA’s consensus implied upside of approximately 21 percent on a $273.62 target (MarketBeat, as of 2026-05-16) reflects a more dispersed but generally constructive analyst stance with a range of $195 to $360.

    Entry-strategy considerations for either name, particularly after large one-week and one-month moves, are addressed in the dollar-cost averaging versus lump sum investing piece. For traders considering defined-risk exposure to either stock through derivatives, the options trading basics guide covers the mechanics.

    Three Conditional Scenarios for AMD Relative to NVIDIA

    The question of AMD’s prospects compared with those of NVIDIA is directional. This analysis declines to answer it as a binary judgement. Instead, the three scenarios below set out concrete conditions under which AMD either narrows the gap, fails to narrow the gap, or produces a mixed result over the six-to-twelve-month mid horizon.

    Upside Conditions for AMD: The Gap Narrows

    The upside case requires three conditions to be met, rather than only one. First, the MI450 ramp from the second half of 2026 must reach the volume and yield targets implied by the Meta and OpenAI commitments (AMD press releases, as of 2026-02-24). Public confirmation of MI450 production volumes at the announced gigawatt levels by the Q4 2026 or Q1 2027 reporting would be the most direct trigger. Second, ROCm adoption must extend beyond Meta to at least one additional top-five hyperscaler that runs ROCm on Instinct as a primary production stack rather than as a hedge. Third, AMD’s Data Center segment must continue to compound at or above the 57 percent year-over-year rate posted in Q1 2026 (AMD IR, as of 2026-05-05) through the next two reported quarters; a deceleration to the 30 to 35 percent range would not constitute upside, even with the Meta and OpenAI deals announced.

    Downside Conditions for AMD: The Gap Widens or Remains

    The downside case has clearer single-trigger pathways. First, NVIDIA Blackwell Ultra retains developer and hyperscaler lock-in. The 50-times performance and 35-times cost-reduction figures versus Hopper for agentic AI cited by SemiAnalysis InferenceX (Silicon Analysts coverage, as of the first half of 2026) are vendor-friendly, but if real-world inference TCO comparisons by independent third parties land in approximately the same range, MI450’s price advantage shrinks materially. Second, custom silicon, in the form of Google TPU v6 and AWS Trainium 2, captures share more rapidly than AMD. Independent coverage has already characterised custom silicon as the more material near-term threat to NVIDIA’s share than AMD represents (Silicon Analysts coverage, as of the first half of 2026); the same dynamic that erodes NVIDIA’s share also erodes the addressable share pool for which AMD competes. Third, ROCm friction in production, whether in drivers, framework versions, or networking, slows MI450 deployment at Meta or OpenAI relative to the announced schedule.

    Neutral Conditions: Mixed Signals

    The neutral case is, by construction, the most likely. AMD continues to grow Data Center revenue at high double-digit rates, MI450 ships at Meta and OpenAI on approximately the announced schedule with normal production difficulties, ROCm advances on major frameworks but does not displace CUDA outside committed deployments, and NVIDIA continues to grow its absolute Data Center revenue more rapidly than AMD in dollar terms even as AMD grows more rapidly in percentage terms. In this scenario, the share gap (80 percent versus 5 to 7 percent) narrows modestly, perhaps to 78 percent versus 8 to 10 percent on the twelve-month horizon, but does not close, and both stocks can perform well in absolute terms while NVIDIA retains the volume leadership.

    On the basis of the data referenced, namely the 73 percent versus 38 percent revenue growth gap, the 75.0 percent versus 53 percent GAAP gross margin gap, the 80 percent versus 5 to 7 percent share gap, and the second-half 2026 timing of the MI450 ramp, conditions appear to favour the neutral scenario over the upside scenario across the six-to-twelve-month mid horizon. This is a tentative observation, grounded in the premise that the MI450 ramp will not contribute materially to AMD Data Center revenue until late 2026 at the earliest, rather than a definitive conclusion. The upside scenario remains feasible if the second-half 2026 MI450 ramp executes cleanly and if reported Data Center growth in the second half of 2026 accelerates rather than decelerates.

    Macro variables fall outside the company-specific scenarios but bound them. Rate-cut expectations and their effect on long-duration growth stocks are discussed in the US interest rate cut outlook piece, and the broader geopolitical overlay, including export controls relevant to AI accelerators sold into China, is covered in the US-China trade war investment strategy piece and the geopolitical events framework.

    Limitations of This Analysis

    This analysis relies on company-reported financials, vendor-provided benchmarks, and third-party industry coverage; none of these sources constitute audited TCO calculations, and the market-share and AI capital expenditure figures are estimates subject to revision. Forward-looking statements regarding MI450 ramp execution, ROCm hyperscaler adoption, and Blackwell Ultra real-world performance cannot be verified ahead of subsequent reporting cycles, and readers should expect the scenario conditions above to be re-evaluated against each quarterly print.

    Frequently Asked Questions

    Is AMD overtaking NVIDIA in AI accelerators?

    No publicly disclosed data supports this characterization as of writing. NVIDIA holds roughly 80% of the AI accelerator market versus AMD’s roughly 5-7% (Silicon Analysts coverage, as of first half 2026). AMD’s Q1 2026 Data Center revenue of $5.8 billion (AMD IR, as of 2026-05-05) compares to NVIDIA’s Q4 FY2026 Data Center revenue of $62 billion (NVIDIA newsroom, as of 2026-02-25), a roughly 10.7x ratio. AMD is growing Data Center revenue at 57% year-over-year, faster than the broader market, but absolute dollar growth at NVIDIA remains larger.

    What do the Meta and OpenAI 6-gigawatt commitments mean in dollar terms?

    AMD has not publicly disclosed dollar values for either the Meta or the OpenAI commitment as of writing; both are framed in gigawatts of deployed capacity rather than in revenue (AMD press releases, as of 2026-02-24). Translating gigawatts to revenue requires rack density, GPU power draw, and price-per-accelerator inputs that have not been disclosed for these specific deals. What is confirmed is that the two commitments are non-overlapping (per AMD CEO Lisa Su, AMD press releases, as of 2026-02-24) and that first shipments for both begin in H2 2026.

    How does ROCm compare to CUDA in 2026?

    ROCm (Radeon Open Compute) has narrowed the functional gap with CUDA (Compute Unified Device Architecture) on major machine learning frameworks including PyTorch, TensorFlow, and JAX. Meta’s decision to run ROCm in production on its custom MI450-based Helios deployment (AMD press release, as of 2026-02-24) is the strongest single signal that ROCm is now considered hyperscaler-grade. The gap that remains is in the long tail of niche libraries and in two decades of accumulated CUDA developer mindshare; no public metric quantifies this gap precisely.

    What is the biggest risk to AMD’s AI accelerator business?

    Independent industry coverage has characterized the collective custom-silicon threat (Google TPU v6 expanding beyond Google, AWS Trainium 2, Apple Silicon for on-device) as a faster-growing share threat to NVIDIA than AMD currently represents (Silicon Analysts coverage, as of first half 2026). The implication for AMD is that even if NVIDIA’s share erodes, AMD may not be the primary beneficiary. The second risk is execution on the MI450 ramp in H2 2026; the Meta and OpenAI commitments are MI450-specific.

    What about Intel and Korean memory suppliers?

    Intel (INTC, NASDAQ) competes in the AI accelerator market through its Gaudi product line, which is included in the roughly 13-15% non-NVIDIA, non-AMD share figure (Silicon Analysts coverage, as of first half 2026); detailed Intel-specific Gaudi revenue figures were not cited in this brief. Korean memory suppliers — SK Hynix (000660, KOSPI) and Samsung — supply the HBM (high-bandwidth memory) used on both AMD and NVIDIA accelerator packages; their influence is on package gross margin rather than on AMD-versus-NVIDIA differentiation.

    Related Reading on aicodeinvest.com:

    References

    Investment Disclaimer: This post is provided for informational purposes only and does not constitute a recommendation to buy or sell any specific security. All investment decisions and their outcomes are the sole responsibility of the individual investor.
  • xPatch Explained: Dual-Stream Time Series Forecasting with EMA Decomposition

    PatchTST established the prevailing benchmark for transformer-based time series forecasting. A subsequent paper from KAIST then demonstrated a less comfortable result: a non-transformer model composed of two simple streams, an MLP and a CNN, outperforms PatchTST. xPatch achieves this with approximately one-quarter of the compute and an established idea, namely exponential moving averages.

    The paper is xPatch: Dual-Stream Time Series Forecasting with Exponential Seasonal-Trend Decomposition by Artyom Stitsyuk and Jaesik Choi, published at AAAI 2025 (arXiv:2412.17323). It is the type of paper that quietly recalibrates the field. There is no new attention variant, no foundation model with 100 billion parameters, only a careful re-examination of which inductive biases actually contribute to forecasting performance for electricity load, traffic, weather, or stock returns.

    This article examines in detail every load-bearing component of the paper: the EMA decomposition, the dual-stream architecture, the arctangent loss, the sigmoid learning-rate schedule, the experimental results, and the implications for practitioners deploying forecasts in production.

    Summary

    What this post covers: A detailed examination of the AAAI 2025 xPatch paper by Stitsyuk and Choi, including its EMA decomposition, dual-stream MLP and CNN architecture, training methods (arctangent loss, sigmoid learning-rate schedule, RevIN), benchmark results, and the implications for transformer-dominated time-series forecasting.

    Key insights:

    • A non-transformer dual-stream model (a linear stream for the trend and a depthwise-separable CNN for the seasonal component) outperforms CARD, the previous current best, by an average of 2.46 percent in MSE and 2.34 percent in MAE across eight standard benchmarks, while running approximately four times faster.
    • The appropriate inductive bias (EMA trend-seasonal decomposition combined with patching and dual specialisation) consistently outperforms generic attention for typical multivariate forecasting, echoing the earlier critique advanced by DLinear in “Are Transformers Effective?”
    • Training-side techniques contribute meaningfully to performance. The arctangent loss (a horizon-weighted MAE that prevents any single horizon from dominating the gradient) and the sigmoid learning-rate schedule also transfer to PatchTST and CARD, suggesting that many architecture comparisons in the literature have employed suboptimal training recipes.
    • The recommended default for the EMA alpha is 0.3 on large benchmarks (Weather, Traffic, Electricity). On smaller or noisier datasets, a sweep over {0.1, 0.3, 0.5, 0.7, 0.9} is appropriate. A smaller alpha produces smoother trends, while a larger alpha produces more reactive trends.
    • xPatch is preferable to PatchTST as a production default unless the application involves heavy channel correlations that benefit from cross-channel attention, or requires a look-back longer than 96 steps. xPatch is faster to train, faster to infer, slightly more accurate, and easier to debug because the two streams are individually interpretable.

    Main topics: Why this paper matters, The EMA Decomposition at the Centre of xPatch, The Dual-Stream Architecture, Training Components: Arctangent Loss, Sigmoid Schedule, and RevIN, Benchmark Results, Ablations: What Drives Performance, How to use xPatch (PyTorch sketch), When to use xPatch versus alternatives, Limitations and open questions, Implications for the Field, Frequently asked questions.

    Why this paper matters

    For approximately three years, time series forecasting has been dominated by transformer-based models. Informer (2021) made attention practical for long sequences. Autoformer (2021) incorporated series decomposition. FEDformer (2022) shifted attention into the frequency domain. PatchTST (2023) adapted the patching technique from Vision Transformers and became the strongest model on a substantial set of benchmarks. iTransformer (2024) inverted the embedding dimension. CARD (2024) refined the channel-aligned attention design.

    DLinear, introduced in 2022, raised an awkward question: is attention actually required for forecasting? A two-line linear model, consisting of a single fully connected layer with a moving-average decomposition, could match or surpass several transformer variants on standard benchmarks. The community responded with a wave of “are transformers effective” papers, and the consensus that emerged was nuanced: transformers help on some datasets, harm on others, and the gains are often smaller than the speed advantages forgone.

    xPatch takes the next logical step. Rather than abandoning the transformer entirely (as DLinear does) or retaining a transformer while refining attention (as CARD and iTransformer do), it constructs a dual-stream non-transformer model with stronger inductive biases. One stream is a simple MLP. The other is a compact depthwise-separable CNN. Combined with EMA-based decomposition and an improved loss function, the result outperforms CARD, the previous current best, while training approximately four times faster.

    For an overview of the broader landscape in which these models operate, see the companion overview of time series forecasting models in 2026. xPatch is one of the clearest examples of a non-foundation-model approach that continues to deliver competitive performance on real benchmarks.

    Key Takeaway: xPatch provides evidence that for typical multivariate forecasting, appropriate inductive biases (decomposition, patching, and dual specialisation) contribute more than attention itself. Architecture is not the only frontier; loss functions and learning-rate schedules also account for a substantial share of observed performance differences.

    xPatch: Dual-Stream Architecture Input X L × N RevIN normalize EMA decomposition X_T (trend) X_S = X − X_T Linear Stream (X_T) FC → AvgPool(k=2) → LN FC → AvgPool(k=2) → LN no activation, project → T CNN Stream (X_S) Patch P=16, S=8 Depthwise (k=P) → Pointwise GELU → BatchNorm → residual Concat + Linear de-RevIN → Ŷ Linear stream handles smooth trend; CNN stream handles bursty seasonal patterns.

    The EMA Decomposition at the Centre of xPatch

    The single most important point to retain about xPatch is the following: the model’s first operation is to separate every channel of the input series into a slow component and a fast component, and then to model each component with a distinct network. The separation is performed using an exponential moving average.

    Why decomposition matters

    Trend and seasonality have fundamentally different dynamics. A trend is slow, often nearly linear over short windows, and dominated by accumulating shifts in level. A seasonal component is fast, often locally periodic, and frequently bursty (for example, traffic spikes or weather fronts). If one network is asked to model both at once, it must compromise: smooth filters blur the seasonal spikes, while sharp filters chase the trend’s drift. Decomposition removes that conflict by assigning each component to a specialist.

    This is not a new idea. Classical statistics has applied decomposition for decades:

    • STL (Seasonal-Trend decomposition using Loess): local polynomial regression for seasonality extraction.
    • Holt-Winters: three exponential smoothers (level, trend, and seasonal) chained together.
    • X-11 / X-13ARIMA-SEATS: a workhorse of official statistics based on iterative moving averages.

    Recent machine-learning approaches retained the spirit of decomposition while employing different tools. DLinear used a simple moving-average filter, and FEDformer projected the series into the frequency domain via Fourier transforms. xPatch adopts a different choice: an exponential moving average.

    The recursive formula

    The EMA decomposition is defined by Equation 2 of the paper:

    s₀ = x₀
    sₐ = α · xₐ + (1 - α) · sₐ₋₁    for t > 0
    
    X_T = EMA(X)         (trend)
    X_S = X − X_T        (seasonal residual)

    The parameter α is the smoothing factor, taking values in (0, 1). A small α (such as 0.1) produces a very smooth trend dominated by older observations, while a large α (such as 0.9) causes the trend to track the most recent value almost immediately. The seasonal stream consists of whatever the trend cannot explain.

    The recursion appears computationally expensive, since it is sequential by definition. However, Appendix D of the paper presents a vectorised form with O(1) per-step cost in terms of GPU operations. The technique is to expand the recursion into a closed-form weighted sum and compute it as a single matrix multiplication with a Toeplitz-style weight matrix. In practice, the EMA pre-processing is essentially free relative to the rest of the forward pass.

    Why α = 0.3 performs best on large datasets

    The paper sweeps α over {0.1, 0.3, 0.5, 0.7, 0.9}. On Weather, Traffic, and Electricity, the larger and more channel-rich benchmarks, α = 0.3 is consistently optimal. The intuition is as follows. With many noisy channels, the trend must be genuinely slow in order to filter short-lived noise while still tracking the multi-step drift. A smaller α oversmooths and deprives the seasonal stream of bandwidth, whereas a larger α allows excessive high-frequency content to leak into the trend. The value 0.3 sits in the appropriate range.

    On smaller and noisier datasets the result is less clear-cut. In some cases α = 0.5 or 0.7 is preferable because the trend must react more quickly to abrupt regime changes. The paper treats α as a hyperparameter rather than a learnable parameter; making α learnable is one obvious direction for follow-up research.

    Simple moving average versus exponential moving average

    Property Simple Moving Average (DLinear-style) Exponential Moving Average (xPatch)
    Weight scheme Uniform inside a window Geometric decay, recent > old
    Hyperparameter Window length k Smoothing factor α
    Edge effects Hard window boundary Smooth, no boundary discontinuity
    Reactivity to recent shocks Slow (averaged equally with old data) Fast (recent point gets weight α)
    Implementation cost O(k) per step O(1) per step (vectorized)

     

    EMA Decomposition (α = 0.3) Original X Trend X_T = EMA(X) sₐ = α·xₐ + (1−α)·sₐ₋₁ Seasonal X_S = X − X_T Trend: smooth low-pass via EMA. Seasonal: bursty residual carries the high-frequency structure.

    The Dual-Stream Architecture

    Once X_T (the trend) and X_S (the seasonal component) are obtained, xPatch processes them in two specialised streams. The design principle is to use the appropriate tool for each component and combine the results at the end.

    The linear stream (processing X_T)

    The trend is, by construction, smooth. After EMA filtering, little non-linear structure remains. xPatch therefore processes the trend through two MLP-style blocks, each composed of:

    • A fully connected (FC) projection.
    • A 1D average pooling layer with kernel size k = 2.
    • A LayerNorm operation.

    Importantly, there is no non-linear activation function anywhere in the linear stream. Up to the LayerNorm, the entire stream consists of a sequence of linear operators. The final output is projected to dimension T (the forecast horizon). Readers familiar with DLinear will recognise the structure: xPatch retains the DLinear approach for trend modelling.

    The LayerNorm is the only operator in the stream with a non-linear character, since it divides by an instance-computed standard deviation that is data-dependent. It stabilises training when the trend’s scale varies across samples. The average pooling acts as an additional smoothing step, reducing the probability that the linear stream over-fits to high-frequency noise that leaks through the decomposition.

    The CNN stream (processing X_S)

    The seasonal stream is where most of the modelling work occurs. Seasonal residuals are bursty, locally periodic, and channel-correlated. xPatch handles them with a depthwise-separable CNN:

    • Patching: the input is segmented into patches of length P = 16 with stride S = 8. The number of patches is N = ⌊(L − P) / S⌋ + 2, matching the PatchTST configuration. With L = 96, the result is approximately 12 patches per channel.
    • Depthwise convolution: kernel size P = 16, stride P = 16, with groups equal to the number of channels N. Each channel receives its own filter aligned to patch boundaries, with no cross-channel mixing at this step.
    • Pointwise convolution: a 1×1 convolution that mixes information across channels.
    • GELU activation: the only major non-linearity in the entire model. The smooth saturating shape of GELU is well suited to spiky residuals.
    • BatchNorm: applied for training stability across batches.
    • Residual connection: the input is added back to the output, which simplifies optimisation and allows the stream to behave approximately as an identity if the seasonal component is near zero.

    The depthwise plus pointwise pattern is the classic MobileNet-style separable convolution. It reduces parameters substantially relative to a full convolution while retaining a similar receptive field. For time series with many channels (Traffic has 862 and Electricity has 321), the reduction is essential, since a full Conv1D would be prohibitively large.

    Why this division of labour is effective

    An MLP can learn arbitrary linear projections but must allocate capacity to discover local structure. A patch-aligned CNN encodes locality and translation-equivariance directly into the architecture. By passing only the seasonal residual into the CNN, xPatch allows the CNN to concentrate on local patterns, the task it is best suited to, without expending capacity on re-learning the trend. Conversely, the linear stream is not required to model seasonal spikes that would force a compromise.

    This is the same lesson that graph attention networks illustrate in a different domain: the architecture’s inductive biases should align with the structure of the signal being modelled. Attention is a powerful general-purpose mixer, but its generality is not free.

    Combining the two streams

    The outputs of the linear and CNN streams are concatenated and passed through a final linear layer (Equation 12 in the paper) to produce the forecast over horizon T. The combination is intentionally simple. The model is not required to learn a complex gating mechanism; it learns a linear combination of the two specialists’ outputs.

    Tip: For implementations starting from scratch, an effective sanity check is to begin with the linear stream alone and verify that it matches DLinear performance on ETTh1. The CNN stream can then be added, and the gains will become visible on noisier datasets such as Weather and Traffic.

    Training Components: Arctangent Loss, Sigmoid Schedule, and RevIN

    The architecture is only half of the story. The other half is the training recipe, and the paper makes a strong case that some of the gains derive from techniques that any forecasting model can adopt.

    RevIN (Reversible Instance Normalisation)

    Distribution shift is endemic in time series. The mean and variance of a channel during training rarely match those at inference time, particularly in non-stationary domains such as finance, traffic, or weather. RevIN addresses this issue with a simple procedure:

    1. Before the model: subtract the per-instance mean and divide by the per-instance standard deviation, where the instance is a single look-back window.
    2. After the model: multiply by the same standard deviation and add back the same mean, along with learnable affine parameters.

    The model therefore only sees standardised inputs and does not need to memorise the level or scale of any particular channel. The de-normalisation at the output returns the forecast to the original scale. RevIN is now standard equipment in modern forecasting models, and xPatch employs it in the same manner as PatchTST and CARD.

    The arctangent loss

    This is one of the more novel components of the paper. CARD popularised a horizon-weighted loss that assigns greater importance to longer-horizon predictions, with weights that grow exponentially. The motivation is reasonable, since long-horizon errors compound, but exponential weighting grows quickly and can dominate the optimisation.

    xPatch replaces this with a slower-growing function based on the arctangent (Equations 16 and 17):

    ρ_arctan(i) = −arctan(i) + π/4 + 1
    
    L_arctan = (1/T) · Σᵢ ρ_arctan(i) · ||Ŷᵢ − yᵢ||₁

    The motivation for the arctangent function is that it is bounded (growth slows asymptotically), monotonic, and smooth. Unlike exponential weighting, it does not allow any single horizon to dominate the gradient. The result is more uniform attention across the entire forecast window, which empirically translates into improved performance on long horizons without degrading performance on shorter ones.

    The paper’s most notable ablation finding is that the arctangent loss helps even when applied to other models. Substituting it into PatchTST or CARD improves accuracy. The loss is therefore a transferable technique that can serve as a free upgrade for an existing forecasting pipeline.

    Sigmoid learning-rate schedule

    Standard schedules in this literature are step decay (the learning rate is halved every K epochs) or cosine annealing. xPatch introduces a sigmoid-shaped schedule (Equation 23) with a warm-up parameter w. The shape consists of a smooth ramp-up from a low initial value, a flat plateau in the middle, and a gentle ramp-down. Compared with step decay, it avoids the discontinuities that can destabilise training. Compared with cosine annealing, the explicit warm-up provides the optimiser with time to locate a suitable basin before the learning rate becomes high.

    As with the arctangent loss, the paper shows that the sigmoid schedule transfers cleanly to other models. The implication is that learning-rate schedules are often under-tuned in benchmark comparisons. When all models use the same default, any architecture that claims a win must outperform the also-suboptimal training of every competitor.

    Compute footprint

    xPatch is trained for 100 epochs on a single NVIDIA Quadro RTX 6000. The configuration corresponds to a single mid-range GPU and a short schedule by current standards. There is no foundation-model pre-training, no distributed setup, and no specialised quantisation. This minimal footprint is part of the paper’s argument: current best forecasting does not necessarily require current best compute.

    Caution: The arctangent loss assumes that all horizons matter equally. If the downstream application weights the next-step forecast more heavily (for example, real-time anomaly detection on the next minute), the weighting should be shifted toward shorter horizons, or a custom ρ function should be used. The paper’s choice is well motivated for the standard MSE-on-all-horizons benchmark, but it is not necessarily optimal for every production setting.

    Benchmark Results

    The experimental setup is the standard long-horizon forecasting suite that has dominated the literature since Informer.

    Datasets

    Dataset Dim Frequency Forecast horizons
    ETTh1, ETTh2 7 Hourly 96, 192, 336, 720
    ETTm1, ETTm2 7 15 min 96, 192, 336, 720
    Weather 21 10 min 96, 192, 336, 720
    Traffic 862 Hourly 96, 192, 336, 720
    Electricity 321 Hourly 96, 192, 336, 720
    Exchange-rate 8 Daily 96, 192, 336, 720
    Solar 137 10 min 96, 192, 336, 720
    ILI 7 Weekly 24, 36, 48, 60

     

    The look-back window is L = 96 for all datasets except ILI, which uses L = 36. The baselines are the principal models of the past few years: Autoformer, FEDformer, ETSformer, TimesNet, DLinear, RLinear, MICN, PatchTST, iTransformer, TimeMixer, and CARD.

    Headline numbers

    Dataset Horizon xPatch MSE xPatch MAE
    ETTh1 96 0.428 0.419
    Weather 720 0.310 0.322

     

    Across all eight datasets and all four horizons, xPatch outperforms CARD, the previous current best, by an average of 2.46 percent in MSE and 2.34 percent in MAE. The margin is small but clear, given how saturated these benchmarks have become. Gains of 1 to 3 percent are now considered meaningful in the literature, and such gains are typically obtained at the cost of new attention variants, larger models, or longer training.

    Speed

    While accuracy is the headline result, the speed advantage is equally important. Table 3 of the paper reports per-step training and inference times.

    Model Training (msec/step) Inference (msec/step) Relative speed vs xPatch
    xPatch 3.099 1.303 1.0×
    CARD 14.877 4.8× slower

     

    Training is approximately 4.8 times faster than CARD per step. The paper does not provide equivalently precise per-step numbers for PatchTST and DLinear, but the general ordering reported is DLinear < xPatch < PatchTST < CARD in training time. In production settings, where forecasting models may be retrained daily on streaming data, this speed advantage matters more than the marginal MSE gain.

    Speed vs Accuracy: xPatch is Pareto-optimal Training time per step (msec) — lower is better MSE — lower is better 1 3 7 12 15 20 0.42 0.44 0.46 0.48 0.50 DLinear (1 msec, 0.50) iTransformer (~10 msec, ~0.46) PatchTST (~7 msec, ~0.45) CARD (15 msec, 0.44) xPatch (3 msec, 0.43) — Pareto-optimal MSE values are illustrative averages across benchmarks; xPatch achieves both lower MSE and faster training than CARD/PatchTST.

    Ablations: What Drives Performance

    Ablation studies indicate whether a paper’s gains are robust or fragile. The ablations reported for xPatch are transparent and informative.

    EMA α sweep

    α Weather Traffic Electricity Notes
    0.1 slightly worse slightly worse slightly worse Trend too smooth, leaks structure
    0.3 best best best Optimal balance for big datasets
    0.5 close close close Reasonable fallback
    0.7 worse worse worse Trend tracks too fast
    0.9 worst worst worst Trend ~= input, decomposition fails

     

    The pattern is clear: 0.3 dominates on the larger datasets. The paper notes that smaller and noisier datasets sometimes favour higher α values, so fixing α = 0.3 for every problem is unwise. The parameter should instead be swept on a held-out validation split.

    Necessity of both streams

    The paper ablates the removal of each stream. Removing the linear stream (so that the CNN handles both trend and seasonal components) degrades performance. Removing the CNN stream (so that the linear stream attempts to capture seasonality) degrades performance more substantially. The two streams are genuinely complementary, and neither is dispensable.

    Transferability of the arctangent loss

    This is arguably the most important ablation in the paper. When the standard MSE loss in PatchTST or CARD is replaced with the arctangent loss, those models also improve. The loss is therefore a free upgrade for the field. Practitioners operating an existing forecasting pipeline can adopt the new loss as a one-line change and likely gain a few percentage points in accuracy.

    Transferability of the sigmoid schedule

    The same conclusion applies to the sigmoid schedule: it also helps other models. The implication is uncomfortable for the literature. A non-trivial fraction of past “architecture wins” may have been confounded by suboptimal training schedules. xPatch at least isolates how much of its margin derives from the loss and the schedule, as distinct from the dual-stream design itself.

    Key Takeaway: A meaningful share of the gains attributed to xPatch derives from training methods rather than architecture. The honest reading is that xPatch outperforms on multiple dimensions, including better decomposition, better dual-stream design, a better loss, and a better schedule. Practitioners should consider carefully which of these components to adopt independently.

    How to use xPatch (PyTorch sketch)

    The official implementation is available at github.com/stitsyuk/xPatch and follows the structure of standard long-horizon forecasting library scaffolds. The full code includes data loaders, evaluation harnesses, and configurations for each benchmark, but the model itself is compact enough to summarise in a single screen.

    The following is a minimal but faithful PyTorch outline. It is not a drop-in replacement for the official repository, which should be used for benchmarking, but it represents the architecture clearly.

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class EMADecomp(nn.Module):
        """Exponential moving-average decomposition (Eq. 2)."""
        def __init__(self, alpha: float = 0.3):
            super().__init__()
            self.alpha = alpha
    
        def forward(self, x):
            # x shape: (B, L, N)  batch, look-back, channels
            B, L, N = x.shape
            trend = torch.zeros_like(x)
            trend[:, 0, :] = x[:, 0, :]
            for t in range(1, L):
                trend[:, t, :] = (
                    self.alpha * x[:, t, :]
                    + (1.0 - self.alpha) * trend[:, t - 1, :]
                )
            seasonal = x - trend
            return trend, seasonal
    
    
    class LinearStream(nn.Module):
        """2 FC + AvgPool + LayerNorm blocks, no activation."""
        def __init__(self, L: int, T: int, hidden: int = 128):
            super().__init__()
            self.fc1 = nn.Linear(L, hidden)
            self.pool1 = nn.AvgPool1d(kernel_size=2, stride=1, padding=1)
            self.ln1 = nn.LayerNorm(hidden + 1)
            self.fc2 = nn.Linear(hidden + 1, hidden)
            self.pool2 = nn.AvgPool1d(kernel_size=2, stride=1, padding=1)
            self.ln2 = nn.LayerNorm(hidden + 1)
            self.proj = nn.Linear(hidden + 1, T)
    
        def forward(self, x):
            # x: (B, L, N) -> (B, N, L)
            x = x.transpose(1, 2)
            h = self.pool1(self.fc1(x).transpose(1, 2)).transpose(1, 2)
            h = self.ln1(h)
            h = self.pool2(self.fc2(h).transpose(1, 2)).transpose(1, 2)
            h = self.ln2(h)
            return self.proj(h)  # (B, N, T)
    
    
    class CNNStream(nn.Module):
        """Patch -> depthwise -> pointwise -> GELU -> BN -> residual."""
        def __init__(self, N: int, L: int, T: int,
                     P: int = 16, S: int = 8):
            super().__init__()
            self.P, self.S = P, S
            n_patches = (L - P) // S + 2
            self.depthwise = nn.Conv1d(
                in_channels=N, out_channels=N,
                kernel_size=P, stride=P, groups=N,
            )
            self.pointwise = nn.Conv1d(N, N, kernel_size=1)
            self.bn = nn.BatchNorm1d(N)
            self.proj = nn.Linear(n_patches * P, T)
    
        def forward(self, x):
            # x: (B, L, N) -> (B, N, L)
            x = x.transpose(1, 2)
            h = self.depthwise(x)
            h = self.pointwise(h)
            h = F.gelu(h)
            h = self.bn(h)
            # residual: pad and add (omitted for brevity)
            h = h.flatten(start_dim=2)
            h = F.pad(h, (0, max(0, self.proj.in_features - h.size(-1))))
            return self.proj(h[..., :self.proj.in_features])
    
    
    class XPatch(nn.Module):
        def __init__(self, L: int, T: int, N: int, alpha: float = 0.3):
            super().__init__()
            self.decomp = EMADecomp(alpha)
            self.linear_stream = LinearStream(L, T)
            self.cnn_stream = CNNStream(N, L, T)
            self.fuse = nn.Linear(2 * T, T)
    
        def forward(self, x):
            # RevIN
            mean = x.mean(dim=1, keepdim=True)
            std = x.std(dim=1, keepdim=True) + 1e-5
            x_norm = (x - mean) / std
    
            trend, seasonal = self.decomp(x_norm)
            y_lin = self.linear_stream(trend)        # (B, N, T)
            y_cnn = self.cnn_stream(seasonal)        # (B, N, T)
            y = torch.cat([y_lin, y_cnn], dim=-1)
            y = self.fuse(y).transpose(1, 2)         # (B, T, N)
    
            # de-RevIN
            return y * std + mean
    
    
    def arctangent_loss(pred, target):
        """L_arctan from Eq. 16-17."""
        T = pred.size(1)
        i = torch.arange(T, device=pred.device, dtype=torch.float32)
        rho = -torch.atan(i) + torch.pi / 4 + 1.0
        abs_err = (pred - target).abs().mean(dim=-1)  # (B, T)
        return (rho * abs_err).mean()
    

    Several practical notes apply:

    • The Python loop in EMADecomp should be replaced with the vectorised closed-form for a genuine speed-up. The mathematics is presented in Appendix D of the paper, and the official repository implements the vectorised version.
    • The CNN stream’s output projection is sketched in a simplified manner here; the official implementation handles the patching dimensions more carefully.
    • For a clean initial configuration, use L = 96, P = 16, S = 8, α = 0.3, 100 epochs, the sigmoid learning-rate schedule with a warm-up of approximately 10 epochs, and the arctangent loss.

    For applications involving anomaly detection on the same series, the overview of time series anomaly detection models is relevant. Many of the same training techniques (RevIN, patching, decomposition) carry over.

    Hyperparameter reference

    Hyperparameter Default When to change
    Look-back L 96 (36 for ILI) Increase if your seasonality is longer than 96 steps
    Patch size P 16 Should align with your series’ natural local period
    Stride S 8 Smaller for more overlap, larger for fewer patches
    EMA α 0.3 Sweep {0.1, 0.3, 0.5, 0.7, 0.9} on small/noisy data
    Epochs 100 Use early stopping to cut wasted compute
    Loss Arctangent Switch to standard MAE if all horizons matter equally

     

    When to use xPatch versus alternatives

    No single model is appropriate for every problem. xPatch occupies a specific region of the design space: low-latency, accuracy-competitive, supervised, point-forecast, and multivariate. The following framework is useful for selecting an appropriate model.

    Need Recommended approach Why
    Fastest training/inference, good accuracy xPatch Beats CARD, ~5× faster than CARD per training step
    Foundation model / zero-shot TimesFM, Chronos, Moirai Pretrained at scale, generalize across domains without fine-tuning
    Calibrated uncertainty estimates Gaussian processes Native posterior variances, principled credible intervals
    Long-context attention reasoning PatchTST, iTransformer When channel relationships are essential and context exceeds ~512 steps
    Tabular-style features without temporal structure XGBoost / LightGBM When good lag/window features can be engineered, GBMs are difficult to beat on tabular forecasting
    Linear/stationary signal, minimal compute DLinear, classical ARIMA If the data is genuinely simple, simpler is better
    High-throughput streaming infra xPatch + Kafka time-series engine Low-latency model fits well with streaming pipelines

     

    For principled tuning of hyperparameters in any of these alternatives, the companion note on Bayesian hyperparameter optimisation is a useful reference.

    Limitations and open questions

    xPatch is a strong paper, but no paper is without weaknesses. The honest limitations are as follows:

    • α is a hyperparameter rather than a learned parameter. A natural extension is to make α differentiable, or even to make it both per-channel and per-timescale. The paper acknowledges this and identifies it as future work.
    • The datasets are relatively small. The largest is Traffic, with 862 channels and approximately 17,000 timesteps. This is small compared with the data on which foundation models such as Chronos and TimesFM are pre-trained. The behaviour of xPatch on substantially larger streams remains untested in the paper.
    • Two streams imply two forward passes. Inference remains fast, but a fused single-pass implementation would be faster still and might be feasible with a careful architectural redesign.
    • The model produces point forecasts only. xPatch produces a single-trajectory forecast without a probabilistic interpretation. For risk-sensitive applications such as finance, energy, and healthcare, quantiles or full distributions are typically required, and xPatch does not provide them natively. A quantile head or a Bayesian wrapper is necessary.
    • Benchmark saturation. The community has acknowledged that ETTh, Weather, and related benchmarks are showing signs of saturation. Gains of 2 to 3 percent may not transfer to messier real-world data with greater drift, missing values, and concept shift. xPatch’s results are current best on these benchmarks; whether they generalise to, for example, the tick data of a finance trading desk is an empirical question.
    • The paper presents no theoretical analysis. The contribution is empirical. There is no generalisation bound, no convergence proof for the recursion, and no analysis of the loss landscape. This is acceptable for an applied paper but leaves room for follow-up theory.
    Caution: If an application is characterised by heavy concept drift (for example, post-COVID demand forecasting or regime-changing financial markets), benchmark gains do not automatically transfer. Practitioners should evaluate on their own data with a realistic backtest before relying on leaderboard results.

    Implications for the Field

    Considered at a higher level, the broader narrative is more interesting than the architectural details alone:

    • Inductive biases continue to matter. Decomposition (the separation of trend from seasonality) has been valuable since the 1950s, and it remains valuable in 2025. Patching, locality, and dual-specialisation all encode useful priors. Generic attention without such priors is rarely the appropriate choice for time series.
    • Loss functions and learning-rate schedules are underrated. The fact that the arctangent loss and the sigmoid schedule transfer to other models suggests that the field has been comparing architectures under suboptimal training. Future benchmark papers should standardise the training recipe before claiming architectural wins.
    • The Pareto frontier is the appropriate evaluation axis. A model that is 1 percent more accurate but 10 times slower may not be worth deploying. xPatch occupies the region in which accuracy is competitive and speed is meaningfully better, which is the appropriate position for production systems.
    • Foundation models are not the only path forward. The same year that produced TimesFM and Chronos also produced xPatch, which is task-specific, compact, fast, and competitive. Both styles will coexist; the appropriate choice depends on deployment constraints.
    • Self-supervised pre-training remains an open opportunity. xPatch is fully supervised. Whether self-supervised pre-training of the CNN stream, analogous to TS2Vec and related methods, would unlock further gains is an open question. The overview of self-supervised pretraining covers the relevant techniques.

    For a concise reminder of the statistical foundations on which these models rest (independence, the role of variance, the importance of sample size for stable estimators), the explainer on the Central Limit Theorem is relevant. For deployment considerations, the comparison of databases for preprocessed time series reviews the relevant trade-offs.

    Frequently asked questions

    Why does a non-transformer model outperform PatchTST?

    Three factors combine. First, the EMA decomposition provides the model with two cleaner sub-signals rather than a single mixed signal. Second, the dual-stream architecture matches the appropriate tool to each component: a linear stream for the smooth trend and a CNN for the bursty seasonal residual. Third, the arctangent loss and the sigmoid learning-rate schedule provide a training-side improvement. PatchTST employs channel-independent attention and learnable patching, but it asks a single stack of attention layers to handle both trend and seasonal components simultaneously. xPatch’s specialisation wins by an average of 2.46 percent in MSE while running approximately 4.8 times faster than CARD.

    Should xPatch or PatchTST be used in production?

    The default choice should be xPatch unless there is a specific reason to prefer PatchTST. xPatch is faster to train, faster to infer, slightly more accurate on the standard benchmarks, and easier to debug because the streams are individually interpretable. PatchTST is preferable if the dataset is heavily channel-correlated and the cross-channel mixing of attention is essential, or if a look-back longer than 96 steps is required and the global receptive field of attention is needed.

    How is the EMA alpha parameter tuned?

    The recommended starting point is α = 0.3, which is optimal for the largest benchmarks in the paper (Weather, Traffic, Electricity). For smaller or noisier datasets, a sweep over {0.1, 0.3, 0.5, 0.7, 0.9} on a held-out validation split is appropriate. A smaller α produces smoother trends, which is suitable when noise dominates. A larger α produces more reactive trends, which is suitable when regime changes are abrupt. The paper deliberately keeps α non-learnable; making it learnable is a reasonable research extension.

    What is the arctangent loss and why does it help?

    The arctangent loss replaces standard MSE or MAE with a horizon-weighted MAE in which the weights follow ρ(i) = −arctan(i) + π/4 + 1. The arctangent grows much more slowly than the exponential weighting used by CARD, which prevents any single horizon from dominating the gradient. The result is a more uniform learning signal across all forecast horizons. Empirically, the loss benefits not only xPatch but also other models such as PatchTST and CARD, which makes it a transferable upgrade for any forecasting pipeline.

    Does xPatch support multivariate forecasting?

    Yes. The architecture is designed for multivariate inputs. The depthwise convolution in the CNN stream operates per channel (groups = N), and the pointwise convolution mixes information across channels. The linear stream processes each channel through the same weights while preserving the channel dimension. The paper evaluates on datasets with up to 862 channels (Traffic) without modification.

    Related reading

    Related reading:

    External references

    This article is for informational and educational purposes only. It summarizes a publicly available academic paper and is not a substitute for reading the original. Implementation details should be verified against the official repository before production use.

  • Anomaly Detection Metrics Explained: AUROC, AUPRC, F1, Precision, Recall, FAR

    This guide examines the evaluation metrics that are appropriate for anomaly detection systems, in which the positive class is by definition rare. When 99.9 percent of transactions are legitimate, a model that flags every record as “normal” attains 99.9 percent accuracy while delivering no operational value. The choice of evaluation metric is therefore one of the most consequential decisions in an anomaly detection project.

    The discussion proceeds through the metrics that are relevant for this task, from the basic measures (Precision and Recall) to threshold-independent ranking metrics (AUROC and AUPRC) and the specialised time-series metrics (PA-F1 and VUS). For each metric the formula, the trade-offs, and a full Python implementation are presented so that the material can be applied directly.

    Summary

    What this post covers: A complete reference for selecting and computing anomaly detection metrics, including Precision, Recall, F1, FAR, MCC, AUROC, AUPRC, the time-series variants, and Top-K measures. The discussion presents the formulas, the trade-offs, and the Python implementations for ML engineers building rare-event detectors in fraud, intrusion, defects, and biometrics.

    Key insights:

    • Accuracy is degenerate when anomalies are rare. A constant “normal” predictor can score 99.9 percent, so the first decision in any anomaly-detection project is to discard accuracy as the headline metric.
    • For severely imbalanced data (anomalies below 1 percent), AUPRC is the primary ranking metric and AUROC is secondary. AUROC can appear misleadingly high on heavily imbalanced data because the TN count dominates the denominator.
    • Different stakeholders require different metrics for the same model. Engineers focus on AUROC and AUPRC, operations focuses on FAR and alert volume, and finance focuses on dollar-weighted recall. A single number is therefore always a stakeholder choice in disguise.
    • Standard point-wise F1 fails for time-series anomalies because real anomalies are contiguous events, not isolated samples. Range-based F1, VUS, or NAB Score should be used instead.
    • Most production teams should report a small bundle: AUPRC, Precision@K, Recall, and FAR. This combination covers model quality, operational alert volume, miss rate, and false-alarm rate together.

    Main topics: why anomaly metrics matter, the confusion matrix foundation, threshold-dependent metrics, threshold-independent metrics, a decision framework for picking metrics, time-series-specific metrics, Top-K ranking metrics, Python implementations, threshold selection for production, common pitfalls, and domain reporting templates.

    Why Anomaly Detection Metrics Matter and Why Accuracy Does Not

    Consider a scenario in which a team builds a fraud detector and reports that it attains 99.9 percent accuracy. The result appears impressive. When a stakeholder asks how many actual fraud cases the system caught in the previous quarter, however, the answer may be none. The model achieves 99.9 percent accuracy by predicting “not fraud” for every transaction, because the base rate of fraud at a typical payment processor is approximately 0.1 percent. The model is in effect a constant, the accuracy figure is real, and the system is operationally worthless.

    This is the foundational point of anomaly detection: the positive class, namely the anomaly, is rare and sometimes extremely rare. Network intrusions, manufacturing defects, credit-card fraud, and rare diseases all have base rates between approximately 0.01 percent and 5 percent. When the negative class dominates, accuracy becomes a degenerate metric, and a model that predicts “normal” for every input will appear excellent.

    This is the imbalance problem. A second issue is equally important: cost asymmetry. Missing a true anomaly (a false negative) almost always costs more than flagging a legitimate event by mistake (a false positive). A missed credit-card fraud may cost $5,000, while an unnecessary alert costs perhaps 30 seconds of an analyst’s time. These errors are not symmetric, and the chosen metric must reflect the asymmetry.

    Different stakeholders are concerned with different metrics for the same model:

    • The ML engineer requires AUROC and AUPRC for comparing model architectures.
    • The product manager requires Precision@K because the user interface shows the top 50 alerts per day.
    • The operations lead requires False Alarm Rate (FAR) and Mean Time To Detect (MTTD) because analysts must triage every alert.
    • The CFO requires dollar-weighted recall, namely the fraction of fraud value caught, rather than the count of incidents.

    The selection of a single number to optimise implicitly entails a stakeholder choice. The appropriate response is to report a small set of complementary metrics so that each audience receives the information that it requires.

    Key Takeaway: Accuracy is almost never the appropriate metric for anomaly detection. The base rate is too low, and the cost of false negatives is too high. Precision, Recall, F1, AUPRC, and FAR should be used in combinations selected according to the operational objective.

    The Confusion Matrix Foundation

    Every metric in this guide is built from four numbers, namely the cells of the confusion matrix. By convention, in anomaly detection the anomaly is the positive class and the normal point is the negative class.

    Term Definition Fraud Example
    True Positive (TP) Model predicts anomaly, truly is anomaly Caught a fraudulent transaction
    False Positive (FP) Model predicts anomaly, truly is normal Flagged a legitimate purchase
    True Negative (TN) Model predicts normal, truly is normal Correctly cleared a normal payment
    False Negative (FN) Model predicts normal, truly is anomaly Missed a fraudulent transaction

     

    The following is a worked example. Consider 10,000 credit-card transactions in which 100 are fraudulent (a 1 percent anomaly rate) and the model produces the predictions shown below:

    Confusion Matrix—Fraud Detection (1% anomaly rate) Predicted Anomaly (positive) Normal (negative) Actual Anomaly Normal TP = 95 caught fraud (of 100 frauds) FN = 5 missed fraud (slipped past) FP = 30 false alarm (of 9,900 normals) TN = 9,870 correctly cleared normal traffic Derived Metrics Precision = 95/(95+30) = 0.760 Recall = 95/(95+5) = 0.950 F1 = 2·P·R/(P+R) = 0.844 FAR = 30/(30+9870) = 0.0030 Accuracy = 99.65% (misleading) Total = 10,000 | True anomalies = 100 (1%) | Predicted anomalies = 125 Green cells = correct predictions | Red cells = errors Accuracy alone (99.65%) hides the fact that we missed 5 frauds and raised 30 false alarms.

    From the cells above, every metric discussed in this guide is derivable. One observation is important: the accuracy for this model is (95 + 9870) / 10000 = 99.65 percent, which sounds excellent. A constant “always normal” model, however, would score 99.0 percent. The improvement from a real model is therefore only 0.65 percentage points. A comparison of two models on accuracy alone yields almost no useful information.

    The fundamental trade-off in any threshold-based detector is as follows. Lowering the threshold catches more anomalies (TP increases) but also flags more normals (FP increases). Raising the threshold reduces false alarms (FP decreases) but misses more anomalies (FN increases). Every metric in this guide either fixes one threshold and reports performance at that point, or sweeps over all thresholds and summarises the trade-off.

    Threshold-Dependent Metrics: Precision, Recall, F1, FAR, MCC

    These metrics require commitment to a single decision threshold (typically 0.5 for probabilities, or a calibrated value for anomaly scores). Once the threshold is fixed, the four-cell confusion matrix can be computed and the metrics below derived.

    Precision: The Purity of Alerts

    Precision = TP / (TP + FP). The metric answers the question: of everything flagged as anomalous, how many actually were anomalous? In the worked example, Precision = 95/125 = 0.76, which indicates that 76 percent of the alerts were genuine fraud and 24 percent were false alarms.

    Precision matters most in the following contexts:

    • Alert fatigue. If a SOC analyst receives 100 alerts per day of which 90 are incorrect, the analyst will cease to trust the system. The corresponding precision is 0.10.
    • Costly interventions. If acting on an alert involves freezing a customer’s account, the alert must be correct.
    • Limited human review capacity. When only the top 50 cases can be investigated, the investigated cases must be of high quality.

    Recall (Sensitivity, True Positive Rate): The Proportion Caught

    Recall = TP / (TP + FN). The metric answers: of all true anomalies, how many were caught? In the worked example, Recall = 95/100 = 0.95, a 95 percent catch rate.

    Recall matters most in the following contexts:

    • Catastrophic miss costs. Cancer screening, cybersecurity intrusions, and aircraft engine faults are domains in which missing an event is unacceptable.
    • Rare but serious anomalies. When the cost of a false negative greatly exceeds the cost of a false positive.
    • Compliance and regulatory contexts. Anti-money-laundering regulations effectively mandate high recall.

    F1 Score: A Balanced Measure

    F1 = 2·P·R / (P + R) is the harmonic mean of Precision and Recall, constructed so that a low score in either component reduces F1 substantially. In the worked example, F1 = 2 · (0.76)(0.95) / (0.76 + 0.95) = 0.844.

    The harmonic mean is preferred to the arithmetic mean because, for example, Precision = 1.0 and Recall = 0.01 (only one true anomaly flagged out of 100) should not average to 0.505, which would be misleading. The harmonic mean gives 0.0198, which more accurately reflects the model’s poor performance.

    For asymmetric costs, the F-beta measure should be used:

    Fβ = (1 + β2) · P · R / (β2·P + R)

    • β = 1 produces the standard F1, with equal weight on precision and recall.
    • β = 2 produces F2, in which recall is weighted twice as heavily as precision (suitable for medical or security applications).
    • β = 0.5 produces F0.5, in which precision is weighted twice as heavily as recall (suitable for alert-fatigue contexts).

    Specificity (TNR) and False Alarm Rate (FAR/FPR)

    Specificity = TN / (TN + FP) is the fraction of true normals correctly left alone. FAR (= FPR = 1 − Specificity) is the fraction of normals that have been flagged. In the worked example, FAR = 30/9900 = 0.30 percent.

    FAR is the metric that the operations team typically quotes. When 1 million events are processed per day at FAR = 0.5 percent, the result is 5,000 false alarms per day, which is operationally unworkable. Most operational systems target FAR below 0.1 percent or even 0.01 percent and accept the resulting recall.

    False Reject Rate (FRR)

    FRR = FN / (FN + TP) = 1 − Recall. This is biometrics terminology: in face recognition or fingerprint authentication, FRR is the fraction of legitimate users incorrectly rejected. The “False Acceptance Rate” in biometrics is identical to FAR or FPR in this context.

    Matthews Correlation Coefficient (MCC)

    MCC = (TP·TN − FP·FN) / √((TP+FP)(TP+FN)(TN+FP)(TN+FN))

    The range is [−1, +1]. A value of +1 indicates perfect classification, 0 corresponds to random classification, and −1 indicates inverted classification. Unlike F1, MCC uses all four cells of the confusion matrix and remains informative even under severe imbalance. It is particularly useful when a single, balanced number that is not deceived by a majority-class predictor is required.

    Balanced Accuracy

    Balanced Accuracy = (Sensitivity + Specificity) / 2 is the simple average of the per-class accuracies. The “always normal” model achieves 50 percent balanced accuracy regardless of the imbalance. This metric is appropriate when an accuracy-like figure is required that does not reward majority-class prediction.

    Metric Formula Range When to Use
    Precision TP / (TP + FP) [0, 1] Alert fatigue, costly interventions
    Recall (TPR, Sensitivity) TP / (TP + FN) [0, 1] Catastrophic miss costs, security, medical
    F1 2PR / (P + R) [0, 1] Single threshold, balanced trade-off
    Fβ (1+β2)PR / (β2P+R) [0, 1] Asymmetric costs (β>1: recall, β<1: precision)
    Specificity (TNR) TN / (TN + FP) [0, 1] Medical screening (avoid false positives)
    FAR (FPR) FP / (FP + TN) [0, 1] Operations, alert volume control
    FRR (FNR) FN / (FN + TP) [0, 1] Biometrics
    MCC see formula above [−1, 1] Balanced single number for imbalanced data
    Balanced Accuracy (TPR + TNR) / 2 [0, 1] Accuracy-like, imbalance-aware
    AUROC ∫TPR d(FPR) [0, 1] Threshold-free comparison, mild imbalance
    AUPRC (AP) ∫P d(R) [0, 1] Severe imbalance—preferred over AUROC

     

    Threshold-Independent Metrics: AUROC, AUPRC, DET

    The metrics above all assume that a threshold has been chosen. During model development, however, a single number that summarises the model’s quality across all possible thresholds is usually required. Ranking metrics serve this purpose.

    ROC Curve and AUROC

    The Receiver Operating Characteristic (ROC) curve plots TPR (on the y-axis) against FPR (on the x-axis) as the threshold varies. Each point on the curve corresponds to a different decision threshold. The area under this curve, AUROC, has a useful probabilistic interpretation:

    AUROC = P(score(positive) > score(negative))

    If one anomaly and one normal point are drawn at random, AUROC is the probability that the model scores the anomaly higher. A value of 0.5 corresponds to random guessing, 1.0 corresponds to perfect ranking, and 0.95 indicates that 95 percent of randomly chosen pairs are correctly ordered.

    AUROC has useful properties: it is threshold-independent, it is scale-invariant (only the rank order of scores matters), and the random baseline is always exactly 0.5 regardless of class balance. The last property is also its weakness.

    Situations in Which AUROC Misleads

    Consider the following scenario. A dataset of 1 million transactions includes 1,000 fraudulent records (a 0.1 percent rate). The model attains AUROC = 0.97, which sounds impressive. The operational usability is more sobering: at the threshold that produces 1,000 alerts, the model may catch 600 frauds and raise 400 false positives, yielding Precision = 60 percent and Recall = 60 percent. The model still misses 400 frauds, and 40 percent of alerts are false. AUROC = 0.97 has therefore conveyed an impression that the operational reality does not deliver.

    The reason is that AUROC averages TPR over the full FPR range from 0 to 1. In production, however, only the range below approximately 1 percent FPR is of practical interest. Most of the AUROC area is contributed by regions in which the system will never operate. Under severe imbalance, even a sub-1 percent FPR generates substantial numbers of false positives because the negative class is very large.

    Precision-Recall Curve and AUPRC

    The PR curve plots Precision (on the y-axis) against Recall (on the x-axis) as the threshold varies. The area under this curve, AUPRC, also referred to as Average Precision (AP), is considerably more informative for imbalanced data. Saito and Rehmsmeier (2015) demonstrated empirically that PR curves provide a more informative picture than ROC curves when class imbalance is severe.

    The random baseline for AUPRC equals the positive-class fraction. If anomalies constitute 1 percent of the data, a coin-flip detector attains AUPRC of approximately 0.01. Exceeding this baseline by a substantial margin is considerably more demanding than exceeding AUROC’s 0.5 baseline.

    The following figure presents the canonical illustration of the same model evaluated by both curves on a severely imbalanced dataset.

    Same Model, Two Stories, ROC vs PR (1% anomaly rate) ROC Curve AUROC = 0.95 (looks great) False Positive Rate True Positive Rate 0 1 0 1 random model Precision-Recall Curve AUPRC = 0.42 (much less impressive) Recall Precision 0 1 0 1 random = 0.01 model Both panels show the SAME model on the SAME data. AUROC inflates due to the considerable negative class.

    The two curves describe the same model. AUROC = 0.95 suggests a top-tier detector, while AUPRC = 0.42 indicates that the model is adequate but will produce many false positives in production. The PR curve is closer to operational reality.

    Caution: Both AUROC and AUPRC should be reported for imbalanced anomaly detection. Reporting only AUROC for a 0.1 percent anomaly task is, at best, misleading and, at worst, deceptive.

    Detection Error Tradeoff (DET) Curve

    The DET curve is widely used in biometrics and speaker recognition. It plots FAR (on the x-axis) against FRR (on the y-axis), with both axes on a probit (normal-deviate) scale. This transformation stretches the small-error region and facilitates comparison of near-perfect detectors. The Equal Error Rate (EER), the point at which FAR equals FRR, is a single-number summary commonly quoted in this domain.

    When to Use Which Metric: A Decision Framework

    If only one decision aid is to be retained from this article, the following table should be used:

    Situation Recommended Metric(s)
    Severe imbalance (anomalies < 1%) AUPRC (primary), AUROC (secondary)
    Need a single threshold for production F1 (or F-beta if asymmetric costs)
    Operations team cares about alert volume FAR + Recall, or Precision@K
    Cost-sensitive (FN ≫ FP) Recall, F2, cost-weighted score
    Cost-sensitive (FP ≫ FN) Precision, F0.5
    Model selection across architectures AUROC for general comparison; AUPRC if imbalanced
    Reporting to non-technical stakeholders Precision@K, Recall@K, dollar-weighted recall
    Time-series anomaly detection Range-based F1, VUS, NAB Score
    Biometrics / authentication EER, DET curve, FAR @ fixed FRR

     

    Most production teams report a small bundle of metrics: AUPRC, Precision@K, Recall, and FAR. This combination covers model quality, operational alert volume, miss rate, and false-alarm rate, and is sufficient for useful discussion across stakeholder groups.

    Time-Series-Specific Metrics

    Time-series anomaly detection is the domain in which most standard metrics fail. The central issue is that anomalies are typically events, namely contiguous segments of points rather than isolated samples. If a real anomaly lasts from t = 100 to t = 120 (21 timesteps) and a model detects it at t = 103 only, has the model detected the event? Standard point F1 records “1 TP, 20 FN”, which yields a recall of 1/21 = 4.8 percent. Operationally, however, the event has been caught. The label suggests an almost complete miss.

    Several alternative metrics have been proposed. None is fully satisfactory, and the appropriate choice remains a subject of active debate. For a more detailed survey of the models that produce these scores, see the companion guide on time-series anomaly detection models.

    Point-Adjusted (PA) F1

    Proposed in early time-series benchmarks (Xu et al., 2018), Point-Adjusted F1 specifies that if at least one point inside a true anomaly segment is detected, the entire segment is marked as detected. This adjustment substantially addresses the miss-by-one-point problem but it inflates scores in misleading ways. Kim et al. (2022) showed that even random scores can achieve PA-F1 above 0.9 on common benchmarks. PA-F1 should therefore be used with considerable caution and never as the sole metric.

    Range-Based Precision and Recall (Tatbul et al., 2018)

    The seminal paper by Tatbul et al. introduced a parametric framework for range-based recall and precision. Each detection range overlapping a real anomaly range earns partial credit, with adjustable parameters governing the reward for partial overlap (existence, cardinality, or size), the bias toward early or late detection, and the penalty for fragmentation. The framework is principled, configurable, and widely cited, but its parameters require careful selection for each use case.

    NAB Score (Numenta Anomaly Benchmark)

    This metric is designed for streaming detection. Each true anomaly segment is associated with a detection window. Points inside the window earn weighted positive credit (with greater credit for earlier detection), while points outside the window earn weighted negative credit. The result is normalised so that a perfect detector scores 100 and a “no detection” baseline scores 0. NAB is opinionated, since it explicitly rewards early detection, which makes it appropriate for streaming applications and inappropriate for retrospective analysis.

    VUS (Volume Under the Surface, Paparrizos et al., 2022)

    VUS is a range-aware extension of AUROC and AUPRC. Rather than computing area under a 2D curve, VUS computes volume under a 3D surface in which the third dimension is the detection-tolerance buffer. The result is a smooth, parameter-free range-aware metric. VUS-PR is currently among the most defensible single-number summaries for time-series anomaly detection benchmarks.

    Affiliation-Based Metrics (Huet et al., 2022)

    This metric defines a continuous “affiliation” between predicted and true segments based on temporal distance, with statistical normalisation that makes results comparable across datasets. It is more principled than PA-F1 but less widely supported by tooling.

    Metric Range-Aware? Threshold-Free? Notes
    Point F1 No No Penalizes brief detection lag harshly
    Point-Adjusted F1 Partially No Inflates scores; controversial
    Range-Based F1 (Tatbul) Yes No Configurable; needs parameters per use case
    NAB Score Yes No Rewards early detection; for streaming
    VUS-ROC / VUS-PR Yes Yes Modern, parameter-free, recommended
    Affiliation Metrics Yes No Statistical normalization; less tooled

     

    Tip: For new time-series benchmarks, VUS-PR and range-based F1 with documented parameters should be reported. Reliance on PA-F1 alone should be avoided, since recent literature has shown that it can be gamed by random scores.

    Top-K Metrics for Ranking

    In many production environments, the relevant property is not binary classification quality but ranking quality at the top of the list. A SOC analyst reviews the top 50 alerts per shift, and a fraud team escalates the top 100 highest-risk transactions per day. For such contexts, top-K metrics are more appropriate.

    • Precision@K: of the top K most anomalous predictions, the number that correspond to true anomalies. The measure is concrete and operationally meaningful.
    • Recall@K: of all true anomalies, the number that appear in the top K. The measure is useful when a fixed review budget is in place.
    • Mean Average Precision (MAP@K): the average precision computed up to position K, which is sometimes used in ranking contexts.
    • Lift@K: Precision@K divided by the base rate. A lift of 50 indicates that alerts in the top K are 50 times more likely to be anomalies than random samples.

    Top-K metrics require K to be fixed, typically by the available human review capacity. They are less useful for academic comparisons, because different K values produce different rankings, but they are essential for production health monitoring.

    Practical Implementation in Python

    The following section presents the implementations. The discussion proceeds from the confusion matrix to bootstrapped AUROC confidence intervals, providing both scikit-learn shortcuts and from-scratch implementations.

    Setup and Synthetic Data

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.metrics import (
        confusion_matrix, precision_score, recall_score, f1_score,
        fbeta_score, roc_auc_score, average_precision_score,
        roc_curve, precision_recall_curve, matthews_corrcoef,
        balanced_accuracy_score
    )
    
    np.random.seed(42)
    
    # 10,000 samples, 1% anomaly rate
    n = 10_000
    anomaly_rate = 0.01
    y_true = np.random.binomial(1, anomaly_rate, size=n)
    
    # Synthetic anomaly score: anomalies tend to score higher
    # Normal points: Beta(2, 5) -> mean ~0.29
    # Anomalies: shifted up by 0.4 (clipped at 1.0)
    y_score = np.random.beta(2, 5, size=n) + y_true * 0.4
    y_score = np.clip(y_score, 0, 1)
    
    print(f"Total samples: {n}")
    print(f"Anomalies: {y_true.sum()} ({y_true.mean()*100:.2f}%)")
    print(f"Score range: [{y_score.min():.3f}, {y_score.max():.3f}]")

    Building the Confusion Matrix from Scratch

    def confusion_from_scratch(y_true, y_pred):
        """Compute (TN, FP, FN, TP) without sklearn."""
        y_true = np.asarray(y_true).astype(int)
        y_pred = np.asarray(y_pred).astype(int)
        TP = int(((y_pred == 1) & (y_true == 1)).sum())
        FP = int(((y_pred == 1) & (y_true == 0)).sum())
        TN = int(((y_pred == 0) & (y_true == 0)).sum())
        FN = int(((y_pred == 0) & (y_true == 1)).sum())
        return TN, FP, FN, TP
    
    threshold = 0.5
    y_pred = (y_score >= threshold).astype(int)
    
    TN, FP, FN, TP = confusion_from_scratch(y_true, y_pred)
    print(f"TP = {TP}, FP = {FP}, TN = {TN}, FN = {FN}")
    
    # Verify against sklearn
    cm = confusion_matrix(y_true, y_pred)
    assert (TN, FP, FN, TP) == (cm[0,0], cm[0,1], cm[1,0], cm[1,1])

    All Threshold-Dependent Metrics, From Scratch

    def metrics_from_confusion(TN, FP, FN, TP):
        """Compute every threshold-dependent metric from a confusion matrix."""
        eps = 1e-12
        precision = TP / (TP + FP + eps)
        recall    = TP / (TP + FN + eps)        # TPR / sensitivity
        specificity = TN / (TN + FP + eps)       # TNR
        fpr = FP / (FP + TN + eps)               # FAR / FPR
        fnr = FN / (FN + TP + eps)               # FRR
        accuracy = (TP + TN) / (TP + TN + FP + FN + eps)
        balanced_acc = (recall + specificity) / 2
        f1 = 2 * precision * recall / (precision + recall + eps)
        f2 = 5 * precision * recall / (4 * precision + recall + eps)
        f05 = 1.25 * precision * recall / (0.25 * precision + recall + eps)
        # MCC
        num = TP * TN - FP * FN
        den = np.sqrt((TP+FP) * (TP+FN) * (TN+FP) * (TN+FN) + eps)
        mcc = num / den
    
        return {
            "Precision": precision, "Recall": recall, "Specificity": specificity,
            "FAR (FPR)": fpr, "FRR (FNR)": fnr, "Accuracy": accuracy,
            "BalancedAcc": balanced_acc, "F1": f1, "F2": f2, "F0.5": f05, "MCC": mcc,
        }
    
    m = metrics_from_confusion(TN, FP, FN, TP)
    for k, v in m.items():
        print(f"  {k:14s} = {v:.4f}")
    
    # Verify with sklearn
    assert abs(m["F1"] - f1_score(y_true, y_pred)) < 1e-6
    assert abs(m["MCC"] - matthews_corrcoef(y_true, y_pred)) < 1e-6
    assert abs(m["BalancedAcc"] - balanced_accuracy_score(y_true, y_pred)) < 1e-6

    AUROC and AUPRC With sklearn

    auroc = roc_auc_score(y_true, y_score)
    auprc = average_precision_score(y_true, y_score)
    print(f"AUROC = {auroc:.4f}  (random baseline = 0.5)")
    print(f"AUPRC = {auprc:.4f}  (random baseline = {y_true.mean():.4f})")

    Plotting ROC and PR Curves

    fpr, tpr, _ = roc_curve(y_true, y_score)
    prec, rec, _ = precision_recall_curve(y_true, y_score)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.plot(fpr, tpr, lw=2, label=f"Model (AUROC = {auroc:.3f})")
    ax1.plot([0, 1], [0, 1], "--", color="gray", label="Random")
    ax1.set_xlabel("False Positive Rate")
    ax1.set_ylabel("True Positive Rate")
    ax1.set_title("ROC Curve")
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    ax2.plot(rec, prec, lw=2, color="crimson", label=f"Model (AUPRC = {auprc:.3f})")
    ax2.axhline(y=y_true.mean(), linestyle="--", color="gray",
                label=f"Random = {y_true.mean():.3f}")
    ax2.set_xlabel("Recall")
    ax2.set_ylabel("Precision")
    ax2.set_title("Precision-Recall Curve")
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("roc_pr_curves.png", dpi=120)

    Finding the Optimal F1 Threshold

    prec, rec, thresholds = precision_recall_curve(y_true, y_score)
    # precision_recall_curve returns one extra point; align with thresholds
    prec_t, rec_t = prec[:-1], rec[:-1]
    
    f1_curve = 2 * prec_t * rec_t / (prec_t + rec_t + 1e-12)
    best_idx = int(np.argmax(f1_curve))
    best_threshold = thresholds[best_idx]
    best_f1 = f1_curve[best_idx]
    
    print(f"Best F1 = {best_f1:.4f} at threshold = {best_threshold:.4f}")
    print(f"  Precision = {prec_t[best_idx]:.4f}")
    print(f"  Recall    = {rec_t[best_idx]:.4f}")

    Sweeping the Threshold

    def threshold_sweep(y_true, y_score, n_thresholds=100):
        """Compute Precision, Recall, F1, FAR for a grid of thresholds."""
        grid = np.linspace(y_score.min(), y_score.max(), n_thresholds)
        rows = []
        for t in grid:
            y_pred = (y_score >= t).astype(int)
            TN, FP, FN, TP = confusion_from_scratch(y_true, y_pred)
            m = metrics_from_confusion(TN, FP, FN, TP)
            rows.append([t, m["Precision"], m["Recall"], m["F1"], m["FAR (FPR)"]])
        return np.asarray(rows)
    
    sweep = threshold_sweep(y_true, y_score, 200)
    t_grid, prec_g, rec_g, f1_g, far_g = sweep.T
    
    plt.figure(figsize=(9, 5))
    plt.plot(t_grid, prec_g, color="#e74c3c", label="Precision")
    plt.plot(t_grid, rec_g,  color="#3498db", label="Recall")
    plt.plot(t_grid, f1_g,   color="#27ae60", label="F1")
    plt.plot(t_grid, far_g,  color="#f39c12", label="FAR")
    plt.axvline(best_threshold, linestyle="--", color="black", alpha=0.6,
                label=f"Best F1 t={best_threshold:.3f}")
    plt.xlabel("Threshold")
    plt.ylabel("Metric value")
    plt.title("Metric vs Threshold (1% anomaly rate)")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()

    Threshold Trade-off, Precision, Recall, F1, FAR Decision Threshold Metric Value 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 Recall Precision F1 FAR Best F1 t* ≈ 0.55

    Cost-Weighted Metric

    def cost_weighted_score(y_true, y_pred, c_fp=1.0, c_fn=10.0):
        """Lower is better. Useful when FN costs ~10x more than FP."""
        TN, FP, FN, TP = confusion_from_scratch(y_true, y_pred)
        return c_fp * FP + c_fn * FN
    
    def best_threshold_by_cost(y_true, y_score, c_fp=1.0, c_fn=10.0, n=200):
        grid = np.linspace(y_score.min(), y_score.max(), n)
        costs = []
        for t in grid:
            y_pred = (y_score >= t).astype(int)
            costs.append(cost_weighted_score(y_true, y_pred, c_fp, c_fn))
        best = int(np.argmin(costs))
        return grid[best], costs[best]
    
    t_cost, c_cost = best_threshold_by_cost(y_true, y_score, c_fp=1, c_fn=20)
    print(f"Cost-optimal threshold = {t_cost:.4f}, total cost = {c_cost:.0f}")

    Bootstrap Confidence Intervals: An Often Overlooked Step

    Single-number reports without uncertainty estimates are problematic. A 1,000-sample test set containing 10 positives can produce widely varying AUPRC values across reasonable bootstrap resamples. The bootstrap is the standard method for attaching a confidence interval. The reason that averaging across many resamples produces a stable estimate derives from the Central Limit Theorem.

    def bootstrap_ci(y_true, y_score, metric_fn, n_boot=1000, alpha=0.05, seed=0):
        """Bootstrap percentile CI for any score-based metric."""
        rng = np.random.default_rng(seed)
        n = len(y_true)
        scores = []
        for _ in range(n_boot):
            idx = rng.integers(0, n, size=n)
            y_t, y_s = y_true[idx], y_score[idx]
            if y_t.sum() == 0 or y_t.sum() == n:
                continue  # degenerate resample
            scores.append(metric_fn(y_t, y_s))
        scores = np.asarray(scores)
        lo = np.quantile(scores, alpha/2)
        hi = np.quantile(scores, 1 - alpha/2)
        return float(np.mean(scores)), (float(lo), float(hi))
    
    mean_auroc, ci_auroc = bootstrap_ci(y_true, y_score, roc_auc_score, n_boot=500)
    mean_auprc, ci_auprc = bootstrap_ci(y_true, y_score, average_precision_score, n_boot=500)
    
    print(f"AUROC = {mean_auroc:.4f}  95% CI [{ci_auroc[0]:.4f}, {ci_auroc[1]:.4f}]")
    print(f"AUPRC = {mean_auprc:.4f}  95% CI [{ci_auprc[0]:.4f}, {ci_auprc[1]:.4f}]")

    Time-Series PA-F1 Implementation

    def get_event_segments(y):
        """Return list of (start, end_inclusive) for runs of 1s."""
        y = np.asarray(y).astype(int)
        if len(y) == 0:
            return []
        diff = np.diff(np.concatenate(([0], y, [0])))
        starts = np.where(diff == 1)[0]
        ends   = np.where(diff == -1)[0] - 1
        return list(zip(starts.tolist(), ends.tolist()))
    
    def point_adjusted_predictions(y_true, y_pred):
        """Apply Point-Adjusted (PA) protocol: if any point inside a true
        anomaly segment is detected, flag the entire segment as detected."""
        y_pred = y_pred.copy().astype(int)
        for s, e in get_event_segments(y_true):
            if y_pred[s:e+1].any():
                y_pred[s:e+1] = 1
        return y_pred
    
    # Worked example
    y_t = np.array([0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0])
    y_p = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
    
    print("Raw point F1     =", round(f1_score(y_t, y_p), 4))
    y_pa = point_adjusted_predictions(y_t, y_p)
    print("PA-adjusted pred =", y_pa.tolist())
    print("PA-F1            =", round(f1_score(y_t, y_pa), 4))

    In this example the raw point F1 is approximately 0.18 (one TP, two FN inside the first event, one FP outside, and no detection on the second event). After point adjustment, the entire first event is marked as "detected" because one point inside it was flagged, and recall increases substantially. This is the inflation effect that Kim et al. (2022) identified: PA-F1 can appear impressive even when the underlying detection is weak. For range-aware alternatives, the VUS package or the Tatbul range-based implementation in the tsad Python library is recommended.

    Selecting the Threshold for Production

    Once the model has been trained and AUROC and AUPRC are acceptable, the question is which threshold to deploy. The five common strategies are presented below, ordered from the simplest to the most sophisticated.

    Maximise F1 on the Validation Set

    Thresholds are swept on a held-out validation set, and the one with the highest F1 is selected. The procedure is simple, defensible, and yields a balanced precision and recall point. Important caveat: the threshold should never be selected on the test set, as this constitutes data leakage. Validation data must always be reserved for hyperparameter and threshold selection.

    Fixed FAR Budget

    This is the operations-driven approach. For example, if the team can handle 100 alerts per day across 1 million events per day, FAR must be at most 0.01 percent. The threshold corresponding to FAR = 0.0001 on the validation set is selected, and the corresponding recall is reported. Most cybersecurity and network monitoring systems in production are tuned in this way.

    def threshold_for_far_budget(y_true, y_score, far_budget=0.001):
        """Largest recall achievable subject to FAR ≤ far_budget."""
        fpr, tpr, thr = roc_curve(y_true, y_score)
        feasible = fpr <= far_budget
        if not feasible.any():
            return None, 0.0, 0.0
        idx = np.argmax(tpr * feasible)
        return float(thr[idx]), float(tpr[idx]), float(fpr[idx])
    
    t, r, f = threshold_for_far_budget(y_true, y_score, far_budget=0.005)
    print(f"Threshold = {t:.4f}, Recall = {r:.4f} at FAR = {f:.4f}")

    Cost-Weighted Optimisation

    If the dollar cost of a false positive (such as analyst time and customer impact) and a false negative (such as missed fraud value) can be quantified, the threshold that minimises CFP·FP + CFN·FN should be selected. This is the most defensible approach when the asymmetry is well understood.

    Top-K Selection

    This approach forgoes the threshold entirely. Scores are ranked and the top K cases are selected. It is appropriate when human review capacity is the binding constraint and alert volume per period is fixed.

    Sliding or Contextual Threshold

    Time-of-day, day-of-week, or per-segment thresholds may be used. A retail fraud detector might use a threshold of 0.6 on weekday afternoons and 0.4 on holiday weekends. Implementation typically involves a small lookup table or a contextual model that outputs both score and threshold.

    Caution: Thresholds drift. As the data distribution shifts because of seasonal effects and the evolution of fraud patterns, the threshold that maximised F1 in January may produce twice the alert volume in June. Monthly threshold retuning should be scheduled, and precision and FAR should be monitored continuously.

    Common Pitfalls to Avoid

    The most frequently encountered errors across anomaly detection projects in fraud, manufacturing, security, and healthcare are listed below.

    • Reporting AUROC without AUPRC on imbalanced data. AUROC = 0.99 with 0.1 percent positives often corresponds to AUPRC = 0.40. Both should always be reported.
    • Reporting accuracy. For anomaly detection, accuracy is almost always uninformative. The "always negative" baseline outperforms most real models on accuracy.
    • Selecting the threshold on the test set. Tuning should be performed on the validation set, and evaluation on the test set. Maximising F1 across thresholds on the same test set constitutes overfitting.
    • Not using stratified k-fold. With 1 percent positives in 1,000 samples, a random fold may contain zero positives in the validation split. StratifiedKFold should be used.
    • Ignoring confidence intervals. A reported AUPRC of 0.42 ± 0.15 (95 percent CI) is qualitatively different from 0.42 ± 0.02. Bootstrap intervals should be computed and reported.
    • Comparing models on different test sets. This is not a like-for-like comparison. The same fixed test set must be used across all model comparisons.
    • Using point F1 for time series. A single-step detection lag reduces the score substantially. Range-based metrics or VUS should be used instead.
    • Confusion between microaverage and macroaverage in multi-class anomaly settings. Microaverage favours common classes; macroaverage equalises them. The choice must be made deliberately and documented.
    • Treating PA-F1 as a definitive measure. It can be inflated by random noise. If used, it should be reported alongside non-PA metrics.
    • Optimising offline metrics that do not translate to deployment. When the business operates on alert-volume budgets, the metric that respects that constraint should be optimised, rather than F1 alone.

    Real-World Reporting Templates by Domain

    Different domains converge on different metric stacks. The following recommendations are distilled from observed production systems. For more detailed treatment of the underlying anomaly detection methods, the companion guides on Deep SVDD and One-Class SVM may be consulted.

    Domain Recommended Metric Stack Why
    Fraud detection AUPRC, Precision@K, Recall, $-weighted recall Severe imbalance + dollar asymmetry
    Network intrusion AUROC, Precision, FAR @ fixed Recall Operations cares about alert volume
    Medical screening Sensitivity (Recall), Specificity, AUROC Regulatory norms; symmetric reporting
    Industrial sensor Range-based F1, Precision@K, time-to-detect Time-series events; early detection valued
    Server monitoring Precision@K, MTTD, false-alert-per-day Streaming context, on-call workload
    Biometrics / authentication EER, DET curve, FAR @ fixed FRR Field-standard reporting
    Anti-money-laundering Recall + Precision@K, regulatory alert quality Compliance sets minimum recall
    Manufacturing defect Recall, Precision, cost-weighted score Defect cost vs over-inspection cost

     

    If the model is built on top of transfer learning or fine-tuning approaches, the same metric framework applies, although particular caution should be taken with confidence intervals, since pre-training source-target distribution gaps can render small test sets highly noisy.

    Key Takeaway: A robust default reporting set for any anomaly detection project comprises AUPRC, Precision@K, Recall, and FAR, each reported with bootstrap 95 percent confidence intervals and a documented threshold. This combination covers model quality, top-of-list usefulness, miss rate, and operational alert volume.

    Frequently Asked Questions

    Why isn't accuracy a good metric for anomaly detection?

    Because anomalies are rare. If 99% of your data is normal, a "predict normal always" model achieves 99% accuracy without learning anything. Real models barely lift accuracy by a few tenths of a percentage point, so accuracy can't distinguish good models from useless ones. Use AUPRC, F1, or Precision@K instead.

    AUROC vs AUPRC—when should I use which?

    For mild imbalance (positives 5–50%), AUROC and AUPRC tell roughly similar stories, and AUROC is fine. For severe imbalance (positives below 1%), AUROC inflates because most of its area comes from FPR regions you'll never operate in. AUPRC is more honest because its random baseline equals the positive class fraction. Best practice: report both, but rely on AUPRC for imbalanced anomaly detection.

    How do I pick a threshold for production?

    Pick the strategy that matches your business constraint. If your team has a fixed alert-review budget, use top-K or fixed-FAR. If you can quantify costs, optimize C_FP·FP + C_FN·FN. If neither, maximize F1 on a held-out validation set. Always select the threshold on validation, evaluate on test, and re-tune monthly as data shifts.

    What's the difference between FAR and FPR?

    None — they are the same metric: FP / (FP + TN). "False Alarm Rate" is the operations and biometrics term; "False Positive Rate" is the statistical term. Some literature also uses "False Acceptance Rate" (biometrics, identical concept) or "Type I Error rate" (classical statistics).

    Are time-series anomaly detection metrics different?

    Yes. Anomalies in time series are typically contiguous events, not isolated points, so naive point-wise F1 over-penalises brief detection lag. Use range-based metrics (Tatbul et al., 2018), VUS-PR (Paparrizos et al., 2022), or NAB Score for streaming. Reliance on Point-Adjusted F1 alone should be avoided, since recent work has shown that it can be gamed by random noise.

    References and Further Reading

    External References:

    • scikit-learn metrics documentation—https://scikit-learn.org/stable/modules/model_evaluation.html
    • Saito, T. & Rehmsmeier, M. (2015). "The Precision-Recall Plot Is More Informative than the ROC Plot When Evaluating Binary Classifiers on Imbalanced Datasets." PLOS ONE.
    • Tatbul, N., Lee, T. J., Zdonik, S., Alam, M., & Gottschlich, J. (2018). "Precision and Recall for Time Series." NeurIPS.
    • Paparrizos, J., Boniol, P., Palpanas, T., Tsay, R., Elmore, A., & Franklin, M. (2022). "Volume Under the Surface: A New Accuracy Evaluation Measure for Time-Series Anomaly Detection." VLDB.
    • Numenta Anomaly Benchmark (NAB),https://github.com/numenta/NAB
    • Huet, A., Navarro, J. M., & Rossi, D. (2022). "Local Evaluation of Time Series Anomaly Detection Algorithms." KDD.
    • Kim, S. et al. (2022). "Towards a Rigorous Evaluation of Time-Series Anomaly Detection." AAAI.

    This article is for informational purposes only and does not constitute investment, security, or medical advice. Always validate metrics against your specific operational context.

  • GP-Based Hyperparameter Optimization: Bayesian Tuning for ML Models

    Summary

    What this post covers: A practitioner’s guide to tuning ML hyperparameters with Gaussian Process Bayesian Optimization, walking through the full BayesOpt pipeline, acquisition functions, search-space design, and four working Python implementations (scikit-optimize, BoTorch, qNEHVI multi-objective, Optuna+BoTorch).

    Key insights:

    • GP-based Bayesian optimization typically reaches a good configuration in approximately twenty trials, compared with roughly sixty for random search and millions for grid search. It is the appropriate default whenever each training run requires substantial GPU time.
    • GPs perform well for HPO because they natively model observation noise, quantify uncertainty across the search space, and produce a smooth surrogate that an acquisition function can exploit. This combination accounts for their sample efficiency in low-to-moderate dimensions.
    • The choice of acquisition function matters. Expected Improvement is the safe default, UCB exposes an explicit explore-versus-exploit parameter, and Thompson Sampling or qNEHVI are preferable when parallel batches or multi-objective Pareto fronts are required.
    • Search-space design—log-uniform priors for learning rate, integer dimensions, conditional parameters—frequently determines success more than the choice of optimizer. Combining GP-BO with Hyperband (BOHB) is the practical optimum once tens of GPUs are available.
    • For most teams, the appropriate stack is Optuna with the BoTorch sampler. It handles mixed and conditional spaces, parallelizes effectively, and provides GP-grade sample efficiency without requiring direct BoTorch use.

    Main topics: Why Hyperparameter Tuning Is Hard, The HPO Landscape: A Survey of Methods, Why Gaussian Processes Are Effective for HPO, The Full BayesOpt Pipeline for HPO, Acquisition Functions Examined in Detail, Search Space Design, Full Python Implementation, Multi-Fidelity and Parallel HPO, Tools Comparison, Real-World Case Studies, Practical Guide and Pitfalls.

    Tuning a ten-hyperparameter neural network by grid search with five values per dimension requires 9.7 million experiments. Random search reaches a comparable configuration in approximately sixty trials. Gaussian Process Bayesian Optimization typically requires twenty. The level of accuracy is the same; the compute requirement is reduced by a factor of roughly half a million.

    This gap explains why GP-based hyperparameter optimization moved from an academic curiosity to the production default at Google, Meta, and OpenAI. When a single training run requires hours and costs hundreds of dollars in GPU time, grid search is economically infeasible. Random search is unreliable because it cannot incorporate the knowledge accumulated from previous trials. The optimizer must reason between trials, selecting the next configuration in light of every prior one.

    Gaussian Processes provide the mathematical machinery that makes this possible. A GP fits a smooth surrogate to the validation-loss landscape, quantifies its own uncertainty across the search space, and an acquisition function converts that uncertainty into a principled decision about where to evaluate next.

    This post is a practitioner guide. It does not re-derive GP regression; for the underlying mathematics covering kernels, posterior inference, and marginal likelihood, the Gaussian Process fundamentals post with Python and GPyTorch is the appropriate reference. The focus here is the applied question: how to tune XGBoost, a CNN, or a transformer using GP-based Bayesian optimization in production.

    The remainder of the article presents four working code examples (scikit-optimize on XGBoost, BoTorch on a CNN, multi-objective BO with qNEHVI, and Optuna with the BoTorch sampler), a discussion of common acquisition functions, three accompanying diagrams, and a considered recommendation regarding tools.

    Why Hyperparameter Tuning Is Hard

    Before considering the merits of GPs, it is useful to acknowledge what makes HPO genuinely difficult, since this difficulty is what justifies the additional machinery of Bayesian optimization.

    The Combinatorial Explosion

    A typical modern machine-learning model has between ten and thirty tunable hyperparameters. A baseline XGBoost has ten to fifteen (learning rate, max depth, n_estimators, subsample, colsample_bytree, min_child_weight, gamma, reg_alpha, reg_lambda, scale_pos_weight, and others). A vision transformer has more (depth, width, heads, MLP ratio, patch size, dropout, attention dropout, learning rate, weight decay, warmup, label smoothing, mixup alpha, drop path, EMA, and similar).

    Grid-searching ten hyperparameters with five values each requires 510 ≈ 9.77 million configurations. At thirty minutes per training run on a single GPU, this amounts to 5,580 GPU-years. Even with substantial parallelism, the approach is infeasible.

    Non-Trivial Interactions

    Hyperparameters are not independent. The optimal learning rate depends on batch size (the linear scaling rule), on the optimizer (Adam versus SGD), on weight initialization, and on architectural depth. Grid search assumes that hyperparameters can be examined one at a time, which is incorrect.

    Random search handles this better because it samples jointly and therefore observes interactions. It nevertheless wastes compute on unpromising regions because it has no memory between trials.

    Each Evaluation Is Expensive

    Training a single configuration can take from minutes for a small XGBoost model to days for a large language model fine-tune. When each evaluation costs $50 to $500 in cloud GPU time, sample efficiency moves from an academic preference to a budgetary necessity.

    Noise

    The same hyperparameters produce different validation losses across random seeds. Variance arising from data shuffling, dropout randomness, weight initialization, and stochastic optimization means that every observation is noisy. A naive optimizer interprets this noise as signal. GPs handle observation noise natively through the kernel, which is a built-in advantage.

    Mixed Types and Conditional Spaces

    Real search spaces include continuous parameters such as the learning rate, integers such as max depth and the number of layers, categoricals such as activation function and optimizer choice, and conditional dimensions: the dropout rate matters only if dropout is enabled, and momentum matters only for SGD, not Adam. Standard GPs assume continuous Euclidean inputs, so this is a substantive engineering challenge that the search-space section addresses.

    Key Takeaway: HPO is difficult because the search space is substantial and irregularly shaped, evaluations are expensive, observations are noisy, and no gradient is available. Each of these properties points away from grid search and toward a sample-efficient, model-based optimizer, which is precisely the role of GP-based Bayesian optimization.

    The HPO Landscape: A Survey of Methods

    Before focusing on GPs, the practical taxonomy of methods encountered in real applications is summarized below.

    Grid search evaluates the Cartesian product of values for each hyperparameter. It is easy to implement and easy to parallelize, but markedly inefficient. The approach breaks down beyond four or five hyperparameters because of the curse of dimensionality. It is appropriate only for very small problems or the final pinning of two or three parameters.

    Random search samples uniformly from the search space. Bergstra and Bengio (2012) demonstrated that it outperforms grid search because most hyperparameters do not matter equally; random search effectively projects onto the important axes. It is the baseline that every other method should be able to surpass. A method that cannot exceed random search is not functioning correctly.

    Evolutionary and Genetic Algorithms

    An evolutionary algorithm maintains a population of configurations, applies mutation and recombination, and selects the fittest. The method parallelizes well, requires no gradient, and handles unusual search spaces. Sample efficiency is moderate—better than random search but usually worse than Bayesian optimization. The approach is used extensively in neural architecture search (Regularized Evolution and AmoebaNet, for example). For a more detailed exposition, see the genetic algorithm Python implementation guide.

    Bandit-Based Methods: Hyperband and ASHA

    Bandit-based methods frame HPO as a multi-armed bandit problem. Many configurations are run for a small budget, the worst are eliminated, the budget for survivors is doubled, and the process repeats. Successive Halving is the core idea. Hyperband sweeps over different initial budgets to hedge against poor fidelity choices, and ASHA is the asynchronous variant that scales to substantial parallelism. These are multi-fidelity methods that use cheap proxies, such as early epochs, to filter more expensive trials.

    Bayesian Optimization with GPs

    This method fits a GP surrogate to pairs of (hyperparameter, validation_loss) values and uses an acquisition function to select the next trial. It is sample-efficient, provides principled uncertainty quantification, and is theoretically well grounded. It is the focus of this post.

    TPE (Tree-Structured Parzen Estimator)

    TPE is a Bayesian optimization method with a different surrogate. Rather than a GP, it models two densities, p(x | y < threshold) and p(x | y ≥ threshold), and selects x to maximize their ratio. TPE handles conditional spaces natively, scales well to higher dimensions, and underpins the default samplers in Optuna and HyperOpt. It is less sample-efficient than GP-based BO in low dimensions but more flexible in high dimensions and with mixed types.

    A Hybrid Method: BOHB

    Falkner et al. (2018) combined Bayesian Optimization (with TPE) and Hyperband. The combination yields the compute efficiency of Hyperband through early stopping and the informed sampling of BO in place of random sampling within rungs. BOHB is frequently the appropriate default for deep-learning HPO when tens of GPUs are available.

    HPO Method Sample Efficiency on 10-D Problem Approximate trials needed to reach a good configuration (lower is better) Trials (log scale) 10 100 1k 10k 100k 1M+ Grid Search ~9.7M (off-chart) Random Search ~100 Genetic Algorithm ~80 TPE (Optuna) ~40 GP-BO (BoTorch) ~25 BOHB (multi-fid.) ~15 Winners: GP-BO (low-D) BOHB (deep nets) Note: BOHB advantage assumes you can early-stop confidently from partial training curves.

    Quick Decision: When to Use What

    Method Sample Efficiency Parallelism Complexity Categorical Support Best For
    Grid Search Very low Trivial Trivial Native ≤3 hyperparams, final pinning
    Random Search Low Trivial Trivial Native Baseline, exploration phase
    Genetic Algorithm Medium Excellent Medium Native NAS, irregular spaces
    Hyperband / ASHA Medium Excellent Medium Native Big compute, slow training
    TPE High Good Medium Native, conditional Mixed types, conditional spaces
    GP-BO Highest Good (qEI/Thompson) High Custom kernels needed ≤20 dims, expensive evals
    BOHB Highest Excellent High Native (TPE-based) Deep learning at scale

     

    Why Gaussian Processes Are Effective for HPO

    For the majority of real HPO problems—those with fewer than twenty dimensions, expensive evaluations, and largely continuous parameters—GP-based BO is the strongest method on every published benchmark. The reasons are as follows.

    Sample Efficiency Is Paramount

    When each evaluation requires hours of GPU time, the few seconds of overhead associated with fitting a GP are inconsequential. The objective is to make every trial count. GPs use the full information of every prior observation when selecting the next one. Random search discards that information.

    Principled Uncertainty

    A GP does not merely predict the loss; it predicts the loss and a confidence interval. This capability enables intelligent exploration. The GP identifies the regions in which it is uncertain, and the acquisition function exploits this information. Without a probabilistic surrogate, “exploration” reduces to random sampling.

    Smooth Surrogate for a Smooth Landscape

    Hyperparameter loss landscapes are typically smooth, particularly in log-space coordinates such as learning rate and weight decay. The Matérn 5/2 kernel is a near-perfect inductive bias for this property. GPs interpolate cleanly between observations and provide a credible map of the search space after just ten to twenty trials.

    Calibrated Exploration and Exploitation

    Acquisition functions such as Expected Improvement automatically balance exploitation (sampling where the model predicts high quality) with exploration (sampling where the model is uncertain). The trade-off emerges from the mathematics rather than from a hand-tuned epsilon-greedy mechanism.

    Effective Range: at Most Approximately Twenty Dimensions

    GPs become unwieldy beyond approximately twenty dimensions because the kernel struggles to model meaningful similarity in high-dimensional Euclidean space. Fortunately, the vast majority of HPO problems fall within this regime. For higher dimensions, the discussion of TuRBO and random embeddings applies.

    Tip: If the search space has fewer than twenty dimensions, a few seconds of GP-fitting overhead per trial is tolerable, and each trial is expensive (more than a minute), GP-based BO is almost always the appropriate choice. The principal exceptions are extreme parallelism (use Thompson sampling), conditional spaces (use TPE), and genuinely high-dimensional problems (use TuRBO).

    The Full BayesOpt Pipeline for HPO

    The operation of GP-based Bayesian optimization is described step by step below. The loop is the one implemented in BoTorch, scikit-optimize, and Optuna’s GP sampler.

    Step 1: Define the Search Space

    Specify the bounds and type of each hyperparameter, choosing among continuous (with optional log scale), integer, and categorical. This step is responsible for most production errors: bounds set too tight miss the optimum, bounds set too wide waste trials in poor regions, and incorrect scales (linear rather than log for the learning rate, for example) degrade the optimizer.

    Step 2: Initial Random Trials

    Five to ten random configurations should be run to seed the GP. Without these observations the GP has no signal, and the acquisition function repeatedly selects the geometric center of the search box. A common rule of thumb is n_init = max(5, 2 · d), where d is the search-space dimension.

    Step 3: Fit the GP Surrogate

    Given observations (x1, y1), …, (xn, yn), fit a GP with a Matérn 5/2 kernel, which is the standard default for HPO. Optimize the kernel hyperparameters (lengthscales, signal variance, noise) by maximizing the marginal likelihood. This takes seconds for n < 1000.

    Step 4: Optimize the Acquisition Function

    The acquisition function α(x) takes the GP posterior and returns a scalar that expresses the value of evaluating at x. Maximize α(x) over the search space using L-BFGS, multi-start methods, or random sampling for non-smooth cases. The argmax is the next trial.

    Step 5: Run the Trial

    Train the model with the proposed hyperparameters and record (xn+1, yn+1).

    Step 6: Update and Repeat

    Append the new observation, refit the GP, optimize the acquisition function again, and propose the next trial. The loop continues until the budget is exhausted.

    BayesOpt Iterations: GP Posterior + Acquisition Function Iteration 5: wide uncertainty, exploring next x Iteration 10: narrowing on promising area next x Iteration 20: converged on local minimum explore Iteration 30: confirmed global optimum ★ optimum acquisition flat—converged GP mean GP ±2σ observed trial next trial acquisition α(x)

    Caution: A trade-off seldom mentioned: GP fitting combined with acquisition optimization introduces one to ten seconds of overhead per trial. When each trial completes in five seconds, as for a small model on a small dataset, this overhead dominates and BO underperforms random search. BO is advantageous specifically when each trial requires minutes to days. Applying BO to a scikit-learn linear regression is therefore inappropriate.

    Acquisition Functions Examined in Detail

    The acquisition function is the mechanism by which exploration is balanced with exploitation. The choice of acquisition function matters less than is sometimes claimed; Expected Improvement is appropriate in roughly 90 percent of cases. Nonetheless, an understanding of the alternatives is helpful when diagnosing problems.

    Expected Improvement (EI)

    EI(x) = E[max(0, fbest − f(x))], that is, the expected improvement over the current best. For a Gaussian posterior with mean μ(x) and standard deviation σ(x), the expression has a closed form.

    EI(x) = (fbest − μ(x)) · Φ(z) + σ(x) · φ(z), where z = (fbest − μ(x)) / σ(x).

    Φ denotes the standard normal CDF and φ the PDF. The expression is smooth, differentiable, and well-behaved. EI is the default choice. It exhibits a slight bias toward exploitation, but in practice it explores adequately because σ(x) is large in unexplored regions.

    Upper Confidence Bound (UCB)

    UCB(x) = μ(x) − β · σ(x) for minimization, with sign flipped for maximization. The coefficient β explicitly controls the level of exploration: larger values produce more exploration. Theoretical regret bounds (Srinivas et al., 2010) establish that, with βt growing logarithmically, UCB has sublinear cumulative regret. In practice, β = 2 is a reasonable default. UCB is more aggressive about exploration than EI when σ is large.

    Probability of Improvement (PI)

    PI(x) = P(f(x) < fbest) = Φ(z), which is simply the probability of any improvement over the current best. PI is purely greedy: it selects any small improvement and can stagnate by exploiting near the current best indefinitely. It is rarely used in modern HPO except as a pedagogical example.

    Thompson Sampling

    Thompson sampling draws a function from the GP posterior and takes its argmin. The method is naturally diverse, since independent posterior samples select different points. Its principal advantage is trivial parallelization: for batch HPO of size k, k posterior samples can be drawn and their argmins evaluated simultaneously. It is widely used in production systems with many parallel workers.

    Knowledge Gradient (KG)

    EI is myopic: it considers only the immediate improvement. KG looks one step ahead and computes the expected best after an observation at x updates the GP. KG is more principled but also more expensive because it requires nested optimization. Empirically, it offers an improvement of roughly 10 to 20 percent for noisy problems. BoTorch’s qKnowledgeGradient is the standard implementation.

    Max-Value Entropy Search (MES)

    MES is an information-theoretic method: it selects x to maximize mutual information about the location of the optimum. The method is robust to noise and handles batches well, but it is more complex to implement (Wang and Jegelka, 2017). It is available as qMaxValueEntropy in BoTorch.

    Acquisition Formula Intuition Strength Weakness When to Use
    EI Expected gain over best so far Closed-form, balanced Slight exploitation bias Default—start here
    UCB μ − β·σ Tunable exploration, regret bounds Need to set β When EI underexplores
    PI Probability of any improvement Simplest Stagnates, no exploration Almost never
    Thompson argmin of posterior sample Trivial parallelization Higher variance Batch / parallel HPO
    KG Look-ahead expected best Robust to noise Expensive to compute Very noisy objectives
    MES Mutual info about optimum Strong batch behavior Implementation complexity Research / best-of-best

     

    Search Space Design

    This is the most underappreciated aspect of HPO. A GP can only optimize what is specified, and most HPO failures can be traced to a poorly defined search space.

    Log Scale for Multiplicative Parameters

    Learning rates, weight decay, and regularization coefficients have a fundamentally multiplicative effect: moving from 1e-3 to 1e-4 is comparable in magnitude to moving from 1e-4 to 1e-5. Log-uniform sampling is appropriate, and bounds of 1e-5 to 1e-1 are typical for the learning rate.

    Linear Scale for Additive Parameters

    Layer sizes, the number of estimators, batch size, and the number of layers have additive and roughly linear effects.

    Integer Handling

    Most BO libraries treat integers as continuous and round at evaluation time. This works but creates plateaus in the objective. BoTorch’s OneHotToNumeric and Round input transforms handle the case cleanly. Optuna and scikit-optimize handle rounding automatically once the parameter is declared as integer.

    Categorical Handling

    Three approaches are available: (1) one-hot encode and treat as continuous, which functions adequately but incurs a slight efficiency loss; (2) use a custom kernel such as the categorical Hamming kernel, which is cleaner; or (3) use TPE, which handles categoricals natively. BoTorch’s MixedSingleTaskGP supports mixed continuous-categorical spaces.

    Conditional Spaces

    A dropout rate is meaningful only when dropout is enabled, and momentum is relevant only for SGD, not for Adam. TPE handles such structure natively and learns the conditional relationships. GP-based BO requires custom handling. The typical approach is to flatten to the union of possibilities and rely on the optimizer to learn that certain dimensions are irrelevant. For deeply conditional spaces, TPE is often preferable.

    Hyperparameter Type Recommended Representation Typical Range
    Learning rate Log-uniform continuous 1e-5 to 1e-1
    Weight decay / L2 Log-uniform continuous 1e-6 to 1e-2
    Dropout rate Linear continuous 0.0 to 0.5
    Hidden size / width Log-uniform integer 32 to 1024
    Number of layers Linear integer 2 to 12
    Batch size Log-uniform integer (powers of 2) 8 to 512
    Optimizer choice Categorical {Adam, SGD, AdamW, RMSprop}
    Activation Categorical {ReLU, GELU, SiLU, Mish}
    XGBoost max_depth Linear integer 3 to 12
    XGBoost subsample Linear continuous 0.5 to 1.0

     

    Caution: GPs extrapolate poorly outside their training data. If the best hyperparameter value lies on the boundary of the search space, this is a strong signal that the bounds were set too tight. The bounds should be widened and the optimization rerun.

    Full Python Implementation

    Four working examples are presented in order of increasing complexity. Any of them can serve as a starting template for a particular HPO task.

    Example 1: Tuning XGBoost with scikit-optimize

    scikit-optimize is the gentlest entry point: pip-installable, with a scikit-learn-style API and GP-based defaults. It is well suited to tabular machine learning.

    """
    GP-BO for XGBoost using scikit-optimize.
    pip install scikit-optimize xgboost scikit-learn matplotlib
    """
    import numpy as np
    from skopt import gp_minimize
    from skopt.space import Real, Integer
    from skopt.utils import use_named_args
    from skopt.plots import plot_convergence, plot_objective
    from sklearn.datasets import fetch_openml
    from sklearn.model_selection import cross_val_score
    from xgboost import XGBClassifier
    import matplotlib.pyplot as plt
    
    # Load a real tabular dataset
    data = fetch_openml("adult", version=2, as_frame=True)
    X = data.data.select_dtypes(include=[np.number]).fillna(0).values
    y = (data.target == ">50K").astype(int).values
    
    # Define the search space
    space = [
        Real(1e-3, 0.3, prior="log-uniform", name="learning_rate"),
        Integer(3, 12, name="max_depth"),
        Integer(50, 500, name="n_estimators"),
        Real(0.5, 1.0, name="subsample"),
        Real(0.5, 1.0, name="colsample_bytree"),
        Real(1e-6, 1.0, prior="log-uniform", name="reg_alpha"),
        Real(1e-6, 1.0, prior="log-uniform", name="reg_lambda"),
        Real(0.0, 5.0, name="gamma"),
    ]
    
    @use_named_args(space)
    def objective(**params):
        """We minimize negative ROC AUC (skopt minimizes)."""
        clf = XGBClassifier(
            **params,
            tree_method="hist",
            eval_metric="logloss",
            n_jobs=-1,
            random_state=42,
            verbosity=0,
        )
        score = cross_val_score(
            clf, X, y, cv=3, scoring="roc_auc", n_jobs=1
        ).mean()
        return -score
    
    # Run GP-BO with EI acquisition
    result = gp_minimize(
        objective,
        space,
        n_calls=50,            # total trials
        n_initial_points=10,   # random seed trials
        acq_func="EI",         # Expected Improvement
        random_state=42,
        verbose=True,
    )
    
    print(f"Best AUC: {-result.fun:.4f}")
    print("Best hyperparameters:")
    for name, val in zip([s.name for s in space], result.x):
        print(f"  {name}: {val}")
    
    # Diagnostics
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    plot_convergence(result, ax=axes[0])
    axes[0].set_title("Convergence")
    plot_objective(result, ax=axes[1] if False else None)  # separate fig
    plt.tight_layout()
    plt.savefig("xgb_bo_convergence.png", dpi=120)
    

    The procedure runs ten random seed trials followed by forty GP-guided trials using Expected Improvement. The plot_convergence function displays the running best score against the trial number, the canonical visualization showing that BO outperforms random search. The plot_objective function displays partial-dependence plots for each hyperparameter and reveals which dimensions actually mattered.

    On the Adult dataset with fifty trials, GP-BO typically improves on the fifty-trial best from random search by 0.5 to 1.5 percent AUC. The gain is modest in isolation but valuable because it requires no additional trial budget and is reproducible.

    Example 2: Tuning a PyTorch CNN with BoTorch

    BoTorch is the appropriate next step once scikit-optimize becomes restrictive. It is PyTorch-native, GPU-accelerated, and built on GPyTorch (the same library used in the GP fundamentals post). For research and production deep-learning HPO, it is the established standard.

    """
    GP-BO for a PyTorch CNN using BoTorch.
    pip install botorch gpytorch torch torchvision
    """
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    from botorch.models import SingleTaskGP
    from botorch.fit import fit_gpytorch_mll
    from botorch.acquisition import qExpectedImprovement
    from botorch.optim import optimize_acqf
    from gpytorch.mlls import ExactMarginalLogLikelihood
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Search space: [log_lr, log_wd, dropout, log_hidden]
    # Bounds in normalized space [0,1] mapped to actual ranges below.
    BOUNDS = torch.tensor(
        [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
        device=device, dtype=torch.double,
    )
    
    def unnormalize(x):
        """Map [0,1]^4 to actual hyperparameter ranges."""
        log_lr   = -5.0 + x[..., 0] * 3.0   # 1e-5 to 1e-2
        log_wd   = -6.0 + x[..., 1] * 4.0   # 1e-6 to 1e-2
        dropout  = x[..., 2] * 0.5          # 0 to 0.5
        log_hidden = 5.0 + x[..., 3] * 4.0  # 32 to 512 (log2)
        return {
            "lr": float(10 ** log_lr),
            "wd": float(10 ** log_wd),
            "dropout": float(dropout),
            "hidden": int(2 ** round(log_hidden.item())),
        }
    
    class SmallCNN(nn.Module):
        def __init__(self, hidden, dropout):
            super().__init__()
            self.net = nn.Sequential(
                nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(32 * 7 * 7, hidden), nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, 10),
            )
        def forward(self, x):
            return self.net(x)
    
    # Load FashionMNIST (small enough to iterate quickly)
    tfm = transforms.Compose([transforms.ToTensor()])
    train_ds = datasets.FashionMNIST("./data", train=True, download=True, transform=tfm)
    val_ds = datasets.FashionMNIST("./data", train=False, download=True, transform=tfm)
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=512, num_workers=2)
    
    def train_eval(params, epochs=3):
        """Train CNN with given hyperparams, return validation accuracy."""
        model = SmallCNN(params["hidden"], params["dropout"]).to(device)
        opt = optim.AdamW(model.parameters(), lr=params["lr"], weight_decay=params["wd"])
        crit = nn.CrossEntropyLoss()
        for _ in range(epochs):
            model.train()
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                opt.zero_grad()
                crit(model(xb), yb).backward()
                opt.step()
        # Evaluate
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                preds = model(xb).argmax(1)
                correct += (preds == yb).sum().item()
                total += yb.size(0)
        return correct / total
    
    # Initial random trials
    N_INIT = 8
    torch.manual_seed(0)
    X_obs = torch.rand(N_INIT, 4, device=device, dtype=torch.double)
    Y_obs = torch.tensor(
        [[train_eval(unnormalize(x))] for x in X_obs],
        device=device, dtype=torch.double,
    )
    print(f"Init complete. Best so far: {Y_obs.max().item():.4f}")
    
    # BO loop
    N_BO_ITERS = 20
    for it in range(N_BO_ITERS):
        # Fit GP (BoTorch handles standardization, kernel, MLL)
        gp = SingleTaskGP(X_obs, Y_obs)
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_mll(mll)
    
        # qEI acquisition (q=1 for sequential)
        acq = qExpectedImprovement(model=gp, best_f=Y_obs.max())
        candidate, _ = optimize_acqf(
            acq_function=acq,
            bounds=BOUNDS,
            q=1,
            num_restarts=10,
            raw_samples=512,
        )
        # Evaluate candidate
        new_y = train_eval(unnormalize(candidate.squeeze(0)))
        X_obs = torch.cat([X_obs, candidate], dim=0)
        Y_obs = torch.cat([Y_obs, torch.tensor([[new_y]], device=device, dtype=torch.double)], dim=0)
        print(f"Iter {it+1}: y={new_y:.4f} | best={Y_obs.max().item():.4f}")
    
    best_idx = Y_obs.argmax()
    print("\nBest hyperparameters:")
    print(unnormalize(X_obs[best_idx]))
    

    Several details merit note.

    • The implementation operates in normalized [0,1]d space and unnormalizes before training. BoTorch strongly prefers normalized inputs.
    • BoTorch’s SingleTaskGP uses a Matérn 5/2 kernel by default with automatic relevance determination, which learns per-dimension lengthscales.
    • optimize_acqf uses ten multi-start L-BFGS optimizations with 512 random initial points to find the global optimum of the acquisition function.
    • The loop executes twenty-eight trials in total (eight random plus twenty BO). On a single GPU with three-epoch FashionMNIST, this takes approximately thirty minutes.

    Example 3: Multi-Objective BO with qNEHVI

    Real-world deployment depends on more than accuracy: latency and memory also matter. Multi-objective BO produces the entire Pareto frontier between competing objectives.

    """
    Multi-objective HPO: maximize accuracy AND minimize latency.
    Returns the Pareto frontier instead of a single best.
    """
    import time
    import torch
    from botorch.models import SingleTaskGP, ModelListGP
    from botorch.fit import fit_gpytorch_mll
    from botorch.acquisition.multi_objective.monte_carlo import qNoisyExpectedHypervolumeImprovement
    from botorch.optim import optimize_acqf
    from botorch.utils.multi_objective.box_decompositions.dominated import DominatedPartitioning
    from gpytorch.mlls import ExactMarginalLogLikelihood
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    DTYPE = torch.double
    
    # Search space: same 4-dim CNN tuning problem
    BOUNDS = torch.tensor([[0.0]*4, [1.0]*4], device=device, dtype=DTYPE)
    
    # Two objectives: accuracy (maximize) and -latency_ms (maximize, since BoTorch maximizes)
    REF_POINT = torch.tensor([0.5, -200.0], device=device, dtype=DTYPE)  # worst-case bounds
    
    def objective_2d(x_norm):
        """Returns [accuracy, -latency_ms]."""
        params = unnormalize(x_norm)  # reuse from Example 2
        acc = train_eval(params, epochs=3)
        # Measure latency on a batch
        model = SmallCNN(params["hidden"], params["dropout"]).to(device).eval()
        dummy = torch.randn(64, 1, 28, 28, device=device)
        # Warm up
        with torch.no_grad():
            _ = model(dummy)
        torch.cuda.synchronize() if device == "cuda" else None
        t0 = time.perf_counter()
        with torch.no_grad():
            for _ in range(20):
                _ = model(dummy)
        torch.cuda.synchronize() if device == "cuda" else None
        latency_ms = (time.perf_counter() - t0) * 1000 / 20
        return torch.tensor([acc, -latency_ms], device=device, dtype=DTYPE)
    
    # Initial design
    N_INIT = 10
    torch.manual_seed(0)
    X_obs = torch.rand(N_INIT, 4, device=device, dtype=DTYPE)
    Y_obs = torch.stack([objective_2d(x) for x in X_obs])
    
    # Multi-objective BO loop
    for it in range(20):
        # Fit independent GPs for each objective
        models = [SingleTaskGP(X_obs, Y_obs[:, i:i+1]) for i in range(2)]
        model_list = ModelListGP(*models)
        for m in models:
            mll = ExactMarginalLogLikelihood(m.likelihood, m)
            fit_gpytorch_mll(mll)
    
        # qNEHVI acquisition
        acq = qNoisyExpectedHypervolumeImprovement(
            model=model_list,
            ref_point=REF_POINT,
            X_baseline=X_obs,
            prune_baseline=True,
        )
        candidate, _ = optimize_acqf(
            acq_function=acq, bounds=BOUNDS,
            q=2, num_restarts=10, raw_samples=512,
        )
        new_y = torch.stack([objective_2d(x) for x in candidate])
        X_obs = torch.cat([X_obs, candidate])
        Y_obs = torch.cat([Y_obs, new_y])
        # Compute hypervolume
        hv = DominatedPartitioning(ref_point=REF_POINT, Y=Y_obs).compute_hypervolume()
        print(f"Iter {it+1}: HV={hv.item():.3f} | n_obs={len(X_obs)}")
    
    # Extract Pareto frontier
    from botorch.utils.multi_objective.pareto import is_non_dominated
    mask = is_non_dominated(Y_obs)
    pareto = Y_obs[mask]
    print(f"\nPareto frontier: {len(pareto)} points")
    for acc, neg_lat in pareto.cpu().numpy():
        print(f"  acc={acc:.4f}, latency={-neg_lat:.2f}ms")
    

    The output is not a single best configuration but a frontier of Pareto-optimal configurations. For each point on this frontier, accuracy cannot be improved without sacrificing latency, and vice versa. The hypervolume metric quantifies the size of the dominated region; larger values are better.

    Example 4: Optuna with BoTorch Sampler

    Optuna is the most widely adopted HPO library, and an underappreciated feature is that its default TPE sampler can be replaced with a GP-based BoTorch sampler in a single line of code.

    """
    Optuna with GP (BoTorch) sampler vs default TPE.
    pip install optuna botorch
    """
    import optuna
    from optuna.samplers import TPESampler
    from optuna.integration import BoTorchSampler
    import xgboost as xgb
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import cross_val_score
    import numpy as np
    
    X, y = load_breast_cancer(return_X_y=True)
    
    def objective(trial):
        params = {
            "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
            "max_depth": trial.suggest_int("max_depth", 3, 12),
            "n_estimators": trial.suggest_int("n_estimators", 50, 500),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-6, 1.0, log=True),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-6, 1.0, log=True),
        }
        clf = xgb.XGBClassifier(
            **params, tree_method="hist", eval_metric="logloss",
            n_jobs=-1, random_state=42, verbosity=0,
        )
        return cross_val_score(clf, X, y, cv=5, scoring="roc_auc").mean()
    
    # A: TPE sampler (Optuna default)
    study_tpe = optuna.create_study(
        direction="maximize",
        sampler=TPESampler(seed=42, n_startup_trials=10),
    )
    study_tpe.optimize(objective, n_trials=50, show_progress_bar=True)
    
    # B: BoTorch (GP) sampler
    study_gp = optuna.create_study(
        direction="maximize",
        sampler=BoTorchSampler(n_startup_trials=10, seed=42),
    )
    study_gp.optimize(objective, n_trials=50, show_progress_bar=True)
    
    print(f"TPE best AUC: {study_tpe.best_value:.4f}")
    print(f"GP-BO best AUC: {study_gp.best_value:.4f}")
    
    # Visualize convergence
    import matplotlib.pyplot as plt
    def running_best(trials):
        vals = [t.value for t in trials]
        return np.maximum.accumulate(vals)
    
    plt.figure(figsize=(10, 5))
    plt.plot(running_best(study_tpe.trials), label="TPE", linewidth=2)
    plt.plot(running_best(study_gp.trials), label="GP-BO (BoTorch)", linewidth=2)
    plt.xlabel("Trial")
    plt.ylabel("Best AUC so far")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.title("TPE vs GP-BO convergence")
    plt.savefig("tpe_vs_gp.png", dpi=120, bbox_inches="tight")
    

    Empirically, for smaller search spaces (no more than ten dimensions) and noisy objectives, GP-BO converges faster than TPE in trial count. For larger spaces or those with conditional dimensions, TPE closes the gap. The principal benefit of Optuna is the framework: pruning, distributed trials, a web dashboard, and straightforward sampler substitution.

    Tip: For an end-to-end HPO orchestration pipeline that queues trials, distributes them to workers, and persists results, Optuna pairs naturally with Apache Airflow. Each Airflow task corresponds to one trial, and the study state lives in a shared database.

    Multi-Fidelity and Parallel HPO

    A key fact of modern deep learning is that full training is expensive while partial training is informative. A 100-epoch run is ten times more expensive than a 10-epoch run, yet the 10-epoch result correlates strongly with the 100-epoch outcome. Multi-fidelity HPO exploits this relationship.

    BOHB (Falkner et al., 2018)

    BOHB combines Hyperband (early stopping based on partial training curves) with BO (informed sampling rather than random). Hyperband decides when to terminate a trial; BO decides which configurations to try at each rung. Empirically the combination outperforms either method alone for deep-learning HPO.

    BOHB uses TPE rather than a GP for the BO component because the sampling-based density model handles the high-dimensional, conditional spaces of neural-network architectures well. GP variants exist (Falkner discusses the trade-offs), but TPE is the default.

    Multi-Fidelity BO (MFBO)

    MFBO adds fidelity, such as training epochs or dataset fraction, as an additional dimension in the GP. The GP learns the relationship between fidelity and final performance, and the acquisition function selects both x and a fidelity, balancing information gain against compute cost. BoTorch provides qMultiFidelityKnowledgeGradient for this purpose.

    Asynchronous BO (Kriging Believer)

    For batch parallelism, while a trial is running, its result is fantasized using the GP posterior mean. The hallucinated observation is added to the training set, a temporary GP is fitted, and the next trial is selected on the assumption that the in-flight trial will reach its predicted value. The observation is corrected when the trial finishes. This decouples scheduling from observations and enables many parallel workers without serializing on the GP fit.

    Trust Region BO (TuRBO)

    Eriksson et al. (2019) proposed TuRBO for high-dimensional HPO (50 or more dimensions). The method maintains a small trust region around the current best, fits a local GP, and optimizes within. The trust region expands when a step succeeds and contracts when it does not. The approach effectively decomposes a high-dimensional problem into many local low-dimensional problems. It is available in BoTorch.

    Key Takeaway: With eight or more GPUs and slow training, BOHB typically outperforms vanilla GP-BO. With one GPU and up to twenty hyperparameters, vanilla GP-BO with Expected Improvement offers the best return on investment. With more than fifty hyperparameters, characteristic of neural architecture search, TuRBO or evolutionary methods are appropriate.

    Tools Comparison

    HPO Tools Landscape X-axis: Simplicity → Research flexibility  |  Y-axis: Sample efficiency Simplicity (left) ⟶ Research flexibility (right) Sample efficiency ⟶ simple API framework research-flexible high med low skopt GP-based Optuna TPE / GP Ax GP-based BoTorch GP-based HyperOpt TPE Ray Tune mixed W&B mixed Vizier GP-based SageMaker managed Backend legend GP-based TPE Mixed Managed cloud

    Tool Default Backend Multi-Objective Constraints Conditional Spaces Best For
    Optuna TPE (GP via BoTorch) Yes Limited Native Production engineering
    Ax GP (BoTorch) Yes (Pareto) Yes Yes Adaptive experimentation
    BoTorch GP (PyTorch) Yes Yes Custom Research, custom algorithms
    scikit-optimize GP / RF No No No Quickstart, sklearn integration
    HyperOpt TPE Limited No Native Mature distributed TPE
    Ray Tune Pluggable (BO/TPE/PBT/ASHA) Yes (via Ax) Via backend Via backend Distributed orchestration
    W&B Sweeps Bayes / Random / Grid No No Limited Experiment tracking integration
    Vertex AI Vizier GP (Google) Yes Yes Yes Managed, GCP-native
    SageMaker AMT GP / Hyperband No No Limited Managed, AWS-native

     

    Practical Recommendation

    For the majority of practitioners and HPO problems, the following guidance is appropriate.

    • Begin with Optuna. The API is the cleanest, the defaults are sensible, the dashboard is effective, and the BoTorch sampler can be substituted when TPE becomes inadequate.
    • Move to Ax when multi-objective optimization with constraints is required, or when a higher-level service-style API for ongoing experimentation is desirable.
    • Use BoTorch directly when implementing custom acquisition functions, conducting research, or requiring fine-grained control over GP fitting through custom kernels, priors, or multi-task models.
    • Use scikit-optimize for one-off tabular machine-learning tuning where simplicity outweighs power.
    • Use Ray Tune when distributed orchestration is the bottleneck and hundreds of workers require scheduling.

    Real-World Case Studies

    Google Vizier

    Vizier is Google’s internal Bayesian Optimization service, used to tune systems ranging from ad models to ranking systems to LLM training pipelines. The original 2017 paper reported thousands of studies per day across the company. The default algorithm is GP-based BO with batched parallel evaluation. Vertex AI Vizier exposes the service externally on Google Cloud Platform.

    Meta’s Ax and BoTorch

    Meta open-sourced Ax and BoTorch from work on optimizing ranking models. Published results indicate ranking-quality improvements exceeding 40 percent relative to random search, with substantially fewer trials required. The same stack is used to tune hyperparameters in video-encoding research, ad-auction simulators, and infrastructure scheduling.

    AlphaGo and AlphaFold

    DeepMind has used Bayesian optimization in inner loops for many years. AlphaGo reportedly used GP-based BO to tune MCTS hyperparameters and training schedules. AlphaFold 2’s training pipeline used multi-fidelity BO for architecture-related hyperparameters where each evaluation was prohibitively expensive.

    Drug Discovery and Protein Design

    Beyond machine-learning hyperparameters, GP-BO is the standard tool for real experimental design: which molecules to synthesize next, which protein variants to screen, and which experimental conditions to test. Each trial requires days of laboratory time and thousands of dollars in reagents, making sample efficiency essential.

    Key Takeaway: GP-based BO is not a research curiosity. It runs in production at scale at every major technology company and at most pharmaceutical firms. The supporting tools (BoTorch, Ax, Optuna, Vizier) reflect hundreds of person-years of engineering. Teams that do not use BO for HPO are likely forgoing accuracy gains of 0.5 to 5 percent.

    Practical Guide and Pitfalls

    Initial Design: Avoid a Cold Start

    Five to ten random trials should be run before BO begins. Without seed observations, the GP has no signal and the acquisition function selects the geometric center of the search box. The rule of thumb is n_init = max(5, 2 · d), where d is the search-space dimension.

    Parallelize Four to Eight Trials per BO Step

    Modern HPO at scale uses batch acquisition functions (qEI, qNEI, qNEHVI) to propose four to eight candidates per BO iteration. This represents an effective compromise: enough parallelism to utilize a multi-GPU node, but not so much that GP information gain saturates within a batch.

    Stopping Criteria

    • Trial budget (the most common): for example, running 100 trials. Simple and reproducible.
    • Time budget: for example, running for 24 hours. Useful in production where wall-clock time matters more than trial count.
    • Convergence: stop when the running-best improvement is less than ε for k consecutive trials. This criterion is risky in isolation because BO can stall before identifying the global optimum.
    • Combination: max(trial_budget, no_improvement_for_k_trials). A practical default.

    Reproducibility

    All random seeds should be set: numpy, torch, the BO library, and the model-training loop. Every (config, score, wallclock, seed) tuple should be logged. The most common way to lose value from HPO is to be unable to reproduce the best configuration. Pairing the optimizer with experiment tracking such as W&B or MLflow is sufficient.

    Debugging GP Fits

    If BO recommendations appear pathological (clustered in a corner, or oscillating widely), the following checks are appropriate.

    • Lengthscales: whether the lengthscales are reasonable. Very small values indicate that the GP is treating every observation as noise; very large values indicate that it considers the function constant.
    • Output standardization: BoTorch handles standardization internally; some libraries do not. Standardizing y manually when in doubt is prudent.
    • Input normalization: inputs should always be normalized to [0,1]d before being passed to a GP.
    • Noise: if observation noise is too low, refit with a slightly higher noise prior.

    High-Dimensional Pitfalls

    Beyond approximately twenty dimensions, vanilla GPs degrade. The symptoms are that BO no longer outperforms random search and that GP lengthscales reach the boundary of their allowed range. Possible remedies include TuRBO (trust regions), random embeddings (REMBO), dimensionality reduction by PCA on a random sample, or a switch to evolutionary methods. For further discussion of high-dimensional optimization, see the companion posts on genetic algorithms and mixed-integer programming.

    Constrained BO

    Infeasible configurations should not consume evaluations. If a model has a memory budget, latency budget, or hardware constraint, the constraint should be modeled as a separate GP and used with a constrained acquisition function such as expected feasible improvement or qNEHVI with constraints in BoTorch. The savings in trial budget can be substantial.

    The Cold-Start Problem

    When tuning a new but related task, prior trials from similar tasks are typically available. Transferable BO initializes the GP using observations from prior studies (with appropriate weighting), providing an informative prior in place of a cold start. The method is available in Ax (multi-task BO) and in the academic literature.

    Trial Replication and Noise

    For genuinely noisy objectives such as reinforcement-learning rewards or classification on small datasets, the best candidates should be replicated to reduce noise. The Central Limit Theorem guide covers the underlying mathematics: averaging k noisy observations reduces the standard error by a factor of √k. Allocating 20 percent of the trial budget to replication yields a substantially more reliable best configuration.

    Caution: The most common HPO failure mode is not the wrong method but the wrong objective. If the validation loss is not a good proxy for test loss (a small validation set, data leakage, or distribution shift), no optimizer can compensate. The evaluation pipeline should be audited before tuning begins. Cross-validation, held-out validation, and techniques covered in the semi-supervised learning guide matter more than the choice of optimizer.

    Frequently Asked Questions

    Why is GP-based BO better than random search for HPO?

    GP-based BO uses information from prior trials to pick the next one. Random search throws that information away. On benchmark HPO problems with 5–20 hyperparameters, GP-BO typically reaches the same accuracy as random search using 3–10× fewer trials. When each trial costs hours of GPU time, that compounds into significant compute savings—typically 60–90% of the budget.

    When does TPE beat GP-based BO?

    Three regimes: (1) high-dimensional spaces (30+ hyperparameters) where GPs degrade, (2) heavily conditional spaces (this hyperparameter only exists if that one is true) where TPE handles structure natively, (3) when you need very fast wall-clock per BO iteration because TPE’s sampling is cheaper than GP fitting + acquisition optimization. For most “normal” HPO with ≤20 dims, GP-BO is more sample-efficient.

    How many initial random trials should I run before starting BO?

    Rule of thumb: n_init = max(5, 2 · d) where d is the search space dimension. For a 4-dimensional space, 8–10 random trials. For 10 dimensions, 20 random trials. Without enough seeds, the GP has no signal and BO collapses to picking the box center repeatedly.

    Can GP-BO handle categorical hyperparameters like activation function or optimizer choice?

    Yes, three approaches: (1) one-hot encode and treat as continuous (works, slight efficiency loss), (2) use a custom kernel like Hamming distance for categoricals (cleaner, BoTorch’s MixedSingleTaskGP does this), (3) switch to TPE which handles categoricals natively. For 1–2 categorical dimensions, one-hot is fine. For many categoricals, use TPE or a properly mixed kernel.

    BoTorch vs Optuna—which should I use?

    For most production HPO, start with Optuna: cleaner API, better tooling (dashboard, study persistence, distributed trials), and you can swap in the BoTorch sampler for GP-BO when needed. Use BoTorch directly when you need custom acquisition functions, multi-task GPs, advanced features (qNEHVI, qKG, MES), or are doing research. Many production setups use both: Optuna for orchestration, BoTorch sampler under the hood.

    References and Further Reading

    • Bergstra & Bengio (2012). Random Search for Hyper-Parameter Optimization. JMLR. The paper that established random search as the baseline.
    • Frazier (2018). A Tutorial on Bayesian Optimization. arXiv:1807.02811. The clearest intro to BO mathematics.
    • Falkner et al. (2018). BOHB: Robust and Efficient Hyperparameter Optimization at Scale. ICML. The BOHB paper.
    • Eriksson et al. (2019). Scalable Global Optimization via Local Bayesian Optimization. NeurIPS. TuRBO.
    • Wang & Jegelka (2017). Max-value Entropy Search for Efficient Bayesian Optimization. ICML.
    • BoTorch documentation,official docs for Meta’s Bayesian optimization library.
    • Optuna documentation—practical HPO framework with TPE and GP samplers.
    • scikit-optimize documentation—sklearn-style GP and forest-based BO.
    • Ax (Adaptive Experimentation Platform),Meta’s higher-level wrapper around BoTorch.
    Related Reading:

  • Gaussian Processes Explained: Bayesian Regression with Uncertainty

    Summary

    What this post covers: A first-principles tour of Gaussian Processes (GPs) for regression and Bayesian optimization, with the underlying math, a from-scratch NumPy implementation, a production GPyTorch workflow, kernel design, and the scalability tricks that push GPs past their classical O(n^3) limit.

    Key insights:

    • A Gaussian Process is a nonparametric Bayesian model that returns both a mean prediction and a calibrated confidence interval at every input. Uncertainty grows automatically in regions where training data is sparse, which is precisely the behavior a trustworthy model should exhibit.
    • The kernel constitutes the entire model. It encodes assumptions about smoothness, periodicity, or linearity, and a Matérn-5/2 kernel with Automatic Relevance Determination (ARD), together with per-dimension input standardization, is an appropriate default in practice.
    • Hyperparameters such as lengthscales, output scale, and noise variance are learned by maximizing the log marginal likelihood, which automatically penalizes overly complex models. Occam’s razor follows from the mathematics rather than being applied externally.
    • GPs are particularly effective for small-to-medium, sample-expensive problems such as Bayesian optimization of hyperparameters, surrogate modeling of simulations, drug discovery, and geostatistics, where neural networks tend to overfit and calibrated uncertainty materially affects the resulting decisions.
    • The O(n^3) scaling barrier is no longer a hard ceiling. Inducing-point methods such as SVGP, BBMM in GPyTorch, and Deep Kernel Learning allow modern GPs to handle 10^5 to 10^6 points and high-dimensional structured inputs.

    Main topics: The Central Idea: Distributions Over Functions, The Underlying Mathematics, Kernels: The Heart of Gaussian Processes, Hyperparameter Learning and the Marginal Likelihood, Full Python Implementation, Applications: Where GPs Excel, Scalability: Breaking the O(n^3) Wall, Gaussian Processes vs. Alternatives, Common Pitfalls and How to Avoid Them, Related Reading, Frequently Asked Questions, Conclusion and Further Reading.

    A neural network predicts a stock price of $127.50. A Gaussian Process predicts $125 to $130 with 95 percent confidence. The distinction is not one of precision but of recognizing the limits of one’s knowledge. Gaussian Processes are the principal mechanism by which machine learning models can express well-calibrated uncertainty.

    This characteristic explains why Gaussian Processes (GPs) have quietly become indispensable in domains where uncertainty matters more than raw predictive power: Bayesian optimization of hyperparameters, surrogate modeling of expensive physics simulations, geostatistics, drug discovery, robotic control, and active learning. A neural network returns a single number. A Gaussian Process returns a probability distribution over possible answers—a mean prediction accompanied by a principled estimate of its reliability.

    The remainder of this article examines Gaussian Processes from first principles. The mathematics is presented accessibly but rigorously, a GP is constructed from scratch with NumPy, and the implementation is then extended to production-grade code in GPyTorch. The discussion covers kernels, hyperparameter learning, Bayesian optimization, classification, and the scalability techniques that allow modern GPs to handle hundreds of thousands of points. Readers will gain an understanding of not only how to use a GP, but when and why to do so.

    The Central Idea: Distributions Over Functions

    Most machine learning models parameterize a function. Linear regression selects two numbers (slope and intercept). A neural network selects millions of weights. Given those parameters, the model becomes a single fixed function that maps inputs to outputs. Provided an input x, the model returns an output y.

    A Gaussian Process operates differently and, once understood, more elegantly. Rather than committing to a single function, a GP defines a probability distribution over infinitely many possible functions. Before any data are observed, every function that could plausibly describe the problem carries some prior probability. After observing training points, the GP updates this distribution: functions consistent with the data become more likely while others diminish in probability. The “prediction” is therefore not a single curve but a family of curves, and the spread of that family at any point x* indicates precisely how uncertain the model is.

    Why Gaussian Processes Matter

    Four reasons recommend GPs for inclusion in a practitioner’s toolkit.

    • Principled uncertainty quantification. Every prediction is accompanied by a calibrated confidence interval grounded in Bayes’ rule rather than heuristics.
    • Excellent sample efficiency. GPs often perform well with 20, 50, or 500 training points, a regime in which deep networks routinely overfit.
    • Bayesian by design. There is no separate pipeline for training and uncertainty evaluation; the posterior is the model.
    • Interpretable inductive bias. The kernel expresses assumptions about smoothness, periodicity, or linearity in explicit and inspectable form.
    Key Takeaway: A Gaussian Process is a nonparametric Bayesian model that returns both a prediction and a calibrated confidence interval at every input point. Its uncertainty grows naturally in regions where training data are sparse, which is precisely the behavior a trustworthy model should exhibit.

    When to Use a Gaussian Process

    GPs are the appropriate tool in the following circumstances.

    • The data are small to medium in size, typically N < 10,000 for a standard GP, or up to 100,000 with approximations.
    • The application requires uncertainty estimates that can be relied upon, rather than softmax outputs or heuristic approximations such as dropout.
    • Evaluating the target function is expensive, for example a wet-lab experiment, a supercomputer simulation, or a 48-hour hyperparameter sweep.
    • The underlying process is smooth and structured, such as a physical system, a spatial field, or a slowly varying time series.

    GPs are usually not the right tool when the following conditions hold.

    • The dataset contains millions of rows and is expected to continue growing, in which case the O(n3) training cost becomes prohibitive.
    • The inputs are very high-dimensional, such as raw images, long sequences, or graphs; kernels on raw pixels rarely capture useful structure.
    • The features are categorical with no natural distance metric.
    • The problem requires deep hierarchical feature learning that only a neural network can provide.

    A useful heuristic: if the dataset fits in RAM and the problem has smooth structure, a GP is a sensible first choice. More complex methods may not be necessary.

    The Underlying Mathematics

    This section develops intuition for what a Gaussian Process is mathematically. Plain language accompanies each equation.

    Formal Definition

    A Gaussian Process is fully specified by two objects.

    • A mean function m(x), which describes the average value of the process at any input x. In practice m(x) = 0 is almost always adopted after the data are centered, leaving the kernel to perform the main modeling work.
    • A covariance function or kernel k(x, x’), which describes how strongly two outputs are correlated given the similarity of their inputs.

    This is written as follows.

    f(x) ∼ GP(m(x), k(x, x’))

    The defining property is elegantly simple: for any finite set of inputs {x1, x2, …, xn}, the corresponding outputs [f(x1), f(x2), …, f(xn)] follow a multivariate Gaussian distribution. For any n input points, the joint distribution of the function values is a bell-shaped cloud in n dimensions, with means given by m and covariance matrix entries given by k.

    This is why GPs lie at the intersection of functional analysis and probability: they enable reasoning about an infinite-dimensional object (a whole function) by projecting it down to finite-dimensional Gaussians whenever necessary. Any property that holds for multivariate Gaussians, including conditioning, marginalization, and linear transformation, also holds for GPs. The connection to the Central Limit Theorem and multivariate Gaussians is not coincidental; it is precisely what makes this model class tractable.

    The Posterior Predictive Distribution

    Consider training inputs X = [x1, …, xn] with noisy observations y = [y1, …, yn], where each yi = f(xi) + εi and εi ∼ N(0, σn2). The objective is to predict f(x*) at a new test input x*.

    Because the prior over f is a GP and the observation noise is Gaussian, the posterior over f(x*) is also Gaussian, and its mean and variance can be expressed in closed form.

    Posterior mean:     μ*  = K(x*, X) · [K(X, X) + σ_n² I]⁻¹ · y
    Posterior variance: σ*² = K(x*, x*) - K(x*, X) · [K(X, X) + σ_n² I]⁻¹ · K(X, x*)

    In plain language, the components have the following meanings.

    • K(X, X) is the n×n matrix of kernel evaluations between all pairs of training inputs. Each entry expresses the similarity between two training points.
    • K(x*, X) is a 1×n row vector that expresses the similarity between the test point and each training input.
    • σn2 I is the noise variance added to the diagonal. It both reflects measurement noise and provides jitter for numerical stability.
    • The posterior mean is a weighted combination of training targets, with weights determined by similarity.
    • The posterior variance begins at the prior variance K(x*, x*) and is reduced by an amount that depends on the informativeness of nearby training points.

    The consequence is straightforward. When x* is close to many training points, the similarity vector K(x*, X) contains large entries, the variance reduction is substantial, and the model becomes confident. When x* is far from every training point, all similarities are small, the variance reduction is negligible, and the posterior variance remains close to the prior variance. GPs therefore identify their own extrapolation regions and report them explicitly.

    Visualizing the Posterior

    Gaussian Process Posterior: Mean and 95% Confidence Band Input x Output f(x) Observed training data Posterior mean μ* 95% confidence band μ* ± 2σ* Wide uncertainty (no data) Narrow near data

    The blue shaded band expands in regions far from the black training points and contracts where data are dense. This is the GP communicating its confidence directly: high confidence near observed points and lower confidence elsewhere, without any additional calibration step.

    Kernels: The Heart of Gaussian Processes

    If the kernel is the heart of a GP, each kernel choice constitutes a theory about how the modeled phenomenon behaves. Kernels encode what “similar” means in the input space: whether nearby points are expected to have similar outputs, whether seasonality should be encoded, and whether the underlying function is smooth or jagged. The most common kernels are reviewed below.

    The RBF (Squared Exponential) Kernel

    The RBF kernel is the workhorse and frequently the first choice in practice.

    k_RBF(x, x') = σ² · exp( - ||x - x'||² / (2 · ℓ²) )

    The parameter ℓ is the length scale, which controls how rapidly correlation decays with distance. A small ℓ produces highly oscillatory functions in which neighbors barely influence each other; a large ℓ produces smooth, slowly varying functions. The output variance σ2 scales the overall amplitude. Samples drawn from an RBF-kernel GP are infinitely differentiable, which is sometimes unrealistically smooth.

    The Matérn Kernel

    Real-world functions are rarely infinitely smooth. The Matérn family introduces a smoothness parameter ν that interpolates between jagged and smooth behavior. Common choices are ν = 3/2 (once-differentiable) and ν = 5/2 (twice-differentiable). Both are standard defaults in Bayesian optimization precisely because they model realistic physical processes more accurately than the RBF kernel.

    The Periodic Kernel

    k_periodic(x, x') = σ² · exp( -2 · sin²(π |x - x'| / p) / ℓ² )

    The parameter p denotes the period. The periodic kernel is appropriate for phenomena that repeat, including daily electricity demand, annual temperature cycles, and tidal patterns. It extrapolates periodic behavior indefinitely into the future, which is both a strength and a risk.

    The Linear Kernel

    k(x, x’) = σ2 · x · x’. A GP with a linear kernel is equivalent to Bayesian linear regression and is useful when combined with other kernels to model long-term trends.

    Composite Kernels

    The real power of GPs lies in combining kernels. Two fundamental operations preserve positive semi-definiteness, which is a required property.

    • Addition: k1(x, x’) + k2(x, x’). Encodes multiple independent effects, for example a trend combined with seasonality.
    • Multiplication: k1(x, x’) · k2(x, x’). Encodes interactions, for example a periodic pattern whose amplitude varies slowly.

    A common time-series specification is RBF + Periodic + Linear, which simultaneously models local smoothness, repeating seasonality, and a drifting trend. The kernel grammar effectively functions as a small programming language for expressing inductive biases.

    Automatic Relevance Determination (ARD)

    For multi-dimensional inputs, each dimension can be assigned its own length scale ℓi. Dimensions irrelevant to the output acquire large length scales and are effectively ignored, while informative features acquire short length scales. This procedure, known as Automatic Relevance Determination, turns a GP into a feature-importance ranker as a byproduct of training.

    Sample Draws from GPs with Different Kernels RBF (very smooth) Matérn-3/2 (rougher) Periodic (repeating) RBF + Periodic (trend + seasonality) Each panel shows three sample functions drawn from the GP prior with the indicated kernel.

    Kernel Cheat Sheet

    Kernel Formula Smoothness Typical Use Case
    RBF (Squared Exponential) σ² exp(-d² / 2ℓ²) Infinitely differentiable Default choice, very smooth signals
    Matérn-3/2 σ² (1 + √3 d/ℓ) exp(-√3 d/ℓ) Once differentiable Realistic physics, Bayesian opt
    Matérn-5/2 σ² (1 + √5 d/ℓ + 5d²/3ℓ²) exp(-√5 d/ℓ) Twice differentiable Hyperparameter tuning (BoTorch default)
    Periodic σ² exp(-2 sin²(π d/p) / ℓ²) Infinitely differentiable, repeating Seasonality, cycles
    Linear σ² x · x’ Linear only Drifts, trends, baselines

     

    Hyperparameter Learning and the Marginal Likelihood

    Kernels come equipped with hyperparameters: length scales, output variances, and noise levels. The natural question is how these should be selected. The GP’s answer is elegant: maximize the log marginal likelihood of the observed data.

    The Log Marginal Likelihood

    For training targets y, inputs X, and hyperparameters θ = {ℓ, σ, σn}, the log marginal likelihood takes the following form.

    log p(y | X, θ) = -½ yᵀ K_y⁻¹ y  -  ½ log |K_y|  -  (n/2) log(2π)
    
    where K_y = K(X, X) + σ_n² I

    The three terms perform three distinct roles.

    • The first term (the data-fit term) penalizes hyperparameters that make the observed y implausible under the prior.
    • The second term (the complexity penalty) penalizes overly flexible kernels. Occam’s razor is built into the mathematics: a highly flexible kernel can fit anything, but it incurs a cost here.
    • The third term is a normalization constant that does not depend on the data.

    The complexity penalty is why GPs regularize automatically. Unlike a neural network, which requires dropout, weight decay, or early stopping to prevent overfitting, a GP trained by maximizing the marginal likelihood naturally settles at an appropriate level of smoothness. This is one of the principal reasons GPs perform well on small datasets.

    Optimization in Practice

    The log marginal likelihood is differentiable with respect to θ, so gradient-based optimizers are applicable. L-BFGS is the traditional choice; Adam works effectively in GPyTorch because it integrates with PyTorch’s autograd system.

    A fully Bayesian treatment, in which priors are placed on hyperparameters and the hyperparameters are integrated out, can be performed via MCMC (slower but more principled) or variational approximations. This is particularly important when data are scarce and marginal likelihood estimates are themselves noisy.

    Caution: When N is small (below twenty, for example), the marginal likelihood landscape is multimodal and optimization can become stuck. Initialization from several random starts, or placement of informative priors on hyperparameters, is advisable.

    Full Python Implementation

    Having developed the theory, the next step is to construct a GP. The implementation begins with a from-scratch NumPy version to consolidate intuition and then proceeds to GPyTorch for practical use.

    From Scratch with NumPy

    The implementation below follows the equations above literally. Cholesky decomposition handles the matrix inverse efficiently and stably.

    import numpy as np
    import matplotlib.pyplot as plt
    
    
    def rbf_kernel(X1, X2, lengthscale=1.0, variance=1.0):
        """RBF / squared-exponential kernel."""
        X1 = np.atleast_2d(X1)
        X2 = np.atleast_2d(X2)
        sqdist = (np.sum(X1**2, axis=1).reshape(-1, 1)
                  + np.sum(X2**2, axis=1)
                  - 2 * X1 @ X2.T)
        return variance * np.exp(-0.5 * sqdist / lengthscale**2)
    
    
    class GaussianProcess:
        def __init__(self, lengthscale=1.0, variance=1.0, noise=1e-4):
            self.lengthscale = lengthscale
            self.variance = variance
            self.noise = noise
    
        def fit(self, X, y):
            self.X_train = np.atleast_2d(X)
            self.y_train = y.reshape(-1)
            K = rbf_kernel(self.X_train, self.X_train,
                           self.lengthscale, self.variance)
            # Add noise to diagonal + tiny jitter for numerical stability
            K += (self.noise + 1e-8) * np.eye(len(self.X_train))
            # Cholesky factorization: K = L L^T
            self.L = np.linalg.cholesky(K)
            # alpha = K^{-1} y, solved via triangular systems
            self.alpha = np.linalg.solve(
                self.L.T, np.linalg.solve(self.L, self.y_train))
            return self
    
        def predict(self, X_test, return_std=True):
            X_test = np.atleast_2d(X_test)
            K_s = rbf_kernel(self.X_train, X_test,
                             self.lengthscale, self.variance)
            mu = K_s.T @ self.alpha                         # posterior mean
            v = np.linalg.solve(self.L, K_s)
            K_ss = rbf_kernel(X_test, X_test,
                              self.lengthscale, self.variance)
            cov = K_ss - v.T @ v                            # posterior cov
            std = np.sqrt(np.maximum(np.diag(cov), 0))
            return (mu, std) if return_std else mu
    
        def log_marginal_likelihood(self):
            n = len(self.y_train)
            return (-0.5 * self.y_train @ self.alpha
                    - np.sum(np.log(np.diag(self.L)))
                    - 0.5 * n * np.log(2 * np.pi))
    
    
    # ---------------- Demo: noisy sine function ----------------
    rng = np.random.default_rng(42)
    X_train = np.sort(rng.uniform(-5, 5, 12)).reshape(-1, 1)
    y_train = np.sin(X_train).ravel() + rng.normal(0, 0.15, 12)
    X_test = np.linspace(-7, 7, 300).reshape(-1, 1)
    
    gp = GaussianProcess(lengthscale=1.0, variance=1.0, noise=0.02).fit(X_train, y_train)
    mu, std = gp.predict(X_test)
    
    plt.figure(figsize=(10, 5))
    plt.fill_between(X_test.ravel(), mu - 2*std, mu + 2*std,
                     color="#93c5fd", alpha=0.5, label="95% confidence")
    plt.plot(X_test, mu, color="#1d4ed8", lw=2, label="Posterior mean")
    plt.plot(X_test, np.sin(X_test), "g--", lw=1.5, label="True function")
    plt.scatter(X_train, y_train, color="black", zorder=10, label="Training data")
    plt.legend()
    plt.title(f"GP Regression  |  LML = {gp.log_marginal_likelihood():.2f}")
    plt.show()
    

    When executed, the mean tracks the sine function closely near the data, with confidence bands widening substantially outside the training range. The Cholesky factorization performed by np.linalg.cholesky avoids explicit matrix inversion and maintains numerical stability.

    Production-Grade GPs with GPyTorch

    For real applications requiring GPU acceleration, automatic differentiation, modern kernel structures, and scalable methods, GPyTorch is the appropriate tool. It integrates directly with the PyTorch ecosystem and allows kernels, approximations, and likelihoods to be substituted with minimal code changes.

    import torch
    import gpytorch
    
    
    class ExactGPModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super().__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            # Matérn-5/2 with ARD if train_x is multi-dimensional
            base_kernel = gpytorch.kernels.MaternKernel(
                nu=2.5, ard_num_dims=train_x.shape[-1])
            self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
    
        def forward(self, x):
            mean = self.mean_module(x)
            covar = self.covar_module(x)
            return gpytorch.distributions.MultivariateNormal(mean, covar)
    
    
    # ---------------- Data ----------------
    torch.manual_seed(0)
    train_x = torch.linspace(0, 1, 50).unsqueeze(-1)
    train_y = torch.sin(train_x * 2 * torch.pi).squeeze() + 0.1 * torch.randn(50)
    
    # ---------------- Model ----------------
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = ExactGPModel(train_x, train_y, likelihood)
    
    # ---------------- Training loop ----------------
    model.train(); likelihood.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    
    for i in range(100):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()
        if i % 20 == 0:
            print(f"iter {i:3d}  loss={loss.item():.3f}  "
                  f"ls={model.covar_module.base_kernel.lengthscale.item():.3f}  "
                  f"noise={model.likelihood.noise.item():.4f}")
    
    # ---------------- Prediction ----------------
    model.eval(); likelihood.eval()
    test_x = torch.linspace(-0.2, 1.2, 200).unsqueeze(-1)
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        pred = likelihood(model(test_x))
        mean = pred.mean
        lower, upper = pred.confidence_region()  # ± 2 σ
    

    Several aspects of this snippet warrant note. The ScaleKernel adds the output variance σ2 as a learnable parameter. The Matérn-5/2 base kernel with ard_num_dims automatically provides per-dimension length scales. The training loop is standard PyTorch, supporting any optimizer, scheduler, or device. For data that fit on a GPU, calling .cuda() on the tensors and model is sufficient; GPyTorch manages the remainder.

    Tip: Inputs and targets should always be standardized (zero mean, unit variance) before a GP is trained. Kernels with a single length scale perform poorly when features differ markedly in magnitude, and non-zero-mean data wastes the model’s expressive capacity.

    Applications: Where GPs Excel

    Bayesian Optimization: The Primary Application

    Consider a function that is expensive to evaluate, such as training a deep neural network with a particular set of hyperparameters, synthesizing a candidate molecule, or running a multi-week physical simulation. Grid search is infeasible, so each evaluation should yield as much information as possible.

    Bayesian Optimization uses a GP as a surrogate for the expensive function. Each iteration proceeds as follows.

    1. Fit a GP to the data observed so far.
    2. Use an acquisition function to determine where to evaluate next, balancing exploitation (sampling where the GP predicts a high value) against exploration (sampling where the GP is most uncertain).
    3. Evaluate the true function at that point.
    4. Add the new observation to the dataset and repeat.

    Common acquisition functions include the following.

    • Expected Improvement (EI): the expected amount by which the new point improves on the best observed value. EI has a closed form under a GP.
    • Upper Confidence Bound (UCB): μ(x) + β · σ(x), with tunable exploration through β.
    • Probability of Improvement (PI): the probability that the new point exceeds the incumbent. Simple but often excessively greedy.

    Bayesian Optimization: Narrowing in on the Optimum Iteration 1: 2 points, wide uncertainty next query Iteration 3: 4 points, narrowing next query Iteration 5: 6 points, converged optimum found Evaluated points GP posterior mean 95% confidence True function (hidden) Acquisition function Each iteration updates the GP, the acquisition function peaks at the most informative next query, and uncertainty collapses near observed points.

    A working Bayesian optimization loop in approximately forty lines is shown below.

    import numpy as np
    from scipy.stats import norm
    
    def expensive_function(x):
        """The black box we want to maximize — pretend this takes hours."""
        return -((x - 2.3)**2) + 0.5 * np.sin(3 * x) + 2.0
    
    def expected_improvement(mu, sigma, f_best, xi=0.01):
        with np.errstate(divide='ignore', invalid='ignore'):
            imp = mu - f_best - xi
            z = imp / sigma
            ei = imp * norm.cdf(z) + sigma * norm.pdf(z)
            ei[sigma < 1e-9] = 0.0
        return ei
    
    # Seed with 2 random evaluations
    rng = np.random.default_rng(7)
    X_obs = rng.uniform(0, 5, 2).reshape(-1, 1)
    y_obs = expensive_function(X_obs.ravel())
    
    for step in range(10):
        gp = GaussianProcess(lengthscale=0.8, variance=1.0, noise=1e-3).fit(X_obs, y_obs)
        X_grid = np.linspace(0, 5, 500).reshape(-1, 1)
        mu, sigma = gp.predict(X_grid)
        ei = expected_improvement(mu, sigma, y_obs.max())
        x_next = X_grid[np.argmax(ei)]
        y_next = expensive_function(x_next)
        X_obs = np.vstack([X_obs, x_next.reshape(1, -1)])
        y_obs = np.append(y_obs, y_next)
        print(f"step {step+1:2d}  queried x={x_next[0]:.3f}  "
              f"y={y_next:.3f}  best={y_obs.max():.3f}")
    

    In production use, established libraries such as BoTorch (built on GPyTorch), scikit-optimize, Optuna, and Ax are recommended. They support mixed discrete and continuous spaces, multi-objective problems, constraints, and batch acquisition. Bayesian optimization is the method by which serious teams tune LLM hyperparameters, design experiments, and optimize materials. It is also a natural alternative to evolutionary search; the companion piece on genetic algorithms for black-box optimization provides a useful comparison.

    Time Series Forecasting

    GPs are well suited to time series forecasting because kernels can directly encode expected features: a periodic kernel for seasonality, a Matérn kernel for local smoothness, and a linear kernel for drift. Composite kernels such as RBF + Periodic + Linear reproduce results close to those of Facebook Prophet while including calibrated uncertainty by construction.

    A related application is time series anomaly detection: a GP is fitted to normal behavior, and any new observation falling outside the 3σ prediction band is flagged. The method is interpretable, adapts to local seasonality, and does not require labeled anomalies.

    Spatial Modeling and Kriging

    In geostatistics, the technique known as Kriging is, in mathematical terms, a Gaussian Process under a different name. Developed by the mining engineer Danie Krige in the 1950s, it has been used for decades to interpolate ore grades, oil-reservoir properties, soil contamination maps, and climate variables from sparse measurements. A heatmap of pollution concentrations interpolated from thirty monitoring stations was very likely produced by a GP.

    GP Classification

    GP regression assumes Gaussian noise and closed-form posterior inference. For classification, outputs are discrete, so the latent GP is wrapped in a sigmoid (binary) or softmax (multi-class) link function. The posterior is no longer Gaussian and requires approximation: Laplace approximation, expectation propagation, or modern variational inference. The procedure entails more effort than a neural-network classifier for high-dimensional data, but it remains useful when calibrated class probabilities are required and data are scarce.

    Active Learning and Surrogate Modeling

    Given a query budget and a candidate pool, a GP selects the next query to label by maximizing the posterior variance, which corresponds to the most informative point. This active-learning loop substantially reduces labeling cost in domains such as materials discovery, protein engineering, and any setting in which ground-truth labels require an experiment. GPs combine particularly well with semi-supervised learning and self-supervised representation learning when labels are scarce but unlabeled data are abundant.

    Applications at a Glance

    Application Typical N Popular Libraries
    Bayesian optimization (hyperparameter tuning) 20 – 500 BoTorch, Ax, Optuna, scikit-optimize
    Time series / forecasting 100 – 10,000 GPyTorch, GPflow, PyMC
    Spatial interpolation (Kriging) 500 – 100,000 (sparse) PyKrige, scikit-gstat, GPyTorch
    Surrogate modeling for simulation 50 – 5,000 GPyTorch, SMT, emukit
    Classification 100 – 5,000 scikit-learn, GPyTorch, GPflow

     

    Scalability: Breaking the O(n3) Wall

    Standard GPs invert an n×n matrix, which requires O(n3) time and O(n2) memory. At n = 1,000 the cost is negligible. At n = 10,000 the wait becomes noticeable. At n = 100,000 the computation is infeasible on a laptop. Much of contemporary GP research is devoted to raising this ceiling.

    Sparse GPs via Inducing Points

    The dominant approach is to approximate the n training points with a much smaller set of M inducing points, typically M = 50 to 1000. Computation is then reduced to O(n M2).

    Method Idea Strengths / Caveats
    FITC Fully Independent Training Conditional Fast, but can underestimate noise and produce overconfident predictions.
    DTC Deterministic Training Conditional Simpler than FITC, tends to overestimate variance.
    VFE Variational Free Energy (Titsias 2009) Principled variational bound, well-calibrated — a common default.
    SVGP Stochastic Variational GP (Hensman 2013) Mini-batch training, scales to millions of points, handles non-Gaussian likelihoods.

     

    Exact GPs at Scale with BBMM

    GPyTorch introduced Black-Box Matrix-Matrix multiplication (BBMM), which uses preconditioned conjugate gradients and Lanczos iterations to solve the relevant linear systems without forming the inverse. On a GPU, exact GPs now scale to more than 100,000 points, a regime that previously required approximation.

    Deep Kernel Learning and Deep GPs

    Deep Kernel Learning (DKL) places a neural network before the kernel: the network extracts features φ(x), and the kernel then operates on φ. The result combines deep representation learning with GP uncertainty quantification. For structured inputs such as images, graphs, and sequences, DKL is often the appropriate compromise. It complements graph-based architectures such as Graph Attention Networks when both rich features and calibrated uncertainty are required.

    Deep GPs stack multiple GP layers, each feeding into the next. They can learn hierarchical nonstationary functions but require variational inference for training. The added expressiveness is powerful but frequently more than is required.

    Gaussian Processes Compared to Alternatives

    The comparison between GPs and other common models is summarized below, followed by a brief discussion.

    Model Uncertainty Small-data performance Scalability Interpretability
    Gaussian Process Native, calibrated Excellent O(n³) standard High (via kernel)
    Linear Regression Yes (Bayesian version) Good if linear O(n d²) Very high
    Random Forest Partial (ensemble variance) Good O(n log n) Medium
    Neural Network No (heuristic only) Overfits easily O(n) Low
    Bayesian NN Approximate Good Expensive (MCMC/VI) Low-medium

     

    Several observations are worth noting.

    • GP versus linear regression. A GP with a linear kernel is Bayesian linear regression. Adding an RBF kernel produces a nonlinear, nonparametric counterpart.
    • GP versus random forest. Random forests produce discontinuous step functions and only approximate variance estimates. GPs produce smooth, calibrated predictions. Random forests handle categorical features natively, whereas GPs require custom kernels.
    • GP versus neural network. Neural networks dominate large-data, high-dimensional problems. GPs dominate small-data, uncertainty-critical problems. In the infinite-width limit a Bayesian neural network is equivalent to a GP, a result known as the Neural Tangent Kernel or NNGP correspondence.
    • GP versus Bayesian neural network. GPs admit closed-form posteriors for Gaussian likelihoods. Bayesian neural networks rely on variational or MCMC approximations that are difficult to validate.
    • GP versus MCMC. The two are complementary rather than competing. MCMC is appropriate for exploring complex non-Gaussian posteriors; a GP is appropriate when the posterior is close to Gaussian and computational speed is important.
    • GP versus SVM. Both are kernel methods, but SVMs optimize a margin-based classifier and provide no uncertainty. The companion SVM comparison guide covers kernel machines outside the GP family.
    • Combination. Deep Kernel Learning is a natural hybrid: a neural network extracts features and a GP supplies uncertainty on top. The combination frequently performs well in competitions.

    Common Pitfalls and How to Avoid Them

    The following traps commonly arise when GPs are deployed in real projects.

    • Failure to center the target. The default mean function is zero. When targets have a mean of 500, the GP extrapolates toward zero far from training data, producing implausible predictions. The training mean should always be subtracted from y before fitting and added back during prediction.
    • Numerical instability. Kernel matrices are nearly singular when training points cluster. A small “jitter” (for example 1e-6) should be added to the diagonal of K(X, X) before Cholesky decomposition. GPyTorch does this automatically; from-scratch implementations should do so as well.
    • Wrong kernel for the data. Using RBF for a jagged function produces oversmoothed predictions with overconfident error bars. For rough-looking data, Matérn-3/2 or Matérn-5/2 is preferable. For periodic data, a periodic kernel is appropriate.
    • Overfitting hyperparameters with very small N. When N < 20, the marginal likelihood can have multiple local optima. Priors on hyperparameters and optimization from several random seeds are recommended.
    • Scaling without approximations. When N > 10,000, attempting to use a standard GP without GPyTorch’s scalable kernels or an SVGP exhausts memory. The recommended approximations should be used.
    • Gaussian noise assumption. Standard GP regression assumes Gaussian observation noise. For data with heavy tails or outliers, Student-t likelihoods or a different model should be considered.
    • Failure to standardize features. A single length scale cannot accommodate features with widely different units. Inputs should be standardized, or ARD kernels with per-dimension length scales should be used.
    Key Takeaway: A GP is as much an engineering artifact as a mathematical one. Sound numerical hygiene—jitter, standardization, warm restarts—is the difference between a model that works reliably and one that fails inexplicably. These practices apply to engineering in general; see the clean code principles guide for further discussion.
    Related Reading:

    Frequently Asked Questions

    Gaussian Process vs. Neural Network — when should I use which?

    Use a Gaussian Process when you have small to medium data (under ~10,000 points), need calibrated uncertainty, and believe the underlying function is smooth and structured. Use a neural network when you have large data (100k+), high-dimensional raw inputs (images, text, graphs), and your primary need is raw predictive accuracy rather than uncertainty. When you want both — deep features and uncertainty — combine them via Deep Kernel Learning, which puts a neural network feature extractor in front of a GP.

    Can Gaussian Processes handle large datasets?

    Standard GPs scale as O(n3) in time and O(n2) in memory, which breaks down past roughly 10,000 training points. Modern approximations change this picture dramatically. Sparse variational GPs like SVGP use a small set of inducing points and can train on millions of rows with mini-batching. GPyTorch’s BBMM algorithm uses conjugate gradients to solve exact GPs with 100,000+ points on a GPU. For most practical workloads, scalability is no longer a hard barrier — you just need to pick the right approximation.

    What kernel should I choose?

    A safe starting point is the Matérn-5/2 kernel with Automatic Relevance Determination (ARD) — it assumes realistic smoothness and learns per-dimension length scales automatically. Use RBF if you truly expect infinitely differentiable behavior. Add a periodic kernel if your data has clear cycles (daily, weekly, yearly). Combine kernels by addition (for independent effects) or multiplication (for interactions). When in doubt, train several kernels and pick the one with the highest log marginal likelihood on held-out data.

    Is a Gaussian Process the same as Kriging?

    Yes, essentially. Kriging is the name used in geostatistics and mining engineering, dating back to Danie Krige’s work in the 1950s, while “Gaussian Process” is the machine-learning community’s term. The underlying mathematics is identical: both model spatial (or more general) data as a realization of a Gaussian random field, use kernel-based covariance, and produce predictions with uncertainty. Ordinary Kriging corresponds to a GP with a constant mean; universal Kriging corresponds to a GP with a parametric mean function.

    Can GPs do classification, not just regression?

    Yes, but it’s more complex than regression. A GP classifier wraps the latent GP output in a link function (sigmoid for binary, softmax for multi-class), which makes the posterior non-Gaussian. Inference requires approximations like the Laplace approximation, Expectation Propagation, or modern variational methods. Libraries like GPyTorch and scikit-learn support GP classification out of the box. In practice, for low-dimensional inputs with small to medium data and a need for calibrated probabilities, GP classification is a powerful option — but for high-dimensional inputs like images, a neural network is still the better tool.

    Conclusion and Further Reading

    Gaussian Processes occupy an unusual position in machine learning. They are mathematically elegant, practically useful, and philosophically honest: they return not a number but a distribution, not an answer but a calibrated belief. Where neural networks excel in scale, GPs reassure with calibration. Where tree-based models prevail on heterogeneous tabular data, GPs prevail on smooth structured signals. Where MCMC is principled but slow, GPs are principled and fast, at least for regression.

    The practical toolkit derived from this discussion is as follows.

    • Begin with a Matérn-5/2 kernel with ARD and GPyTorch.
    • Standardize inputs and outputs.
    • Train by maximizing the log marginal likelihood using Adam or L-BFGS.
    • Use Bayesian optimization (BoTorch, Optuna, or Ax) for expensive black-box functions.
    • Scale with inducing points or BBMM when N > 10,000.
    • Combine with neural networks via Deep Kernel Learning for structured high-dimensional inputs.
    • Respect the Gaussian noise assumption; if the noise is non-Gaussian, use a different likelihood or a different model.

    GPs are worth including in any practitioner’s repertoire if only for the epistemic humility they enforce. A model that explicitly acknowledges the limits of its knowledge is one that can be trusted. In an environment increasingly populated by confident-sounding predictions, such humility is a rare and valuable trait. Readers interested in adjacent Python engineering choices may find the broader discussion in the Python versus Rust comparison useful.

    References and Further Reading

    • Rasmussen, C. E. & Williams, C. K. I. — Gaussian Processes for Machine Learning, MIT Press, 2006. Free online at gaussianprocess.org/gpml. The canonical textbook.
    • GPyTorch documentation — gpytorch.ai. Modern scalable GPs in PyTorch.
    • Distill.pub — A Visual Exploration of Gaussian Processes. Stunning interactive visualizations.
    • BoTorch documentation — botorch.org. Production Bayesian optimization built on GPyTorch.
    • scikit-learn GP regressor — scikit-learn.org/stable/modules/gaussian_process. Good for small experiments and teaching.
    • Titsias, M. — Variational Learning of Inducing Variables in Sparse Gaussian Processes, AISTATS 2009. The VFE paper.
    • Hensman, J., Fusi, N., Lawrence, N. D. — Gaussian Processes for Big Data, UAI 2013. The SVGP paper.

    Disclaimer: This post is for educational and informational purposes only. Any illustrative example involving investment prices or financial returns is for pedagogical purposes and is not investment advice.

  • Semi-Supervised Learning Explained: Pseudo-Labeling, FixMatch, and More

    Summary

    What this post covers: A detailed examination of semi-supervised learning (SSL), from classical methods through modern consistency-based approaches, with a full PyTorch implementation of FixMatch that enables a model to match supervised accuracy using 10 to 100 times fewer labels.

    Key insights:

    • Modern SSL methods like FixMatch can match fully-supervised performance with 10x to 100x fewer labels by combining weak augmentation, confidence thresholding (tau = 0.95), and strong-augmentation consistency.
    • Semi-supervised learning is not self-supervised learning: SSL uses some task labels plus unlabeled data, while self-supervised invents labels from data structure and produces a pretrained backbone.
    • SSL only works when the smoothness, cluster, manifold, or low-density assumption holds; applying it blindly across distribution shift between labeled and unlabeled splits will silently destroy accuracy.
    • The confidence-gated pseudo-label is a natural curriculum: early in training most unlabeled examples fall below threshold and are ignored, so the model is not poisoned by its own bad predictions.
    • FixMatch’s effectiveness comes mostly from strong augmentation (RandAugment + Cutout) and high confidence thresholds, not from complex architectures, which is why it generalizes across vision, audio, NLP, and medical imaging.

    Main topics: The Promise of Learning from Almost-Free Data, What Semi-Supervised Learning Is (and Isn’t), Semi-Supervised vs Self-Supervised: The Critical Distinction, The Four Assumptions That Make SSL Work, Classical Semi-Supervised Methods, The Deep Learning Era of SSL, FixMatch in Detail: How the Method Works, Full PyTorch Implementation of FixMatch, Real-World Applications Across Domains, Paradigm Comparison: SSL, Self-SSL, Transfer, Active, Practical Guide: Thresholds, Data Ratios, Pitfalls, Connections to Transfer, Active, and Domain Adaptation, Frequently Asked Questions, References and Further Reading.

    The Promise of Learning from Almost-Free Data

    Consider a setting with 1,000 labelled medical images and 100,000 unlabelled ones. Training only on the labelled portion yields 78% accuracy. Adding the unlabelled data through semi-supervised learning raises that figure to 93%, with no additional labels required.

    That single observation explains why semi-supervised learning has quietly become one of the most consequential ideas in modern machine learning. Labels are expensive. A radiologist annotating a chest X-ray represents both real cost and real time. A crowd worker labelling toxic comments must read each one carefully. An engineer hand-segmenting pedestrians in a video frame may require ten minutes per frame. The raw data, however, is largely free: unlabelled X-rays accumulate on hospital servers, billions of comments sit on social platforms, and petabytes of driving footage occupy onboard storage.

    Semi-supervised learning (SSL) refers to the set of techniques that train models using both kinds of data simultaneously: a small set of labelled examples and a much larger set of unlabelled ones. When SSL succeeds, the gains can be substantial. Modern methods such as FixMatch match fully supervised performance with 10 to 100 times fewer labels. When SSL fails, the causes are typically subtle—confirmation bias, distribution shift, and class imbalance—and are examined in detail below.

    Important Disambiguation: This post concerns semi-supervised learning. It does not concern self-supervised learning, even though both are sometimes abbreviated as “SSL.” The two are distinct paradigms addressing distinct problems. Readers seeking the self-supervised treatment (pretext tasks, contrastive learning, masked image modelling) should consult the dedicated guide to self-supervised learning. The distinction is examined in detail in the next section, as the difference is consequential.

    By the end of the article, a reader should understand the full arc: why SSL works in theory, how the classical methods of the 1960s evolved into today’s recent best, how FixMatch became the default, and how to implement it from scratch in PyTorch. The article also identifies cases in which SSL should not be applied, since applying it without consideration of distribution shift between labelled and unlabelled splits can quietly degrade accuracy.

    What Semi-Supervised Learning Is (and Isn’t)

    The formal definition is straightforward. Semi-supervised learning involves two datasets:

    • A labeled set DL = {(x1, y1), (x2, y2),…, (xn, yn)}, typically small.
    • An unlabeled set DU = {xn+1, xn+2,…, xn+m}, typically large—often m is 10 to 1000 times larger than n.

    The labels correspond to the same target task of interest (for example, “cat” or “dog” or “pneumonia”). The unlabelled data is drawn from approximately the same distribution as the labelled data, but lacks annotations. The objective is to train a model that performs well on that target task, with the expectation that the unlabelled data, used judiciously, improves performance beyond what the labelled data alone would permit.

    It sits on a spectrum of supervision:

    • Fully supervised: every example has a label. The default. Expensive.
    • Semi-supervised: some examples labeled, most not. Solves the downstream task directly.
    • Self-supervised: no human labels at all. Invents labels from data structure (predict masked pixels, predict next token, match augmented views). Usually produces a backbone that’s then fine-tuned.
    • Unsupervised: no labels, no downstream task, just clustering, density estimation, dimensionality reduction.
    • Weakly supervised: labels exist but are noisy, imprecise, or indirect (e.g., image-level labels used for segmentation).

    The Supervision Spectrum 100% labels 0% labels Supervised All examples labeled Semi-Supervised Few labeled + many unlabeled Self-Supervised No labels; invents pretext tasks Unsupervised No labels; no task Semi-Supervised Data Mixture Labeled (green) n = 1,000 Unlabeled (grey) m = 100,000 Goal: jointly train a model using both sets for the downstream task.

    Semi-Supervised vs Self-Supervised: The Critical Distinction

    The two paradigms are frequently conflated, partly because of the shared “SSL” abbreviation and partly because both involve unlabelled data. They are nonetheless distinct, and a clear separation prevents considerable downstream confusion.

    Self-supervised learning uses no human-provided labels at training time. It generates labels from the structure of the data itself. A common pattern is to mask 15% of tokens in a sentence and predict them (BERT). Another is to crop two patches of an image and train the network to identify which pair came from the same image (contrastive learning). A third is to predict whether a rotated image was rotated 0°, 90°, 180°, or 270°. The “label” is generated automatically. The output of self-supervised learning is typically not a task-solving model but a pretrained backbone that is subsequently fine-tuned on a downstream task with labels.

    Semi-supervised learning uses some human-provided labels together with unlabelled data. The labels correspond directly to the downstream task (“cat” versus “dog,” “malignant” versus “benign,” “spam” versus “ham”). The output is a model that solves that task. There is no pretext task. Unlabelled data is used to enforce consistency, propagate labels, or minimise entropy, but the objective is always tied back to the labelled task.

    Aspect Semi-Supervised Self-Supervised
    Goal Solve downstream task directly Learn general representations (pretraining)
    Human labels used Yes, a small number None during pretraining
    Label source Humans (partial coverage) Invented from data (masking, pairs, rotations)
    Typical methods FixMatch, Mean Teacher, MixMatch, pseudo-labeling MAE, SimCLR, MoCo, DINO, BERT, GPT
    Output artifact Task-ready classifier/regressor Frozen backbone to be fine-tuned later
    When to use You have some labels and can’t afford more You have substantial unlabeled corpora and want reusable features
    Example 250 labeled CIFAR-10 + 50k unlabeled → 94% accuracy Pretrain on 1B images → fine-tune on ImageNet

     

    Semi-Supervised vs Self-Supervised Semi-Supervised Learning Labeled data (small) e.g. 1,000 images + labels Unlabeled data (large) e.g. 100,000 images Joint training Supervised loss + consistency loss Downstream classifier (ready to use) Predicts cat/dog, tumor/benign, etc. One pipeline. One model. Solves the task directly. Self-Supervised Learning Unlabeled data only No human labels at all Pretext task Mask, contrast, rotate, predict next token Pretrained backbone Generic feature extractor Fine-tune on labeled downstream task

    A useful summary: self-supervised learning produces backbones; semi-supervised learning produces task solvers. The two can be combined: pretrain with self-supervision, then fine-tune with semi-supervised learning. In practice, this combination underlies many of the strongest current pipelines. For the self-supervised half of that combination, the self-supervised learning guide covers masked image modelling, contrastive learning, and the DINO family in detail.

    The Four Assumptions That Make SSL Work

    Semi-supervised learning does not succeed unconditionally. If the unlabelled data were unrelated to the labelled data, no algorithmic refinement would help. SSL relies on structural assumptions about the relationship between inputs and labels. Four assumptions are most commonly cited:

    • Smoothness: if two points are close in input space, their labels should be similar. This is what enables consistency regularization—perturb the input slightly, and the prediction shouldn’t change.
    • Cluster assumption: data naturally forms clusters, and points in the same cluster share labels. Decision boundaries should run between clusters, not through them.
    • Low-density separation: the optimal decision boundary lies in a low-density region of the input space. This is the cluster assumption restated in terms of density, semi-supervised SVMs (S³VM) directly encode it.
    • Manifold assumption: high-dimensional data actually lies on a lower-dimensional manifold, and the relevant variation for labels happens along the manifold. Graph-based methods exploit this by defining similarity along the data manifold.
    Key Takeaway: When SSL produces strong gains, it is generally because one or more of these assumptions hold approximately for the data. When SSL fails silently, the typical cause is that the unlabelled data violates the cluster or manifold assumption: for example, the unlabelled set contains classes absent from the labelled set, or originates from a different sensor or population.

    Classical Semi-Supervised Methods

    Before deep learning, researchers developed a substantial body of semi-supervised algorithms. Many remain useful, and their ideas recur in modern deep methods.

    Self-Training (Pseudo-Labelling)

    This is the oldest approach, dating to Scudder in 1965 and popularised for deep learning by Dong-Hyun Lee in 2013. The procedure is simple:

    1. Train a model on the labeled set.
    2. Predict labels for the unlabeled set.
    3. Keep the predictions where the model is very confident (softmax > threshold).
    4. Add those pseudo-labeled examples to the training set.
    5. Retrain. Optionally iterate.

    The principal risk is confirmation bias: if the model’s initial predictions are biased, retraining on those biased predictions reinforces the bias. Pseudo-labelling alone is rarely the strongest method, but it forms the backbone of every modern approach, including FixMatch.

    Co-Training

    Blum and Mitchell (1998) proposed training two classifiers on two different “views” of the input, such as the URL of a web page and the text on the page. Each classifier labels the unlabelled examples on which it is most confident, and those pseudo-labels are used to train the other classifier. The underlying assumption is that the two views are conditionally independent given the label. When this assumption holds, co-training can substantially reduce the number of labels required.

    Label Propagation

    The procedure constructs a k-nearest-neighbour graph over all examples (labelled and unlabelled). Labels propagate through the graph, with each node’s label becoming a weighted average of its neighbours’ labels. Iteration continues until convergence. Labelled nodes remain pinned to their true labels; unlabelled nodes absorb labels from their neighbourhood. This represents a direct implementation of the manifold assumption and pairs naturally with graph neural networks. See the graph attention networks (GAT) guide for the modern deep counterpart.

    Transductive SVM (S³VM)

    A standard SVM finds the maximum-margin hyperplane separating labelled points. A transductive SVM considers both labelled and unlabelled points, and seeks a hyperplane that (i) separates labels correctly and (ii) passes through a low-density region of the unlabelled data. The optimisation is non-convex and difficult, but the underlying idea—that decision boundaries should avoid data-dense regions—is central.

    Generative Methods

    The approach fits a generative model (a Gaussian mixture, a naive Bayes model, a variational autoencoder) jointly on labelled and unlabelled data. EM-style updates treat unlabelled examples as having latent class labels. Provided the generative model is well-specified, unlabelled data tightens parameter estimates and improves the classifier. If the model is misspecified—for example, if the data is not Gaussian—unlabelled data can degrade performance.

    Entropy Minimisation

    Grandvalet and Bengio (2005) observed that if the cluster assumption holds, the model should make confident predictions on unlabelled data. Their approach adds a term to the loss that minimises the entropy of predictions on unlabelled inputs:

    L_total = L_supervised + lambda * H(p_model(y | x_unlabeled))

    This term encourages the model to avoid decision boundaries that run through unlabelled data. Entropy minimisation is a small but pervasive component of nearly every modern method. FixMatch implements it indirectly through confidence thresholding and pseudo-labelling.

    The Deep Learning Era of SSL

    Deep networks transformed SSL in two principal ways. First, they made representation learning on unlabelled data genuinely useful, whereas shallow models gain little from unlabelled data once the feature space is fixed. Second, they made consistency regularisation, a powerful tool, practical at scale.

    Consistency Regularisation

    The central idea is that predictions should be invariant to small perturbations of the input. Flipping an image horizontally, cropping it, adding a small amount of noise, or applying different dropout masks should not materially change the output probability distribution. This constraint can be enforced directly in the loss, and importantly it can be applied to unlabelled examples, because stability under noise does not require a label.

    Π-model (Laine and Aila, 2017). For each unlabelled example, two forward passes are run with different stochastic augmentations and dropout masks. The squared difference between the two softmax outputs is minimised. Combined with standard cross-entropy on the labelled data, this constitutes a complete SSL algorithm.

    Temporal Ensembling. The Π-model’s two predictions are noisy. Temporal Ensembling replaces one of them with an exponential moving average of predictions across epochs, producing a smoother and more stable target. The drawback is memory consumption: running predictions must be stored for every unlabelled example.

    Mean Teacher (Tarvainen and Valpola, 2017). Rather than averaging predictions over time, the method averages model weights over time. Two networks are maintained: a “student” trained via SGD, and a “teacher” whose weights are an exponential moving average of the student’s weights. The teacher produces the target for the consistency loss. Mean Teacher is more stable and more memory-efficient than Temporal Ensembling, and it remains an excellent baseline, particularly for regression and segmentation tasks.

    Pseudo-Labelling, Revisited

    Noisy Student (Xie et al., 2020). This method returned pseudo-labelling to the front rank of techniques. The procedure trains a teacher on labelled ImageNet, uses it to pseudo-label 300 million unlabelled images from JFT, and trains a larger student on the combined set under heavy noise (RandAugment, dropout, stochastic depth). The noisy student generalises better than its teacher; iteration follows, with each student becoming the next teacher. Noisy Student raised ImageNet accuracy beyond what fully supervised models had achieved.

    Hybrid Methods

    MixMatch (Berthelot et al., 2019). Combines (a) K augmented predictions averaged and sharpened into a soft pseudo-label, (b) MixUp between labelled and unlabelled batches, and (c) consistency. The method was strong at the time of publication.

    ReMixMatch. Adds distribution alignment (the unlabelled pseudo-label distribution should match the labelled class distribution) and augmentation anchoring (predictions are anchored to weakly-augmented copies, not averages).

    FixMatch (Sohn et al., 2020). The current default. It strips away most of MixMatch’s complexity and retains only what works: weak augmentation for pseudo-labels, strong augmentation for the consistency target, and a confidence threshold. The method is implemented from scratch later in this article.

    FlexMatch. Replaces FixMatch’s single global threshold with per-class dynamic thresholds that reflect each class’s learning difficulty. It is helpful on imbalanced or curriculum-style problems.

    Graph-Based Deep SSL

    When data naturally lives on a graph—citation networks, molecular graphs, social networks—semi-supervised node classification with a Graph Convolutional Network or Graph Attention Network is the canonical approach. A handful of labelled nodes coexist with millions of unlabelled ones, and information flows along edges. The GAT architecture is, in effect, learned label propagation with attention-weighted edges.

    FixMatch in Detail: How the Method Works

    FixMatch warrants close examination. The method is simple, highly effective, and offers a useful mental model for what “modern SSL” entails.

    The Idea in One Sentence

    For every unlabelled example, if the model produces a confident prediction for a particular class from a weakly augmented version of the image, the model is then required to predict that class from a strongly augmented version of the same image.

    Ingredients

    • A backbone network f (ResNet, WideResNet, etc.) with a classification head.
    • A weak augmentation α: typically random horizontal flip and random crop.
    • A strong augmentation A: RandAugment or CTAugment (color, rotation, shear, contrast), followed by Cutout.
    • A labeled batch of size B and an unlabeled batch of size μB (usually μ = 7, so 7× more unlabeled per step).
    • A confidence threshold τ, commonly 0.95.
    • A loss weight λ for the unsupervised term, commonly 1.0.

    The Loss

    On each training step, compute two losses:

    Supervised loss on the labeled batch:

    L_s = (1/B) * sum over labeled examples of CE(y_b, f(alpha(x_b)))

    Unsupervised loss on the unlabeled batch:

    # For each unlabeled example x_u:
    q_u    = softmax(f(alpha(x_u)))        # weak-aug prediction
    p_hat  = argmax(q_u)                   # pseudo-label
    mask   = 1 if max(q_u) >= tau else 0   # confidence gate
    L_u   += mask * CE(p_hat, f(A(x_u)))   # strong-aug prediction vs pseudo-label

    The total loss is L = L_s + λ · L_u.

    Two practical subtleties matter:

    1. The weak-augmentation forward pass uses torch.no_grad(), or gradients are otherwise stopped on q_u. Backpropagation through the pseudo-label target is not permitted.
    2. The confidence mask is applied element-wise. Early in training, most unlabelled examples fall below the threshold and are ignored. As the model improves, an increasing fraction of examples receive pseudo-labels. This produces a natural curriculum.

    FixMatch Training Step Labeled image x_b with label y_b Weak aug alpha flip + crop Model f shared backbone L_s = CE(y_b, f(alpha(x_b))) supervised cross-entropy Unlabeled image x_u (no label) Weak aug alpha flip + crop Strong aug A RandAugment + Cutout Model f (no grad) softmax output q_u Model f prediction f(A(x_u)) Pseudo-label p_hat = argmax(q_u) Mask = 1 if max(q_u) >= tau (0.95) L_u = mask * CE(p_hat, f(A(x_u))) consistency cross-entropy Total loss: L = L_s + lambda * L_u lambda typically 1.0, tau typically 0.95 Shared backbone; gradients flow from both losses; pseudo-label target has no gradient. Early in training most unlabeled examples are below threshold and ignored—natural curriculum.

    Full PyTorch Implementation of FixMatch

    The following is a complete, runnable FixMatch implementation on CIFAR-10. It uses a simple WideResNet-style backbone and follows the original paper’s recipe closely enough to reach approximately 90%+ accuracy with 250 labels given sufficient training (the paper reports 94.93%). For illustration, the training loop is kept short; extending the number of epochs and iterations is required for full results.

    Tip: FixMatch requires many iterations: the original paper trains for 1,048,576 steps (220). Results are not visible in 10 epochs. Compute should be planned accordingly, or a faster dataset such as MNIST may be used for prototyping.
    import math
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import datasets, transforms
    from torchvision.transforms import RandAugment
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # ---------- 1. Dataset split: labeled + unlabeled ----------
    
    def split_labeled_unlabeled(dataset, n_labeled_per_class=25, n_classes=10):
        """Create a small labeled subset and treat the rest as unlabeled."""
        labels = np.array(dataset.targets)
        labeled_idx, unlabeled_idx = [], []
        for c in range(n_classes):
            idx = np.where(labels == c)[0]
            np.random.shuffle(idx)
            labeled_idx.extend(idx[:n_labeled_per_class])
            unlabeled_idx.extend(idx[n_labeled_per_class:])
        return labeled_idx, unlabeled_idx
    
    # ---------- 2. Weak and strong augmentation ----------
    
    CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
    CIFAR_STD  = (0.2470, 0.2435, 0.2616)
    
    class WeakAug:
        def __init__(self):
            self.t = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ])
        def __call__(self, x): return self.t(x)
    
    class StrongAug:
        """Weak flip/crop + RandAugment + Cutout."""
        def __init__(self):
            self.base = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
                RandAugment(num_ops=2, magnitude=10),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ])
        def __call__(self, x):
            img = self.base(x)
            # Cutout: random 16x16 zero patch
            _, H, W = img.shape
            y, x_ = np.random.randint(H), np.random.randint(W)
            y1, y2 = max(0, y-8), min(H, y+8)
            x1, x2 = max(0, x_-8), min(W, x_+8)
            img[:, y1:y2, x1:x2] = 0
            return img
    
    class LabeledDataset(Dataset):
        def __init__(self, base, idx):
            self.base, self.idx, self.aug = base, idx, WeakAug()
        def __len__(self): return len(self.idx)
        def __getitem__(self, i):
            img, y = self.base[self.idx[i]]
            return self.aug(img), y
    
    class UnlabeledDataset(Dataset):
        """Returns (weak_aug, strong_aug) pair."""
        def __init__(self, base, idx):
            self.base, self.idx = base, idx
            self.weak, self.strong = WeakAug(), StrongAug()
        def __len__(self): return len(self.idx)
        def __getitem__(self, i):
            img, _ = self.base[self.idx[i]]
            return self.weak(img), self.strong(img)
    
    # ---------- 3. Simple WideResNet-ish backbone ----------
    
    class BasicBlock(nn.Module):
        def __init__(self, cin, cout, stride=1):
            super().__init__()
            self.bn1 = nn.BatchNorm2d(cin)
            self.conv1 = nn.Conv2d(cin, cout, 3, stride, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(cout)
            self.conv2 = nn.Conv2d(cout, cout, 3, 1, 1, bias=False)
            self.shortcut = (nn.Conv2d(cin, cout, 1, stride, bias=False)
                             if stride != 1 or cin != cout else nn.Identity())
        def forward(self, x):
            h = self.conv1(F.relu(self.bn1(x)))
            h = self.conv2(F.relu(self.bn2(h)))
            return h + self.shortcut(x)
    
    class WideResNet(nn.Module):
        def __init__(self, num_classes=10, widen=2):
            super().__init__()
            n = 16
            self.stem = nn.Conv2d(3, n, 3, 1, 1, bias=False)
            widths = [n, n*widen, n*2*widen, n*4*widen]
            layers = []
            for i in range(3):
                stride = 1 if i == 0 else 2
                layers.append(BasicBlock(widths[i], widths[i+1], stride))
                layers.append(BasicBlock(widths[i+1], widths[i+1], 1))
            self.blocks = nn.Sequential(*layers)
            self.bn = nn.BatchNorm2d(widths[-1])
            self.fc = nn.Linear(widths[-1], num_classes)
        def forward(self, x):
            h = self.blocks(self.stem(x))
            h = F.relu(self.bn(h))
            h = F.adaptive_avg_pool2d(h, 1).flatten(1)
            return self.fc(h)
    
    # ---------- 4. Data pipeline ----------
    
    raw = datasets.CIFAR10("./data", train=True, download=True)
    test = datasets.CIFAR10("./data", train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]))
    
    lab_idx, unlab_idx = split_labeled_unlabeled(raw, n_labeled_per_class=25)
    lab_ds   = LabeledDataset(raw, lab_idx)           # 250 images
    unlab_ds = UnlabeledDataset(raw, unlab_idx)       # ~49,750 images
    
    B, mu = 64, 7
    lab_loader   = DataLoader(lab_ds,   batch_size=B,    shuffle=True,
                              num_workers=2, drop_last=True)
    unlab_loader = DataLoader(unlab_ds, batch_size=B*mu, shuffle=True,
                              num_workers=2, drop_last=True)
    test_loader  = DataLoader(test, batch_size=256, num_workers=2)
    
    # ---------- 5. FixMatch training loop ----------
    
    model = WideResNet(num_classes=10, widen=2).to(device)
    opt = torch.optim.SGD(model.parameters(), lr=0.03,
                          momentum=0.9, nesterov=True, weight_decay=5e-4)
    tau, lam = 0.95, 1.0
    
    def infinite(loader):
        while True:
            for batch in loader:
                yield batch
    
    lab_iter   = infinite(lab_loader)
    unlab_iter = infinite(unlab_loader)
    
    for step in range(5000):         # paper uses 2**20; 5k is illustrative
        model.train()
        x_l, y_l        = next(lab_iter)
        x_u_w, x_u_s    = next(unlab_iter)
        x_l, y_l        = x_l.to(device), y_l.to(device)
        x_u_w, x_u_s    = x_u_w.to(device), x_u_s.to(device)
    
        # One concatenated forward pass for speed (interleaved BN trick):
        x = torch.cat([x_l, x_u_w, x_u_s], dim=0)
        logits = model(x)
        l_logits = logits[:B]
        u_w_logits, u_s_logits = logits[B:].chunk(2)
    
        # Supervised loss
        loss_s = F.cross_entropy(l_logits, y_l)
    
        # Pseudo-label from weak aug (no grad through target)
        with torch.no_grad():
            probs_w = F.softmax(u_w_logits, dim=-1)
            max_probs, pseudo = probs_w.max(dim=-1)
            mask = (max_probs >= tau).float()
    
        # Unsupervised loss on strong aug
        loss_u = (F.cross_entropy(u_s_logits, pseudo, reduction="none") * mask).mean()
    
        loss = loss_s + lam * loss_u
        opt.zero_grad(); loss.backward(); opt.step()
    
        if step % 500 == 0:
            model.eval()
            correct = total = 0
            with torch.no_grad():
                for xb, yb in test_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    pred = model(xb).argmax(-1)
                    correct += (pred == yb).sum().item()
                    total   += yb.size(0)
            print(f"step {step:5d}  loss_s={loss_s.item():.3f}  "
                  f"loss_u={loss_u.item():.3f}  mask_used={mask.mean().item():.2f}  "
                  f"test_acc={100*correct/total:.2f}%")

    Several observations follow from running the code above:

    • For the first few hundred steps, mask_used remains near zero: the model is not yet confident on anything, so the unsupervised term contributes nothing. This is expected; the supervised loss is performing the work.
    • Between approximately step 1,000 and step 3,000, mask_used begins climbing into the 0.2 to 0.6 range, and test accuracy increases noticeably. This is the point at which FixMatch begins to contribute substantively.
    • The 5,000-step budget here is an order of magnitude shorter than that used in the paper. Reproducing the reported 94.93% on CIFAR-10 with 250 labels requires much longer training, a cosine learning-rate schedule, and EMA weights at evaluation time.

    A realistic labelled-only baseline (the same backbone, the same 250 labels, no unlabelled data, with only heavy augmentation) tends to land in the range of 50% to 60% test accuracy. FixMatch approaches 95%. That gap of more than 30 percentage points, from the same 250 labels, is the central result of modern semi-supervised learning.

    Real-World Applications Across Domains

    Semi-supervised learning is most valuable wherever the ratio of labelled to unlabelled data is extreme and the cost of labelling is high.

    Domain Why SSL fits Typical setup
    Medical imaging Radiologist time is expensive; raw DICOMs accumulate 5k labeled scans + 500k unlabeled; FixMatch or Mean Teacher
    Manufacturing QA Defects are rare; passing parts flood the line Few labeled defects, many unlabeled parts; SSL + one-class anomaly models
    NLP (sentiment, NER) Labeled corpora small; web text infinite Backtranslation or UDA on top of a pretrained transformer
    Autonomous driving Segmentation labels cost minutes/frame; fleet logs petabytes Mean Teacher for segmentation; auto-labeling pipelines
    Fraud detection Confirmed frauds are rare; transactions are billions Graph SSL + entropy minimization + active learning loop
    Speech recognition Transcribed audio scarce; raw audio abundant wav2vec 2.0 pretrain + semi-supervised fine-tuning
    Industrial anomaly detection Very few examples of failure; many normal runs Deep SAD (semi-supervised variant of Deep SVDD)

     

    The manufacturing and anomaly-detection cases deserve a particular note: a semi-supervised variant of one-class classification, Deep SAD, builds directly on the Deep SVDD framework. It uses the few labelled abnormal examples to tighten the hypersphere around normal data. For anomaly detection with even a handful of confirmed anomalies, Deep SAD typically outperforms pure Deep SVDD.

    Paradigm Comparison: SSL, Self-SSL, Transfer, Active

    When a stakeholder asks which approach to use, the underlying question is often whether more labelling can be avoided. Several paradigms address this question in different ways.

    Paradigm Data Labeling cost Typical performance When to use
    Fully supervised All labeled High Baseline Labels are cheap or already exist
    Semi-supervised Few labeled + many unlabeled Low Matches supervised at 1–10% labels Labels scarce, unlabeled data plentiful, distributions match
    Self-supervised Unlabeled only (pretrain) None for pretraining Great when scaled to considerable data You need reusable backbones; substantial unlabeled corpus
    Transfer learning Pretrained weights + small labeled Low Strong and fast A suitable pretrained model exists in your modality
    Active learning Iteratively label smartly Medium Maximizes labels ROI Labeling is possible but slow/expensive; you want to budget it
    Domain adaptation Labeled source + unlabeled target Medium Bridges distribution shift Your deployment data differs from your labeled data

     

    These paradigms combine freely. A strong 2026 pipeline might: (1) pretrain a backbone with self-supervised learning, (2) fine-tune with semi-supervised learning on the actual task, (3) apply DANN-style domain adaptation when deploying to a new facility, and (4) use active learning to prioritise which difficult examples to return to human annotators.

    Method Comparison Within SSL

    Method Complexity Typical CIFAR-10 (250 labels) Strengths Weaknesses
    Pseudo-labeling Very low ~60–70% Trivial to implement Confirmation bias, error amplification
    Mean Teacher Medium ~80% Stable; good for regression/segmentation Weaker on classification vs FixMatch
    MixMatch High ~88% Strong with limited tricks Many moving parts; sensitive to sharpening temperature
    FixMatch Medium ~95% Simple, current best, broadly applicable Global threshold can stall on hard classes
    FlexMatch Medium-high ~95.5% Per-class dynamic thresholds; handles curriculum More hyperparameters

     

    Practical Guide: Thresholds, Data Ratios, Pitfalls

    How Much Labelled Data Is Required?

    Empirically, SSL gains are largest when very few labels are available (for example, 4 to 40 per class) and diminish as the count approaches thousands per class. Above roughly 10% of the dataset labelled, FixMatch and related methods tend to converge with the fully supervised baseline. This does not mean SSL is useless above 10%; rather, the marginal advantage of SSL over additional labelling becomes smaller. The most favourable regime is one in which labels are genuinely scarce.

    Key Takeaway: The classic SSL gain curve shows substantial improvements at small labelled fractions (1% to 5%), steady diminution through 10%, and marginal returns by 20%. Labelling budgets should be designed accordingly.

    Choosing a Method

    • Standard image classification. Start with FixMatch. It is a strong default with minimal hyperparameter sensitivity.
    • Regression or segmentation. Mean Teacher adapts more naturally, because the consistency target can be a continuous prediction or pixel map rather than a class.
    • Imbalanced classes or class-dependent difficulty. FlexMatch’s dynamic thresholds prevent the majority classes from absorbing all pseudo-labels.
    • Graph-structured data. Use GCN or GAT directly; both are natively semi-supervised.

    Hyperparameter Tips

    • Confidence threshold τ: 0.95 is the FixMatch default. Lower it (0.7 to 0.8) if mask_used remains near zero for an extended period; raise it if pseudo-labels appear noisy.
    • Unsupervised weight λ: 1.0 typically works. If the supervised loss is unstable early in training, ramp λ from 0 to 1 over the first few epochs.
    • EMA decay (Mean Teacher): 0.999 is standard. Lower values cause the teacher to track the student noisily; higher values cause it to stop learning.
    • Batch size ratio μ: FixMatch uses μ = 7 (seven times more unlabelled per labelled batch). The unlabelled batch must be large enough that confidence-gated pseudo-labels are not all of the same class.

    Common Pitfalls

    • Confirmation bias. The model pseudo-labels unlabelled data confidently but incorrectly, then trains on those incorrect labels. Strong augmentation and confidence thresholding mitigate this risk.
    • Class imbalance. If the labelled set is 90% class A, pseudo-labels will skew toward class A on unlabelled data, reinforcing the imbalance. FlexMatch and distribution alignment (ReMixMatch) address this.
    • Distribution shift. If the labelled data originates from Hospital A and the unlabelled data from Hospital B, SSL can degrade performance. The appropriate response is domain adaptation, either instead of SSL or in conjunction with it.
    • Open-set contamination. The unlabelled set contains classes that are absent from the labelled set. Pseudo-labelling forces these into known classes, corrupting the model.
    • Insufficient iterations. FixMatch requires extended training for mask_used to rise. Judgments should not be made after a single epoch.
    Caution: If the labelled and unlabelled sets originate from different distributions—different hospitals, sensors, geographies, or time periods—semi-supervised learning can actively degrade performance. SSL should always be benchmarked against a supervised baseline on a held-out set that reflects deployment conditions.

    Tools and Libraries

    • USB (Unified Semi-supervised learning Benchmark). PyTorch framework with more than 15 SSL algorithms and a common evaluation harness.
    • TorchSSL. Curated implementations of the classical SSL algorithms for image classification.
    • MMClassification / MMSegmentation. OpenMMLab tools with SSL support for image classification and segmentation.
    • Google’s official FixMatch repository. The paper authors’ reference TensorFlow implementation.

    Connections to Transfer, Active, and Domain Adaptation

    Semi-supervised learning is most powerful when treated not as a standalone technique but as one element of a broader set of complementary methods.

    Semi-Supervised plus Transfer Learning

    A common pattern is to begin with a pretrained backbone (ImageNet, CLIP, wav2vec) and fine-tune it using FixMatch on a small labelled set together with the unlabelled data. This combination routinely outperforms either approach in isolation. The pretrained features provide a head start on representation; SSL allows the model to adapt to the specific label structure. The transfer learning guide presents a concrete version of this pipeline for a cobot anomaly-detection project.

    Semi-Supervised plus Active Learning

    Active learning selects which unlabelled examples are most worth labelling next, while SSL uses the unlabelled examples without labelling them. The combined workflow trains with SSL, identifies examples on which the model is least confident or on which the SSL pseudo-label fluctuated across epochs, sends those to a human annotator, returns them as labelled data, and repeats. This pattern characterises most production labelling pipelines.

    Semi-Supervised plus Domain Adaptation

    If the labelled data (source domain) and unlabelled data (target domain) originate from different distributions, plain SSL will fail. Domain-adversarial training (DANN) or maximum-mean-discrepancy methods align the feature distributions, and once alignment is achieved, SSL can operate effectively. This combination is the basis on which many medical AI systems generalise across hospitals.

    Semi-Supervised plus Self-Supervised

    The two approaches need not be alternatives; they can be stacked. Pretrain with self-supervised learning on a substantial unlabelled corpus (see the self-supervised learning guide), then fine-tune with FixMatch on a small labelled set together with a focused unlabelled set. This combination underlies the modern recipe used in speech (wav2vec 2.0), vision (MAE plus FixMatch fine-tune), and NLP (pretraining plus UDA).

    Statistical reasoning helps explain why additional data tends to assist: as unlabelled examples contribute to parameter estimation, the effective sample size grows and variance falls, a phenomenon closely related to the central limit theorem in parameter estimation.

    Frequently Asked Questions

    What’s the difference between semi-supervised and self-supervised learning?

    Semi-supervised learning uses some human-labeled data plus unlabeled data to solve a specific downstream task directly. Self-supervised learning uses only unlabeled data and invents its own labels from data structure (masking, contrastive pairs) to produce a reusable pretrained backbone, which is later fine-tuned with labeled data on a downstream task. Semi-supervised is a training strategy for a task; self-supervised is a pretraining strategy for representations.

    How many labeled samples do I need for semi-supervised learning?

    The requirement depends on task complexity, but as a rule of thumb, FixMatch-class methods produce substantial gains with as few as 4 to 40 labelled examples per class for image classification. Returns diminish once approximately 10% of the dataset is labelled. For NLP and tabular data the curve is similar, though the inflection often arises with slightly more labels per class due to greater input variability.

    When does semi-supervised learning hurt rather than help?

    SSL can degrade performance when (a) the unlabelled data distribution differs materially from the labelled data distribution, (b) the unlabelled set contains novel classes not present in the labelled set, (c) class imbalance in the labelled set biases the pseudo-labels, or (d) the core assumptions (smoothness, cluster, manifold) do not hold for the data. The SSL model should always be measured against a strong supervised baseline on a held-out set that reflects deployment conditions.

    FixMatch vs MixMatch—which should I use?

    FixMatch is simpler, performs better on most benchmarks, and has fewer hyperparameters. It is the recommended starting point unless a specific reason argues for MixMatch (for example, a separate requirement for MixUp regularisation). MixMatch’s averaging-and-sharpening is conceptually elegant, but its empirical gains have been surpassed by FixMatch’s weak/strong pseudo-label scheme.

    Can I combine semi-supervised learning with transfer learning?

    Yes, and combining them is generally recommended. Initialise with a pretrained backbone (ImageNet, CLIP, or a domain-specific model) and apply FixMatch or Mean Teacher on top. Pretrained weights provide strong features from the start, which means FixMatch’s mask threshold is reached earlier in training and pseudo-labels are more reliable. This combination approximates the default recipe in modern practice.

    References and Further Reading

    External References

    This article is for informational and educational purposes only and does not constitute investment advice.