Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: [email protected]
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Take our 85+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!

Publication

Explainability for 3DResNet Classifier
Latest   Machine Learning

Explainability for 3DResNet Classifier

Last Updated on January 24, 2025 by Editorial Team

Author(s): Shashwat (Shawn) Gupta

Originally published on Towards AI.

GradCAM is one of the simplest techniques to get explainability insights into model prediction. I was surprised to find that while there are many blogs on Medium about using GradCAM with ResNet, there aren’t any specifically for GradCAM with 3D images (eg. for ResNet3D); and almost none in Pytorch. Furthermore, most Github codes, inspired from 2D GradCAM, do incorrect implementation of GradCAMDetermined to fill this gap, I spent an entire night understanding the intricate details of the code and successfully wrote my own implementation.

2D Explainabilty by GradCAM. Source: author X-ray image from kaggle Chest X-ray dataset

What This Code Does

This code builds a ResNet3D model from scratch, which is a 3D version of the popular ResNet (Residual Network) used for image recognition. It incorporates GradCAM (Gradient-weighted Class Activation Mapping), a technique that helps visualize which parts of the input data the model focuses on when making decisions. The model processes NIfTI files (a common format for medical imaging data) listed in train.txt and test.txt. Instead of performing image segmentation, we modify the model to do classification by replacing the segmentation layer with a feedforward network (ffn) initialised using Xavier initialisation. Initially, only the new layers are trained while keeping the existing weights fixed. After a few training cycles (epochs), the entire network is fine-tuned. To speed up training, the code utilizes multiple GPUs (Graphics Processing Units) through Data-Parallelism, allowing the model to use all available GPUs efficiently. The train.txt, test.txt, and gradcam.txt files should contain the paths to the .nii.gz files and their corresponding class labels, separated by a space. For example:

./file1/a.nii.gz 1
./file2/b.nii.gz 0 ….

Imports and Parameters

import torch

num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import nibabel
from scipy import ndimage
import time
from scipy.ndimage import zoom
import warnings
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
# Parameters (Command Line)
n_epochs = 700
epoch_unfreeze_all = 300
data_root = './data'
train_img_list = './data/train.txt'
test_img_list = './data/test.txt'
manual_seed = 1
num_classes = 2 # Updated for classification
learning_rate = 0.001
num_workers = 4
batch_size = 1
save_intervals = 30
input_D = 56
input_H = 448
input_W = 448
resume_path = '' # Resume from this if it's a file
model_depth = 10 # 10 | 18 | 34 | 50 | 101 | 152 | 200
pretrain_path = f'pretrain/resnet_{model_depth}.pth'
new_layer_names = ['fc'] # Updated to 'fc' for classification head
gpu_id = [i for i in range(num_gpus)]
model = 'resnet'
resnet_shortcut = 'B' # A | B # A - Identity Matrix v B - Projection Matrix
save_folder = "./trails/models/{}_{}".format(model, model_depth)
test_batch_size = 1
test_num_workers = 4
no_cuda = not torch.cuda.is_available()
if not no_cuda and torch.cuda.device_count() > 0:
pin_memory = True
test_pin_memory = True # Set to True if using GPU
print(f"Using GPU(s). Number of GPUs available: {torch.cuda.device_count()}")
else:
pin_memory = False
test_pin_memory = False
print("No GPU available, using CPU.")

Model Description

The ResNet (Residual Network) is a type of neural network that uses residual blocks to allow the network to learn more effectively. In this implementation, we define different types of blocks and layers to build the ResNet3D model tailored for classification tasks.

def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
# 3x3x3 convolution with padding
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
dilation=dilation,
stride=stride,
padding=dilation,
bias=False)

