This notebook shows an example of how to use PyTorch Lightning to wrap the model, train, monitor training, validate, and visualize results.
from import initialize_datasets
from breakhis_gradcam.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from breakhis_gradcam.utils import get_param_lr_maps, mixup_data, setup_optimizer_and_scheduler
import os
import torch
from torch import nn
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.logging.tensorboard import TensorBoardLogger

resnet_model_mapping = {
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
    'resnet101': resnet101,
    'resnet152': resnet152

%load_ext tensorboard

We'll define the Lightning module below, with several of the things seen in the last modules, including:

  • The transforms included as a method that can be tweaked using input arguments of the initialization of the Lightning module *
class LightningResNet(pl.LightningModule):
    def __init__(
        criterion=['tumor_type', 'magnification'],
        finetune_body_factor=[1e-5, 1e-2],
        super(LightningResNet, self).__init__()
        self.mixup, self.mixup_alpha = mixup, mixup_alpha
        self.tta, self.tta_mixing = tta, tta_mixing
        train_transform, val_transform = self.get_transforms(resize_shape, tta=tta)
        ds_mapping = initialize_datasets(
            split={'train': train_ratio, 'val': 1 - train_ratio},
            label=label, criterion=criterion,
            split_transforms={'train': train_transform, 'val': val_transform}
        tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
        self.tr_dl =, batch_size=batch_size, shuffle=True)
        self.val_dl =, batch_size=batch_size)
        assert resnet_type in resnet_model_mapping, "Please specify a valid ResNet architecture."
        self.model = resnet_model_mapping[resnet_type](
            pretrained=True, num_classes=2, create_log_and_save_dirs=False
        self.base_lr = base_lr
        self.num_epochs = num_epochs
        self.param_lr_maps = get_param_lr_maps(self.model, base_lr, finetune_body_factor)
        self.criterion = {
            'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
            'val': nn.CrossEntropyLoss()
    def get_tta_transforms(resize_shape, normalize_transform, n=5):
        tta = transforms.Compose([
            transforms.RandomResizedCrop((resize_shape, resize_shape)),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        original_transform = transforms.Compose([
            transforms.Resize((resize_shape, resize_shape)),
        return transforms.Compose([
                lambda image: torch.stack(
                    [tta(image) for _ in range(n)] + [original_transform(image)]
                lambda images: torch.stack([
                    normalize_transform(image) for image in images

    def get_transforms(self, resize_shape, tta=False, tta_n=5):
        random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
        random_horizontal_flip = transforms.RandomHorizontalFlip()
        resize = transforms.Resize((resize_shape, resize_shape))
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        train_transforms = transforms.Compose([
            random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
        val_transforms = (
            get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
            else transforms.Compose([resize, transforms.ToTensor(), normalize])
        return train_transforms, val_transforms
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        mixed_x, y_a, y_b, lam, mixup_criterion, mixup_acc = mixup_data(
            x, y, self.criterion['train'], alpha=self.mixup_alpha if self.mixup else 0.0
        output = self.forward(mixed_x)
        prediction = torch.argmax(output, -1)
        loss = mixup_criterion(output)
        return {
            'loss': loss,
            'batch_size': len(y),
            'correct': mixup_acc(prediction),
            'log': {
                'train_loss': loss.item()
    def validation_step(self, batch, batch_idx):
        x, y = batch
        if self.tta:
            bs, n_aug, c, h, w = x.size()
            output = self.forward(x.view(-1, c, h, w)).view(bs, n_aug, -1)
            output = (
                ((1 - self.tta_mixing) * output[:, -1, :]) + (self.tta_mixing * output[:, :-1, :].mean(1))
            output = self.forward(x)
        prediction = torch.argmax(output, -1)
        loss = self.criterion['val'](output, y)
        return {
            'loss': loss.item(),
            'batch_size': len(y),
            'correct': (prediction == y).sum().item(),
            'log': {
                'val_loss': loss.item()

    def validation_end(self, outputs):
        total_loss = sum([out['loss'] for out in outputs])
        total_correct = sum([out['correct'] for out in outputs])
        total = sum([out['batch_size'] for out in outputs])
        return {
            'val_loss': total_loss / total,
            'val_acc': total_correct / total,
            'log': {
                'val_loss': total_loss / total,
                'val_acc': total_correct / total

    def configure_optimizers(self):
        optimizer, scheduler = setup_optimizer_and_scheduler(
            self.param_lr_maps, self.base_lr, self.num_epochs, len(self.tr_dl)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return self.tr_dl
    def val_dataloader(self):
        return self.val_dl
model = LightningResNet(resnet_type='resnet34')
Setting up optimizer to fine-tune body with LR in range [0.00000001, 0.00001000] and head with LR 0.00100

Lightning will resume from a checkpoint (usually based on Slurm run ID of the notebook / process you're running in) - to avoid this, you simply need to do something like the following (sets the version number to the next highest number to avoid using a checkpointed model state).

from pytorch_lightning.logging import TensorBoardLogger

use_new_model_version = True

logger = TensorBoardLogger(
if use_new_model_version:
    logger._version = logger._get_next_version()

trainer = pl.Trainer(logger=logger, gpus=1)
print("Logging under lightning_logs directory, under version ID %s" % trainer.logger.version)
Logging under lightning_logs directory, under version ID 270075

We can take advantage of the fact that Lightning logs Tensorboard logs automatically for us, and see loss curves in the notebook itself!

%tensorboard --logdir lightning_logs/ --host --port 8081
Now, using this trained model, we can just as easily visualize the class-activation maps using the vis module.

from import BreaKHisDataset
from breakhis_gradcam.vis import show_heatmap_and_original

resize = transforms.Resize((224, 224))
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
inference_transform = transforms.Compose([resize, transforms.ToTensor(), normalize])
inference_ds = BreaKHisDataset.initalize(
    '/share/nikola/export/dt372/BreaKHis_v1/', label='tumor_class',
    criterion=['tumor_type', 'magnification'],
    split={'all': 1.0},
    split_transforms={'all': inference_transform}
show_heatmap_and_original(model.model, inference_ds[1], inference_transform, show_activation_grid=False)
Model would have predicted benign (0.82078 vs. 0.82078)
Showing activation heatmap for the given label: benign