# -*- coding: utf-8 -*- import numpy as np from matplotlib import pyplot as plt import scipy import scipy.stats from imageio import imsave import cv2 def concat_images(images, image_width, spacer_size): """ Concat image horizontally with spacer """ spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255 images_with_spacers = [] image_size = len(images) for i in range(image_size): images_with_spacers.append(images[i]) if i != image_size - 1: # Add spacer images_with_spacers.append(spacer) ret = np.hstack(images_with_spacers) return ret def concat_images_in_rows(images, row_size, image_width, spacer_size=4): """ Concat images in rows """ column_size = len(images) // row_size spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4], dtype=np.uint8) * 255 row_images_with_spacers = [] for row in range(row_size): row_images = images[column_size*row:column_size*row+column_size] row_concated_images = concat_images(row_images, image_width, spacer_size) row_images_with_spacers.append(row_concated_images) if row != row_size-1: row_images_with_spacers.append(spacer_h) ret = np.vstack(row_images_with_spacers) return ret def convert_to_colormap(im, cmap): im = cmap(im) im = np.uint8(im * 255) return im def rgb(im, cmap='jet', smooth=True): cmap = plt.cm.get_cmap(cmap) np.seterr(invalid='ignore') # ignore divide by zero err im = (im - np.min(im)) / (np.max(im) - np.min(im)) if smooth: im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0) im = cmap(im) im = np.uint8(im * 255) return im def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16): images = [rgb(im, cmap, smooth) for im in activations[:n_plots]] rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1]) return rm_fig def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False): '''Compute spatial firing fields Args: model: The RNN model trajectory_generator: Generator for test trajectories options: Training options res: Resolution of the rate map grid n_avg: Number of batches to average over Ng: Number of grid cells to analyze idxs: Indices of specific grid cells to analyze return_raw: If True, also return raw activations (g) and positions (pos). Warning: This uses significant memory for large batch_size/n_avg. If False, returns None for g and pos to save memory. Returns: activations: Spatial firing fields [Ng, res, res] rate_map: Flattened rate maps [Ng, res*res] g: Raw activations (None if return_raw=False) pos: Raw positions (None if return_raw=False) ''' if not n_avg: n_avg = 1000 // options.sequence_length if not np.any(idxs): idxs = np.arange(Ng) idxs = idxs[:Ng] # Only allocate large arrays if return_raw is True if return_raw: g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng]) pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2]) else: g = None pos = None activations = np.zeros([Ng, res, res]) counts = np.zeros([res, res]) for index in range(n_avg): inputs, pos_batch, _ = trajectory_generator.get_test_batch() g_batch = model.g(inputs).detach().cpu().numpy() pos_batch = np.reshape(pos_batch.cpu(), [-1, 2]) g_batch = g_batch[:,:,idxs].reshape(-1, Ng) if return_raw: g[index] = g_batch pos[index] = pos_batch x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res for i in range(options.batch_size*options.sequence_length): x = x_batch[i] y = y_batch[i] if x >=0 and x < res and y >=0 and y < res: counts[int(x), int(y)] += 1 activations[:, int(x), int(y)] += g_batch[i, :] for x in range(res): for y in range(res): if counts[x, y] > 0: activations[:, x, y] /= counts[x, y] if return_raw: g = g.reshape([-1, Ng]) pos = pos.reshape([-1, 2]) # # scipy binned_statistic_2d is slightly slower # activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0] rate_map = activations.reshape(Ng, -1) return activations, rate_map, g, pos def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None): if not n_avg: n_avg = 1000 // options.sequence_length activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator, options, res=res, n_avg=n_avg) rm_fig = plot_ratemaps(activations, n_plots=len(activations)) imdir = options.save_dir + "/" + options.run_ID imsave(imdir + "/" + str(step) + ".png", rm_fig) def save_autocorr(sess, model, save_name, trajectory_generator, step, flags): starts = [0.2] * 10 ends = np.linspace(0.4, 1.0, num=10) coord_range=((-1.1, 1.1), (-1.1, 1.1)) masks_parameters = zip(starts, ends.tolist()) latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters) res = dict() index_size = 100 for _ in range(index_size): feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height) mb_res = sess.run({ 'pos_xy': model.target_pos, 'bottleneck': model.g, }, feed_dict=feed_dict) res = utils.concat_dict(res, mb_res) filename = save_name + '/autocorrs_' + str(step) + '.pdf' imdir = flags.save_dir + '/' out = utils.get_scores_and_plot( latest_epoch_scorer, res['pos_xy'], res['bottleneck'], imdir, filename)