Main model package#

Submodules#

Main Model module#

exception ablator.main.model.main.CheckpointNotFoundError[source]#

Bases: FileNotFoundError

exception ablator.main.model.main.EvaluationError[source]#

Bases: Exception

exception ablator.main.model.main.LogStepError[source]#

Bases: Exception

class ablator.main.model.main.ModelBase(model_class: type[torch.nn.modules.module.Module])[source]#

Bases: ABC

Base class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality.

Notes

  1. Class properties are simply listed by name. Please check out property docstring for more information.

  2. Users must implement the abstract methods to customize the model’s behavior.

3. Mixed precision training enables some operations to use the torch.float32 datatype and other operations use lower precision floating point datatype torch.float16. This is for saving time and reducing memory usage. Ordinarily, “automatic mixed precision training” means training with torch.autocast and torch.cuda.amp.GradScaler together. More information: https://pytorch.org/docs/stable/amp.html

Attributes:
model_classType[nn.Module]

The class definition of the model’s structure, which is a subclass of nn.Module.

run_configRunConfig

An instance of RunConfig containing configuration details.

train_dataloaderDataLoader

A DataLoader object responsible for model training.

val_dataloaderOptional[DataLoader]

An optional DataLoader object used for model evaluation.

test_dataloaderOptional[DataLoader]

An optional DataLoader object used for model testing.

loggerUnion[SummaryLogger, tutils.Dummy]

Records information on the program’s operation and model training, such as progress and performance metrics.

devicestr

The type of device used for running the experiment. i.e. "cuda", "cpu", "cuda:0".

model_dirPath

The model directory.

experiment_dirPath

The experiment directory.

autocasttorch.autocast

Enables autocasting for chosen regions. Autocasting automatically chooses the precision for GPU operations to improve performance while maintaining accuracy.

verbosebool

If True, prints additional information while training. Only applied for the master process.

ampbool

If True, apply automatic mixed precision training, otherwise default precision.

random_seedOptional[int]

Sets the seed for generating random numbers.

train_tqdmtqdm, optional

An optional instance of tqdm that creates progress bars and displays real-time information during training. i.e. time remaining. Only applied for the master process.

current_checkpointOptional[Path]

Directory for the current checkpoint file, by default None.

metricsMetrics

Training metrics including model information. i.e. learning rate and loss value.

current_statedict

The currrent state of the model, including run_config, metrics and other necessary states.

learning_ratefloat

The current learning rate.

total_stepsint

The total steps for the training process.

epochsint

The total epochs for the training process.

current_iterationint

The current iteration of training.

best_iterationint

The iteration with the best loss value.

best_lossfloat

The lowest loss value encountered during training.

__init__(model_class: type[torch.nn.modules.module.Module])[source]#

Initializes the ModelBase class with the required model_class and optional configurations.

Parameters:
model_classtype[nn.Module]

The base class for user’s model, which defines the neural network.

abstract checkpoint(is_best=False)[source]#

Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the checkpoint method in the ModelWrapper class.

Parameters:
is_bestbool, optional

Indicates if the current checkpoint is the best model so far, by default False.

abstract config_parser(run_config: RunConfig)[source]#

Abstract method to parse the provided configuration.

Must be implemented by subclasses. Example implementation: Please see the make_dataloaders method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

abstract create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#

Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the create_model method in the ModelWrapper class.

Parameters:
save_dictdict[str, ty.Any] | None, optional

A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default None.

strict_loadbool, optional

If True, the model will be loaded strictly, ensuring that the saved state matches the model’s structure exactly. If False, the model can be loaded with a partially matching state, by default True.

property current_epoch: int#

Calculates and returns the current epoch during training.

Returns:
int

The current epoch number.

property epoch_len#

Returns the length of an epoch, which is the number of batches in the train_dataloader.

Returns:
int

The length of an epoch, represented as the number of batches in the train_dataloader.

Raises:
AssertionError

If the train_dataloader is not defined or its length is 0.

property eval_itr#

Calculate the interval between evaluations.

Returns:
int

The interval between evaluations.

abstract evaluate(run_config: RunConfig)[source]#

Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the evaluate method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

abstract evaluation_functions() dict[str, collections.abc.Callable] | None[source]#

Abstract method to create and return a dictionary of evaluation functions used during training and validation.

Must be implemented by subclasses. Example implementation: Please see the evaluation_functions method in the ModelWrapper class.

Returns:
dict[str, Callable] | None

A dictionary containing evaluation functions as values and their names as keys.

abstract load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None[source]#

Abstract method to load the model and its state from a given save dictionary.

Must be implemented by subclasses. Example implementation: Please see the load_checkpoint method in the ModelWrapper class.

Parameters:
save_dictdict[str, ty.Any]

A dictionary containing the saved model state and other necessary information.

model_onlybool, optional, default=False

If True, only the model’s weights will be loaded, ignoring other state information.

property log_itr#

Calculate the interval between logging steps.

Returns:
int

The interval between logging steps.

abstract make_dataloaders(run_config: RunConfig)[source]#

Abstract method to create dataloaders for the training, validation, and testing datasets.

This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided run_config.

