This function is called in train if mixup
is specified as true.
x
,y
should betorch.Tensor
criterion
should be atorch
loss function, e.g.nn.CrossEntropyLoss
alpha
is a float defining the distribution for sampling the mixing value (see the Mixup paper for details)
This function helps set up logging functionality for readable outputted metrics.
model
should be the model constructed from the ResNet helper functions. They are initialized with logging and output directories by default, as long as you specify existing overarching model and logging directories.log_to_file
specifies whether to output to a metrics log file in the model's logging directorylog_to_stdout
specifies whether to output metrics to STDOUT
The function returns a closure that, when called, will clear any handlers set up in the logging module for outputting to log file or STDOUT, depending on what was specified. To avoid any confusion when logging between training runs in the same notebook, it's important to call this closure to not have redundant logging.
This function performs 1 epoch of training.
model
should be atorch.nn.Module
epoch
should indicate the current epoch of training, and is only really necessary for logging purposes.dataloader should be a
torch.utils.data.DataLoaderwrapping a
BreaKHisDataset` objectcriterion
should be atorch
loss functionoptimizer
should be atorch.optim.Optimizer
, e.g. Adamscheduler
is optional, but when included, should be atorch.optim._LRScheduler
, e.g. CyclicLRmixup
is a boolean indicating whether to use mixup augmentation for training (default is False)alpha
is a float determining the distribution for sampling the mixing ratiologging_frequency
determines the cycle of iterations before logging metrics
This function performs 1 epoch of validation.
model
should be atorch.nn.Module
epoch
should indicate the current epoch of training, and is only really necessary for logging purposes.dataloader should be a
torch.utils.data.DataLoaderwrapping a
BreaKHisDataset` objectcriterion
should be atorch
loss functionoptimizer
should be atorch.optim.Optimizer
, e.g. Adamtta
is a boolean indicating whether to use test-time augmentation (default is False)tta_mixing
determines how much of the test-time augmented data to use in determining the final output (default is 0.6)logging_frequency
determines the cycle of iterations before logging metrics
Here are some toy examples using the functions defined above. For brevity, we use a small subset of the dataset.
#example
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms
def get_tta_transforms(resize_shape, normalize_transform, n=5):
tta = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomResizedCrop((resize_shape, resize_shape)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor()
])
original_transform = transforms.Compose([
transforms.Resize((resize_shape, resize_shape)),
transforms.ToTensor()
])
return transforms.Compose([
transforms.Lambda(
lambda image: torch.stack(
[tta(image) for _ in range(n)] + [original_transform(image)]
)
),
transforms.Lambda(
lambda images: torch.stack([
normalize_transform(image) for image in images
])
),
])
def get_transforms(resize_shape, tta=False, tta_n=5):
random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
random_horizontal_flip = transforms.RandomHorizontalFlip()
resize = transforms.Resize((resize_shape, resize_shape))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_transforms = transforms.Compose([
random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
])
val_transforms = (
get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
else transforms.Compose([resize, transforms.ToTensor(), normalize])
)
return train_transforms, val_transforms
train_transform, val_transform = get_transforms(224, tta=True)
#example
ds_mapping = initialize_datasets(
'/share/nikola/export/dt372/BreaKHis_v1/',
label='tumor_class', criterion=['tumor_type', 'magnification'],
split_transforms={'train': train_transform, 'val': val_transform}
)
#example
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
#example
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
model = model.cuda()
optimizer = torch.optim.AdamW([{'params': model.out_fc.parameters(), 'lr': 1e-3}])
mixup = True
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
The training loop might include something like the following. Note the calls to clear_logging_handlers
- this should be included in your code as well to avoid logging redundancy.
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=True)
try:
tr_loss, tr_acc = train(
model, 0, tr_dl, criterion['train'], optimizer, mixup=mixup, alpha=0.4,
logging_frequency=25
)
val_loss, val_acc = validate(
model, 0, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
except BaseException:
clear_logging_handlers()
finally:
clear_logging_handlers()
Since we were just testing here, it might be annoying to have find the log and state files later to remove for saving memory. However, we can just do the following to resolve that (this will delete all the contents to clear the log and model/system state directory):
#example
model.clear_logging_and_output_dirs()
This function is useful for setting up parameter to LR mappings for fine-tuning the model. Specifically:
model
should be atorch.nn.Module
base_lr
should be a float, defining the LR for the linear headfinetune_body_factor
should be a list of two floats: a lower bound factor and upper bound factor. The learning rate for the body of the model will be equally (log) spaced between (base_lr
lower_bound_factor
) and (base_lr
upper_bound_factor
)
In the below example, you can see how to set up the optimizer and scheduler to fine-tune using the one-cycle LR scheme. The linear head is fine-tuned with a learning rate of $10^{-3}$, and the body is fine-tuned with a learning rate spaced between $10^{-8}$ and $10^{-5}$.
#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
model = model.cuda()
mixup = True
num_epochs = 5
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
A simple training loop would look like the following. Note that:
- The one-cycle LR scheduler is passed in, and the logic for updating that is handled in
train
- Different criterion are used for training and validation. This is because the criterion for mixup is different for each batch, due to the mixing factor, so the criterion is modified in the loop for training, so the reduction is handled there, whereas reduction is standard when evaluating in validation (i.e. mean reduction)
- Test-time augmentation is done in validation. Note that this will require having a special augmentation scheme, so validation transforms will need to be set appropriately. You can see above for an example of how to do that.
- The model state is checkpointed each epoch. After checkpointing the state of the model and system, the directory where the state was saved can be accessed by inspecting
model.save_dir
.
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
tr_loss, tr_acc = train(
model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
mixup=mixup, alpha=0.4, logging_frequency=25
)
val_loss, val_acc = validate(
model, epoch + 1, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
checkpoint_state(
model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
)
clear_logging_handlers()
#example
os.listdir(model.save_dir)
We can just use the validate method with some slight alterations to get the standard training accuracy (not the mixup accracy, which might not be as representative).
#example
_, tr_acc_no_mixup = validate(model, epoch + 1, tr_dl, criterion['val'], tta=False, logging_frequency=25)
print("Training accuracy after %d epochs is %.5f" % (epoch + 1, tr_acc_no_mixup))