When Explanations Lie MICCAI Educational Challenge tutorial
Executable notebook Run in Colab GitHub repository Setup instructions

When Explanations Lie: Stress-Testing Saliency Maps Before Clinical Deployment¶

A hands-on MICCAI tutorial on shortcut learning, sanity checks, and trustworthy medical imaging AI

Teaser figure summarizing saliency sanity checks, shortcut learning, and quantitative explanation tests.

In [1]:
# Colab setup: clone the repository when opened directly from GitHub,
# then install the small set of packages that may be missing.
from pathlib import Path
import subprocess
import sys

REPO_URL = "https://github.com/Kirscher/when-explanations-lie.git"
REPO_DIR = Path("when-explanations-lie")
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    if not Path("src").exists() and not (REPO_DIR / "src").exists():
        subprocess.run(["git", "clone", "--depth", "1", REPO_URL, str(REPO_DIR)], check=True)
    if (REPO_DIR / "src").exists():
        %cd when-explanations-lie

    missing_packages = []
    for module_name, package_name in {
        "medmnist": "medmnist",
        "captum": "captum",
    }.items():
        try:
            __import__(module_name)
        except ModuleNotFoundError:
            missing_packages.append(package_name)

    if missing_packages:
        %pip -q install {" ".join(missing_packages)} tqdm scikit-learn nbformat nbconvert

Abstract¶

Saliency maps are widely used to explain medical imaging AI systems, but visually plausible explanations can be misleading. This hands-on notebook teaches students how to implement common saliency methods and, more importantly, how to stress-test them. Using PneumoniaMNIST, participants train a small chest X-ray classifier, implement vanilla gradients, SmoothGrad, Integrated Gradients, occlusion sensitivity, and Grad-CAM, and evaluate explanations using randomization tests, shortcut-learning experiments, deletion curves, stability analysis, and shortcut-region attribution. This is a controlled, clinically motivated teaching experiment rather than a deployable pneumonia detector: participants leave with a reproducible debugging workflow for asking whether explanations are sensitive to model parameters, robust to perturbations, and capable of revealing non-clinical shortcuts.

Reviewer-facing contribution statement. This tutorial contributes:

  1. A reproducible saliency stress-test workflow for medical image classifiers.
  2. A controlled shortcut-learning experiment with known non-clinical artifacts.
  3. Transparent implementations: minimal PyTorch implementations of common saliency methods so learners can see how each explanation is computed.
  4. A practical checklist for using explanations as debugging tools before clinical deployment.

Pedagogical implementation note. All explanation methods in this tutorial are implemented from scratch in compact PyTorch code under src/, rather than delegated to a black-box interpretability library. This is intentional: the goal is to make the assumptions, gradients, baselines, smoothing, occlusion masks, and Grad-CAM hooks visible to learners. These implementations are designed for teaching and reproducibility, not as optimized clinical software.

Who is this tutorial for?¶

This notebook is for incoming Master's and PhD students in medical image computing who know basic Python and deep learning. The goal is educational: build enough of the pipeline to understand where saliency maps can fail, not to produce a deployable pneumonia detector.

Learning objectives¶

By the end of this tutorial, you will be able to:

  1. Implement common saliency methods for medical image classifiers.
  2. Explain why visually plausible heatmaps can be misleading.
  3. Run model, layer, and label-randomization sanity checks.
  4. Detect shortcut learning using controlled artifacts.
  5. Interpret deletion curves, stability tests, and shortcut-region attribution.
  6. Describe why explanation quality is not the same as clinical validity.
  7. Connect explanation failures to real-world clinical deployment risks.

Clinical motivation: from pixels to patients¶

Imagine a hospital deploys a chest X-ray triage model for pneumonia. The model appears accurate, and its saliency maps seem to highlight the lungs. Later, performance collapses at another hospital. Investigation reveals that the model relied partly on non-clinical image artifacts: scanner-specific borders, preprocessing marks, or labels correlated with the training site.

This tutorial demonstrates why attractive heatmaps are not enough. We will train a simple medical image classifier, generate explanations, and then deliberately stress-test those explanations. The goal is not to build a deployable pneumonia detector, but to learn practical checks that should happen before explainability tools are used to support patient-facing decisions. Similar concerns appear in reports of cross-site pneumonia generalization failures and shortcut use in radiographic COVID-19 models [7, 8].

Setup and reproducibility¶

Run the cells in order. The default configuration uses the full PneumoniaMNIST splits with a moderate epoch budget. Set FAST_DEV_RUN = True for a shorter smoke test while developing or teaching under tight time constraints. Seeds improve reproducibility, but deterministic GPU execution can be slower, so this tutorial leaves deterministic kernels off by default.

Runtime and data box. The default setting uses the full PneumoniaMNIST train, validation, and test splits, 10 baseline epochs, 5 label-noise epochs, and 8 shortcut epochs. Runtime depends on hardware, dependency installation, and the first dataset download. No private data are required.

Materials included. The submission package includes this rendered HTML, the executable notebook, the src/ package used by the notebook, requirements.txt and environment.yml, and a README.md with the local and Colab run paths.

In [2]:
from pathlib import Path
import copy
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = NOTEBOOK_DIR.parent if NOTEBOOK_DIR.name == "notebooks" else NOTEBOOK_DIR
sys.path.insert(0, str(REPO_ROOT))

import medmnist
from medmnist import INFO
import sklearn

from src.data import (
    DATA_FLAG,
    RandomLabelDataset,
    ShortcutDataset,
    count_labels,
    load_pneumoniamnist,
    make_subset,
    marker_mask_for_label,
)
from src.explainers import GradCAM, integrated_gradients, occlusion_sensitivity, smoothgrad, vanilla_gradient
from src.metrics import (
    binary_clinical_metrics,
    confusion_matrix_np,
    deletion_auc,
    deletion_curve,
    map_correlation,
    perturb_image,
    shortcut_region_attribution_fraction,
)
from src.models import SmallMedicalCNN, evaluate, find_example, reset_module_parameters, train_one_epoch
from src.utils import get_device, print_versions, set_seed
from src.visualization import (
    denorm,
    overlay_explanation,
    plot_confusion_matrix,
    plot_explanation_grid,
    plot_metric_bar,
    show_batch,
    show_image_tensor,
)

%matplotlib inline

