data loading and plotting for results wip
This commit is contained in:
71
tools/demo_loaded_data.py
Normal file
71
tools/demo_loaded_data.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
|
||||
from load_results import load_pretraining_results_dataframe, load_results_dataframe
|
||||
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Example “analysis-ready” queries (Polars idioms)
|
||||
# ------------------------------------------------------------
|
||||
def demo_queries(df: pl.DataFrame):
|
||||
# q1: lazy is fine, then collect
|
||||
q1 = (
|
||||
df.lazy()
|
||||
.filter(
|
||||
(pl.col("network") == "LeNet")
|
||||
& (pl.col("latent_dim") == 1024)
|
||||
& (pl.col("semi_normals") == 0)
|
||||
& (pl.col("semi_anomalous") == 0)
|
||||
& (pl.col("eval") == "exp_based")
|
||||
)
|
||||
.group_by(["model"])
|
||||
.agg(pl.col("auc").mean().alias("mean_auc"))
|
||||
.sort(["mean_auc"], descending=True)
|
||||
.collect()
|
||||
)
|
||||
|
||||
# q2: do the filtering eagerly, then pivot (LazyFrame has no .pivot)
|
||||
base = df.filter(
|
||||
(pl.col("model") == "deepsad")
|
||||
& (pl.col("eval") == "exp_based")
|
||||
& (pl.col("network") == "LeNet")
|
||||
& (pl.col("semi_normals") == 0)
|
||||
& (pl.col("semi_anomalous") == 0)
|
||||
).select("fold", "latent_dim", "auc")
|
||||
q2 = base.pivot(
|
||||
values="auc",
|
||||
index="fold",
|
||||
columns="latent_dim",
|
||||
aggregate_function="first", # or "mean" if duplicates exist
|
||||
).sort("fold")
|
||||
|
||||
# roc_subset: eager filter/select, then explode struct fields
|
||||
roc_subset = (
|
||||
df.filter(
|
||||
(pl.col("model") == "ocsvm")
|
||||
& (pl.col("eval") == "manual_based")
|
||||
& (pl.col("network") == "efficient")
|
||||
& (pl.col("latent_dim") == 1024)
|
||||
& (pl.col("semi_normals") == 0)
|
||||
& (pl.col("semi_anomalous") == 0)
|
||||
)
|
||||
.select("fold", "roc_curve")
|
||||
.with_columns(
|
||||
pl.col("roc_curve").struct.field("fpr").alias("fpr"),
|
||||
pl.col("roc_curve").struct.field("tpr").alias("tpr"),
|
||||
pl.col("roc_curve").struct.field("thr").alias("thr"),
|
||||
)
|
||||
)
|
||||
|
||||
return q1, q2, roc_subset
|
||||
|
||||
|
||||
def main():
|
||||
root = Path("/home/fedex/mt/results/done")
|
||||
df = load_results_dataframe(root, allow_cache=True)
|
||||
demo_queries(df)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user