diff --git a/trackpy/plots.py b/trackpy/plots.py index 685e0aa5..087abe0f 100644 --- a/trackpy/plots.py +++ b/trackpy/plots.py @@ -43,6 +43,13 @@ def wrapper(*args, **kwargs): show_plot = (plt.get_backend() != "agg") else: show_plot = False + + #if kwargs.get('fig') is None: + # kwargs['fig'] = plt.gcf() + # # show plot unless the matplotlib backend is headless + # show_plot = (plt.get_backend() != "agg") + #else: + # show_plot = False # Delete legend keyword so remaining ones can be passed to plot(). legend = kwargs.pop('legend', False) @@ -261,9 +268,9 @@ def scatter3d(*args, **kwargs): @make_axes -def plot_traj(traj, colorby='particle', mpp=None, label=False, - superimpose=None, cmap=None, ax=None, t_column=None, - pos_columns=None, plot_style={}, **kwargs): +def plot_traj(traj, colorby='particle', mpp=None, fps=None, label=False, + superimpose=None, cmap=None, fig=None, ax=None, t_column=None, + pos_columns=None, plot_style={}, colorbar_units='frames', **kwargs): """Plot traces of trajectories for each particle. Optionally superimpose it on a frame from the video. @@ -271,15 +278,19 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False, ---------- traj : DataFrame The DataFrame should include time and spatial coordinate columns. - colorby : {'particle', 'frame'}, optional + colorby : {'particle', 'frame', 'velocity'}, optional mpp : float, optional Microns per pixel. If omitted, the labels will have units of pixels. + fps : int, optional + Number of frames in one second. label : boolean, optional Set to True to write particle ID numbers next to trajectories. superimpose : ndarray, optional Background image, default None cmap : colormap, optional This is only used in colorby='frame' mode. Default = mpl.cm.winter + fig : matplotlib figure object, optional + Defaults to current figure ax : matplotlib axes object, optional Defaults to current axes t_column : string, optional @@ -288,6 +299,11 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False, Dataframe column names for spatial coordinates. Default is ['x', 'y']. plot_style : dictionary Keyword arguments passed through to the `Axes.plot(...)` command + colorbar_units : {None, 'frames', 'real'} string, optional + Set to None to disable the colorbar. If 'frames' is selected then + units like frames and px will be used. If 'real' is selected then + units like seconds and um will be used. + Returns ------- @@ -301,6 +317,8 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False, import matplotlib.pyplot as plt from matplotlib.collections import LineCollection + fig = plt.gcf() + if cmap is None: cmap = plt.cm.winter if t_column is None: @@ -312,15 +330,37 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False, _plot_style = dict(linewidth=1) _plot_style.update(**_normalize_kwargs(plot_style, 'line2d')) + colorbar_title = "" + frames_to_seconds_factor = 1.0 + + # Select the colorbar title according to the type of data plotted. + if colorbar_units == 'frames': + if colorby == 'frame': + colorbar_title = "Frame number" + if colorby == 'velocity': + colorbar_title = "Speed (px/s)" + + if colorbar_units == 'real': + if colorby == 'frame': + colorbar_title = "Time (s)" + if fps is None: + fps = 24 # Set default framerate to 24 fps if not specified. + frames_to_seconds_factor = 1.0 / fps + if colorby == 'velocity': + if mpl.rcParams['text.usetex']: + colorbar_title = r'Speed (\textmu m/s)' + else: + colorbar_title = 'Speed (\xb5m/s)' + # Axes labels if mpp is None: - _set_labels(ax, '{} [px]', pos_columns) + _set_labels(ax, '{} (px)', pos_columns) mpp = 1. # for computations of image extent below else: if mpl.rcParams['text.usetex']: - _set_labels(ax, r'{} [\textmu m]', pos_columns) + _set_labels(ax, r'{} (\textmu m)', pos_columns) else: - _set_labels(ax, r'{} [\xb5m]', pos_columns) + _set_labels(ax, '{} (\xb5m)', pos_columns) # Background image if superimpose is not None: ax.imshow(superimpose, cmap=plt.cm.gray, @@ -338,23 +378,63 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False, # Read https://scipy-cookbook.readthedocs.io/items/Matplotlib_MulticoloredLine.html x = traj.set_index([t_column, 'particle'])['x'].unstack() y = traj.set_index([t_column, 'particle'])['y'].unstack() - color_numbers = traj[t_column].values/float(traj[t_column].max()) - norm = plt.Normalize(color_numbers.min(), color_numbers.max()) - color_max = float(traj[t_column].max()) + norm = plt.Normalize(traj[t_column].min() * frames_to_seconds_factor, + traj[t_column].max() * frames_to_seconds_factor) logger.info("Drawing multicolor lines takes a while. " "Come back in a minute.") for particle in x: x_series = x[particle].dropna() y_series = y[particle].dropna() frames = np.array(x_series.index) - frames_map = frames/color_max points = np.array([x_series.values, y_series.values]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, cmap=cmap, norm=norm) - lc.set_array(frames_map) - ax.add_collection(lc) + lc.set_array(frames * frames_to_seconds_factor) + lines = ax.add_collection(lc) + ax.set_xlim(x.apply(np.min).min(), x.apply(np.max).max()) + ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max()) + if colorbar_units is not None: + fig.colorbar(lines, shrink=0.5, label=colorbar_title) + + if colorby == 'velocity': + x = traj.set_index([t_column, 'particle'])['x'].unstack() + y = traj.set_index([t_column, 'particle'])['y'].unstack() + logger.info("Drawing multicolor lines takes a while. " + "Come back in a minute.") + # Normalization of the colormap has to be done AFTER the velocities + # are calculated, so the results are stored in the following lists. + list_of_segments = [] + list_of_min_velocities = [] + list_of_max_velocities = [] + list_of_velocities = [] + # Computation of velocities and segments to be drawn. + for particle in x: + x_series = x[particle].dropna() + y_series = y[particle].dropna() + if x_series.shape[0] <= 2: + continue + frames = np.array(x_series.index) + xy_displacements = np.gradient(np.array([x_series, y_series]), frames, axis=1, edge_order=2) + displacement_um = np.linalg.norm(xy_displacements, axis=0) * mpp + velocity = displacement_um * frames_to_seconds_factor + list_of_velocities.append(velocity) + list_of_min_velocities.append(np.min(velocity)) + list_of_max_velocities.append(np.max(velocity)) + points = np.array([x_series.values, y_series.values]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + list_of_segments.append(segments) + # Normalization of the colormap. + norm = plt.Normalize(np.min(np.array(list_of_min_velocities)), + np.max(np.array(list_of_max_velocities))) + # Drawing segments. + for segments, velocity in zip(list_of_segments, list_of_velocities): + lc = LineCollection(segments, cmap=cmap, norm=norm) + lc.set_array(velocity) + lines = ax.add_collection(lc) ax.set_xlim(x.apply(np.min).min(), x.apply(np.max).max()) - ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max()) + ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max()) + if colorbar_units is not None: + fig.colorbar(lines, shrink=0.5, label=colorbar_title) if label: unstacked = traj.set_index([t_column, 'particle'])[pos_columns].unstack() first_frame = int(traj[t_column].min())