Using NBDev to Package Dynamic U-Net Code [Draft]
A brief follow-up of the previoius post for installing the dynamic U-Net code for your own use.
First, make sure to install the package, which can be found here.
$(which pip) install
, to ensure I’m installing to the right site packages directory.
#show
import dynamic_unet
We'll use some finer-grained imports to mimic exactly what we did in the previous blog post.
#show
from dynamic_unet.encoder import resnet34
from dynamic_unet.unet import DynamicUNet
from dynamic_unet.utils import CamvidDataset, display_segmentation_from_file, load_camvid_dataset
from dynamic_unet.opt import dice_similarity, DiceLoss
import torch
from tqdm.notebook import tqdm
Next, let's load the dataset.
#collapse_show
camvid_data_directory = "/home/jupyter/data/camvid"
all_data, val_indices, label_mapping = load_camvid_dataset(camvid_data_directory);
Split it into train/validation splits.
#collapse_show
tr_data, val_data = [tpl for i, tpl in enumerate(all_data) if i not in val_indices], \
[tpl for i, tpl in enumerate(all_data) if i in val_indices]
Let's visualize an example quickly.
i = 10
display_segmentation_from_file(tr_data[i][0], tr_data[i][1])
Now, let's initialize the PyTorch dataset and dataloaders.
tr_ds, val_ds = CamvidDataset(tr_data, resize_shape=(360, 480)),\
CamvidDataset(val_data, resize_shape=(360, 480), is_train=False)
tr_dl, val_dl = torch.utils.data.DataLoader(tr_ds, batch_size=4, shuffle=True), torch.utils.data.DataLoader(val_ds)
Finally, we'll initialize our model, loss, and optimizer. For now, we'll just look to train the decoder parameters. More details can be found in the previous post for further fine-tuning.
model = DynamicUNet(resnet34(), num_output_channels=32, input_size=(360, 480))
if torch.cuda.is_available():
model = model.cuda()
decoder_parameters = [item for module in model.decoder for item in module.parameters()]
optimizer = torch.optim.AdamW(decoder_parameters) # Only training the decoder for now
criterion = DiceLoss()
# Training specific parameters
num_epochs = 10
num_up_epochs, num_down_epochs = 3, 7
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-2,
total_steps=num_epochs * len(tr_dl),
)
We'll copy over the training loop from the previous post to see how the model trains.
#collapse_hide
model.train()
losses = []
accuracies = []
tqdm_iterator = tqdm(range(num_epochs), position=0)
for epoch in tqdm_iterator:
tr_loss, tr_correct_pixels, tr_total_pixels, tr_dice_similarity, total = 0., 0., 0., 0., 0.
tqdm_epoch_iterator = tqdm(tr_dl, position=1, leave=False)
for i, (x, y) in enumerate(tqdm_epoch_iterator):
optimizer.zero_grad()
if torch.cuda.is_available():
x, y = x.cuda(), y.squeeze(dim=1).cuda()
output = model(x)
probs = torch.softmax(output, dim=1)
prediction = torch.argmax(output, dim=1)
tr_correct_pixels += ((prediction == y).sum())
tr_total_pixels += y.numel()
tr_dice_similarity += dice_similarity(output, y.squeeze(1)) * len(y)
loss = criterion(output, y.squeeze(1))
tr_loss += loss.data.cpu() * len(y)
total += len(y)
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
if i % 1 == 0:
curr_loss = tr_loss / total
curr_acc = tr_correct_pixels / tr_total_pixels
curr_dice = tr_dice_similarity / total
tqdm_epoch_iterator.set_postfix({
"Loss": curr_loss.item(), "Accuracy": curr_acc.item(), "Dice": curr_dice.item()
})
overall_loss = tr_loss.item() / total
overall_acc = tr_correct_pixels.item() / tr_total_pixels
losses.append(overall_loss)
accuracies.append(overall_acc)
tqdm_iterator.set_postfix({"Loss": overall_loss, "Accuracy": overall_acc})
And that's it! That's all you need to use pre-trained ResNets for image segmentation of natural images.