def downsample_basic_block(x, planes, stride, no_cuda=False):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if not no_cuda:
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
sample_input_D,
sample_input_H,
sample_input_W,
num_classes, # Changed from num_seg_classes to num_classes
shortcut_type='B',
no_cuda=False
):
super(ResNet, self).__init__()
self.inplanes = 64
self.no_cuda = no_cuda
self.conv1 = nn.Conv3d(
1,
64,
kernel_size=7,
stride=(2, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(
block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(
block, 256, layers[2], shortcut_type, stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], shortcut_type, stride=1, dilation=4)

# placeholder for the gradients
self.gradients = None
# Remove or comment out the segmentation head
# self.conv_seg = nn.Sequential(
# ...
# )
# Add a classification head
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) # Binary classification (2 classes)
# Initialize weights for new layers
self._initialize_weights()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
no_cuda=self.no_cuda)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def _initialize_weights(self):
# Initialize weights for the new classification head using Xavier initialization
nn.init.xavier_normal_(self.fc.weight)
if self.fc.bias is not None:
nn.init.constant_(self.fc.bias, 0)
def forward(self, x,reg_hook=True):
# Feature extraction
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# register hook (needed for grad-cam)
if reg_hook:
x.register_hook(self.activations_hook)
# Classification head
x = self.global_pool(x) # [N, 512*expansion, 1, 1, 1]
x = x.view(x.size(0), -1) # [N, 512*expansion]
x = self.fc(x) # [N, num_classes]
return x

# hook for the gradients of the activations
def activations_hook(self, grad):
self.gradients = grad
def get_activations_gradient(self):
return self.gradients

def get_activations(self, x):
"""
This is the feature extractor that returns the activations
for the Grad-CAM method.
"""

x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# Get activations from the last convolutional layer
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

return x # Return the activations from the last convolutional layer
def resnet10_classification(**kwargs):
"""Constructs a ResNet-18 model.
"""

model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
return model
def resnet18_classification(**kwargs):
"""Constructs a ResNet-18 model.
"""

model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34_classification(**kwargs):
"""Constructs a ResNet-34 model.
"""

model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50_classification(**kwargs):
"""Constructs a ResNet-50 model.
"""

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""

model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""

model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
def resnet200_classification(**kwargs):
"""Constructs a ResNet-101 model.
"""

model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
return model

[Edit] We use register_hooks which registers β€˜backward hooks’ and the hooks are called during backward() call for gradients. We compute activations by rewriting the part before registering hooks as get_activations function. A more sophisticated (but less-easy) approach uses backward-hooks and forward hooks; that could be used for any loaded model. A 2D example can be found here : https://towardsdatascience.com/grad-cam-in-pytorch-use-of-forward-and-backward-hooks-7eba5e38d569

Setting Parameters and Training

In the this section, we set up various parameters like the number of training epochs, learning rates, and paths to data and pre-trained models. We also define the optimizer and learning rate scheduler to control how the model learns over time.

# main.py
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import nibabel
from scipy import ndimage
import time
from scipy.ndimage import zoom
import warnings
from sklearn.metrics import (accuracy_score, balanced_accuracy_score, recall_score,
precision_score, f1_score, matthews_corrcoef,
roc_auc_score, confusion_matrix)

