AP
← Research

Robust and Explainable CNN-Based Chest X-Ray Classification Across Unseen Clinical Datasets

Ongoing

A cross-dataset generalization study training CNN models on CheXpert and evaluating them zero-shot on MIMIC-CXR, NIH ChestX-ray14, PadChest, VinDr-CXR, and BRAX to measure architecture robustness and detect dataset-specific shortcuts.

Medical ImagingCNNPyTorchXAIGrad-CAMMulti-label ClassificationCheXpertGeneralization

Overview

This project implements a strict external validation framework for multi-label chest X-ray classification. Models are trained solely on CheXpert (224,316 images, 14 labels, 65,240 patients) and then evaluated — without any fine-tuning — on five external datasets: MIMIC-CXR-JPG, ChestX-ray14, PadChest, VinDr-CXR, and BRAX. A label harmonization module maps all datasets to a shared disease space.

The backbone comparison spans 15+ CNN architectures. Each model is evaluated on per-class AUROC, macro-AUROC, macro-F1, sensitivity, and calibration metrics. The generalization gap (internal CheXpert performance minus external performance) is the primary measure of architectural robustness.

Secondary experiments ablate optimizer choice, loss functions, and uncertainty-label strategies. A dataset-source classifier detects shortcut features, and Grad-CAM heatmaps verify whether models attend to clinically meaningful lung regions or acquisition artifacts.

Cross-Dataset Setup and Label Harmonization

Models are trained solely on CheXpert (224,316 images, 14 labels) and evaluated without fine-tuning on MIMIC-CXR-JPG, NIH ChestX-ray14, PadChest, VinDr-CXR, and BRAX. A label harmonization module maps all dataset annotations to a shared disease space: No Finding, Cardiomegaly, Atelectasis, Consolidation, Pleural Effusion, Pneumothorax, Pneumonia, Edema, and Lung Lesion. Unmappable labels (e.g., Lung Lesion approximated from Mass + Nodule) are handled per-dataset.

Image Preprocessing Pipeline

Original CXR images range from low-resolution downsampled versions to 2000–3000px scans. The pipeline tests multiple resolutions (224, 384, 512, 1024, 2048) and normalization strategies: ImageNet normalization, min-max scaling, per-image z-score normalization, histogram standardization, ROI cropping, adaptive ROI crop + histogram standardization, and CLAHE. Preprocessing variant is treated as an experimental variable.

CNN Architecture Comparison

15+ CNN architectures are trained under the same CheXpert data split and evaluation protocol:

  • Classic baselines: DenseNet-121, ResNet-18/50/101, DenseNet-161
  • Modern CNNs: ConvNeXt-S, ConvNeXt V2-T, EfficientNetV2-S
  • Efficient/NAS models: EfficientNet-B0, EfficientNet-B4
  • Attention-based models: TResNet-50, SE-ResNet-50, ResNeXt-101
  • Lightweight models: MobileNetV3-Large, RepVGG-B0
  • Underexplored: RegNet-Y-4GF, Xception, DPN-style models

Generalization Gap Measurement

Each trained checkpoint is evaluated on all external datasets without any fine-tuning. The generalization gap is defined as: Internal CheXpert Performance − External Dataset Performance. Per-class AUROC, macro-AUROC, macro-F1, precision, recall, sensitivity, and calibration metrics are computed per dataset. Architecture-efficiency tracking records parameters, FLOPs, GPU memory, training time per epoch, and inference throughput alongside accuracy.

Optimizer and Loss Ablations

Using a fixed backbone (DenseNet-121 or ConvNeXt-S), the project ablates optimizer choice: AdamW, Muon, Schedule-Free AdamW, Lion, SOAP, Adan, MARS, SAM variants, and RMSProp. Loss functions tested include binary cross-entropy, focal loss, margin-based surrogate loss, and AUC-oriented loss. Uncertainty-label strategies for ambiguous annotations are compared: U-Ones, U-Zeros, U-Ignore, and label smoothing.

Shortcut Analysis

A dataset-source classifier is trained to predict which hospital/dataset an image came from rather than its disease labels. High classification accuracy indicates strong dataset-specific visual signatures (scanner type, resolution, preprocessing artifacts), suggesting disease classifiers may exploit shortcuts. Shortcut presence is correlated with the observed generalization gap across architectures.

XAI-Based Inspection

Grad-CAM-style heatmaps are generated for correct predictions, incorrect predictions, and high-confidence failures across all external datasets. Heatmaps are compared to verify whether models attend to clinically meaningful regions (lungs, pleura, cardiac silhouette, lesion areas) or irrelevant features (image borders, text markers, black padding, acquisition artifacts, dataset-specific cues).

Implementation Pipeline

  1. 1

    Load and preprocess CheXpert as the primary training dataset.

  2. 2

    Harmonize labels across CheXpert, MIMIC-CXR-JPG, ChestX-ray14, PadChest, VinDr-CXR, and BRAX.

  3. 3

    Train CNN models on CheXpert using a multi-label sigmoid classification head.

  4. 4

    Validate internally on the CheXpert validation split.

  5. 5

    Freeze trained model weights and evaluate directly on all external datasets.

  6. 6

    Compute generalization gap: internal CheXpert performance minus external dataset performance.

  7. 7

    Repeat the protocol across multiple CNN backbones, preprocessing strategies, optimizers, loss functions, and uncertainty-label strategies.

  8. 8

    Use efficiency metrics and XAI visualizations to identify models that are accurate, robust, and clinically interpretable.