black formatted files before changes
This commit is contained in:
@@ -12,8 +12,16 @@ import numpy as np
|
||||
|
||||
class CIFAR10_Dataset(TorchvisionDataset):
|
||||
|
||||
def __init__(self, root: str, normal_class: int = 5, known_outlier_class: int = 3, n_known_outlier_classes: int = 0,
|
||||
ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
normal_class: int = 5,
|
||||
known_outlier_class: int = 3,
|
||||
n_known_outlier_classes: int = 0,
|
||||
ratio_known_normal: float = 0.0,
|
||||
ratio_known_outlier: float = 0.0,
|
||||
ratio_pollution: float = 0.0,
|
||||
):
|
||||
super().__init__(root)
|
||||
|
||||
# Define normal and outlier classes
|
||||
@@ -28,28 +36,48 @@ class CIFAR10_Dataset(TorchvisionDataset):
|
||||
elif n_known_outlier_classes == 1:
|
||||
self.known_outlier_classes = tuple([known_outlier_class])
|
||||
else:
|
||||
self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))
|
||||
self.known_outlier_classes = tuple(
|
||||
random.sample(self.outlier_classes, n_known_outlier_classes)
|
||||
)
|
||||
|
||||
# CIFAR-10 preprocessing: feature scaling to [0, 1]
|
||||
transform = transforms.ToTensor()
|
||||
target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))
|
||||
|
||||
# Get train set
|
||||
train_set = MyCIFAR10(root=self.root, train=True, transform=transform, target_transform=target_transform,
|
||||
download=True)
|
||||
train_set = MyCIFAR10(
|
||||
root=self.root,
|
||||
train=True,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
download=True,
|
||||
)
|
||||
|
||||
# Create semi-supervised setting
|
||||
idx, _, semi_targets = create_semisupervised_setting(np.array(train_set.targets), self.normal_classes,
|
||||
self.outlier_classes, self.known_outlier_classes,
|
||||
ratio_known_normal, ratio_known_outlier, ratio_pollution)
|
||||
train_set.semi_targets[idx] = torch.tensor(semi_targets) # set respective semi-supervised labels
|
||||
idx, _, semi_targets = create_semisupervised_setting(
|
||||
np.array(train_set.targets),
|
||||
self.normal_classes,
|
||||
self.outlier_classes,
|
||||
self.known_outlier_classes,
|
||||
ratio_known_normal,
|
||||
ratio_known_outlier,
|
||||
ratio_pollution,
|
||||
)
|
||||
train_set.semi_targets[idx] = torch.tensor(
|
||||
semi_targets
|
||||
) # set respective semi-supervised labels
|
||||
|
||||
# Subset train_set to semi-supervised setup
|
||||
self.train_set = Subset(train_set, idx)
|
||||
|
||||
# Get test set
|
||||
self.test_set = MyCIFAR10(root=self.root, train=False, transform=transform, target_transform=target_transform,
|
||||
download=True)
|
||||
self.test_set = MyCIFAR10(
|
||||
root=self.root,
|
||||
train=False,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
download=True,
|
||||
)
|
||||
|
||||
|
||||
class MyCIFAR10(CIFAR10):
|
||||
@@ -71,7 +99,11 @@ class MyCIFAR10(CIFAR10):
|
||||
Returns:
|
||||
tuple: (image, target, semi_target, index)
|
||||
"""
|
||||
img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index])
|
||||
img, target, semi_target = (
|
||||
self.data[index],
|
||||
self.targets[index],
|
||||
int(self.semi_targets[index]),
|
||||
)
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
|
||||
Reference in New Issue
Block a user