set_seed(42, deterministic=False)
device = get_device()
print("Using device:", device)
print_versions({"scikit-learn": sklearn, "MedMNIST": medmnist})
Using device: cuda
Python: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0]
PyTorch: 2.11.0+cu130
NumPy: 2.4.4
scikit-learn: 1.8.0
MedMNIST: 3.0.2
In [3]:
FAST_DEV_RUN = False
IMAGE_SIZE = 64
BATCH_SIZE = 128
NUM_WORKERS = 0
LR = 1e-3

if FAST_DEV_RUN:
    MAX_TRAIN_SAMPLES = 2048
    MAX_VAL_SAMPLES = 512
    MAX_TEST_SAMPLES = 512
    EPOCHS_BASELINE = 5
    EPOCHS_RANDOM_LABEL = 3
    EPOCHS_SHORTCUT = 5
    SMOOTHGRAD_SAMPLES = 12
    IG_STEPS = 24
else:
    MAX_TRAIN_SAMPLES = None
    MAX_VAL_SAMPLES = None
    MAX_TEST_SAMPLES = None
    EPOCHS_BASELINE = 10
    EPOCHS_RANDOM_LABEL = 5
    EPOCHS_SHORTCUT = 8
    SMOOTHGRAD_SAMPLES = 20
    IG_STEPS = 32

print({
    "FAST_DEV_RUN": FAST_DEV_RUN,
    "EPOCHS_BASELINE": EPOCHS_BASELINE,
    "EPOCHS_RANDOM_LABEL": EPOCHS_RANDOM_LABEL,
    "EPOCHS_SHORTCUT": EPOCHS_SHORTCUT,
})
{'FAST_DEV_RUN': False, 'EPOCHS_BASELINE': 10, 'EPOCHS_RANDOM_LABEL': 5, 'EPOCHS_SHORTCUT': 8}

Dataset: PneumoniaMNIST¶

PneumoniaMNIST is part of MedMNIST v2, a lightweight biomedical image classification benchmark [1]. The task is binary pediatric chest X-ray classification: normal versus pneumonia. We normalize images to [-1, 1], which matters later when choosing baselines and occlusion values.

In [4]:
info = INFO[DATA_FLAG]
print("Dataset:", info["description"][:500], "...")

bundle = load_pneumoniamnist(
    medmnist_module=medmnist,
    info=info,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    max_train_samples=MAX_TRAIN_SAMPLES,
    max_val_samples=MAX_VAL_SAMPLES,
    max_test_samples=MAX_TEST_SAMPLES,
)

train_ds, val_ds, test_ds = bundle.train_ds, bundle.val_ds, bundle.test_ds
train_loader, val_loader, test_loader = bundle.train_loader, bundle.val_loader, bundle.test_loader
CLASS_NAMES = bundle.class_names
NUM_CLASSES = bundle.num_classes

print("Classes:", CLASS_NAMES)
print("Used split sizes:", len(train_ds), len(val_ds), len(test_ds))
show_batch(train_loader, CLASS_NAMES, n=12)

counts = count_labels(train_ds, NUM_CLASSES)
plt.figure(figsize=(5, 3))
plt.bar(CLASS_NAMES, counts)
plt.title("Training class distribution")
plt.ylabel("Number of images")
plt.xticks(rotation=20)
plt.show()
print(dict(zip(CLASS_NAMES, counts.tolist())))
Dataset: The PneumoniaMNIST is based on a prior dataset of 5,856 pediatric chest X-Ray images. The task is binary-class classification of pneumonia against normal. We split the source training set with a ratio of 9:1 into training and validation set and use its source validation set as the test set. The source images are gray-scale, and their sizes are (384−2,916)×(127−2,713). We center-crop the images and resize them into 1×28×28. ...
Classes: ['normal', 'pneumonia']
Used split sizes: 4708 524 624
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.
{'normal': 1214, 'pneumonia': 3494}

Figure caption: the image grid shows the normalized PneumoniaMNIST samples used in the teaching run, followed by the class distribution of the selected training subset.

Training a baseline pneumonia classifier¶

The classifier is intentionally small so students can inspect the full pipeline. It is a teaching model, not a clinical device. The training curve helps identify obvious underfitting or overfitting before we interpret any explanation.

In [5]:
baseline_model = SmallMedicalCNN(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(device)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

for epoch in range(EPOCHS_BASELINE):
    train_loss, train_acc = train_one_epoch(baseline_model, train_loader, optimizer, criterion, device)
    val_out = evaluate(baseline_model, val_loader, device, criterion)
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_out["loss"])
    history["val_acc"].append(val_out["acc"])
    print(
        f"Epoch {epoch + 1:02d}/{EPOCHS_BASELINE} | "
        f"train loss {train_loss:.3f}, train acc {train_acc:.3f} | "
        f"val loss {val_out['loss']:.3f}, val acc {val_out['acc']:.3f}"
    )

plt.figure(figsize=(6, 4))
plt.plot(history["train_acc"], marker="o", label="train")
plt.plot(history["val_acc"], marker="o", label="validation")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Baseline training curve")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

test_out = evaluate(baseline_model, test_loader, device, criterion)
print(f"Test accuracy: {test_out['acc']:.3f}")
Epoch 01/10 | train loss 0.451, train acc 0.778 | val loss 0.242, val acc 0.939
Epoch 02/10 | train loss 0.203, train acc 0.894 | val loss 0.164, val acc 0.945
Epoch 03/10 | train loss 0.157, train acc 0.937 | val loss 0.144, val acc 0.948
Epoch 04/10 | train loss 0.121, train acc 0.957 | val loss 0.106, val acc 0.962
Epoch 05/10 | train loss 0.120, train acc 0.955 | val loss 0.096, val acc 0.968
Epoch 06/10 | train loss 0.106, train acc 0.962 | val loss 0.086, val acc 0.968
Epoch 07/10 | train loss 0.097, train acc 0.963 | val loss 0.083, val acc 0.964
Epoch 08/10 | train loss 0.085, train acc 0.972 | val loss 0.074, val acc 0.971
Epoch 09/10 | train loss 0.081, train acc 0.971 | val loss 0.069, val acc 0.971
Epoch 10/10 | train loss 0.072, train acc 0.975 | val loss 0.074, val acc 0.966
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Test accuracy: 0.875

