Skip to content

Note

Click here to download the full example code

Fit Head-direction population

Learning objectives

  • Learn how to add history-related predictors to nemos GLM
  • Learn about nemos Basis objects
  • Learn how to use Basis objects with convolution
import jax
import matplotlib.pyplot as plt
import nemos as nmo
import numpy as np
import pynapple as nap

import workshop_utils

# Set the default precision to float64, which is generally a good idea for
# optimization purposes.
jax.config.update("jax_enable_x64", True)
# configure plots some
plt.style.use(workshop_utils.STYLE_FILE)

Data Streaming

  • Stream the head-direction neurons data
path = workshop_utils.data.download_data("Mouse32-140822.nwb", "https://osf.io/jb2gd/download",
                                         '../data')

Pynapple

  • load_file : open the NWB file and give a preview.
data = nap.load_file(path)

data
  • Load the units
spikes = data["units"]

spikes
  • Load the epochs and take only wakefulness
epochs = data["epochs"]
wake_ep = data["epochs"]["wake"]
  • Load the angular head-direction of the animal (in radians)
angle = data["ry"]
  • Select only those units that are in ADn
spikes = spikes.getby_category("location")["adn"]
  • Restrict the activity to wakefulness (both the spiking activity and the angle)
spikes = spikes.restrict(wake_ep).getby_threshold("rate", 1.0)
angle = angle.restrict(wake_ep)
  • Compute tuning curves as a function of head-direction
tuning_curves = nap.compute_1d_tuning_curves(
    group=spikes, feature=angle, nb_bins=61, minmax=(0, 2 * np.pi)
)
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(tuning_curves.iloc[:, 0])
ax[0].set_xlabel("Angle (rad)")
ax[0].set_ylabel("Firing rate (Hz)")
ax[1].plot(tuning_curves.iloc[:, 1])
ax[1].set_xlabel("Angle (rad)")
plt.tight_layout()
  • Let's visualize the data at the population level.
fig = workshop_utils.plotting.plot_head_direction_tuning(
    tuning_curves, spikes, angle, threshold_hz=1, start=8910, end=8960
)
  • Take the first 3 minutes of wakefulness to speed up optimization
wake_ep = nap.IntervalSet(
    start=wake_ep.loc[0, "start"], end=wake_ep.loc[0, "start"] + 3 * 60
)
  • bin the spike trains in 10 ms bin
bin_size = 0.01
count = spikes.count(bin_size, ep=wake_ep)
  • sort the neurons by their preferred direction using pandas
pref_ang = tuning_curves.idxmax()

count = nap.TsdFrame(
    t=count.t,
    d=count.values[:, np.argsort(pref_ang.values)],
)

Nemos

Self-Connected Single Neuron

  • Start with modeling a self-connected single neuron
  • Select a neuron and visualize the spike count time course
# select a neuron's spike count time series
neuron_count = count.loc[[0]]

# restrict to a smaller time interval
epoch_one_spk = nap.IntervalSet(
    start=count.time_support["start"][0], end=count.time_support["start"][0] + 1.2
)
plt.figure(figsize=(8, 3.5))
plt.step(
    neuron_count.restrict(epoch_one_spk).t, neuron_count.restrict(epoch_one_spk).d, where="post"
)
plt.title("Spike Count Time Series")
plt.xlabel("Time (sec)")
plt.ylabel("Counts")
plt.tight_layout()

Features Construction

  • Use the past counts over a fixed window to predict the current sample
# set the size of the spike history window in seconds
window_size_sec = 0.8

workshop_utils.plotting.plot_history_window(neuron_count, epoch_one_spk, window_size_sec)
  • Roll your window one bin at the time to predict the subsequent samples
workshop_utils.plotting.run_animation(neuron_count, float(epoch_one_spk.start))
  • Form a predictor matrix by vertically stacking all the windows (you can use a convolution).
# convert the prediction window to bins (by multiplying with the sampling rate)
window_size = int(window_size_sec * neuron_count.rate)

# convolve the counts with the identity matrix.
input_feature = nmo.utils.convolve_1d_trials(
    np.eye(window_size), np.expand_dims(neuron_count.d, axis=1)
)
  • Check the shape of the counts and features.
print(f"Time bins in counts: {neuron_count.shape[0]}")
print(f"Convolution window size in bins: {window_size}")
print(f"Feature shape: {input_feature.shape}")
  • Match time axis.
# get rid of the last time point.
input_feature = np.squeeze(input_feature[:-1])

print(f"Feature shape: {input_feature.shape}")
print(f"Time bins in counts: {neuron_count.shape[0]}")
print(f"Convolution window size in bins: {window_size}")
  • Plot the convolution output.
suptitle = "Input feature: Count History"
neuron_id = 0
workshop_utils.plotting.plot_features(input_feature, count.rate, suptitle)
  • Convert the features back to a pynapple TsdFrame.
# convert features to TsdFrame
input_feature = nap.TsdFrame(t=neuron_count.t[window_size:], d=np.asarray(input_feature))

