changed from thundersvm to sklearn ocsvm

This commit is contained in:
Michael eder
2025-06-25 10:18:31 +02:00
parent c552173cb2
commit 207eed14ef
2 changed files with 11 additions and 12 deletions

View File

@@ -11,7 +11,7 @@ from sklearn.metrics import (
roc_auc_score,
roc_curve,
)
from thundersvm import OneClassSVM
from sklearn.svm import OneClassSVM
from base.base_dataset import BaseADDataset
from networks.main import build_autoencoder
@@ -27,7 +27,7 @@ class OCSVM(object):
self.rho = None
self.gamma = None
self.model = OneClassSVM(kernel=kernel, nu=nu, verbose=True, max_mem_size=4048)
self.model = OneClassSVM(kernel=kernel, nu=nu)
self.hybrid = hybrid
self.latent_space_dim = latent_space_dim
@@ -166,8 +166,6 @@ class OCSVM(object):
kernel=self.kernel,
nu=self.nu,
gamma=gamma,
verbose=True,
max_mem_size=4048,
)
# Train
@@ -198,7 +196,7 @@ class OCSVM(object):
# If hybrid, also train a model with linear kernel
if self.hybrid:
self.linear_model = OneClassSVM(
kernel="linear", nu=self.nu, max_mem_size=4048
kernel="linear", nu=self.nu
)
start_time = time.time()
self.linear_model.fit(X)
@@ -479,14 +477,15 @@ class OCSVM(object):
self.ae_net.to(torch.device(device))
self.ae_net.eval()
def save_model(self, export_path: Path):
def save_model(self, export_path):
"""Save OC-SVM model to export_path."""
self.model.save_to_file(export_path)
with open(export_path, "wb") as fp:
pickle.dump(self.model, fp)
def load_model(self, import_path: Path):
def load_model(self, import_path):
"""Load OC-SVM model from import_path."""
self.model.save_to_file(import_path)
pass
with open(import_path, "rb") as fp:
self.model = pickle.load(fp)
def save_results(self, export_pkl):
with open(export_pkl, "wb") as fp:

View File

@@ -597,7 +597,7 @@ def main(
deepSAD.save_model(export_model=xp_path + "/model_deepsad.tar")
if train_ocsvm:
ocsvm.save_results(export_pkl=xp_path + "/results_ocsvm.pkl")
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.bin")
ocsvm.save_model(export_path=xp_path + "/model_ocsvm.pkl")
if train_isoforest:
Isoforest.save_results(
export_pkl=xp_path + "/results_isoforest.pkl"
@@ -616,7 +616,7 @@ def main(
export_pkl=xp_path + f"/results_ocsvm_{fold_idx}.pkl"
)
ocsvm.save_model(
export_path=xp_path + f"/model_ocsvm_{fold_idx}.bin"
export_path=xp_path + f"/model_ocsvm_{fold_idx}.pkl"
)
if train_isoforest:
Isoforest.save_results(