Source code for ablator.analysis.plot
import logging
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from ablator.main.configs import Optim
logger = logging.getLogger(__name__)
[docs]class Plot(ABC):
[docs] def __init__(
self,
metric: pd.Series,
attributes: pd.Series,
metric_obj_fn: Optim,
y_axis: str | None = None,
x_axis: str | None = None,
x_ticks: np.ndarray | None = None,
ax: Axes | None = None,
) -> None:
self.attributes = self._parse_attributes(metric, attributes)
self.metric = self._parse_metrics(metric)
self.metrics_obj_fn = metric_obj_fn
self.y_axis = y_axis
self.x_axis = x_axis
self.x_ticks = x_ticks
self.figure, self.ax = self._make_figure(ax)
def _make_figure(self, ax: Axes | None = None) -> tuple[Figure | None, Axes]:
figure = None
if ax is None:
figure = plt.figure(figsize=(4, 4))
ax = figure.add_subplot(1, 1, 1)
return figure, ax
def _parse_attributes(self, metric: pd.Series, attributes: pd.Series) -> pd.Series:
attributes = attributes[~metric.isna()]
return attributes
def _parse_metrics(self, metric: pd.Series) -> pd.Series:
metric = metric[~metric.isna()]
return metric
def _parse_legend(self, ax):
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05))
def _parse_figure_axis(
self,
ax: Axes,
x_axis: str | None = None,
y_axis: str | None = None,
labels: list[str] | None = None,
):
if labels is not None:
ax.set_xticklabels(labels, size=14)
ax.set_xticks(np.arange(len(labels)) + 1)
if x_axis is not None:
ax.set_xlabel(x_axis, size=18)
if y_axis is not None:
ax.set_ylabel(y_axis, size=18)
self._parse_legend(ax)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=12)
ax.figure.tight_layout()
def make(self, **kwargs):
fig, ax = self._make(**kwargs)
self._parse_figure_axis(ax, self.x_axis, self.y_axis, self.x_ticks)
return fig, ax
@abstractmethod
def _make(self, **kwargs):
pass