"""
sampler.py
=======================
Balanced sampling based on one of the columns of the patch information.
"""
import torch
import torch.utils.data
import torchvision
import numpy as np
[docs]class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py
Arguments:
indices (list, optional): a list of indices
num_samples (int, optional): number of samples to draw
"""
def __init__(self, dataset, indices=None, num_samples=None):
# if indices is not provided,
# all elements in the dataset will be considered
self.indices = list(range(len(dataset))) \
if indices is None else indices
self.n_targets=len(dataset.targets)
# if num_samples is not provided,
# draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices) \
if num_samples is None else num_samples
# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
if label in label_to_count:
label_to_count[label] += 1
else:
label_to_count[label] = 1
# weight for each sample
weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
def _get_label(self, dataset, idx):
dataset_type = type(dataset)
if dataset_type is torchvision.datasets.MNIST:
return dataset.train_labels[idx].item()
elif dataset_type is torchvision.datasets.ImageFolder:
return dataset.imgs[idx][1]
else:
y=dataset.patch_info.iloc[idx][dataset.targets].values
if self.n_targets>1:
y=np.argmax(y)
elif isinstance(y,(list,np.ndarray)):
y=y[0]
#print(y)
return int(y)
def __iter__(self):
return (self.indices[i] for i in torch.multinomial(
self.weights, self.num_samples, replacement=True))
def __len__(self):
return self.num_samples