Source code for ablator.analysis.plot.cat_plot
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from ablator.analysis.plot import Plot
from ablator.main.configs import Optim
logger = logging.getLogger(__name__)
[docs]class Categorical(Plot):
DATA_TYPE = "categorical"
[docs] def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.attribute_metric_map = self._make_attribute_metric_map(
self.metric, self.attributes
)
@classmethod
def _make_attribute_metric_map(
cls,
metric: pd.Series,
attributes: pd.Series,
):
unique_values = attributes.unique()
metrics: dict[str, pd.Series] = {}
if None in unique_values:
unique_values = list(filter(None, unique_values))
none_name = "None"
if "None" in unique_values:
logger.warning(
"`None` name is present as categorical value as well as np.nan."
)
none_name = "Type(None)"
assert none_name not in unique_values, (
f"{none_name} is also present as a categorical. Highly "
"unlikely it is by accident."
)
metrics[none_name] = metric[attributes.isna()]
for u in sorted(unique_values):
metrics[u] = metric[attributes == u]
return metrics
def _sort_vals_obj(self, vals: pd.Series, obj_fn: Optim) -> np.ndarray:
if Optim(obj_fn) == Optim.min:
return vals.sort_values(na_position="last").values
return vals.sort_values(ascending=False, na_position="last").values
[docs]class ViolinPlot(Categorical):
[docs] def __init__(self, *args, **kwargs) -> None:
sns.set()
sns.set_style("whitegrid")
self.figsize = (8, 4)
super().__init__(*args, **kwargs)
def _make_figure(self, ax: Axes | None = None) -> tuple[Figure | None, Axes]:
figure = None
if ax is None:
figure = plt.figure(figsize=(10, 8))
ax = figure.add_subplot(1, 1, 1)
return figure, ax
def _make(
self,
**kwargs,
):
sns.violinplot(
[v.values for v in self.attribute_metric_map.values()],
ax=self.ax,
palette="Set3",
)
mean_perf = []
std_perf = []
median_perf = []
best_perf = []
for vals in self.attribute_metric_map.values():
# top performance marker
obj_fn = self.metrics_obj_fn
best_perf.append(self._sort_vals_obj(vals, obj_fn)[0])
std = np.std(vals)
if Optim(obj_fn) == Optim.min:
std *= -1
mean_perf.append(np.mean(vals))
std_perf.append(np.mean(vals) + std)
median_perf.append(np.median(vals))
labels = [
f"Mean: {mean:.2e}\nBest: {best:.2e}\n{name}"
for mean, best, name in zip(
mean_perf, best_perf, self.attribute_metric_map.keys()
)
]
self.ax.set_xticks(
np.arange(len(self.attribute_metric_map)),
labels=labels,
)
sns.despine(left=True, bottom=True)
return self.figure, self.ax
def _parse_legend(self, ax):
pass