Source code for ablator.utils.file
import copy
import json
import typing as ty
from pathlib import Path
import numpy as np
import pandas as pd
import torch
[docs]def make_sub_dirs(parent: str | Path, *dir_names) -> list[Path]:
"""
Create subdirectories under the given parent directory.
Parameters
----------
parent : str | Path
Parent directory where subdirectories should be created.
*dir_names : str
Names of the subdirectories to create.
Returns
-------
list[Path]
A list of created subdirectory paths.
"""
dirs = []
for dir_name in dir_names:
dir_path = Path(parent).joinpath(dir_name)
dir_path.mkdir(parents=True, exist_ok=True)
dirs.append(dir_path)
return dirs
[docs]def save_checkpoint(state, filename="checkpoint.pt"):
"""
Save a checkpoint of the given state.
Parameters
----------
state : dict
Model State dictionary to save.
filename : str, optional
The name of the checkpoint file, by default "checkpoint.pt".
"""
torch.save(state, filename)
[docs]def clean_checkpoints(checkpoint_folder: Path, n_checkpoints: int):
"""
Remove all but the n latest checkpoints from the given directory.
Parameters
----------
checkpoint_folder : Path
Directory containing the checkpoint files.
n_checkpoints : int
Number of checkpoints to keep.
"""
chkpts = sorted(list(checkpoint_folder.glob("*.pt")))[::-1]
# Keep only last n checkpoints (or first n because we sort in reverse)
if len(chkpts) > n_checkpoints:
chkpts_to_del = chkpts[n_checkpoints:]
for _chkpt in chkpts_to_del:
Path(_chkpt).unlink(missing_ok=True)
[docs]def default_val_parser(val):
"""Converts the input value to a JSON compatible format.
Parameters
----------
val : ty.Any
The value to be converted.
Returns
-------
ty.Any
The converted value.
"""
if isinstance(val, np.ndarray):
return val.tolist()
if isinstance(val, torch.Tensor):
return default_val_parser(val.detach().cpu().numpy())
if isinstance(val, pd.DataFrame):
return default_val_parser(np.array(val))
return str(val)
[docs]def json_to_dict(_json):
"""
Convert a JSON string into a dictionary.
Parameters
----------
_json : str
JSON string to be converted.
Returns
-------
dict
A dictionary representation of the JSON string.
"""
_dict = json.loads(_json)
return _dict
[docs]def dict_to_json(_dict):
"""
Convert a dictionary into a JSON string.
Parameters
----------
_dict : dict
The dictionary to be converted.
Returns
-------
str
The JSON string representation of the dictionary.
"""
_json = json.dumps(_dict, indent=0, default=default_val_parser)
# make sure it can be decoded
json_to_dict(_json)
return _json
[docs]def nested_set(_dict, keys: list[str], value: ty.Any):
"""
Set a value in a nested dictionary.
Parameters
----------
_dict : dict
The dictionary to update.
keys : list[str]
List of keys representing the nested path.
value : ty.Any
The value need to set at the specified path.
Examples
--------
>>> _dict = {'a': {'b': {'c': 1}}}
>>> nested_set(_dict, ['a', 'b', 'c'], 2)
>>> _dict
{'a': {'b': {'c': 2}}}
Returns
-------
dict
The updated dictionary with the new value set.
"""
original_dict = copy.deepcopy(_dict)
x = original_dict
for key in keys[:-1]:
if key not in x:
x[key] = {}
x = x[key]
x[keys[-1]] = value
return original_dict