The utility functions here can be used for training and evaluation of the model.

mixup_data[source]

mixup_data(x, y, criterion, alpha=1.0)

Compute the mixup data for batch `x, y`. Return mixed inputs, pairs of targets, and lambda.

This function is called in train if mixup is specified as true.

  • x, y should be torch.Tensor
  • criterion should be a torch loss function, e.g. nn.CrossEntropyLoss
  • alpha is a float defining the distribution for sampling the mixing value (see the Mixup paper for details)

setup_logging_streams[source]

setup_logging_streams(model, log_to_file=True, log_to_stdout=False)

Utility function for setting up logging handlers for `model`.

This function helps set up logging functionality for readable outputted metrics.

  • model should be the model constructed from the ResNet helper functions. They are initialized with logging and output directories by default, as long as you specify existing overarching model and logging directories.
  • log_to_file specifies whether to output to a metrics log file in the model's logging directory
  • log_to_stdout specifies whether to output metrics to STDOUT

The function returns a closure that, when called, will clear any handlers set up in the logging module for outputting to log file or STDOUT, depending on what was specified. To avoid any confusion when logging between training runs in the same notebook, it's important to call this closure to not have redundant logging.

train[source]

train(model, epoch, dataloader, criterion, optimizer, scheduler=None, mixup=False, alpha=0.4, logging_frequency=50)

Trains `model` on data in `dataloader` with loss `criterion` and optimization scheme
defined by `optimizer`, with optional learning schedule defined by `scheduler`.

This function performs 1 epoch of training.

  • model should be a torch.nn.Module
  • epoch should indicate the current epoch of training, and is only really necessary for logging purposes.
  • dataloader should be atorch.utils.data.DataLoaderwrapping aBreaKHisDataset` object
  • criterion should be a torch loss function
  • optimizer should be a torch.optim.Optimizer, e.g. Adam
  • scheduler is optional, but when included, should be a torch.optim._LRScheduler, e.g. CyclicLR
  • mixup is a boolean indicating whether to use mixup augmentation for training (default is False)
  • alpha is a float determining the distribution for sampling the mixing ratio
  • logging_frequency determines the cycle of iterations before logging metrics

validate[source]

validate(model, epoch, dataloader, criterion, tta=False, tta_mixing=0.6, logging_frequency=50)

Validates `model` on data in `dataloader` for epoch `epoch` using objective `criterion`.

This function performs 1 epoch of validation.

  • model should be a torch.nn.Module
  • epoch should indicate the current epoch of training, and is only really necessary for logging purposes.
  • dataloader should be atorch.utils.data.DataLoaderwrapping aBreaKHisDataset` object
  • criterion should be a torch loss function
  • optimizer should be a torch.optim.Optimizer, e.g. Adam
  • tta is a boolean indicating whether to use test-time augmentation (default is False)
  • tta_mixing determines how much of the test-time augmented data to use in determining the final output (default is 0.6)
  • logging_frequency determines the cycle of iterations before logging metrics

Here are some toy examples using the functions defined above. For brevity, we use a small subset of the dataset.

#example
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms

def get_tta_transforms(resize_shape, normalize_transform, n=5):
    tta = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop((resize_shape, resize_shape)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor()
    ])
    original_transform = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        transforms.ToTensor()
    ])
    return transforms.Compose([
        transforms.Lambda(
            lambda image: torch.stack(
                [tta(image) for _ in range(n)] + [original_transform(image)]
            )
        ),
        transforms.Lambda(
            lambda images: torch.stack([
                normalize_transform(image) for image in images
            ])
        ),
    ])

def get_transforms(resize_shape, tta=False, tta_n=5):
    random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
    random_horizontal_flip = transforms.RandomHorizontalFlip()
    resize = transforms.Resize((resize_shape, resize_shape))
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transforms = transforms.Compose([
        random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
    ])
    val_transforms = (
        get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
        else transforms.Compose([resize, transforms.ToTensor(), normalize])
    )
    return train_transforms, val_transforms
    
train_transform, val_transform = get_transforms(224, tta=True)
#example
ds_mapping = initialize_datasets(
    '/share/nikola/export/dt372/BreaKHis_v1/',
    label='tumor_class', criterion=['tumor_type', 'magnification'],
    split_transforms={'train': train_transform, 'val': val_transform}
)
#example
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
#example
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.AdamW([{'params': model.out_fc.parameters(), 'lr': 1e-3}])
mixup = True
criterion = {
    'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
    'val': nn.CrossEntropyLoss()
}

The training loop might include something like the following. Note the calls to clear_logging_handlers - this should be included in your code as well to avoid logging redundancy.

#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=True)
try:
    tr_loss, tr_acc = train(
        model, 0, tr_dl, criterion['train'], optimizer, mixup=mixup, alpha=0.4,
        logging_frequency=25
    )
    val_loss, val_acc = validate(
        model, 0, val_dl, criterion['val'], tta=True,
        logging_frequency=25
    )
except BaseException:
    clear_logging_handlers()
finally:
    clear_logging_handlers()
