Some plotting functions that receive :py:mod:`numpy` matrices as input

import matplotlib.pyplot as plt
import numpy as np

[docs]def plot_matrix(img, vmin=0, vmax=1): r"""Plots a numpy matrix in grayscale normalizing with vmin, vmax""" plt.imshow(img, cmap='gray', vmin=vmin, vmax=vmax)
[docs]def plot_list_matrices(imgs): for i in imgs: plot_matrix(i)
[docs]def plot_dict_matrices(imgs): """ Args: imgs: Dictionary of {title: numpy matrix} items """ for k, i in imgs.items(): plt.title(k) plot_matrix(i)
[docs]def tensor_to_numpy_matrix(tensor, shape): """ Returns: Numpy version of tensor with shape = shape """ tensor = tensor.cpu() if tensor.is_cuda else tensor return np.reshape(tensor.numpy(), shape)
[docs]def plot_tensor(tensor, shape, vmin=0, vmax=1): img = tensor_to_numpy_matrix(tensor, shape) plot_matrix(img, vmin, vmax)