Skip to content

Note

Click here to download the full example code

Fit V1 cell

Learning objectives

  • Learn how to combine GLM with other modeling approach.
  • Review previous tutorials.
import jax
import math
import os
import matplotlib.pyplot as plt
import nemos as nmo
import numpy as np
import pynapple as nap
import requests
import tqdm
import workshop_utils

# required for second order methods (BFGS, Newton-CG)
jax.config.update("jax_enable_x64", True)
# configure plots some
plt.style.use(workshop_utils.STYLE_FILE)

Data Streaming

Here we load the data from OSF. This data comes from Sonica Saraf, in Tony Movshon's lab.

path = workshop_utils.data.download_data("m691l1.nwb", "https://osf.io/xesdm/download",
                                         '../data')

Pynapple

The data have been copied to your local station. We are gonna open the NWB file with pynapple

data = nap.load_file(path)

What does it look like?

print(data)

Out:

m691l1
┍━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys        Type        │
┝━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units       TsGroup     │
│ epochs      IntervalSet │
│ whitenoise  TsdTensor   │
┕━━━━━━━━━━━━┷━━━━━━━━━━━━━┙

Let's extract the data.

epochs = data["epochs"]
spikes = data["units"]
stimulus = data["whitenoise"]

stimulus is white noise shown at 40 Hz

- stimulus is white noise shown at 40 Hz - white noise is a good stimulus for mapping basic stimulus properties of V1 simple cells
fig, ax = plt.subplots(1, 1, figsize=(12,4))
ax.imshow(stimulus[0], cmap='Greys_r')
stimulus.shape

04 v1 cells

Out:

(96001, 51, 51)

There are 73 neurons recorded together in V1. To fit the GLM faster, we will focus on one neuron.

print(spikes)
spikes = spikes[[34]]

Out:

  Index    rate  location      group
-------  ------  ----------  -------
      1    0.34  v1                0
     11    0.73  v1                0
     19    0.58  v1                0
     20    5.97  v1                0
     23    2.86  v1                0
     26    3.67  v1                0
     30    1.48  v1                0
     33    1.02  v1                0
     34    8.56  v1                0
     36    0.46  v1                0
     38    0.25  v1                0
     40   21.81  v1                0
     41    2.13  v1                0
     50    0.29  v1                0
     54   12.08  v1                0
     56    1.1   v1                0
     60    1.47  v1                0
     64    0     v1                0
     69    9.09  v1                0
     72   13.87  v1                0
     75    1.08  v1                0
     76    0.97  v1                0
     81    0.9   v1                0
     82    1.37  v1                0
     86    3.21  v1                0
     88    0.01  v1                0
     90    2.85  v1                0
     97    0.6   v1                0
     98    0.75  v1                0
    109    1.63  v1                0
    110    0.01  v1                0
    112    9.67  v1                0
    116    3.47  v1                0
    121    1.61  v1                0
    126    3.05  v1                0
    131    0.9   v1                0
    137    5.21  v1                0
    141   10.43  v1                0
    146    2.46  v1                0
    151    8.92  v1                0
    154    5.9   v1                0
    159    4.84  v1                0
    160    1.11  v1                0
    169    4.65  v1                0
    171    0.37  v1                0
    175    2.56  v1                0
    176    1.64  v1                0
    179    6.01  v1                0
    180    1.26  v1                0
    185    6.37  v1                0
    187   27.05  v1                0
    188    1.07  v1                0
    192    1.36  v1                0
    197    0.99  v1                0
    202    2.57  v1                0
    205    1.74  v1                0
    215    0.03  v1                0
    219    0.82  v1                0
    222    0.08  v1                0
    224    8.35  v1                0
    231    0.33  v1                0
    233    3.88  v1                0
    235    5.29  v1                0
    238    3.59  v1                0
    245    1.93  v1                0
    249    0.01  v1                0
    251    1.42  v1                0
    255    6.52  v1                0
    257    5.83  v1                0
    261    1.13  v1                0
    262    1.75  v1                0
    266    1.79  v1                0
    269   11.61  v1                0
- goal is to predict the neuron's response to this white noise stimuli - several ways we could do this, what do you think?

How could we predict neuron's response to white noise stimulus?

  • we could fit the instantaneous spatial response. that is, just predict neuron's response to a given frame of white noise. this will give an x by y filter. implicitly assumes that there's no temporal info: only matters what we've just seen

  • could fit spatiotemporal filter. instead of an x by y that we use independently on each frame, fit (x, y, t) over, say 100 msecs. and then fit each of these independently (like in head direction example)

  • that's a lot of parameters! can simplify by assumping that the response is separable: fit a single (x, y) filter and then modulate it over time. this wouldn't catch e.g., direction-selectivity because it assumes that phase preference is constant over time

  • could make use of our knowledge of V1 and try to fit a more complex functional form, e.g., a Gabor.

