Main model package#
Submodules#
Main Model module#
- class ablator.main.model.main.ModelBase(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ABC
Base class that removes training boiler-plate code with extensible support for multiple use-cases. The class follows a stateful initialization paradigm. Requires the user to implement specific to their use-case load model and creation functionality.
Notes
Class properties are simply listed by name. Please check out property docstring for more information.
Users must implement the abstract methods to customize the model’s behavior.
3. Mixed precision training enables some operations to use the
torch.float32
datatype and other operations use lower precision floating point datatypetorch.float16
. This is for saving time and reducing memory usage. Ordinarily, “automatic mixed precision training” means training withtorch.autocast
andtorch.cuda.amp.GradScaler
together. More information: https://pytorch.org/docs/stable/amp.html- Attributes:
- model_classType[nn.Module]
The class definition of the model’s structure, which is a subclass of
nn.Module
.- run_configRunConfig
An instance of
RunConfig
containing configuration details.- train_dataloaderDataLoader
A DataLoader object responsible for model training.
- val_dataloaderOptional[DataLoader]
An optional DataLoader object used for model evaluation.
- test_dataloaderOptional[DataLoader]
An optional DataLoader object used for model testing.
- loggerUnion[SummaryLogger, tutils.Dummy]
Records information on the program’s operation and model training, such as progress and performance metrics.
- devicestr
The type of device used for running the experiment. i.e.
"cuda"
,"cpu"
,"cuda:0"
.- model_dirPath
The model directory.
- experiment_dirPath
The experiment directory.
- autocasttorch.autocast
Enables autocasting for chosen regions. Autocasting automatically chooses the precision for GPU operations to improve performance while maintaining accuracy.
- verbosebool
If
True
, prints additional information while training. Only applied for the master process.- ampbool
If
True
, apply automatic mixed precision training, otherwise default precision.- random_seedOptional[int]
Sets the seed for generating random numbers.
- train_tqdmtqdm, optional
An optional instance of
tqdm
that creates progress bars and displays real-time information during training. i.e. time remaining. Only applied for the master process.- current_checkpointOptional[Path]
Directory for the current checkpoint file, by default None.
- metricsMetrics
Training metrics including model information. i.e. learning rate and loss value.
- current_statedict
The currrent state of the model, including run_config, metrics and other necessary states.
- learning_ratefloat
The current learning rate.
- total_stepsint
The total steps for the training process.
- epochsint
The total epochs for the training process.
- current_iterationint
The current iteration of training.
- best_iterationint
The iteration with the best loss value.
- best_lossfloat
The lowest loss value encountered during training.
- __init__(model_class: type[torch.nn.modules.module.Module])[source]#
Initializes the ModelBase class with the required
model_class
and optional configurations.- Parameters:
- model_classtype[nn.Module]
The base class for user’s model, which defines the neural network.
- abstract checkpoint(is_best=False)[source]#
Abstract method to save a checkpoint of the model. Must be implemented by subclasses. Example implementation: Please see the
checkpoint
method in theModelWrapper
class.- Parameters:
- is_bestbool, optional
Indicates if the current checkpoint is the best model so far, by default
False
.
- abstract config_parser(run_config: RunConfig)[source]#
Abstract method to parse the provided configuration.
Must be implemented by subclasses. Example implementation: Please see the
make_dataloaders
method in theModelWrapper
class.- Parameters:
- run_configRunConfig
An instance of
RunConfig
containing configuration details.
- abstract create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None [source]#
Abstract method to create and initialize the model. Must be implemented by subclasses. Example implementation: Please see the
create_model
method in theModelWrapper
class.- Parameters:
- save_dictdict[str, ty.Any] | None, optional
A dictionary containing saved model data, such as weights, optimizer state, etc., to be loaded into the model, by default
None
.- strict_loadbool, optional
If True, the model will be loaded strictly, ensuring that the saved state matches the model’s structure exactly. If False, the model can be loaded with a partially matching state, by default
True
.
- property current_epoch: int#
Calculates and returns the current epoch during training.
- Returns:
- int
The current epoch number.
- property epoch_len#
Returns the length of an epoch, which is the number of batches in the
train_dataloader
.- Returns:
- int
The length of an epoch, represented as the number of batches in the
train_dataloader
.
- Raises:
- AssertionError
If the
train_dataloader
is not defined or its length is 0.
- property eval_itr#
Calculate the interval between evaluations.
- Returns:
- int
The interval between evaluations.
- abstract evaluate(run_config: RunConfig)[source]#
Abstract method to evaluate the model. Must be implemented by subclasses. Example implementation: Please see the
evaluate
method in theModelWrapper
class.- Parameters:
- run_configRunConfig
An instance of
RunConfig
containing configuration details.
- abstract evaluation_functions() dict[str, collections.abc.Callable] | None [source]#
Abstract method to create and return a dictionary of evaluation functions used during training and validation.
Must be implemented by subclasses. Example implementation: Please see the
evaluation_functions
method in theModelWrapper
class.- Returns:
- dict[str, Callable] | None
A dictionary containing evaluation functions as values and their names as keys.
- abstract load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None [source]#
Abstract method to load the model and its state from a given save dictionary.
Must be implemented by subclasses. Example implementation: Please see the
load_checkpoint
method in theModelWrapper
class.- Parameters:
- save_dictdict[str, ty.Any]
A dictionary containing the saved model state and other necessary information.
- model_onlybool, optional, default=False
If
True
, only the model’s weights will be loaded, ignoring other state information.
- property log_itr#
Calculate the interval between logging steps.
- Returns:
- int
The interval between logging steps.
- abstract make_dataloaders(run_config: RunConfig)[source]#
Abstract method to create dataloaders for the training, validation, and testing datasets.
This method should define the process of loading the data and creating dataloaders for the training, validation, and testing datasets based on the provided
run_config
.Must be implemented by subclasses. Example implementation: Please see the
make_dataloaders
method in theModelWrapper
class.- Parameters:
- run_configRunConfig
An instance of
RunConfig
containing configuration details.
- abstract save_dict() dict[str, Any] | None [source]#
Abstract method to create and return a save dictionary containing the model’s state and other necessary information.
Must be implemented by subclasses. Example implementation: Please see the
save_dict
method in theModelWrapper
class.- Returns:
- dict[str, ty.Any] | None
A dictionary containing the saved model state and other necessary information.
- abstract train(run_config: RunConfig, smoke_test: bool = False)[source]#
Abstract method to train the model. Must be implemented by subclasses. Example implementation: Please see the
train
method in theModelWrapper
class.- Parameters:
- run_configRunConfig
An instance of
RunConfig
containing configuration details.- smoke_testbool, optional
Whether to run as a smoke test, by default
False
.
- property train_stats: OrderedDict#
Returns an ordered dictionary containing the current training statistics.
- Returns:
- OrderedDict
An ordered dictionary with the following keys and values:
learning_rate: The current learning rate.
total_steps: The total steps for the training process.
epochs: The number of epochs for training.
current_epoch: The current epoch during training.
current_iteration: The current iteration during training.
best_iteration: The iteration with the best loss value so far.
best_loss: The best (lowest) loss value achieved during training.
- property uid#
Returns a unique identifier (UID) for the current run configuration.
- Returns:
- str
A string representing the unique identifier of the current run configuration.
Model Wrapper module#
- class ablator.main.model.wrapper.ModelWrapper(model_class: type[torch.nn.modules.module.Module])[source]#
Bases:
ModelBase
A wrapper around model_class that removes training boiler-plate code, with over-writable functions with support for custom use-cases.
- Attributes:
- model_class: torch.nn.Module
The model class to wrap.
- model: torch.nn.Module
The model created from the model class or checkpoint
- optimizer: Optimizer
The optimizer created from the optimizer config or checkpoint
- scaler: GradScaler
The scaler created from the scaler config or checkpoint
- scheduler: Scheduler
The scheduler created from the scheduler config or checkpoint
- __init__(model_class: type[torch.nn.modules.module.Module])[source]#
Initializes the model wrapper.
- Parameters:
- model_class: torch.nn.Module
The model class to wrap.
- apply_loss(model: Module, loss: Tensor | None, optimizer: Optimizer, scaler: GradScaler, scheduler: _LRScheduler | ReduceLROnPlateau | Any | None) float | None [source]#
Calculate the loss and apply the gradients, call
optimizer.step()
andscheduler.step()
.- Parameters:
- model: nn.Module
The model to apply the loss to.
- loss: torch.Tensor | None
The loss to apply.
- optimizer: Optimizer
The optimizer to step.
- scaler: torch.cuda.amp.GradScaler
The scaler to use for mixed precision training.
- scheduler: ty.Optional[Scheduler]
The scheduler to step.
- Returns:
- float | None
The loss value.
- aux_metrics(output_dict: dict[str, torch.Tensor] | None) dict[str, Any] | None [source]#
Auxiliary metrics to be computed during training.
- Parameters:
- output_dict: dict[str, torch.Tensor] | None
The output dictionary from the model.
- Returns:
- ty.Optional[dict[str, ty.Any]]
The auxiliary metrics.
Notes
Auxiliary metrics are computed during training and are used for
moving_aux_metrics
inTrainMetrics
. CheckTrainMetrics
for more details.
- checkpoint(is_best=False)[source]#
Save a checkpoint of the model.It will use the class name of the model as the filename.
- Parameters:
- is_best: bool
Whether this is the best model so far.
- create_model(save_dict: dict[str, Any] | None = None, strict_load: bool = True) None [source]#
Creates the model, optimizer, scheduler and scaler from the save dict or from config.
- Parameters:
- save_dict: dict[str, ty.Any]
The save dict to load from.
- strict_load: bool
Whether to load the model strictly or not.
- create_optimizer(model: Module, optimizer_config: OptimizerConfig | None = None, optimizer_state: dict[str, Any] | None = None) Optimizer [source]#
Creates the optimizer from the saved state or from config.
- Parameters:
- model: nn.Module
The model to create the optimizer for.
- optimizer_config: OptimizerConfig
The optimizer config to create the optimizer from.
- optimizer_state: dict[str, ty.Any]
The optimizer state to load the optimizer from.
- Returns:
- optimizer: Optimizer
The optimizer.
- create_scaler(scaler_state: dict | None = None) GradScaler [source]#
Creates the scaler from the saved state or from config.
- Parameters:
- scaler_state: dict[str, ty.Any]
The scaler state to load the scaler from.
- Returns:
- scaler: GradScaler
The scaler.
- create_scheduler(model: Module, optimizer: Optimizer, scheduler_config: SchedulerConfig | None = None, scheduler_state: dict | None = None) _LRScheduler | ReduceLROnPlateau | Any | None [source]#
Creates the scheduler from the saved state or from config.
- Parameters:
- model: nn.Module
The model to create the scheduler for.
- optimizer: Optimizer
The optimizer to create the scheduler for.
- scheduler_config: SchedulerConfig
The scheduler config to create the scheduler from.
- scheduler_state: dict[str, ty.Any]
The scheduler state to load the scheduler from.
- Returns:
- scheduler: Scheduler
The scheduler.
- final eval(smoke_test=False)[source]#
Evaluate the model then update scheduler and save checkpoint if the current iteration is an evaluation step. It also check if it is early stopping (check Model Configuration module for more details).
- final evaluate(run_config: RunConfig)[source]#
Evaluate the model after training on the test and validation sets.
- Parameters:
- run_config: RunConfig
The run config to use for evaluation.
- evaluation_functions() dict[str, collections.abc.Callable] | None [source]#
- Returns:
- dict[str, Callable]
The evaluation functions to use.Also see
TrainMetrics
for details.
- load_checkpoint(save_dict: dict[str, Any], model_only: bool = False) None [source]#
Loads the checkpoint from the save dict.
- Parameters:
- save_dict: dict[str, ty.Any]
The save dict to load the checkpoint from.
- model_only: bool
Whether to load only the model or include scheduler, optimizer and scaler.
Notes
This method is the implementation of the abstract method in the base class.
- log()[source]#
Log if the current iteration is a logging step. It also evaluate training metrics for logging.
- log_step()[source]#
A single step for logging.
Notes
This method is update the logger with the current metrics and log a status message.
- make_dataloader_test(run_config: RunConfig) DataLoader | None [source]#
Function to make the test dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The test dataloader.
- abstract make_dataloader_train(run_config: RunConfig) DataLoader [source]#
Function to make the training dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader
The training dataloader.
- make_dataloader_val(run_config: RunConfig) DataLoader | None [source]#
Function to make the validation dataloader.
- Parameters:
- run_config: RunConfig
The run configuration.
- Returns:
- DataLoader | None
The validation dataloader.
- make_dataloaders(run_config: RunConfig) None [source]#
This function is done post-initialization because otherwise the dataloaders are pickled with the object when running distributed.
- final mock_train(run_config: RunConfig | None = None, run_async=True, block: bool = True) Process | TrainMetrics [source]#
Mock train the model as a smoke test
- Parameters:
- run_config: RunConfig
The run config to use for the mock train.
- run_async: bool
Whether to run the mock train in a separate process.
- block: bool
Whether to block the current process until the mock train is finished.
- Returns:
- p: mp.Process
The process running the mock train.
- metrics: TrainMetrics
The metrics from the mock train.
- model_step(model: Module, batch: Iterable) tuple[dict[str, torch.Tensor] | None, torch.Tensor | None] [source]#
A single inference step for the model.
- Parameters:
- model: nn.Module
The model to train.
- batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- out: tuple[dict[str, torch.Tensor] | None, torch.Tensor | None]
The output of the model,contains current predictions and loss of the model
- save_dict() dict[str, Any] [source]#
Save the current state of the trainer, including model parameters, and current states of the optimizer, the scaler, and the scheduler.
- Returns:
- dict[str, ty.Any]
The current state of the trainer.
- status_message() str [source]#
Return a string generated from dictionary of current metrics,including all the static metrics and moving average metrics.
- Returns:
- str
The status message.
- to_device(data: Iterable, device=None) Iterable [source]#
Moves the data to the specified device.
- Parameters:
- data: Iterable
The data to move to the device.
- device: ty.Optional[ty.Union[torch.device, str]]
The device to move the data to. If
None
, the device specified in the config is used.
- Returns:
- data: Iterable
The data on the device.
- property total_steps#
The total number of steps for training.
- final train(run_config: RunConfig, smoke_test: bool = False, debug: bool = False, resume: bool = False) TrainMetrics [source]#
Initialize states and train the model. When keyboard interrupts, saves a checkpoint
- Parameters:
- run_configRunConfig
The run config to use for training.
- smoke_testbool, default=False
Whether to run a smoke test.
- debugbool, default=False
Whether to run in debug mode.
- resumebool, default=False
Whether to resume training the model from existing checkpoints and existing experiment state.
- Returns:
- TrainMetrics
The metrics from the training.
- train_loop(smoke_test=False)[source]#
Train the model in many steps, evaluate the model and log the metrics for each iteration. metrics including static metrics like learning rate, along with validation and training metrics like loss and mean.
- Parameters:
- smoke_test: bool
Whether to run a smoke test.
- final train_step(batch: Iterable) tuple[dict[str, torch.Tensor] | None, dict[str, Any]] [source]#
A single step for training. It also updates learning rate with scheduler.
- Parameters:
- batch: Iterable
The batch of input data to pass through the model,it could be a list, dict or a single tensor.
- Returns:
- outputs: dict[str, torch.Tensor] | None
The output of the model.
- train_metrics: dict[str, ty.Any]
The training metrics.
- update_status()[source]#
Update the metrics with current training stats, and then all metrics (static and moving average) will be set as description for the
tqdm
progress.
- validation_loop(model: Module, dataloader: DataLoader, metrics: TrainMetrics, tag: Literal['train', 'test', 'val'], subsample: float = 1.0, smoke_test: bool = False) dict[str, float] [source]#
Validate the model on data in dataloader (which can either be val dataloader - so tag is
val
, or test dataloader - so tag istest
)- Parameters:
- model: nn.Module
The model to validate.
- dataloader: DataLoader
The dataloader to use for validation.
- metrics: TrainMetrics
The metrics to use for validation.
- tag: ty.Literal[“train”, “test”, “val”]
The tag to use for validation. Also see
TrainMetrics
for details.- subsample: float
The fraction of the dataloader to use for validation.
- smoke_test: bool
Whether to execute this function as a smoke test. If
True
, only one iteration will be performed, which is useful for quickly checking if the code runs without errors. Default isFalse
.
- Returns:
- dict[str, float]
The metrics from the validation.