Functions designed to visualize how the model is performing on the dataset via saliency maps.

First, let's train a model for 3 epochs to have something reasonable for visualization.

#example
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms

def get_tta_transforms(resize_shape, normalize_transform, n=5):
    tta = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop((resize_shape, resize_shape)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor()
    ])
    original_transform = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        transforms.ToTensor()
    ])
    return transforms.Compose([
        transforms.Lambda(
            lambda image: torch.stack(
                [tta(image) for _ in range(n)] + [original_transform(image)]
            )
        ),
        transforms.Lambda(
            lambda images: torch.stack([
                normalize_transform(image) for image in images
            ])
        ),
    ])

def get_transforms(resize_shape, tta=False, tta_n=5):
    random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
    random_horizontal_flip = transforms.RandomHorizontalFlip()
    resize = transforms.Resize((resize_shape, resize_shape))
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transforms = transforms.Compose([
        random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
    ])
    val_transforms = (
        get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
        else transforms.Compose([resize, transforms.ToTensor(), normalize])
    )
    return train_transforms, val_transforms
    
train_transform, val_transform = get_transforms(224, tta=True)
#example
ds_mapping = initialize_datasets(
    '/share/nikola/export/dt372/BreaKHis_v1/',
    label='tumor_class', criterion=['tumor_type', 'magnification'],
    split_transforms={'train': train_transform, 'val': val_transform}
)
#example
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
#example
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
#example
model = resnet18(pretrained=True, num_classes=2, create_log_and_save_dirs=False)
if torch.cuda.is_available():
    model = model.cuda()
mixup = True
num_epochs = 3
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
    'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
    'val': nn.CrossEntropyLoss()
}
Setting up optimizer to fine-tune body with LR in range [0.00000001, 0.00001000] and head with LR 0.00100
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
    tr_loss, tr_acc = train(
        model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
        mixup=mixup, alpha=0.4, logging_frequency=25
    )
    val_loss, val_acc = validate(
        model, epoch + 1, val_dl, criterion['val'], tta=True,
        logging_frequency=25
    )
    checkpoint_state(
        model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
    )
clear_logging_handlers()

Now, with our trained model, let's use non-random transforms for inference, and corresponding visualization.

#example
resize = transforms.Resize((224, 224))
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
inference_transform = transforms.Compose([resize, transforms.ToTensor(), normalize])
inference_ds = BreaKHisDataset.initalize(
    '/share/nikola/export/dt372/BreaKHis_v1/', label='tumor_class',
    criterion=['tumor_type', 'magnification'],
    split={'all': 1.0},
    split_transforms={'all': inference_transform}
)['all'].dataset

show_image[source]

show_image(datapoint, ax=None)

Shows the image corresponding to `datapoint` (taken from a `BreaKHisDataset` object).
Optionally provide an axis object `ax` from Matplotlib for multi-image plots.

Here's an example of what one of our images looks like.

#example
show_image(inference_ds[0])

get_preprocessed_image[source]

get_preprocessed_image(datapoint, inference_transform)