That last one is very non-linear and thus non-convex. we'll do the third one.

in this example, we'll fit the spatial filter outside of the GLM framework, using spike-triggered average, and then we'll use the GLM to fit the temporal timecourse.

Spike-triggered average

Spike-triggered average says: every time our neuron spikes, we store the stimulus that was on the screen. for the whole recording, we'll have many of these, which we then average to get this STA, which is the "optimal stimulus" / spatial filter.

In practice, we do not just the stimulus on screen, but in some window of time around it. (it takes some time for info to travel through the eye/LGN to V1). Pynapple makes this easy:

- compute spike-triggered average to visualize receptive field.
sta = nap.compute_event_trigger_average(spikes, stimulus, binsize=0.025,
                                        windowsize=(-0.15, 0.0))

sta is a TsdTensor, which gives us the 2d receptive field at each of the time points.

sta

Out:

Time (s)
----------  -------------------------------------------------------------
-0.15       [[[0.009472777724619654 ... 0.00899435460721462] ...] ...]
-0.125      [[[0.01100373170031576 ... 0.0012439001052530858] ...] ...]
-0.1        [[[-0.0033968041335757345 ... 0.004449334991866807] ...] ...]
-0.075      [[[-0.00449717730360731 ... 0.005166969667974357] ...] ...]
-0.05       [[[0.00885082767199311 ... -0.005549708161898383] ...] ...]
-0.025      [[[-0.0011482154817720792 ... 0.009807673906803177] ...] ...]
0           [[[0.0007654769878480528 ... 0.0018180078461391255] ...] ...]
dtype: float64, shape: (7, 1, 51, 51)

We index into this in a 2d manner: row, column (here we only have 1 column).

sta[1, 0]

Out:

array([[ 0.01100373, -0.00052627,  0.00186585, ..., -0.00459286,
        -0.01066884,  0.0012439 ],
       [ 0.00138743,  0.00999904, -0.00478423, ..., -0.00019137,
        -0.00162664,  0.01636207],
       [ 0.0065544 ,  0.00200938, -0.01114726, ..., -0.0046407 ,
        -0.0083724 ,  0.00516697],
       ...,
       [ 0.0003349 ,  0.00291838, -0.00688929, ..., -0.00755909,
        -0.00956846, -0.01789302],
       [-0.01908908,  0.00301407,  0.00478423, ..., -0.00066979,
        -0.00483207,  0.00138743],
       [-0.00172232, -0.00794182, -0.00492776, ...,  0.00315759,
         0.00990336, -0.0012439 ]])

we can easily plot this

- visualize spike-triggered average and decide on our spatial filter.
fig, axes = plt.subplots(1, len(sta), figsize=(3*len(sta),3))
for i, t in enumerate(sta.t):
    axes[i].imshow(sta[i,0], vmin = np.min(sta), vmax = np.max(sta),
                   cmap='Greys_r')
    axes[i].set_title(str(t)+" s")

-0.15 s, -0.125 s, -0.1 s, -0.075 s, -0.05 s, -0.025 s, 0.0 s

that looks pretty reasonable for a V1 simple cell: localized in space, orientation, and spatial frequency. that is, looks Gabor-ish

To convert this to the spatial filter we'll use for the GLM, let's take the average across the bins that look informative: -.125 to -.05

receptive_field = np.mean(sta.get(-0.125, -0.05), axis=0)[0]

fig, ax = plt.subplots(1, 1, figsize=(4,4))
ax.imshow(receptive_field, cmap='Greys_r')

04 v1 cells

Out:

<matplotlib.image.AxesImage object at 0x7fab568a6bf0>

This receptive field gives us the spatial part of the linear response: it gives a map of weights that we use for a weighted sum on an image. There are multiple ways of performing this operation:

- use the spike-triggered average to preprocess our visual input.
# element-wise multiplication and sum
print((receptive_field * stimulus[0]).sum())
# dot product of flattened versions
print(np.dot(receptive_field.flatten(), stimulus[0].flatten()))

Out:

-0.1176203234140274
-0.11762032341402737

When performing this operation on multiple stimuli, things become slightly more complicated. For loops on the above methods would work, but would be slow. Reshaping and using the dot product is one common method, as are methods like np.tensordot.

We'll use einsum to do this, which is a convenient way of representing many different matrix operations:

filtered_stimulus = np.einsum('t h w, h w -> t', stimulus, receptive_field)
# add the extra dimension for feature
filtered_stimulus = np.expand_dims(filtered_stimulus, 1)

This notation says: take these arrays with dimensions (t,h,w) and (h,w) and multiply and sum to get an array of shape (t,). This performs the same operations as above.

And this remains a pynapple object, so we can easily visualize it!

fig, ax = plt.subplots(1, 1, figsize=(12,4))
ax.plot(filtered_stimulus)

04 v1 cells

Out:

