diff --git a/doc/api.rst b/doc/api.rst index 8482e139..05bb03e9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -65,7 +65,7 @@ Plotting .. autofunction:: plot_probe - .. autofunction:: plot_probe_group + .. autofunction:: plot_probegroup Library ------- diff --git a/doc/generate_format_example.py b/doc/generate_format_example.py index d247bf64..eb23d428 100644 --- a/doc/generate_format_example.py +++ b/doc/generate_format_example.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup, combine_probes, write_probeinterface -from probeinterface.plotting import plot_probe, plot_probe_group +from probeinterface.plotting import plot_probe, plot_probegroup from probeinterface import generate_tetrode diff --git a/examples/ex_03_generate_probe_group.py b/examples/ex_03_generate_probe_group.py index cb68bbff..8a640d3a 100644 --- a/examples/ex_03_generate_probe_group.py +++ b/examples/ex_03_generate_probe_group.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup -from probeinterface.plotting import plot_probe_group +from probeinterface.plotting import plot_probegroup from probeinterface import generate_dummy_probe ############################################################################## @@ -39,11 +39,11 @@ ############################################################################## #  We can now plot all probes in the same axis: -plot_probe_group(probegroup, same_axes=True) +plot_probegroup(probegroup, same_axes=True) ############################################################################## #  or in separate axes: -plot_probe_group(probegroup, same_axes=False, with_contact_id=True) +plot_probegroup(probegroup, same_axes=False, with_contact_id=True) plt.show() diff --git a/examples/ex_05_device_channel_indices.py b/examples/ex_05_device_channel_indices.py index f928c801..5731c910 100644 --- a/examples/ex_05_device_channel_indices.py +++ b/examples/ex_05_device_channel_indices.py @@ -17,7 +17,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup -from probeinterface.plotting import plot_probe, plot_probe_group +from probeinterface.plotting import plot_probe, plot_probegroup from probeinterface import generate_multi_columns_probe ############################################################################## @@ -85,6 +85,6 @@ # The indices of the probe group can also be plotted: fig, ax = plt.subplots() -plot_probe_group(probegroup, with_contact_id=True, same_axes=True, ax=ax) +plot_probegroup(probegroup, with_contact_id=True, same_axes=True, ax=ax) plt.show() diff --git a/examples/ex_06_import_export_to_file.py b/examples/ex_06_import_export_to_file.py index 6d34b153..438cef65 100644 --- a/examples/ex_06_import_export_to_file.py +++ b/examples/ex_06_import_export_to_file.py @@ -24,7 +24,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup -from probeinterface.plotting import plot_probe, plot_probe_group +from probeinterface.plotting import plot_probe, plot_probegroup from probeinterface import generate_dummy_probe from probeinterface import write_probeinterface, read_probeinterface from probeinterface import write_prb, read_prb @@ -48,7 +48,7 @@ write_probeinterface('my_two_probe_setup.json', probegroup) probegroup2 = read_probeinterface('my_two_probe_setup.json') -plot_probe_group(probegroup2) +plot_probegroup(probegroup2) ############################################################################## # The format looks like this: @@ -98,6 +98,6 @@ f.write(prb_two_tetrodes) two_tetrode = read_prb('two_tetrodes.prb') -plot_probe_group(two_tetrode, same_axes=False, with_contact_id=True) +plot_probegroup(two_tetrode, same_axes=False, with_contact_id=True) plt.show() diff --git a/examples/ex_07_probe_generator.py b/examples/ex_07_probe_generator.py index 19781648..960682bc 100644 --- a/examples/ex_07_probe_generator.py +++ b/examples/ex_07_probe_generator.py @@ -17,7 +17,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup -from probeinterface.plotting import plot_probe, plot_probe_group +from probeinterface.plotting import plot_probe, plot_probegroup ############################################################################## # Generate 4 tetrodes: @@ -35,7 +35,7 @@ df = probegroup.to_dataframe() df -plot_probe_group(probegroup, with_contact_id=True, same_axes=True) +plot_probegroup(probegroup, with_contact_id=True, same_axes=True) ############################################################################## # Generate a linear probe: diff --git a/examples/ex_08_more_plotting_examples.py b/examples/ex_08_more_plotting_examples.py index 9b2bf8f6..223163d1 100644 --- a/examples/ex_08_more_plotting_examples.py +++ b/examples/ex_08_more_plotting_examples.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from probeinterface import Probe, ProbeGroup -from probeinterface.plotting import plot_probe, plot_probe_group +from probeinterface.plotting import plot_probe, plot_probegroup from probeinterface import generate_multi_columns_probe, generate_linear_probe ############################################################################## diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index ce535cf6..35cd2a8f 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -17,8 +17,9 @@ def create_probe_polygons( contacts_colors: list | None = None, contacts_values: np.ndarray | None = None, cmap: str = "viridis", - contacts_kargs: dict = {}, + contact_kwargs: dict = {}, probe_shape_kwargs: dict = {}, + contacts_kargs=None, # DEPRECATED ): """Create PolyCollection objects for a Probe. @@ -32,7 +33,7 @@ def create_probe_polygons( Values to color the contacts with cmap : str, default: "viridis" A colormap color - contacts_kargs : dict, default: {} + contact_kwargs : dict, default: {} Dict with kwargs for contacts (e.g. alpha, edgecolor, lw) probe_shape_kwargs : dict, default: {} Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw) @@ -44,6 +45,16 @@ def create_probe_polygons( poly_contour : PolyCollection | None The polygon collection for the probe shape """ + if contacts_kargs is not None: + import warnings + + warnings.warn( + "contacts_kargs is deprecated and will be removed in 0.3.4. Please use `contacts_kwargs` instead.", + category=DeprecationWarning, + stacklevel=2, + ) + contact_kwargs = contacts_kargs + if probe.ndim == 2: from matplotlib.collections import PolyCollection @@ -59,7 +70,7 @@ def create_probe_polygons( _probe_shape_kwargs.update(probe_shape_kwargs) _contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5) - _contacts_kargs.update(contacts_kargs) + _contacts_kargs.update(contact_kwargs) n = probe.get_contact_count() @@ -93,7 +104,7 @@ def plot_probe( with_contact_id: bool = False, with_device_index: bool = False, text_on_contact: list | np.ndarray | None = None, - contacts_values: np.ndarray | None = None, + contacts_values: list | np.ndarray | None = None, cmap: str = "viridis", title: bool = True, contacts_kargs: dict = {}, @@ -119,9 +130,9 @@ def plot_probe( If True, channel ids are displayed on top of the channels with_device_index : bool, default: False If True, device channel indices are displayed on top of the channels - text_on_contact: None | list | numpy.array, default: None + text_on_contact: None | list | np.ndarray, default: None Addintional text to plot on each contact - contacts_values : np.array, default: None + contacts_values : list | np.ndarray | None, default: None Values to color the contacts with cmap : a colormap color, default: "viridis" A colormap color @@ -248,7 +259,7 @@ def on_press(event): return poly, poly_contour -def plot_probegroup(probegroup, same_axes: bool = True, **kargs): +def plot_probegroup(probegroup, same_axes: bool = True, **kwargs): """Plot all probes from a ProbeGroup Can be in an existing set of axes or separate axes. @@ -258,19 +269,37 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs): The ProbeGroup to plot same_axes : bool, default: True If True, the probes are plotted on the same axis - kargs: dict - see docstring for plot_probe for possible kargs + kwargs: dict + Additional keyword arguments to pass to plot_probe. + If same_axes is True, the same kwargs are passed to all probes. + If same_axes is False, the kwargs are passed separately to each probe, + if they have the same length as the total number of contacts in the ProbeGroup. + For example, if contacts_colors is given and has the same length as the total + number of contacts in the ProbeGroup, then the colors are split and passed + separately to each probe. + + Available kwargs: + + - contacts_colors + - with_contact_id + - with_device_index + - text_on_contact + - contacts_values + - cmap + - title + - contacts_kargs + - probe_shape_kwargs """ import matplotlib.pyplot as plt - figsize = kargs.pop("figsize", None) + figsize = kwargs.pop("figsize", None) n = len(probegroup.probes) if same_axes: - if "ax" in kargs: - ax = kargs.pop("ax") + if "ax" in kwargs: + ax = kwargs.pop("ax") else: if probegroup.ndim == 2: fig, ax = plt.subplots(figsize=figsize) @@ -279,14 +308,16 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs): ax = fig.add_subplot(1, 1, 1, projection="3d") axs = [ax] * n else: - if "ax" in kargs: + if "ax" in kwargs: raise ValueError("when same_axes=False, an axes object cannot be passed into this function.") if probegroup.ndim == 2: fig, axs = plt.subplots(ncols=n, nrows=1, figsize=figsize) if n == 1: axs = [axs] else: - raise NotImplementedError + raise NotImplementedError( + "same_axes=False is currently only implemented for 2D probes. For 3D probes, please set same_axes=True." + ) if same_axes: # global lims @@ -297,36 +328,34 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs): ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) if zlims is not None: zlims = min(zlims[0], zlims2[0]), max(zlims[1], zlims2[1]) - kargs["xlims"] = xlims - kargs["ylims"] = ylims - kargs["zlims"] = zlims + kwargs["xlims"] = xlims + kwargs["ylims"] = ylims + kwargs["zlims"] = zlims else: # will be auto for each probe in each axis - kargs["xlims"] = None - kargs["ylims"] = None - kargs["zlims"] = None + kwargs["xlims"] = None + kwargs["ylims"] = None + kwargs["zlims"] = None - kargs["title"] = False - for i, probe in enumerate(probegroup.probes): - plot_probe(probe, ax=axs[i], **kargs) - - -def plot_probe_group(probegroup, same_axes: bool = True, **kargs): - """ - This function is deprecated and will be removed in 0.2.23 - Please use plot_probegroup instead""" + kwargs["title"] = False - from warnings import warn + cum_contact_count = 0 + total_contacts = sum(p.get_contact_count() for p in probegroup.probes) - warn( - "`plot_probe_group` is deprecated and will be removed in 2.23. Use plot_probegroup instead", - category=DeprecationWarning, - stacklevel=2, - ) + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + kwargs_probe = kwargs.copy() + for key in ["contacts_colors", "contacts_values", "text_on_contact"]: + if kwargs.get(key) is not None: + val = np.array(kwargs[key]) + if len(val) == total_contacts: + kwargs_probe[key] = val[cum_contact_count : cum_contact_count + n] - plot_probegroup(probegroup, same_axes=same_axes, **kargs) + plot_probe(probe, ax=axs[i], **kwargs_probe) + cum_contact_count += n +### MATPLOTLIB INTERACTION ### def _on_press(probe, event): ax = event.inaxes x, y = event.xdata, event.ydata diff --git a/tests/test_plotting.py b/tests/test_plotting.py index eaa676f4..12ec258e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -2,9 +2,6 @@ from probeinterface import generate_dummy_probe, generate_dummy_probe_group from probeinterface.plotting import plot_probe, plot_probegroup -# remove once plot_probe_group is removed -from probeinterface.plotting import plot_probe_group - import matplotlib.pyplot as plt import numpy as np @@ -38,10 +35,6 @@ def test_plot_probegroup(): plot_probegroup(probegroup, same_axes=True, with_contact_id=True) plot_probegroup(probegroup, same_axes=False) - # remove when plot_probe_group has been removed - with pytest.warns(DeprecationWarning): - plot_probe_group(probegroup) - # 3d probegroup_3d = ProbeGroup() for probe in probegroup.probes: @@ -74,6 +67,5 @@ def test_plot_probe_two_side(): if __name__ == "__main__": # test_plot_probe() - # test_plot_probe_group() test_plot_probe_two_side() plt.show()