In the last post, we saw how to build the dynamic U-Net architecture using pre-trained ResNets as the encoder backbone. In this post, we can show how to use that code base without having to copy paste, thanks to NBDev.

First, make sure to install the package, which can be found here.

Note: For installing packages, I like to use the command $(which pip) install, to ensure I’m installing to the right site packages directory.
Then, we import it below.
#show
import dynamic_unet

We'll use some finer-grained imports to mimic exactly what we did in the previous blog post.

#show
from dynamic_unet.encoder import resnet34
from dynamic_unet.unet import DynamicUNet
from dynamic_unet.utils import CamvidDataset, display_segmentation_from_file, load_camvid_dataset
from dynamic_unet.opt import dice_similarity, DiceLoss

import torch
from tqdm.notebook import tqdm

Next, let's load the dataset.

#collapse_show
camvid_data_directory = "/home/jupyter/data/camvid"
all_data, val_indices, label_mapping = load_camvid_dataset(camvid_data_directory);

Split it into train/validation splits.

#collapse_show
tr_data, val_data = [tpl for i, tpl in enumerate(all_data) if i not in val_indices], \
                    [tpl for i, tpl in enumerate(all_data) if i in val_indices]

Let's visualize an example quickly.

i = 10
display_segmentation_from_file(tr_data[i][0], tr_data[i][1])

Now, let's initialize the PyTorch dataset and dataloaders.

tr_ds, val_ds = CamvidDataset(tr_data, resize_shape=(360, 480)),\
                CamvidDataset(val_data, resize_shape=(360, 480), is_train=False)
tr_dl, val_dl = torch.utils.data.DataLoader(tr_ds, batch_size=4, shuffle=True), torch.utils.data.DataLoader(val_ds)

Finally, we'll initialize our model, loss, and optimizer. For now, we'll just look to train the decoder parameters. More details can be found in the previous post for further fine-tuning.

model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(360, 480))
if torch.cuda.is_available():
    model = model.cuda()

decoder_parameters = [item for module in model.decoder for item in module.parameters()]
optimizer = torch.optim.AdamW(decoder_parameters)  # Only training the decoder for now

criterion = DiceLoss()

# Training specific parameters
num_epochs = 10
num_up_epochs, num_down_epochs = 3, 7
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-2,
    total_steps=num_epochs * len(tr_dl),
)

We'll copy over the training loop from the previous post to see how the model trains.

#collapse_hide
model.train()
losses = []
accuracies = []

tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
    tr_loss, tr_correct_pixels, tr_total_pixels, tr_dice_similarity, total = 0., 0., 0., 0., 0.
    tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
    for i, (x, y) in enumerate(tqdm_epoch_iterator):
        optimizer.zero_grad()
        if torch.cuda.is_available():
            x, y = x.cuda(), y.squeeze(dim=1).cuda()
        output = model(x)
        probs = torch.softmax(output, dim=1)
        prediction = torch.argmax(output, dim=1)
        tr_correct_pixels += ((prediction == y).sum())
        tr_total_pixels += y.numel()
        tr_dice_similarity += dice_similarity(output, y.squeeze(1)) * len(y)
        loss = criterion(output, y.squeeze(1))
        tr_loss += loss.data.cpu() * len(y)
        total += len(y)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if i % 1 == 0:
            curr_loss = tr_loss / total
            curr_acc = tr_correct_pixels / tr_total_pixels
            curr_dice = tr_dice_similarity / total
            tqdm_epoch_iterator.set_postfix({
                "Loss": curr_loss.item(), "Accuracy": curr_acc.item(), "Dice": curr_dice.item()
            })
    overall_loss = tr_loss.item() / total
    overall_acc = tr_correct_pixels.item() / tr_total_pixels
    losses.append(overall_loss)
    accuracies.append(overall_acc)
    tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-9-4caa77be1f34> in <module>
     12         if torch.cuda.is_available():
     13             x, y = x.cuda(), y.squeeze(dim=1).cuda()
---> 14         output = model(x)
     15         probs = torch.softmax(output, dim=1)
     16         prediction = torch.argmax(output, dim=1)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/dynamic_unet/unet.py in forward(self, x)
     35 
     36         try:
---> 37             self.encoder(x)
     38         finally:
     39             if self.verbose >= 1:

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/dynamic_unet/encoder.py in forward(self, x)
    146         x = self.layer2(x)
    147         x = self.layer3(x)
--> 148         x = self.layer4(x)
    149 
    150         return x

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/dynamic_unet/encoder.py in forward(self, x)
     51         out = self.relu(out)
     52 
---> 53         out = self.conv2(out)
     54         out = self.bn2(out)
     55 

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/conv.py in forward(self, input)
    343 
    344     def forward(self, input):
--> 345         return self.conv2d_forward(input, self.weight)
    346 
    347 class Conv3d(_ConvNd):

/opt/conda/envs/technical_blog/lib/python3.7/site-packages/torch/nn/modules/conv.py in conv2d_forward(self, input, weight)
    340                             _pair(0), self.dilation, self.groups)
    341         return F.conv2d(input, weight, self.bias, self.stride,
--> 342                         self.padding, self.dilation, self.groups)
    343 
    344     def forward(self, input):

KeyboardInterrupt: 

And that's it! That's all you need to use pre-trained ResNets for image segmentation of natural images.

Note: this post is a draft - I’m waiting on GPUs to do training using the Dice loss, to directly optimize for the Dice similarity metric.