Fitting the model

  • Split your epochs in two for validation purposes.
# construct the train and test epochs
duration = input_feature.time_support.tot_length("s")
start = input_feature.time_support["start"]
end = input_feature.time_support["end"]
first_half = nap.IntervalSet(start, start + duration / 2)
second_half = nap.IntervalSet(start + duration / 2, end)
  • Fit a GLM to the first half.
# define the GLM object
model = workshop_utils.model.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS"))

# Fit over the training epochs
model.fit(input_feature.restrict(first_half), neuron_count.restrict(first_half))
  • Plot the weights.
plt.figure()
plt.title("Spike History Weights")
plt.plot(np.arange(window_size) / count.rate, model.coef_, lw=2, label="GLM raw history 1st Half")
plt.axhline(0, color="k", lw=0.5)
plt.xlabel("Time From Spike (sec)")
plt.ylabel("Kernel")
plt.legend()

Inspecting the results

  • Fit on the other half and compare results.
# fit on the test set

model_second_half = workshop_utils.model.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS"))
model_second_half.fit(input_feature.restrict(second_half), neuron_count.restrict(second_half))

plt.figure()
plt.title("Spike History Weights")
plt.plot(np.arange(window_size) / count.rate, model.coef_, label="GLM raw history 1st Half", lw=2)
plt.plot(np.arange(window_size) / count.rate, model_second_half.coef_, color="orange", label="GLM raw history 2nd Half", lw=2)
plt.axhline(0, color="k", lw=0.5)
plt.xlabel("Time From Spike (sec)")
plt.ylabel("Kernel")
plt.legend()

Reducing feature dimensionality

  • Visualize the raised cosine basis.
workshop_utils.plotting.plot_basis()
  • Define the raised cosine basis through the "nemos.basis" module.
basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=8)
  • Create the basis kernel matrix (window_size, n_basis_funcs) with the "evaluate_on_grid" method.
# `basis.evaluate_on_grid` is a convenience method to view all basis functions
# across their whole domain:
time, basis_kernels = basis.evaluate_on_grid(window_size)

print(basis_kernels.shape)

# time takes equi-spaced values between 0 and 1, we could multiply by the
# duration of our window to scale it to seconds.
time *= window_size_sec
  • Check that we can approximate the "decay" in the history filter with the basis. Use least-squares to find choose appropriate weights.
# compute the least-squares weights
lsq_coef, _, _, _ = np.linalg.lstsq(basis_kernels, model.coef_, rcond=-1)

# plot the basis and the approximation
workshop_utils.plotting.plot_weighted_sum_basis(time, model.coef_, basis_kernels, lsq_coef)
  • Convolve the counts with the basis functions.
conv_spk = nmo.utils.convolve_1d_trials(basis_kernels, np.expand_dims(neuron_count, 1))
conv_spk = nap.TsdFrame(t=count[window_size:].t, d=np.asarray(conv_spk[:-1, 0]))

print(f"Raw count history as feature: {input_feature.shape}")
print(f"Compressed count history as feature: {conv_spk.shape}")
  • Visualize the output.
# Visualize the convolution results
epoch_one_spk = nap.IntervalSet(8917.5, 8918.5)
epoch_multi_spk = nap.IntervalSet(8979.2, 8980.2)

workshop_utils.plotting.plot_convolved_counts(neuron_count, conv_spk, epoch_one_spk, epoch_multi_spk)

# find interval with two spikes to show the accumulation, in a second row

Fit and compare the models

  • Fit the model using the compressed features.
# use restrict on interval set training
model_basis = workshop_utils.model.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS"))
model_basis.fit(conv_spk.restrict(first_half), neuron_count.restrict(first_half))
print(model_basis.coef_)
  • Reconstruct the history filter.
self_connection = np.matmul(basis_kernels, model_basis.coef_)

print(self_connection.shape)
  • Compare with the raw count history model.
plt.figure()
plt.title("Spike History Weights")
plt.plot(time, model.coef_, alpha=0.3, label="GLM raw history")
plt.plot(time, self_connection, "--k", label="GLM basis", lw=2)
plt.axhline(0, color="k", lw=0.5)
plt.xlabel("Time from spike (sec)")
plt.ylabel("Weight")
plt.legend()
  • Fit the other half of the data.
  • Plot and compare the results.
model_basis_second_half = workshop_utils.model.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS"))
model_basis_second_half.fit(conv_spk.restrict(second_half), neuron_count.restrict(second_half))

# compute responses for the 2nd half fit
self_connection_second_half = np.matmul(basis_kernels, model_basis_second_half.coef_)

plt.figure()
plt.title("Spike History Weights")
plt.plot(time, model.coef_, "k", alpha=0.3, label="GLM raw history 1st half")
plt.plot(time, model_second_half.coef_, alpha=0.3, color="orange", label="GLM raw history 2nd half")
plt.plot(time, self_connection, "--k", lw=2, label="GLM basis 1st half")
plt.plot(time, self_connection_second_half, color="orange", lw=2, ls="--", label="GLM basis 2nd half")
plt.axhline(0, color="k", lw=0.5)
plt.xlabel("Time from spike (sec)")
plt.ylabel("Weight")
plt.legend()
  • Use the score function to evaluate the GLM predictions.
