Note
Click here to download the full example code
Fit Grid Cells population
import math
import os
from typing import Optional
import jax
import matplotlib.pyplot as plt
import nemos as nmo
import numpy as np
import pynapple as nap
import nemos as nmo
from scipy.ndimage import gaussian_filter
import workshop_utils
jax.config.update("jax_enable_x64", True)
DATA STREAMING
- Stream the data
io = workshop_utils.data.download_dandi_data("000582", "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
)
PYNAPPLE
- Load the data with pynapple
data = nap.NWBFile(io.read())
- Print the data
print(data)
- extract the spike times and the position of the animal frome the
data
object
spikes = data["units"] # Get spike timings
position = data["SpatialSeriesLED1"] # Get the tracked orientation of the animal
- compute the head-direction of the animal from
SpatialSeriesLED1
andSpatialSeriesLED1
diff = data['SpatialSeriesLED1'].values-data['SpatialSeriesLED2'].values
head_dir = (np.arctan2(*diff.T) + (2*np.pi))%(2*np.pi)
head_dir = nap.Tsd(data['SpatialSeriesLED1'].index, head_dir).dropna()
- compute the head-direction and position tuning curves
hd_tuning = nap.compute_1d_tuning_curves(
group=spikes,
feature=head_dir,
nb_bins=61,
minmax=(0, 2 * np.pi)
)
pos_tuning, binsxy = nap.compute_2d_tuning_curves(
group=spikes,
features=position,
nb_bins=12)
- plot the tuning curves for each neurons
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(2, len(spikes))
for i in range(len(spikes)):
ax = plt.subplot(gs[0,i], projection='polar')
ax.plot(hd_tuning.loc[:,i])
ax = plt.subplot(gs[1,i])
ax.imshow(gaussian_filter(pos_tuning[i], sigma=1))
plt.tight_layout()
NEMOS
- bin spike trains in 10 ms bin size
bin_size = 0.01 # second
counts = spikes.count(bin_size, ep=position.time_support)
- interpolate the position to the timestamps of counts using
interpolate
function of pynapple
position = position.interpolate(counts)
- define a basis in 2D using nemos
RaisedCosineBasisLinear
basis_2d = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10) * \
nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10)
- evaluate the basis on a 100x100 grid using
evaluate_on_grid
X, Y, Z = basis_2d.evaluate_on_grid(100, 100)
- plot the evaluated basis
fig, axs = plt.subplots(2,5, figsize=(10, 4))
for k in range(2):
for h in range(5):
axs[k][h].contourf(X, Y, Z[:, :, 50+2*(k+h)], cmap='Blues')
plt.tight_layout()
- rescale the position between 0 and 1 to match the basis functions
position = (position - np.min(position, 0)) / (np.max(position, 0) - np.min(position, 0))
- evaluate the basis for each position of the animal
position_basis = basis_2d.evaluate(position['x'], position['y'])
print(position_basis.shape)
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(2, 5)
xt = np.arange(0, 1000, 200)
cmap = plt.get_cmap("rainbow")
colors = np.linspace(0,1, len(xt))
for cnt, i in enumerate(xt):
ax = plt.subplot(gs[0, i // 200])
ax.imshow(position_basis[i].reshape(10, 10).T, origin = 'lower')
for spine in ["top", "bottom", "left","right"]:
ax.spines[spine].set_color(cmap(colors[cnt]))
ax.spines[spine].set_linewidth(3)
plt.title("T "+str(i))
ax = plt.subplot(gs[1, 2])
ax.plot(position['x'][0:1000], position['y'][0:1000])
for i in range(len(xt)):
ax.plot(position['x'][xt[i]], position['y'][xt[i]], 'o', color = cmap(colors[i]))
plt.tight_layout()
- instantiate a GLM model with Ridge regularization.
- set
regularizer_strength=1.0
model = workshop_utils.model.GLM(
regularizer=nmo.regularizer.Ridge(regularizer_strength=1.0, solver_name="LBFGS")
)
- fit the model only to neuron 7 for faster computation
neuron = 7
model.fit(position_basis, counts[:,neuron])
- predict the rate and compute a tuning curves using
compute_2d_tuning_curves_continuous
from pynapple
rate_pos = model.predict(position_basis)
rate_pos = nap.TsdFrame(t=counts.t, d=np.asarray(rate_pos), columns = [neuron])
model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
tsdframe=rate_pos,
features=position,
nb_bins=12)
- compare the tuning curves
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
ax.imshow(gaussian_filter(pos_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 1])
ax.imshow(gaussian_filter(model_tuning[neuron], sigma=1))
plt.tight_layout()
- find the best
regularizer_strength
usingsklearn.model_selection.GriSearchCV
from sklearn.model_selection import GridSearchCV
param_grid = dict(regularizer__regularizer_strength=[1e-6, 1e-3, 1.0])
cls = GridSearchCV(model, param_grid=param_grid)
cls.fit(position_basis, counts[:,neuron])
- instantiate the best model from scikit-learn
best_model = cls.best_estimator_
- predict the rate of the best model and compute a 2d tuning curves
best_rate_pos = best_model.predict(position_basis)
best_rate_pos = nap.TsdFrame(t=counts.t, d=np.asarray(best_rate_pos), columns=[neuron])
best_model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
tsdframe=best_rate_pos,
features=position,
nb_bins=12)
- compare the 2d tuning curves
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(1, 3)
ax = plt.subplot(gs[0, 0])
ax.imshow(gaussian_filter(pos_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 1])
ax.imshow(gaussian_filter(model_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 2])
ax.imshow(gaussian_filter(best_model_tuning[neuron], sigma=1))
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.000 seconds)
Download Python source code: 03_grid_cells_code.py