Must be implemented by subclasses. Example implementation: Please see the make_dataloaders method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

abstract save_dict() dict[str, Any] | None[source]#

Abstract method to create and return a save dictionary containing the model’s state and other necessary information.

Must be implemented by subclasses. Example implementation: Please see the save_dict method in the ModelWrapper class.

Returns:
dict[str, ty.Any] | None

A dictionary containing the saved model state and other necessary information.

abstract train(run_config: RunConfig, smoke_test: bool = False)[source]#

Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the train method in the ModelWrapper class.

Parameters:
run_configRunConfig

An instance of RunConfig containing configuration details.

smoke_testbool, optional

Whether to run as a smoke test, by default False.

property train_stats: OrderedDict#

Returns an ordered dictionary containing the current training statistics.

Returns:
OrderedDict

An ordered dictionary with the following keys and values:

  • learning_rate: The current learning rate.

  • total_steps: The total steps for the training process.

  • epochs: The number of epochs for training.

  • current_epoch: The current epoch during training.

  • current_iteration: The current iteration during training.

  • best_iteration: The iteration with the best loss value so far.

  • best_loss: The best (lowest) loss value achieved during training.

property uid#

Returns a unique identifier (UID) for the current run configuration.

Returns:
str

A string representing the unique identifier of the current run configuration.

exception ablator.main.model.main.TrainPlateauError[source]#

Bases: Exception

Model Wrapper module#

class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]#

Bases: ModelBase

A wrapper around model_class that removes training boiler-plate code, with over-writable functions with support for custom use-cases.

Attributes:
model_class: torch.nn.Module

The model class to wrap.

model: torch.nn.Module

The model created from the model class or checkpoint

optimizer: Optimizer

The optimizer created from the optimizer config or checkpoint

scaler: GradScaler

The scaler created from the scaler config or checkpoint

scheduler: Scheduler

The scheduler created from the scheduler config or checkpoint

__init__(model_class: type[torch.nn.modules.module.Module])[source]#

Initializes the model wrapper.

Parameters:
model_class: torch.nn.Module

The model class to wrap.

apply_loss(model: Module, loss: Tensor | None, optimizer: Optimizer, scaler: GradScaler, scheduler: _LRScheduler | ReduceLROnPlateau | Any | None) float | None[source]#

Calculate the loss and apply the gradients, call optimizer.step() and scheduler.step().

Parameters:
model: nn.Module

The model to apply the loss to.

loss: torch.Tensor | None

The loss to apply.

optimizer: Optimizer

The optimizer to step.

scaler: torch.cuda.amp.GradScaler

The scaler to use for mixed precision training.

scheduler: ty.Optional[Scheduler]

The scheduler to step.

Returns:
float | None

The loss value.

aux_metrics(output_dict: dict[str, torch.Tensor] | None) dict[str, Any] | None[source]#

Auxiliary metrics to be computed during training.

Parameters:
output_dict: dict[str, torch.Tensor] | None

The output dictionary from the model.

Returns:
ty.Optional[dict[str, ty.Any]]

The auxiliary metrics.

Notes

Auxiliary metrics are computed during training and are used for moving_aux_metrics in TrainMetrics. Check TrainMetrics for more details.

checkpoint(is_best=False)[source]#

Save a checkpoint of the model.It will use the class name of the model as the filename.

Parameters:
is_best: bool

Whether this is the best model so far.

config_parser(run_config: RunConfig)[source]#

Used to initialize Derived properties

create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None[source]#

Creates the model, optimizer, scheduler and scaler from the save dict or from config.

Parameters:
save_dict: dict[str, ty.Any]

The save dict to load from.

strict_load: bool

Whether to load the model strictly or not.

create_optimizer(model: Module, optimizer_config: OptimizerConfig | None = None, optimizer_state: dict[str, Any] | None = None) Optimizer[source]#

Creates the optimizer from the saved state or from config.

Parameters:
model: nn.Module

The model to create the optimizer for.

optimizer_config: OptimizerConfig

The optimizer config to create the optimizer from.

optimizer_state: dict[str, ty.Any]

The optimizer state to load the optimizer from.

Returns:
optimizer: Optimizer

The optimizer.

create_scaler(scaler_state: dict | None = None) GradScaler[source]#

Creates the scaler from the saved state or from config.

Parameters:
scaler_state: dict[str, ty.Any]

The scaler state to load the scaler from.

Returns:
scaler: GradScaler

The scaler.

create_scheduler(model: Module, optimizer: Optimizer, scheduler_config: SchedulerConfig | None = None, scheduler_state: dict | None = None) _LRScheduler | ReduceLROnPlateau | Any | None[source]#

Creates the scheduler from the saved state or from config.

Parameters:
model: nn.Module

The model to create the scheduler for.

optimizer: Optimizer

The optimizer to create the scheduler for.

scheduler_config: SchedulerConfig

The scheduler config to create the scheduler from.

scheduler_state: dict[str, ty.Any]

The scheduler state to load the scheduler from.

Returns:
scheduler: Scheduler

The scheduler.

final eval(smoke_test=False)[source]#

Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step. It also check if it is early stopping (check Model Configuration module for more details).

