hw / visualize.py
violet1723's picture
Upload folder using huggingface_hub
00c2650 verified
# -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
import scipy
import scipy.stats
from imageio import imsave
import cv2
def concat_images(images, image_width, spacer_size):
""" Concat image horizontally with spacer """
spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255
images_with_spacers = []
image_size = len(images)
for i in range(image_size):
images_with_spacers.append(images[i])
if i != image_size - 1:
# Add spacer
images_with_spacers.append(spacer)
ret = np.hstack(images_with_spacers)
return ret
def concat_images_in_rows(images, row_size, image_width, spacer_size=4):
""" Concat images in rows """
column_size = len(images) // row_size
spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4],
dtype=np.uint8) * 255
row_images_with_spacers = []
for row in range(row_size):
row_images = images[column_size*row:column_size*row+column_size]
row_concated_images = concat_images(row_images, image_width, spacer_size)
row_images_with_spacers.append(row_concated_images)
if row != row_size-1:
row_images_with_spacers.append(spacer_h)
ret = np.vstack(row_images_with_spacers)
return ret
def convert_to_colormap(im, cmap):
im = cmap(im)
im = np.uint8(im * 255)
return im
def rgb(im, cmap='jet', smooth=True):
cmap = plt.cm.get_cmap(cmap)
np.seterr(invalid='ignore') # ignore divide by zero err
im = (im - np.min(im)) / (np.max(im) - np.min(im))
if smooth:
im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0)
im = cmap(im)
im = np.uint8(im * 255)
return im
def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16):
images = [rgb(im, cmap, smooth) for im in activations[:n_plots]]
rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1])
return rm_fig
def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False):
'''Compute spatial firing fields
Args:
model: The RNN model
trajectory_generator: Generator for test trajectories
options: Training options
res: Resolution of the rate map grid
n_avg: Number of batches to average over
Ng: Number of grid cells to analyze
idxs: Indices of specific grid cells to analyze
return_raw: If True, also return raw activations (g) and positions (pos).
Warning: This uses significant memory for large batch_size/n_avg.
If False, returns None for g and pos to save memory.
Returns:
activations: Spatial firing fields [Ng, res, res]
rate_map: Flattened rate maps [Ng, res*res]
g: Raw activations (None if return_raw=False)
pos: Raw positions (None if return_raw=False)
'''
if not n_avg:
n_avg = 1000 // options.sequence_length
if not np.any(idxs):
idxs = np.arange(Ng)
idxs = idxs[:Ng]
# Only allocate large arrays if return_raw is True
if return_raw:
g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng])
pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2])
else:
g = None
pos = None
activations = np.zeros([Ng, res, res])
counts = np.zeros([res, res])
for index in range(n_avg):
inputs, pos_batch, _ = trajectory_generator.get_test_batch()
g_batch = model.g(inputs).detach().cpu().numpy()
pos_batch = np.reshape(pos_batch.cpu(), [-1, 2])
g_batch = g_batch[:,:,idxs].reshape(-1, Ng)
if return_raw:
g[index] = g_batch
pos[index] = pos_batch
x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res
y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res
for i in range(options.batch_size*options.sequence_length):
x = x_batch[i]
y = y_batch[i]
if x >=0 and x < res and y >=0 and y < res:
counts[int(x), int(y)] += 1
activations[:, int(x), int(y)] += g_batch[i, :]
for x in range(res):
for y in range(res):
if counts[x, y] > 0:
activations[:, x, y] /= counts[x, y]
if return_raw:
g = g.reshape([-1, Ng])
pos = pos.reshape([-1, 2])
# # scipy binned_statistic_2d is slightly slower
# activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0]
rate_map = activations.reshape(Ng, -1)
return activations, rate_map, g, pos
def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None):
if not n_avg:
n_avg = 1000 // options.sequence_length
activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator,
options, res=res, n_avg=n_avg)
rm_fig = plot_ratemaps(activations, n_plots=len(activations))
imdir = options.save_dir + "/" + options.run_ID
imsave(imdir + "/" + str(step) + ".png", rm_fig)
def save_autocorr(sess, model, save_name, trajectory_generator, step, flags):
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
coord_range=((-1.1, 1.1), (-1.1, 1.1))
masks_parameters = zip(starts, ends.tolist())
latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters)
res = dict()
index_size = 100
for _ in range(index_size):
feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height)
mb_res = sess.run({
'pos_xy': model.target_pos,
'bottleneck': model.g,
}, feed_dict=feed_dict)
res = utils.concat_dict(res, mb_res)
filename = save_name + '/autocorrs_' + str(step) + '.pdf'
imdir = flags.save_dir + '/'
out = utils.get_scores_and_plot(
latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
imdir, filename)