import copy
from sklearn.model_selection import GroupKFold
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from pylissom.datasets.datasets import subj_indep_train_test_samplers
from pylissom.utils.helpers import save_model
from pylissom.utils.training.pipeline import Pipeline
# from pylissom.nn.functional import images as images
[docs]class CVSubjectIndependent(object):
def __init__(self, ck_dataset, k=5):
other_idxs, self.test_idxs = subj_indep_train_test_samplers(ck_dataset.subjs, 1 - 1 / k)
self.folds = self._generate_folds(k - 1, ck_dataset, other_idxs)
[docs] def train_val_samplers(self):
return self.folds
[docs] def test_sampler(self):
return SubsetRandomSampler(self.test_idxs)
@staticmethod
def _generate_folds(k, ck_dataset, other_idxs):
kf = GroupKFold(n_splits=k)
return [(SubsetRandomSampler(train_index), SubsetRandomSampler(val_index)) for train_index, val_index in
kf.split(ck_dataset.X[other_idxs], ck_dataset.y[other_idxs],
ck_dataset.subjs[other_idxs])]
[docs]def run_cross_validation(model_fn, ck_dataset, cv_sampler, args):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
best_accuracy = 0
best_model = None
for fold, (train_sampler, val_sampler) in enumerate(cv_sampler.train_val_samplers()):
train_loader = DataLoader(ck_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
val_loader = DataLoader(ck_dataset, batch_size=args.batch_size, sampler=val_sampler, **kwargs)
model, optimizer, loss_fn = model_fn()
pipeline = Pipeline(model, optimizer, loss_fn, log_interval=args.log_interval, dataset_len=args.dataset_len,
cuda=args.cuda, prefix='fold_' + str(fold))
if args.save_images:
model[0].register_forward_hook(lambda *x: images.generate_images(*x, prefix='fold_' + str(fold)))
# TODO: Change epochs to 0
for epoch in range(1, args.epochs + 1):
pipeline.train(train_data_loader=train_loader, epoch=epoch)
curr_accuracy = pipeline.test(test_data_loader=val_loader, epoch=epoch)
if best_model is None or curr_accuracy > best_accuracy:
best_model = copy.deepcopy(model.state_dict())
best_accuracy = curr_accuracy
fold = 'test'
model, optimizer, loss_fn = model_fn()
model.load_state_dict(best_model)
save_model(model)
if args.save_images:
model[0].register_forward_hook(lambda *x: images.generate_images(*x, prefix='fold_' + str(fold)))
test_sampler = cv_sampler.test_sampler()
test_loader = DataLoader(ck_dataset, batch_size=args.batch_size, sampler=test_sampler, **kwargs)
pipeline = Pipeline(model, optimizer, loss_fn, log_interval=args.log_interval, dataset_len=args.dataset_len,
cuda=args.cuda, prefix='fold_' + str(fold))
pipeline.test(test_data_loader=test_loader, epoch=1)