Overview

In this notebook, we'll take a look in more detail about how to set up a segmentation network dynamically from a given ResNet backbone. Specifically, we'll take advantage of PyTorch hooks to setup the decoder layers for outputting a segmentation, in the scheme shown in the U-Net paper (image shown below). alt text

The left half of the network will be referred to frequently as the encoder, and the right half of the network will be referred to frequently as the decoder. Briefly, the novel proposition when the U-Net paper was published was the idea of using skip connections (here, from the encoder to the decoder) to combat any loss of information when upsampling.

Concretely, by using skip connections that concatenate a level of encoder's output with the input to the corresponding level of the decoder, whenever an upsampling operation is done, there is 2 times the amount of information in the number of channels that is used for upsampling.

Encoder Setup

We'll first setup the code for the ResNet backbone (i.e. the encoder in the U-Net network, or the left-hand side).

#collapse_hide
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tifffile as tiff
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.transforms.functional as tf
from tqdm.notebook import tqdm

The point of using the ResNet backbone is to leverage the pre-trained weights we can get from PyTorch's model zoo. The links are set up below, and will be used in our implementation for fetching the ResNet encoder backbone.

#collapse_hide
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

This code is taken from the official PyTorch implementations for ResNets. Some things are not ideal with how they initially trained these networks (based on recent empirical findings), but otherwise, these pre-trained weights are very useful as we now have a very good baseline to start with for segmentation of real images.

#collapse_hide
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

The nice thing about ResNets are the residual blocks - these form very clear levels for the encoder and decoder of our U-Net to interface with. This is made even more clear in the implementation of the ResNet network - the variables self.layer1 to self.layer4 comprise these levels, and self.layer0 comprises the initial input encoding for the ResNet (generally going from the standard 3 channels of input from the image to 64 channels, i.e. 64 filters).

The implementation is modified slightly from the original PyTorch version. Specifically:

  • We don't need the last FC layer, as we'll have something different for connecting to the decoder
  • We encompass the input encoding for the network (self.conv1 to self.bn1 to self.relu) in self.layer0 for clarity in the forward function

