Deep Learning based medical image segmentation — Part 6— Neural Network Model Training
We’ve downloaded the TCIA Glioblastoma MRI dataset, analyzed its contents, prepared the Google Colab/Drive environment to train our deep learning neural network model using nnUNet, prepared the training and test data, and preprocessed the images.
In this post, let’s train our model on the MRI scans and their associated ground truth segmentation maps (labels) that were created by radiologists.
GPU selection for model training
If you remember, we split 147 patients’ MRI scans and segmentation labels into 80% training cohort (n=117) and 20% test cohort (n=30). We will now train the neural network on the 117 training cohort.
Until now, I have used the free version of Google Colab. You get a Tesla K80 GPU with the free version, which is not going to be sufficient to train on the 147 cases. You are also limited on the runtime and memory, and your training will most certainly time out.
To overcome this, you can sign up for Google Colab Pro, or Pro+ to get access to faster GPUs with higher GPU memory that will allow us to train the neural network.
For this project, I signed up to the Pay as you Go plan, and purchased 100 compute units for $9.99. This gives us access to faster GPUs and more memory as well. I was able to complete both the training and inference tasks and still had about 10 compute units remaining. Unused units expire after 90 days. After the purchase, you can gain access to the faster GPU and more memory by going to the Runtime → Change Runtime Type menu in your Colab notebook.
You may see the NVIDIA P100 (Pascal architecture), V100 (Volta architecture), or the A100 (Ampere architecture) as your available options for GPU Type. Choose A100, as this is the fastest GPU among the selection available. Also, choose the High RAM option for more GPU memory.
To save your compute units, remember to Disconnect and Delete Runtime if you are not actively training the model.
- Open a new Colab notebook — 05_t501_glio_model_train.ipynb
- Mount your Google Drive
from google.colab import drive
drive.mount('/content/drive')
- Install the nnUNet V2 package
!pip install nnunetv2
- Install the hiddenlayer package, which will print the neural network architecture to a PDF file for review.
!pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git
nnUNet minor hyperparameter hack
By default, nnUNet runs model training for 1000 epochs. It takes a long time to run 5 cross-validation folds for 1000 epochs each. For this project, I want to run the training for only 100 epochs per fold. Unfortunately, as of this writing, there is no parameter to change the number of epochs dynamically during training. We will need to change the hardcoded value in the nnUNet package. This may result in lower accuracy, but will suffice for our purpose of learning.
- With your Google Colab Pro subscription, you also get access to the VM’s terminal (bottom left of your Colab window). Click on that and your screen will split, with the terminal window on the right, and your notebook on the left.
- The number of epochs is hardcoded in this file on the Colab VM— /usr/local/lib/python3.10/dist-packages/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
- You can see the original file here on the nnUNet Github
- In the Colab Terminal window (not the Notebook area), go to that directory and open the file using a Unix editor like vim. Go to the section that says — “Some hyperparameters for you to fiddle with”.
- Let’s fiddle with a few. Change the following three hyperparameters to the values below -
### Some hyperparameters for you to fiddle with
...
self.num_iterations_per_epoch = 50
self.num_val_iterations_per_epoch = 25
self.num_epochs = 100
...
- Save and exit the file after making these changes. Now we are ready to train our neural network model.
Model Training
- In the same Colab notebook, with the A100 GPU and High RAM enabled, and hyperparameters updated, let’s set up our environment variables and run the nnUNet training command.
- The model is trained using the following command format -
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
- We will use these options for our model training — DATASET_NAME_OR_ID will be Dataset501, UNET_CONFIGURATION will be 3d full resolution, and FOLD will be 0–4. We will have to repeat this for 5 cross-validation folds (fold # 0, 1, 2, 3 & 4). We will run the training one fold after another, starting with fold 0.
- We will also include the “— npz” parameter, which tells nnUNet to store the softmax outputs during final validation. This is needed because after the training, we are going to ask nnUNet to choose the best configuration.
# Set up environment variables
import os
os.environ['nnUNet_raw'] = "/content/drive/MyDrive/TCIA/nnUNet/nnUNet_raw"
os.environ['nnUNet_preprocessed'] = "/content/drive/MyDrive/TCIA/nnUNet/nnUNet_preprocessed"
os.environ['nnUNet_results'] = "/content/drive/MyDrive/TCIA/nnUNet/nnUNet_results"
# Run training one fold at a time, one after the other, from 0 - 4
# This is not a Python command, but an OS command running on the VM
# Uncomment each fold command in turn and comment out the prior fold command
!nnUNetv2_train Dataset501_Glioblastoma 3d_fullres 0 --npz
# !nnUNetv2_train Dataset501_Glioblastoma 3d_fullres 1 --npz
# !nnUNetv2_train Dataset501_Glioblastoma 3d_fullres 2 --npz
# !nnUNetv2_train Dataset501_Glioblastoma 3d_fullres 3 --npz
# !nnUNetv2_train Dataset501_Glioblastoma 3d_fullres 4 --npz
Using device: cuda:0
#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################
This is the configuration used by this training:
Configuration name: 3d_fullres
{'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [128, 160, 112], 'median_image_size_in_voxels': [140.0, 172.0, 137.0], 'spacing': [1.0, 1.0, 1.0], 'normalization_schemes': ['ZScoreNormalization', 'ZScoreNormalization', 'ZScoreNormalization', 'ZScoreNormalization'], 'use_mask_for_norm': [True, True, True, True], 'UNet_class_name': 'PlainConvUNet', 'UNet_base_num_features': 32, 'n_conv_per_stage_encoder': [2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2], 'num_pool_per_axis': [5, 5, 4], 'pool_op_kernel_sizes': [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 1]], 'conv_kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'unet_max_num_features': 320, 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'batch_dice': False}
These are the global plan.json settings:
{'dataset_name': 'Dataset501_Glioblastoma', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [1.0, 1.0, 1.0], 'original_median_shape_after_transp': [140, 172, 137], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 1464.0, 'mean': 329.0144958496094, 'median': 322.0, 'min': 0.0, 'percentile_00_5': 138.0, 'percentile_99_5': 963.0, 'std': 93.77518463134766}, '1': {'max': 2347.0, 'mean': 427.0850524902344, 'median': 379.0, 'min': 0.0, 'percentile_00_5': 153.0, 'percentile_99_5': 1258.0, 'std': 186.98681640625}, '2': {'max': 2957.0, 'mean': 538.0934448242188, 'median': 515.0, 'min': 0.0, 'percentile_00_5': 74.0, 'percentile_99_5': 1556.0, 'std': 351.8751525878906}, '3': {'max': 1536.0, 'mean': 397.7358093261719, 'median': 376.0, 'min': 0.0, 'percentile_00_5': 110.0, 'percentile_99_5': 973.0, 'std': 146.7981414794922}}}
2023-06-22 01:03:55.889929: unpacking dataset...
2023-06-22 01:03:59.957053: unpacking done...
2023-06-22 01:03:59.960094: do_dummy_2d_data_aug: False
2023-06-22 01:03:59.968338: Using splits from existing split file: /content/drive/MyDrive/TCIA/nnUNet_raw_data_base/nnUNet_preprocessed/Dataset501_Glioblastoma/splits_final.json
2023-06-22 01:03:59.971466: The split file contains 5 splits.
2023-06-22 01:03:59.973142: Desired fold for training: 4
2023-06-22 01:03:59.974958: This split has 94 training and 23 validation cases.
2023-06-22 01:04:03.552835:
2023-06-22 01:04:03.554749: Epoch 0
2023-06-22 01:04:03.556601: Current learning rate: 0.01
using pin_memory on device 0
using pin_memory on device 0
2023-06-22 01:04:43.621871: train_loss 0.648
2023-06-22 01:04:43.679138: val_loss -0.1173
2023-06-22 01:04:43.715275: Pseudo dice [0.0001, 0.6325, 0.5719]
2023-06-22 01:04:43.795578: Epoch time: 40.07 s
2023-06-22 01:04:43.842204: Yayy! New best EMA pseudo Dice: 0.4015
2023-06-22 01:04:48.932346:
2023-06-22 01:04:48.934511: Epoch 1
2023-06-22 01:04:48.937582: Current learning rate: 0.00991
2023-06-22 01:05:08.461267: train_loss -0.2243
2023-06-22 01:05:08.500713: val_loss -0.2635
2023-06-22 01:05:08.529491: Pseudo dice [0.0, 0.5936, 0.7116]
2023-06-22 01:05:08.555052: Epoch time: 19.53 s
2023-06-22 01:05:08.610179: Yayy! New best EMA pseudo Dice: 0.4049
2023-06-22 01:05:13.500060:
2023-06-22 01:05:13.501934: Epoch 2
2023-06-22 01:05:13.505103: Current learning rate: 0.00982
2023-06-22 01:05:37.768519: train_loss -0.2903
2023-06-22 01:05:37.822310: val_loss -0.4272
2023-06-22 01:05:37.859740: Pseudo dice [0.1743, 0.6974, 0.7755]
2023-06-22 01:05:37.882360: Epoch time: 24.27 s
2023-06-22 01:05:37.911484: Yayy! New best EMA pseudo Dice: 0.4193
2023-06-22 01:05:42.975867:
.....
.....
2023-06-22 01:50:13.120085: Epoch 98
2023-06-22 01:50:13.122439: Current learning rate: 0.0003
2023-06-22 01:50:33.998227: train_loss -0.7056
2023-06-22 01:50:34.044335: val_loss -0.7099
2023-06-22 01:50:34.091183: Pseudo dice [0.8565, 0.8246, 0.8284]
2023-06-22 01:50:34.132137: Epoch time: 20.9 s
2023-06-22 01:50:37.814465:
2023-06-22 01:50:37.818093: Epoch 99
2023-06-22 01:50:37.820400: Current learning rate: 0.00016
2023-06-22 01:51:04.239096: train_loss -0.7435
2023-06-22 01:51:04.280445: val_loss -0.6882
2023-06-22 01:51:04.302704: Pseudo dice [0.8099, 0.8502, 0.8387]
2023-06-22 01:51:04.349361: Epoch time: 26.43 s
2023-06-22 01:51:10.418192: Using splits from existing split file: /content/drive/MyDrive/TCIA/nnUNet_raw_data_base/nnUNet_preprocessed/Dataset501_Glioblastoma/splits_final.json
2023-06-22 01:51:10.439320: The split file contains 5 splits.
2023-06-22 01:51:10.447497: Desired fold for training: 4
2023-06-22 01:51:10.454137: This split has 94 training and 23 validation cases.
2023-06-22 01:51:10.469486: predicting 102
2023-06-22 01:51:14.112995: predicting 112
2023-06-22 01:51:16.829235: predicting 118
2023-06-22 01:51:19.061675: predicting 119
2023-06-22 01:51:21.536263: predicting 13
2023-06-22 01:51:23.938888: predicting 154
2023-06-22 01:51:26.321893: predicting 176
2023-06-22 01:51:28.583216: predicting 180
2023-06-22 01:51:30.781342: predicting 193
2023-06-22 01:51:32.962270: predicting 2
2023-06-22 01:51:35.076004: predicting 206
2023-06-22 01:51:37.160351: predicting 262
2023-06-22 01:51:39.198120: predicting 330
2023-06-22 01:51:41.221925: predicting 362
2023-06-22 01:51:43.297707: predicting 371
2023-06-22 01:51:45.319198: predicting 373
2023-06-22 01:51:47.350525: predicting 375
2023-06-22 01:51:49.343235: predicting 404
2023-06-22 01:51:51.385258: predicting 474
2023-06-22 01:51:53.361364: predicting 6
2023-06-22 01:51:55.358148: predicting 88
2023-06-22 01:51:57.391516: predicting 9
2023-06-22 01:51:59.424051: predicting 91
2023-06-22 01:52:10.922611: Validation complete
2023-06-22 01:52:10.944788: Mean Validation Dice: 0.8029053811307724
Interpreting model training outputs
Let’s look at the model training output in detail.
- Training Log — The training log is shown in your Colab console and also stored in this folder — /content/drive/MyDrive/TCIA/nnUNet/nnUNet_results/Dataset501_Glioblastoma/nnUNetTrainer__nnUNetPlans__3d_fullres/fold_n, where fold_n is the fold number from 0–4.
- Training:Validation Split — For each training fold, nnUNet is randomly choosing 80% for training (n=94), and 20% for validation (n=23). The actual patient cohort chosen in each fold will be different. This allows us to cross-validate between different sub-sets of the training dataset for optimal accuracy.
2023-06-22 01:51:10.418192: Using splits from existing split file: /content/drive/MyDrive/TCIA/nnUNet_raw_data_base/nnUNet_preprocessed/Dataset501_Glioblastoma/splits_final.json
2023-06-22 01:51:10.439320: The split file contains 5 splits.
2023-06-22 01:51:10.447497: Desired fold for training: 4
2023-06-22 01:51:10.454137: This split has 94 training and 23 validation cases.
- Training Time — We can see that total training time for 100 epochs per fold was about 49 minutes (1:03–1:52). So to train 5 folds for 100 epochs each, you will need about 5 hours. This is also shown in the progress.png file, located in the same folder as the log file above.
- Training and Validation Loss — For each epoch, the training and validation loss are shown. The psuedo-DICE score (Dice-Sorensen Coefficient, a metric used to assess the similarity of two samples) is also shown. These metrics are also plotted in the progress.png file. You should check the training and validation loss trendline for underfitting or overfitting. I will not go into details around overfitting here, but if both loss trendlines are converging, generally, it may mean that the fit is optimal.
2023-06-22 01:04:48.934511: Epoch 1
2023-06-22 01:04:48.937582: Current learning rate: 0.00991
2023-06-22 01:05:08.461267: train_loss -0.2243
2023-06-22 01:05:08.500713: val_loss -0.2635
2023-06-22 01:05:08.529491: Pseudo dice [0.0, 0.5936, 0.7116]
2023-06-22 01:05:08.555052: Epoch time: 19.53 s
2023-06-22 01:05:08.610179: Yayy! New best EMA pseudo Dice: 0.4049
- Learning Rate — nnUnet dynamically decreases the learning rate every epoch. The trendline is charted in the progress.png file.
2023-06-22 01:04:03.554749: Epoch 0
2023-06-22 01:04:03.556601: Current learning rate: 0.01
2023-06-22 01:04:48.934511: Epoch 1
2023-06-22 01:04:48.937582: Current learning rate: 0.00991
....
....
2023-06-22 01:50:13.120085: Epoch 98
2023-06-22 01:50:13.122439: Current learning rate: 0.0003
2023-06-22 01:50:37.818093: Epoch 99
2023-06-22 01:50:37.820400: Current learning rate: 0.00016
- Inference on the Validation Cohort — Each fold, nnUNet will run inference on the 20% of the validation subset that was randomly chosen at the start of the training fold. It then calculates the mean validation DICE score, which will give us an indication of how our trained model will perform on the unseen test dataset.
2023-06-22 01:51:10.454137: This split has 94 training and 23 validation cases.
2023-06-22 01:51:10.469486: predicting 102
2023-06-22 01:51:14.112995: predicting 112
....
....
2023-06-22 01:51:57.391516: predicting 9
2023-06-22 01:51:59.424051: predicting 91
2023-06-22 01:52:10.922611: Validation complete
2023-06-22 01:52:10.944788: Mean Validation Dice: 0.8029053811307724
It’s been a great journey so far. In our next post, we will find the best model configuration and run inference on the unseen test dataset to see how our model will perform.