Source code for ablator.modules.loggers.main

import copy
import json
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
from PIL import Image

import ablator.utils.file as futils
from ablator.main.configs import RunConfig
from ablator.modules.loggers import LoggerBase
from ablator.modules.loggers.file import FileLogger
from ablator.modules.loggers.tensor import TensorboardLogger
from ablator.modules.metrics.main import TrainMetrics
from ablator.modules.metrics.stores import MovingAverage


[docs]class DuplicateRunError(Exception): pass
[docs]class SummaryLogger: """ A logger for training and evaluation summary. Attributes ---------- SUMMARY_DIR_NAME : str Name of the summary directory. RESULTS_JSON_NAME : str Name of the results JSON file. LOG_FILE_NAME : str Name of the log file. CONFIG_FILE_NAME : str Name of the configuration file. METADATA_JSON : str Name of the metadata JSON file. CHKPT_DIR_NAMES : list[str] List of checkpoint directory names. CHKPT_DIR_VALUES : list[str] List of checkpoint directory values. CHKPT_DIRS : dict[str, Path] Dictionary containing checkpoint directories. keep_n_checkpoints : int Number of checkpoints to keep. log_iteration : int Current log iteration. checkpoint_iteration : dict[str, dict[str, int]] ``checkpoint_iteration`` is a dictionary that keeps track of the checkpoint iterations for each directory. It is used in the ``checkpoint()`` method to determine the appropriate iteration number for the saved checkpoint. log_file_path : Path | None Path to the log file. dashboard : LoggerBase | None Dashboard logger. model_dir : Path | None the model directory. result_json_path : Path | None Path to the results JSON file. """ SUMMARY_DIR_NAME = "dashboard" RESULTS_JSON_NAME = "results.json" LOG_FILE_NAME = "train.log" CONFIG_FILE_NAME = "config.yaml" METADATA_JSON = "metadata.json" CHKPT_DIR_NAMES = ["best", "recent"] CHKPT_DIR_VALUES = ["best_checkpoints", "checkpoints"] CHKPT_DIRS: dict[str, Path]
[docs] def __init__( self, run_config: RunConfig, model_dir: str | None | Path = None, resume: bool = False, keep_n_checkpoints: int | None = None, verbose: bool = True, ): """ Initialize a SummaryLogger. Parameters ---------- run_config : RunConfig The run configuration. model_dir : str | None | Path, optional Path to the model directory, by default ``None``. resume : bool, optional Whether to resume from an existing model directory, by default ``False``. keep_n_checkpoints : int | None, optional Number of checkpoints to keep, by default ``None``. verbose : bool, optional Whether to print messages to the console, by default ``True``. """ run_config = copy.deepcopy(run_config) self.uid = run_config.uid self.keep_n_checkpoints: int = ( keep_n_checkpoints if keep_n_checkpoints is not None else int(1e6) ) self.log_iteration: int = 0 self.checkpoint_iteration: dict[str, dict[str, int]] = {} self.log_file_path: Path | None = None self.dashboard: LoggerBase | None = None self.model_dir: Path | None = None self.result_json_path: Path | None = None self.CHKPT_DIRS = {} if model_dir is not None: self.model_dir = Path(model_dir) if not resume and self.model_dir.exists(): raise DuplicateRunError( f"SummaryLogger: Resume is set to {resume} but {self.model_dir} exists." ) if resume and self.model_dir.exists(): _run_config = type(run_config).load( self.model_dir.joinpath(self.CONFIG_FILE_NAME) ) assert ( run_config.train_config == _run_config.train_config and run_config.model_config == _run_config.model_config ), ( "Different supplied run_config than" f" existing run_config {self.model_dir}. \n{run_config}----\n{_run_config}" ) metadata = json.loads( self.model_dir.joinpath(self.METADATA_JSON).read_text( encoding="utf-8" ) ) self.checkpoint_iteration = metadata["checkpoint_iteration"] self.log_iteration = metadata["log_iteration"] (self.summary_dir, *chkpt_dirs) = futils.make_sub_dirs( model_dir, self.SUMMARY_DIR_NAME, *self.CHKPT_DIR_VALUES ) for name, path in zip(self.CHKPT_DIR_NAMES, chkpt_dirs): self.CHKPT_DIRS[name] = path self.result_json_path = self.model_dir / self.RESULTS_JSON_NAME self.log_file_path = self.model_dir.joinpath(self.LOG_FILE_NAME) self.dashboard = self._make_dashboard(self.summary_dir, run_config) self._write_config(run_config) self._update_metadata() self.logger = FileLogger(path=self.log_file_path, verbose=verbose)
def _update_metadata(self): """ Update the metadata file. """ if self.model_dir is None: return metadata_path = self.model_dir.joinpath(self.METADATA_JSON) metadata_path.write_text( json.dumps( { "log_iteration": self.log_iteration, "checkpoint_iteration": self.checkpoint_iteration, } ), encoding="utf-8", ) def _make_dashboard( self, summary_dir: Path, run_config: RunConfig | None = None ) -> LoggerBase | None: """ Make a dashboard logger. Parameters ---------- summary_dir : Path Path to the summary directory. run_config : RunConfig | None, optional The run configuration, by default None. Returns ------- LoggerBase | None A TensorboardLogger. """ if run_config is None or not run_config.tensorboard: return None return TensorboardLogger(summary_dir.joinpath("tensorboard")) def _write_config(self, run_config: RunConfig): """ Write the run configuration to the model directory and to the dashboard. Parameters ---------- run_config : RunConfig The run configuration. """ if self.model_dir is None: return self.model_dir.joinpath(self.CONFIG_FILE_NAME).write_text( str(run_config), encoding="utf-8" ) if self.dashboard is not None: self.dashboard.write_config(run_config) def _add_metric(self, k, v, itr): """ Add a metric to the dashboard. Parameters ---------- k : str The metric name. v : Any The metric value. itr : int The iteration. """ if self.dashboard is None: return if isinstance(v, (list, np.ndarray)): v = np.array(v) if v.dtype.kind in {"b", "i", "u", "f", "c"}: v_dict = {str(i): _v for i, _v in enumerate(v)} self.dashboard.add_scalars(k, v_dict, itr) else: self.dashboard.add_text(k, " ".join(v), itr) elif isinstance(v, dict): for sub_k, sub_v in v.items(): self.dashboard.add_scalar(f"{k}_{sub_k}", sub_v, itr) elif isinstance(v, MovingAverage): self.dashboard.add_scalar(k, v.get(), itr) elif isinstance(v, str): self.dashboard.add_text(k, v, itr) elif isinstance(v, Image.Image): self.dashboard.add_image( k, np.array(v).transpose(2, 0, 1), itr, dataformats="CHW" ) elif isinstance(v, pd.DataFrame): self.dashboard.add_table(k, v, itr) elif isinstance(v, (int, float)): self.dashboard.add_scalar(k, v, itr) else: raise ValueError( f"Unsupported dashboard value {v}. Must be " "[int,float, pd.DataFrame, Image.Image, str, " "MovingAverage, dict[str,float|int], list[float,int], np.ndarray] " ) def _append_metrics(self, metrics: dict[str, float]): """ Append metrics to the result json file. Parameters ---------- metrics : dict[str, float] The metrics to append. """ if self.result_json_path is not None: with open(self.result_json_path, "a", encoding="utf-8") as fp: fp.write(futils.dict_to_json(metrics) + "\n")
[docs] def update( self, metrics: Union[TrainMetrics, dict], itr: Optional[int] = None, ): """ Update the dashboard with the given metrics. write some metrics to json files and update the current metadata (``log_iteration``) Parameters ---------- metrics : Union[TrainMetrics, dict] The metrics to update. itr : Optional[int], optional The iteration, by default ``None``. Raises ------ AssertionError If the iteration is not greater than the current iteration. Notes ----- Attribute ``log_iteration`` is increased by 1 every time ``update()`` is called while training models. """ if itr is None: itr = self.log_iteration self.log_iteration += 1 else: assert ( itr > self.log_iteration ), f"Current iteration > {itr}. Can not add metrics." self.log_iteration = itr if isinstance(metrics, TrainMetrics): dict_metrics = metrics.to_dict() else: dict_metrics = metrics for k, v in dict_metrics.items(): self._add_metric(k, v, itr) self._append_metrics(dict_metrics) self._update_metadata()
[docs] def checkpoint( self, save_dict: object, file_name: str, itr: int | None = None, is_best: bool = False, ): """Save a checkpoint and update the checkpoint iteration Saves the model checkpoint in the appropriate directory based on the ``is_best`` parameter. If ``is_best==True``, the checkpoint is saved in the ``"best"`` directory, indicating the best performing model so far. Otherwise, the checkpoint is saved in the ``"recent"`` directory, representing the most recent checkpoint. The file path for the checkpoint is constructed using the selected directory name (``"best"`` or ``"recent"``), and the file name with the format ``"{file_name}_{itr:010}.pt"``, where ``itr`` is the iteration number. The ``checkpoint_iteration`` dictionary is updated with the current iteration number for each directory. If ``itr`` is not provided, the iteration number is increased by 1 each time a checkpoint is saved. Otherwise, the iteration number is set to the provided ``itr``. Parameters ---------- save_dict : object The object to save. file_name : str The file name. itr : int | None, optional The iteration, by default None. If not provided, the current iteration is incremented by 1. is_best : bool, optional Whether this is the best checkpoint, by default False. Raises ------ AssertionError If the provided ``itr`` is not larger than the current iteration associated with the checkpoint. """ if self.model_dir is None: return dir_name = "best" if is_best else "recent" if self.keep_n_checkpoints > 0: if dir_name not in self.checkpoint_iteration: self.checkpoint_iteration[dir_name] = {} if file_name not in self.checkpoint_iteration[dir_name]: self.checkpoint_iteration[dir_name][file_name] = -1 if itr is None: self.checkpoint_iteration[dir_name][file_name] += 1 itr = self.checkpoint_iteration[dir_name][file_name] else: cur_iter = self.checkpoint_iteration[dir_name][file_name] assert ( itr > cur_iter ), f"Checkpoint iteration {cur_iter} > training iteration {itr}. Can not save checkpoint." self.checkpoint_iteration[dir_name][file_name] = itr dir_path = self.model_dir.joinpath(self.CHKPT_DIRS[dir_name]) file_path = dir_path.joinpath(f"{file_name}_{itr:010}.pt") assert not file_path.exists(), f"Checkpoint exists: {file_path}" futils.save_checkpoint(save_dict, file_path) futils.clean_checkpoints(dir_path, self.keep_n_checkpoints) self._update_metadata()
[docs] def clean_checkpoints(self, keep_n_checkpoints: int): """ Clean up checkpoints and keep only the specified number of checkpoints. Parameters ---------- keep_n_checkpoints : int Number of checkpoints to keep. """ if self.model_dir is None: return for chkpt_dir in self.CHKPT_DIR_VALUES: dir_path = self.model_dir.joinpath(chkpt_dir) futils.clean_checkpoints(dir_path, keep_n_checkpoints)
[docs] def info(self, *args, **kwargs): """ Log an info to files and to console message using the logger. Parameters ---------- *args Positional arguments passed to the logger's info method. **kwargs Keyword arguments passed to the logger's info method. """ self.logger.info(*args, **kwargs)
[docs] def warn(self, *args, **kwargs): """ Log a warning message to files and to console using the logger. Parameters ---------- *args Positional arguments passed to the logger's warn method. **kwargs Keyword arguments passed to the logger's warn method. """ self.logger.warn(*args, **kwargs)
[docs] def error(self, *args, **kwargs): """ Log an error message to files and to console using the logger. Parameters ---------- *args Positional arguments passed to the logger's error method. **kwargs Keyword arguments passed to the logger's error method. """ self.logger.error(*args, **kwargs)