Explainability for 3DResNet Classifier
Last Updated on January 24, 2025 by Editorial Team
Author(s): Shashwat (Shawn) Gupta
Originally published on Towards AI.
GradCAM is one of the simplest techniques to get explainability insights into model prediction. I was surprised to find that while there are many blogs on Medium about using GradCAM with ResNet, there arenβt any specifically for GradCAM with 3D images (eg. for ResNet3D); and almost none in Pytorch. Furthermore, most Github codes, inspired from 2D GradCAM, do incorrect implementation of GradCAMDetermined to fill this gap, I spent an entire night understanding the intricate details of the code and successfully wrote my own implementation.
What This Code Does
This code builds a ResNet3D model from scratch, which is a 3D version of the popular ResNet (Residual Network) used for image recognition. It incorporates GradCAM (Gradient-weighted Class Activation Mapping), a technique that helps visualize which parts of the input data the model focuses on when making decisions. The model processes NIfTI files (a common format for medical imaging data) listed in train.txt
and test.txt
. Instead of performing image segmentation, we modify the model to do classification by replacing the segmentation layer with a feedforward network (ffn) initialised using Xavier initialisation. Initially, only the new layers are trained while keeping the existing weights fixed. After a few training cycles (epochs), the entire network is fine-tuned. To speed up training, the code utilizes multiple GPUs (Graphics Processing Units) through Data-Parallelism, allowing the model to use all available GPUs efficiently. The train.txt
, test.txt
, and gradcam.txt
files should contain the paths to the .nii.gz
files and their corresponding class labels, separated by a space. For example:
./file1/a.nii.gz 1
./file2/b.nii.gz 0 β¦.
Imports and Parameters
import torch
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import nibabel
from scipy import ndimage
import time
from scipy.ndimage import zoom
import warnings
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
# Parameters (Command Line)
n_epochs = 700
epoch_unfreeze_all = 300
data_root = './data'
train_img_list = './data/train.txt'
test_img_list = './data/test.txt'
manual_seed = 1
num_classes = 2 # Updated for classification
learning_rate = 0.001
num_workers = 4
batch_size = 1
save_intervals = 30
input_D = 56
input_H = 448
input_W = 448
resume_path = '' # Resume from this if it's a file
model_depth = 10 # 10 | 18 | 34 | 50 | 101 | 152 | 200
pretrain_path = f'pretrain/resnet_{model_depth}.pth'
new_layer_names = ['fc'] # Updated to 'fc' for classification head
gpu_id = [i for i in range(num_gpus)]
model = 'resnet'
resnet_shortcut = 'B' # A | B # A - Identity Matrix v B - Projection Matrix
save_folder = "./trails/models/{}_{}".format(model, model_depth)
test_batch_size = 1
test_num_workers = 4
no_cuda = not torch.cuda.is_available()
if not no_cuda and torch.cuda.device_count() > 0:
pin_memory = True
test_pin_memory = True # Set to True if using GPU
print(f"Using GPU(s). Number of GPUs available: {torch.cuda.device_count()}")
else:
pin_memory = False
test_pin_memory = False
print("No GPU available, using CPU.")
Model Description
The ResNet (Residual Network) is a type of neural network that uses residual blocks to allow the network to learn more effectively. In this implementation, we define different types of blocks and layers to build the ResNet3D model tailored for classification tasks.
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
# 3x3x3 convolution with padding
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
dilation=dilation,
stride=stride,
padding=dilation,
bias=False)
def downsample_basic_block(x, planes, stride, no_cuda=False):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if not no_cuda:
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
sample_input_D,
sample_input_H,
sample_input_W,
num_classes, # Changed from num_seg_classes to num_classes
shortcut_type='B',
no_cuda=False):
super(ResNet, self).__init__()
self.inplanes = 64
self.no_cuda = no_cuda
self.conv1 = nn.Conv3d(
1,
64,
kernel_size=7,
stride=(2, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(
block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(
block, 256, layers[2], shortcut_type, stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], shortcut_type, stride=1, dilation=4)
# placeholder for the gradients
self.gradients = None
# Remove or comment out the segmentation head
# self.conv_seg = nn.Sequential(
# ...
# )
# Add a classification head
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) # Binary classification (2 classes)
# Initialize weights for new layers
self._initialize_weights()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
no_cuda=self.no_cuda)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def _initialize_weights(self):
# Initialize weights for the new classification head using Xavier initialization
nn.init.xavier_normal_(self.fc.weight)
if self.fc.bias is not None:
nn.init.constant_(self.fc.bias, 0)
def forward(self, x,reg_hook=True):
# Feature extraction
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# register hook (needed for grad-cam)
if reg_hook:
x.register_hook(self.activations_hook)
# Classification head
x = self.global_pool(x) # [N, 512*expansion, 1, 1, 1]
x = x.view(x.size(0), -1) # [N, 512*expansion]
x = self.fc(x) # [N, num_classes]
return x
# hook for the gradients of the activations
def activations_hook(self, grad):
self.gradients = grad
def get_activations_gradient(self):
return self.gradients
def get_activations(self, x):
"""
This is the feature extractor that returns the activations
for the Grad-CAM method.
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# Get activations from the last convolutional layer
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x # Return the activations from the last convolutional layer
def resnet10_classification(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
return model
def resnet18_classification(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34_classification(**kwargs):
"""Constructs a ResNet-34 model.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50_classification(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
def resnet200_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
return model
[Edit] We use register_hooks which registers βbackward hooksβ and the hooks are called during backward() call for gradients. We compute activations by rewriting the part before registering hooks as get_activations function. A more sophisticated (but less-easy) approach uses backward-hooks and forward hooks; that could be used for any loaded model. A 2D example can be found here : https://towardsdatascience.com/grad-cam-in-pytorch-use-of-forward-and-backward-hooks-7eba5e38d569
Setting Parameters and Training
In the this section, we set up various parameters like the number of training epochs, learning rates, and paths to data and pre-trained models. We also define the optimizer and learning rate scheduler to control how the model learns over time.
# main.py
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import nibabel
from scipy import ndimage
import time
from scipy.ndimage import zoom
import warnings
from sklearn.metrics import (accuracy_score, balanced_accuracy_score, recall_score,
precision_score, f1_score, matthews_corrcoef,
roc_auc_score, confusion_matrix)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Setting Seed
torch.manual_seed(manual_seed)
# Check if the depth is valid
assert model_depth in [10, 18, 34, 50, 101, 152, 200], "Invalid depth"
# Loading the model
model_parameters = {
'sample_input_W': input_W,
'sample_input_H': input_H,
'sample_input_D': input_D,
'shortcut_type': resnet_shortcut,
'no_cuda': no_cuda,
'num_classes': num_classes # Updated parameter
}
# Initialize the appropriate ResNet model for classification
if model_depth == 10:
model = resnet10_classification(**model_parameters)
elif model_depth == 18:
model = resnet18_classification(**model_parameters)
elif model_depth == 34:
model = resnet34_classification(**model_parameters)
elif model_depth == 50:
model = resnet50_classification(**model_parameters)
elif model_depth == 101:
model = resnet101_classification(**model_parameters)
elif model_depth == 152:
model = resnet152_classification(**model_parameters)
elif model_depth == 200:
model = resnet200_classification(**model_parameters)
# Move model to GPU and handle DataParallel if necessary
if not no_cuda:
if len(gpu_id) > 1:
model = model.cuda()
model = nn.DataParallel(model, device_ids=gpu_id)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id[0])
model = model.cuda()
model = nn.DataParallel(model, device_ids=None)
else:
pass # Handle CPU if necessary
# Load pretrained weights
print(f'Loading pretrained model from {pretrain_path}')
pretrain = torch.load(pretrain_path)
# Exclude the classification head from pretrained weights
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in model.state_dict().keys() and 'fc' not in k}
# Update the model's state dict with pretrained weights
model.load_state_dict(pretrain_dict, strict=False)
# Initialize the new classification head with Xavier initialization
def initialize_new_layers(model):
for name, module in model.named_modules():
if 'fc' in name:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
initialize_new_layers(model)
# Freeze the pretrained layers except 'fc'
for name, param in model.named_parameters():
if 'fc' not in name:
param.requires_grad = False
# Define optimizer with classification head parameters
new_parameters = [p for name, p in model.named_parameters() if 'fc' in name]
base_parameters = [p for name, p in model.named_parameters() if 'fc' not in name]
parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
params = [
{'params': parameters['base_parameters'], 'lr': learning_rate},
{'params': parameters['new_parameters'], 'lr': learning_rate * 10} # Higher LR for new layers
]
# Define optimizer for Phase 1: Train only 'fc' layers
optimizer = torch.optim.SGD([
{'params': new_parameters, 'lr': learning_rate * 10} # Higher LR for 'fc'
], momentum=0.9, weight_decay=1e-3)
# Initialize scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# Resume from checkpoint if needed
if os.path.isfile(resume_path):
print(f"=> loading checkpoint '{resume_path}'")
checkpoint = torch.load(resume_path)
model.load_state_dict(checkpoint['state_dict'], strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
print(f"=> loaded checkpoint '{resume_path}' (epoch {checkpoint.get('epoch', 'Unknown')})")
# Assuming this is already part of your main script
Creating and Loading the Dataset
We define a Dataset class to handle loading and preprocessing the NIfTI files. The data is then loaded using a DataLoader, which helps in batching and shuffling the data during training and testing.
class MyDataset(Dataset):
def __init__(self, root_dir, img_list, sets):
with open(img_list, 'r') as f:
self.img_list = [line.strip() for line in f]
print(f"Processing {len(self.img_list)} samples")
self.root_dir = root_dir
self.input_D = sets.input_D
self.input_H = sets.input_H
self.input_W = sets.input_W
self.phase = sets.phase
def __nii2tensorarray__(self, data):
[z, y, x] = data.shape
new_data = np.reshape(data, [1, z, y, x]).astype("float32")
return new_data
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# Each line: "relative_path.nii 0" or "relative_path.nii 1"
line_info = self.img_list[idx].split(" ")
img_path = os.path.join(self.root_dir, line_info[0])
img_nii = nibabel.load(img_path)
img_array = img_nii.get_fdata()
# Read label
if self.phase in ["train", "val", "test"]:
label = int(line_info[1]) # 0 or 1
else:
label = 0 # Placeholder if labels are not available
# Preprocess the image: resize, normalize
img_array = self.__resize_data__(img_array)
img_array = self.__intensity_normalize_one_volume__(img_array)
# Convert to tensor
img_tensor = self.__nii2tensorarray__(img_array)
# Convert to torch tensors
img_tensor = torch.from_numpy(img_tensor)
label_tensor = torch.tensor(label).long()
return img_tensor, label_tensor
def __resize_data__(self, data):
""" Resize the data to the desired input shape """
[depth, height, width] = data.shape
scale = [
self.input_D / depth,
self.input_H / height,
self.input_W / width
]
data = ndimage.zoom(data, scale, order=1) # Changed order to 1 for smoother resizing
return data
def __intensity_normalize_one_volume__(self, volume):
"""
Normalize the intensity of the volume.
Set background (zeros) to random noise.
"""
pixels = volume[volume > 0]
if len(pixels) == 0:
mean = 0
std = 1
else:
mean = pixels.mean()
std = pixels.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
# Fetching data
phase = 'train'
pin_memory = True
class Settings:
def __init__(self, input_D, input_H, input_W, phase, no_cuda=False):
self.input_D = input_D
self.input_H = input_H
self.input_W = input_W
self.phase = phase
self.no_cuda = no_cuda
train_sets = Settings(input_D, input_H, input_W, phase='train', no_cuda=no_cuda)
training_dataset = MyDataset(data_root, train_img_list, train_sets)
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
print(len(training_dataset))
test_sets = Settings(input_D, input_H, input_W, phase='test', no_cuda=no_cuda)
testing_dataset = MyDataset(data_root, test_img_list, test_sets)
test_loader = DataLoader(testing_dataset, batch_size=test_batch_size, shuffle=False, num_workers=test_num_workers, pin_memory=test_pin_memory)
print(len(testing_dataset))
Training the Model
The training function handles the process of feeding data to the model, calculating loss, updating weights, and periodically saving the modelβs state. It also evaluates the modelβs performance on a validation set after each epoch.
import matplotlib.pyplot as plt
def train(data_loader, model, optimizer, scheduler, total_epochs,start_second_phase, save_interval, save_folder, no_cuda):
# Define the loss function
loss_fn = nn.CrossEntropyLoss()
if not no_cuda:
loss_fn = loss_fn.cuda()
# Initialize lists to store metrics
train_losses = []
test_losses = []
train_balanced_accuracies = []
test_balanced_accuracies = []
train_sensitivities = []
test_sensitivities = []
print(f'{total_epochs} epochs in total.')
for epoch in range(total_epochs):
model.train()
scheduler.step()
print(f'Start epoch {epoch + 1}/{total_epochs}')
print(f'Learning rate: {scheduler.get_last_lr()}')
all_labels_train = []
all_preds_train = []
all_probs_train = []
epoch_loss = 0.0
for batch_id, batch_data in enumerate(data_loader):
volumes, labels = batch_data
if not no_cuda:
volumes = volumes.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = model(x=volumes,reg_hook=False) # [N, 2]
# Compute loss
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# Predictions
_, preds = torch.max(outputs, 1) # Predicted class indices
probs = torch.softmax(outputs, dim=1)[:, 1] # Probability of class 1
# Move to CPU and convert to numpy
preds = preds.detach().cpu().numpy()
labels_np = labels.detach().cpu().numpy()
probs = probs.detach().cpu().numpy()
# Accumulate for metrics
all_labels_train.extend(labels_np)
all_preds_train.extend(preds)
all_probs_train.extend(probs)
if (batch_id + 1) % 10 == 0 or (batch_id + 1) == len(data_loader):
print(f'Epoch [{epoch + 1}/{total_epochs}], Step [{batch_id + 1}/{len(data_loader)}], Loss: {loss.item():.4f}')
# Save model checkpoint
global_batch_id = epoch * len(data_loader) + batch_id
if ((epoch+1)%10 == 0 or epoch==0) and ((global_batch_id + 1) % save_interval == 0 or (global_batch_id + 1) >= len(training_dataset)):
model_save_path = os.path.join(save_folder, f'model_epoch_{epoch + 1}_batch_{batch_id + 1}.pth.tar')
os.makedirs(save_folder, exist_ok=True)
torch.save({
'epoch': epoch + 1,
'batch_id': batch_id + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()},
model_save_path)
print(f'Saved checkpoint: {model_save_path}')
# Compute and store average training loss
avg_train_loss = epoch_loss / len(data_loader)
train_losses.append(avg_train_loss)
# Compute test balanced accuracy
train_balanced_acc = balanced_accuracy_score(all_labels_train, all_preds_train)
train_balanced_accuracies.append(train_balanced_acc)
# Compute training sensitivity (Recall)
train_sensitivity = recall_score(all_labels_train, all_preds_train, pos_label=1)
train_sensitivities.append(train_sensitivity)
# Evaluate on Test Set
model.eval()
all_labels_test = []
all_preds_test = []
all_probs_test = []
test_epoch_loss = 0.0
with torch.no_grad():
for test_volumes, test_labels in test_loader:
if not no_cuda:
test_volumes = test_volumes.cuda()
test_labels = test_labels.cuda()
test_outputs = model(x=test_volumes,reg_hook=False) # [N, 2]
test_loss = loss_fn(test_outputs, test_labels)
test_epoch_loss += test_loss.item()
# Predictions
_, test_preds = torch.max(test_outputs, 1)
test_probs = torch.softmax(test_outputs, dim=1)[:, 1]
# Move to CPU and convert to numpy
test_preds = test_preds.detach().cpu().numpy()
test_labels_np = test_labels.detach().cpu().numpy()
test_probs = test_probs.detach().cpu().numpy()
# Accumulate for test metrics
all_labels_test.extend(test_labels_np)
all_preds_test.extend(test_preds)
all_probs_test.extend(test_probs)
# Compute and store average test loss
avg_test_loss = test_epoch_loss / len(test_loader)
test_losses.append(avg_test_loss)
# Compute test balanced accuracy
test_balanced_acc = balanced_accuracy_score(all_labels_test, all_preds_test)
test_balanced_accuracies.append(test_balanced_acc)
# Compute test sensitivity (Recall)
test_sensitivity = recall_score(all_labels_test, all_preds_test, pos_label=1)
test_sensitivities.append(test_sensitivity)
# Compute metrics at the end of the epoch
accuracy = accuracy_score(all_labels_train, all_preds_train)
balanced_acc = balanced_accuracy_score(all_labels_train, all_preds_train)
sensitivity = recall_score(all_labels_train, all_preds_train, pos_label=1)
precision = precision_score(all_labels_train, all_preds_train, pos_label=1)
f1 = f1_score(all_labels_train, all_preds_train, pos_label=1)
mcc = matthews_corrcoef(all_labels_train, all_preds_train)
# Compute specificity
cm = confusion_matrix(all_labels_train, all_preds_train)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
else:
specificity = 0 # Undefined if not binary
# Compute AUC
try:
auc = roc_auc_score(all_labels_train, all_probs_train)
except ValueError:
auc = 0.0 # Undefined if only one class is present
# Display metrics
print(f'--- Epoch {epoch + 1} Metrics ---')
print(f'Train Loss: {epoch_loss / len(data_loader):.4f}')
print(f'Train Accuracy: {accuracy:.4f}')
print(f'Train Balanced Accuracy: {balanced_acc:.4f}')
print(f'Train Sensitivity (Recall): {sensitivity:.4f}')
print(f'Train Specificity: {specificity:.4f}')
print(f'Train Precision (PPV): {precision:.4f}')
print(f'Train F1 Score: {f1:.4f}')
print(f'Train Matthews Correlation Coefficient (MCC): {mcc:.4f}')
print(f'Train AUC: {auc:.4f}')
print(f'Test Loss: {avg_test_loss:.4f}')
print(f'Test BA: {test_balanced_acc:.4f}')
print(f'Test Sensitivity: {test_sensitivity:.4f}')
print('-------------------------\n')
# Phase Transition: After 50 epochs, unfreeze base_parameters and retrain all layers
if epoch + 1 == start_second_phase:
print('--- Transitioning to Phase 2: Fine-tuning all layers ---')
# Unfreeze base parameters
for param in base_parameters:
param.requires_grad = True
# Redefine the optimizer to include all parameters with appropriate learning rates
optimizer = torch.optim.SGD([
{'params': base_parameters, 'lr': learning_rate}, # Base layers
{'params': new_parameters, 'lr': learning_rate * 10} # 'fc' layer
], momentum=0.9, weight_decay=1e-3)
# Optionally, redefine the scheduler if necessary
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
print('Phase 2 initialized: All layers are now trainable.')
print("Finished Training")
# Plotting the metrics
epochs = range(1, total_epochs + 1)
# 1. Plot Training Loss and Test Loss vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'loss_vs_epoch.png'))
# 2. Plot Test Balanced Accuracy vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, test_balanced_accuracies, label='Test Balanced Accuracy', color='green')
plt.plot(epochs, train_balanced_accuracies, label='Train Balanced Accuracy', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Balanced Accuracy')
plt.title('Test Balanced Accuracy vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'balanced_accuracy_vs_epoch.png'))
# 3. Plot Sensitivity (Recall) for Test and Training vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_sensitivities, label='Train Sensitivity (Recall)', color='orange')
plt.plot(epochs, test_sensitivities, label='Test Sensitivity (Recall)', color='red')
plt.xlabel('Epoch')
plt.ylabel('Sensitivity (Recall)')
plt.title('Sensitivity (Recall) vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'sensitivity_vs_epoch.png'))
print('Starting training...')
# Initialize lists to store metrics
train(train_loader, model, optimizer, scheduler, total_epochs=n_epochs,start_second_phase=epoch_unfreeze_all , save_interval=save_intervals, save_folder=save_folder, no_cuda=no_cuda)
print('Finished training')
Implementing GradCAM
GradCAM helps in visualizing which parts of the input data are most influential in the modelβs decision-making process. This section of the code generates GradCAM heatmaps and overlays them on the original NIfTI images for better visualization.
# GradCAM
import torch
import numpy as np
import cv2
import nibabel as nib
import torch.nn.functional as F
import os
from math import ceil
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
def generate_gradcam_heatmap(net, img, original_shape=None):
"""Generate a 3D Grad-CAM heatmap for the input 3D image."""
net.eval() # Set the model to evaluation mode
if isinstance(net, torch.nn.DataParallel):
net = net.module
# Perform forward pass and gradient computation on GPU
img = img.cuda(non_blocking=True)
pred = net(x=img, reg_hook=True)
pred[:, pred.argmax(dim=1)].backward()
# Retrieve gradients and activations
gradients = net.get_activations_gradient().cpu() # Move to CPU
activations = net.get_activations(img).cpu() # Move to CPU
img = img.cpu() # Free GPU memory as no longer needed
# Pool gradients
pooled_gradients = gradients.mean(dim=(2, 3, 4), keepdim=True)
# Weight activations
weighted_activations = activations * pooled_gradients
heatmap = weighted_activations.mean(dim=1) # Average over channels
# Apply ReLU and normalize
heatmap = F.relu(heatmap)
heatmap -= heatmap.min()
heatmap /= heatmap.max() + 1e-6 # Avoid divide-by-zero
return heatmap
def visualize_and_save_gradcam(net, data_loader, num_images=10, output_dir="gradcam_outputs",batch_size=1,opacity=0.4):
"""
Visualize and save Grad-CAM heatmaps as 3D volumes.
Args:
net: The trained network (e.g., your 3D ResNet).
data_loader: DataLoader for the validation/test set.
num_images: Number of images to process.
output_dir: Directory to save the Grad-CAM outputs.
"""
# Make sure the output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
it = iter(data_loader)
num_final_images = num_images % batch_size
num_batches = ceil(num_images / batch_size)
print('num_final_images',num_final_images)
print('num_batches',num_batches)
# Define the SpiceJet color map
spicejet_colors = [(0, 0, 1), (0, 1, 1), (1, 1, 0), (1, 0, 0)] # Blue -> Cyan -> Yellow -> Red
spicejet_cmap = LinearSegmentedColormap.from_list('SpiceJet', spicejet_colors)
for i in range(num_batches):
img, label = next(it)
img = img.cuda() # Move to GPU if needed
# Get the original shape of the image
original_shape = img.shape # (batch_size, channels, D, H, W)
# Generate Grad-CAM heatmap for the 3D image
heatmap_3d = generate_gradcam_heatmap(net, img, original_shape=original_shape)
if i + 1 == num_batches: # last batch
internal_iter = num_final_images + 1
else:
internal_iter = batch_size
for j in range(internal_iter):
temp_heatmap = heatmap_3d[j]
temp_heatmap = temp_heatmap.unsqueeze(0).unsqueeze(0)
img_nii = nibabel.load(files[i*batch_size + j][0])
img_array = img_nii.get_fdata()
target_size = tuple(img_array.shape)
heatmap = F.interpolate(temp_heatmap, size=target_size, mode='trilinear', align_corners=False)
heatmap = heatmap.squeeze(0).squeeze(0).detach().cpu().numpy()
# Create a NIfTI image and save it
output_filename = os.path.join(output_dir, f"gradcam_{os.path.basename(files[i*batch_size + j][0])}")
img_nii2 = nib.Nifti1Image(heatmap, img_nii.affine)
nib.save(img_nii2, output_filename)
print(f"Saved {output_filename}")
# Saving overlay to view in NiiVue:
# Load original image
original_img = img_nii
original_data = original_img.get_fdata()
# Normalize the original image to 0-255 range for RGB
original_normalized = ((original_data - np.min(original_data)) * 255 /
(np.max(original_data) - np.min(original_data))).astype(np.uint8)
# Create an RGB image from the original data
rgb_image = np.zeros((*original_data.shape, 3), dtype=np.uint8)
for i in range(3):
rgb_image[..., i] = original_normalized
# Load Grad-CAM mask
gradcam_img = img_nii2
gradcam_data = gradcam_img.get_fdata()
# Normalize Grad-CAM values to 0-1 range
gradcam_normalized = (gradcam_data - np.min(gradcam_data)) / (np.max(gradcam_data) - np.min(gradcam_data))
# Map normalized Grad-CAM values to the SpiceJet colormap
gradcam_colored = spicejet_cmap(gradcam_normalized)
# Overlay Grad-CAM colors onto the original image with opacity
for i in range(3): # R, G, B channels
rgb_image[..., i] = np.clip(
(1 - opacity) * rgb_image[..., i] + opacity * (gradcam_colored[..., i] * 255),
0, 255).astype(np.uint8)
# Save the overlay as a new NIfTI file
shape_3d = rgb_image.shape[:3]
rgb_dtype = np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')])
ras_pos = rgb_image.copy().view(dtype=rgb_dtype).reshape(shape_3d)
overlay_img = nib.Nifti1Image(ras_pos, original_img.affine)
nib.save(overlay_img, os.path.join(output_dir, f"ogradcam_{os.path.basename(files[i*batch_size + j][0])}"))
print(f"Overlay saved as {os.path.join(output_dir, f"ogradcam_{os.path.basename(files[i*batch_size + j][0])}")}")
# Make the dataset of a list of images for which we want to compute GradCAM
files = [
('./data/A.nii.gz', 1),
('./data/B.nii.gz', 0),
('./data/D.nii.gz', 1),
]
gradcam_img_list = './data/gradcam.txt'
# Join the file names and labels with newline separator
gradcam_img_list_content = "\n".join([f[0] + " " + str(f[1]) for f in files])
# Write the content to the file
with open(gradcam_img_list, 'w') as file:
file.write(gradcam_img_list_content)
print("File written successfully!")
# Load model from checkpoint
checkpoint_path = "./trails/models/atestbtrain_resnet_10/model_epoch_500_batch_85.pth.tar"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
gradcam_sets = Settings(input_D, input_H, input_W, phase='test', no_cuda=no_cuda)
gradcam_dataset = MyDataset(data_root, gradcam_img_list, gradcam_sets)
grad_validationloader = DataLoader(testing_dataset, batch_size=test_batch_size, shuffle=False, num_workers=test_num_workers, pin_memory=test_pin_memory)
opacity = 0.4
visualize_and_save_gradcam(model, grad_validationloader, num_images=len(files), output_dir="gradcam_outputs",batch_size = batch_size,opacity=opacity)
Visualizing the Results
The GradCAM process generates .nii.gz
files that visualize the areas of the original NIfTI images the model focused on. These visualizations can be viewed using NiiVue in VSCode, allowing you to see 3D images with highlighted regions indicating the model's attention.
Conclusion
I hope this guide helps those who, like me, were struggling to find resources on GradCAM with ResNet3D. Implementing these techniques can greatly enhance your understanding of how models make decisions, especially in complex 3D data scenarios. Feel free to leave comments or reach out if any part needs more explanation, and Iβll update the blog accordingly.
References
These references provide foundational knowledge and resources related to ResNet3D, GradCAM, data handling with NIfTI files, PyTorch functionalities, and visualization tools used in the implementation. They should help deepen your understanding and offer additional guidance as you work with similar projects.
- ResNet Paper: Deep Residual Learning for Image Recognition : https://arxiv.org/abs/1512.03385
- Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization : https://arxiv.org/abs/1610.02391
- Nibabel: Accessing the data stored in common neuroimaging file formats : https://nipy.org/nibabel/
- NiiVue: NIfTI Viewer for VSCode by JoΓ£o Moren : https://marketplace.visualstudio.com/items?itemName=joaomoreno.NiiVue
- PyTorch Documentation : https://pytorch.org/docs/stable/index.html (Please learn Pytorch lightning if you are new to Pytorch since it makes lives simpler)
- Scipy.ndimage: n-dimensional image processin : https://docs.scipy.org/doc/scipy/reference/ndimage.html
- Data Parallelism in PyTorch: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
- Understanding Xavier Initialization: https://pytorch.org/docs/stable/generated/torch.nn.init.xavier_normal_.html
- 3D ResNet Implementation and the pretrained weights were taken from Monai: https://github.com/kenshohara/3D-ResNets-PyTorch
- NIfTI File Format Specification : https://nifti.nimh.nih.gov/nifti-1
- Matplotlib Documentation: https://matplotlib.org/stable/contents.html
- SciPy Documentation: https://www.scipy.org/docs.html
- Torchvision Models: https://pytorch.org/vision/stable/models.html
- Automatic Mixed Precision (AMP) in PyTorch: https://pytorch.org/docs/stable/amp.html
- Balanced Accuracy Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
- Matthews Correlation Coefficient (MCC): https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
- ROC AUC Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
- NiiVue VSCode Extension Documentation: https://github.com/joaomoreno/vscode-nii-vue
- 3D Convolutional Neural Networks: https://towardsdatascience.com/a-comprehensive-introduction-to-different-types-of-convolutions-in-deep-learning-669281e58215
- Exponential Learning Rate Scheduler in PyTorch: https://pytorch.org/docs/stable/optim.html#exponentiallr
Note: This blog is reuploaded due to some comments about the formatting issues.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI