"""
Extends :py:mod:`torchvision.datasets` with two common Lissom stimuli, Oriented Gaussians and "Gaussian" Faces
"""
from random import shuffle
import numpy as np
import torch
from pylissom.utils.stimuli import random_gaussians_generator, faces_generator
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
[docs]def get_dataset(train, args):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'ck':
dataset = CKDataset()
return DataLoader(CKDataset(), batch_size=args.batch_size,
sampler=train_test_ck_samplers(dataset, train),
**kwargs)
elif args.dataset == 'mnist':
return DataLoader(
datasets.MNIST('../data', train=train, download=True, transform=transforms.ToTensor()),
batch_size=args.batch_size, **kwargs)
elif args.dataset == 'number_one':
raise NotImplementedError
[docs]class RandomDataset(Dataset):
r"""Abstract Dataset representing random samples, subclasses must implement
:py:func:`pylissom.datasets.RandomDataset._gen`"""
def __init__(self, length):
self._lenght = length
def __len__(self):
return self._lenght
def __getitem__(self, index):
if index >= (len(self)):
raise StopIteration
# Foo target
return torch.from_numpy(next(self._gen)), torch.Tensor(2)
@property
def _gen(self):
raise NotImplementedError
[docs]class OrientatedGaussians(RandomDataset):
r"""Dataset of random Oriented Gaussians samples, as used in Computional Maps in the Visual Cortex"""
@property
def _gen(self):
return random_gaussians_generator(self.size, self.gaussians)
def __init__(self, size, length, gaussians=2):
super(OrientatedGaussians, self).__init__(length)
self.gaussians = gaussians
self.size = size
[docs]class ThreeDotFaces(RandomDataset):
r"""Dataset of random Faces made of Three Gaussians Disks, as used in Computional Maps in the Visual Cortex"""
@property
def _gen(self):
return faces_generator(self.size, self.faces)
def __init__(self, size, length, faces=2):
super(ThreeDotFaces, self).__init__(length)
self.faces = faces
self.size = size
[docs]class CKDataset(Dataset):
def __init__(self, path_images='/home/hbari/data/X.npy', path_labels='/home/hbari/data/y.npy',
path_subjects='/home/hbari/data/subjs.npy'):
self.path_labels = path_labels
self.path_images = path_images
self.X = np.load(self.path_images)
# Substract 1 bc labels are 1-7 and need to start from 0
self.y = np.load(self.path_labels) - 1
self.subjs = np.load(path_subjects)
def __getitem__(self, item):
return torch.Tensor(self.X[item]), int(self.y[item])
def __len__(self):
return len(self.X)
[docs]def train_test_ck_samplers(ck_dataset, train, train_pct=0.5):
train_idxs, test_idxs = subj_indep_train_test_samplers(ck_dataset.subjs, pct=train_pct)
if train:
return SubsetRandomSampler(train_idxs)
else:
return SubsetRandomSampler(test_idxs)
[docs]def subj_indep_train_test_samplers(subjs, pct):
set_subjs = list(set(subjs))
shuffle(set_subjs)
split = int(len(set_subjs) * pct)
train_subjs = set_subjs[:split]
train_idxs = []
test_idxs = []
for idx, subj in enumerate(subjs):
if subj in train_subjs:
train_idxs.append(idx)
else:
test_idxs.append(idx)
return train_idxs, test_idxs