Source code for spiketimes.plots

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import warnings
from .alignment import align_around, split_by_trial


def _raster(spiketrain: np.ndarray, ax=None, y_data_ind: int = 1, **kwargs):
    """
    Construct a raster plot of a single spiketrain over one trial

    Args:
        spiketrain: A numpy array of spiketimes in seconds
        ax: A matplotlib axes object to plot on
        y_data_ind: The y tick for spiketrain
        kwargs: Kwargs to pass to matplotlib.pyplot.plot
    Returns:
        A matplotlib axes object
    """
    try:
        (_ for _ in spiketrain[0])
        raise TypeError(
            f"Must Pass in a single numpy array. Nested iterable found.\n"
            f"Spike times: {spiketrain}"
        )
    except TypeError:
        pass
    if ax is None:
        _, ax = plt.subplots()
    y_data = np.zeros(shape=(1, len(spiketrain))).flatten() + y_data_ind
    ax.scatter(spiketrain, y_data, marker="|", **kwargs)
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_yticks([y_data_ind])
    ax.set_xlabel("Time [sec]")
    ax.set_ylabel("Spiketrain")
    return ax


[docs]def raster( spiketrain_list: list, ax=None, skip_empty=True, t_start: float = None, t_stop: float = None, _starting_ytick=None, **kwargs, ): """ Construct a raster plot of multiple spiketrains Args: spiketrain_list: A list of numpy arrays containing the timings of spiking events ax: A matplotlib axes object on which to plot skip_empty: Whether to skip spiketrains with no spikes in the plotting interval t_start: Minimum timepoint t_stop: Maximum timepoint _starting_y_tick: The position on the y axis to start. kwargs: Additional key-word arguments will be passed into matplotlib.pyplot.plot Returns: A matloblib axes object """ if _starting_ytick is None: _starting_ytick = 0 try: (_ for _ in spiketrain_list[0]) # iterable check if len(spiketrain_list) == 1: raise TypeError except TypeError: raise TypeError( f"spike_time must be an iterable containing at least array of spiketimes\n" f"Passed spiketimes: {spiketrain_list}" ) if ax is None: _, ax = plt.subplots() if t_start: spiketrain_list = [ np.array(list(filter(lambda x: x > t_start, spikes))) for spikes in spiketrain_list ] if t_stop: spiketrain_list = [ np.array(list(filter(lambda x: x < t_stop, spikes))) for spikes in spiketrain_list ] if isinstance(spiketrain_list, list): i = 0 for spikes in spiketrain_list: if len(spikes): ax = _raster(spikes, ax=ax, y_data_ind=i + _starting_ytick, **kwargs) i += 1 else: warnings.warn( "A spiketrain with no spikes in the plotting window was passed, skipping." ) i += 1 if skip_empty else 0 else: raise ValueError("Must pass in a list of spike times") ax.yaxis.set_major_locator(MaxNLocator(integer=True)) return ax
[docs]def grouped_raster( st_lists: list, color_list: list = None, ax=None, skip_empty=True, t_start: float = None, t_stop: float = None, plot_kwargs=None, space_between_groups: int = 2, ): """ Construct a raster plot of multiple groups of spiketrains. Args: st_lists: A list of lists of spiketrains. Each sublist contains one group of spiketrains. color_list: A list containing one color for each group to be plotted/ ax: A matplotlib axes object skip_empty: Whether to skip spiketrains with no spikes in the plotting interval t_start: Minimum timepoint t_stop: Maximum timepoint plot_kwargs: Additional key-word arguments will be passed into matplotlib.pyplot.plot space_between_groups: Number of spaces between groups in the y direction. Returns: A matloblib axes object """ DEFAULT_COLORS = [ "black", "red", "green", "blue", "pink", "purple", "orange", "yellow", ] if color_list is None: color_list = DEFAULT_COLORS if ax is None: _, ax = plt.subplots() if plot_kwargs is None: plot_kwargs = {} starting_ytick = 0 for i, st_list in enumerate(st_lists): k1 = plot_kwargs.copy() k1["color"] = color_list[i] if i < len(color_list) else color_list[i - i] ax = raster( st_list, ax=ax, t_start=t_start, t_stop=t_stop, _starting_ytick=starting_ytick, **k1, ) starting_ytick += len(st_list) + space_between_groups return ax
[docs]def aligned_raster( spiketrain: np.ndarray, trial_starts: np.ndarray, before: float = None, max_latency: float = None, ax=None, raster_kwargs=None, ): """ Constructs a raster plot with each row containing spikes from a single trial. Args: spiketrain: a spiketrain containing spiketimes in seconds. trial_starts: an array of trial starts in seconds. before: if specified, include this amount of time (in seconds) before each trial max_latency: if specified, exclude spikes occuring this amount of time (in seconds) after the final event. ax: matplotlib axes object to plot on. Returns: A matplotlib axes object """ st_list = split_by_trial( spiketrain=spiketrain, trial_starts=trial_starts, max_latency=max_latency, before=before, kwargs=raster_kwargs, ) ax = raster(st_list, ax=ax) ax.set_ylabel("Trial") return ax
[docs]def psth( spiketimes: np.ndarray, events: np.ndarray, binwidth: float = 0.01, t_before: float = 0.2, max_latency: float = 2, ax=None, hist_kwargs: dict = None, vline_kwargs: dict = None, ): """ Contruct a peristimulus time histogram of spiketimes latencies to events. t_before defines the time before time 0 (when the event occured) to include in the histogram Args: spiketimes: A numpy array of spiketimes events: A numpy array of event times in the same units as spiketimes binwidth: The width of time bins t_before: The time before the aligned event to include in the psth max_latency: The maximum allowed latency. Useful for excluding spikes occuring after the final event. ax: An optional matloblib axes object to use hist_kwargs: A dict of kwargs to pass to matplotlib.pyplot.hist Returns: A matplotlib.pyplot.axes object """ if ax is None: _, ax = plt.subplots() if hist_kwargs is None: hist_kwargs = {} if vline_kwargs is None: vline_kwargs = {} hist_kwargs["alpha"] = 0.5 if "alpha" not in hist_kwargs else hist_kwargs["alpha"] vline_kwargs["linewidth"] = ( 2.5 if "linewidth" not in vline_kwargs else vline_kwargs["linewidth"] ) latencies = align_around(spiketimes, events, t_before, max_latency, drop=True) bins = np.arange(np.min(latencies), np.max(latencies), binwidth) if hist_kwargs is None: hist_kwargs = {} ax.hist(latencies, bins=bins, **hist_kwargs) ax = add_event_vlines(ax, 0, vline_kwargs=vline_kwargs) ax.set_xlabel("Time [sec]") ax.set_ylabel("Counts") return ax
[docs]def add_event_vlines( ax, events: np.ndarray, t_min: float = None, t_max: float = None, vline_kwargs: dict = None, ): """ Add vertical lines to a matplotlib axes object at the point(s) specified in events. t_min and t_max define minimum and maximum timepoints for events i.e. no events outside these limits will be plotted. Args: ax: the axes to plot on top of events: an array of points on the x axis to plot t_min: if specified, no points less than this will be plotted t_max: if specified, no points greater than this will be plotted Returns: matplotlib axes """ if vline_kwargs is None: vline_kwargs = {} vline_kwargs["color"] = ( "black" if "color" not in vline_kwargs else vline_kwargs["color"] ) vline_kwargs["linestyle"] = ( "--" if "linestyle" not in vline_kwargs else vline_kwargs["linestyle"] ) if vline_kwargs is None: vline_kwargs = {} try: _ = (x for x in events) except TypeError: events = [events] if t_min: events = np.array(list(filter(lambda x: x > t_min, events))) if t_max: events = np.array(list(filter(lambda x: x < t_max, events))) for event in events: ax.axvline(event, **vline_kwargs) return ax