Clinical metrics beyond accuracy¶

In clinical screening or triage, overall accuracy can hide clinically important errors. Sensitivity reflects how often pneumonia cases are detected, while specificity reflects how often normal cases are correctly ruled out. A model with high accuracy but poor sensitivity may be unsafe in a screening workflow.

In [6]:
labels_np = test_out["labels"].numpy()
preds_np = test_out["preds"].numpy()
probs_np = test_out["probs"].numpy()

cm = confusion_matrix_np(labels_np, preds_np, NUM_CLASSES)
plot_confusion_matrix(cm, CLASS_NAMES, title="Baseline test confusion matrix")

clinical_metrics = binary_clinical_metrics(labels_np, preds_np, probs_np, positive_class=1)
for name, value in clinical_metrics.items():
    print(f"{name}: {value:.3f}")
Notebook-generated figure; see the surrounding heading and caption for interpretation.
accuracy: 0.875
sensitivity_recall_pneumonia: 0.987
specificity_normal: 0.688
balanced_accuracy: 0.838
auroc: 0.941

Figure caption and clinical note: the confusion matrix summarizes normal-versus-pneumonia errors in the fast run, and the printed metrics report sensitivity and specificity. This model catches most pneumonia cases in the fast run, but it overcalls pneumonia in many normal images; that tradeoff would matter in triage because false positives can burden clinical workflow.

In [7]:
def find_or_fallback(model, loader, desired_label):
    try:
        return find_example(model, loader, device, desired_label=desired_label, correct=True)
    except ValueError:
        return find_example(model, loader, device, desired_label=desired_label, correct=False)

x_pneumonia, y_pneumonia, pred_pneumonia, prob_pneumonia = find_or_fallback(baseline_model, test_loader, desired_label=1)
x_normal, y_normal, pred_normal, prob_normal = find_or_fallback(baseline_model, test_loader, desired_label=0)

print("Pneumonia example:", CLASS_NAMES[y_pneumonia], "pred=", CLASS_NAMES[pred_pneumonia], "probs=", prob_pneumonia.numpy())
print("Normal example:", CLASS_NAMES[y_normal], "pred=", CLASS_NAMES[pred_normal], "probs=", prob_normal.numpy())
show_image_tensor(x_pneumonia, f"True: {CLASS_NAMES[y_pneumonia]} | Pred: {CLASS_NAMES[pred_pneumonia]}")
show_image_tensor(x_normal, f"True: {CLASS_NAMES[y_normal]} | Pred: {CLASS_NAMES[pred_normal]}")
Pneumonia example: pneumonia pred= pneumonia probs= [3.4789142e-07 9.9999964e-01]
Normal example: normal pred= normal probs= [0.99732214 0.00267792]
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the two example images selected for saliency analysis show one pneumonia case and one normal case with the model's predicted label in each title.

Saliency methods: intuition and implementation¶

For visualization, each saliency map is normalized independently. This improves visual contrast but can make weak and strong explanations look equally intense. Quantitative comparisons should use unnormalized maps or a shared normalization scale.

Taking the absolute value shows sensitivity regardless of direction, but it hides whether a pixel supports or suppresses the predicted class. Signed attributions may be preferable when the direction of evidence matters.

Explanation vocabulary¶

  • Plausibility: whether a heatmap looks reasonable to a human observer.
  • Faithfulness: whether the map changes when the learned model or target behavior changes.
  • Localization: whether attribution falls on a known region of interest, such as our synthetic marker.
  • Clinical validity: whether the model and its explanations remain useful, safe, and calibrated in a real clinical workflow.

These concepts are related but not interchangeable. A map can be plausible without being faithful, and a faithful debugging signal is still not proof of clinical validity.

Vanilla gradients¶

Vanilla saliency computes the gradient of a class score with respect to the input pixels:

$$ S(x) = \left| \frac{\partial f_c(x)}{\partial x} \right| $$

Here x is the input image and f_c(x) is the score for class c. Large values indicate pixels where small local changes strongly affect the class score.

In [8]:
vg = vanilla_gradient(baseline_model, x_pneumonia, pred_pneumonia)
vg_signed = vanilla_gradient(baseline_model, x_pneumonia, pred_pneumonia, signed=True)
overlay_explanation(x_pneumonia, vg, "Vanilla gradient")
overlay_explanation(x_pneumonia, vg_signed, "Signed gradient (direction retained)")
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.

SmoothGrad¶

SmoothGrad averages gradients over noisy copies of the same image [3]:

$$ S_{\text{SmoothGrad}}(x) = \frac{1}{N} \sum_{i=1}^{N} \left| \frac{\partial f_c(x + \epsilon_i)}{\partial x} \right|, \quad \epsilon_i \sim \mathcal{N}(0, \sigma^2) $$

Averaging can reduce visual noise. Because our images are normalized to [-1, 1], noisy inputs are clamped back to that range to avoid explaining unrealistic out-of-range images.

In [9]:
sg = smoothgrad(baseline_model, x_pneumonia, pred_pneumonia, n_samples=SMOOTHGRAD_SAMPLES, noise_std=0.08)
overlay_explanation(x_pneumonia, sg, "SmoothGrad")
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Integrated Gradients¶

Integrated Gradients was proposed by Sundararajan et al. [2]. It integrates gradients along a path from a baseline image x' to the input image:

