utils package¶
Submodules¶
utils.estimate_loss module¶
- utils.estimate_loss.estimate_loss(model, ctx, dataset, config, splits=['train', 'val'])¶
Estimates the average loss for the provided model on the given data splits.
This function puts the model into evaluation mode, then iteratively samples batches from the dataset and computes the loss for each split. The function returns a dictionary containing the average loss for each data split.
Parameters¶
- modeltorch.nn.Module
The model for which the loss is to be estimated.
- ctxcontextlib._GeneratorContextManager
The context manager for gradient computation. This is typically the result of a torch.no_grad() or torch.enable_grad() context.
- datasetobject
The dataset object, which should have a get_batch() method for obtaining batches of data.
- configobject
- The configuration object. It should have the following attributes:
eval_iters (int): The number of iterations to perform for each split.
- splitslist, optional
The list of data splits for which the loss should be estimated. The default is [‘train’, ‘val’].
Returns¶
- dict
A dictionary where the keys are the names of the data splits and the values are the estimated average losses for those splits.
utils.plot_losses module¶
- utils.plot_losses.plot_losses(losses, xlim=None, ylim=None)¶
Plots the training and validation losses per epoch on a black background with vibrant colors.
The function creates a line plot with two lines - one for the training loss and one for the validation loss. The x-axis represents the epoch number and the y-axis represents the loss. The function allows to manually set the x and y limits of the plot.
Parameters¶
- lossesdict
A dictionary containing ‘train’ and ‘val’ lists. These lists should contain the recorded losses for each epoch during training and validation, respectively.
- xlimtuple, optional
A tuple of two integers specifying the minimum and maximum x-values to be plotted on the graph. If None, the x-axis limits will be determined automatically.
- ylimtuple, optional
A tuple of two integers specifying the minimum and maximum y-values to be plotted on the graph. If None, the y-axis limits will be determined automatically.
Returns¶
- None
The function doesn’t return a value. It displays a matplotlib plot.
utils.training_loop module¶
- utils.training_loop.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)¶
Creates a learning rate scheduler for training Transformers.
This scheduler first warms up the learning rate linearly for a given number of steps, and then decays the learning rate linearly to 0 for the rest of the training steps.
Parameters¶
- optimizertorch.optim.Optimizer
The optimizer for which to schedule the learning rate.
- num_warmup_stepsint
The number of steps for the warmup phase.
- num_training_stepsint
The total number of training steps.
Returns¶
- torch.optim.lr_scheduler.LambdaLR
A learning rate scheduler that adjusts the learning rate of the optimizer according to the warm-up and decay strategy.
- utils.training_loop.training_loop(model, ctx, optimizer, scaler, dataset, config, saved_path='./out/transformer_state_dict.pth')¶
Performs the training loop for the given transformer model using the provided optimizer and dataset.
This function trains the model for a specified number of epochs, and implements gradient accumulation, gradient clipping, and learning rate scheduling. It also monitors the training and validation losses and implements early stopping when validation loss stops improving.
Parameters¶
- modeltorch.nn.Module
The model to train.
- ctxtorch.cuda.amp.autocast_mode.autocast
The context manager for mixed precision training.
- optimizertorch.optim.Optimizer
The optimizer to use for training.
- scalertorch.cuda.amp.GradScaler
The gradient scaler for mixed precision training.
- datasettorch.utils.data.Dataset
The dataset to use for training.
- configobject
The configuration object containing various hyperparameters and settings.
- saved_pathstr, optional
The path where the model should be saved, defaults to “./out/transformer_state_dict.pth”.
Returns¶
- dict
A dictionary containing the training and validation losses at each evaluation step.