This commit is contained in:
Jan Kowalczyk
2025-09-27 16:34:52 +02:00
parent cfb77dccab
commit c270783225
6 changed files with 609 additions and 118 deletions

View File

@@ -8,11 +8,11 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.lines import Line2D
from scipy.stats import sem, t
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
from plot_scripts.load_results import load_results_dataframe
from load_results import load_results_dataframe
from matplotlib.lines import Line2D
from scipy.stats import sem, t
# ---------------------------------
# Config
@@ -23,6 +23,10 @@ OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
EVALS = ["exp_based", "manual_based"]
EVALS_LABELS = {
"exp_based": "Experiment-Based Labels",
"manual_based": "Manually-Labeled",
}
# Interp grids
ROC_GRID = np.linspace(0.0, 1.0, 200)
@@ -30,6 +34,10 @@ PRC_GRID = np.linspace(0.0, 1.0, 200)
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
BASELINE_NET = "Efficient"
BASELINE_LABELS = {
"isoforest": "Isolation Forest",
"ocsvm": "One-Class SVM",
}
# Colors/styles
COLOR_BASELINES = {
@@ -147,12 +155,8 @@ def _select_rows(
return df.filter(pl.all_horizontal(exprs))
def _auc_list(sub: pl.DataFrame) -> list[float]:
return [x for x in sub.select("auc").to_series().to_list() if x is not None]
def _ap_list(sub: pl.DataFrame) -> list[float]:
return [x for x in sub.select("ap").to_series().to_list() if x is not None]
def _auc_list(sub: pl.DataFrame, kind: str) -> list[float]:
return [x for x in sub.select(f"{kind}_auc").to_series().to_list() if x is not None]
def _plot_panel(
@@ -165,7 +169,7 @@ def _plot_panel(
kind: str,
):
"""
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient).
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + Baselines (from Efficient).
Legend entries include mean±CI of AUC/AP.
"""
ax.grid(True, alpha=0.3)
@@ -200,9 +204,9 @@ def _plot_panel(
continue
# Metric for legend
metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b)
metric_vals = _auc_list(sub_b, kind)
m, ci = mean_ci(metric_vals)
lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
lab = f"{BASELINE_LABELS[model]}\n(AUC={m:.3f}±{ci:.3f})"
color = COLOR_BASELINES[model]
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
@@ -230,9 +234,9 @@ def _plot_panel(
if np.all(np.isnan(mean_y)):
continue
metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d)
metric_vals = _auc_list(sub_d, kind)
m, ci = mean_ci(metric_vals)
lab = f"DeepSAD {net_for_deepsad} semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
lab = f"DeepSAD {net_for_deepsad}{sn}/{sa}\n(AUC={m:.3f}±{ci:.3f})"
color = COLOR_REGIMES[regime]
ls = LINESTYLES[regime]
@@ -246,7 +250,7 @@ def _plot_panel(
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
# Legend
ax.legend(loc="lower right", fontsize=9, frameon=True)
ax.legend(loc="upper right", fontsize=9, frameon=True)
def make_figures_for_dim(
@@ -254,9 +258,11 @@ def make_figures_for_dim(
):
# ROC: 2×1
fig_roc, axes = plt.subplots(
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
)
fig_roc.suptitle(
f"ROC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
)
fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
_plot_panel(
axes[0],
@@ -266,7 +272,7 @@ def make_figures_for_dim(
latent_dim=latent_dim,
kind="roc",
)
axes[0].set_title("DeepSAD (LeNet) + baselines")
axes[0].set_title("DeepSAD (LeNet) + Baselines")
_plot_panel(
axes[1],
@@ -276,7 +282,7 @@ def make_figures_for_dim(
latent_dim=latent_dim,
kind="roc",
)
axes[1].set_title("DeepSAD (Efficient) + baselines")
axes[1].set_title("DeepSAD (Efficient) + Baselines")
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
@@ -284,9 +290,11 @@ def make_figures_for_dim(
# PRC: 2×1
fig_prc, axes = plt.subplots(
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
)
fig_prc.suptitle(
f"PRC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
)
fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
_plot_panel(
axes[0],
@@ -296,7 +304,7 @@ def make_figures_for_dim(
latent_dim=latent_dim,
kind="prc",
)
axes[0].set_title("DeepSAD (LeNet) + baselines")
axes[0].set_title("DeepSAD (LeNet) + Baselines")
_plot_panel(
axes[1],
@@ -306,7 +314,7 @@ def make_figures_for_dim(
latent_dim=latent_dim,
kind="prc",
)
axes[1].set_title("DeepSAD (Efficient) + baselines")
axes[1].set_title("DeepSAD (Efficient) + Baselines")
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")