|
|
|
|
|
import numpy as np |
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
import scipy |
|
|
import scipy.stats |
|
|
from imageio import imsave |
|
|
import cv2 |
|
|
|
|
|
|
|
|
def concat_images(images, image_width, spacer_size): |
|
|
""" Concat image horizontally with spacer """ |
|
|
spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255 |
|
|
images_with_spacers = [] |
|
|
|
|
|
image_size = len(images) |
|
|
|
|
|
for i in range(image_size): |
|
|
images_with_spacers.append(images[i]) |
|
|
if i != image_size - 1: |
|
|
|
|
|
images_with_spacers.append(spacer) |
|
|
ret = np.hstack(images_with_spacers) |
|
|
return ret |
|
|
|
|
|
|
|
|
def concat_images_in_rows(images, row_size, image_width, spacer_size=4): |
|
|
""" Concat images in rows """ |
|
|
column_size = len(images) // row_size |
|
|
spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4], |
|
|
dtype=np.uint8) * 255 |
|
|
|
|
|
row_images_with_spacers = [] |
|
|
|
|
|
for row in range(row_size): |
|
|
row_images = images[column_size*row:column_size*row+column_size] |
|
|
row_concated_images = concat_images(row_images, image_width, spacer_size) |
|
|
row_images_with_spacers.append(row_concated_images) |
|
|
|
|
|
if row != row_size-1: |
|
|
row_images_with_spacers.append(spacer_h) |
|
|
|
|
|
ret = np.vstack(row_images_with_spacers) |
|
|
return ret |
|
|
|
|
|
|
|
|
def convert_to_colormap(im, cmap): |
|
|
im = cmap(im) |
|
|
im = np.uint8(im * 255) |
|
|
return im |
|
|
|
|
|
|
|
|
def rgb(im, cmap='jet', smooth=True): |
|
|
cmap = plt.cm.get_cmap(cmap) |
|
|
np.seterr(invalid='ignore') |
|
|
im = (im - np.min(im)) / (np.max(im) - np.min(im)) |
|
|
if smooth: |
|
|
im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0) |
|
|
im = cmap(im) |
|
|
im = np.uint8(im * 255) |
|
|
return im |
|
|
|
|
|
|
|
|
def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16): |
|
|
images = [rgb(im, cmap, smooth) for im in activations[:n_plots]] |
|
|
rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1]) |
|
|
return rm_fig |
|
|
|
|
|
|
|
|
def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False): |
|
|
'''Compute spatial firing fields |
|
|
|
|
|
Args: |
|
|
model: The RNN model |
|
|
trajectory_generator: Generator for test trajectories |
|
|
options: Training options |
|
|
res: Resolution of the rate map grid |
|
|
n_avg: Number of batches to average over |
|
|
Ng: Number of grid cells to analyze |
|
|
idxs: Indices of specific grid cells to analyze |
|
|
return_raw: If True, also return raw activations (g) and positions (pos). |
|
|
Warning: This uses significant memory for large batch_size/n_avg. |
|
|
If False, returns None for g and pos to save memory. |
|
|
|
|
|
Returns: |
|
|
activations: Spatial firing fields [Ng, res, res] |
|
|
rate_map: Flattened rate maps [Ng, res*res] |
|
|
g: Raw activations (None if return_raw=False) |
|
|
pos: Raw positions (None if return_raw=False) |
|
|
''' |
|
|
|
|
|
if not n_avg: |
|
|
n_avg = 1000 // options.sequence_length |
|
|
|
|
|
if not np.any(idxs): |
|
|
idxs = np.arange(Ng) |
|
|
idxs = idxs[:Ng] |
|
|
|
|
|
|
|
|
if return_raw: |
|
|
g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng]) |
|
|
pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2]) |
|
|
else: |
|
|
g = None |
|
|
pos = None |
|
|
|
|
|
activations = np.zeros([Ng, res, res]) |
|
|
counts = np.zeros([res, res]) |
|
|
|
|
|
for index in range(n_avg): |
|
|
inputs, pos_batch, _ = trajectory_generator.get_test_batch() |
|
|
g_batch = model.g(inputs).detach().cpu().numpy() |
|
|
|
|
|
pos_batch = np.reshape(pos_batch.cpu(), [-1, 2]) |
|
|
g_batch = g_batch[:,:,idxs].reshape(-1, Ng) |
|
|
|
|
|
if return_raw: |
|
|
g[index] = g_batch |
|
|
pos[index] = pos_batch |
|
|
|
|
|
x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res |
|
|
y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res |
|
|
|
|
|
for i in range(options.batch_size*options.sequence_length): |
|
|
x = x_batch[i] |
|
|
y = y_batch[i] |
|
|
if x >=0 and x < res and y >=0 and y < res: |
|
|
counts[int(x), int(y)] += 1 |
|
|
activations[:, int(x), int(y)] += g_batch[i, :] |
|
|
|
|
|
for x in range(res): |
|
|
for y in range(res): |
|
|
if counts[x, y] > 0: |
|
|
activations[:, x, y] /= counts[x, y] |
|
|
|
|
|
if return_raw: |
|
|
g = g.reshape([-1, Ng]) |
|
|
pos = pos.reshape([-1, 2]) |
|
|
|
|
|
|
|
|
|
|
|
rate_map = activations.reshape(Ng, -1) |
|
|
|
|
|
return activations, rate_map, g, pos |
|
|
|
|
|
|
|
|
def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None): |
|
|
if not n_avg: |
|
|
n_avg = 1000 // options.sequence_length |
|
|
activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator, |
|
|
options, res=res, n_avg=n_avg) |
|
|
rm_fig = plot_ratemaps(activations, n_plots=len(activations)) |
|
|
imdir = options.save_dir + "/" + options.run_ID |
|
|
imsave(imdir + "/" + str(step) + ".png", rm_fig) |
|
|
|
|
|
|
|
|
def save_autocorr(sess, model, save_name, trajectory_generator, step, flags): |
|
|
starts = [0.2] * 10 |
|
|
ends = np.linspace(0.4, 1.0, num=10) |
|
|
coord_range=((-1.1, 1.1), (-1.1, 1.1)) |
|
|
masks_parameters = zip(starts, ends.tolist()) |
|
|
latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters) |
|
|
|
|
|
res = dict() |
|
|
index_size = 100 |
|
|
for _ in range(index_size): |
|
|
feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height) |
|
|
mb_res = sess.run({ |
|
|
'pos_xy': model.target_pos, |
|
|
'bottleneck': model.g, |
|
|
}, feed_dict=feed_dict) |
|
|
res = utils.concat_dict(res, mb_res) |
|
|
|
|
|
filename = save_name + '/autocorrs_' + str(step) + '.pdf' |
|
|
imdir = flags.save_dir + '/' |
|
|
out = utils.get_scores_and_plot( |
|
|
latest_epoch_scorer, res['pos_xy'], res['bottleneck'], |
|
|
imdir, filename) |
|
|
|