Using PyTorch Hooks for the Dynamic U-Net Architecture [Draft]
A look into how you can use PyTorch hooks, and using it to set up the dynamic U-Net architecture.
In this notebook, we'll take a look in more detail about how to set up a segmentation network dynamically from a given ResNet backbone. Specifically, we'll take advantage of PyTorch hooks to setup the decoder layers for outputting a segmentation, in the scheme shown in the U-Net paper (image shown below).
The left half of the network will be referred to frequently as the encoder, and the right half of the network will be referred to frequently as the decoder. Briefly, the novel proposition when the U-Net paper was published was the idea of using skip connections (here, from the encoder to the decoder) to combat any loss of information when upsampling.
Concretely, by using skip connections that concatenate a level of encoder's output with the input to the corresponding level of the decoder, whenever an upsampling operation is done, there is 2 times the amount of information in the number of channels that is used for upsampling.
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tifffile as tiff
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.transforms.functional as tf
from tqdm.notebook import tqdm
The point of using the ResNet backbone is to leverage the pre-trained weights we can get from PyTorch's model zoo. The links are set up below, and will be used in our implementation for fetching the ResNet encoder backbone.
model_urls = {
'resnet18': '',
'resnet34': '',
'resnet50': '',
'resnet101': '',
'resnet152': '',
This code is taken from the official PyTorch implementations for ResNets. Some things are not ideal with how they initially trained these networks (based on recent empirical findings), but otherwise, these pre-trained weights are very useful as we now have a very good baseline to start with for segmentation of real images.
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
The nice thing about ResNets are the residual blocks - these form very clear levels for the encoder and decoder of our U-Net to interface with. This is made even more clear in the implementation of the ResNet network - the variables self.layer1
to self.layer4
comprise these levels, and self.layer0
comprises the initial input encoding for the ResNet (generally going from the standard 3 channels of input from the image to 64 channels, i.e. 64 filters).
The implementation is modified slightly from the original PyTorch version. Specifically:
- We don't need the last FC layer, as we'll have something different for connecting to the decoder
- We encompass the input encoding for the network (
) inself.layer0
for clarity in theforward
class ResNetEncoder(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ResNetEncoder, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer0 = nn.Sequential(self.conv1, self.bn1, self.relu)
self.layer1 = nn.Sequential(self.maxpool, self._make_layer(block, 64, layers[0]))
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.out_dim = 512 * block.expansion
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
Now, we can load the pre-trained weights from the model_urls
defined above. This is made easy using the utility API in PyTorch, model_zoo
. Then, because we modified the ResNet implementation slightly, but kept the parameter names the same, we can load the state dictionary for any parameters still present (e.g. self.conv1
, self.bn1
, self.relu
, and the self.layer
s) without being strict to load as much of the pre-trained weights as possible.
These function definitions (shown below) are again taken from the official PyTorch implementation, but the modification of making the weight loading non-strict, and (by default) loading a pre-trained network.
def resnet18(pretrained=True, **kwargs):
"""Constructs a ResNet-18 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
return model
def resnet34(pretrained=True, **kwargs):
"""Constructs a ResNet-34 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
return model
def resnet50(pretrained=True, **kwargs):
"""Constructs a ResNet-50 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
return model
def resnet101(pretrained=True, **kwargs):
"""Constructs a ResNet-101 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
return model
def resnet152(pretrained=True, **kwargs):
"""Constructs a ResNet-152 model.
pretrained (bool): If True, returns a model pre-trained on ImageNet
model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
return model
Decoder Setup
Now that we have our setup for the pre-trained ResNet encoder, we need to automatically construct the Decoder using the architecture given in the encoder. To do so, we'll define some helper layers as nn.Module
is just a general form of a convolution, ReLU, and batch normalization layer in sequence, with some empirical bets practices (e.g. initializing using $\frac{1}{\sqrt{5}}$ for all the weights in the convolutional layer, as per the FastAI course). -
is just a thin wrapper on
function that concatenates all inputs along the channel dimension, assuming inputs are image batches, i.e. they have shape (batch size, num channels, height, width). -
is just a thin wrapper of a generic lambda function -
is a utility function for setting up convolutions that upsample an image. As mentioned above, in the U-Net architecture, we first concatenate the encoder output with the corresponding decoder input, so that when we upsample an image (i.e. from $(h, w)$ in size to $(2h, 2w)$ in size), we always have 2 times the amount of information (in this case, from having two times the number of channels). Accordingly, we will always convolve using an atrous convolution (where we dilate the kernel, rather than inserting 0s in the input to the convolutional layer), followed by the actual upsampling operation (using bilinear upsampling).
class ConvLayer(nn.Module):
def __init__(self, num_inputs, num_filters, bn=True, kernel_size=3, stride=1,
padding=None, transpose=False, dilation=1):
super(ConvLayer, self).__init__()
if padding is None:
padding = (kernel_size-1)//2 if transpose is not None else 0
if transpose:
self.layer = nn.ConvTranspose2d(num_inputs, num_filters, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation)
self.layer = nn.Conv2d(num_inputs, num_filters, kernel_size=kernel_size,
stride=stride, padding=padding)
nn.init.kaiming_uniform_(self.layer.weight, a=np.sqrt(5))
self.bn_layer = nn.BatchNorm2d(num_filters) if bn else None
def forward(self, x):
out = self.layer(x)
out = F.relu(out)
return out if self.bn_layer is None else self.bn_layer(out)
class ConcatLayer(nn.Module):
def forward(self, x, dim=1):
return, dim=dim)
class LambdaLayer(nn.Module):
def __init__(self, f):
super(LambdaLayer, self).__init__()
self.f = f
def forward(self, x):
return self.f(x)
def upconv2x2(inplanes, outplanes, size=None, stride=1):
if size is not None:
return [
ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
nn.Upsample(size=size, mode='bilinear', align_corners=True)
return [
ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
Some specifics in how the decoder is coordinated (here, the first layer means the input encoding layer of the encoder, and the last layer indicates the last layer in the encoder). These details are not super important, and are probably understandable if you inspect the U-Net architecture image more closely.
- The first layer's output passed along, concatenated, fed through
before upsampling, then fed through a regularconv3x3
two times, then aconv1x1
to output the right number of channels for segmentation output - The middle layers output all are passed along, concatenated, and fed through a
that first halves number of channels to upsample, then a regularconv3x3
- The last layer output's takes two pathways:
- Going down in the figure, the output goes through: max-pool (2x2), conv3x3, conv3x3, upconv2x2. These operations are encompassed in the
class. - Going across, assed across and concatenated to the result of above step
- Going down in the figure, the output goes through: max-pool (2x2), conv3x3, conv3x3, upconv2x2. These operations are encompassed in the
Again, these details don't particularly matter, unless you're implementing the architecture yourself. The important point is that upsampling always happens after a concatenation of the encoder's output with the corresponding input to the corresponding level of the decoder.
class DecoderConnect(nn.Module):
def __init__(self, inplanes, output_size):
super(DecoderConnect, self).__init__()
self.bottom_process = nn.Sequential(
ConvLayer(inplanes, inplanes * 2, kernel_size=3),
ConvLayer(inplanes * 2, inplanes * 2, kernel_size=3),
*upconv2x2(inplanes * 2, inplanes, size=output_size)
self.concat_process = nn.Sequential(
ConvLayer(inplanes * 2, inplanes * 2, kernel_size=1),
ConvLayer(inplanes * 2, inplanes, kernel_size=3),
ConvLayer(inplanes, inplanes, kernel_size=3)
def forward(self, x):
decoder_input = self.bottom_process(x)
return self.concat_process({0: x, 1: decoder_input})
The crux of constructing the decoder happens in the setup_decoder
function call below, and consequently in the construct_decoder
. The details are hard to extract from the code below, so we can break it down as follows (tracing the code in the setup_decoder
function first.
Getting Shapes Using Hooks
We're going to gather the input size and output size of a tensor to any layer in the ResNet encoder network with a name that has the prefix "layer". To do so, we'll use hooks. Specifically, a hook is a closure, i.e. a function that's passed as an argument when registering a hook for a specific layer in our network. You can see here that
is the function passed when registering a hook - We
for anychild
layer of our network that has a name thatstartswith
layer, e.g.self.layer0
, and so on.
Note the specification for shape_hook
, and generally for the function passed to register_forward_hook
- it will have access to the input and output of the layer we are calling register_forward_hook
for (note that input and output can be tuples here). In our case, we only care about their shapes, as we'll need the shape to determine the shape of the decoder's input, and accordingly the number of filters the convolutional layers need to output in the previous layer.
Accordingly, we'll take those shapes, and add them to our input_sizes
and output_sizes
array, to keep track of the input and output shapes as the network processes an input. To actually populate these arrays, we have to do exctly that - process an input. Thus, we'll make a dummy input (in the code, test_input
) that we pass through our encoder, and after it finishes processing that input, our input_sizes
and output_sizes
array will be populated!
Constructing the Decoder
Now that we have the input and output sizes of any tensors passing through the blocks of our ResNet encoder, we can construct our decoder level by level. To do so, we'll just look at the following things:
- How much we need to upsample the size of the image (determined by looking at the ratio of the input image size and the output image size)
- What the difference in channels between the input and output of the corresponding encoder level are (determined by looking at the ratio of channels between input and output)
Looking at both of these gives us a sense of the operation we need to do to reverse what the encoder did. Specifically, we can abide by the following assumptions when constructing the decoder:
- The shape of the input to the level of the decoder we're working on will be the same as the shape of the output of the corresponding level of the encoder
- The shape of the output of this level of the decoder will be the same shape as the shape of the input of the corresponding level of the encoder
With these assumptions in mind, and using the details above for constructing the operations for each level of the decoder, we can just use case work for actually constructing the decoder, depending on whether we're looking at the last layer of the encoder, one of the middle layers, or the first layer of the encoder.
Since we're starting from the inputs and outputs of the first layer of the encoder, we add on the constructed layers as we inspect the shapes of the inputs and outputs of the encoder, and then reverse the list of constructed layers when finalizing the decoder architecture, to ensure that we go from the last output shape of the encoder to the first input shape of the encoder, which is (generally) what we want to output for segmentation. (This doesn't necessarily have to be true, in which case, a 1x1 convolution is added at the end of the decoder to get the right number of output channels, specified in the constructing of the class as num_output_channels
Note that we maintain the decoder as a list of modules, i.e. an nn.ModuleList
. This is an intentional choice, as we'll need to perform the operations of our network in sequence by level, as each level requires getting the corresponding output of the encoder, and processing it alongside the corresponding input of the decoder.
Model Forward Using Hooks
The last part of setting up our dynamic U-Net architecture is to specify the forward
function. In order to do so, we need to keep track of the outputs of each level of our encoder. Since we've encompassed the encoder as one module when constructing our U-Net, the easiest way to get the outputs for each level of the encoder is to just use hooks again.
The setup for these hooks is very similar to how we set up the shape hooks above, but instead, we only keep track of the outputs, and we want the actual output tensor, not the shape. This is encompassed in the encoder_output_hook
hook in the forward function below. Again, we register the hook for all layers in our encoder that have name starting with "layer".
To actually use these outputs, we only need to keep track of the corresponding input we are passing into the current level of the decoder. This becomes convenient to do since we left the decoder as an nn.ModuleList
, so we need only iterate over the encoder outputs and the corresponding layer of the decoder that they'll be passed into with the corresponding input to the decoder. This is encompassed in the following loop in the forward
prev_output = None
for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
if prev_output is not None:
prev_output = rdl({0: reo, 1: prev_output})
prev_output = rdl(reo)
Note how that for the first layer of the decoder (the one that ties with the last layer of the encoder), there's no previous output. This is because the first layer of the decoder has the additional pathway (seen at the bottom of the U-Net architecture figure) that is concatenated with the output of the last layer of the encoder. On the other hand, for all other layers, the encoder output (reo
) and the decoder input (prev_output
) are concatenated together in a single pathway (explicitly, via the ConcatLayer
forward function).
class DynamicUNet(nn.Module):
def __init__(self, encoder, input_size=(224, 224), num_output_channels=None, verbose=0):
super(DynamicUNet, self).__init__()
self.encoder = encoder
self.verbose = verbose
self.input_size = input_size
self.num_input_channels = 3 # This must be 3 because we're using a ResNet encoder
self.num_output_channels = num_output_channels
self.decoder = self.setup_decoder()
def forward(self, x):
encoder_outputs = []
def encoder_output_hook(self, input, output):
handles = [
child.register_forward_hook(encoder_output_hook) for name, child in self.encoder.named_children()
if name.startswith('layer')
if self.verbose >= 1:
print("Removing all forward handles")
for handle in handles:
prev_output = None
for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
if prev_output is not None:
prev_output = rdl({0: reo, 1: prev_output})
prev_output = rdl(reo)
return prev_output
def setup_decoder(self):
input_sizes = []
output_sizes = []
def shape_hook(self, input, output):
handles = [
child.register_forward_hook(shape_hook) for name, child in self.encoder.named_children()
if name.startswith('layer')
test_input = torch.randn(1, self.num_input_channels, *self.input_size)
if self.verbose >= 1:
print("Removing all shape hook handles")
for handle in handles:
decoder = self.construct_decoder(input_sizes, output_sizes, num_output_channels=self.num_output_channels)
return decoder
def construct_decoder(self, input_sizes, output_sizes, num_output_channels=None):
decoder_layers = []
for layer_index, (input_size, output_size) in enumerate(zip(input_sizes, output_sizes)):
upsampling_size_factor = int(input_size[-1] / output_size[-1])
upsampling_channel_factor = input_size[-3] / output_size[-3]
next_layer = []
bs, c, h, w = input_size
ops = []
if layer_index == len(input_sizes) - 1:
last_layer_ops = DecoderConnect(output_size[-3], output_size[2:])
last_layer_ops_input = torch.randn(*output_size)
last_layer_concat_ops_output = last_layer_ops(last_layer_ops_input)
if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
last_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
last_layer_concat_upconv_op_output = nn.Sequential(*last_layer_concat_upconv_op)(
elif layer_index == 0:
first_layer_concat_ops = [
ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
output_size[-3] * 2,
size=[dim * upsampling_size_factor for dim in output_size[2:]]
ConvLayer(output_size[-3], output_size[-3], kernel_size=3),
input_size[-3] if self.num_output_channels is None else self.num_output_channels,
first_layer_concat_ops_output = nn.Sequential(*first_layer_concat_ops)(
{0: torch.randn(*output_size), 1: torch.randn(*output_size)}
middle_layer_concat_ops = [
ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
ConvLayer(output_size[-3] * 2, output_size[-3], kernel_size=3),
ConvLayer(output_size[-3], output_size[-3], kernel_size=3)
middle_layer_concat_ops_output = nn.Sequential(*middle_layer_concat_ops)(
{0: torch.randn(*output_size), 1: torch.randn(*output_size)}
if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
middle_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
middle_layer_concat_upconv_op_output = nn.Sequential(*middle_layer_concat_upconv_op)(
return nn.ModuleList(reversed(decoder_layers))
def load_camvid_dataset(data_directory):
with open(os.path.join(data_directory, "valid.txt"), "r") as f:
val_names = [line.strip() for line in f]
with open(os.path.join(data_directory, "codes.txt"), "r") as f:
label_mapping = {l.strip(): i for i, l in enumerate(f)}
data = []
image_index_mapping = {}
for im_f in os.listdir(os.path.join(data_directory, "images")):
if im_f.split('.')[-1] != 'png':
image_index_mapping[im_f] = len(data)
fp = os.path.join(data_directory, "images", im_f)
for label_f in os.listdir(os.path.join(data_directory, "labels")):
im_f = label_f.split('.')
im_f[0] = '_'.join(im_f[0].split('_')[:-1])
im_f = '.'.join(im_f)
index = image_index_mapping[im_f]
fp = os.path.join(data_directory, "labels", label_f)
data[index] = (data[index], fp)
val_indices = [image_index_mapping[name] for name in val_names]
return data, val_indices, label_mapping
camvid_data_directory = "/home/jupyter/data/camvid"
all_data, val_indices, label_mapping = load_camvid_dataset(camvid_data_directory);
We'll split the data for now into training and validation, based on the split specified in the dataset itself.
tr_data, val_data = [tpl for i, tpl in enumerate(all_data) if i not in val_indices], \
[tpl for i, tpl in enumerate(all_data) if i in val_indices]
def display_segmentation(image, target, ax=None):
if ax:
ax.imshow(image, cmap='gray')
plt.imshow(image, cmap='gray')
if ax:
ax.imshow(target, cmap='jet', alpha=0.5)
plt.imshow(target, cmap='jet', alpha=0.5)
def display_segmentation_from_file(im_f, label_f):
im, label =,
display_segmentation(im, label)
We can visualize some examples and their corresponding segmentations overlayed.
i = 10
display_segmentation_from_file(tr_data[i][0], tr_data[i][1])
Here, we'll use a simple dataset class that transforms each image and it's corresponding segmentation label by randomly (with probably 0.5) vertically flipping, randomly (with probably 0.5) horizontally flipping, and then normalizing the input image using the standard ImageNet normalization statistics.
We also specify a resize shape, defaulting to 360 by 480, where nearest neighbor interpolation is used for resizing.
class CamvidDataset(
def __init__(self, data, resize_shape=(360, 480), is_train=True):
self.images, self.labels = [tpl[0] for tpl in data], \
[tpl[1] for tpl in data]
self.resize_shape = resize_shape
self.is_train = is_train
def transform(self, index):
input, target = map(, (self.images[index], self.labels[index]))
input, target = (
tf.resize(input, self.resize_shape),
tf.resize(target, self.resize_shape, interpolation=Image.NEAREST)
if self.is_train:
horizontal_draw = torch.rand(1).item()
vertical_draw = torch.rand(1).item()
if horizontal_draw > 0.5:
input, target = tf.hflip(input), tf.hflip(target)
if vertical_draw > 0.5:
input, target = tf.vflip(input), tf.vflip(target)
input, target = map(tf.to_tensor, (input, target))
torch.clamp((255 * target), 0, 32, out=target)
return tf.normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), target.long()
def __getitem__(self, index):
return self.transform(index)
def __len__(self):
return len(self.images)
tr_ds, val_ds = CamvidDataset(tr_data, resize_shape=(360, 480)),\
CamvidDataset(val_data, resize_shape=(360, 480), is_train=False)
tr_dl, val_dl =, batch_size=4, shuffle=True),
One last thing that will be useful is the Dice loss, used commonly in segmentation. It is implemented below. We use the differentiable variant (i.e. the smooth variant). In the literature, both the Dice loss and cross-entropy are used, with comparable results. For the purpose of this blog post, we'll optimize directly for the Dice score.
In paritcular, the reason for using the Dice score is that it is just the F1 score, which is the harmonic mean of the precision $P$ and the recall $R$. Specifically, we can write it as:
$$\frac{1}{\frac{1}{P} + \frac{1}{R}}$$
By using this loss, we optimize equally for precision and recall, where optimizing for cross-entropy is (theoretically) a bit more difficult due to background pixels being more likely to occur than specific class pixels.
class DiceLoss(nn.Module):
Module to compute the Dice segmentation loss. Based on the following discussion:
def __init__(self, weights=None, ignore_index=None, eps=0.0001):
super(DiceLoss, self).__init__()
self.weights = weights
self.ignore_index = ignore_index
self.eps = eps
def forward(self, output, target):
output: (N, C, H, W) tensor of probabilities for the predicted output
target: (N, H, W) tensor corresponding to the pixel-wise labels
loss: the Dice loss averaged over channels
encoded_target = output.detach() * 0
if self.ignore_index is not None:
mask = target == self.ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if self.weights is None:
self.weights = 1
intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target
if self.ignore_index is not None:
denominator[mask] = 0
denominator = denominator.sum(0).sum(1).sum(1) + self.eps
loss_per_channel = self.weights * (1 - (numerator / denominator))
return loss_per_channel.sum() / output.size(1)
def dice_similarity(output, target, weights=None, ignore_index=None, eps=1e-8):
output: (N, C, H, W) tensor of model output
target: (N, H, W) tensor corresponding to the pixel-wise labels
loss: the Dice loss averaged over channels
prediction = torch.argmax(output, dim=1)
encoded_prediction = output.detach() * 0
encoded_prediction.scatter_(1, prediction.unsqueeze(1), 1)
encoded_target = output.detach() * 0
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if weights is None:
weights = 1
intersection = encoded_prediction * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1) + eps
denominator = intersection + encoded_target
if ignore_index is not None:
denominator[mask] = 0
denominator = denominator.sum(0).sum(1).sum(1) + eps
acc_per_channel = weights * ((numerator / denominator))
return acc_per_channel.sum() / output.size(1)
To begin, we'll only train the decoder of our model, and freeze the weights of our encoder (since they are pre-trained weights). This is easily done by only specifying the parameters of our decoder in the optimizer settings.
model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(360, 480))
if torch.cuda.is_available():
model = model.cuda()
decoder_parameters = [item for module in model.decoder for item in module.parameters()]
optimizer = optim.AdamW(decoder_parameters) # Only training the decoder for now
criterion = DiceLoss()
# Training specific parameters
num_epochs = 10
num_up_epochs, num_down_epochs = 3, 7
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-2,
total_steps=num_epochs * len(tr_dl),
We'll train for 10 epochs to fine-tune the decoder to be on par with the pre-trained weights being used in the ResNet encoder. We're using the default learning rate for Adam, $10^{-3}$. We'll print the per-pixel accuracy, as well as the Dice similarity (which is the F1 score).
losses = []
accuracies = []
tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
tr_loss, tr_correct_pixels, tr_total_pixels, tr_dice_similarity, total = 0., 0., 0., 0., 0.
tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
for i, (x, y) in enumerate(tqdm_epoch_iterator):
if torch.cuda.is_available():
x, y = x.cuda(), y.squeeze(dim=1).cuda()
output = model(x)
probs = torch.softmax(output, dim=1)
prediction = torch.argmax(output, dim=1)
tr_correct_pixels += ((prediction == y).sum())
tr_total_pixels += y.numel()
tr_dice_similarity += dice_similarity(output, y.squeeze(1)) * len(y)
loss = criterion(output, y.squeeze(1))
tr_loss += * len(y)
total += len(y)
if scheduler is not None:
if i % 1 == 0:
curr_loss = tr_loss / total
curr_acc = tr_correct_pixels / tr_total_pixels
curr_dice = tr_dice_similarity / total
"Loss": curr_loss.item(), "Accuracy": curr_acc.item(), "Dice": curr_dice.item()
overall_loss = tr_loss.item() / total
overall_acc = tr_correct_pixels.item() / tr_total_pixels
tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})
Now we can take a look at the losses over time and the accuracies over time.
plt.plot(list(range(len(losses))), losses)
plt.title("Losses over time for fine-tuning decoder")
plt.plot(list(range(len(accuracies))), accuracies)
plt.title("Accuracies over time for fine-tuning decoder")
Let's save a model checkpoint so that, if necessary, we can reset back here if our next step goes badly. This is a good practice in general.
import pickle
decoder_finetuned_model_state_dict = model.state_dict()
output_fp = "/share/nikola/export/dt372/technical_blog/dynamic_unet/model_ckpts/decoder_finetuned_model_state_dict.pkl"
with open(output_fp, 'wb') as f:
pickle.dump(decoder_finetuned_model_state_dict, f)
with open(output_fp, 'rb') as f:
decoder_finetuned_model_state_dict = pickle.load(f)
Now, we'll fine-tune the whole network. We'll use a smaller learning rate for the encoder, and a larger, but still small learning rate for the decoder. Ideally, we'd do some sort of learning rate finding scheme (as done in FastAI), but for simplicity, we'll just pick reasonable numbers.
Now we'll train for 20 epochs, and see how the model performs over time.
optimizer = optim.AdamW([
{'params': model.decoder.parameters(), 'lr': 1e-4},
{'params': model.encoder.parameters(), 'lr': 1e-7},
num_epochs = 20
num_up_epochs, num_down_epochs = 4, 16
scheduler = torch.optim.lr_scheduler.CyclicLR(
optimizer, base_lr=[1e-6, 1e-8], max_lr=[1e-3, 1e-6],
step_size_up=num_up_epochs * len(tr_dl),
step_size_down=num_down_epochs * len(tr_dl),
cycle_momentum=False # Need to set this since we're using AdamW?
tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
tr_loss, tr_correct_pixels, tr_total_pixels, total = 0., 0., 0., 0.
tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
for i, (x, y) in enumerate(tqdm_epoch_iterator):
if torch.cuda.is_available():
x, y = x.cuda(), y.squeeze(dim=1).cuda()
output = model(x)
prediction = torch.argmax(output, dim=1)
tr_correct_pixels += ((prediction == y).sum())
tr_total_pixels += y.numel()
loss = criterion(output, y)
tr_loss += * len(y)
total += len(y)
if scheduler is not None:
if i % 1 == 0:
curr_loss = tr_loss / total
curr_acc = tr_correct_pixels / tr_total_pixels
tqdm_epoch_iterator.set_postfix({"Loss": curr_loss.item(), "Accuracy": curr_acc.item()})
overall_loss = tr_loss.item() / total
overall_acc = tr_correct_pixels.item() / tr_total_pixels
tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})
Again, let's look at how the losses and accuracies change (after epoch 10 is where we started the above fine-tuning cycle).
plt.plot(list(range(len(losses))), losses)
plt.title("Losses over time for fine-tuning decoder")
plt.plot(list(range(len(accuracies))), accuracies)
plt.title("Accuracies over time for fine-tuning decoder")
Now, we can visualize some results on the validation dataset. First is the prediction, second is the ground truth.
index = 10
x, y = val_dl.dataset[index]
x, y = x.unsqueeze(0).cuda(), y.cuda()
output = model(x)
prediction = torch.argmax(output, dim=1).cpu()
display_segmentation(tf.to_pil_image(x[0].cpu()), tf.to_pil_image(prediction.byte()))
display_segmentation_from_file(val_dl.dataset.images[index], val_dl.dataset.labels[index])
Clearly, there's some room for improvement.