"""
Provides some helpers to calculate Orientation Preferences of a Lissom Network
"""
from collections import Counter
from functools import lru_cache
import matplotlib.pyplot as plt
import numpy as np
import torch
from skimage.transform import rotate
from pylissom.utils.stimuli import translate, generate_horizontal_bar
from pylissom.utils.training.pipeline import Pipeline
[docs]class OrientationMap(object):
# TODO: optimize using vectorization to calculate activations
def __init__(self, model, inputs, use_tqdm_notebook=True):
self.use_tqdm_notebook = use_tqdm_notebook
self.model = model
self.inputs = inputs
[docs] def maximum_activations(self, model, inputs):
activations = []
for inp in inputs:
inp = Pipeline.process_input(inp)
act = model(inp)
activations.append(act)
maximums, _ = torch.max(torch.stack(activations), 0)
return maximums
[docs] def calculate_keys_activations(self, model, inputs):
return {k: self.maximum_activations(model, array) for k, array in inputs.items()}
[docs] @lru_cache()
def get_orientation_map(self):
activations = self.calculate_keys_activations(self.model, self.inputs)
mat = torch.stack(list(activations.values()))
_, preferences = torch.max(mat, 0)
keys = list(activations.keys())
orientation_map = [keys[idx.data[0]] for idx in preferences.squeeze()]
# Assumes Square Maps
rows = int(np.sqrt(self.model.out_features))
return np.reshape(np.asarray(orientation_map), (rows, rows))
[docs] @staticmethod
def orientation_hist(orientation_map):
orientation_hist = Counter(orientation_map.flatten().tolist())
return orientation_hist
[docs] @lru_cache()
def get_orientation_hist(self):
return self.orientation_hist(self.get_orientation_map())
[docs]def plot_orientation_map(orientation_map):
return plt.imshow(orientation_map, cmap='gist_rainbow')
[docs]def plot_orientation_hist(orientation_hist):
values = [float(v) for v in orientation_hist.values()]
labels = [str(k) + '°' for k in orientation_hist.keys()]
plot = plt.pie(values, labels=labels,
autopct='%.2f')
return plot
[docs]def metrics_orientation_hist(orientation_hist):
values = [float(v) for v in orientation_hist.values()]
normalized = values / np.linalg.norm(values, ord=1)
mean = np.mean(normalized)
std = np.std(normalized)
return mean, std
[docs]def get_oriented_lines(size, orientations=180):
vertical_bar = generate_horizontal_bar(size)
move_vertical = translate(vertical_bar, over_x=False)
inputs = {}
for degree in np.linspace(0, 180, num=orientations):
rotated = []
for im in move_vertical:
# mode : {‘constant’, ‘edge’, ‘symmetric’, ‘reflect’, ‘wrap’}
rot = rotate(im, degree, mode='reflect')
rotated.append(rot.astype(im.dtype))
inputs[int(degree)] = rotated
return numpy_dict_to_tensors(inputs)
[docs]def numpy_dict_to_tensors(d):
return {k: map(torch.from_numpy, v) for k, v in d.items()}