Files
mt/tools/plot_scripts/results_latent_space_tables.py
Jan Kowalczyk 95867bde7a table plot
2025-09-17 11:07:07 +02:00

256 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import polars as pl
# CHANGE THIS IMPORT IF YOUR LOADER MODULE IS NAMED DIFFERENTLY
from load_results import load_results_dataframe
# ----------------------------
# Config
# ----------------------------
ROOT = Path("/home/fedex/mt/results/copy") # experiments root you pass to the loader
OUTPUT_DIR = Path("/home/fedex/mt/plots/results_latent_space_tables")
# Semi-labeling regimes (semi_normals, semi_anomalous)
SEMI_LABELING_REGIMES: list[tuple[int, int]] = [(0, 0), (50, 10), (500, 100)]
# Which evaluation columns to include (one table per eval × semi-regime)
EVALS: list[str] = ["exp_based", "manual_based"]
# Row order (latent dims)
LATENT_DIMS: list[int] = [32, 64, 128, 256, 512, 768, 1024]
# Column order (method shown to the user)
# We split DeepSAD into the two network backbones, like your plots.
METHOD_COLUMNS = [
("deepsad", "LeNet"), # DeepSAD (LeNet)
("deepsad", "Efficient"), # DeepSAD (Efficient)
("isoforest", "Efficient"), # IsolationForest (Efficient backbone baseline)
("ocsvm", "Efficient"), # OC-SVM (Efficient backbone baseline)
]
# Formatting
DECIMALS = 3 # number of decimals for mean/std
STD_FMT = r"\textpm" # between mean and std in LaTeX
# ----------------------------
# Helpers
# ----------------------------
def _with_net_label(df: pl.DataFrame) -> pl.DataFrame:
"""Add a canonical 'net_label' column like the plotting script (LeNet/Efficient/fallback)."""
return df.with_columns(
pl.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("lenet")
)
.then(pl.lit("LeNet"))
.when(
pl.col("network").cast(pl.Utf8).str.to_lowercase().str.contains("efficient")
)
.then(pl.lit("Efficient"))
.otherwise(pl.col("network").cast(pl.Utf8))
.alias("net_label")
)
def _filter_base(
df: pl.DataFrame,
*,
eval_type: str,
semi_normals: int,
semi_anomalous: int,
) -> pl.DataFrame:
"""Common filtering by regime/eval/valid dims&models."""
return df.filter(
(pl.col("semi_normals") == semi_normals)
& (pl.col("semi_anomalous") == semi_anomalous)
& (pl.col("eval") == eval_type)
& (pl.col("latent_dim").is_in(LATENT_DIMS))
& (pl.col("model").is_in(["deepsad", "isoforest", "ocsvm"]))
).select(
"model",
"net_label",
"latent_dim",
"fold",
"auc",
)
def _format_mean_std(mean: float | None, std: float | None) -> str:
if mean is None or (mean != mean): # NaN check
return "--"
if std is None or (std != std):
return f"{mean:.{DECIMALS}f}"
return f"{mean:.{DECIMALS}f} {STD_FMT} {std:.{DECIMALS}f}"
@dataclass(frozen=True)
class Cell:
mean: float | None
std: float | None
def _compute_cells(df: pl.DataFrame) -> dict[tuple[int, str, str], Cell]:
"""
Compute per-(latent_dim, model, net_label) mean/std for AUC across folds.
Returns a dict keyed by (latent_dim, model, net_label).
"""
if df.is_empty():
return {}
agg = (
df.group_by(["latent_dim", "model", "net_label"])
.agg(
pl.col("auc").mean().alias("mean_auc"), pl.col("auc").std().alias("std_auc")
)
.to_dicts()
)
out: dict[tuple[int, str, str], Cell] = {}
for row in agg:
key = (int(row["latent_dim"]), str(row["model"]), str(row["net_label"]))
out[key] = Cell(mean=row.get("mean_auc"), std=row.get("std_auc"))
return out
def _bold_best_in_row(values: list[float | None]) -> list[bool]:
"""Return a mask of which entries are the (tied) maximum among non-None values."""
clean = [(v if (v is not None and v == v) else None) for v in values]
finite_vals = [v for v in clean if v is not None]
if not finite_vals:
return [False] * len(values)
maxv = max(finite_vals)
return [(v is not None and abs(v - maxv) < 1e-12) for v in clean]
def _latex_table(
cells: dict[tuple[int, str, str], Cell],
*,
eval_type: str,
semi_normals: int,
semi_anomalous: int,
) -> str:
"""
Build a LaTeX table with rows=latent dims and columns=METHOD_COLUMNS.
Bold best AUC (mean) per row.
"""
header_cols = [
r"\textbf{DeepSAD (LeNet)}",
r"\textbf{DeepSAD (Efficient)}",
r"\textbf{IsolationForest}",
r"\textbf{OC\text{-}SVM}",
]
eval_type_str = (
"experiment-based evaluation"
if eval_type == "exp_based"
else "handlabeling-based evaluation"
)
lines: list[str] = []
lines.append(r"\begin{table}[t]")
lines.append(r"\centering")
lines.append(
rf"\caption{{AUC (mean {STD_FMT} std) across 5 folds for \texttt{{{eval_type_str}}}, "
rf"semi-labeling regime: {semi_normals} normal samples {semi_anomalous} anomalous samples.}}"
)
lines.append(rf"\label{{tab:auc_{eval_type}_semi_{semi_normals}_{semi_anomalous}}}")
lines.append(r"\begin{tabularx}{\textwidth}{cYYYY}")
lines.append(r"\toprule")
lines.append(r"\textbf{Latent Dim.} & " + " & ".join(header_cols) + r" \\")
lines.append(r"\midrule")
for dim in LATENT_DIMS:
# Collect means for bolding
means_for_bold: list[float | None] = []
cell_strs: list[str] = []
for model, net in METHOD_COLUMNS:
cell = cells.get((dim, model, net), Cell(None, None))
means_for_bold.append(cell.mean)
cell_strs.append(_format_mean_std(cell.mean, cell.std))
bold_mask = _bold_best_in_row(means_for_bold)
pretty_cells: list[str] = []
for s, do_bold in zip(cell_strs, bold_mask):
if do_bold and s != "--":
pretty_cells.append(r"\textbf{" + s + r"}")
else:
pretty_cells.append(s)
lines.append(f"{dim} & " + " & ".join(pretty_cells) + r" \\")
lines.append(r"\bottomrule")
lines.append(r"\end{tabularx}")
lines.append(r"\end{table}")
return "\n".join(lines)
def main():
# Load full results DF (cache behavior handled by your loader)
df = load_results_dataframe(ROOT, allow_cache=True)
df = _with_net_label(df)
# Prepare output dirs
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
archive_dir = OUTPUT_DIR / "archive"
archive_dir.mkdir(parents=True, exist_ok=True)
ts_dir = archive_dir / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ts_dir.mkdir(parents=True, exist_ok=True)
emitted_files: list[Path] = []
for semi_normals, semi_anomalous in SEMI_LABELING_REGIMES:
for eval_type in EVALS:
sub = _filter_base(
df,
eval_type=eval_type,
semi_normals=semi_normals,
semi_anomalous=semi_anomalous,
)
# For baselines (isoforest/ocsvm) we constrain to Efficient backbone to mirror plots
sub = sub.filter(
pl.when(pl.col("model").is_in(["isoforest", "ocsvm"]))
.then(pl.col("net_label") == "Efficient")
.otherwise(True)
)
cells = _compute_cells(sub)
tex = _latex_table(
cells,
eval_type=eval_type,
semi_normals=semi_normals,
semi_anomalous=semi_anomalous,
)
out_name = f"auc_table_{eval_type}_semi_{semi_normals}_{semi_anomalous}.tex"
out_path = ts_dir / out_name
out_path.write_text(tex, encoding="utf-8")
emitted_files.append(out_path)
# Copy this script to preserve the code used for the outputs
script_path = Path(__file__)
shutil.copy2(script_path, ts_dir / script_path.name)
# Mirror latest
latest = OUTPUT_DIR / "latest"
latest.mkdir(exist_ok=True, parents=True)
for f in latest.iterdir():
if f.is_file():
f.unlink()
for f in ts_dir.iterdir():
if f.is_file():
shutil.copy2(f, latest / f.name)
print(f"Saved tables to: {ts_dir}")
print(f"Also updated: {latest}")
for p in emitted_files:
print(f" - {p.name}")
if __name__ == "__main__":
main()