Source code for pathflowai.sampler

"""
sampler.py
=======================
Balanced sampling based on one of the columns of the patch information.
"""

import torch
import torch.utils.data
import torchvision
import numpy as np


[docs]class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): """Samples elements randomly from a given list of indices for imbalanced dataset https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py Arguments: indices (list, optional): a list of indices num_samples (int, optional): number of samples to draw """ def __init__(self, dataset, indices=None, num_samples=None): # if indices is not provided, # all elements in the dataset will be considered self.indices = list(range(len(dataset))) \ if indices is None else indices self.n_targets=len(dataset.targets) # if num_samples is not provided, # draw `len(indices)` samples in each iteration self.num_samples = len(self.indices) \ if num_samples is None else num_samples # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(dataset, idx) if label in label_to_count: label_to_count[label] += 1 else: label_to_count[label] = 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices] self.weights = torch.DoubleTensor(weights) def _get_label(self, dataset, idx): dataset_type = type(dataset) if dataset_type is torchvision.datasets.MNIST: return dataset.train_labels[idx].item() elif dataset_type is torchvision.datasets.ImageFolder: return dataset.imgs[idx][1] else: y=dataset.patch_info.iloc[idx][dataset.targets].values if self.n_targets>1: y=np.argmax(y) elif isinstance(y,(list,np.ndarray)): y=y[0] #print(y) return int(y) def __iter__(self): return (self.indices[i] for i in torch.multinomial( self.weights, self.num_samples, replacement=True)) def __len__(self): return self.num_samples