Logging to /share/nikola/export/dt372/breakhis_gradcam/logs/2020-02-16-23-18-10/metrics.log
Logging to STDOUT
[Metrics][02:16:2020:06:18:29][DEBUG]: [Epoch 0, Iteration 25 / 198] Training Loss: 0.63333, Training Accuracy: 1.94728 [Projected Accuracy: 15.42243]
[Metrics][02:16:2020:06:18:41][DEBUG]: [Epoch 0, Iteration 50 / 198] Training Loss: 0.58616, Training Accuracy: 3.94854 [Projected Accuracy: 15.63623]
[Metrics][02:16:2020:06:18:54][DEBUG]: [Epoch 0, Iteration 75 / 198] Training Loss: 0.56723, Training Accuracy: 5.90912 [Projected Accuracy: 15.60008]
[Metrics][02:16:2020:06:19:07][DEBUG]: [Epoch 0, Iteration 100 / 198] Training Loss: 0.55472, Training Accuracy: 7.89645 [Projected Accuracy: 15.63498]
[Metrics][02:16:2020:06:19:21][DEBUG]: [Epoch 0, Iteration 125 / 198] Training Loss: 0.54388, Training Accuracy: 9.99066 [Projected Accuracy: 15.82520]
[Metrics][02:16:2020:06:19:34][DEBUG]: [Epoch 0, Iteration 150 / 198] Training Loss: 0.53556, Training Accuracy: 12.05890 [Projected Accuracy: 15.91775]
[Metrics][02:16:2020:06:19:47][DEBUG]: [Epoch 0, Iteration 175 / 198] Training Loss: 0.52685, Training Accuracy: 14.26282 [Projected Accuracy: 16.13737]
[Metrics][02:16:2020:06:19:58][INFO]: Reporting 0.52171 training loss, 15.97910 training accuracy for epoch 0.
[Metrics][02:16:2020:06:20:50][DEBUG]: [Epoch 0, Iteration 25 / 50] Validation Loss: 0.41677, Validation Accuracy: 0.46014 [Validation Accuracy: 0.92028]
[Metrics][02:16:2020:06:21:36][INFO]: Reporting 0.41893 validation loss, 0.87947 validation accuracy for epoch 0.
Cleared all logging handlers

Since we were just testing here, it might be annoying to have find the log and state files later to remove for saving memory. However, we can just do the following to resolve that (this will delete all the contents to clear the log and model/system state directory):

#example
model.clear_logging_and_output_dirs()
Removing directory /share/nikola/export/dt372/breakhis_gradcam/logs/2020-02-16-23-18-10 and all contents.
Removing directory /share/nikola/export/dt372/breakhis_gradcam/models/2020-02-16-23-18-10 and all contents.
Resetting /share/nikola/export/dt372/breakhis_gradcam/logs/2020-02-16-23-18-10 and /share/nikola/export/dt372/breakhis_gradcam/models/2020-02-16-23-18-10.

get_param_lr_maps[source]

get_param_lr_maps(model, base_lr, finetune_body_factor)

Output parameter LR mappings for setting up an optimizer for `model`.

This function is useful for setting up parameter to LR mappings for fine-tuning the model. Specifically:

  • model should be a torch.nn.Module
  • base_lr should be a float, defining the LR for the linear head
  • finetune_body_factor should be a list of two floats: a lower bound factor and upper bound factor. The learning rate for the body of the model will be equally (log) spaced between (base_lr lower_bound_factor) and (base_lr upper_bound_factor)

setup_optimizer_and_scheduler[source]

setup_optimizer_and_scheduler(param_lr_maps, base_lr, epochs, steps_per_epoch)

Create a PyTorch AdamW optimizer and OneCycleLR scheduler with `param_lr_maps` parameter mapping,
with base LR `base_lr`, for training for `epochs` epochs, with `steps_per_epoch` iterations
per epoch.

checkpoint_state[source]

checkpoint_state(model, epoch, optimizer, scheduler, train_loss, train_acc, val_loss, val_acc)

Checkpoint the state of the system, including `model` state, `optimizer` state, `scheduler`
state, for `epoch`, saving the metrics as well.

In the below example, you can see how to set up the optimizer and scheduler to fine-tune using the one-cycle LR scheme. The linear head is fine-tuned with a learning rate of $10^{-3}$, and the body is fine-tuned with a learning rate spaced between $10^{-8}$ and $10^{-5}$.

#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
    model = model.cuda()
mixup = True
num_epochs = 5
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
    'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
    'val': nn.CrossEntropyLoss()
}

A simple training loop would look like the following. Note that:

  • The one-cycle LR scheduler is passed in, and the logic for updating that is handled in train
  • Different criterion are used for training and validation. This is because the criterion for mixup is different for each batch, due to the mixing factor, so the criterion is modified in the loop for training, so the reduction is handled there, whereas reduction is standard when evaluating in validation (i.e. mean reduction)
  • Test-time augmentation is done in validation. Note that this will require having a special augmentation scheme, so validation transforms will need to be set appropriately. You can see above for an example of how to do that.
  • The model state is checkpointed each epoch. After checkpointing the state of the model and system, the directory where the state was saved can be accessed by inspecting model.save_dir.
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
    tr_loss, tr_acc = train(
        model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
        mixup=mixup, alpha=0.4, logging_frequency=25
    )
    val_loss, val_acc = validate(
        model, epoch + 1, val_dl, criterion['val'], tta=True,
        logging_frequency=25
    )
    checkpoint_state(
        model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
    )
clear_logging_handlers()
Logging to /share/nikola/export/dt372/breakhis_gradcam/logs/2020-02-16-23-22-02/metrics.log
Cleared all logging handlers
#example
os.listdir(model.save_dir)
['epoch_1.pth', 'epoch_2.pth', 'epoch_3.pth', 'epoch_4.pth', 'epoch_5.pth']

We can just use the validate method with some slight alterations to get the standard training accuracy (not the mixup accracy, which might not be as representative).

#example
_, tr_acc_no_mixup = validate(model, epoch + 1, tr_dl, criterion['val'], tta=False, logging_frequency=25)
print("Training accuracy after %d epochs is %.5f" % (epoch + 1, tr_acc_no_mixup))
Training accuracy after 5 epochs is 0.96358