Returns the pre-processed image and corresponding label ID using the `inference_transform`.
#example
get_preprocessed_image(inference_ds[0], inference_transform)
(tensor([[[[ 0.9303,  1.0331,  1.0673,  ...,  1.3242,  1.3070,  1.1700],
           [ 0.9303,  1.0502,  1.1015,  ...,  1.3242,  1.2728,  1.1358],
           [ 0.9646,  1.1187,  1.1358,  ...,  1.1872,  1.0844,  1.0844],
           ...,
           [ 0.5878,  0.7419,  0.9132,  ...,  1.4954,  1.4612,  1.3584],
           [ 0.3652,  0.3823,  0.5193,  ...,  1.4269,  1.4098,  1.3584],
           [ 0.3652,  0.4508,  0.5707,  ...,  1.3755,  1.3927,  1.3584]],
 
          [[ 0.1702,  0.1527,  0.1176,  ...,  0.5378,  0.5378,  0.5203],
           [ 0.1352,  0.1352,  0.1176,  ...,  0.5203,  0.4678,  0.5028],
           [ 0.1352,  0.1527,  0.1527,  ...,  0.3803,  0.3102,  0.4503],
           ...,
           [ 0.0301,  0.1352,  0.3627,  ...,  0.6779,  0.5903,  0.5903],
           [-0.1975, -0.2500, -0.1275,  ...,  0.5903,  0.5203,  0.5728],
           [-0.1800, -0.2150, -0.1625,  ...,  0.5378,  0.4678,  0.5553]],
 
          [[ 0.9145,  1.0365,  1.0714,  ...,  1.2980,  1.3154,  1.3328],
           [ 0.8797,  1.0191,  1.0714,  ...,  1.2980,  1.2805,  1.2980],
           [ 0.8971,  1.0365,  1.0714,  ...,  1.1759,  1.1411,  1.2805],
           ...,
           [ 0.6008,  0.7402,  0.9668,  ...,  1.5420,  1.4722,  1.4897],
           [ 0.4788,  0.5311,  0.6356,  ...,  1.4548,  1.3851,  1.4548],
           [ 0.5311,  0.6531,  0.7228,  ...,  1.3677,  1.3502,  1.4548]]]],
        device='cuda:0'),
 0)

show_heatmap_and_original[source]

show_heatmap_and_original(model, datapoint, inference_transform, show_for_label=True, show_for_prediction=False, label_type='tumor_class', show_activation_grid=False)

Shows a heatmap corresponding the `model`'s prediction for `datapoint` after transforming the image
using `inference_transform`. Assumes that the model was trained on labels of `label_type`. Optionally
show the activation grid by specifying `show_activation_grid`.

This is the main function for visualization. It will show an activation map using gradient-weighted activations from the last layer of the model (specifically, it's from the activations of layer4 for every ResNet. Note that by default, the activation map is shown based on how probable the model believes the label is correct. By specifying show_for_label as False and show_for_prediction as True, one can see the activation heatmap for why the model might believe something other than the label is correct.

Below, an example is shown when the above model is visualized on a benign and malignant example.

#example
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=False)
Model would have predicted benign (0.62375 vs. 0.62375)
Showing activation heatmap for the given label: benign
#example
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=True)
Model would have predicted benign (0.62375 vs. 0.62375)
Showing activation heatmap for the given label: benign
#example
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=False)
Model would have predicted malignant (0.93616 vs. 0.93616)
Showing activation heatmap for the given label: malignant
#example
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=True)
Model would have predicted malignant (0.93616 vs. 0.93616)
Showing activation heatmap for the given label: malignant

Here are some examples where the model was incorrect. Note how the activation heatmaps correctly piece with each other.

#example
show_heatmap_and_original(
    model, inference_ds[3], inference_transform, show_for_label=True, show_activation_grid=True
)
Model would have predicted malignant (0.65845 vs. 0.34155)
Showing activation heatmap for the given label: benign
#example
show_heatmap_and_original(
    model, inference_ds[3], inference_transform, show_for_label=False, show_for_prediction=True,
    show_activation_grid=True
)
Model would have predicted malignant (0.65845 vs. 0.34155)
Showing activation heatmap for the model's prediction: malignant
#example
show_heatmap_and_original(
    model, inference_ds[30], inference_transform, show_for_label=True, show_activation_grid=True
)
Model would have predicted malignant (0.58048 vs. 0.41952)
Showing activation heatmap for the given label: benign
#example
show_heatmap_and_original(
    model, inference_ds[30], inference_transform, show_for_label=False, show_for_prediction=True,
    show_activation_grid=True
)
Model would have predicted malignant (0.58048 vs. 0.41952)
Showing activation heatmap for the model's prediction: malignant