Skip to content

Note

Click here to download the full example code

Fit Grid Cells population

import math
import os
from typing import Optional

import jax
import matplotlib.pyplot as plt
import nemos as nmo
import numpy as np
import pynapple as nap
import nemos as nmo
from scipy.ndimage import gaussian_filter

import workshop_utils

jax.config.update("jax_enable_x64", True)

DATA STREAMING

  • Stream the data
io = workshop_utils.data.download_dandi_data("000582", "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
)

PYNAPPLE

  • Load the data with pynapple
data = nap.NWBFile(io.read())
  • Print the data
print(data)
  • extract the spike times and the position of the animal frome the data object
spikes = data["units"]  # Get spike timings
position = data["SpatialSeriesLED1"] # Get the tracked orientation of the animal
  • compute the head-direction of the animal from SpatialSeriesLED1 and SpatialSeriesLED1
diff = data['SpatialSeriesLED1'].values-data['SpatialSeriesLED2'].values
head_dir = (np.arctan2(*diff.T) + (2*np.pi))%(2*np.pi)
head_dir = nap.Tsd(data['SpatialSeriesLED1'].index, head_dir).dropna()
  • compute the head-direction and position tuning curves
hd_tuning = nap.compute_1d_tuning_curves(
    group=spikes, 
    feature=head_dir,
    nb_bins=61, 
    minmax=(0, 2 * np.pi)
    )

pos_tuning, binsxy = nap.compute_2d_tuning_curves(
    group=spikes, 
    features=position, 
    nb_bins=12)
  • plot the tuning curves for each neurons
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(2, len(spikes))
for i in range(len(spikes)):
    ax = plt.subplot(gs[0,i], projection='polar')
    ax.plot(hd_tuning.loc[:,i])

    ax = plt.subplot(gs[1,i])
    ax.imshow(gaussian_filter(pos_tuning[i], sigma=1))
plt.tight_layout()

NEMOS

  • bin spike trains in 10 ms bin size
bin_size = 0.01 # second
counts = spikes.count(bin_size, ep=position.time_support)
  • interpolate the position to the timestamps of counts using interpolate function of pynapple
position = position.interpolate(counts)
  • define a basis in 2D using nemos RaisedCosineBasisLinear
basis_2d = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10) * \
            nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10)
  • evaluate the basis on a 100x100 grid using evaluate_on_grid
X, Y, Z = basis_2d.evaluate_on_grid(100, 100)
  • plot the evaluated basis
fig, axs = plt.subplots(2,5, figsize=(10, 4))
for k in range(2):
  for h in range(5):
    axs[k][h].contourf(X, Y, Z[:, :, 50+2*(k+h)], cmap='Blues')

plt.tight_layout()
  • rescale the position between 0 and 1 to match the basis functions
position = (position - np.min(position, 0)) / (np.max(position, 0) - np.min(position, 0))
  • evaluate the basis for each position of the animal
position_basis = basis_2d.evaluate(position['x'], position['y'])
print(position_basis.shape)
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(2, 5)
xt = np.arange(0, 1000, 200)
cmap = plt.get_cmap("rainbow")
colors = np.linspace(0,1, len(xt))
for cnt, i in enumerate(xt):
    ax = plt.subplot(gs[0, i // 200])
    ax.imshow(position_basis[i].reshape(10, 10).T, origin = 'lower')
    for spine in ["top", "bottom", "left","right"]:
        ax.spines[spine].set_color(cmap(colors[cnt]))
        ax.spines[spine].set_linewidth(3)
    plt.title("T "+str(i))

ax = plt.subplot(gs[1, 2])

ax.plot(position['x'][0:1000], position['y'][0:1000])
for i in range(len(xt)):
    ax.plot(position['x'][xt[i]], position['y'][xt[i]], 'o', color = cmap(colors[i]))

plt.tight_layout()
  • instantiate a GLM model with Ridge regularization.
  • set regularizer_strength=1.0
model = workshop_utils.model.GLM(
        regularizer=nmo.regularizer.Ridge(regularizer_strength=1.0, solver_name="LBFGS")
    )
  • fit the model only to neuron 7 for faster computation
neuron = 7

model.fit(position_basis, counts[:,neuron])
  • predict the rate and compute a tuning curves using compute_2d_tuning_curves_continuous from pynapple
rate_pos = model.predict(position_basis)
rate_pos = nap.TsdFrame(t=counts.t, d=np.asarray(rate_pos), columns = [neuron])
model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
    tsdframe=rate_pos,
    features=position, 
    nb_bins=12)
  • compare the tuning curves
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
ax.imshow(gaussian_filter(pos_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 1])
ax.imshow(gaussian_filter(model_tuning[neuron], sigma=1))
plt.tight_layout()
  • find the best regularizer_strength using sklearn.model_selection.GriSearchCV
from sklearn.model_selection import GridSearchCV
param_grid = dict(regularizer__regularizer_strength=[1e-6, 1e-3, 1.0])

cls = GridSearchCV(model, param_grid=param_grid)

cls.fit(position_basis, counts[:,neuron])
  • instantiate the best model from scikit-learn
best_model = cls.best_estimator_
  • predict the rate of the best model and compute a 2d tuning curves
best_rate_pos = best_model.predict(position_basis)
best_rate_pos = nap.TsdFrame(t=counts.t, d=np.asarray(best_rate_pos), columns=[neuron])

best_model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
    tsdframe=best_rate_pos,
    features=position, 
    nb_bins=12)
  • compare the 2d tuning curves
fig = plt.figure(figsize = (12, 4))
gs = plt.GridSpec(1, 3)
ax = plt.subplot(gs[0, 0])
ax.imshow(gaussian_filter(pos_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 1])
ax.imshow(gaussian_filter(model_tuning[neuron], sigma=1))
ax = plt.subplot(gs[0, 2])
ax.imshow(gaussian_filter(best_model_tuning[neuron], sigma=1))
plt.tight_layout()

Total running time of the script: ( 0 minutes 0.000 seconds)

Download Python source code: 03_grid_cells_code.py

Download Jupyter notebook: 03_grid_cells_code.ipynb

Gallery generated by mkdocs-gallery