First, let's train a model for 3 epochs to have something reasonable for visualization.
#example
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms
def get_tta_transforms(resize_shape, normalize_transform, n=5):
tta = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomResizedCrop((resize_shape, resize_shape)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor()
])
original_transform = transforms.Compose([
transforms.Resize((resize_shape, resize_shape)),
transforms.ToTensor()
])
return transforms.Compose([
transforms.Lambda(
lambda image: torch.stack(
[tta(image) for _ in range(n)] + [original_transform(image)]
)
),
transforms.Lambda(
lambda images: torch.stack([
normalize_transform(image) for image in images
])
),
])
def get_transforms(resize_shape, tta=False, tta_n=5):
random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
random_horizontal_flip = transforms.RandomHorizontalFlip()
resize = transforms.Resize((resize_shape, resize_shape))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_transforms = transforms.Compose([
random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
])
val_transforms = (
get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
else transforms.Compose([resize, transforms.ToTensor(), normalize])
)
return train_transforms, val_transforms
train_transform, val_transform = get_transforms(224, tta=True)
#example
ds_mapping = initialize_datasets(
'/share/nikola/export/dt372/BreaKHis_v1/',
label='tumor_class', criterion=['tumor_type', 'magnification'],
split_transforms={'train': train_transform, 'val': val_transform}
)
#example
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
#example
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
#example
model = resnet18(pretrained=True, num_classes=2, create_log_and_save_dirs=False)
if torch.cuda.is_available():
model = model.cuda()
mixup = True
num_epochs = 3
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
tr_loss, tr_acc = train(
model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
mixup=mixup, alpha=0.4, logging_frequency=25
)
val_loss, val_acc = validate(
model, epoch + 1, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
checkpoint_state(
model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
)
clear_logging_handlers()
Now, with our trained model, let's use non-random transforms for inference, and corresponding visualization.
#example
resize = transforms.Resize((224, 224))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
inference_transform = transforms.Compose([resize, transforms.ToTensor(), normalize])
inference_ds = BreaKHisDataset.initalize(
'/share/nikola/export/dt372/BreaKHis_v1/', label='tumor_class',
criterion=['tumor_type', 'magnification'],
split={'all': 1.0},
split_transforms={'all': inference_transform}
)['all'].dataset
Here's an example of what one of our images looks like.
#example
show_image(inference_ds[0])
#example
get_preprocessed_image(inference_ds[0], inference_transform)
This is the main function for visualization. It will show an activation map using gradient-weighted activations from the last layer of the model (specifically, it's from the activations of layer4
for every ResNet. Note that by default, the activation map is shown based on how probable the model believes the label is correct. By specifying show_for_label
as False and show_for_prediction
as True, one can see the activation heatmap for why the model might believe something other than the label is correct.
Below, an example is shown when the above model is visualized on a benign and malignant example.
#example
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=False)
#example
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=True)
#example
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=False)
#example
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=True)
Here are some examples where the model was incorrect. Note how the activation heatmaps correctly piece with each other.
#example
show_heatmap_and_original(
model, inference_ds[3], inference_transform, show_for_label=True, show_activation_grid=True
)
#example
show_heatmap_and_original(
model, inference_ds[3], inference_transform, show_for_label=False, show_for_prediction=True,
show_activation_grid=True
)
#example
show_heatmap_and_original(
model, inference_ds[30], inference_transform, show_for_label=True, show_activation_grid=True
)
#example
show_heatmap_and_original(
model, inference_ds[30], inference_transform, show_for_label=False, show_for_prediction=True,
show_activation_grid=True
)