Med-BERT

The Med-BERT model is a natural language processing model for disease prediction based on EHR records. You can read more about it in the paper:

Laila Rasmy, Yang Xiang, Ziqian Xie, Cui Tao, and Degui Zhi. “Med-BERT: pre-trained contextualized embeddings on large-scale structured electronic health records for disease prediction.” npj digital medicine 2021 https://www.nature.com/articles/s41746-021-00455-y.

Due to vendor restrictions, the authors could not share their trained model:

Initially we really hoped to share our models but unfortunately, the pre-trained models are no longer sharable. According to SBMI Data Service Office: “Under the terms of our contracts with data vendors, we are not permitted to share any of the data utilized in our publications, as well as large models derived from those data.”

but they shared code to reproducte Med-BERT at https://github.com/ZhiGroup/Med-BERT.

If you have access to data that aligns with Med-BERT’s requirements, you can leverage LARCC’s resources to create your own instance of Med-BERT. Here is an example for the pre-training phase:

  1. Setup code dependencies. For this case, the pretraining code depends on tensorflow 1.x, which

    • is only compatible with python 3.5 to 3.7. The cluster comes with python 3.9 by default and, currently, there is no module for any of these python versions. Thus, you will need to use Conda to create an environment with the desired python version.

    • is compatible with protobuf versions prior 4.0.

    • is compatible with cuda versions up to CUDA 10. LARCC’s gpus are only compatible with CUDA versions greater than 11.8, so you will need to use CPUs for the pretraining.

    module load miniforge3
    conda create --name my_tf1 python=3.7 tensorflow-gpu 'protobuf<=3.20' pandas numpy matplotlib
    
  2. Download code and rename all spaces in folder names with _ to avoid conflicts in Linux.

    cd $WORK
    git clone https://github.com/ZhiGroup/Med-BERT.git
    find Med-BERT -type d -name '*[[:space:]]*' | xargs -I '{}' sh -c "mv '{}' \`echo '{}' | sed 's/ /_/g'\`"
    
  3. Preprocess the data you will use for the pretraining step. In the example below, the option --output_file='ehr_tf_features' will create a tensorflow formatted features file named ehr_tf_features required for the pretraining.

    cd $WORK/Med-BERT/Pretraining_Code/Data_Pre-processing_Code
    # NOTE: You can do the following on a batch job instead.
    srun --partition=cpu384g --job-name med-bert --time=01:00:00 --ntasks-per-node=32 --cpu-bind=cores --pty /bin/bash -i
    cd $WORK/Med-BERT/Pretraining_Code/Data_Pre-processing_Code
    module load miniforge3
    conda activate my_tf1
    # NOTE: This assumes your input file is stored in the path below. Change it to something
    # else if you store your data somewhere else
    INPUT=$WORK/Med-BERT/Pretraining_Code/Data_Pre-processing_Code/data_file.tsv
    OUT_PREFIX=preprocessed
    python3 preprocess_pretrain_data.py "$INPUT" NA "$OUT_PREFIX"
    python3 create_BERTpretrain_EHRfeatures.py \
        --input_file="$OUT_PREFIX.bencs.train" \
        --output_file='ehr_tf_features' \
        --vocab_file="$OUT_PREFIX.types" \
        --max_predictions_per_seq=1 \
        --max_seq_length=64
    exit
    
  4. Create a submission script for the pretraining phase. Assume the script below is written to $WORK/med-bert.sbatch.

     1#!/bin/bash
     2#SBATCH --partition=cpu384g
     3#SBATCH --time=3-00:00:00
     4#SBATCH --ntasks=32
     5#SBATCH --ntasks-per-core=1
     6
     7PROJECT_NAME="Med-BERT"
     8PROJECT_DIR="/work/$USER/$PROJECT_NAME"
     9SCRATCH_DIR="/mnt/local/scratch/$USER"
    10WDIR="$SCRATCH_DIR/$SLURM_JOB_ID"
    11PRETRAIN_DIR="$WDIR/$PROJECT_NAME/Pretraining_Code"
    12OUTDIR="$WDIR/out"
    13
    14# Copy project to scratch space to speed-up reading and writting of data.
    15# Checkpoints will be written to scratch too. You want to avoid writting to
    16# your home storage as it's orders of magnitude slower for read/write operations.
    17mkdir -p "$WDIR"
    18mkdir -p "$OUTDIR"
    19cp -r "$PROJECT_DIR" "$WDIR"
    20cd "$PRETRAIN_DIR"
    21
    22# BACKUP settings:
    23#   These variables control how often and where the checkpoints produced
    24#   by the pretraining code is copied over to persistent storage.
    25BACKUP_DIR="$PROJECT_DIR/Pretraining_Code/checkpoint_$SLURM_JOB_ID"
    26BACKUP_FREQUENCY=600 # i.e. checkpoint every 10m
    27BACKUP_LOCK="$SCRATCH_DIR/backup_in_progress_$SLURM_JOB_ID"
    28BACKUP_CMD="rsync -a $OUTDIR/ $BACKUP_DIR/"
    29
    30# Start background process to copy any produced checkpoint back to
    31# persistent storage
    32bash -c "
    33  while [ -e /proc/$$ ]; do
    34    sleep $BACKUP_FREQUENCY
    35    touch $BACKUP_LOCK
    36    $BACKUP_CMD
    37    rm $BACKUP_LOCK
    38  done
    39" &
    40backup_pid=$!
    41
    42module load miniforge3
    43conda activate my_tf1
    44
    45python3 run_EHRpretraining.py \
    46        --input_file="$PRETRAIN_DIR/Data_Pre-processing_Code/ehr_tf_features" \
    47        --output_dir="$OUTDIR" \
    48        --do_train=True \
    49        --do_eval=True \
    50        --bert_config_file="$PRETRAIN_DIR/config.json" \
    51        --train_batch_size=32 \
    52        --max_seq_length=64 \
    53        --max_predictions_per_seq=1 \
    54        --num_train_steps=4500000 \
    55        --num_warmup_steps=10000 \
    56        --learning_rate=5e-5
    57
    58# Wait for active backup (if any) to finish
    59while [ -e $BACKUP_LOCK ]; do sleep 1; done
    60# Prevent any more auto-backups from starting and perform
    61# final backup
    62kill -9 $backup_pid
    63$BACKUP_CMD
    
  5. Submit script to slurm with sbatch ~/med-bert.sbatch.