retest implemented and fixed missing center in save data
This commit is contained in:
192
Deep-SAD-PyTorch/flake.lock
generated
192
Deep-SAD-PyTorch/flake.lock
generated
@@ -1,192 +0,0 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_2": {
|
||||
"inputs": {
|
||||
"systems": "systems_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-github-actions": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"poetry2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1703863825,
|
||||
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
|
||||
"owner": "nix-community",
|
||||
"repo": "nix-github-actions",
|
||||
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "nix-github-actions",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1719327525,
|
||||
"narHash": "sha256-fPWiFM4aYbK9zGTt3KJ9CwX//iyElRiNHWNj2hk3i0E=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "191a3fd9786d09c8d82e89ed68c4463e7be09b3e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable-small",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-newest": {
|
||||
"locked": {
|
||||
"lastModified": 1749285348,
|
||||
"narHash": "sha256-frdhQvPbmDYaScPFiCnfdh3B/Vh81Uuoo0w5TkWmmjU=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "3e3afe5174c561dee0df6f2c2b2236990146329f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"poetry2nix": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils_2",
|
||||
"nix-github-actions": "nix-github-actions",
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"systems": "systems_3",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719358925,
|
||||
"narHash": "sha256-ZV/2YB7nyeYCsDm6EMH0EKtlpxuu2ImEd5WrlceNwRE=",
|
||||
"owner": "nix-community",
|
||||
"repo": "poetry2nix",
|
||||
"rev": "bbc1ee74fc1ac4082f617bf32f1c927e759717d2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "poetry2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-newest": "nixpkgs-newest",
|
||||
"poetry2nix": "poetry2nix"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_2": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_3": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"id": "systems",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"treefmt-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"poetry2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1718522839,
|
||||
"narHash": "sha256-ULzoKzEaBOiLRtjeY3YoGFJMwWSKRYOic6VNw2UyTls=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "68eb1dc333ce82d0ab0c0357363ea17c31ea1f81",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
{
|
||||
description = "Deepsad devenv with python 3.11";
|
||||
|
||||
inputs = {
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
|
||||
# Added newest nixpkgs for an updated poetry package.
|
||||
nixpkgs-newest.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
poetry2nix = {
|
||||
url = "github:nix-community/poetry2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
nixpkgs,
|
||||
nixpkgs-newest,
|
||||
flake-utils,
|
||||
poetry2nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
# see https://github.com/nix-community/poetry2nix/tree/master#api for more functions and examples.
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfree = true;
|
||||
config.cudaSupport = true;
|
||||
};
|
||||
pkgsNew = nixpkgs-newest.legacyPackages.${system};
|
||||
thundersvm = import ./nix/thundersvm.nix {
|
||||
inherit pkgs;
|
||||
inherit (pkgs) fetchFromGitHub cmake gcc12Stdenv;
|
||||
cudaPackages = pkgs.cudaPackages;
|
||||
};
|
||||
|
||||
thundersvm-python = import ./nix/thundersvm-python.nix {
|
||||
inherit pkgs;
|
||||
pythonPackages = pkgs.python311Packages;
|
||||
thundersvm = thundersvm;
|
||||
};
|
||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryApplication defaultPoetryOverrides;
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
deepsad = mkPoetryApplication {
|
||||
projectDir = self;
|
||||
preferWheels = true;
|
||||
python = pkgs.python311;
|
||||
overrides = defaultPoetryOverrides.extend (
|
||||
final: prev: {
|
||||
torch-receptive-field = prev.torch-receptive-field.overridePythonAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [ prev.setuptools ];
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
default = self.packages.${system}.deepsad;
|
||||
};
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = [ self.packages.${system}.deepsad ];
|
||||
buildInputs = with pkgs.python311Packages; [
|
||||
torch-bin
|
||||
torchvision-bin
|
||||
thundersvm-python
|
||||
];
|
||||
#LD_LIBRARY_PATH = with pkgs; lib.makeLibraryPath [
|
||||
#pkgs.stdenv.cc.cc
|
||||
#];
|
||||
};
|
||||
|
||||
devShells.poetry = pkgs.mkShell {
|
||||
packages = [
|
||||
pkgsNew.poetry
|
||||
pkgs.python311
|
||||
];
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
1002
Deep-SAD-PyTorch/poetry.lock
generated
1002
Deep-SAD-PyTorch/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,32 +1,29 @@
|
||||
[tool.poetry]
|
||||
[project]
|
||||
name = "deep-sad-pytorch"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Your Name <you@example.com>"]
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"click>=8.2.1",
|
||||
"cvxopt>=1.3.2",
|
||||
"cycler>=0.12.1",
|
||||
"joblib>=1.5.1",
|
||||
"kiwisolver>=1.4.8",
|
||||
"matplotlib>=3.10.3",
|
||||
"numpy>=2.3.1",
|
||||
"pandas>=2.3.0",
|
||||
"pillow>=11.2.1",
|
||||
"pyparsing>=3.2.3",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"pytz>=2025.2",
|
||||
"scikit-learn>=1.7.0",
|
||||
"scipy>=1.16.0",
|
||||
"seaborn>=0.13.2",
|
||||
"six>=1.17.0",
|
||||
"torch-receptive-field",
|
||||
"torchscan>=0.1.1",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.11,<3.12"
|
||||
click = "^8.1.7"
|
||||
matplotlib = "^3.9.0"
|
||||
numpy = "^2.0.0"
|
||||
pandas = "^2.2.2"
|
||||
cvxopt = "^1.3.2"
|
||||
cycler = "^0.12.1"
|
||||
joblib = "^1.4.2"
|
||||
kiwisolver = "^1.4.5"
|
||||
pillow = "^10.3.0"
|
||||
pyparsing = "^3.1.2"
|
||||
python-dateutil = "^2.9.0.post0"
|
||||
pytz = "^2024.1"
|
||||
scikit-learn = "^1.5.0"
|
||||
scipy = "^1.14.0"
|
||||
seaborn = "^0.13.2"
|
||||
six = "^1.16.0"
|
||||
torchscan = "^0.1.2"
|
||||
torch-receptive-field = {git = "https://github.com/Fangyh09/pytorch-receptive-field.git"}
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
[tool.uv.sources]
|
||||
torch-receptive-field = { git = "https://github.com/Fangyh09/pytorch-receptive-field.git" }
|
||||
|
||||
@@ -126,6 +126,8 @@ class DeepSAD(object):
|
||||
)
|
||||
# Get the model
|
||||
self.net = self.trainer.train(dataset, self.net, k_fold_idx=k_fold_idx)
|
||||
|
||||
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list
|
||||
# Store training results including indices
|
||||
self.results["train"]["time"] = self.trainer.train_time
|
||||
self.results["train"]["indices"] = self.trainer.train_indices
|
||||
@@ -333,7 +335,7 @@ class DeepSAD(object):
|
||||
# load autoencoder parameters if specified
|
||||
if load_ae:
|
||||
if self.ae_net is None:
|
||||
self.ae_net = build_autoencoder(self.net_name)
|
||||
self.ae_net = build_autoencoder(self.net_name, self.rep_dim)
|
||||
self.ae_net.load_state_dict(model_dict["ae_net_dict"])
|
||||
|
||||
def save_results(self, export_pkl):
|
||||
|
||||
@@ -25,7 +25,8 @@ from utils.visualization.plot_images_grid import plot_images_grid
|
||||
[
|
||||
"train",
|
||||
"infer",
|
||||
"ae_elbow_test", # Add new action
|
||||
"ae_elbow_test",
|
||||
"retest",
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -773,6 +774,165 @@ def main(
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown action: {action}")
|
||||
elif action == "retest":
|
||||
if (
|
||||
not load_model
|
||||
or not Path(load_model).exists()
|
||||
or not Path(load_model).is_dir()
|
||||
):
|
||||
logger.error(
|
||||
"For retest mode a model directory has to be loaded! Pass the --load_model option with the model directory path!"
|
||||
)
|
||||
return
|
||||
load_model = Path(load_model)
|
||||
if not load_config:
|
||||
logger.error(
|
||||
"For retest mode a config has to be loaded! Pass the --load_config option with the config path!"
|
||||
)
|
||||
return
|
||||
|
||||
dataset = load_dataset(
|
||||
cfg.settings["dataset_name"],
|
||||
data_path,
|
||||
cfg.settings["normal_class"],
|
||||
cfg.settings["known_outlier_class"],
|
||||
cfg.settings["n_known_outlier_classes"],
|
||||
cfg.settings["ratio_known_normal"],
|
||||
cfg.settings["ratio_known_outlier"],
|
||||
cfg.settings["ratio_pollution"],
|
||||
random_state=np.random.RandomState(cfg.settings["seed"]),
|
||||
k_fold_num=cfg.settings["k_fold_num"],
|
||||
num_known_normal=cfg.settings["num_known_normal"],
|
||||
num_known_outlier=cfg.settings["num_known_outlier"],
|
||||
)
|
||||
|
||||
train_passes = (
|
||||
range(cfg.settings["k_fold_num"]) if cfg.settings["k_fold"] else [None]
|
||||
)
|
||||
|
||||
retest_isoforest = True
|
||||
retest_ocsvm = True
|
||||
retest_deepsad = True
|
||||
|
||||
for fold_idx in train_passes:
|
||||
if fold_idx is None:
|
||||
logger.info("Single train re-testing without k-fold")
|
||||
deepsad_model_path = load_model / "model_deepsad.tar"
|
||||
isoforest_model_path = load_model / "model_ocsvm.pkl"
|
||||
ocsvm_model_path = load_model / "model_isoforest.pkl"
|
||||
ae_model_path = load_model / "model_ae.tar"
|
||||
else:
|
||||
logger.info(f"Fold {fold_idx + 1}/{cfg.settings['k_fold_num']}")
|
||||
|
||||
deepsad_model_path = load_model / f"model_deepsad_{fold_idx}.tar"
|
||||
isoforest_model_path = load_model / f"model_isoforest_{fold_idx}.pkl"
|
||||
ocsvm_model_path = load_model / f"model_ocsvm_{fold_idx}.pkl"
|
||||
ae_model_path = load_model / f"model_ae_{fold_idx}.tar"
|
||||
|
||||
# Check which model files exist and which do not
|
||||
model_paths = [
|
||||
deepsad_model_path,
|
||||
isoforest_model_path,
|
||||
ocsvm_model_path,
|
||||
ae_model_path,
|
||||
]
|
||||
missing_models = [
|
||||
path.name
|
||||
for path in model_paths
|
||||
if not path.exists() or not path.is_file()
|
||||
]
|
||||
if missing_models:
|
||||
logger.error(
|
||||
f"The following model files do not exist: {', '.join(missing_models)}. Please check the load_model path."
|
||||
)
|
||||
return
|
||||
|
||||
# Initialize Isolation Forest model
|
||||
if retest_isoforest:
|
||||
Isoforest = IsoForest(
|
||||
hybrid=False,
|
||||
n_estimators=cfg.settings["isoforest_n_estimators"],
|
||||
max_samples=cfg.settings["isoforest_max_samples"],
|
||||
contamination=cfg.settings["isoforest_contamination"],
|
||||
n_jobs=cfg.settings["isoforest_n_jobs_model"],
|
||||
seed=cfg.settings["seed"],
|
||||
)
|
||||
Isoforest.load_model(import_path=isoforest_model_path, device=device)
|
||||
Isoforest.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
# Initialize DeepSAD model and set neural network phi
|
||||
if retest_deepsad:
|
||||
deepSAD = DeepSAD(cfg.settings["latent_space_dim"], cfg.settings["eta"])
|
||||
deepSAD.set_network(cfg.settings["net_name"])
|
||||
deepSAD.load_model(
|
||||
model_path=deepsad_model_path, load_ae=True, map_location=device
|
||||
)
|
||||
logger.info("Loading model from %s." % load_model)
|
||||
deepSAD.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
)
|
||||
|
||||
if retest_ocsvm:
|
||||
ocsvm = OCSVM(
|
||||
kernel=cfg.settings["ocsvm_kernel"],
|
||||
nu=cfg.settings["ocsvm_nu"],
|
||||
hybrid=True,
|
||||
latent_space_dim=cfg.settings["latent_space_dim"],
|
||||
)
|
||||
ocsvm.load_ae(
|
||||
net_name=cfg.settings["net_name"],
|
||||
model_path=ae_model_path,
|
||||
device=device,
|
||||
)
|
||||
ocsvm.load_model(import_path=ocsvm_model_path)
|
||||
ocsvm.test(
|
||||
dataset,
|
||||
device=device,
|
||||
n_jobs_dataloader=cfg.settings["n_jobs_dataloader"],
|
||||
k_fold_idx=fold_idx,
|
||||
batch_size=256,
|
||||
)
|
||||
|
||||
retest_output_path = load_model / "retest_output"
|
||||
retest_output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save results, model, and configuration
|
||||
if fold_idx is None:
|
||||
if retest_deepsad:
|
||||
deepSAD.save_results(
|
||||
export_pkl=retest_output_path / "results_deepsad.pkl"
|
||||
)
|
||||
if retest_ocsvm:
|
||||
ocsvm.save_results(
|
||||
export_pkl=retest_output_path / "results_ocsvm.pkl"
|
||||
)
|
||||
if retest_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=retest_output_path / "results_isoforest.pkl"
|
||||
)
|
||||
else:
|
||||
if retest_deepsad:
|
||||
deepSAD.save_results(
|
||||
export_pkl=retest_output_path
|
||||
/ f"results_deepsad_{fold_idx}.pkl"
|
||||
)
|
||||
if retest_ocsvm:
|
||||
ocsvm.save_results(
|
||||
export_pkl=retest_output_path / f"/results_ocsvm_{fold_idx}.pkl"
|
||||
)
|
||||
if retest_isoforest:
|
||||
Isoforest.save_results(
|
||||
export_pkl=retest_output_path
|
||||
/ f"/results_isoforest_{fold_idx}.pkl"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user