final evaluate(run_config: RunConfig)[source]#

Evaluate the model after training on the test and validation sets.

Parameters:
run_config: RunConfig

The run config to use for evaluation.

evaluation_functions() dict[str, collections.abc.Callable] | None[source]#
Returns:
dict[str, Callable]

The evaluation functions to use.Also see TrainMetrics for details.

load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None[source]#

Loads the checkpoint from the save dict.

Parameters:
save_dict: dict[str, ty.Any]

The save dict to load the checkpoint from.

model_only: bool

Whether to load only the model or include scheduler, optimizer and scaler.

Notes

This method is the implementation of the abstract method in the base class.

log()[source]#

Log if the current iteration is a logging step. It also evaluate training metrics for logging.

log_step()[source]#

A single step for logging.

Notes

This method is update the logger with the current metrics and log a status message.

make_dataloader_test(run_config: RunConfig) DataLoader | None[source]#

Function to make the test dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The test dataloader.

abstract make_dataloader_train(run_config: RunConfig) DataLoader[source]#

Function to make the training dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader

The training dataloader.

make_dataloader_val(run_config: RunConfig) DataLoader | None[source]#

Function to make the validation dataloader.

Parameters:
run_config: RunConfig

The run configuration.

Returns:
DataLoader | None

The validation dataloader.

make_dataloaders(run_config: RunConfig) None[source]#

This function is done post-initialization because otherwise the dataloaders are pickled with the object when running distributed.

final mock_train(run_config: RunConfig | None = None, run_async=True, block: bool = True) Process | TrainMetrics[source]#

Mock train the model as a smoke test

Parameters:
run_config: RunConfig

The run config to use for the mock train.

run_async: bool

Whether to run the mock train in a separate process.

block: bool

Whether to block the current process until the mock train is finished.

Returns:
p: mp.Process

The process running the mock train.

metrics: TrainMetrics

The metrics from the mock train.

model_step(model: Module, batch: Iterable) tuple[dict[str, torch.Tensor] | None, torch.Tensor | None][source]#

A single inference step for the model.

Parameters:
model: nn.Module

The model to train.

batch: Iterable

The batch of input data to pass through the model,it could be a list, dict or a single tensor.

Returns:
out: tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]

The output of the model,contains current predictions and loss of the model

reset_optimizer_scheduler()[source]#

Resets the optimizer and scheduler by recreating them.

save_dict() dict[str, Any][source]#

Save the current state of the trainer, including model parameters, and current states of the optimizer, the scaler, and the scheduler.

Returns:
dict[str, ty.Any]

The current state of the trainer.

status_message() str[source]#

Return a string generated from dictionary of current metrics,including all the static metrics and moving average metrics.

Returns:
str

The status message.

to_device(data: Iterable, device=None) Iterable[source]#

Moves the data to the specified device.

Parameters:
data: Iterable

The data to move to the device.

device: ty.Optional[ty.Union[torch.device, str]]

The device to move the data to. If None, the device specified in the config is used.

Returns:
data: Iterable

The data on the device.

property total_steps#

The total number of steps for training.

final train(run_config: RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False) TrainMetrics[source]#

Initialize states and train the model. When keyboard interrupts, saves a checkpoint

Parameters:
run_configRunConfig

The run config to use for training.

smoke_testbool, default=False

Whether to run a smoke test.

debugbool, default=False

Whether to run in debug mode.

resumebool, default=False

Whether to resume training the model from existing checkpoints and existing experiment state.

Returns:
TrainMetrics

The metrics from the training.

train_loop(smoke_test=False)[source]#

Train the model in many steps, evaluate the model and log the metrics for each iteration. metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.

Parameters:
smoke_test: bool

Whether to run a smoke test.

final train_step(batch: Iterable) tuple[dict[str, torch.Tensor] | None, dict[str, Any]][source]#

A single step for training. It also updates learning rate with scheduler.

Parameters:
batch: Iterable

The batch of input data to pass through the model,it could be a list, dict or a single tensor.

Returns:
outputs: dict[str, torch.Tensor] | None

The output of the model.

train_metrics: dict[str, ty.Any]

The training metrics.

update_status()[source]#

Update the metrics with current training stats, and then all metrics (static and moving average) will be set as description for the tqdm progress.

validation_loop(model: Module, dataloader: DataLoader, metrics: TrainMetrics, tag: Literal['train', 'test', 'val'], subsample: float = 1.0, smoke_test: bool = False) dict[str, float][source]#

Validate the model on data in dataloader (which can either be val dataloader - so tag is val, or test dataloader - so tag is test)

Parameters:
model: nn.Module

The model to validate.

dataloader: DataLoader

The dataloader to use for validation.

metrics: TrainMetrics

The metrics to use for validation.

tag: ty.Literal[“train”, “test”, “val”]

The tag to use for validation. Also see TrainMetrics for details.

subsample: float

The fraction of the dataloader to use for validation.

smoke_test: bool

Whether to execute this function as a smoke test. If True, only one iteration will be performed, which is useful for quickly checking if the code runs without errors. Default is False.

Returns:
dict[str, float]

The metrics from the validation.

Module contents#