import pickle from pathlib import Path import matplotlib.pyplot as plt import numpy as np from scipy.stats import sem, t from sklearn.metrics import PrecisionRecallDisplay, auc def confidence_interval(data, confidence=0.95): """Compute mean and margin of error for a given list of scores.""" n = len(data) mean = np.mean(data) # Standard error of the mean: std_err = sem(data) # Confidence interval radius h = std_err * t.ppf((1 + confidence) / 2.0, n - 1) return mean, h # 1) LOAD PRECISION-RECALL DATA prc_data = [] # Stores (precision, recall) for each DeepSAD fold ap_scores = [] # Average Precision for each DeepSAD fold isoforest_prc_data = [] # Stores (precision, recall) for each IsoForest fold isoforest_ap_scores = [] # Average Precision for each IsoForest fold results_path = Path( "/home/fedex/mt/projects/thesis-kowalczyk-jan/Deep-SAD-PyTorch/log/DeepSAD/subter_kfold_800_3000_new" ) # We assume we have 5 folds (adjust if you have a different number) for i in range(5): with (results_path / f"results_{i}.pkl").open("rb") as f: data = pickle.load(f) precision, recall, _ = data["test_prc"] # (precision, recall, thresholds) prc_data.append((precision, recall)) # Compute Average Precision (AP) via AUC of the (recall, precision) curve ap_scores.append(auc(recall, precision)) with (results_path / f"results_isoforest_{i}.pkl").open("rb") as f: data = pickle.load(f) precision, recall, _ = data["test_prc"] isoforest_prc_data.append((precision, recall)) isoforest_ap_scores.append(auc(recall, precision)) # 2) CALCULATE PER-FOLD STATISTICS mean_ap, ap_ci = confidence_interval(ap_scores) isoforest_mean_ap, isoforest_ap_ci = confidence_interval(isoforest_ap_scores) # 3) INTERPOLATE EACH FOLD'S PRC ON A COMMON RECALL GRID FOR MEAN CURVE mean_recall = np.linspace(0, 1, 100) # -- DeepSAD deep_sad_precisions_interp = [] for precision, recall in prc_data: # Interpolate precision values at mean_recall interp_precision = np.interp(mean_recall, precision, recall) deep_sad_precisions_interp.append(interp_precision) mean_precision = np.mean(deep_sad_precisions_interp, axis=0) std_precision = np.std(deep_sad_precisions_interp, axis=0) # -- IsoForest isoforest_precisions_interp = [] for precision, recall in isoforest_prc_data: interp_precision = np.interp(mean_recall, precision, recall) isoforest_precisions_interp.append(interp_precision) isoforest_mean_precision = np.mean(isoforest_precisions_interp, axis=0) isoforest_std_precision = np.std(isoforest_precisions_interp, axis=0) # 4) CREATE PLOT USING PrecisionRecallDisplay fig, ax = plt.subplots(figsize=(8, 6)) # (A) Plot each fold (optional) for DeepSAD for i, (precision, recall) in enumerate(prc_data): disp = PrecisionRecallDisplay(precision=precision, recall=recall) # Label only the first fold to avoid legend clutter disp.plot( ax=ax, color="b", alpha=0.3, label=f"DeepSAD Fold {i+1}" if i == 0 else None ) # (B) Plot each fold (optional) for IsoForest for i, (precision, recall) in enumerate(isoforest_prc_data): disp = PrecisionRecallDisplay(precision=precision, recall=recall) disp.plot( ax=ax, color="r", alpha=0.3, label=f"IsoForest Fold {i+1}" if i == 0 else None ) # (C) Plot mean curve for DeepSAD mean_disp_deepsad = PrecisionRecallDisplay(precision=mean_precision, recall=mean_recall) mean_disp_deepsad.plot( ax=ax, color="b", label=f"DeepSAD Mean PR (AP={mean_ap:.2f} ± {ap_ci:.2f})" ) ax.fill_between( mean_recall, mean_precision - std_precision, mean_precision + std_precision, color="b", alpha=0.2, ) # (D) Plot mean curve for IsoForest mean_disp_isoforest = PrecisionRecallDisplay( precision=isoforest_mean_precision, recall=mean_recall ) mean_disp_isoforest.plot( ax=ax, color="r", label=f"IsoForest Mean PR (AP={isoforest_mean_ap:.2f} ± {isoforest_ap_ci:.2f})", ) ax.fill_between( mean_recall, isoforest_mean_precision - isoforest_std_precision, isoforest_mean_precision + isoforest_std_precision, color="r", alpha=0.2, ) # 5) FINAL PLOT ADJUSTMENTS ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title("Precision-Recall Curve with 5-Fold Cross-Validation") ax.legend(loc="upper right") plt.savefig("pr_curve_800_3000_2.png")