wip
This commit is contained in:
@@ -8,11 +8,11 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from matplotlib.lines import Line2D
|
||||
from scipy.stats import sem, t
|
||||
|
||||
# CHANGE THIS IMPORT IF YOUR LOADER MODULE NAME IS DIFFERENT
|
||||
from plot_scripts.load_results import load_results_dataframe
|
||||
from load_results import load_results_dataframe
|
||||
from matplotlib.lines import Line2D
|
||||
from scipy.stats import sem, t
|
||||
|
||||
# ---------------------------------
|
||||
# Config
|
||||
@@ -23,6 +23,10 @@ OUTPUT_DIR = Path("/home/fedex/mt/plots/results_semi_labels_comparison")
|
||||
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
|
||||
SEMI_REGIMES = [(0, 0), (50, 10), (500, 100)]
|
||||
EVALS = ["exp_based", "manual_based"]
|
||||
EVALS_LABELS = {
|
||||
"exp_based": "Experiment-Based Labels",
|
||||
"manual_based": "Manually-Labeled",
|
||||
}
|
||||
|
||||
# Interp grids
|
||||
ROC_GRID = np.linspace(0.0, 1.0, 200)
|
||||
@@ -30,6 +34,10 @@ PRC_GRID = np.linspace(0.0, 1.0, 200)
|
||||
|
||||
# Baselines are duplicated across nets; use Efficient-only to avoid repetition
|
||||
BASELINE_NET = "Efficient"
|
||||
BASELINE_LABELS = {
|
||||
"isoforest": "Isolation Forest",
|
||||
"ocsvm": "One-Class SVM",
|
||||
}
|
||||
|
||||
# Colors/styles
|
||||
COLOR_BASELINES = {
|
||||
@@ -147,12 +155,8 @@ def _select_rows(
|
||||
return df.filter(pl.all_horizontal(exprs))
|
||||
|
||||
|
||||
def _auc_list(sub: pl.DataFrame) -> list[float]:
|
||||
return [x for x in sub.select("auc").to_series().to_list() if x is not None]
|
||||
|
||||
|
||||
def _ap_list(sub: pl.DataFrame) -> list[float]:
|
||||
return [x for x in sub.select("ap").to_series().to_list() if x is not None]
|
||||
def _auc_list(sub: pl.DataFrame, kind: str) -> list[float]:
|
||||
return [x for x in sub.select(f"{kind}_auc").to_series().to_list() if x is not None]
|
||||
|
||||
|
||||
def _plot_panel(
|
||||
@@ -165,7 +169,7 @@ def _plot_panel(
|
||||
kind: str,
|
||||
):
|
||||
"""
|
||||
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + baselines (from Efficient).
|
||||
Plot one panel: DeepSAD (net_for_deepsad) with 3 regimes + Baselines (from Efficient).
|
||||
Legend entries include mean±CI of AUC/AP.
|
||||
"""
|
||||
ax.grid(True, alpha=0.3)
|
||||
@@ -200,9 +204,9 @@ def _plot_panel(
|
||||
continue
|
||||
|
||||
# Metric for legend
|
||||
metric_vals = _auc_list(sub_b) if kind == "roc" else _ap_list(sub_b)
|
||||
metric_vals = _auc_list(sub_b, kind)
|
||||
m, ci = mean_ci(metric_vals)
|
||||
lab = f"{model} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||
lab = f"{BASELINE_LABELS[model]}\n(AUC={m:.3f}±{ci:.3f})"
|
||||
|
||||
color = COLOR_BASELINES[model]
|
||||
h = ax.plot(grid, mean_y, lw=2, color=color, label=lab)[0]
|
||||
@@ -230,9 +234,9 @@ def _plot_panel(
|
||||
if np.all(np.isnan(mean_y)):
|
||||
continue
|
||||
|
||||
metric_vals = _auc_list(sub_d) if kind == "roc" else _ap_list(sub_d)
|
||||
metric_vals = _auc_list(sub_d, kind)
|
||||
m, ci = mean_ci(metric_vals)
|
||||
lab = f"DeepSAD {net_for_deepsad} — semi {sn}/{sa} ({'AUC' if kind == 'roc' else 'AP'}={m:.3f}±{ci:.3f})"
|
||||
lab = f"DeepSAD {net_for_deepsad} — {sn}/{sa}\n(AUC={m:.3f}±{ci:.3f})"
|
||||
|
||||
color = COLOR_REGIMES[regime]
|
||||
ls = LINESTYLES[regime]
|
||||
@@ -246,7 +250,7 @@ def _plot_panel(
|
||||
ax.plot([0, 1], [0, 1], "k--", alpha=0.6, label="Chance")
|
||||
|
||||
# Legend
|
||||
ax.legend(loc="lower right", fontsize=9, frameon=True)
|
||||
ax.legend(loc="upper right", fontsize=9, frameon=True)
|
||||
|
||||
|
||||
def make_figures_for_dim(
|
||||
@@ -254,9 +258,11 @@ def make_figures_for_dim(
|
||||
):
|
||||
# ROC: 2×1
|
||||
fig_roc, axes = plt.subplots(
|
||||
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
||||
)
|
||||
fig_roc.suptitle(
|
||||
f"ROC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
||||
)
|
||||
fig_roc.suptitle(f"ROC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||
|
||||
_plot_panel(
|
||||
axes[0],
|
||||
@@ -266,7 +272,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="roc",
|
||||
)
|
||||
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
||||
|
||||
_plot_panel(
|
||||
axes[1],
|
||||
@@ -276,7 +282,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="roc",
|
||||
)
|
||||
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
||||
|
||||
out_roc = out_dir / f"roc_{latent_dim}_{eval_type}.png"
|
||||
fig_roc.savefig(out_roc, dpi=150, bbox_inches="tight")
|
||||
@@ -284,9 +290,11 @@ def make_figures_for_dim(
|
||||
|
||||
# PRC: 2×1
|
||||
fig_prc, axes = plt.subplots(
|
||||
nrows=1, ncols=2, figsize=(14, 5), constrained_layout=True
|
||||
nrows=2, ncols=1, figsize=(7, 10), constrained_layout=True
|
||||
)
|
||||
fig_prc.suptitle(
|
||||
f"PRC — {EVALS_LABELS[eval_type]} — Latent Dim.={latent_dim}", fontsize=14
|
||||
)
|
||||
fig_prc.suptitle(f"PRC — {eval_type} — latent_dim={latent_dim}", fontsize=14)
|
||||
|
||||
_plot_panel(
|
||||
axes[0],
|
||||
@@ -296,7 +304,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="prc",
|
||||
)
|
||||
axes[0].set_title("DeepSAD (LeNet) + baselines")
|
||||
axes[0].set_title("DeepSAD (LeNet) + Baselines")
|
||||
|
||||
_plot_panel(
|
||||
axes[1],
|
||||
@@ -306,7 +314,7 @@ def make_figures_for_dim(
|
||||
latent_dim=latent_dim,
|
||||
kind="prc",
|
||||
)
|
||||
axes[1].set_title("DeepSAD (Efficient) + baselines")
|
||||
axes[1].set_title("DeepSAD (Efficient) + Baselines")
|
||||
|
||||
out_prc = out_dir / f"prc_{latent_dim}_{eval_type}.png"
|
||||
fig_prc.savefig(out_prc, dpi=150, bbox_inches="tight")
|
||||
|
||||
Reference in New Issue
Block a user