Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Plotting

.. autofunction:: plot_probe

.. autofunction:: plot_probe_group
.. autofunction:: plot_probegroup

Library
-------
Expand Down
2 changes: 1 addition & 1 deletion doc/generate_format_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup, combine_probes, write_probeinterface
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup

from probeinterface import generate_tetrode

Expand Down
6 changes: 3 additions & 3 deletions examples/ex_03_generate_probe_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe_group
from probeinterface.plotting import plot_probegroup
from probeinterface import generate_dummy_probe

##############################################################################
Expand All @@ -39,11 +39,11 @@
##############################################################################
#  We can now plot all probes in the same axis:

plot_probe_group(probegroup, same_axes=True)
plot_probegroup(probegroup, same_axes=True)

##############################################################################
#  or in separate axes:

plot_probe_group(probegroup, same_axes=False, with_contact_id=True)
plot_probegroup(probegroup, same_axes=False, with_contact_id=True)

plt.show()
4 changes: 2 additions & 2 deletions examples/ex_05_device_channel_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_multi_columns_probe

##############################################################################
Expand Down Expand Up @@ -85,6 +85,6 @@
# The indices of the probe group can also be plotted:

fig, ax = plt.subplots()
plot_probe_group(probegroup, with_contact_id=True, same_axes=True, ax=ax)
plot_probegroup(probegroup, with_contact_id=True, same_axes=True, ax=ax)

plt.show()
6 changes: 3 additions & 3 deletions examples/ex_06_import_export_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_dummy_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
Expand All @@ -48,7 +48,7 @@
write_probeinterface('my_two_probe_setup.json', probegroup)

probegroup2 = read_probeinterface('my_two_probe_setup.json')
plot_probe_group(probegroup2)
plot_probegroup(probegroup2)

##############################################################################
# The format looks like this:
Expand Down Expand Up @@ -98,6 +98,6 @@
f.write(prb_two_tetrodes)

two_tetrode = read_prb('two_tetrodes.prb')
plot_probe_group(two_tetrode, same_axes=False, with_contact_id=True)
plot_probegroup(two_tetrode, same_axes=False, with_contact_id=True)

plt.show()
4 changes: 2 additions & 2 deletions examples/ex_07_probe_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup

##############################################################################
# Generate 4 tetrodes:
Expand All @@ -35,7 +35,7 @@
df = probegroup.to_dataframe()
df

plot_probe_group(probegroup, with_contact_id=True, same_axes=True)
plot_probegroup(probegroup, with_contact_id=True, same_axes=True)

##############################################################################
# Generate a linear probe:
Expand Down
2 changes: 1 addition & 1 deletion examples/ex_08_more_plotting_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import matplotlib.pyplot as plt

from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup
from probeinterface import generate_multi_columns_probe, generate_linear_probe

##############################################################################
Expand Down
101 changes: 65 additions & 36 deletions src/probeinterface/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def create_probe_polygons(
contacts_colors: list | None = None,
contacts_values: np.ndarray | None = None,
cmap: str = "viridis",
contacts_kargs: dict = {},
contact_kwargs: dict = {},
probe_shape_kwargs: dict = {},
contacts_kargs=None, # DEPRECATED
):
"""Create PolyCollection objects for a Probe.

Expand All @@ -32,7 +33,7 @@ def create_probe_polygons(
Values to color the contacts with
cmap : str, default: "viridis"
A colormap color
contacts_kargs : dict, default: {}
contact_kwargs : dict, default: {}
Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
probe_shape_kwargs : dict, default: {}
Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
Expand All @@ -44,6 +45,16 @@ def create_probe_polygons(
poly_contour : PolyCollection | None
The polygon collection for the probe shape
"""
if contacts_kargs is not None:
import warnings

warnings.warn(
"contacts_kargs is deprecated and will be removed in 0.3.4. Please use `contacts_kwargs` instead.",
category=DeprecationWarning,
stacklevel=2,
)
contact_kwargs = contacts_kargs

if probe.ndim == 2:
from matplotlib.collections import PolyCollection

