Skip to content

Latest commit

 

History

History
154 lines (106 loc) · 12.6 KB

README.md

File metadata and controls

154 lines (106 loc) · 12.6 KB

BenchMD: A Benchmark for Unified Learning on Medical Images and Sensors [ArXiv]

This repository contains the code for BenchMD, a benchmark for modality-agnostic self-supervised learning on medical image and sensor data. The BenchMD benchmark consists of 19 real-world medical datasets across 7 medical modalities. Models for each modality are first trained on a source dataset, and successful models will achieve high performance when evaluated on out-of-distribution target datasets in the same modality.

Methods Diagram

The basic components of the benchmark can be found in datasets, encoders, and algorithms. Training is implemented with the PyTorch Lightning framework, logging with Weights and Biases, and configuration management with Hydra.

BenchMD was built using the codebase for DABS: A Domain Agnostic Benchmark for Self-Supervised Learning (DABS, DABS 2.0).

Usage

We provide support for Python >= 3.7. Install requirements with

python -m pip install -r requirements.txt

For instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org.

Datasets

We provide a set of dataset implementations (in src/datasets) from Electrocardiograms (ECG), Electroencephalograms (EEG), Chest X-Rays (CXR), Dermascopic Images (Derm), Mammograms (Mammo), Fundus Images (Fundus), and Low Dose Computed Tomography (LDCT). Preprocessing operations on these datasets are minimal and hard-coded as simple resizing (i.e. of images), truncations (i.e. of signal data), or windowing (i.e. of 3D CT data). These should not be changed so as to maintain fair comparisons across other users of the benchmark.

