Source code for ablator.analysis.plot.main

import logging
import typing as ty
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from PIL import Image

from ablator.analysis.main import Analysis
from ablator.analysis.plot import Plot
from ablator.analysis.plot.cat_plot import ViolinPlot
from ablator.analysis.plot.num_plot import LinearPlot
from ablator.main.configs import Optim

logger = logging.getLogger(__name__)


[docs]class PlotAnalysis(Analysis): """ Class for analyzing plotting """ @classmethod def _write_images( cls, fig_map: dict[str, ty.Union[Axes, Figure, Image.Image]], path: Path, file_format: ty.Literal["png", "pdf", "jpg"] = "png", ): """ Write images to a directory based on fig types, Parameters ---------- fig_map: dict[str, ty.Union[Axes, Figure, Image.Image]] A dictionary mapping names to matplotlib objects. path: Path Path to save the images to. file_format: ty.Literal["png", "pdf", "jpg"] the file format to save the images as. Examples -------- >>> fig_map = {"figure1": plt.subplots()[0]} >>> path = Path("output_dir") >>> PlotAnalysis._write_images(fig_map, path, "png") """ path.mkdir(exist_ok=True, parents=True) for name, fig in fig_map.items(): img_path = path.joinpath(f"{name}.{file_format}") if isinstance(fig, Axes): fig.figure.savefig(img_path) elif isinstance(fig, Figure): fig.savefig(img_path) elif isinstance(fig, Image.Image): fig.save(img_path) @classmethod def _make_metric_plots( cls, path: Path | None, plot_cls: type[Plot], metrics: pd.DataFrame, results: pd.DataFrame, metric_map: dict[str, Optim], append=False, ax: Axes | None = None, metric_name_remap=None, attribute_name_remap=None, **kwargs, ): """ Method level docstring goes here. Parameters ---------- path: Path | None A pathlib.Path object representing the directory to write images to. plot_cls: type[Plot] A subclass of Plot representing the type of plot to make. metrics: pd.DataFrame A pandas DataFrame containing metric values. results: pd.DataFrame A pandas DataFrame containing attribute values. metric_map: dict[str, Optim] A dictionary mapping metric names to optimization functions. append: bool A boolean indicating whether to append plots to an existing axes object. ax: Axes | None A matplotlib.axes.Axes object representing the axis to plot on. metric_name_remap: dict A dictionary mapping metric names to new metric names. attribute_name_remap: dict A dictionary mapping attribute names to new attribute names. kwargs: Additional keyword arguments to pass to the plot method. Examples -------- >>> metrics = pd.DataFrame({"metric1": [1, 2, 3], "metric2": [4, 5, 6]}) >>> results = pd.DataFrame({"attr1": [7, 8, 9], "attr2": [10, 11, 12]}) >>> metric_map = {"metric1": Optim.max, "metric2": Optim.min} >>> PlotAnalysis._make_metric_plots(None, LinearPlot, metrics, results, metric_map, False, None, None, None) """ axes = {} (results, metrics, metric_map) = cls._remap_results( results, metrics, metric_map, metric_name_remap, attribute_name_remap ) inv_metric_name_remap = None if metric_name_remap is not None: inv_metric_name_remap = {v: k for k, v in metric_name_remap.items()} inv_attribute_name_remap = None if attribute_name_remap is not None: inv_attribute_name_remap = {v: k for k, v in attribute_name_remap.items()} for metric_name in metrics.columns: metric_values = metrics[metric_name] metric_obj_fn = metric_map[metric_name] axes_map = cls._make_plot( plot_cls=plot_cls, metric_values=metric_values, metric_obj_fn=metric_obj_fn, results=results, append=append, ax=ax, inv_attribute_name_map=inv_attribute_name_remap, **kwargs, ) if inv_metric_name_remap is not None: original_metric_name = inv_metric_name_remap[metric_name] else: original_metric_name = metric_name if path is not None: p = path.joinpath(original_metric_name) cls._write_images(axes_map, p) for axe in axes_map.values(): plt.close(axe.figure) axes[metric_name] = axes_map return axes @classmethod def _make_plot( cls, metric_values: pd.Series, results: pd.DataFrame, plot_cls: type[Plot], metric_obj_fn: Optim, append=False, ax: Axes | None = None, inv_attribute_name_map: dict[str, str] | None = None, **kwargs, ) -> dict[str, Axes]: axes_map = {} for attribute_name in results.columns: attribute_values = results[attribute_name] _, axes = plot_cls( metric=metric_values, attributes=attribute_values, metric_obj_fn=metric_obj_fn, y_axis=metric_values.name, x_axis=attribute_name, ax=ax, ).make(**kwargs) if inv_attribute_name_map is not None: original_attribute_name = inv_attribute_name_map[attribute_name] else: original_attribute_name = attribute_name if append: ax = axes else: axes_map[original_attribute_name] = axes plt.close() if append: axes_map["combined_attribute"] = axes plt.close() return axes_map def make_violinplot( self, attribute_names: list[str], metrics: list[str], save_dir: ty.Union[Path, str], **plt_kwargs, ): save_path = Path(save_dir).joinpath("violinplot") metric_map = {k: v for k, v in self.optim_metrics.items() if k in metrics} self._make_metric_plots( path=save_path, plot_cls=ViolinPlot, metrics=self.results[metrics], results=self.results[attribute_names], metric_map=metric_map, **plt_kwargs, ) def make_linearplot( self, attribute_names: list[str], metrics: list[str], save_dir: ty.Union[Path, str], **plt_kwargs, ): save_path = Path(save_dir).joinpath("linearplot") metric_map = {k: v for k, v in self.optim_metrics.items() if k in metrics} return self._make_metric_plots( path=save_path, plot_cls=LinearPlot, metrics=self.results[metrics], results=self.results[attribute_names], metric_map=metric_map, **plt_kwargs, ) def make_figures( self, metric_name_remap: dict[str, str] | None = None, attribute_name_remap: dict[str, str] | None = None, **plt_kwargs, ): cat_attrs = list(self.categorical_attributes) num_attrs = list(self.numerical_attributes) if attribute_name_remap is not None: cat_attrs = list(set(attribute_name_remap.keys()).intersection(cat_attrs)) num_attrs = list(set(attribute_name_remap.keys()).intersection(num_attrs)) if len(cat_attrs) > 0: for plot_fn in ("make_violinplot",): getattr(self, plot_fn)( cat_attrs, self.metric_names, metric_name_remap=metric_name_remap, attribute_name_remap=attribute_name_remap, save_dir=self.save_dir, **plt_kwargs, ) if len(num_attrs) > 0: for plot_fn in ("make_linearplot",): getattr(self, plot_fn)( num_attrs, self.metric_names, metric_name_remap=metric_name_remap, attribute_name_remap=attribute_name_remap, save_dir=self.save_dir, **plt_kwargs, )