$$ \mathrm{IG}_i(x) = (x_i - x'_i) \int_{\alpha=0}^{1} \frac{\partial f_c(x' + \alpha(x - x'))}{\partial x_i} \, d\alpha $$

Here the default baseline is a mean-gray image in normalized space. A true black image would be represented by -1 after normalization.

In [10]:
baseline_mode = "gray"  # options: "gray", "black"
ig = integrated_gradients(
    baseline_model,
    x_pneumonia,
    pred_pneumonia,
    baseline_mode=baseline_mode,
    steps=IG_STEPS,
)
overlay_explanation(x_pneumonia, ig, f"Integrated Gradients ({baseline_mode} baseline)")
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Extension: compare Integrated Gradients with Captum¶

Our in-repository Integrated Gradients implementation is intentionally explicit: it builds points along the straight-line path from a baseline image to the input image, backpropagates the target-class logit at each point, averages those gradients, and multiplies by input - baseline. That makes the algorithm inspectable for teaching.

In practice, you may prefer a maintained interpretability library. Captum provides IntegratedGradients for PyTorch models, accepts explicit baselines and target classes, supports several numerical integration rules, and can return a convergence delta. The cell below reproduces the same idea with Captum and compares it with the transparent tutorial implementation.

Captum documentation: https://captum.ai/api/integrated_gradients.html.

In [11]:
from captum.attr import IntegratedGradients as CaptumIntegratedGradients

x_for_captum = x_pneumonia.to(device)
if baseline_mode == "gray":
    captum_baseline = torch.zeros_like(x_for_captum)
elif baseline_mode == "black":
    captum_baseline = torch.full_like(x_for_captum, -1.0)
else:
    raise ValueError("baseline_mode must be 'gray' or 'black'")

captum_ig = CaptumIntegratedGradients(baseline_model)
captum_attr, captum_delta = captum_ig.attribute(
    x_for_captum,
    baselines=captum_baseline,
    target=pred_pneumonia,
    n_steps=IG_STEPS,
    method="riemann_right",
    return_convergence_delta=True,
)
captum_map = captum_attr.detach().abs().max(dim=1)[0][0].cpu()

print("Correlation between tutorial IG and Captum IG:", map_correlation(ig, captum_map))
print("Captum convergence delta:", captum_delta.detach().cpu().numpy())
plot_explanation_grid(
    x_pneumonia,
    {"Tutorial IG": ig, "Captum IG": captum_map},
    "Integrated Gradients: transparent implementation vs Captum",
)
Correlation between tutorial IG and Captum IG: 0.9998345971107483
Captum convergence delta: [0.06916523]
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Occlusion sensitivity¶

Occlusion sensitivity asks how the target probability changes when a region is hidden:

$$ S_R(x) = f_c(x) - f_c(x_{\setminus R}) $$

R is an image region and x_{\setminus R} is the image with that region occluded. A large drop means the region was important for the prediction. Perturbation methods are related to meaningful perturbation approaches [6].

In [12]:
occ = occlusion_sensitivity(baseline_model, x_pneumonia, pred_pneumonia, patch_size=8, stride=4, fill_value=0.0)
overlay_explanation(x_pneumonia, occ, "Occlusion sensitivity")
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Grad-CAM¶

Grad-CAM weights convolutional feature maps using gradients flowing into a target layer [4]:

$$ \alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial f_c}{\partial A_{ij}^k} $$

$$ L_{\text{Grad-CAM}}^c = \mathrm{ReLU} \left( \sum_k \alpha_k^c A^k \right) $$

A^k is feature map k, and \alpha_k^c measures how important that feature map is for class c. The ReLU keeps features that positively support the class.

In [13]:
gradcam = GradCAM(baseline_model, baseline_model.features[6])
gcam = gradcam(x_pneumonia, pred_pneumonia)
gradcam.remove()
overlay_explanation(x_pneumonia, gcam, "Grad-CAM")

heatmaps_baseline = {
    "Gradient": vg,
    "SmoothGrad": sg,
    "Integrated Gradients": ig,
    "Occlusion": occ,
    "Grad-CAM": gcam,
}
plot_explanation_grid(x_pneumonia, heatmaps_baseline, "Baseline model explanations")
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the baseline explanation grid compares five saliency methods on the same pneumonia example. Heatmaps are normalized independently for visibility, so color intensity should not be compared as an absolute score across methods.

Why plausible heatmaps can be misleading¶

Visual plausibility is not faithfulness. A map can look anatomical because chest X-rays have strong edges and repeated structures, even if the model parameters or labels are wrong. Sanity checks for saliency maps were popularized by Adebayo et al. [5] and are a practical first line of defense before explanation claims are used in clinical discussions.

Sanity check 1: model randomization¶

We compare the trained model with a randomly initialized model. Using the same target class isolates whether the explanation depends on learned model parameters, rather than changing simply because we explained a different output class.

In [14]:
random_model = SmallMedicalCNN(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(device)
random_model.eval()

with torch.no_grad():
    random_pred = random_model(x_pneumonia).argmax(dim=1).item()

vg_random_pred = vanilla_gradient(random_model, x_pneumonia, random_pred)
vg_random_fixed_target = vanilla_gradient(random_model, x_pneumonia, pred_pneumonia)

print("Trained prediction:", CLASS_NAMES[pred_pneumonia])
print("Random model prediction:", CLASS_NAMES[random_pred])
print("Gradient map correlation trained vs random, random target:", map_correlation(vg, vg_random_pred))
print("Gradient map correlation trained vs random, fixed target:", map_correlation(vg, vg_random_fixed_target))

plot_explanation_grid(
    x_pneumonia,
    {"Trained model": vg, "Random model, fixed target": vg_random_fixed_target},
    "Model randomization test",
)
Trained prediction: pneumonia
Random model prediction: pneumonia
Gradient map correlation trained vs random, random target: 0.04296400025486946
Gradient map correlation trained vs random, fixed target: 0.04296400025486946
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the model-randomization sanity check compares the trained-model gradient map with a randomly initialized model for the same target class. A faithful explanation should usually change when the learned parameters are replaced.

Sanity check 2: layer randomization¶

Now we randomize parts of the trained network. If explanations barely change after classifier or feature layers are reset, the method may be reflecting image structure more than learned decision logic.

In [15]:
classifier_randomized = copy.deepcopy(baseline_model)
reset_module_parameters(classifier_randomized.classifier)
classifier_randomized.eval()

features_randomized = copy.deepcopy(baseline_model)
reset_module_parameters(features_randomized.features)
features_randomized.eval()

vg_classifier_rand = vanilla_gradient(classifier_randomized, x_pneumonia, pred_pneumonia)
vg_features_rand = vanilla_gradient(features_randomized, x_pneumonia, pred_pneumonia)

print("Correlation original vs classifier-randomized, fixed target:", map_correlation(vg, vg_classifier_rand))
print("Correlation original vs features-randomized, fixed target:", map_correlation(vg, vg_features_rand))

plot_explanation_grid(
    x_pneumonia,
    {"Original": vg, "Classifier randomized": vg_classifier_rand, "Features randomized": vg_features_rand},
    "Layer randomization test",
)
Correlation original vs classifier-randomized, fixed target: 0.3412981927394867
Correlation original vs features-randomized, fixed target: 0.08532992750406265
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the layer-randomization sanity check shows whether resetting the classifier or feature extractor changes the explanation. High similarity after randomization is a warning that the map may be dominated by image structure rather than learned decision logic.

Sanity check 3: label-noise sanity check¶

This lightweight label-noise demonstration is not a full random-label memorization experiment. The point is narrower: even when the supervised task is corrupted, explanations can still look structured. A stronger memorization experiment would train longer on a small random-label subset.

In [16]:
random_label_base = make_subset(train_ds, max_samples=min(512, len(train_ds)), seed=7)
random_label_ds = RandomLabelDataset(random_label_base, num_classes=NUM_CLASSES, seed=123)
random_label_loader = DataLoader(random_label_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

random_label_model = SmallMedicalCNN(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(device)
optimizer = torch.optim.Adam(random_label_model.parameters(), lr=LR)

for epoch in range(EPOCHS_RANDOM_LABEL):
    loss, acc = train_one_epoch(random_label_model, random_label_loader, optimizer, criterion, device)
    print(f"Label-noise epoch {epoch + 1:02d}/{EPOCHS_RANDOM_LABEL} | train loss {loss:.3f}, train acc {acc:.3f}")

with torch.no_grad():
    pred_rl = random_label_model(x_pneumonia).argmax(dim=1).item()
vg_rl = vanilla_gradient(random_label_model, x_pneumonia, pred_pneumonia)
print("Label-noise model prediction:", CLASS_NAMES[pred_rl])
print("Correlation baseline vs label-noise explanation, fixed target:", map_correlation(vg, vg_rl))
plot_explanation_grid(x_pneumonia, {"Baseline": vg, "Label-noise model": vg_rl}, "Label-noise sanity check")
Label-noise epoch 01/5 | train loss 0.701, train acc 0.531
Label-noise epoch 02/5 | train loss 0.697, train acc 0.486
Label-noise epoch 03/5 | train loss 0.694, train acc 0.496
Label-noise epoch 04/5 | train loss 0.690, train acc 0.537
Label-noise epoch 05/5 | train loss 0.688, train acc 0.535
Label-noise model prediction: normal
Correlation baseline vs label-noise explanation, fixed target: 0.03335356339812279
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the label-noise sanity check compares the baseline explanation with a model trained on corrupted labels. Structured-looking heatmaps can still appear even when the supervised task is not clinically meaningful.

Shortcut learning experiment¶

In medical imaging, shortcuts can include scanner artifacts, burned-in text, portable AP markers, scanner borders, institution-specific preprocessing, acquisition protocols, or dataset-source leakage. We simulate a controlled non-clinical marker so we know exactly where the shortcut is and can test whether explanation methods reveal it.

Add a synthetic non-clinical marker¶

Pneumonia images receive a bright top-left marker; normal images receive a bright bottom-right marker. The marker is easy for a CNN to learn but has no clinical meaning.

In [17]:
shortcut_train_ds = ShortcutDataset(train_ds, marker_size=8)
shortcut_val_ds = ShortcutDataset(val_ds, marker_size=8)
shortcut_test_ds = ShortcutDataset(test_ds, marker_size=8)
loader_kwargs = {"batch_size": BATCH_SIZE, "num_workers": NUM_WORKERS, "pin_memory": torch.cuda.is_available()}
shortcut_train_loader = DataLoader(shortcut_train_ds, shuffle=True, **loader_kwargs)
shortcut_val_loader = DataLoader(shortcut_val_ds, shuffle=False, **loader_kwargs)
shortcut_test_loader = DataLoader(shortcut_test_ds, shuffle=False, **loader_kwargs)
show_batch(shortcut_train_loader, CLASS_NAMES, n=12)
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the shortcut examples show the synthetic label-coded corner marker: pneumonia images receive a bright top-left marker and normal images receive a bright bottom-right marker.

Train a shortcut model¶

A high score on marked data is not evidence of clinically meaningful learning. The clean test set helps reveal whether the classifier depends on the artificial marker.

In [18]:
shortcut_model = SmallMedicalCNN(num_classes=NUM_CLASSES, image_size=IMAGE_SIZE).to(device)
optimizer = torch.optim.Adam(shortcut_model.parameters(), lr=LR)
shortcut_history = {"train_acc": [], "val_acc": []}

for epoch in range(EPOCHS_SHORTCUT):
    train_loss, train_acc = train_one_epoch(shortcut_model, shortcut_train_loader, optimizer, criterion, device)
    val_out = evaluate(shortcut_model, shortcut_val_loader, device, criterion)
    shortcut_history["train_acc"].append(train_acc)
    shortcut_history["val_acc"].append(val_out["acc"])
    print(f"Shortcut epoch {epoch + 1:02d}/{EPOCHS_SHORTCUT} | train acc {train_acc:.3f} | val acc {val_out['acc']:.3f}")

plt.figure(figsize=(6, 4))
plt.plot(shortcut_history["train_acc"], marker="o", label="marked train")
plt.plot(shortcut_history["val_acc"], marker="o", label="marked validation")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Shortcut model training curve")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
Shortcut epoch 01/8 | train acc 0.912 | val acc 1.000
Shortcut epoch 02/8 | train acc 0.999 | val acc 1.000
Shortcut epoch 03/8 | train acc 1.000 | val acc 1.000
Shortcut epoch 04/8 | train acc 1.000 | val acc 1.000
Shortcut epoch 05/8 | train acc 1.000 | val acc 1.000
Shortcut epoch 06/8 | train acc 1.000 | val acc 1.000
Shortcut epoch 07/8 | train acc 1.000 | val acc 1.000
Shortcut epoch 08/8 | train acc 1.000 | val acc 1.000
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Compare clean vs marked test performance¶

If the marked and clean test results differ, the marker may be carrying predictive signal. This mirrors the clinical risk of models performing well in one hospital but failing after a shift in scanners, protocols, or preprocessing.

In [19]:
shortcut_marked_test = evaluate(shortcut_model, shortcut_test_loader, device, criterion)
shortcut_clean_test = evaluate(shortcut_model, test_loader, device, criterion)
print(f"Shortcut model accuracy on marked test images: {shortcut_marked_test['acc']:.3f}")
print(f"Shortcut model accuracy on clean test images:  {shortcut_clean_test['acc']:.3f}")
Shortcut model accuracy on marked test images: 1.000
Shortcut model accuracy on clean test images:  0.671
In [20]:
x_short, y_short, pred_short, prob_short = find_or_fallback(shortcut_model, shortcut_test_loader, desired_label=1)
print("Shortcut example true:", CLASS_NAMES[y_short], "pred:", CLASS_NAMES[pred_short], "probs:", prob_short.numpy())
show_image_tensor(x_short, "Marked shortcut image")

vg_short = vanilla_gradient(shortcut_model, x_short, pred_short)
sg_short = smoothgrad(shortcut_model, x_short, pred_short, n_samples=SMOOTHGRAD_SAMPLES, noise_std=0.08)
ig_short = integrated_gradients(shortcut_model, x_short, pred_short, baseline_mode=baseline_mode, steps=IG_STEPS)
occ_short = occlusion_sensitivity(shortcut_model, x_short, pred_short, patch_size=8, stride=4)
gradcam_short = GradCAM(shortcut_model, shortcut_model.features[6])
gcam_short = gradcam_short(x_short, pred_short)
gradcam_short.remove()

heatmaps_shortcut = {
    "Gradient": vg_short,
    "SmoothGrad": sg_short,
    "Integrated Gradients": ig_short,
    "Occlusion": occ_short,
    "Grad-CAM": gcam_short,
}
plot_explanation_grid(x_short, heatmaps_shortcut, "Shortcut model explanations")
Shortcut example true: pneumonia pred: pneumonia probs: [1.0583699e-12 1.0000000e+00]
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the marked shortcut image is the individual test example used for the shortcut-model explanations below.

Figure caption: shortcut-model explanations are shown for a marked pneumonia image. Grad-CAM can miss tiny corner artifacts because its map is computed from low-resolution convolutional features; this is a limitation of spatial granularity, not automatic evidence that the model ignored the marker.

Remove the marker and measure probability sensitivity¶

If removing the marker changes the prediction or target probability, that is direct evidence that a non-clinical feature influenced the model. A class flip is strong evidence, but it is not required: a smaller probability change still shows marker sensitivity, and no change in one image does not rule out shortcut use across the dataset.

In [21]:
def remove_markers(x, marker_size=8, fill_value=0.0):
    x = x.clone()
    x[:, :, :marker_size, :marker_size] = fill_value
    x[:, :, -marker_size:, -marker_size:] = fill_value
    return x

x_short_clean = remove_markers(x_short, marker_size=8, fill_value=0.0)
with torch.no_grad():
    probs_with = F.softmax(shortcut_model(x_short), dim=1)[0].cpu()
    probs_without = F.softmax(shortcut_model(x_short_clean), dim=1)[0].cpu()

print("With marker probs:   ", {CLASS_NAMES[i]: float(probs_with[i]) for i in range(NUM_CLASSES)})
print("Without marker probs:", {CLASS_NAMES[i]: float(probs_without[i]) for i in range(NUM_CLASSES)})

plt.figure(figsize=(8, 3.5))
plt.subplot(1, 2, 1)
plt.imshow(denorm(x_short[0, 0].detach().cpu()), cmap="gray")
plt.title("With marker")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(denorm(x_short_clean[0, 0].detach().cpu()), cmap="gray")
plt.title("Marker removed")
plt.axis("off")
plt.show()
With marker probs:    {'normal': 1.0580105585036859e-12, 'pneumonia': 1.0}
Without marker probs: {'normal': 0.6006954312324524, 'pneumonia': 0.3993045687675476}
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption and interpretation: the two panels compare the same shortcut image before and after replacing both possible marker corners with the normalized gray fill value. In the saved full-split run, removing the marker changes the selected example's class probabilities substantially. Depending on the run, this may appear as a confidence drop or a class flip; the larger clean-versus-marked test gap above is the stronger dataset-level shortcut signal.

Quantify attribution to the shortcut region¶

The shortcut-region attribution fraction estimates how much attribution lies inside the artificial marker:

$$ F_{\text{shortcut}} = \frac{ \sum_{i \in R_{\text{shortcut}}} S_i }{ \sum_i S_i + \epsilon } $$

This is a simple diagnostic, not a clinical validity score.

In [22]:
mask_short = marker_mask_for_label(y_short, IMAGE_SIZE, IMAGE_SIZE, marker_size=8)
fractions = {name: shortcut_region_attribution_fraction(h, mask_short) for name, h in heatmaps_shortcut.items()}
plot_metric_bar(
    fractions,
    ylabel="Fraction of attribution in marker",
    title="How much attribution falls in the non-clinical shortcut?",
    ylim=(0, max(0.2, max(fractions.values()) * 1.2)),
)
fractions
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Out[22]:
{'Gradient': 0.2587019184277326,
 'SmoothGrad': 0.2724421616220347,
 'Integrated Gradients': 0.5697475097826117,
 'Occlusion': 0.9999995758673301,
 'Grad-CAM': 0.0200277234358911}

Figure caption: the shortcut-region attribution plot reports the fraction of each method's attribution that falls inside the known marker region. Higher values mean the method localized more attribution to the non-clinical artifact.

Quantitative explanation evaluation¶

Visual inspection should be paired with quantitative stress tests. These metrics are diagnostics for model debugging; none of them proves clinical validity.

Map correlation¶

Map correlation compares two flattened heatmaps. We used it above to ask whether explanations change after randomizing model parameters.

In [23]:
print("Self-correlation:", map_correlation(vg, vg))
print("Baseline vs model-randomized fixed target:", map_correlation(vg, vg_random_fixed_target))
print("Baseline vs label-noise fixed target:", map_correlation(vg, vg_rl))
Self-correlation: 0.9999999403953552
Baseline vs model-randomized fixed target: 0.04296400025486946
Baseline vs label-noise fixed target: 0.03335356339812279

Deletion curves¶

A deletion curve removes pixels in order of saliency and tracks the target score:

$$ D(t) = f_c(x_{\text{deleted up to fraction } t}) $$

A sharper confidence drop usually indicates a more class-relevant explanation.

Deletion AUC¶

Deletion curves are useful stress tests, but replacing pixels can create unrealistic images. Interpret deletion AUC as a relative diagnostic, not a direct measure of clinical explanation quality. Lower AUC means the target probability fell faster after deleting high-attribution pixels.

In [24]:
def plot_deletion_curves(model, x, heatmaps, target_class, title="Deletion curves"):
    aucs = {}
    plt.figure(figsize=(7, 4.5))
    for name, heatmap in heatmaps.items():
        xs, ys = deletion_curve(model, x, heatmap, target_class, steps=20, fill_value=0.0)
        aucs[name] = deletion_auc(xs, ys)
        plt.plot(xs, ys, marker="o", markersize=3, label=f"{name} AUC={aucs[name]:.3f}")
    plt.xlabel("Fraction of pixels deleted")
    plt.ylabel("Target class probability")
    plt.title(title)
    plt.legend(fontsize=8)
    plt.grid(alpha=0.3)
    plt.show()
    return aucs

baseline_deletion_aucs = plot_deletion_curves(
    baseline_model, x_pneumonia, heatmaps_baseline, pred_pneumonia, title="Deletion curves: baseline model"
)
shortcut_deletion_aucs = plot_deletion_curves(
    shortcut_model, x_short, heatmaps_shortcut, pred_short, title="Deletion curves: shortcut model"
)
print("Baseline deletion AUC:", baseline_deletion_aucs)
print("Shortcut deletion AUC:", shortcut_deletion_aucs)
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Baseline deletion AUC: {'Gradient': 0.9978689438139554, 'SmoothGrad': 0.9985082889324985, 'Integrated Gradients': 0.9898751714426908, 'Occlusion': 0.5917592504847562, 'Grad-CAM': 0.996942297759233}
Shortcut deletion AUC: {'Gradient': 0.8972911030869, 'SmoothGrad': 0.8860569059252157, 'Integrated Gradients': 0.8641085574490717, 'Occlusion': 0.2962369700023828, 'Grad-CAM': 0.7814376527858258}

Figure caption: the deletion-curve legends and printed AUC dictionaries are produced by the same cell. Lower AUC means the target probability dropped faster as high-attribution pixels were removed.

What deletion AUC shows in this run¶

In this run, deletion detects the shortcut strongly for several methods, especially occlusion, while Grad-CAM remains relatively high. This illustrates that quantitative explanation tests are method-dependent: some methods expose the shortcut, while coarse spatial methods may still miss small artifacts. The main lesson is not that one metric is definitive, but that visual inspection, deletion behavior, and known-region attribution should be read together.

Stability under perturbations¶

A stable explanation should not change drastically under tiny input perturbations when the prediction remains the same. Instability is a warning that the visual story may be fragile.

In [25]:
def explanation_stability(model, x, method_fn, target_class, n=8, noise_std=0.03):
    base_map = method_fn(model, x, target_class).detach().cpu()
    corrs = []
    same_prediction = 0
    with torch.no_grad():
        base_pred = model(x).argmax(dim=1).item()
    for _ in range(n):
        x_pert = perturb_image(x, noise_std=noise_std)
        with torch.no_grad():
            pert_pred = model(x_pert).argmax(dim=1).item()
        same_prediction += int(pert_pred == base_pred)
        pert_map = method_fn(model, x_pert, target_class).detach().cpu()
        corrs.append(map_correlation(base_map, pert_map))
    return {
        "mean_corr": float(np.mean(corrs)),
        "std_corr": float(np.std(corrs)),
        "same_prediction_rate": same_prediction / n,
    }

stability_results = {
    "Gradient": explanation_stability(baseline_model, x_pneumonia, vanilla_gradient, pred_pneumonia, n=8),
    "Integrated Gradients": explanation_stability(
        baseline_model,
        x_pneumonia,
        lambda m, x, c: integrated_gradients(m, x, c, steps=12),
        pred_pneumonia,
        n=5,
    ),
}
print(stability_results)

names = list(stability_results.keys())
means = [stability_results[n]["mean_corr"] for n in names]
stds = [stability_results[n]["std_corr"] for n in names]
plt.figure(figsize=(6, 3.5))
plt.bar(names, means, yerr=stds, capsize=4)
plt.ylim(0, 1.05)
plt.ylabel("Mean heatmap correlation")
plt.title("Explanation stability under small noise")
plt.xticks(rotation=20)
plt.show()
{'Gradient': {'mean_corr': 0.743223138153553, 'std_corr': 0.011179810885290594, 'same_prediction_rate': 1.0}, 'Integrated Gradients': {'mean_corr': 0.8565796971321106, 'std_corr': 0.012707924825512515, 'same_prediction_rate': 1.0}}
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the stability plot summarizes how much the selected explanation maps change under small input noise. Low or highly variable correlation suggests that a visual explanation may be fragile even when the prediction is unchanged.

Shortcut-region attribution fraction¶

The marker-fraction plot above is most useful when compared across methods. Grad-CAM is spatially coarse because it is computed from low-resolution convolutional feature maps. Small artifacts such as corner markers may be blurred away or missed, even when the classifier uses them.

In [26]:
summary_rows = []
for name in heatmaps_baseline:
    summary_rows.append({
        "method": name,
        "baseline_deletion_auc": baseline_deletion_aucs.get(name, np.nan),
        "shortcut_deletion_auc": shortcut_deletion_aucs.get(name, np.nan),
        "shortcut_marker_fraction": fractions.get(name, np.nan),
    })
for row in summary_rows:
    print(row)

methods = [row["method"] for row in summary_rows]
marker_vals = [row["shortcut_marker_fraction"] for row in summary_rows]
base_auc_vals = [row["baseline_deletion_auc"] for row in summary_rows]
short_auc_vals = [row["shortcut_deletion_auc"] for row in summary_rows]

plt.figure(figsize=(7, 3.5))
plt.bar(methods, marker_vals)
plt.ylabel("Shortcut marker attribution fraction")
plt.title("Did the method reveal the shortcut?")
plt.xticks(rotation=30, ha="right")
plt.show()

plt.figure(figsize=(7, 3.5))
xpos = np.arange(len(methods))
width = 0.35
plt.bar(xpos - width / 2, base_auc_vals, width, label="baseline")
plt.bar(xpos + width / 2, short_auc_vals, width, label="shortcut")
plt.ylabel("Deletion AUC")
plt.title("Deletion AUC comparison")
plt.xticks(xpos, methods, rotation=30, ha="right")
plt.legend()
plt.show()
{'method': 'Gradient', 'baseline_deletion_auc': 0.9978689438139554, 'shortcut_deletion_auc': 0.8972911030869, 'shortcut_marker_fraction': 0.2587019184277326}
{'method': 'SmoothGrad', 'baseline_deletion_auc': 0.9985082889324985, 'shortcut_deletion_auc': 0.8860569059252157, 'shortcut_marker_fraction': 0.2724421616220347}
{'method': 'Integrated Gradients', 'baseline_deletion_auc': 0.9898751714426908, 'shortcut_deletion_auc': 0.8641085574490717, 'shortcut_marker_fraction': 0.5697475097826117}
{'method': 'Occlusion', 'baseline_deletion_auc': 0.5917592504847562, 'shortcut_deletion_auc': 0.2962369700023828, 'shortcut_marker_fraction': 0.9999995758673301}
{'method': 'Grad-CAM', 'baseline_deletion_auc': 0.996942297759233, 'shortcut_deletion_auc': 0.7814376527858258, 'shortcut_marker_fraction': 0.0200277234358911}
Notebook-generated figure; see the surrounding heading and caption for interpretation.
Notebook-generated figure; see the surrounding heading and caption for interpretation.

Figure caption: the final summary plots compare shortcut-region attribution and deletion AUC across explanation methods. These bars should match the printed summary_rows values above them after a full run-all execution.

Clinical interpretation and limitations¶

This notebook demonstrates failure modes, not deployment readiness. PneumoniaMNIST is small and preprocessed, the CNN is simple, and the marker shortcut is synthetic. Real clinical validation would require external sites, locked preprocessing, subgroup and scanner/protocol analysis, uncertainty and calibration reporting, decision-curve or workflow evaluation, clinician review of successes and failures, and prospective monitoring after deployment. Reporting checklists such as CLAIM help make medical imaging AI studies more transparent [10].

Learner exercises and expected observations¶

  1. Change the Integrated Gradients baseline from gray to black. Expected observation: maps can shift because the baseline defines the reference image; this is a reminder that attribution is partly a measurement choice.
  2. Change marker_size from 8 to 4 in the shortcut datasets. Expected observation: the shortcut becomes harder to see, and coarse methods such as Grad-CAM may miss it more often.
  3. Replace the corner marker with a border stripe or simulated burned-in text. Expected observation: different explanation methods may reveal different artifact shapes, so agreement across methods is more informative than one attractive map.
  4. Compare shortcut-model errors on clean data by class. Expected observation: the clean test set often exposes asymmetric false positives or false negatives that are hidden by high marked-test accuracy.
  5. Search for the test image with the largest probability drop after marker removal. Expected observation: a single dramatic counterfactual is useful for teaching, but the dataset-level clean-versus-marked gap is better evidence of shortcut reliance.

From pixels to patients: what this tutorial teaches¶

This notebook does not show how to deploy a pneumonia detector. Instead, it shows why deployment requires more than high accuracy and attractive heatmaps.

Before a model influences patient care, teams should ask:

  1. Does performance hold across hospitals, scanners, protocols, and patient groups?
  2. Are explanations sensitive to the learned model?
  3. Could the model be using non-clinical shortcuts?
  4. Are uncertainty, calibration, and failure modes reported?
  5. Have clinicians reviewed both successes and failures?
  6. Would the explanation change a clinical decision, and should it?

Checklist before trusting saliency maps in clinical AI¶

  • Evaluate performance on external data.
  • Report sensitivity, specificity, balanced accuracy, and uncertainty, not only accuracy.
  • Run model-randomization sanity checks.
  • Run label-randomization or label-noise sanity checks.
  • Test robustness to small perturbations.
  • Look for attribution to artifacts, borders, markers, or text.
  • Compare multiple explanation methods.
  • Involve clinical experts in failure analysis.
  • Document dataset provenance and preprocessing.
  • Treat explanations as debugging tools, not proof of clinical validity.

Key takeaways¶

  1. Saliency maps can be useful debugging tools, but visual plausibility is not evidence of faithful reasoning.
  2. Randomization, label-noise, deletion, stability, and shortcut-region tests reveal different failure modes.
  3. Accuracy alone is not enough for clinical screening; sensitivity and specificity matter.
  4. Synthetic shortcuts are useful teaching tools because the ground-truth artifact is known.
  5. Moving from pixels to patients requires external validation, transparent reporting, clinical review, and careful workflow design.

Checklist summary: the final checklist is the practical handoff from the tutorial. It frames saliency maps as one debugging input alongside external validation, clinical metrics, robustness, calibration, provenance, and expert review.

AI assistance disclosure¶

Parts of this tutorial were prepared with assistance from an AI language model for editing, structuring, and clarity. All code, experiments, results, and scientific claims were reviewed and validated by T. Kirscher.

References¶

[1] Yang, J. et al. MedMNIST v2: A large-scale lightweight benchmark for 2D and 3D biomedical image classification. Scientific Data, 2023.

[2] Sundararajan, M., Taly, A., and Yan, Q. Axiomatic Attribution for Deep Networks. ICML, 2017.

[3] Smilkov, D. et al. SmoothGrad: removing noise by adding noise. arXiv, 2017.

[4] Selvaraju, R. R. et al. Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. ICCV, 2017.

[5] Adebayo, J. et al. Sanity Checks for Saliency Maps. NeurIPS, 2018.

[6] Fong, R. C. and Vedaldi, A. Interpretable Explanations of Black Boxes by Meaningful Perturbation. ICCV, 2017.

[7] Zech, J. R. et al. Variable generalization performance of a deep learning model to detect pneumonia in chest radiographs: A cross-sectional study. PLOS Medicine, 2018.

[8] DeGrave, A. J., Janizek, J. D., and Lee, S.-I. AI for radiographic COVID-19 detection selects shortcuts over signal. Nature Machine Intelligence, 2021.

[9] Samek, W. et al., editors. Explainable AI: Interpreting, Explaining and Visualizing Deep Learning. Springer, 2019.

[10] Mongan, J., Moy, L., and Kahn, C. E. Checklist for Artificial Intelligence in Medical Imaging (CLAIM): A Guide for Authors and Reviewers. Radiology: Artificial Intelligence, 2020.