This function is called in train if mixup is specified as true.
x,yshould betorch.Tensorcriterionshould be atorchloss function, e.g.nn.CrossEntropyLossalphais a float defining the distribution for sampling the mixing value (see the Mixup paper for details)
This function helps set up logging functionality for readable outputted metrics.
modelshould be the model constructed from the ResNet helper functions. They are initialized with logging and output directories by default, as long as you specify existing overarching model and logging directories.log_to_filespecifies whether to output to a metrics log file in the model's logging directorylog_to_stdoutspecifies whether to output metrics to STDOUT
The function returns a closure that, when called, will clear any handlers set up in the logging module for outputting to log file or STDOUT, depending on what was specified. To avoid any confusion when logging between training runs in the same notebook, it's important to call this closure to not have redundant logging.
This function performs 1 epoch of training.
modelshould be atorch.nn.Moduleepochshould indicate the current epoch of training, and is only really necessary for logging purposes.dataloader should be atorch.utils.data.DataLoaderwrapping aBreaKHisDataset` objectcriterionshould be atorchloss functionoptimizershould be atorch.optim.Optimizer, e.g. Adamscheduleris optional, but when included, should be atorch.optim._LRScheduler, e.g. CyclicLRmixupis a boolean indicating whether to use mixup augmentation for training (default is False)alphais a float determining the distribution for sampling the mixing ratiologging_frequencydetermines the cycle of iterations before logging metrics
This function performs 1 epoch of validation.
modelshould be atorch.nn.Moduleepochshould indicate the current epoch of training, and is only really necessary for logging purposes.dataloader should be atorch.utils.data.DataLoaderwrapping aBreaKHisDataset` objectcriterionshould be atorchloss functionoptimizershould be atorch.optim.Optimizer, e.g. Adamttais a boolean indicating whether to use test-time augmentation (default is False)tta_mixingdetermines how much of the test-time augmented data to use in determining the final output (default is 0.6)logging_frequencydetermines the cycle of iterations before logging metrics
Here are some toy examples using the functions defined above. For brevity, we use a small subset of the dataset.
#example
from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms
def get_tta_transforms(resize_shape, normalize_transform, n=5):
tta = transforms.Compose([
transforms.RandomRotation(15),
transforms.RandomResizedCrop((resize_shape, resize_shape)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor()
])
original_transform = transforms.Compose([
transforms.Resize((resize_shape, resize_shape)),
transforms.ToTensor()
])
return transforms.Compose([
transforms.Lambda(
lambda image: torch.stack(
[tta(image) for _ in range(n)] + [original_transform(image)]
)
),
transforms.Lambda(
lambda images: torch.stack([
normalize_transform(image) for image in images
])
),
])
def get_transforms(resize_shape, tta=False, tta_n=5):
random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
random_horizontal_flip = transforms.RandomHorizontalFlip()
resize = transforms.Resize((resize_shape, resize_shape))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_transforms = transforms.Compose([
random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
])
val_transforms = (
get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
else transforms.Compose([resize, transforms.ToTensor(), normalize])
)
return train_transforms, val_transforms
train_transform, val_transform = get_transforms(224, tta=True)
#example
ds_mapping = initialize_datasets(
'/share/nikola/export/dt372/BreaKHis_v1/',
label='tumor_class', criterion=['tumor_type', 'magnification'],
split_transforms={'train': train_transform, 'val': val_transform}
)
#example
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
#example
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
model = model.cuda()
optimizer = torch.optim.AdamW([{'params': model.out_fc.parameters(), 'lr': 1e-3}])
mixup = True
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
The training loop might include something like the following. Note the calls to clear_logging_handlers - this should be included in your code as well to avoid logging redundancy.
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=True)
try:
tr_loss, tr_acc = train(
model, 0, tr_dl, criterion['train'], optimizer, mixup=mixup, alpha=0.4,
logging_frequency=25
)
val_loss, val_acc = validate(
model, 0, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
except BaseException:
clear_logging_handlers()
finally:
clear_logging_handlers()
Since we were just testing here, it might be annoying to have find the log and state files later to remove for saving memory. However, we can just do the following to resolve that (this will delete all the contents to clear the log and model/system state directory):
#example
model.clear_logging_and_output_dirs()
This function is useful for setting up parameter to LR mappings for fine-tuning the model. Specifically:
modelshould be atorch.nn.Modulebase_lrshould be a float, defining the LR for the linear headfinetune_body_factorshould be a list of two floats: a lower bound factor and upper bound factor. The learning rate for the body of the model will be equally (log) spaced between (base_lrlower_bound_factor) and (base_lrupper_bound_factor)
In the below example, you can see how to set up the optimizer and scheduler to fine-tune using the one-cycle LR scheme. The linear head is fine-tuned with a learning rate of $10^{-3}$, and the body is fine-tuned with a learning rate spaced between $10^{-8}$ and $10^{-5}$.
#example
model = resnet18(pretrained=True, num_classes=2)
if torch.cuda.is_available():
model = model.cuda()
mixup = True
num_epochs = 5
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
'val': nn.CrossEntropyLoss()
}
A simple training loop would look like the following. Note that:
- The one-cycle LR scheduler is passed in, and the logic for updating that is handled in
train - Different criterion are used for training and validation. This is because the criterion for mixup is different for each batch, due to the mixing factor, so the criterion is modified in the loop for training, so the reduction is handled there, whereas reduction is standard when evaluating in validation (i.e. mean reduction)
- Test-time augmentation is done in validation. Note that this will require having a special augmentation scheme, so validation transforms will need to be set appropriately. You can see above for an example of how to do that.
- The model state is checkpointed each epoch. After checkpointing the state of the model and system, the directory where the state was saved can be accessed by inspecting
model.save_dir.
#example
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
tr_loss, tr_acc = train(
model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
mixup=mixup, alpha=0.4, logging_frequency=25
)
val_loss, val_acc = validate(
model, epoch + 1, val_dl, criterion['val'], tta=True,
logging_frequency=25
)
checkpoint_state(
model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
)
clear_logging_handlers()
#example
os.listdir(model.save_dir)
We can just use the validate method with some slight alterations to get the standard training accuracy (not the mixup accracy, which might not be as representative).
#example
_, tr_acc_no_mixup = validate(model, epoch + 1, tr_dl, criterion['val'], tta=False, logging_frequency=25)
print("Training accuracy after %d epochs is %.5f" % (epoch + 1, tr_acc_no_mixup))