warnings.filterwarnings("ignore", category=DeprecationWarning)
# Setting Seed
torch.manual_seed(manual_seed)
# Check if the depth is valid
assert model_depth in [10, 18, 34, 50, 101, 152, 200], "Invalid depth"
# Loading the model
model_parameters = {
'sample_input_W': input_W,
'sample_input_H': input_H,
'sample_input_D': input_D,
'shortcut_type': resnet_shortcut,
'no_cuda': no_cuda,
'num_classes': num_classes # Updated parameter
}
# Initialize the appropriate ResNet model for classification
if model_depth == 10:
model = resnet10_classification(**model_parameters)
elif model_depth == 18:
model = resnet18_classification(**model_parameters)
elif model_depth == 34:
model = resnet34_classification(**model_parameters)
elif model_depth == 50:
model = resnet50_classification(**model_parameters)
elif model_depth == 101:
model = resnet101_classification(**model_parameters)
elif model_depth == 152:
model = resnet152_classification(**model_parameters)
elif model_depth == 200:
model = resnet200_classification(**model_parameters)
# Move model to GPU and handle DataParallel if necessary
if not no_cuda:
if len(gpu_id) > 1:
model = model.cuda()
model = nn.DataParallel(model, device_ids=gpu_id)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id[0])
model = model.cuda()
model = nn.DataParallel(model, device_ids=None)
else:
pass # Handle CPU if necessary
# Load pretrained weights
print(f'Loading pretrained model from {pretrain_path}')
pretrain = torch.load(pretrain_path)
# Exclude the classification head from pretrained weights
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in model.state_dict().keys() and 'fc' not in k}
# Update the model's state dict with pretrained weights
model.load_state_dict(pretrain_dict, strict=False)
# Initialize the new classification head with Xavier initialization
def initialize_new_layers(model):
for name, module in model.named_modules():
if 'fc' in name:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
initialize_new_layers(model)
# Freeze the pretrained layers except 'fc'
for name, param in model.named_parameters():
if 'fc' not in name:
param.requires_grad = False
# Define optimizer with classification head parameters
new_parameters = [p for name, p in model.named_parameters() if 'fc' in name]
base_parameters = [p for name, p in model.named_parameters() if 'fc' not in name]
parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
params = [
{'params': parameters['base_parameters'], 'lr': learning_rate},
{'params': parameters['new_parameters'], 'lr': learning_rate * 10} # Higher LR for new layers
]
# Define optimizer for Phase 1: Train only 'fc' layers
optimizer = torch.optim.SGD([
{'params': new_parameters, 'lr': learning_rate * 10} # Higher LR for 'fc'
], momentum=0.9, weight_decay=1e-3)
# Initialize scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# Resume from checkpoint if needed
if os.path.isfile(resume_path):
print(f"=> loading checkpoint '{resume_path}'")
checkpoint = torch.load(resume_path)
model.load_state_dict(checkpoint['state_dict'], strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
print(f"=> loaded checkpoint '{resume_path}' (epoch {checkpoint.get('epoch', 'Unknown')})")
# Assuming this is already part of your main script

Creating and Loading the Dataset

We define a Dataset class to handle loading and preprocessing the NIfTI files. The data is then loaded using a DataLoader, which helps in batching and shuffling the data during training and testing.

class MyDataset(Dataset):
def __init__(self, root_dir, img_list, sets):
with open(img_list, 'r') as f:
self.img_list = [line.strip() for line in f]
print(f"Processing {len(self.img_list)} samples")
self.root_dir = root_dir
self.input_D = sets.input_D
self.input_H = sets.input_H
self.input_W = sets.input_W
self.phase = sets.phase
def __nii2tensorarray__(self, data):
[z, y, x] = data.shape
new_data = np.reshape(data, [1, z, y, x]).astype("float32")
return new_data
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# Each line: "relative_path.nii 0" or "relative_path.nii 1"
line_info = self.img_list[idx].split(" ")
img_path = os.path.join(self.root_dir, line_info[0])
img_nii = nibabel.load(img_path)
img_array = img_nii.get_fdata()
# Read label
if self.phase in ["train", "val", "test"]:
label = int(line_info[1]) # 0 or 1
else:
label = 0 # Placeholder if labels are not available
# Preprocess the image: resize, normalize
img_array = self.__resize_data__(img_array)
img_array = self.__intensity_normalize_one_volume__(img_array)
# Convert to tensor
img_tensor = self.__nii2tensorarray__(img_array)

# Convert to torch tensors
img_tensor = torch.from_numpy(img_tensor)
label_tensor = torch.tensor(label).long()
return img_tensor, label_tensor
def __resize_data__(self, data):
""" Resize the data to the desired input shape """
[depth, height, width] = data.shape
scale = [
self.input_D / depth,
self.input_H / height,
self.input_W / width
]
data = ndimage.zoom(data, scale, order=1) # Changed order to 1 for smoother resizing
return data
def __intensity_normalize_one_volume__(self, volume):
"""
Normalize the intensity of the volume.
Set background (zeros) to random noise.
"""

pixels = volume[volume > 0]
if len(pixels) == 0:
mean = 0
std = 1
else:
mean = pixels.mean()
std = pixels.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
# Fetching data
phase = 'train'
pin_memory = True
class Settings:
def __init__(self, input_D, input_H, input_W, phase, no_cuda=False):
self.input_D = input_D
self.input_H = input_H
self.input_W = input_W
self.phase = phase
self.no_cuda = no_cuda

train_sets = Settings(input_D, input_H, input_W, phase='train', no_cuda=no_cuda)
training_dataset = MyDataset(data_root, train_img_list, train_sets)
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
print(len(training_dataset))
test_sets = Settings(input_D, input_H, input_W, phase='test', no_cuda=no_cuda)
testing_dataset = MyDataset(data_root, test_img_list, test_sets)
test_loader = DataLoader(testing_dataset, batch_size=test_batch_size, shuffle=False, num_workers=test_num_workers, pin_memory=test_pin_memory)
print(len(testing_dataset))

Training the Model

The training function handles the process of feeding data to the model, calculating loss, updating weights, and periodically saving the model’s state. It also evaluates the model’s performance on a validation set after each epoch.

import matplotlib.pyplot as plt

def train(data_loader, model, optimizer, scheduler, total_epochs,start_second_phase, save_interval, save_folder, no_cuda):
# Define the loss function
loss_fn = nn.CrossEntropyLoss()
if not no_cuda:
loss_fn = loss_fn.cuda()
# Initialize lists to store metrics
train_losses = []
test_losses = []
train_balanced_accuracies = []
test_balanced_accuracies = []
train_sensitivities = []
test_sensitivities = []

print(f'{total_epochs} epochs in total.')

for epoch in range(total_epochs):
model.train()
scheduler.step()
print(f'Start epoch {epoch + 1}/{total_epochs}')
print(f'Learning rate: {scheduler.get_last_lr()}')

all_labels_train = []
all_preds_train = []
all_probs_train = []

epoch_loss = 0.0

for batch_id, batch_data in enumerate(data_loader):
volumes, labels = batch_data

if not no_cuda:
volumes = volumes.cuda()
labels = labels.cuda()

optimizer.zero_grad()
outputs = model(x=volumes,reg_hook=False) # [N, 2]

# Compute loss
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

epoch_loss += loss.item()

# Predictions
_, preds = torch.max(outputs, 1) # Predicted class indices
probs = torch.softmax(outputs, dim=1)[:, 1] # Probability of class 1

# Move to CPU and convert to numpy
preds = preds.detach().cpu().numpy()
labels_np = labels.detach().cpu().numpy()
probs = probs.detach().cpu().numpy()

# Accumulate for metrics
all_labels_train.extend(labels_np)
all_preds_train.extend(preds)
all_probs_train.extend(probs)

if (batch_id + 1) % 10 == 0 or (batch_id + 1) == len(data_loader):
print(f'Epoch [{epoch + 1}/{total_epochs}], Step [{batch_id + 1}/{len(data_loader)}], Loss: {loss.item():.4f}')

# Save model checkpoint
global_batch_id = epoch * len(data_loader) + batch_id
if ((epoch+1)%10 == 0 or epoch==0) and ((global_batch_id + 1) % save_interval == 0 or (global_batch_id + 1) >= len(training_dataset)):
model_save_path = os.path.join(save_folder, f'model_epoch_{epoch + 1}_batch_{batch_id + 1}.pth.tar')
os.makedirs(save_folder, exist_ok=True)
torch.save({
'epoch': epoch + 1,
'batch_id': batch_id + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()},
model_save_path)
print(f'Saved checkpoint: {model_save_path}')
# Compute and store average training loss
avg_train_loss = epoch_loss / len(data_loader)
train_losses.append(avg_train_loss)
# Compute test balanced accuracy
train_balanced_acc = balanced_accuracy_score(all_labels_train, all_preds_train)
train_balanced_accuracies.append(train_balanced_acc)

# Compute training sensitivity (Recall)
train_sensitivity = recall_score(all_labels_train, all_preds_train, pos_label=1)
train_sensitivities.append(train_sensitivity)
# Evaluate on Test Set
model.eval()
all_labels_test = []
all_preds_test = []
all_probs_test = []
test_epoch_loss = 0.0
with torch.no_grad():
for test_volumes, test_labels in test_loader:
if not no_cuda:
test_volumes = test_volumes.cuda()
test_labels = test_labels.cuda()

test_outputs = model(x=test_volumes,reg_hook=False) # [N, 2]
test_loss = loss_fn(test_outputs, test_labels)
test_epoch_loss += test_loss.item()

# Predictions
_, test_preds = torch.max(test_outputs, 1)
test_probs = torch.softmax(test_outputs, dim=1)[:, 1]

# Move to CPU and convert to numpy
test_preds = test_preds.detach().cpu().numpy()
test_labels_np = test_labels.detach().cpu().numpy()
test_probs = test_probs.detach().cpu().numpy()

# Accumulate for test metrics
all_labels_test.extend(test_labels_np)
all_preds_test.extend(test_preds)
all_probs_test.extend(test_probs)
# Compute and store average test loss
avg_test_loss = test_epoch_loss / len(test_loader)
test_losses.append(avg_test_loss)

# Compute test balanced accuracy
test_balanced_acc = balanced_accuracy_score(all_labels_test, all_preds_test)
test_balanced_accuracies.append(test_balanced_acc)

# Compute test sensitivity (Recall)
test_sensitivity = recall_score(all_labels_test, all_preds_test, pos_label=1)
test_sensitivities.append(test_sensitivity)

# Compute metrics at the end of the epoch
accuracy = accuracy_score(all_labels_train, all_preds_train)
balanced_acc = balanced_accuracy_score(all_labels_train, all_preds_train)
sensitivity = recall_score(all_labels_train, all_preds_train, pos_label=1)
precision = precision_score(all_labels_train, all_preds_train, pos_label=1)
f1 = f1_score(all_labels_train, all_preds_train, pos_label=1)
mcc = matthews_corrcoef(all_labels_train, all_preds_train)

# Compute specificity
cm = confusion_matrix(all_labels_train, all_preds_train)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
else:
specificity = 0 # Undefined if not binary

# Compute AUC
try:
auc = roc_auc_score(all_labels_train, all_probs_train)
except ValueError:
auc = 0.0 # Undefined if only one class is present

# Display metrics
print(f'--- Epoch {epoch + 1} Metrics ---')
print(f'Train Loss: {epoch_loss / len(data_loader):.4f}')
print(f'Train Accuracy: {accuracy:.4f}')
print(f'Train Balanced Accuracy: {balanced_acc:.4f}')
print(f'Train Sensitivity (Recall): {sensitivity:.4f}')
print(f'Train Specificity: {specificity:.4f}')
print(f'Train Precision (PPV): {precision:.4f}')
print(f'Train F1 Score: {f1:.4f}')
print(f'Train Matthews Correlation Coefficient (MCC): {mcc:.4f}')
print(f'Train AUC: {auc:.4f}')
print(f'Test Loss: {avg_test_loss:.4f}')
print(f'Test BA: {test_balanced_acc:.4f}')
print(f'Test Sensitivity: {test_sensitivity:.4f}')
print('-------------------------\n')

# Phase Transition: After 50 epochs, unfreeze base_parameters and retrain all layers
if epoch + 1 == start_second_phase:
print('--- Transitioning to Phase 2: Fine-tuning all layers ---')

# Unfreeze base parameters
for param in base_parameters:
param.requires_grad = True

# Redefine the optimizer to include all parameters with appropriate learning rates
optimizer = torch.optim.SGD([
{'params': base_parameters, 'lr': learning_rate}, # Base layers
{'params': new_parameters, 'lr': learning_rate * 10} # 'fc' layer
], momentum=0.9, weight_decay=1e-3)

# Optionally, redefine the scheduler if necessary
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

print('Phase 2 initialized: All layers are now trainable.')
print("Finished Training")
# Plotting the metrics
epochs = range(1, total_epochs + 1)

# 1. Plot Training Loss and Test Loss vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Training Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'loss_vs_epoch.png'))

# 2. Plot Test Balanced Accuracy vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, test_balanced_accuracies, label='Test Balanced Accuracy', color='green')
plt.plot(epochs, train_balanced_accuracies, label='Train Balanced Accuracy', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Balanced Accuracy')
plt.title('Test Balanced Accuracy vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'balanced_accuracy_vs_epoch.png'))

# 3. Plot Sensitivity (Recall) for Test and Training vs. Epoch
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_sensitivities, label='Train Sensitivity (Recall)', color='orange')
plt.plot(epochs, test_sensitivities, label='Test Sensitivity (Recall)', color='red')
plt.xlabel('Epoch')
plt.ylabel('Sensitivity (Recall)')
plt.title('Sensitivity (Recall) vs. Epoch')
plt.legend()
plt.grid(True)
plt.axvline(x=start_second_phase, color='green', linestyle='dotted', linewidth=2, label=f'start of 2nd phase')
plt.savefig(os.path.join(save_folder, 'sensitivity_vs_epoch.png'))
print('Starting training...')
# Initialize lists to store metrics
train(train_loader, model, optimizer, scheduler, total_epochs=n_epochs,start_second_phase=epoch_unfreeze_all , save_interval=save_intervals, save_folder=save_folder, no_cuda=no_cuda)
print('Finished training')

Implementing GradCAM

GradCAM helps in visualizing which parts of the input data are most influential in the model’s decision-making process. This section of the code generates GradCAM heatmaps and overlays them on the original NIfTI images for better visualization.

# GradCAM
import torch
import numpy as np
import cv2
import nibabel as nib
import torch.nn.functional as F
import os
from math import ceil
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
def generate_gradcam_heatmap(net, img, original_shape=None):
"""Generate a 3D Grad-CAM heatmap for the input 3D image."""
net.eval() # Set the model to evaluation mode
if isinstance(net, torch.nn.DataParallel):
net = net.module
# Perform forward pass and gradient computation on GPU
img = img.cuda(non_blocking=True)
pred = net(x=img, reg_hook=True)
pred[:, pred.argmax(dim=1)].backward()
# Retrieve gradients and activations
gradients = net.get_activations_gradient().cpu() # Move to CPU
activations = net.get_activations(img).cpu() # Move to CPU
img = img.cpu() # Free GPU memory as no longer needed
# Pool gradients
pooled_gradients = gradients.mean(dim=(2, 3, 4), keepdim=True)
# Weight activations
weighted_activations = activations * pooled_gradients
heatmap = weighted_activations.mean(dim=1) # Average over channels
# Apply ReLU and normalize
heatmap = F.relu(heatmap)
heatmap -= heatmap.min()
heatmap /= heatmap.max() + 1e-6 # Avoid divide-by-zero
return heatmap

def visualize_and_save_gradcam(net, data_loader, num_images=10, output_dir="gradcam_outputs",batch_size=1,opacity=0.4):
"""
Visualize and save Grad-CAM heatmaps as 3D volumes.
Args:
net: The trained network (e.g., your 3D ResNet).
data_loader: DataLoader for the validation/test set.
num_images: Number of images to process.
output_dir: Directory to save the Grad-CAM outputs.
"""

# Make sure the output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
it = iter(data_loader)
num_final_images = num_images % batch_size
num_batches = ceil(num_images / batch_size)
print('num_final_images',num_final_images)
print('num_batches',num_batches)
# Define the SpiceJet color map
spicejet_colors = [(0, 0, 1), (0, 1, 1), (1, 1, 0), (1, 0, 0)] # Blue -> Cyan -> Yellow -> Red
spicejet_cmap = LinearSegmentedColormap.from_list('SpiceJet', spicejet_colors)
for i in range(num_batches):
img, label = next(it)
img = img.cuda() # Move to GPU if needed
# Get the original shape of the image
original_shape = img.shape # (batch_size, channels, D, H, W)
# Generate Grad-CAM heatmap for the 3D image
heatmap_3d = generate_gradcam_heatmap(net, img, original_shape=original_shape)
if i + 1 == num_batches: # last batch
internal_iter = num_final_images + 1
else:
internal_iter = batch_size
for j in range(internal_iter):
temp_heatmap = heatmap_3d[j]
temp_heatmap = temp_heatmap.unsqueeze(0).unsqueeze(0)
img_nii = nibabel.load(files[i*batch_size + j][0])
img_array = img_nii.get_fdata()
target_size = tuple(img_array.shape)
heatmap = F.interpolate(temp_heatmap, size=target_size, mode='trilinear', align_corners=False)
heatmap = heatmap.squeeze(0).squeeze(0).detach().cpu().numpy()

# Create a NIfTI image and save it
output_filename = os.path.join(output_dir, f"gradcam_{os.path.basename(files[i*batch_size + j][0])}")
img_nii2 = nib.Nifti1Image(heatmap, img_nii.affine)
nib.save(img_nii2, output_filename)
print(f"Saved {output_filename}")
# Saving overlay to view in NiiVue:
# Load original image
original_img = img_nii
original_data = original_img.get_fdata()
# Normalize the original image to 0-255 range for RGB
original_normalized = ((original_data - np.min(original_data)) * 255 /
(np.max(original_data) - np.min(original_data))).astype(np.uint8)
# Create an RGB image from the original data
rgb_image = np.zeros((*original_data.shape, 3), dtype=np.uint8)
for i in range(3):
rgb_image[..., i] = original_normalized
# Load Grad-CAM mask
gradcam_img = img_nii2
gradcam_data = gradcam_img.get_fdata()
# Normalize Grad-CAM values to 0-1 range
gradcam_normalized = (gradcam_data - np.min(gradcam_data)) / (np.max(gradcam_data) - np.min(gradcam_data))
# Map normalized Grad-CAM values to the SpiceJet colormap
gradcam_colored = spicejet_cmap(gradcam_normalized)
# Overlay Grad-CAM colors onto the original image with opacity
for i in range(3): # R, G, B channels
rgb_image[..., i] = np.clip(
(1 - opacity) * rgb_image[..., i] + opacity * (gradcam_colored[..., i] * 255),
0, 255).astype(np.uint8)
# Save the overlay as a new NIfTI file
shape_3d = rgb_image.shape[:3]
rgb_dtype = np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')])
ras_pos = rgb_image.copy().view(dtype=rgb_dtype).reshape(shape_3d)
overlay_img = nib.Nifti1Image(ras_pos, original_img.affine)
nib.save(overlay_img, os.path.join(output_dir, f"ogradcam_{os.path.basename(files[i*batch_size + j][0])}"))
print(f"Overlay saved as {os.path.join(output_dir, f"ogradcam_{os.path.basename(files[i*batch_size + j][0])}")}")

# Make the dataset of a list of images for which we want to compute GradCAM
files = [
('./data/A.nii.gz', 1),
('./data/B.nii.gz', 0),
('./data/D.nii.gz', 1),
]
gradcam_img_list = './data/gradcam.txt'
# Join the file names and labels with newline separator
gradcam_img_list_content = "\n".join([f[0] + " " + str(f[1]) for f in files])
# Write the content to the file
with open(gradcam_img_list, 'w') as file:
file.write(gradcam_img_list_content)
print("File written successfully!")
# Load model from checkpoint
checkpoint_path = "./trails/models/atestbtrain_resnet_10/model_epoch_500_batch_85.pth.tar"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
gradcam_sets = Settings(input_D, input_H, input_W, phase='test', no_cuda=no_cuda)
gradcam_dataset = MyDataset(data_root, gradcam_img_list, gradcam_sets)
grad_validationloader = DataLoader(testing_dataset, batch_size=test_batch_size, shuffle=False, num_workers=test_num_workers, pin_memory=test_pin_memory)
opacity = 0.4
visualize_and_save_gradcam(model, grad_validationloader, num_images=len(files), output_dir="gradcam_outputs",batch_size = batch_size,opacity=opacity)

Visualizing the Results

The GradCAM process generates .nii.gz files that visualize the areas of the original NIfTI images the model focused on. These visualizations can be viewed using NiiVue in VSCode, allowing you to see 3D images with highlighted regions indicating the model's attention.

Source: Result of our algorithm on one of our images. Source: Image by author

Conclusion

I hope this guide helps those who, like me, were struggling to find resources on GradCAM with ResNet3D. Implementing these techniques can greatly enhance your understanding of how models make decisions, especially in complex 3D data scenarios. Feel free to leave comments or reach out if any part needs more explanation, and I’ll update the blog accordingly.

References

These references provide foundational knowledge and resources related to ResNet3D, GradCAM, data handling with NIfTI files, PyTorch functionalities, and visualization tools used in the implementation. They should help deepen your understanding and offer additional guidance as you work with similar projects.

  1. ResNet Paper: Deep Residual Learning for Image Recognition : https://arxiv.org/abs/1512.03385
  2. Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization : https://arxiv.org/abs/1610.02391
  3. Nibabel: Accessing the data stored in common neuroimaging file formats : https://nipy.org/nibabel/
  4. NiiVue: NIfTI Viewer for VSCode by JoΓ£o Moren : https://marketplace.visualstudio.com/items?itemName=joaomoreno.NiiVue
  5. PyTorch Documentation : https://pytorch.org/docs/stable/index.html (Please learn Pytorch lightning if you are new to Pytorch since it makes lives simpler)
  6. Scipy.ndimage: n-dimensional image processin : https://docs.scipy.org/doc/scipy/reference/ndimage.html
  7. Data Parallelism in PyTorch: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
  8. Understanding Xavier Initialization: https://pytorch.org/docs/stable/generated/torch.nn.init.xavier_normal_.html
  9. 3D ResNet Implementation and the pretrained weights were taken from Monai: https://github.com/kenshohara/3D-ResNets-PyTorch
  10. NIfTI File Format Specification : https://nifti.nimh.nih.gov/nifti-1
  11. Matplotlib Documentation: https://matplotlib.org/stable/contents.html
  12. SciPy Documentation: https://www.scipy.org/docs.html
  13. Torchvision Models: https://pytorch.org/vision/stable/models.html
  14. Automatic Mixed Precision (AMP) in PyTorch: https://pytorch.org/docs/stable/amp.html
  15. Balanced Accuracy Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
  16. Matthews Correlation Coefficient (MCC): https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
  17. ROC AUC Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
  18. NiiVue VSCode Extension Documentation: https://github.com/joaomoreno/vscode-nii-vue
  19. 3D Convolutional Neural Networks: https://towardsdatascience.com/a-comprehensive-introduction-to-different-types-of-convolutions-in-deep-learning-669281e58215
  20. Exponential Learning Rate Scheduler in PyTorch: https://pytorch.org/docs/stable/optim.html#exponentiallr

Note: This blog is reuploaded due to some comments about the formatting issues.

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.

Published via Towards AI

Feedback ↓