utils package

Submodules

utils.estimate_loss module

utils.estimate_loss.estimate_loss(model, ctx, dataset, config, splits=['train', 'val'])

Estimates the average loss for the provided model on the given data splits.

This function puts the model into evaluation mode, then iteratively samples batches from the dataset and computes the loss for each split. The function returns a dictionary containing the average loss for each data split.

Parameters

modeltorch.nn.Module

The model for which the loss is to be estimated.

ctxcontextlib._GeneratorContextManager

The context manager for gradient computation. This is typically the result of a torch.no_grad() or torch.enable_grad() context.

datasetobject

The dataset object, which should have a get_batch() method for obtaining batches of data.

configobject
The configuration object. It should have the following attributes:

eval_iters (int): The number of iterations to perform for each split.

splitslist, optional

The list of data splits for which the loss should be estimated. The default is [‘train’, ‘val’].

Returns

dict

A dictionary where the keys are the names of the data splits and the values are the estimated average losses for those splits.

utils.plot_losses module

utils.plot_losses.plot_losses(losses, xlim=None, ylim=None)

Plots the training and validation losses per epoch on a black background with vibrant colors.

The function creates a line plot with two lines - one for the training loss and one for the validation loss. The x-axis represents the epoch number and the y-axis represents the loss. The function allows to manually set the x and y limits of the plot.

Parameters

lossesdict

A dictionary containing ‘train’ and ‘val’ lists. These lists should contain the recorded losses for each epoch during training and validation, respectively.

xlimtuple, optional

A tuple of two integers specifying the minimum and maximum x-values to be plotted on the graph. If None, the x-axis limits will be determined automatically.

ylimtuple, optional

A tuple of two integers specifying the minimum and maximum y-values to be plotted on the graph. If None, the y-axis limits will be determined automatically.

Returns

None

The function doesn’t return a value. It displays a matplotlib plot.

utils.training_loop module

utils.training_loop.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

Creates a learning rate scheduler for training Transformers.

This scheduler first warms up the learning rate linearly for a given number of steps, and then decays the learning rate linearly to 0 for the rest of the training steps.

Parameters

optimizertorch.optim.Optimizer

The optimizer for which to schedule the learning rate.

num_warmup_stepsint

The number of steps for the warmup phase.

num_training_stepsint

The total number of training steps.

Returns

torch.optim.lr_scheduler.LambdaLR

A learning rate scheduler that adjusts the learning rate of the optimizer according to the warm-up and decay strategy.

utils.training_loop.training_loop(model, ctx, optimizer, scaler, dataset, config, saved_path='./out/transformer_state_dict.pth')

Performs the training loop for the given transformer model using the provided optimizer and dataset.

This function trains the model for a specified number of epochs, and implements gradient accumulation, gradient clipping, and learning rate scheduling. It also monitors the training and validation losses and implements early stopping when validation loss stops improving.

Parameters

modeltorch.nn.Module

The model to train.

ctxtorch.cuda.amp.autocast_mode.autocast

The context manager for mixed precision training.

optimizertorch.optim.Optimizer

The optimizer to use for training.

scalertorch.cuda.amp.GradScaler

The gradient scaler for mixed precision training.

datasettorch.utils.data.Dataset

The dataset to use for training.

configobject

The configuration object containing various hyperparameters and settings.

saved_pathstr, optional

The path where the model should be saved, defaults to “./out/transformer_state_dict.pth”.

Returns

dict

A dictionary containing the training and validation losses at each evaluation step.

Module contents