See conf/dataset/*.yaml for all dataset configs, including the loss, metrics, and batch size used for each dataset. The name field in each dataset YAML file will be passed in under the dataset argument when you run pretraining or transfer.

Modality Source Dataset Target Dataset(s) Label type (unused) Input Type
ECG PTB-XL Chapman-Shaoxing, Georgia, CPSC Single label 1d
EEG SHHS ISRUC Single label 1d
CXR MIMIC-CXR CheXpert, VinDr-CXR Multi label 2d
Mammo VinDr-Mammo CBIS-DDSM Single label 2d
Derm BCN 20000 HAM 10000, PAD-UFES-20 Single label 2d
Fundus Messidor-2 APTOS 2019, Jinchi Single label 2d
LDCT LIDC-IDRI LNDb Multi label 3d

All information for downloading datasets, as well as label distributions and label mappings (where applicable) are provided in dataset-info.md.

Adding New Datasets

  1. Add a new Python file for preprocessing under the data_root directory (as specified in conf/pretrain.yaml and conf/transfer.yaml). One of the available preprocessing files can be used as template, and we recommend using one with a matching input type (1D, 2D, or 3D).

  2. Add a corresponding configuration file under conf/dataset, filling in at least all fields we have provided. The name argument will be used elsewhere in the repo and when you run training with this dataset.

  3. Indata_root/catalog.py, import the dataset Python file and add the dataset name argument wherever applicable, depending on whether it will be used for pretraining and/or transfer, whether it is labelled, etc.

Encoders

A domain-agnostic SSL method should have an encoder which remains as constant as possible across domains. In src/encoders, we provide a general domain-agnostic transformer encoder baseline, as well as an ImageNet-pretrained Vision Transformer (ViT-T). The transformer operates on a sequence of vectors that are produced by a small set of embedding modules (e.g. patch or token embeddings). Each dataset and encoder has its own config file (see conf/dataset and conf/model directories).

See conf/model/*.yaml for all model configs, including the embedding dimension, pooling method, depth, and dropout used for each encoder type.

Adding New Encoders

  1. Add a new Python class for the algorithm under src/models. The new class must inherit from the the BaseModel class in src/models/base_model.py.

  2. Add a corresponding configuration file under conf/model, filling in at least the name argument, which will be used elsewhere in the repo and when you run training with this encoder.

  3. Add corresponding logic for instantating this model class using its name under src/systems/base_system in the get_model method.

Pretraining algorithms

The pretraining algorithm is the framework and objective that the encoder is trained with. Examples of domain-specific algorithms include SimCLR, BYOL, and MoCo, but these are not domain-agnostic methods as they depend on vision-specific augmentations. We provide domain-agnostic implementations of recent algorithms, including e-mix (a generalization of i-mix), Shuffled Embedding Detection (ShED; a generalization of ELECTRA), which randomly permutes a subset of the input embeddings and trains the model to identify the permuted embeddings, and Masked Autoencoders (MAE), which masks a subset of input embeddings and trains models to reconstruct them.

Adding New Pretraining Algorithms

  1. Add a new Python file for the algorithm under src/systems. The new class must inherit from the the BaseSystem class in src/systems/base_system.py.

  2. Add corresponding logic for it under pretrain.py.

Pretraining Phase

During the pretraining phase, self-supervised encoders are trained to learn good representations from unlabeled data. We currently support seven datasets for pretraining, one for each modality: PTB-XL, SHHS, MIMIC, VinDr-Mammo, BCN 20000, Messidor-2, and LIDC-IDRI. Since the pretraining datasets have associated labels, online linear evaluators are jointly trained with the encoders to provide heuristics of transfer performance.

Run pretraining with commands like

python pretrain.py exp.name=<experiment-name> dataset=<dataset> algorithm=<algorithm> model=<encoder-type>

Key Configuration Fields

  • exp.name: The directory under which checkpoints and other files will be saved, under exp.base_dir. Must be specified each time, no default.
  • dataset: Options include ptb-xl, sshs, mimic-cxr, vindr-mammo, isic2019, messidor2, and lidc. Default is mimic-cxr.
  • algorithm: Options include emix, shed, mae. Default is emix.
  • model: Encoder model to use for training. Options include transformer and imagenet-vit (the domain-agnostic Transformer and ImageNet-pretrained ViT-T, respectively). Default is transformer.

See pretrain.yaml for all pretraining configuration fields.

For example, to train a domain-agnostic Transformer on the PTB-XL dataset with the e-Mix algorithm, run

python pretrain.py exp.name=emix-ptbxl encoder=transformer dataset=ptbxl algorithm=emix

Transfer Learning Phase

After pretraining, a small linear classifier is trained on top of the frozen encoder. Run transfer learning from a randomly initialized encoder with

python transfer.py exp.name=<experiment-name> dataset=<dataset> model=<encoder-type> finetune_size=<null-or-sml> ckpt=null test=<t-or-f> 

Key Configuration Fields

  • exp.name: The directory under which checkpoints and other files will be saved, under exp.base_dir. Must be specified each time, no default.
  • dataset: Transfer learning options are the same source datasetsas pretraining: ptb-xl, sshs, mimic-cxr, vindr-mammo, isic2019, messidor2, and lidc. Options for testing on target datasets (see below) include ga, chapman, cpsc, isruc, chexpert, vindr-cxr, cbis, isic2018, aptos, jinchi, lndb. Default is chexpert.
  • model: Encoder model to use for training. Options include transformer and imagenet-vit (the domain-agnostic Transformer and ImageNet-pretrained ViT-T, respectively). Default is transformer.
  • finetune_size: Options include null, small, medium, large, or full. Default is null, which corresponds to performing linear evaluation. If you'd like to perform finetuning, set finetune_size to be one of small (8 labels/class), medium (64 labels/class), large (256 labels/class), or full (all labels). This will allow the encoder's weights to be trained alongside the linear classifier.
  • ckpt: Path to file for model to beginning transfer learning from. Default null, which corresponds to performing transfer learning on a randomly initialized model.
  • test: Either true or false. Default is false, which corresponds to performing transfer learning. The true option will freeze the model provided via the ckpt argument and evaluate it on the dataset.

See transfer.yaml for all transfer learning configuration fields.

For example, to perform finetuning with a "small" label set on the E-Mix model we pretrained above, run

python transfer.py exp.name=finetune-small-ptbxl-emix dataset=ptbxl finetune_size=small ckpt=<path-to-pretrain-ckpt> 

To evaluate this model on the CPSC dataset, run

python transfer.py exp.name=test-finetune-small-ptbxl-emix dataset=cpsc ckpt=<path-to-transfer-ckpt> test=True

See conf/transfer.yaml for all transfer learning configuration fields

Results

Results Table

Bugs or questions?

Feel free to email Kathryn Wantlin ([email protected]) or open an issue in this repo with any questions related to the code or our paper, and also to let us know if you have updated results!

Citation

If you are using BenchMD, or are using our code in your research, please cite our paper:

@misc{wantlin2023benchmd, title={BenchMD: A Benchmark for Modality-Agnostic Learning on Medical Images and Sensors}, author={Kathryn Wantlin and Chenwei Wu and Shih-Cheng Huang and Oishi Banerjee and Farah Dadabhoy and Veeral Vipin Mehta and Ryan Wonhee Han and Fang Cao and Raja R. Narayan and Errol Colak and Adewole Adamson and Laura Heacock and Geoffrey H. Tison and Alex Tamkin and Pranav Rajpurkar}, year={2023}, eprint={2304.08486}, archivePrefix={arXiv}, primaryClass={cs.CV} }