Source code for ablator.modules.scheduler

import typing as ty
from abc import abstractmethod

from torch import nn
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau, StepLR, _LRScheduler
from torch.optim import Optimizer

from ablator.config.main import ConfigBase, Derived, configclass


Scheduler = ty.Union[_LRScheduler, ReduceLROnPlateau, ty.Any]

StepType = ty.Literal["train", "val", "epoch"]


[docs]@configclass class SchedulerArgs(ConfigBase): """ Abstract base class for defining arguments to initialize a learning rate scheduler. Attributes ---------- step_when : StepType The step type at which the scheduler.step() should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ # step every train step or every validation step step_when: StepType
[docs] @abstractmethod def init_scheduler(self, model, optimizer): """ Abstract method to be implemented by derived classes, which creates and returns a scheduler object. """ raise NotImplementedError("init_optimizer method not implemented.")
[docs]@configclass class SchedulerConfig(ConfigBase): """ Class that defines a configuration for a learning rate scheduler. Attributes ---------- name : str The name of the scheduler. arguments : SchedulerArgs The arguments needed to initialize the scheduler. """ name: str arguments: SchedulerArgs
[docs] def __init__(self, name, arguments: dict[str, ty.Any]): """ Initializes the scheduler configuration. Parameters ---------- name : str The name of the scheduler, this can be any in ``['None', 'step', 'cycle', 'plateau']``. arguments : dict[str, ty.Any] The arguments for the scheduler, specific to a certain type of scheduler. Examples -------- In the following example, ``scheduler_config`` will initialize property ``arguments`` of type ``StepLRConfig``, setting ``step_size=1``, ``gamma=0.99`` as its properties. We also have access to ``init_scheduler()`` method of the property, which initalizes an StepLR scheduler. This method is actually called in ``make_scheduler()`` >>> scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) """ _arguments: None | StepLRConfig | OneCycleConfig | PlateuaConfig if (argument_cls := SCHEDULER_CONFIG_MAP[name]) is None: _arguments = StepLRConfig(gamma=1) else: _arguments = argument_cls(**arguments) super().__init__(name=name, arguments=_arguments)
[docs] def make_scheduler(self, model, optimizer) -> Scheduler: """ Creates a new scheduler for an optimizer, based on the configuration. Parameters ---------- model The model. optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- Scheduler The scheduler. Examples -------- >>> scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler_config.make_scheduler(model, optimizer) """ return self.arguments.init_scheduler(model, optimizer)
[docs]@configclass class OneCycleConfig(SchedulerArgs): """ Configuration class for the OneCycleLR scheduler. Attributes ---------- max_lr : float Upper learning rate boundaries in the cycle. total_steps : Derived[int] The total number of steps to run the scheduler in a cycle. step_when : StepType The step type at which the scheduler.step() should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ max_lr: float total_steps: Derived[int] # TODO fix mypy errors for custom types # type: ignore step_when: StepType = "train"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initializes the OneCycleLR scheduler. Creates and returns a OneCycleLR scheduler that monitors optimizer's learning rate. Parameters ---------- model : nn.Module The model. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- OneCycleLR The OneCycleLR scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = OneCycleConfig(max_lr=0.5, total_steps=100) >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return OneCycleLR(optimizer, **kwargs)
[docs]@configclass class PlateuaConfig(SchedulerArgs): """Configuration class for ReduceLROnPlateau scheduler. Attributes ---------- patience : int Number of epochs with no improvement after which learning rate will be reduced. min_lr : float A lower bound on the learning rate. mode : str One of ``'min'``, ``'max'``, or ``'auto'``, which defines the direction of optimization, so as to adjust the learning rate accordingly, i.e when a certain metric ceases improving. factor : float Factor by which the learning rate will be reduced. ``new_lr = lr * factor``. threshold : float Threshold for measuring the new optimum, to only focus on significant changes. verbose : bool If ``True``, prints a message to ``stdout`` for each update. step_when : StepType The step type at which the scheduler should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ patience: int = 10 min_lr: float = 1e-5 mode: str = "min" factor: float = 0.0 # TODO {fixme} this is error prone -> new_lr = 0 threshold: float = 1e-4 verbose: bool = False # TODO fix mypy errors for custom types # type: ignore step_when: StepType = "val"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initialize the ReduceLROnPlateau scheduler. Parameters ---------- model : nn.Module The model being optimized. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- ReduceLROnPlateau The ReduceLROnPlateau scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = PlateuaConfig(min_lr=1e-7, mode='min') >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return ReduceLROnPlateau(optimizer, **kwargs)
[docs]@configclass class StepLRConfig(SchedulerArgs): """ Configuration class for StepLR scheduler. Parameters ---------- step_size : int Period of learning rate decay, by default 1. gamma : float Multiplicative factor of learning rate decay, by default 0.99. step_when : StepType The step type at which the scheduler should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ step_size: int = 1 gamma: float = 0.99 # TODO fix mypy errors for custom types # type: ignore step_when: StepType = "epoch"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initialize the StepLR scheduler for a given model and optimizer. Parameters ---------- model : nn.Module The model to apply the scheduler. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- StepLR The StepLR scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = StepLRConfig(step_size=20, gamma=0.9) >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return StepLR(optimizer, **kwargs)
SCHEDULER_CONFIG_MAP = { "none": None, "step": StepLRConfig, "cycle": OneCycleConfig, "plateau": PlateuaConfig, }