added error for incorrect preprocessing

This commit is contained in:
Jan Kowalczyk
2024-07-03 17:39:32 +02:00
parent 42fb437fe1
commit 61424bf053

View File

@@ -1,5 +1,6 @@
import torch
import numpy as np
import logging
def create_semisupervised_setting(
@@ -58,6 +59,15 @@ def create_semisupervised_setting(
n_unlabeled_outlier = int(x[2])
n_known_outlier = int(x[3])
if (
sum((n_known_normal, n_unlabeled_normal, n_unlabeled_outlier, n_known_outlier))
> labels.shape[0]
):
logger = logging.getLogger()
logger.error(
"Given ratios for the semi-supervised setting are not possible due to data restraints. Please change the ratios or provide more/different data."
)
# Sample indices
perm_normal = np.random.permutation(n_normal)
perm_outlier = np.random.permutation(len(idx_outlier))