[<matplotlib.lines.Line2D object at 0x7fab568d29e0>]

But what is this? It's how much each frame in the video should drive our neuron, based on the receptive field we fit using the spike-triggered average.

This, then, is the spatial component of our input, as described above.

Preparing data for nemos

We'll now use the GLM to fit the temporal component. To do that, let's get this and our spike counts into the proper format for nemos:

- get `counts` and `filtered_stimulus` into proper shape for nemos
# grab spikes from when we were showing our stimulus, and bin at 1 msec
# resolution
bin_size = .001
counts = spikes.restrict(filtered_stimulus.time_support).count(bin_size)
print(counts.rate)
print(filtered_stimulus.rate)

Out:

1000.0001425044869
39.973157342501494

Hold on, our stimulus is at a much lower rate than what we want for our rates -- in previous examples, our input has been at a higher rate than our spikes, and so we used bin_average to down-sample to the appropriate rate. When the input is at a lower rate, we need to think a little more carefully about how to up-sample.

print(counts[:5])
print(filtered_stimulus[:5])

Out:

  Time (s)    0
----------  ---
    0.0005    0
    0.0015    0
    0.0025    0
    0.0035    0
    0.0045    0
dtype: int64, shape: (5, 1)
  Time (s)           0
----------  ----------
 0          -0.11762
 0.025017    0.224512
 0.0500341   0.0305712
 0.0750511   0.297902
 0.100068   -0.0934241
dtype: float64, shape: (5, 1)

What was the visual input to the neuron at time 0.005? It was the same input as time 0. At time 0.0015? Same thing, up until we pass time 0.025017. Thus, we want to "fill forward" the values of our input, and we have pynapple convenience function to do so:

filtered_stimulus = workshop_utils.data.fill_forward(counts, filtered_stimulus)
filtered_stimulus

Out:

Time (s)                 0
--------------  ----------
0.0005          -0.11762
0.0015          -0.11762
0.0025          -0.11762
0.0035          -0.11762
0.0045          -0.11762
...
2401.632500031  -0.0683786
2401.633500031  -0.0683786
2401.634500031  -0.0683786
2401.635500031  -0.0683786
2401.636500031  -0.0683786
dtype: float64, shape: (2401637, 1)

We can see that the time points are now aligned, and we've filled forward the values the way we'd like.

Now, similar to the head direction tutorial, we'll use the log-stretched raised cosine basis to create the predictor for our GLM:

- Set up the basis and prepare the temporal predictor for the GLM.
basis = nmo.basis.RaisedCosineBasisLog(8)
window_size = 100
time, basis_kernels = basis.evaluate_on_grid(window_size)
time *= bin_size * window_size
convolved_input = nmo.utils.convolve_1d_trials(basis_kernels, filtered_stimulus)
# convolved_input has shape (n_time_pts, n_features, n_basis_funcs), and
# n_features is the singleton dimension from filtered_stimulus, so let's
# squeeze it out:
convolved_input = np.squeeze(convolved_input)
# and, as also described in the head direction tutorial, when doing this we
# need to remove the first window_size time points from the neuron counts and
# the last time point from the convolved input:
counts = counts[window_size:]
convolved_input = convolved_input[:-1]
# and grab the counts for our single neuron
counts = counts[:, 0]

Fitting the GLM

Now we're ready to fit the model! Let's do it, same as before:

- Fit the GLM
model = workshop_utils.model.GLM(regularizer=nmo.regularizer.UnRegularized(solver_name="LBFGS"))
model.fit(convolved_input, counts)

Out:

<workshop_utils.model.GLM object at 0x7fab5db9d150>

We have our coefficients for each of our 8 basis functions, let's combine them to get the temporal time course of our input:

- Examine the resulting temporal filter
temp_weights = np.einsum('b, t b -> t', model.coef_, basis_kernels)
plt.plot(time, temp_weights)

04 v1 cells

Out:

[<matplotlib.lines.Line2D object at 0x7fab56489d50>]

When taken together, the results of the GLM and the spike-triggered average give us the linear component of our LNP model: the separable spatio-temporal filter.

Further exercises

There's more that could (and should) be done here. First, we should probably split our data into separate test and train sets, to see how consistent our estimates of the spatial and temporal filters are. Then, using the test and train sets, we can:

  • try different choices for the spatial receptive field: modify the parameters of the STA, pick one of the time bins directly (instead of averaging), lowpass filter the receptive field (to remove the high frequency noise), manually create or fit a Gabor to match the STA results.

  • try different choices for the temporal filter: change basis functions, change the parameters of the basis object.

  • try adding regularization to the GLM for fitting the temporal filter.

Total running time of the script: ( 0 minutes 28.125 seconds)

Download Python source code: 04_v1_cells.py

Download Jupyter notebook: 04_v1_cells.ipynb

Gallery generated by mkdocs-gallery