#collapse_show
class ResNetEncoder(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super(ResNetEncoder, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer0 = nn.Sequential(self.conv1, self.bn1, self.relu)
        self.layer1 = nn.Sequential(self.maxpool, self._make_layer(block, 64, layers[0]))
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.out_dim = 512 * block.expansion

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

Now, we can load the pre-trained weights from the model_urls defined above. This is made easy using the utility API in PyTorch, model_zoo. Then, because we modified the ResNet implementation slightly, but kept the parameter names the same, we can load the state dictionary for any parameters still present (e.g. self.conv1, self.bn1, self.relu, and the self.layers) without being strict to load as much of the pre-trained weights as possible.

These function definitions (shown below) are again taken from the official PyTorch implementation, but the modification of making the weight loading non-strict, and (by default) loading a pre-trained network.

#collapse_show
def resnet18(pretrained=True, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
    return model


def resnet34(pretrained=True, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
    return model


def resnet50(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
    return model


def resnet101(pretrained=True, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
    return model


def resnet152(pretrained=True, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
    return model

Decoder Setup

Now that we have our setup for the pre-trained ResNet encoder, we need to automatically construct the Decoder using the architecture given in the encoder. To do so, we'll define some helper layers as nn.Modules.

  • ConvLayer is just a general form of a convolution, ReLU, and batch normalization layer in sequence, with some empirical bets practices (e.g. initializing using $\frac{1}{\sqrt{5}}$ for all the weights in the convolutional layer, as per the FastAI course).
  • ConcatLayer is just a thin wrapper on the torch.cat function that concatenates all inputs along the channel dimension, assuming inputs are image batches, i.e. they have shape (batch size, num channels, height, width).
  • LambdaLayer is just a thin wrapper of a generic lambda function
  • upconv2x2 is a utility function for setting up convolutions that upsample an image. As mentioned above, in the U-Net architecture, we first concatenate the encoder output with the corresponding decoder input, so that when we upsample an image (i.e. from $(h, w)$ in size to $(2h, 2w)$ in size), we always have 2 times the amount of information (in this case, from having two times the number of channels). Accordingly, we will always convolve using an atrous convolution (where we dilate the kernel, rather than inserting 0s in the input to the convolutional layer), followed by the actual upsampling operation (using bilinear upsampling).
class ConvLayer(nn.Module):
    def __init__(self, num_inputs, num_filters, bn=True, kernel_size=3, stride=1,
                 padding=None, transpose=False, dilation=1):
        super(ConvLayer, self).__init__()
        if padding is None:
            padding = (kernel_size-1)//2 if transpose is not None else 0
        if transpose:
            self.layer = nn.ConvTranspose2d(num_inputs, num_filters, kernel_size=kernel_size,
                                            stride=stride, padding=padding, dilation=dilation)
        else:
            self.layer = nn.Conv2d(num_inputs, num_filters, kernel_size=kernel_size,
                                   stride=stride, padding=padding)
        nn.init.kaiming_uniform_(self.layer.weight, a=np.sqrt(5))
        self.bn_layer = nn.BatchNorm2d(num_filters) if bn else None

    def forward(self, x):
        out = self.layer(x)
        out = F.relu(out)
        return out if self.bn_layer is None else self.bn_layer(out)
    
class ConcatLayer(nn.Module):
    def forward(self, x, dim=1):
        return torch.cat(list(x.values()), dim=dim)
    
class LambdaLayer(nn.Module):
    def __init__(self, f):
        super(LambdaLayer, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

def upconv2x2(inplanes, outplanes, size=None, stride=1):
    if size is not None:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(size=size, mode='bilinear', align_corners=True)
        ] 
    else:
        return [
            ConvLayer(inplanes, outplanes, kernel_size=2, dilation=2, stride=stride),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        ]

Some specifics in how the decoder is coordinated (here, the first layer means the input encoding layer of the encoder, and the last layer indicates the last layer in the encoder). These details are not super important, and are probably understandable if you inspect the U-Net architecture image more closely.

  • The first layer's output passed along, concatenated, fed through conv3x3 before upsampling, then fed through a regular conv3x3 two times, then a conv1x1 to output the right number of channels for segmentation output
  • The middle layers output all are passed along, concatenated, and fed through a conv3x3 that first halves number of channels to upsample, then a regular conv3x3
  • The last layer output's takes two pathways:
    • Going down in the figure, the output goes through: max-pool (2x2), conv3x3, conv3x3, upconv2x2. These operations are encompassed in the DecoderConnect class.
    • Going across, assed across and concatenated to the result of above step

Again, these details don't particularly matter, unless you're implementing the architecture yourself. The important point is that upsampling always happens after a concatenation of the encoder's output with the corresponding input to the corresponding level of the decoder.

#collapse_show
class DecoderConnect(nn.Module):
    def __init__(self, inplanes, output_size):
        super(DecoderConnect, self).__init__()
        self.bottom_process = nn.Sequential(
            ConvLayer(inplanes, inplanes * 2, kernel_size=3),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=3),
            *upconv2x2(inplanes * 2, inplanes, size=output_size)
        )
        self.concat_process = nn.Sequential(
            ConcatLayer(),
            ConvLayer(inplanes * 2, inplanes * 2, kernel_size=1),
            ConvLayer(inplanes * 2, inplanes, kernel_size=3),
            ConvLayer(inplanes, inplanes, kernel_size=3)
        )
        
    def forward(self, x):
        decoder_input = self.bottom_process(x)
        return self.concat_process({0: x, 1: decoder_input})

The crux of constructing the decoder happens in the setup_decoder function call below, and consequently in the construct_decoder. The details are hard to extract from the code below, so we can break it down as follows (tracing the code in the setup_decoder function first.

Getting Shapes Using Hooks

We're going to gather the input size and output size of a tensor to any layer in the ResNet encoder network with a name that has the prefix "layer". To do so, we'll use hooks. Specifically, a hook is a closure, i.e. a function that's passed as an argument when registering a hook for a specific layer in our network. You can see here that

  • shape_hook is the function passed when registering a hook
  • We register_forward_hook for any child layer of our network that has a name that startswith layer, e.g. self.layer0, self.layer1, and so on.

Note the specification for shape_hook, and generally for the function passed to register_forward_hook - it will have access to the input and output of the layer we are calling register_forward_hook for (note that input and output can be tuples here). In our case, we only care about their shapes, as we'll need the shape to determine the shape of the decoder's input, and accordingly the number of filters the convolutional layers need to output in the previous layer.

Accordingly, we'll take those shapes, and add them to our input_sizes and output_sizes array, to keep track of the input and output shapes as the network processes an input. To actually populate these arrays, we have to do exctly that - process an input. Thus, we'll make a dummy input (in the code, test_input) that we pass through our encoder, and after it finishes processing that input, our input_sizes and output_sizes array will be populated!

Constructing the Decoder

Now that we have the input and output sizes of any tensors passing through the blocks of our ResNet encoder, we can construct our decoder level by level. To do so, we'll just look at the following things:

  • How much we need to upsample the size of the image (determined by looking at the ratio of the input image size and the output image size)
  • What the difference in channels between the input and output of the corresponding encoder level are (determined by looking at the ratio of channels between input and output)

Looking at both of these gives us a sense of the operation we need to do to reverse what the encoder did. Specifically, we can abide by the following assumptions when constructing the decoder:

  • The shape of the input to the level of the decoder we're working on will be the same as the shape of the output of the corresponding level of the encoder
  • The shape of the output of this level of the decoder will be the same shape as the shape of the input of the corresponding level of the encoder

With these assumptions in mind, and using the details above for constructing the operations for each level of the decoder, we can just use case work for actually constructing the decoder, depending on whether we're looking at the last layer of the encoder, one of the middle layers, or the first layer of the encoder.

Since we're starting from the inputs and outputs of the first layer of the encoder, we add on the constructed layers as we inspect the shapes of the inputs and outputs of the encoder, and then reverse the list of constructed layers when finalizing the decoder architecture, to ensure that we go from the last output shape of the encoder to the first input shape of the encoder, which is (generally) what we want to output for segmentation. (This doesn't necessarily have to be true, in which case, a 1x1 convolution is added at the end of the decoder to get the right number of output channels, specified in the constructing of the class as num_output_channels.

Note that we maintain the decoder as a list of modules, i.e. an nn.ModuleList. This is an intentional choice, as we'll need to perform the operations of our network in sequence by level, as each level requires getting the corresponding output of the encoder, and processing it alongside the corresponding input of the decoder.

Model Forward Using Hooks

The last part of setting up our dynamic U-Net architecture is to specify the forward function. In order to do so, we need to keep track of the outputs of each level of our encoder. Since we've encompassed the encoder as one module when constructing our U-Net, the easiest way to get the outputs for each level of the encoder is to just use hooks again.

The setup for these hooks is very similar to how we set up the shape hooks above, but instead, we only keep track of the outputs, and we want the actual output tensor, not the shape. This is encompassed in the encoder_output_hook hook in the forward function below. Again, we register the hook for all layers in our encoder that have name starting with "layer".

To actually use these outputs, we only need to keep track of the corresponding input we are passing into the current level of the decoder. This becomes convenient to do since we left the decoder as an nn.ModuleList, so we need only iterate over the encoder outputs and the corresponding layer of the decoder that they'll be passed into with the corresponding input to the decoder. This is encompassed in the following loop in the forward function:

prev_output = None
for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
    if prev_output is not None:
        prev_output = rdl({0: reo, 1: prev_output})
    else:
        prev_output = rdl(reo)

Note how that for the first layer of the decoder (the one that ties with the last layer of the encoder), there's no previous output. This is because the first layer of the decoder has the additional pathway (seen at the bottom of the U-Net architecture figure) that is concatenated with the output of the last layer of the encoder. On the other hand, for all other layers, the encoder output (reo) and the decoder input (prev_output) are concatenated together in a single pathway (explicitly, via the ConcatLayer forward function).

#collapse_show
class DynamicUNet(nn.Module):
    def __init__(self, encoder, input_size=(224, 224), num_output_channels=None, verbose=0):
        super(DynamicUNet, self).__init__()
        self.encoder = encoder
        self.verbose = verbose
        self.input_size = input_size
        self.num_input_channels = 3  # This must be 3 because we're using a ResNet encoder
        self.num_output_channels = num_output_channels
        
        self.decoder = self.setup_decoder()
        
    def forward(self, x):
        encoder_outputs = []
        def encoder_output_hook(self, input, output):
            encoder_outputs.append(output)

        handles = [
            child.register_forward_hook(encoder_output_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]

        try:
            self.encoder(x)
        finally:
            if self.verbose >= 1:
                print("Removing all forward handles")
            for handle in handles:
                handle.remove()

        prev_output = None
        for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
            if prev_output is not None:
                prev_output = rdl({0: reo, 1: prev_output})
            else:
                prev_output = rdl(reo)
        return prev_output
                
    def setup_decoder(self):
        input_sizes = []
        output_sizes = []
        def shape_hook(self, input, output):
            input_sizes.append(input[0].shape)
            output_sizes.append(output.shape)

        handles = [
            child.register_forward_hook(shape_hook) for name, child in self.encoder.named_children()
            if name.startswith('layer')
        ]    

        self.encoder.eval()
        test_input = torch.randn(1, self.num_input_channels, *self.input_size)
        try:
            self.encoder(test_input)
        finally:
            if self.verbose >= 1:
                print("Removing all shape hook handles")
            for handle in handles:
                handle.remove()
        decoder = self.construct_decoder(input_sizes, output_sizes, num_output_channels=self.num_output_channels)
        return decoder
        
    def construct_decoder(self, input_sizes, output_sizes, num_output_channels=None):
        decoder_layers = []
        for layer_index, (input_size, output_size) in enumerate(zip(input_sizes, output_sizes)):
            upsampling_size_factor = int(input_size[-1] / output_size[-1])
            upsampling_channel_factor = input_size[-3] / output_size[-3]
            next_layer = []
            bs, c, h, w = input_size
            ops = []
            if layer_index == len(input_sizes) - 1:
                last_layer_ops = DecoderConnect(output_size[-3], output_size[2:])
                last_layer_ops_input = torch.randn(*output_size)
                last_layer_concat_ops_output = last_layer_ops(last_layer_ops_input)
                next_layer.extend([last_layer_ops])
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    last_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    last_layer_concat_upconv_op_output = nn.Sequential(*last_layer_concat_upconv_op)(
                        last_layer_concat_ops_output
                    )
                    next_layer.extend(last_layer_concat_upconv_op)
            elif layer_index == 0:
                first_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    *upconv2x2(
                        output_size[-3] * 2,
                        output_size[-3],
                        size=[dim * upsampling_size_factor for dim in output_size[2:]]
                    ),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3),
                    ConvLayer(
                        output_size[-3],
                        input_size[-3] if self.num_output_channels is None else self.num_output_channels,
                        kernel_size=1
                    ),
                ]
                first_layer_concat_ops_output = nn.Sequential(*first_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(first_layer_concat_ops)
            else:
                middle_layer_concat_ops = [
                    ConcatLayer(),
                    ConvLayer(output_size[-3] * 2, output_size[-3] * 2, kernel_size=1),
                    ConvLayer(output_size[-3] * 2, output_size[-3], kernel_size=3),
                    ConvLayer(output_size[-3], output_size[-3], kernel_size=3)
                ]
                middle_layer_concat_ops_output = nn.Sequential(*middle_layer_concat_ops)(
                    {0: torch.randn(*output_size), 1: torch.randn(*output_size)}
                )
                next_layer.extend(middle_layer_concat_ops)
                if upsampling_size_factor > 1 or upsampling_channel_factor != 1:
                    middle_layer_concat_upconv_op = upconv2x2(output_size[-3], input_size[-3], size=input_size[2:])
                    middle_layer_concat_upconv_op_output = nn.Sequential(*middle_layer_concat_upconv_op)(
                        middle_layer_concat_ops_output
                    )
                    next_layer.extend(middle_layer_concat_upconv_op)
            decoder_layers.append(nn.Sequential(*next_layer))
        return nn.ModuleList(reversed(decoder_layers))

Testing on the CamVid Dataset

For this example, we'll use the CamVid dataset, since they correspond to natural (3 channel) images, which is what our ResNet encoder expects.

#collapse_hide
def load_camvid_dataset(data_directory):
    with open(os.path.join(data_directory, "valid.txt"), "r") as f:
        val_names = [line.strip() for line in f]
    with open(os.path.join(data_directory, "codes.txt"), "r") as f:
        label_mapping = {l.strip(): i for i, l in enumerate(f)}
    data = []
    image_index_mapping = {}
    for im_f in os.listdir(os.path.join(data_directory, "images")):
        if im_f.split('.')[-1] != 'png':
            continue
        image_index_mapping[im_f] = len(data)
        fp = os.path.join(data_directory, "images", im_f)
        data.append(fp)
    for label_f in os.listdir(os.path.join(data_directory, "labels")):
        im_f = label_f.split('.')
        im_f[0] = '_'.join(im_f[0].split('_')[:-1])
        im_f = '.'.join(im_f)
        index = image_index_mapping[im_f]
        fp = os.path.join(data_directory, "labels", label_f)
        data[index] = (data[index], fp)
    val_indices = [image_index_mapping[name] for name in val_names]
    return data, val_indices, label_mapping

#collapse_hide
camvid_data_directory = "/home/jupyter/data/camvid"
all_data, val_indices, label_mapping = load_camvid_dataset(camvid_data_directory);

We'll split the data for now into training and validation, based on the split specified in the dataset itself.

#collapse_show
tr_data, val_data = [tpl for i, tpl in enumerate(all_data) if i not in val_indices], \
                    [tpl for i, tpl in enumerate(all_data) if i in val_indices]

#collapse_hide
def display_segmentation(image, target, ax=None):
    if ax:
        ax.imshow(image, cmap='gray')
    else:
        plt.imshow(image, cmap='gray')
    if ax:
        ax.imshow(target, cmap='jet', alpha=0.5)
    else:
        plt.imshow(target, cmap='jet', alpha=0.5)
    plt.show()
        
def display_segmentation_from_file(im_f, label_f):
    im, label = Image.open(im_f), Image.open(label_f)
    display_segmentation(im, label)

We can visualize some examples and their corresponding segmentations overlayed.

i = 10
display_segmentation_from_file(tr_data[i][0], tr_data[i][1])

Here, we'll use a simple dataset class that transforms each image and it's corresponding segmentation label by randomly (with probably 0.5) vertically flipping, randomly (with probably 0.5) horizontally flipping, and then normalizing the input image using the standard ImageNet normalization statistics.

We also specify a resize shape, defaulting to 360 by 480, where nearest neighbor interpolation is used for resizing.

#collapse_show
class CamvidDataset(torch.utils.data.Dataset):
    def __init__(self, data, resize_shape=(360, 480), is_train=True):
        self.images, self.labels = [tpl[0] for tpl in data], \
                                   [tpl[1] for tpl in data]
        self.resize_shape = resize_shape
        self.is_train = is_train

    def transform(self, index):
        input, target = map(
            Image.open, (self.images[index], self.labels[index]))
        input, target = (
            tf.resize(input, self.resize_shape),
            tf.resize(target, self.resize_shape, interpolation=Image.NEAREST)
        )
        if self.is_train:
            horizontal_draw = torch.rand(1).item()
            vertical_draw = torch.rand(1).item()
            if horizontal_draw > 0.5:
                input, target = tf.hflip(input), tf.hflip(target)
            if vertical_draw > 0.5:
                input, target = tf.vflip(input), tf.vflip(target)
        
        input, target = map(tf.to_tensor, (input, target))
        torch.clamp((255 * target), 0, 32, out=target)
        return tf.normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), target.long()

    def __getitem__(self, index):
        return self.transform(index)

    def __len__(self):
        return len(self.images)
tr_ds, val_ds = CamvidDataset(tr_data, resize_shape=(360, 480)),\
                CamvidDataset(val_data, resize_shape=(360, 480), is_train=False)
tr_dl, val_dl = torch.utils.data.DataLoader(tr_ds, batch_size=4, shuffle=True), torch.utils.data.DataLoader(val_ds)

One last thing that will be useful is the Dice loss, used commonly in segmentation. It is implemented below. We use the differentiable variant (i.e. the smooth variant). In the literature, both the Dice loss and cross-entropy are used, with comparable results. For the purpose of this blog post, we'll optimize directly for the Dice score.

In paritcular, the reason for using the Dice score is that it is just the F1 score, which is the harmonic mean of the precision $P$ and the recall $R$. Specifically, we can write it as:

$$\frac{1}{\frac{1}{P} + \frac{1}{R}}$$

By using this loss, we optimize equally for precision and recall, where optimizing for cross-entropy is (theoretically) a bit more difficult due to background pixels being more likely to occur than specific class pixels.

class DiceLoss(nn.Module):
    """
    Module to compute the Dice segmentation loss. Based on the following discussion:
    https://discuss.pytorch.org/t/one-hot-encoding-with-autograd-dice-loss/9781
    """
    def __init__(self, weights=None, ignore_index=None, eps=0.0001):
        super(DiceLoss, self).__init__()
        self.weights = weights
        self.ignore_index = ignore_index
        self.eps = eps
        
    def forward(self, output, target):
        """
        Arguments:
            output: (N, C, H, W) tensor of probabilities for the predicted output
            target: (N, H, W) tensor corresponding to the pixel-wise labels
        Returns:
            loss: the Dice loss averaged over channels
        """ 
        encoded_target = output.detach() * 0
        if self.ignore_index is not None:
            mask = target == self.ignore_index
            target = target.clone()
            target[mask] = 0
            encoded_target.scatter_(1, target.unsqueeze(1), 1)
            mask = mask.unsqueeze(1).expand_as(encoded_target)
            encoded_target[mask] = 0
        else:
            encoded_target.scatter_(1, target.unsqueeze(1), 1)

        if self.weights is None:
            self.weights = 1

        intersection = output * encoded_target
        numerator = 2 * intersection.sum(0).sum(1).sum(1)
        denominator = output + encoded_target

        if self.ignore_index is not None:
            denominator[mask] = 0
        denominator = denominator.sum(0).sum(1).sum(1) + self.eps
        loss_per_channel = self.weights * (1 - (numerator / denominator))

        return loss_per_channel.sum() / output.size(1)
def dice_similarity(output, target, weights=None, ignore_index=None, eps=1e-8):
    """
    Arguments:
        output: (N, C, H, W) tensor of model output
        target: (N, H, W) tensor corresponding to the pixel-wise labels
    Returns:
        loss: the Dice loss averaged over channels
    """ 
    prediction = torch.argmax(output, dim=1)
    encoded_prediction = output.detach() * 0
    encoded_prediction.scatter_(1, prediction.unsqueeze(1), 1)
    
    encoded_target = output.detach() * 0
    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = mask.unsqueeze(1).expand_as(encoded_target)
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)

    if weights is None:
        weights = 1

    intersection = encoded_prediction * encoded_target
    numerator = 2 * intersection.sum(0).sum(1).sum(1) + eps
    denominator = intersection + encoded_target
    if ignore_index is not None:
        denominator[mask] = 0
    denominator = denominator.sum(0).sum(1).sum(1) + eps
    acc_per_channel = weights * ((numerator / denominator))

    return acc_per_channel.sum() / output.size(1)

To begin, we'll only train the decoder of our model, and freeze the weights of our encoder (since they are pre-trained weights). This is easily done by only specifying the parameters of our decoder in the optimizer settings.

model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(360, 480))
if torch.cuda.is_available():
    model = model.cuda()

decoder_parameters = [item for module in model.decoder for item in module.parameters()]
optimizer = optim.AdamW(decoder_parameters)  # Only training the decoder for now

criterion = DiceLoss()

# Training specific parameters
num_epochs = 10
num_up_epochs, num_down_epochs = 3, 7
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-2,
    total_steps=num_epochs * len(tr_dl),
)

We'll train for 10 epochs to fine-tune the decoder to be on par with the pre-trained weights being used in the ResNet encoder. We're using the default learning rate for Adam, $10^{-3}$. We'll print the per-pixel accuracy, as well as the Dice similarity (which is the F1 score).

Warning: The following cell runs training, and might take a while - take caution if you’re using a preemptible instance of a cloud service.

#collapse_hide
model.train()
losses = []
accuracies = []

tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
    tr_loss, tr_correct_pixels, tr_total_pixels, tr_dice_similarity, total = 0., 0., 0., 0., 0.
    tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
    for i, (x, y) in enumerate(tqdm_epoch_iterator):
        optimizer.zero_grad()
        if torch.cuda.is_available():
            x, y = x.cuda(), y.squeeze(dim=1).cuda()
        output = model(x)
        probs = torch.softmax(output, dim=1)
        prediction = torch.argmax(output, dim=1)
        tr_correct_pixels += ((prediction == y).sum())
        tr_total_pixels += y.numel()
        tr_dice_similarity += dice_similarity(output, y.squeeze(1)) * len(y)
        loss = criterion(output, y.squeeze(1))
        tr_loss += loss.data.cpu() * len(y)
        total += len(y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if i % 1 == 0:
            curr_loss = tr_loss / total
            curr_acc = tr_correct_pixels / tr_total_pixels
            curr_dice = tr_dice_similarity / total
            tqdm_epoch_iterator.set_postfix({
                "Loss": curr_loss.item(), "Accuracy": curr_acc.item(), "Dice": curr_dice.item()
            })
    overall_loss = tr_loss.item() / total
    overall_acc = tr_correct_pixels.item() / tr_total_pixels
    losses.append(overall_loss)
    accuracies.append(overall_acc)
    tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-20-4caa77be1f34> in <module>
     12         if torch.cuda.is_available():
     13             x, y = x.cuda(), y.squeeze(dim=1).cuda()
---> 14         output = model(x)
     15         probs = torch.softmax(output, dim=1)
     16         prediction = torch.argmax(output, dim=1)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-8-d9839ac87df8> in forward(self, x)
     32         for reo, rdl in zip(reversed(encoder_outputs), self.decoder):
     33             if prev_output is not None:
---> 34                 prev_output = rdl({0: reo, 1: prev_output})
     35             else:
     36                 prev_output = rdl(reo)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-6-cc6e2d84af2e> in forward(self, x)
     17         out = self.layer(x)
     18         out = F.relu(out)
---> 19         return out if self.bn_layer is None else self.bn_layer(out)
     20 
     21 class ConcatLayer(nn.Module):

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
    105             input, self.running_mean, self.running_var, self.weight, self.bias,
    106             self.training or not self.track_running_stats,
--> 107             exponential_average_factor, self.eps)
    108 
    109 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   1668     return torch.batch_norm(
   1669         input, weight, bias, running_mean, running_var,
-> 1670         training, momentum, eps, torch.backends.cudnn.enabled
   1671     )
   1672 

KeyboardInterrupt: 

Now we can take a look at the losses over time and the accuracies over time.

#collapse_hide
plt.plot(list(range(len(losses))), losses)
plt.title("Losses over time for fine-tuning decoder")
plt.show()

#collapse_hide
plt.plot(list(range(len(accuracies))), accuracies)
plt.title("Accuracies over time for fine-tuning decoder")
plt.show()

Let's save a model checkpoint so that, if necessary, we can reset back here if our next step goes badly. This is a good practice in general.

import pickle
decoder_finetuned_model_state_dict = model.state_dict()
output_fp = "/share/nikola/export/dt372/technical_blog/dynamic_unet/model_ckpts/decoder_finetuned_model_state_dict.pkl"
with open(output_fp, 'wb') as f:
    pickle.dump(decoder_finetuned_model_state_dict, f)
with open(output_fp, 'rb') as f:
    decoder_finetuned_model_state_dict = pickle.load(f)
model.load_state_dict(decoder_finetuned_model_state_dict)

Now, we'll fine-tune the whole network. We'll use a smaller learning rate for the encoder, and a larger, but still small learning rate for the decoder. Ideally, we'd do some sort of learning rate finding scheme (as done in FastAI), but for simplicity, we'll just pick reasonable numbers.

Now we'll train for 20 epochs, and see how the model performs over time.

optimizer = optim.AdamW([
    {'params': model.decoder.parameters(), 'lr': 1e-4},
    {'params': model.encoder.parameters(), 'lr': 1e-7},
])
num_epochs = 20
num_up_epochs, num_down_epochs = 4, 16
scheduler = torch.optim.lr_scheduler.CyclicLR(
    optimizer, base_lr=[1e-6, 1e-8], max_lr=[1e-3, 1e-6],
    step_size_up=num_up_epochs * len(tr_dl),
    step_size_down=num_down_epochs * len(tr_dl),
    mode='exp_range',
    cycle_momentum=False  # Need to set this since we're using AdamW?
)

#collapse_hide
model.train()

tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
    tr_loss, tr_correct_pixels, tr_total_pixels, total = 0., 0., 0., 0.
    tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
    for i, (x, y) in enumerate(tqdm_epoch_iterator):
        optimizer.zero_grad()
        if torch.cuda.is_available():
            x, y = x.cuda(), y.squeeze(dim=1).cuda()
        output = model(x)
        prediction = torch.argmax(output, dim=1)
        tr_correct_pixels += ((prediction == y).sum())
        tr_total_pixels += y.numel()
        loss = criterion(output, y)
        tr_loss += loss.data.cpu() * len(y)
        total += len(y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if i % 1 == 0:
            curr_loss = tr_loss / total
            curr_acc = tr_correct_pixels / tr_total_pixels
            tqdm_epoch_iterator.set_postfix({"Loss": curr_loss.item(), "Accuracy": curr_acc.item()})
    overall_loss = tr_loss.item() / total
    overall_acc = tr_correct_pixels.item() / tr_total_pixels
    losses.append(overall_loss)
    accuracies.append(overall_acc)
    tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})

Again, let's look at how the losses and accuracies change (after epoch 10 is where we started the above fine-tuning cycle).

#collapse_hide
plt.plot(list(range(len(losses))), losses)
plt.title("Losses over time for fine-tuning decoder")
plt.show()

#collapse_hide
plt.plot(list(range(len(accuracies))), accuracies)
plt.title("Accuracies over time for fine-tuning decoder")
plt.show()

Now, we can visualize some results on the validation dataset. First is the prediction, second is the ground truth.

index = 10
x, y = val_dl.dataset[index]
x, y = x.unsqueeze(0).cuda(), y.cuda()

model.eval()
output = model(x)
prediction = torch.argmax(output, dim=1).cpu()

display_segmentation(tf.to_pil_image(x[0].cpu()), tf.to_pil_image(prediction.byte()))
display_segmentation_from_file(val_dl.dataset.images[index], val_dl.dataset.labels[index])

Clearly, there's some room for improvement.

Note: this post is a draft - I’m waiting on GPUs to do training using the Dice loss, to directly optimize for the Dice similarity metric.