Source code for ablator.main.model.main

import copy
import math
import traceback
import typing as ty
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from functools import cached_property
from pathlib import Path

import setproctitle
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import ablator.utils.base as butils
from ablator.main.configs import RunConfig
from ablator.modules.loggers.main import SummaryLogger
from ablator.modules.metrics.main import TrainMetrics
from ablator.utils.base import Dummy


[docs]class EvaluationError(Exception): pass
[docs]class TrainPlateauError(Exception): pass
[docs]class LogStepError(Exception): pass
[docs]class CheckpointNotFoundError(FileNotFoundError): pass
[docs]class ModelBase(ABC): """ Base class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality. Attributes ---------- model_class : Type[nn.Module] The class definition of the model's structure, which is a subclass of ``nn.Module``. run_config : RunConfig An instance of ``RunConfig`` containing configuration details. train_dataloader : DataLoader A DataLoader object responsible for model training. val_dataloader : Optional[DataLoader] An optional DataLoader object used for model evaluation. test_dataloader : Optional[DataLoader] An optional DataLoader object used for model testing. logger : Union[SummaryLogger, tutils.Dummy] Records information on the program's operation and model training, such as progress and performance metrics. device : str The type of device used for running the experiment. i.e. ``"cuda"``, ``"cpu"``, ``"cuda:0"``. model_dir : Path The model directory. experiment_dir : Path The experiment directory. autocast : torch.autocast Enables autocasting for chosen regions. Autocasting automatically chooses the precision for GPU operations to improve performance while maintaining accuracy. verbose : bool If ``True``, prints additional information while training. Only applied for the master process. amp : bool If ``True``, apply automatic mixed precision training, otherwise default precision. random_seed : Optional[int] Sets the seed for generating random numbers. train_tqdm : tqdm, optional An optional instance of ``tqdm`` that creates progress bars and displays real-time information during training. i.e. time remaining. Only applied for the master process. current_checkpoint : Optional[Path] Directory for the current checkpoint file, by default None. metrics : Metrics Training metrics including model information. i.e. learning rate and loss value. current_state : dict The currrent state of the model, including run_config, metrics and other necessary states. learning_rate : float The current learning rate. total_steps : int The total steps for the training process. epochs : int The total epochs for the training process. current_iteration : int The current iteration of training. best_iteration : int The iteration with the best loss value. best_loss : float The lowest loss value encountered during training. Notes ----- 1. Class properties are simply listed by name. Please check out property docstring for more information. 2. Users must implement the abstract methods to customize the model's behavior. 3. Mixed precision training enables some operations to use the ``torch.float32`` datatype and other operations use lower precision floating point datatype ``torch.float16``. This is for saving time and reducing memory usage. Ordinarily, "automatic mixed precision training" means training with ``torch.autocast`` and ``torch.cuda.amp.GradScaler`` together. More information: https://pytorch.org/docs/stable/amp.html """
[docs] def __init__( self, model_class: type[nn.Module], ): """Initializes the ModelBase class with the required ``model_class`` and optional configurations. Parameters ---------- model_class : type[nn.Module] The base class for user's model, which defines the neural network. """ self.model_class = model_class self.run_config: RunConfig self.train_dataloader: DataLoader self.val_dataloader: DataLoader | None = None self.test_dataloader: DataLoader | None = None self.logger: ty.Union[SummaryLogger, Dummy] self.device: str self.model_dir: Path | None = None self.experiment_dir: Path | None = None self.autocast: torch.autocast self.verbose: ty.Literal["tqdm", "console", "silent"] self.amp: bool self.random_seed: ty.Optional[int] self.train_tqdm: tqdm = None self.current_checkpoint: Path | None = None # Runtime metrics self.metrics: TrainMetrics self.current_state: dict = {} # stats self.learning_rate = float("inf") self.total_steps: int self.epochs: int self.current_iteration = 0 self.best_iteration = 0 self.best_loss = float("inf")
@property def train_stats(self) -> OrderedDict: """ Returns an ordered dictionary containing the current training statistics. Returns ------- OrderedDict An ordered dictionary with the following keys and values: - learning_rate: The current learning rate. - total_steps: The total steps for the training process. - epochs: The number of epochs for training. - current_epoch: The current epoch during training. - current_iteration: The current iteration during training. - best_iteration: The iteration with the best loss value so far. - best_loss: The best (lowest) loss value achieved during training. """ return OrderedDict( learning_rate=self.learning_rate, total_steps=self.total_steps, epochs=self.epochs, current_epoch=self.current_epoch, current_iteration=self.current_iteration, best_iteration=self.best_iteration, best_loss=self.best_loss, ) @property def current_epoch(self) -> int: """ Calculates and returns the current epoch during training. Returns ------- int The current epoch number. """ if self.current_iteration > 0: return math.floor(self.current_iteration / self.total_steps * self.epochs) return 0 @cached_property def epoch_len(self): """ Returns the length of an epoch, which is the number of batches in the ``train_dataloader``. Returns ------- int The length of an epoch, represented as the number of batches in the ``train_dataloader``. Raises ------ AssertionError If the ``train_dataloader`` is not defined or its length is 0. """ assert ( hasattr(self, "train_dataloader") and len(self.train_dataloader) > 0 ), "Undefined train_dataloader." return len(self.train_dataloader) @cached_property def eval_itr(self): """ Calculate the interval between evaluations. Returns ------- int The interval between evaluations. """ return math.ceil(self.run_config.eval_epoch * self.epoch_len) @cached_property def log_itr(self): """ Calculate the interval between logging steps. Returns ------- int The interval between logging steps. """ return math.ceil(self.run_config.log_epoch * self.epoch_len) @property def uid(self): """ Returns a unique identifier (UID) for the current run configuration. Returns ------- str A string representing the unique identifier of the current run configuration. """ return self.run_config.uid def _get_process_name(self) -> str: """ Retrieves the process name based on the model directory, experiment directory, or UID. Returns ------- str The process name for the current instance. """ if self.model_dir is not None and self.experiment_dir is not None: proc_title = self.model_dir.relative_to( self.experiment_dir.parent ).as_posix() else: proc_title = self.uid return proc_title
[docs] @abstractmethod def create_model( self, save_dict: dict[str, ty.Any] | None = None, strict_load: bool = True, ) -> None: """ Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the ``create_model`` method in the ``ModelWrapper`` class. Parameters ---------- save_dict : dict[str, ty.Any] | None, optional A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default ``None``. strict_load : bool, optional If True, the model will be loaded strictly, ensuring that the saved state matches the model's structure exactly. If False, the model can be loaded with a partially matching state, by default ``True``. """ raise NotImplementedError
[docs] @abstractmethod def checkpoint(self, is_best=False): """ Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the ``checkpoint`` method in the ``ModelWrapper`` class. Parameters ---------- is_best : bool, optional Indicates if the current checkpoint is the best model so far, by default ``False``. """ raise NotImplementedError
[docs] @abstractmethod def train( self, run_config: RunConfig, smoke_test: bool = False, ): """ Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the ``train`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. smoke_test : bool, optional Whether to run as a smoke test, by default ``False``. """ raise NotImplementedError
[docs] @abstractmethod def evaluate( self, run_config: RunConfig, ): """ Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the ``evaluate`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. """ raise NotImplementedError
[docs] @abstractmethod def make_dataloaders(self, run_config: RunConfig): """ Abstract method to create dataloaders for the training, validation, and testing datasets. This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided ``run_config``. Must be implemented by subclasses. Example implementation: Please see the ``make_dataloaders`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. """ raise NotImplementedError
[docs] @abstractmethod def config_parser(self, run_config: RunConfig): """ Abstract method to parse the provided configuration. Must be implemented by subclasses. Example implementation: Please see the ``make_dataloaders`` method in the ``ModelWrapper`` class. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. """ raise NotImplementedError
def _init_logger(self, resume=False, debug=False): """ Initializes the logger used for recording experiment details and progress. Parameters ---------- resume : bool, optional If True, the logger will resume logging from a previous experiment, by default False. debug : bool, optional If True, logger will log additional debug information, by default False. """ self.logger = SummaryLogger( run_config=self.run_config, model_dir=self.model_dir, resume=resume, keep_n_checkpoints=self.run_config.keep_n_checkpoints, verbose=self.run_config.verbose == "console", ) if butils.debugger_is_active() and not debug: self.logger.warn("Debug flag is False but running in debug mode.") self.logger.info(f"Model directory: {self.model_dir}") def _make_dataloaders(self, run_config: RunConfig): """ Creates the data loaders for the training process. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. """ self.make_dataloaders(run_config) assert ( len(self.train_dataloader) > 0 ), "Must define a train dataloader in `make_dataloader`" self.epochs = self.run_config.train_config.epochs def _init_class_attributes(self, debug=False): """ Initializes the class attributes based on the provided configuration. This function sets up various class attributes related to device, mixed precision, warnings handling, early stopping, metrics, experiment and model directories, and process title. Parameters ---------- debug : bool, optional If True, disables the experiment and model directories creation, by default False. """ run_config = self.run_config self.device = butils.parse_device(run_config.device) self.amp = run_config.amp if self.device == "cpu" and self.amp: raise ValueError( "AMP is not supported for CPU. You will need to set `run_config.amp` to False." ) self.autocast = torch.autocast( enabled=self.amp, device_type="cuda" if "cuda" in self.device else "cpu", ) self.verbose = run_config.verbose if self.verbose == "silent": warnings.filterwarnings("ignore") if ( run_config.early_stopping_iter is not None and run_config.early_stopping_iter > 0 ): assert ( self.val_dataloader is not None ), "dataloader function has to return validation set when setting early stopping to True" self.metrics = TrainMetrics( batch_limit=run_config.metrics_n_batches, memory_limit=int(run_config.metrics_mb_limit * 1e6), moving_average_limit=self.epoch_len, evaluation_functions=self.evaluation_functions(), tags=["train"] + (["val"] if self.val_dataloader is not None else []), static_aux_metrics=self.train_stats, moving_aux_metrics=["loss"] + getattr(self, "aux_metric_names", []), ) if self.run_config.experiment_dir is not None and not debug: self.experiment_dir = Path(self.run_config.experiment_dir) self.model_dir = self.experiment_dir.joinpath(self.uid) if debug and (self.experiment_dir is not None or self.model_dir is not None): self.experiment_dir = self.model_dir = None setproctitle.setproctitle(self._get_process_name()) def _init_model_state(self, resume: bool = False, smoke_test: bool = False): """ Initializes the model state based on provided parameters and configuration. Parameters ---------- resume : bool, optional If True, tries to resume training from a checkpoint, by default False. smoke_test : bool, optional Whether to run as a smoke test, by default False. """ if self.run_config.init_chkpt is not None and resume: self.current_checkpoint = Path(self.run_config.init_chkpt) self._load_model(self.current_checkpoint, model_only=False) elif self.run_config.init_chkpt is not None and not resume: # Loads only the weights self.current_checkpoint = Path(self.run_config.init_chkpt) self.logger.info( f"Initializing model weights ONLY from checkpoint. {self.current_checkpoint}" ) self._load_model(self.current_checkpoint, model_only=True) elif resume and not smoke_test: if "recent" not in self.logger.CHKPT_DIRS: raise RuntimeError("Checkpoint folder was not found.") recent_checkpoint_dir = self.logger.CHKPT_DIRS["recent"] # NOTE: current_checkpoint is found in _find_load_valid_checkpoint self._find_load_valid_checkpoint(recent_checkpoint_dir) else: self.current_checkpoint = None self.logger.info("Creating new model") self.create_model() self._update_save_dict() def _init_state( self, run_config: RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False, ): """ Initializes the state of the trainer based on provided configuration and parameters. Parameters ---------- run_config : RunConfig An instance of ``RunConfig`` containing configuration details. smoke_test : bool, optional Whether to run as a smoke test, by default False. debug : bool, optional If True, disables logging and model directory creation, by default False. resume : bool, optional If True, tries to resume training from a checkpoint, by default False. """ self.run_config = run_config self.random_seed = self.run_config.random_seed if self.random_seed is not None: butils.set_seed(self.random_seed) self.run_config = run_config _run_config = copy.deepcopy(run_config) self._make_dataloaders(self.run_config) self.run_config = self.config_parser(run_config) self._init_class_attributes(debug=debug) # Does not create log artifacts during smoke test if not smoke_test: self._init_logger(resume=resume, debug=debug) else: self.logger = butils.Dummy() self._init_model_state(resume, smoke_test) self.run_config.assert_state(_run_config) if self.verbose == "tqdm" and not smoke_test: self.train_tqdm = tqdm( total=self.epoch_len, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", position=0, leave=True, dynamic_ncols=True, ) else: self.train_tqdm = butils.Dummy() def _find_load_valid_checkpoint(self, chkpt_dir): """ Finds and loads the latest valid checkpoint from the given directory. Parameters ---------- chkpt_dir : str The directory containing the checkpoints. Raises ------ CheckpointNotFoundError If no valid checkpoint is found in the specified directory. RuntimeError If a checkpoint is not found. """ latest_checkpoints = butils.get_latest_chkpts(chkpt_dir) current_checkpoint = None if len(latest_checkpoints) > 0: # Try to load first valid chkpt in case there was a crash and some checkpoint is unrecoverable for i, _checkpoint in enumerate(latest_checkpoints): try: self.logger.info(f"Loading checkpoint {_checkpoint}") self._load_model(_checkpoint, model_only=False) current_checkpoint = _checkpoint break except Exception as e: if i == len(latest_checkpoints) - 1: # if it is the last checkpoint raise exception raise RuntimeError("Checkpoint not found") from e # ignore exception self.logger.error( f"Error loading checkpoint {_checkpoint}. Trying another....\n{traceback.format_exc()}" ) if current_checkpoint is None: raise CheckpointNotFoundError(f"Could not find a valid checkpoint in {chkpt_dir}") self.current_checkpoint = current_checkpoint def _load_model(self, checkpoint_path: Path, model_only: bool = False) -> None: """ Loads the model and its state from the checkpoint file at the specified path. Parameters ---------- checkpoint_path : Path The path to the checkpoint file containing the model and its state. model_only : bool, optional, default=False If True, only the model's weights will be loaded, ignoring other state information. Raises ------ NotImplementedError If the model's run configuration is not initialized before attempting to load the model. """ if not hasattr(self, "run_config") or self.run_config is None: raise NotImplementedError( "Can not load model on an unitialzed model state. Consider run init_experiment_state function first" ) save_dict = torch.load(checkpoint_path, map_location="cpu") run_config = type(self.run_config)(**save_dict["run_config"]) assert run_config.uid == self.run_config.uid self._load_stats(save_dict) self.load_checkpoint(save_dict, model_only=model_only) self.current_state = save_dict
[docs] @abstractmethod def load_checkpoint( self, save_dict: dict[str, ty.Any], model_only: bool = False ) -> None: """ Abstract method to load the model and its state from a given save dictionary. Must be implemented by subclasses. Example implementation: Please see the ``load_checkpoint`` method in the ``ModelWrapper`` class. Parameters ---------- save_dict : dict[str, ty.Any] A dictionary containing the saved model state and other necessary information. model_only : bool, optional, default=False If ``True``, only the model's weights will be loaded, ignoring other state information. """ raise NotImplementedError
[docs] @abstractmethod def save_dict(self) -> dict[str, ty.Any] | None: """ Abstract method to create and return a save dictionary containing the model's state and other necessary information. Must be implemented by subclasses. Example implementation: Please see the ``save_dict`` method in the ``ModelWrapper`` class. Returns ------- dict[str, ty.Any] | None A dictionary containing the saved model state and other necessary information. """ raise NotImplementedError
[docs] @abstractmethod def evaluation_functions(self) -> dict[str, Callable] | None: """ Abstract method to create and return a dictionary of evaluation functions used during training and validation. Must be implemented by subclasses. Example implementation: Please see the ``evaluation_functions`` method in the ``ModelWrapper`` class. Returns ------- dict[str, Callable] | None A dictionary containing evaluation functions as values and their names as keys. """ raise NotImplementedError
def _load_stats(self, save_dict) -> None: """ Loads the saved training and validation metrics from the save_dict and updates the model's internal metrics with the loaded values. Parameters ---------- save_dict : dict A dictionary containing the saved model state, metrics, and other necessary information. """ metrics = copy.deepcopy(save_dict["metrics"]) for k in self.train_stats: if ( isinstance(getattr(type(self), k, None), property) and getattr(type(self), k).fset is None ): if getattr(self, k, None) != metrics[k]: self.logger.warn( f"Immutable class attribute {k} value {getattr(self, k)} " f"different than loaded value {metrics[k]}" ) del metrics[k] for k in self.train_stats: if k in metrics: setattr(self, k, metrics[k]) del metrics[k] self.metrics.update_static_metrics(self.train_stats) tags = {m.split("_")[0] for m in metrics} metric_names = {m.split("_")[1] for m in metrics} for tag in tags: self.metrics.update_ma_metrics( {m: metrics[f"{tag}_{m}"] for m in metric_names}, tag=tag ) def _update_save_dict(self, user_save_dict: dict[str, ty.Any] | None = None): """ Updates the current state dictionary with run_config and metrics. If a user_save_dict is provided, it is also merged into the current state dictionary. Parameters ---------- user_save_dict : dict[str, ty.Any] | None, optional A dictionary containing user-defined information to be saved, by default None. """ self.current_state = { "run_config": self.run_config.to_dict(), "metrics": self.metrics.to_dict(), } if user_save_dict is not None: self.current_state.update(**user_save_dict) def _checkpoint(self, is_best=False): """ Updates the current state dictionary with user-defined save_dict and calls the checkpoint method. Parameters ---------- is_best : bool, optional Indicates if the current checkpoint is the best model so far, by default False. """ user_save_dict = self.save_dict() self._update_save_dict(user_save_dict) self.checkpoint(is_best=is_best)