# compare model scores, as expected the training score is better with more parameters
# this may could be over-fitting.
print(f"full history train score: {model.score(input_feature.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
print(f"basis train score: {model_basis.score(conv_spk.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
print(f"\nfull history test score: {model.score(input_feature.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
print(f"basis test score: {model_basis.score(conv_spk.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
  • Predict the rates and plot the results.
rate_basis = nap.Tsd(t=conv_spk.t, d=np.asarray(model_basis.predict(conv_spk.d))) * conv_spk.rate
rate_history = nap.Tsd(t=conv_spk.t, d=np.asarray(model.predict(input_feature))) * conv_spk.rate
ep = nap.IntervalSet(start=8819.4, end=8821)

# plot the rates
workshop_utils.plotting.plot_rates_and_smoothed_counts(
    neuron_count,
    {"Self-connection raw history":rate_history, "Self-connection bsais": rate_basis}
)

All-to-all Connectivity

Preparing the features

  • Convolve all counts.
  • Print the output shape
convolved_count = nmo.utils.convolve_1d_trials(basis_kernels, count.values)
convolved_count = np.asarray(convolved_count[:-1])
print(f"Convolved count shape: {convolved_count.shape}")
  • Reshape the convolved counts to define the feature matrix.
convolved_count = convolved_count.reshape(convolved_count.shape[0], -1)
print(f"Convolved count reshaped: {convolved_count.shape}")
convolved_count = nap.TsdFrame(t=neuron_count.t[window_size:], d=convolved_count)

Fitting the Model

  • Loop over the neurons
  • Fit each neuron
  • Store the result in a list
models = []
for neu in range(count.shape[1]):
    print(f"fitting neuron {neu}...")
    count_neu = count[:, neu]
    model = workshop_utils.model.GLM(
        regularizer=nmo.regularizer.Ridge(regularizer_strength=0.1, solver_name="LBFGS")
    )
    # models.append(model.fit(convolved_count.restrict(train_epoch), count_neu.restrict(train_epoch)))
    models.append(model.fit(convolved_count, count_neu.restrict(convolved_count.time_support)))

Comparing model predictions.

  • Predict the firing rate of each neuron, store it in an array of shape (num_sample_points - window_size, num_neurons)
  • Convert the array to a pynapple TsdFrame
predicted_firing_rate = np.zeros((count.shape[0] - window_size, count.shape[1]))
for receiver_neu in range(count.shape[1]):
    predicted_firing_rate[:, receiver_neu] = models[receiver_neu].predict(
        convolved_count
    ) * conv_spk.rate

predicted_firing_rate = nap.TsdFrame(t=count[window_size:].t, d=predicted_firing_rate)
  • Visualize the predicted rate and tuning function.
# use pynapple for time axis for all variables plotted for tick labels in imshow
workshop_utils.plotting.plot_head_direction_tuning_model(tuning_curves, predicted_firing_rate, spikes, angle, threshold_hz=1,
                                                start=8910, end=8960, cmap_label="hsv")
  • Visually compare all the models.
workshop_utils.plotting.plot_rates_and_smoothed_counts(
    neuron_count,
    {"Self-connection: raw history": rate_history,
     "Self-connection: bsais": rate_basis,
     "All-to-all: basis": predicted_firing_rate[:, 0]}
)

Visualizing the connectivity

  • Compute tuning curves from the predicted rates using pynapple.
tuning = nap.compute_1d_tuning_curves_continuous(predicted_firing_rate,
                                                 feature=angle,
                                                 nb_bins=61,
                                                 minmax=(0, 2 * np.pi))

# Extract the weights
#
# <div class="notes">
# - Extract the weights and store it in an array,
#   shape (num_neurons, num_neurons, num_features).
# </div>

weights = np.zeros((count.shape[1], count.shape[1], basis.n_basis_funcs))
for receiver_neu in range(count.shape[1]):
    weights[receiver_neu] = models[receiver_neu].coef_.reshape(
        count.shape[1], basis.n_basis_funcs
    )
  • Multiply the weights by the basis, to get the history filters.
responses = np.einsum("ijk, tk->ijt", weights, basis_kernels)

print(responses.shape)
  • Plot the connectivity map.
workshop_utils.plotting.plot_coupling(responses, tuning)

Exercise

# 1. What would happen if we regressed explicitly the head direction?
# 2. What would happen to the connectivity if we fit on the sleep epochs?
# 3. How would we sparsify the connectivity?

Total running time of the script: ( 0 minutes 0.000 seconds)

Download Python source code: 02_head_direction_code.py

Download Jupyter notebook: 02_head_direction_code.ipynb

Gallery generated by mkdocs-gallery