Expand All @@ -59,7 +70,7 @@ def create_probe_polygons(
_probe_shape_kwargs.update(probe_shape_kwargs)

_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
_contacts_kargs.update(contacts_kargs)
_contacts_kargs.update(contact_kwargs)

n = probe.get_contact_count()

Expand Down Expand Up @@ -93,7 +104,7 @@ def plot_probe(
with_contact_id: bool = False,
with_device_index: bool = False,
text_on_contact: list | np.ndarray | None = None,
contacts_values: np.ndarray | None = None,
contacts_values: list | np.ndarray | None = None,
cmap: str = "viridis",
title: bool = True,
contacts_kargs: dict = {},
Expand All @@ -119,9 +130,9 @@ def plot_probe(
If True, channel ids are displayed on top of the channels
with_device_index : bool, default: False
If True, device channel indices are displayed on top of the channels
text_on_contact: None | list | numpy.array, default: None
text_on_contact: None | list | np.ndarray, default: None
Addintional text to plot on each contact
contacts_values : np.array, default: None
contacts_values : list | np.ndarray | None, default: None
Values to color the contacts with
cmap : a colormap color, default: "viridis"
A colormap color
Expand Down Expand Up @@ -248,7 +259,7 @@ def on_press(event):
return poly, poly_contour


def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
def plot_probegroup(probegroup, same_axes: bool = True, **kwargs):
"""Plot all probes from a ProbeGroup
Can be in an existing set of axes or separate axes.

Expand All @@ -258,19 +269,37 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
The ProbeGroup to plot
same_axes : bool, default: True
If True, the probes are plotted on the same axis
kargs: dict
see docstring for plot_probe for possible kargs
kwargs: dict
Additional keyword arguments to pass to plot_probe.
If same_axes is True, the same kwargs are passed to all probes.
If same_axes is False, the kwargs are passed separately to each probe,
if they have the same length as the total number of contacts in the ProbeGroup.
For example, if contacts_colors is given and has the same length as the total
number of contacts in the ProbeGroup, then the colors are split and passed
separately to each probe.

Available kwargs:

- contacts_colors
- with_contact_id
- with_device_index
- text_on_contact
- contacts_values
- cmap
- title
- contacts_kargs
- probe_shape_kwargs
"""

import matplotlib.pyplot as plt

figsize = kargs.pop("figsize", None)
figsize = kwargs.pop("figsize", None)

n = len(probegroup.probes)

if same_axes:
if "ax" in kargs:
ax = kargs.pop("ax")
if "ax" in kwargs:
ax = kwargs.pop("ax")
else:
if probegroup.ndim == 2:
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -279,14 +308,16 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
ax = fig.add_subplot(1, 1, 1, projection="3d")
axs = [ax] * n
else:
if "ax" in kargs:
if "ax" in kwargs:
raise ValueError("when same_axes=False, an axes object cannot be passed into this function.")
if probegroup.ndim == 2:
fig, axs = plt.subplots(ncols=n, nrows=1, figsize=figsize)
if n == 1:
axs = [axs]
else:
raise NotImplementedError
raise NotImplementedError(
"same_axes=False is currently only implemented for 2D probes. For 3D probes, please set same_axes=True."
)

if same_axes:
# global lims
Expand All @@ -297,36 +328,34 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1])
if zlims is not None:
zlims = min(zlims[0], zlims2[0]), max(zlims[1], zlims2[1])
kargs["xlims"] = xlims
kargs["ylims"] = ylims
kargs["zlims"] = zlims
kwargs["xlims"] = xlims
kwargs["ylims"] = ylims
kwargs["zlims"] = zlims
else:
# will be auto for each probe in each axis
kargs["xlims"] = None
kargs["ylims"] = None
kargs["zlims"] = None
kwargs["xlims"] = None
kwargs["ylims"] = None
kwargs["zlims"] = None

kargs["title"] = False
for i, probe in enumerate(probegroup.probes):
plot_probe(probe, ax=axs[i], **kargs)


def plot_probe_group(probegroup, same_axes: bool = True, **kargs):
"""
This function is deprecated and will be removed in 0.2.23
Please use plot_probegroup instead"""
kwargs["title"] = False

from warnings import warn
cum_contact_count = 0
total_contacts = sum(p.get_contact_count() for p in probegroup.probes)

warn(
"`plot_probe_group` is deprecated and will be removed in 2.23. Use plot_probegroup instead",
category=DeprecationWarning,
stacklevel=2,
)
for i, probe in enumerate(probegroup.probes):
n = probe.get_contact_count()
kwargs_probe = kwargs.copy()
for key in ["contacts_colors", "contacts_values", "text_on_contact"]:
if kwargs.get(key) is not None:
val = np.array(kwargs[key])
if len(val) == total_contacts:
kwargs_probe[key] = val[cum_contact_count : cum_contact_count + n]

plot_probegroup(probegroup, same_axes=same_axes, **kargs)
plot_probe(probe, ax=axs[i], **kwargs_probe)
cum_contact_count += n


### MATPLOTLIB INTERACTION ###
def _on_press(probe, event):
ax = event.inaxes
x, y = event.xdata, event.ydata
Expand Down
8 changes: 0 additions & 8 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from probeinterface import generate_dummy_probe, generate_dummy_probe_group
from probeinterface.plotting import plot_probe, plot_probegroup

# remove once plot_probe_group is removed
from probeinterface.plotting import plot_probe_group

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -38,10 +35,6 @@ def test_plot_probegroup():
plot_probegroup(probegroup, same_axes=True, with_contact_id=True)
plot_probegroup(probegroup, same_axes=False)

# remove when plot_probe_group has been removed
with pytest.warns(DeprecationWarning):
plot_probe_group(probegroup)

# 3d
probegroup_3d = ProbeGroup()
for probe in probegroup.probes:
Expand Down Expand Up @@ -74,6 +67,5 @@ def test_plot_probe_two_side():

if __name__ == "__main__":
# test_plot_probe()
# test_plot_probe_group()
test_plot_probe_two_side()
plt.show()
Loading