Welcome to PyTorch Warmup’s documentation!

This library contains PyTorch implementations of the warmup schedules described in On the adequacy of untuned warmup for adaptive optimization.

Warmup schedule Python package PyPI version shields.io PyPI license Python versions

Installation

Make sure you have Python 3.9+ and PyTorch 1.9+ or 2.x. Then, install the latest version from the Python Package Index:

pip install -U pytorch_warmup

Examples

Open In Colab
  • CIFAR10 - A sample script to train a ResNet model on the CIFAR10 dataset using an optimization algorithm with a warmup schedule. Its README presents ResNet20 results obtained using each of AdamW, NAdamW, AMSGradW, and AdaMax together with each of various warmup schedules. In addition, there is a ResNet performance comparison (up to ResNet110) obtained using the SGD algorithm with a linear warmup schedule.

  • EMNIST - A sample script to train a CNN model on the EMNIST dataset using the AdamW algorithm with a warmup schedule. Its README presents a result obtained using the AdamW algorithm with each of the untuned linear and exponential warmup, and the RAdam warmup.

  • Plots - A script to plot effective warmup periods as a function of \(\beta_{2}\), and warmup schedules over time.

Usage

When the learning rate schedule uses the global iteration number, the untuned linear warmup can be used together with Adam or its variant (AdamW, NAdam, etc.) as follows:

import torch
import pytorch_warmup as warmup

optimizer = torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
   # This sample code uses the AdamW optimizer.
num_steps = len(dataloader) * num_epochs
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
   # The LR schedule initialization resets the initial LR of the optimizer.
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
   # The warmup schedule initialization dampens the initial LR of the optimizer.
for epoch in range(1,num_epochs+1):
   for batch in dataloader:
      optimizer.zero_grad()
      loss = ...
      loss.backward()
      optimizer.step()
      with warmup_scheduler.dampening():
            lr_scheduler.step()

Warning

Note that the warmup schedule must not be initialized before the initialization of the learning rate schedule.

Other approaches can be found in README.

Indices and tables