Source code for ablator.main.state

import copy
import enum
import typing as ty
from collections import OrderedDict
from pathlib import Path
import builtins

import numpy as np
import optuna
from optuna.trial import TrialState as OptunaTrialState
from sqlalchemy import Integer, PickleType, String, create_engine, select
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column

import ablator.utils.base as butils
from ablator.main.configs import (
    Optim,
    ParallelConfig,
    SearchAlgo,
    SearchSpace,
    SearchType,
)
from ablator.modules.loggers.file import FileLogger
from ablator.utils.file import nested_set


[docs]class Base(DeclarativeBase): pass
[docs]class TrialState(enum.IntEnum): """ An enumeration of possible states for a trial with more pruned states. Attributes ---------- RUNNING : int A trial that has been succesfully scheduled to run COMPLETE : int Succesfully completed trial PRUNED : int Trial pruned because of various reasons FAIL : int Trial that produced an error during execution WAITING : int Trial that has been sampled but is not scheduled to run yet PRUNED_INVALID : int Trial that was pruned during sampling as it was invalid PRUNED_DUPLICATE : int Trial that was sampled but was already present PRUNED_POOR_PERFORMANCE : int Trial that was pruned during execution for poor performance RECOVERABLE_ERROR : int Trial that was pruned during execution for poor performance RESUME : int Trial that needs to be resumed Methods ------- to_optuna_state: Convert this TrialState to an OptunaTrialState. """ # extension of "optuna.trial.TrialState" RUNNING = 0 # A trial that has been succesfully scheduled to run COMPLETE = 1 # Succesfully completed trial PRUNED = 2 # Trial pruned because of various reasons FAIL = 3 # Trial that produced an error during execution WAITING = 4 # Trial that has been sampled but is not scheduled to run yet PRUNED_INVALID = 5 # Trial that was pruned during sampling as it was invalid PRUNED_DUPLICATE = 6 # Trial that was sampled but was already present PRUNED_POOR_PERFORMANCE = ( 7 # Trial that was pruned during execution for poor performance ) RECOVERABLE_ERROR = 8 # Trial that was pruned during execution for poor performance RESUME = 9 # A trial that needs to be resumed
[docs] def to_optuna_state(self) -> OptunaTrialState | None: """ Convert this ``TrialState`` to an ``OptunaTrialState``. Returns ------- OptunaTrialState | None: Corresponding ``OptunaTrialState`` or ``None`` if the state is not applicable. """ if self in [ TrialState.PRUNED, TrialState.PRUNED_INVALID, TrialState.PRUNED_DUPLICATE, TrialState.PRUNED_POOR_PERFORMANCE, ]: return OptunaTrialState.PRUNED if self in {TrialState.RUNNING}: return None return OptunaTrialState(self)
[docs]def augment_trial_kwargs( trial_kwargs: dict[str, ty.Any], augmentation: dict[str, ty.Any] ) -> dict[str, ty.Any]: """ Augment the ``trial_kwargs`` with additional key-value pairs specified in the augmentation dictionary. Parameters ---------- trial_kwargs : dict The dictionary containing the key-value pairs to be augmented. augmentation : dict The dictionary containing the additional key-value pairs. Returns ------- dict The augmented dictionary. Examples -------- >>> trial_kwargs = {'a': 1, 'b': 2} >>> augmentation = {'c': 3, 'd.e': 4} >>> augment_trial_kwargs(trial_kwargs, augmentation) {'a': 1, 'b': 2, 'c': 3, 'd': {'e': 4}} """ trial_kwargs = copy.deepcopy(trial_kwargs) config_dot_path: str dot_paths = list(augmentation.keys()) assert len(set(dot_paths)) == len( dot_paths ), f"Duplicate tune paths: {set(dot_paths).difference(dot_paths)}" for config_dot_path, val in augmentation.items(): path: list[str] = config_dot_path.split(".") trial_kwargs = nested_set(trial_kwargs, path, val) return trial_kwargs
[docs]def parse_metrics( metric_directions: dict[str, Optim], metrics: dict[str, float] ) -> dict[str, float]: """ Convert metrics to ordered dictionary of float values using their direction (minimize or maximize). Parameters ---------- metric_directions : dict The ordered dictionary containing the directions of the metrics (minimize or maximize). metrics : dict The dictionary containing the metric values. Returns ------- OrderedDict The ordered dictionary of metric values converted to float using their direction. Examples -------- >>> metric_directions = OrderedDict([('a', 'max'), ('b', 'min')]) >>> metrics = {'a': 1, 'b': None} >>> parse_metrics(metric_directions, metrics) OrderedDict([('a', 1.0), ('b', inf)]) """ vals = OrderedDict() metric_keys = set(metric_directions) user_metrics = set(metrics) assert ( user_metrics == metric_keys ), f"Different specified metric directions `{metric_keys}` and `{user_metrics}`" for k, v in metric_directions.items(): val = metrics[k] if val is None or not np.isfinite(val): val = float("-inf") if Optim(v) == Optim.max else float("inf") vals[k] = val return vals
[docs]def sample_trial_params( optuna_trial: optuna.Trial, search_space: dict[str, SearchSpace], ) -> dict[str, ty.Any]: """ Sample parameter values from the search space for a given Optuna trial. Parameters ---------- optuna_trial : optuna.Trial The Optuna trial object. search_space : dict of str to SearchSpace The search space containing the parameters to sample from. Returns ------- dict of str to any The dictionary containing the sampled parameter values. Raises ------ ValueError If the search space contains an invalid ``SearchSpace`` object. Examples -------- >>> optuna_trial = self.optuna_study.ask() >>> search_space = {'x': SearchSpace(value_type=SearchType.numerical, value_range=(0.0, 1.0)), ... 'y': SearchSpace(categorical_values=['a', 'b']), ... 'z': SearchSpace(value_type=SearchType.integer, value_range=(1, 10))} >>> sample_trial_params(optuna_trial, search_space) {'x': 0.030961748695615783, 'y': 'a', 'z': 9} """ parameter: dict[str, ty.Any] = {} for k, v in search_space.items(): # TODO conditional sampling if v.value_range is not None and v.value_type == SearchType.integer: low_str, high_str = v.value_range low_int = int(low_str) high_int = int(high_str) assert ( min(low_int, high_int) == low_int ), "`value_range` must be in the format of (min,max)" parameter[k] = optuna_trial.suggest_int(k, low_int, high_int) elif v.value_range is not None and v.value_type == SearchType.numerical: low_str, high_str = v.value_range low_float = float(low_str) high_float = float(high_str) assert ( min(low_float, high_float) == low_float ), "`value_range` must be in the format of (min,max)" parameter[k] = optuna_trial.suggest_float(k, low_float, high_float) elif v.categorical_values is not None: parameter[k] = optuna_trial.suggest_categorical(k, v.categorical_values) else: raise ValueError(f"Invalid SearchSpace {v}.") return parameter
[docs]class Trial(Base): __tablename__ = "trial" id: Mapped[int] = mapped_column(primary_key=True) config_uid: Mapped[str] = mapped_column(String(30)) metrics: Mapped[PickleType] = mapped_column(PickleType) config_param: Mapped[PickleType] = mapped_column(PickleType) optuna_trial_num: Mapped[str] = mapped_column(Integer) state: Mapped[PickleType] = mapped_column(PickleType, default=TrialState.WAITING) runtime_errors: Mapped[int] = mapped_column(Integer, default=0) def __repr__(self) -> str: return f"Trial(id={self.id!r}, config_uid={self.config_uid!r}, fullname={self.config_param!r})"
[docs]class OptunaState: """ A class to store the state of the Optuna study. Attributes ---------- optim_metrics : OrderedDict The ordered dictionary containing the names of the metrics to optimize and their direction (min or max). search_space : dict of str to SearchSpace The search space containing the parameters to sample from. optuna_study : optuna.study.Study The Optuna study object. """
[docs] def __init__( self, storage: str, study_name, optim_metrics: dict[str, Optim], search_algo, search_space: dict[str, SearchSpace], ) -> None: """ Initialize the Optuna state. Parameters ---------- storage : str The path to the database URL or a database URL. study_name : str The name of the study. optim_metrics : dict[str, Optim] A dictionary of metric names and their optimization directions (either ``'max'`` or ``'min'``). search_algo : SearchAlgo The search algorithm to use (``'random'`` or ``'tpe'``). search_space : dict[str, SearchSpace] A dictionary of parameter names and their corresponding SearchSpace instances. Raises ------ NotImplementedError If the specified search algorithm is not implemented. ValueError If ``optim_metrics`` is ``None``. Notes ----- For tuning, add an attribute to the searchspace whose name is the name of the hyperparameter and whose value is the search space eg. ``search_space = {"train_config.optimizer_config.arguments.lr": SearchSpace(value_range=[0, 0.1], value_type="float")}`` """ sampler: optuna.samplers.BaseSampler if search_algo == SearchAlgo.random: sampler = optuna.samplers.RandomSampler() elif search_algo == SearchAlgo.tpe: sampler = optuna.samplers.TPESampler() else: raise NotImplementedError if optim_metrics is None: raise ValueError("Must specify optim_metrics.") self.optim_metrics = OrderedDict(optim_metrics) self.search_space = search_space directions = [ optuna.study.StudyDirection( 2 if Optim(optim_metrics[k]) == Optim.max else 1 ) for k in optim_metrics ] self.optuna_study = optuna.create_study( study_name=study_name, storage=storage, directions=directions, load_if_exists=True, sampler=sampler, )
def _optuna_optim_values(self, metrics: dict[str, float]) -> list[float]: """ Convert the input metrics dictionary to a list of metric values. Parameters ---------- metrics : dict[str, float] A dictionary of metric names and their values. Returns ------- list[float] A list of metric values corresponding to the input metrics dictionary. Examples -------- >>> optuna_state = OptunaState( ... storage="sqlite:///example.db", study_name="test_study", ... optim_metrics={"accuracy": Optim.max}, search_algo=SearchAlgo.tpe, ... search_space = {"train_config.optimizer_config.arguments.lr": ... SearchSpace(value_range=[0, 0.1], value_type="float")}) >>> metrics = {"accuracy": None} >>> optuna_optim_values = optuna_state._optuna_optim_values(metrics) >>> print(optuna_optim_values) [-inf] """ return list(parse_metrics(self.optim_metrics, metrics).values())
[docs] def update_trial( self, trial_num: int, metrics: dict[str, float] | None, state: TrialState, ): """ Update the state of a trial when it is completed with metrics. Parameters ---------- trial_num : int The trial number. metrics : dict[str, float] or None A dictionary of metric names and their corresponding values, or ``None`` if the trial is not complete. state : TrialState The state of the trial. Raises ------ RuntimeError If ``metrics`` is ``None`` and ``state`` is ``COMPLETE``. """ if metrics is None and state == TrialState.COMPLETE: raise RuntimeError(f"Missing metrics for complete trial {trial_num}.") if metrics is None or state != TrialState.COMPLETE: return optuna_state = state.to_optuna_state() optuna_metrics = self._optuna_optim_values(metrics) # TODO raises error for nan values in metrics. Fixme self.optuna_study.tell(trial_num, optuna_metrics, optuna_state)
[docs] def sample_trial(self): """ Sample a new set of trial parameters. Returns ------- Tuple[int, dict[str, Any]] A tuple of the trial number and a dictionary of parameter names and their corresponding values. """ optuna_trial = self.optuna_study.ask() return ( optuna_trial.number, sample_trial_params(optuna_trial, self.search_space), )
[docs]class ExperimentState:
[docs] def __init__( self, experiment_dir: Path, config: ParallelConfig, logger: FileLogger | None = None, resume: bool = False, ) -> None: """ Initializes the ExperimentState. Initialize databases for storing training states and optuna states Create trials based on total num of trials specified in config Parameters ---------- experiment_dir : Path The directory where the experiment data will be stored. config : ParallelConfig The configuration object that defines the experiment settings. logger : FileLogger, optional The logger to use for outputting experiment logs. If not specified, a dummy logger will be used. resume : bool, optional Whether to resume a previously interrupted experiment. Default is ``False``. Raises ------ RuntimeError If the specified ``search_space`` parameter is not found in the configuration. AssertionError If ``config.search_space`` is empty. RuntimeError if the optuna database already exists and ``resume`` is ``False``. """ self.optuna_trial_map: dict[str, optuna.Trial] = {} self.config = config self.logger: FileLogger = logger if logger is not None else butils.Dummy() optuna.logging.set_verbosity(optuna.logging.WARNING) default_vals = self.config.make_dict(self.config.annotations, flatten=True) assert len(self.config.search_space), "Must specify a config.search_space." for k in self.config.search_space: if k not in default_vals: raise RuntimeError( f"SearchSpace parameter {k} was not found in the configuration {sorted(list(default_vals.keys()))}." ) study_name = config.uid self.experiment_dir = experiment_dir optuna_db_path = experiment_dir.joinpath(f"{study_name}_optuna.db") if optuna_db_path.exists() and not resume: raise RuntimeError( f"{optuna_db_path} exists. Please remove before starting a study." ) self.optuna_state = OptunaState( f"sqlite:///{optuna_db_path}", study_name=config.uid, optim_metrics=config.optim_metrics, search_algo=config.search_algo, search_space=config.search_space, ) experiment_state_db = experiment_dir.joinpath(f"{study_name}_state.db") self.engine = create_engine(f"sqlite:///{experiment_state_db}", echo=False) Trial.metadata.create_all(self.engine) self._init_trials(resume=resume)
[docs] @staticmethod def search_space_dot_path(trial: ParallelConfig) -> dict[str, ty.Any]: """ Returns a dictionary of parameter names and their corresponding values for a given trial. Parameters ---------- trial : ParallelConfig The trial object to get the search space dot paths from. Returns ------- dict[str, Any] A dictionary of parameter names and their corresponding values. Examples -------- >>> search_space = {"train_config.optimizer_config.arguments.lr": SearchSpace(value_range=[0, 0.1], value_type="float")} >>> {"train_config.optimizer_config.arguments.lr": 0.1} """ return { dot_path: trial.get_val_with_dot_path(dot_path) for dot_path in trial.search_space.keys() }
[docs] @staticmethod def tune_trial_str(trial: ParallelConfig) -> str: """ Generate a string representation of a trial object. Parameters ---------- trial : ParallelConfig The trial object to generate a string representation for. Returns ------- str A string representation of the trial object. """ trial_map = ExperimentState.search_space_dot_path(trial) msg = f"\n{trial.uid}:\n\t" msg = "\n\t".join( [f"{dot_path} -> {val} " for dot_path, val in trial_map.items()] ) return msg
def _init_trials(self, resume: bool = False) -> list[ParallelConfig]: """ Initialize trials for the experiment. If resume is True, then load the trials from the database and create new trials for the remaining trials. Parameters ---------- resume : bool, optional, default=False Whether to resume an existing experiment. Returns ------- list[ParallelConfig] The list of initialized trials. Raises ------ RuntimeError If an experiment exists and ``resume`` is False. AssertionError If no trials can be scheduled. """ max_trials_conc = min(self.config.concurrent_trials, self.config.total_trials) if self.config.search_algo in [ SearchAlgo.random, ]: trials_to_sample = self.n_trials_remaining else: trials_to_sample = max_trials_conc # if there are currently running trials, return those first and is resume. Otherwise # ERROR if len(self.running_trials) > 0 and not resume: raise RuntimeError( "Experiment exists. You need to use `resume = True` or use a different path." ) running_trials = [] for trial in self.running_trials: self.update_trial_state(trial.uid, None, TrialState.RESUME) running_trials.append(trial) trials = self.__sample_trials( trials_to_sample, running_trials + self.pending_trials, ignore_errors=self.config.ignore_invalid_params, )[:max_trials_conc] assert len(trials) > 0, "No trials could be scheduled." return trials
[docs] def sample_trials(self, n_trials_to_sample: int) -> list[ParallelConfig] | None: """ Sample ``n`` trials from the search space and update database. Number ``n`` is the miniumn value of ``n_trials_to_sample`` and ``n_trials_remaining``. ``n_trials_remaining`` is the number of ``total_trials`` (defined in config) minus the number of trials that have been sampled. Parameters ---------- n_trials_to_sample : int The number of trials to sample. Returns ------- list[ParallelConfig] | None The list of sampled trials. """ # Return pending trials when sampling first. assert n_trials_to_sample > 0 n_trials_to_sample = min(self.n_trials_remaining, n_trials_to_sample) if self.n_trials_remaining == 0: self.logger.warn( f"Limit of trials to sample '{self.config.total_trials}' reached." ) return None trials = self.__sample_trials( n_trials_to_sample, prev_trials=[], ignore_errors=self.config.ignore_invalid_params, )[:n_trials_to_sample] return trials
def __append_trial( self, trial_kwargs: dict[str, ty.Any], optuna_trial_num: int, trial_state: TrialState, ) -> bool: """ Append a trial to the experiment state database. Parameters ---------- trial_kwargs : dict[str, Any] config dict with new sampled hyperparameters. optuna_trial_num : int The optuna trial number. trial_state : TrialState The state of the trial. Returns ------- bool True if the trial state is not pruned, False otherwise. """ if trial_state in {TrialState.PRUNED_INVALID, TrialState.PRUNED_DUPLICATE}: # self.optuna_state.update_trial(optuna_trial_num, None, trial_state) self.__append_trial_internal( "none", trial_kwargs, optuna_trial_num, trial_state ) return False trial_config = type(self.config)(**trial_kwargs) self.__append_trial_internal( trial_config.uid, trial_kwargs, optuna_trial_num, trial_state ) return True def __sample_trials( self, n_trials: int, prev_trials: list[ParallelConfig] | None = None, ignore_errors=False, ) -> list[ParallelConfig]: """ Samples a specified number of trials from the search space and persists states to experiment database. Previous trials can be reused to avoid sampling the same trials again. Parameters ---------- n_trials : int The number of trials to sample. prev_trials : list[ParallelConfig] | None, optional A list of previously sampled trials, by default None. ignore_errors : bool, optional Whether to ignore invalid parameters and continue sampling, by default False. Returns ------- list[ParallelConfig] A list of the sampled trials. Raises ------ RuntimeError If the number of invalid or duplicate trials exceeds the error_upper_bound. TypeError If the trial parameter are invalid """ error_upper_bound = n_trials * 10 sampled_trials: list[ParallelConfig] = ( [] if prev_trials is None else prev_trials ) while len(sampled_trials) < n_trials: if ( len(self.pruned_errored_trials) + len(self.pruned_duplicate_trials) > error_upper_bound ): raise RuntimeError( f"Reached maximum limit of misconfigured trials. {error_upper_bound}\n" f"Found {len(self.pruned_duplicate_trials)} duplicate and " f"{len(self.pruned_errored_trials)} invalid trials." ) trial_num, parameter = self.optuna_state.sample_trial() trial_kwargs = augment_trial_kwargs( trial_kwargs=self.config.to_dict(), augmentation=parameter ) trial_state = TrialState.WAITING try: trial_config = type(self.config)(**trial_kwargs) if trial_config.uid in self.all_trials_uid: trial_state = TrialState.PRUNED_DUPLICATE except builtins.Exception as e: if ignore_errors: trial_state = TrialState.PRUNED_INVALID self.logger.warn(f"ignoring: {parameter}. Error:{e}") else: raise TypeError(f"Invalid trial parameters {parameter}") from e if self.__append_trial(trial_kwargs, trial_num, trial_state): sampled_trials.append(trial_config) return sampled_trials
[docs] def update_trial_state( self, config_uid: str, metrics: dict[str, float] | None = None, state: TrialState = TrialState.RUNNING, ) -> None: """ Update the state of a trial in both the Experiment database and tell Optuna. Parameters ---------- config_uid : str The uid of the trial to update. metrics : dict[str, float] | None, optional The metrics of the trial, by default ``None``. state : TrialState, optional The state of the trial, by default ``TrialState.RUNNING``. Examples -------- >>> experiment.update_trial_state("fje_2211", {"loss": 0.1}, TrialState.COMPLETED) """ if state == TrialState.RECOVERABLE_ERROR: self._inc_error_count(config_uid, state) return self._update_internal_trial_state(config_uid, metrics, state) # NOTE currently it is error prone to update the optuna state trial_num = self._get_optuna_trial_num(config_uid) self.optuna_state.update_trial(trial_num, metrics, state)
def _get_optuna_trial_num(self, config_uid: str) -> int: """ Get the optuna trial number from the database. Parameters ---------- config_uid : str The uid of the trial Returns ------- int The optuna trial number. """ with Session(self.engine) as session: stmt = select(Trial).where(Trial.config_uid == config_uid) res = session.scalar(stmt) if res is not None: return int(res.optuna_trial_num) raise ValueError(f"No trial found with config_uid: {config_uid}") def _update_internal_trial_state( self, config_uid: str, metrics: dict[str, float] | None, state: TrialState ): """ Update the state of a trial in the Experiment state database. Parameters ---------- config_uid : str The uid of the trial to update. metrics : dict[str, float] | None The metrics of the trial. state : TrialState The state of the trial. Returns ------- bool True if the update was successful. """ if metrics is not None: internal_metrics = parse_metrics(self.config.optim_metrics, metrics) else: internal_metrics = None with Session(self.engine) as session: stmt = select(Trial).where(Trial.config_uid == config_uid) res = session.execute(stmt).scalar_one() res.metrics.append(internal_metrics) res.state = state # type: ignore # TODO fix this session.commit() session.flush() return True def _inc_error_count(self, config_uid: str, state: TrialState): with Session(self.engine) as session: stmt = select(Trial).where(Trial.config_uid == config_uid) res = session.execute(stmt).scalar_one() assert state == TrialState.RECOVERABLE_ERROR runtime_errors = copy.deepcopy(res.runtime_errors) res.runtime_errors = Trial.runtime_errors + 1 session.commit() session.flush() if runtime_errors < 10: self.logger.warn(f"{config_uid} failed {runtime_errors} times.") self.update_trial_state(config_uid, None, TrialState.WAITING) else: self.logger.error(f"{config_uid} failed {runtime_errors} times. Skipping.") self.update_trial_state(config_uid, None, TrialState.FAIL) def __append_trial_internal( self, config_uid: str, trial_kwargs: dict[str, ty.Any], optuna_trial_num: int, trial_state: TrialState, ): """ Append a trial to the Experiment state database. Parameters ---------- config_uid : str The uid of the trial to update. trial_kwargs : dict[str, ty.Any] config dict with new sampled hyperparameters. optuna_trial_num : int The optuna trial number. trial_state : TrialState The state of the trial. """ with Session(self.engine) as session: trial = Trial( config_uid=config_uid, config_param=trial_kwargs, optuna_trial_num=optuna_trial_num, state=trial_state, metrics=[], ) session.add(trial) session.commit() def _get_trials_by_stmt(self, stmt) -> list[Trial]: with self.engine.connect() as conn: trials: list[Trial] = conn.execute(stmt).fetchall() # type: ignore return trials def _get_trial_configs_by_stmt(self, stmt) -> list[ParallelConfig]: trials = self._get_trials_by_stmt(stmt) configs = [] for trial in trials: trial_config = type(self.config)(**dict(trial.config_param)) configs.append(trial_config) assert trial_config.uid == trial.config_uid return configs @property def all_trials_uid(self) -> list[str]: return [c.uid for c in self.all_trials] @property def all_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where( # (Trial.state != TrialState.WAITING) (Trial.state != TrialState.PRUNED_DUPLICATE) & (Trial.state != TrialState.PRUNED_INVALID) ) return self._get_trial_configs_by_stmt(stmt) @property def pruned_errored_trials(self) -> list[dict[str, ty.Any]]: """ Error trials can not be initialized to a configuration and such as return the kwargs parameters. """ stmt = select(Trial).where((Trial.state == TrialState.PRUNED_INVALID)) trials = self._get_trials_by_stmt(stmt) return [dict(trial.config_param) for trial in trials] @property def pruned_duplicate_trials(self) -> list[dict[str, ty.Any]]: stmt = select(Trial).where((Trial.state == TrialState.PRUNED_DUPLICATE)) trials = self._get_trials_by_stmt(stmt) return [dict(trial.config_param) for trial in trials] @property def running_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where((Trial.state == TrialState.RUNNING)) return self._get_trial_configs_by_stmt(stmt) @property def pending_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where( (Trial.state == TrialState.WAITING) | (Trial.state == TrialState.RESUME) ) return self._get_trial_configs_by_stmt(stmt) @property def resumed_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where((Trial.state == TrialState.RESUME)) return self._get_trial_configs_by_stmt(stmt) @property def complete_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where((Trial.state == TrialState.COMPLETE)) return self._get_trial_configs_by_stmt(stmt) @property def failed_trials(self) -> list[ParallelConfig]: stmt = select(Trial).where((Trial.state == TrialState.FAIL)) return self._get_trial_configs_by_stmt(stmt) @property def n_trials_remaining(self) -> int: """ We get all trials as it can include, trials at different states. We exclude the unscheduled trials (pending), and the ones that are pruned during sampling. """ return self.config.total_trials - len(self.all_trials)