import copy
import multiprocessing as mp
import traceback
import typing as ty
from abc import abstractmethod
from collections.abc import Callable, Iterable
import numpy as np
import torch
from torch import nn
from torch.cuda.amp import GradScaler
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import ablator.utils.base as butils
from ablator.main.configs import ModelConfig, RunConfig, TrainConfig
from ablator.main.model.main import EvaluationError, ModelBase, TrainPlateauError
from ablator.modules.metrics.main import LossDivergedError, TrainMetrics
from ablator.modules.optimizer import OptimizerConfig
from ablator.modules.scheduler import Scheduler, SchedulerConfig
[docs]class ModelWrapper(ModelBase):
"""
A wrapper around model_class that removes training boiler-plate code, with over-writable functions
with support for custom use-cases.
Attributes
----------
model_class: torch.nn.Module
The model class to wrap.
model: torch.nn.Module
The model created from the model class or checkpoint
optimizer: Optimizer
The optimizer created from the optimizer config or checkpoint
scaler: GradScaler
The scaler created from the scaler config or checkpoint
scheduler: Scheduler
The scheduler created from the scheduler config or checkpoint
"""
[docs] def __init__(
self,
model_class: type[nn.Module],
):
"""
Initializes the model wrapper.
Parameters
----------
model_class: torch.nn.Module
The model class to wrap.
"""
super().__init__(
model_class=model_class,
)
# Will be loaded or created from checkpoint
self.model: nn.Module
self.optimizer: Optimizer
self.scaler: GradScaler
self.scheduler: Scheduler | None
@property
def train_config(self) -> TrainConfig:
return self.run_config.train_config
@property
def model_config(self) -> ModelConfig:
return self.run_config.model_config
[docs] def create_model(
self,
save_dict: dict[str, ty.Any] | None = None,
strict_load: bool = True,
) -> None:
"""
Creates the model, optimizer, scheduler and scaler from the save dict or from config.
Parameters
----------
save_dict: dict[str, ty.Any]
The save dict to load from.
strict_load: bool
Whether to load the model strictly or not.
"""
save_dict = {} if save_dict is None else save_dict
scheduler_state = save_dict["scheduler"] if "scheduler" in save_dict else None
optimizer_state = save_dict["optimizer"] if "optimizer" in save_dict else None
scaler_state = save_dict["scaler"] if "scaler" in save_dict else None
model_class = self.model_class
model: nn.Module
if (model_config := self.model_config) is not None:
model = model_class(model_config) # type: ignore
else:
# Support of decleartive paradigm without model over-writing
model = model_class()
if "model" in save_dict:
model.load_state_dict(save_dict["model"], strict=strict_load)
elif self.train_config.rand_weights_init:
model.apply(butils.init_weights)
model = model.to(self.device)
optimizer = self.create_optimizer(
model=model,
optimizer_config=self.train_config.optimizer_config,
optimizer_state=optimizer_state,
)
scheduler = self.create_scheduler(
model=model,
optimizer=optimizer,
scheduler_config=self.train_config.scheduler_config,
scheduler_state=scheduler_state,
)
scaler = self.create_scaler(scaler_state=scaler_state)
self.model = model
self.optimizer = optimizer
self.scaler = scaler
self.scheduler = scheduler
[docs] def create_scheduler(
self,
model: nn.Module,
optimizer: Optimizer,
scheduler_config: SchedulerConfig | None = None,
scheduler_state: dict | None = None,
) -> Scheduler | None:
"""
Creates the scheduler from the saved state or from config.
Parameters
----------
model: nn.Module
The model to create the scheduler for.
optimizer: Optimizer
The optimizer to create the scheduler for.
scheduler_config: SchedulerConfig
The scheduler config to create the scheduler from.
scheduler_state: dict[str, ty.Any]
The scheduler state to load the scheduler from.
Returns
-------
scheduler: Scheduler
The scheduler.
"""
scheduler: ty.Optional[Scheduler] = None
if scheduler_config is not None:
scheduler = scheduler_config.make_scheduler(model, optimizer)
if scheduler_state is not None:
if scheduler is None:
self.logger.warn(
"Supplied `scheduler_state` without `scheduler_config`. Ignoring scheduler."
)
return None
scheduler.load_state_dict(scheduler_state)
return scheduler
[docs] def create_optimizer(
self,
model: nn.Module,
optimizer_config: OptimizerConfig | None = None,
optimizer_state: dict[str, ty.Any] | None = None,
) -> Optimizer:
"""
Creates the optimizer from the saved state or from config.
Parameters
----------
model: nn.Module
The model to create the optimizer for.
optimizer_config: OptimizerConfig
The optimizer config to create the optimizer from.
optimizer_state: dict[str, ty.Any]
The optimizer state to load the optimizer from.
Returns
-------
optimizer: Optimizer
The optimizer.
"""
optimizer: Optimizer
if optimizer_config is not None:
optimizer = optimizer_config.make_optimizer(model)
if optimizer_state is not None and optimizer is not None:
# NOTE: because https://github.com/pytorch/pytorch/issues/80809
# TODO any good fix for this yet?
for k in optimizer_state["state"].keys():
if "step" in optimizer_state["state"][k] and isinstance(
optimizer_state["state"][k]["step"], torch.Tensor
):
optimizer_state["state"][k]["step"] = optimizer_state["state"][k][
"step"
].cpu()
optimizer.load_state_dict(optimizer_state)
elif optimizer_state is not None:
self.logger.warn(
"Supplied `optimizer_state` without `optimizer_config`. Ignoring optimizer."
)
return optimizer
[docs] def create_scaler(self, scaler_state: ty.Optional[dict] = None) -> GradScaler:
"""
Creates the scaler from the saved state or from config.
Parameters
----------
scaler_state: dict[str, ty.Any]
The scaler state to load the scaler from.
Returns
-------
scaler: GradScaler
The scaler.
"""
scaler = GradScaler(enabled=self.run_config.amp)
if scaler_state:
scaler.load_state_dict(scaler_state)
return scaler
[docs] def reset_optimizer_scheduler(self):
"""
Resets the optimizer and scheduler by recreating them.
"""
optimizer_config = self.train_config.optimizer_config
scheduler_config = self.train_config.scheduler_config
optimizer = self.create_optimizer(
model=self.model,
optimizer_config=optimizer_config,
)
scheduler = self.create_scheduler(
model=self.model,
optimizer=optimizer,
scheduler_config=scheduler_config,
)
self.optimizer = optimizer
self.scheduler = scheduler
[docs] def load_checkpoint(
self, save_dict: dict[str, ty.Any], model_only: bool = False
) -> None:
"""
Loads the checkpoint from the save dict.
Parameters
----------
save_dict: dict[str, ty.Any]
The save dict to load the checkpoint from.
model_only: bool
Whether to load only the model or include scheduler, optimizer and scaler.
Notes
-----
This method is the implementation of the abstract method in the base class.
"""
if model_only:
del save_dict["scheduler"]
del save_dict["optimizer"]
del save_dict["scaler"]
self.create_model(
save_dict,
strict_load=True,
)
[docs] def to_device(self, data: Iterable, device=None) -> Iterable:
"""
Moves the data to the specified device.
Parameters
----------
data: Iterable
The data to move to the device.
device: ty.Optional[ty.Union[torch.device, str]]
The device to move the data to. If ``None``, the device specified in the config is used.
Returns
-------
data: Iterable
The data on the device.
"""
if device is None:
device = self.device
return butils.iter_to_device(data, device)
[docs] def model_step(
self, model: nn.Module, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]:
"""
A single inference step for the model.
Parameters
----------
model: nn.Module
The model to train.
batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
Returns
-------
out: tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]
The output of the model,contains current predictions and loss of the model
"""
batch = self.to_device(batch)
with self.autocast:
if isinstance(batch, list):
out = model(*batch)
elif isinstance(batch, dict):
out = model(**batch)
else:
out = model(batch)
return out
@ty.final
def _update_learning_rate(self):
self.learning_rate = butils.get_lr(self.optimizer)
return self.learning_rate
@ty.final
def _inc_iter(self):
self.current_iteration += 1
def _is_step(self, step_interval):
return (
step_interval > 0
and self.current_iteration > 0
and self.current_iteration % step_interval == 0
)
def _train_evaluation_step(self, smoke_test=False):
is_best = False
val_loss = None
if self.val_dataloader is not None:
metrics = self._validation_loop(
model=self.model,
dataloader=self.val_dataloader,
tag="val",
metrics=self.metrics,
subsample=self.run_config.eval_subsample,
smoke_test=smoke_test,
)
val_loss = metrics["val_loss"] if "val_loss" in metrics else None
if val_loss is not None:
# Use val loss for scheduling or finding best checkpoint
is_best = val_loss < self.best_loss
if is_best or self.best_loss == 0:
self.best_iteration = self.current_iteration
self.best_loss = val_loss
divergence_step = (
self.current_iteration > self.epoch_len * self.run_config.warm_up_epochs
)
is_diverged = val_loss / self.best_loss > self.run_config.divergence_factor
if is_diverged and divergence_step:
raise LossDivergedError(
f"Val loss {val_loss:.4e} has diverged by"
f"a factor of {self.run_config.divergence_factor} to "
f"best loss {self.best_loss:.4e}"
)
if (
self.scheduler is not None
and hasattr(self.train_config.scheduler_config.arguments, "step_when")
and self.train_config.scheduler_config.arguments.step_when == "val"
):
if val_loss is None:
raise EvaluationError(
f"A validation dataset is rquired with {self.scheduler.__class__.__name__} scheduler"
)
self.scheduler.step(val_loss)
self._checkpoint()
if is_best:
self._checkpoint(is_best=True)
# Early stopping
early_stopping_iter = self.run_config.early_stopping_iter
if (
early_stopping_iter is not None
and (self.current_iteration - self.best_iteration) > early_stopping_iter
):
raise TrainPlateauError(
f"Early stopping, no improvement for {early_stopping_iter} iterations."
)
def _model_step(
self, model: nn.Module, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]:
out = self.model_step(model=model, batch=batch)
try:
outputs, loss = out
assert isinstance(outputs, (dict, type(None))) and isinstance(
loss, (torch.Tensor, type(None))
)
if outputs is not None:
for k, v in outputs.items():
assert isinstance(k, str) and isinstance(v, torch.Tensor)
except Exception as exc:
raise RuntimeError(
"Model should return outputs: dict[str, torch.Tensor] | None, loss: torch.Tensor | None."
) from exc
return outputs, loss
[docs] @ty.final
def train_step(
self, batch: Iterable
) -> tuple[dict[str, torch.Tensor] | None, dict[str, ty.Any]]:
"""
A single step for training.
It also updates learning rate with scheduler.
Parameters
----------
batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
Returns
-------
outputs: dict[str, torch.Tensor] | None
The output of the model.
train_metrics: dict[str, ty.Any]
The training metrics.
"""
model = self.model
optimizer = self.optimizer
scaler = self.scaler
scheduler = self.scheduler
# Ensure no left-over grads are in the model's parameters from custom evaluation or what-not
optimizer.zero_grad()
outputs, loss = self._model_step(model=model, batch=batch)
loss_value = self.apply_loss(model, loss, optimizer, scaler, scheduler)
aux_metrics = None
if outputs is not None:
aux_metrics = self.aux_metrics(outputs)
if (
scheduler is not None
and getattr(scheduler, "step_when", None) == "epoch"
and self._is_step(self.epoch_len)
):
scheduler.step() # type: ignore
self._inc_iter()
self._update_learning_rate()
train_metrics = {}
if loss is not None:
train_metrics["loss"] = loss_value
if aux_metrics is not None:
assert (
"loss" not in aux_metrics
), "Can not return key `loss` from `aux_metrics`"
train_metrics.update(aux_metrics)
return outputs, train_metrics
[docs] def log_step(self):
"""
A single step for logging.
Notes
-----
This method is update the logger with the current metrics and log a status message.
"""
self.logger.update(self.metrics)
msg = self.status_message()
verbose = self.verbose == "console"
self.logger.info(msg, verbose=verbose)
[docs] @ty.final
def mock_train(
self,
run_config: ty.Optional[RunConfig] = None,
run_async=True,
block: bool = True,
) -> mp.Process | TrainMetrics:
"""
Mock train the model as a smoke test
Parameters
----------
run_config: RunConfig
The run config to use for the mock train.
run_async: bool
Whether to run the mock train in a separate process.
block: bool
Whether to block the current process until the mock train is finished.
Returns
-------
p: mp.Process
The process running the mock train.
metrics: TrainMetrics
The metrics from the mock train.
"""
mock_model = copy.deepcopy(self)
if run_config is None:
run_config = mock_model.run_config
if run_async:
p = mp.Process(target=mock_model.train, args=(run_config, True))
p.start()
if block:
p.join()
return p
return mock_model.train(run_config=run_config, smoke_test=True)
[docs] def update_status(self):
"""
Update the metrics with current training stats,
and then all metrics (static and moving average) will be set as description for the ``tqdm`` progress.
"""
self.metrics.update_static_metrics(self.train_stats)
if self.verbose != "tqdm":
return
rate = self.train_tqdm.format_dict["rate"]
time_remaining = "??"
if rate is not None and isinstance(rate, (int, float)):
time_remaining = self.train_tqdm.format_interval(
(self.total_steps - self.current_iteration) / rate
)
msg = self.status_message()
self.train_tqdm.set_description(msg)
self.train_tqdm.set_postfix_str(f"Remaining: {time_remaining}")
self.train_tqdm.update(1)
[docs] def status_message(self) -> str:
"""
Return a string generated from dictionary of current metrics,including all the static metrics and moving average metrics.
Returns
-------
str
The status message.
"""
# must return current epoch, iter, losses and metrics
return " ".join([f"{k}: {v}" for k, v in self.metrics.to_dict().items()])
[docs] def log(self):
"""
Log if the current iteration is a logging step. It also evaluate training metrics for logging.
"""
# Log step
if self._is_step(self.log_itr):
self.metrics.evaluate("train", reset=False)
self.log_step()
[docs] @ty.final
def eval(self, smoke_test=False):
"""
Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step.
It also check if it is early stopping (check Model Configuration module for more details).
"""
# Evaluation step
if self._is_step(self.eval_itr):
try:
self._train_evaluation_step(smoke_test=smoke_test)
except (LossDivergedError, TrainPlateauError) as e:
error = traceback.format_exc()
self.logger.error(error)
raise e
finally:
eval_step = (
self.current_iteration
if self.eval_itr == 0
else self.current_iteration // self.eval_itr
)
msg = self.status_message()
self.logger.info(f"Evaluation Step [{eval_step}] {msg}", verbose=False)
@property
def total_steps(self):
"""
The total number of steps for training.
"""
return self.epoch_len * self.epochs
[docs] def train_loop(self, smoke_test=False):
"""
Train the model in many steps, evaluate the model and log the metrics for each iteration.
metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.
Parameters
----------
smoke_test: bool
Whether to run a smoke test.
"""
train_dataloader = self.train_dataloader
generator = iter(train_dataloader)
for i in range(self.current_iteration, self.total_steps):
self.model.train()
try:
batch = next(generator)
except StopIteration:
# restart the generator if the previous generator is exhausted.
generator = iter(train_dataloader)
batch = next(generator)
self.metrics.reset("train")
self.train_tqdm.reset()
outputs, train_metrics = self.train_step(batch)
if outputs is not None:
self.metrics.append_batch(**outputs, tag="train")
self.metrics.update_ma_metrics(train_metrics, tag="train")
if "loss" in train_metrics and not np.isfinite(train_metrics["loss"]):
msg = f"Loss Diverged. Terminating. loss: {train_metrics['loss']}"
self.logger.error(msg)
raise LossDivergedError(msg)
if not smoke_test:
self.update_status()
self.log()
self.eval()
if smoke_test and i > self.epoch_len * 0.01:
self.eval(smoke_test=True)
break
return self.metrics
[docs] @ty.final
def train(
self,
run_config: RunConfig,
smoke_test: bool = False,
debug: bool = False,
resume: bool = False,
) -> TrainMetrics:
"""
Initialize states and train the model.
When keyboard interrupts, saves a checkpoint
Parameters
----------
run_config : RunConfig
The run config to use for training.
smoke_test : bool, default=False
Whether to run a smoke test.
debug : bool, default=False
Whether to run in debug mode.
resume : bool, default=False
Whether to resume training the model from existing checkpoints and existing experiment state.
Returns
-------
TrainMetrics
The metrics from the training.
"""
self._init_state(
run_config=run_config, smoke_test=smoke_test, debug=debug, resume=resume
)
try:
return self.train_loop(smoke_test)
except KeyboardInterrupt:
self._checkpoint()
return self.metrics
[docs] @ty.final
def evaluate(
self,
run_config: RunConfig,
):
"""
Evaluate the model after training on the test and validation sets.
Parameters
----------
run_config: RunConfig
The run config to use for evaluation.
"""
self._init_state(run_config, resume=True)
self.logger.info(f"Evaluating {self.current_checkpoint}")
msg = self.metrics.to_dict()
self.logger.info(f"Current metrics: {msg}")
metrics = {}
for loader, tag in zip(
[self.test_dataloader, self.val_dataloader], ["test", "val"]
):
if loader is not None:
# NOTE we set max memory limit and let it crash because we do not want
# inaccurate metrics calculation. Possibly smarter ways to go about it.
eval_metrics = TrainMetrics(
batch_limit=len(loader) + 1,
memory_limit=int(1e9),
moving_average_limit=len(loader),
evaluation_functions=self.evaluation_functions(),
tags=[tag],
moving_aux_metrics=["loss"] + getattr(self, "aux_metric_names", []),
)
self._validation_loop(
model=self.model,
dataloader=loader,
tag=tag, # type: ignore
metrics=eval_metrics,
subsample=1,
)
metrics[tag] = eval_metrics
msg = self.metrics.to_dict()
self.logger.info(f"Evaluation: {msg}")
return metrics
[docs] def apply_loss(
self,
model: nn.Module,
loss: torch.Tensor | None,
optimizer: Optimizer,
scaler: torch.cuda.amp.GradScaler,
scheduler: ty.Optional[Scheduler],
) -> float | None:
"""
Calculate the loss and apply the gradients, call ``optimizer.step()`` and ``scheduler.step()``.
Parameters
----------
model: nn.Module
The model to apply the loss to.
loss: torch.Tensor | None
The loss to apply.
optimizer: Optimizer
The optimizer to step.
scaler: torch.cuda.amp.GradScaler
The scaler to use for mixed precision training.
scheduler: ty.Optional[Scheduler]
The scheduler to step.
Returns
-------
float | None
The loss value.
"""
if loss is not None:
loss = torch.mean(loss)
if self.amp:
scaler.scale(loss).backward()
else:
loss.backward()
loss_value = loss.item()
else:
loss_value = None
if self.amp:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
if scheduler is not None and getattr(scheduler, "step_when", None) == "train":
scheduler.step() # type: ignore
return loss_value
@torch.no_grad()
def _validation_loop(
self,
model: nn.Module,
dataloader: DataLoader,
metrics: TrainMetrics,
tag: ty.Literal["train", "test", "val"],
subsample: float = 1.0,
smoke_test: bool = False,
) -> dict[str, float]:
was_training = model.training
model.eval()
if (batch_lim := metrics.__batch_limit__) < len(dataloader):
self.logger.warn(
f"Metrics batch-limit {batch_lim} is smaller than "
f"the validation dataloader length {len(dataloader)}. "
"Consider increasing `metrics_n_batches`."
)
metrics_dict = self.validation_loop(
model, dataloader, metrics, tag, subsample, smoke_test
)
if was_training:
model.train()
return metrics_dict
[docs] def validation_loop(
self,
model: nn.Module,
dataloader: DataLoader,
metrics: TrainMetrics,
tag: ty.Literal["train", "test", "val"],
subsample: float = 1.0,
smoke_test: bool = False,
) -> dict[str, float]:
"""
Validate the model on data in dataloader (which can either be val dataloader
- so tag is ``val``, or test dataloader - so tag is ``test``)
Parameters
----------
model: nn.Module
The model to validate.
dataloader: DataLoader
The dataloader to use for validation.
metrics: TrainMetrics
The metrics to use for validation.
tag: ty.Literal["train", "test", "val"]
The tag to use for validation. Also see ``TrainMetrics`` for details.
subsample: float
The fraction of the dataloader to use for validation.
smoke_test: bool
Whether to execute this function as a smoke test. If ``True``, only one iteration will be performed,
which is useful for quickly checking if the code runs without errors. Default is ``False``.
Returns
-------
dict[str, float]
The metrics from the validation.
"""
cutoff_itr = len(dataloader) * subsample
if model.training:
self.logger.warn(
"Called `validation_loop` without setting the model to evaluation mode. i.e. `model.eval()`"
)
for i, batch in enumerate(dataloader):
with torch.no_grad():
outputs, loss = self._model_step(model=model, batch=batch)
val_metrics = {}
if outputs is not None:
aux_metrics = self.aux_metrics(outputs)
metrics.append_batch(tag=tag, **outputs)
if aux_metrics is not None:
assert (
"loss" not in aux_metrics
), "Invalid return key `loss` from `aux_metrics`"
val_metrics.update(aux_metrics)
if loss is not None:
val_metrics["loss"] = torch.mean(loss).item()
metrics.update_ma_metrics(val_metrics, tag=tag)
if i > cutoff_itr or smoke_test:
break
metrics.evaluate(tag)
metrics_dict = {
k: v for k, v in metrics.to_dict().items() if k.startswith(f"{tag}_")
}
return metrics_dict
[docs] @abstractmethod
def make_dataloader_train(self, run_config: RunConfig) -> DataLoader:
"""
Function to make the training dataloader.
Parameters
----------
run_config: RunConfig
The run configuration.
Returns
-------
DataLoader
The training dataloader.
"""
pass
[docs] def evaluation_functions(self) -> dict[str, Callable] | None:
"""
Returns
-------
dict[str, Callable]
The evaluation functions to use.Also see ``TrainMetrics`` for details.
"""
return None
# Functions that can be optionally over-written.
[docs] def make_dataloader_test(self, run_config: RunConfig) -> DataLoader | None:
"""
Function to make the test dataloader.
Parameters
----------
run_config: RunConfig
The run configuration.
Returns
-------
DataLoader | None
The test dataloader.
"""
pass
[docs] def make_dataloader_val(self, run_config: RunConfig) -> DataLoader | None:
"""
Function to make the validation dataloader.
Parameters
----------
run_config: RunConfig
The run configuration.
Returns
-------
DataLoader | None
The validation dataloader.
"""
pass
def custom_evaluation(
self, model: nn.Module, dataloader: Iterable
) -> ty.Optional[dict[str, ty.Any]]:
pass
[docs] def aux_metrics(
self, output_dict: dict[str, torch.Tensor] | None
) -> ty.Optional[dict[str, ty.Any]]:
"""
Auxiliary metrics to be computed during training.
Parameters
----------
output_dict: dict[str, torch.Tensor] | None
The output dictionary from the model.
Returns
-------
ty.Optional[dict[str, ty.Any]]
The auxiliary metrics.
Notes
-----
Auxiliary metrics are computed during training and are used for ``moving_aux_metrics`` in ``TrainMetrics``.
Check ``TrainMetrics`` for more details.
"""
pass
[docs] def config_parser(self, run_config: RunConfig):
"""
Used to initialize Derived properties
"""
return run_config
[docs] def make_dataloaders(self, run_config: RunConfig) -> None:
"""
This function is done post-initialization because otherwise the
dataloaders are pickled with the object when running distributed.
"""
self.train_dataloader = self.make_dataloader_train(run_config)
self.val_dataloader = self.make_dataloader_val(run_config)
self.test_dataloader = self.make_dataloader_test(run_config)
[docs] def checkpoint(self, is_best=False):
"""
Save a checkpoint of the model.It will use the class name of the model as the filename.
Parameters
----------
is_best: bool
Whether this is the best model so far.
"""
self.logger.checkpoint(
self.current_state,
self.model.__class__.__name__,
is_best=is_best,
itr=self.current_iteration,
)
[docs] def save_dict(self) -> dict[str, ty.Any]:
"""
Save the current state of the trainer, including model parameters,
and current states of the optimizer, the scaler, and the scheduler.
Returns
-------
dict[str, ty.Any]
The current state of the trainer.
"""
model_state_dict = self.model.state_dict()
optimizer_state_dict = None
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
scheduler_state_dict = None
if self.scheduler is not None:
scheduler_state_dict = self.scheduler.state_dict()
scaler_state_dict = None
if self.scaler is not None:
scaler_state_dict = self.scaler.state_dict()
return {
"model": model_state_dict,
"optimizer": optimizer_state_dict,
"scheduler": scheduler_state_dict,
"scaler": scaler_state_dict,
}