diff --git a/crystal_toolkit/components/structure.py b/crystal_toolkit/components/structure.py index 9eee713b..fbeaad9e 100644 --- a/crystal_toolkit/components/structure.py +++ b/crystal_toolkit/components/structure.py @@ -19,6 +19,7 @@ from pymatgen.analysis.local_env import NearNeighbors from pymatgen.core import Composition, Molecule, Species, Structure from pymatgen.core.periodic_table import DummySpecie +from pymatgen.io.lobster.lobsterenv import LobsterNeighbors from pymatgen.io.vasp.sets import MPRelaxSet from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -55,7 +56,10 @@ class StructureMoleculeComponent(MPComponent): """ available_bonding_strategies = frozendict( - {subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()} + { + **{subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()}, + "LobsterNeighbors": LobsterNeighbors, + } ) default_scene_settings = frozendict( @@ -112,6 +116,7 @@ def __init__( show_export_button: bool = DEFAULTS["show_export_button"], show_position_button: bool = DEFAULTS["show_position_button"], scene_kwargs: dict | None = None, + site_get_scene_kwargs: dict | None = None, **kwargs, ) -> None: """Create a StructureMoleculeComponent from a structure or molecule. @@ -218,6 +223,7 @@ def __init__( graph, scene_additions=self.initial_data["scene_additions"], **self.initial_data["display_options"], + site_get_scene_kwargs=site_get_scene_kwargs, ) if hasattr(struct_or_mol, "lattice"): self._lattice = struct_or_mol.lattice @@ -968,6 +974,7 @@ def _preprocess_input_to_graph( valid_bond_strategies = ( StructureMoleculeComponent.available_bonding_strategies ) + if bonding_strategy not in valid_bond_strategies: raise ValueError( "Bonding strategy not supported. Please supply a name of a NearNeighbor " @@ -1032,6 +1039,7 @@ def get_scene_and_legend( scene_additions=None, show_compass=DEFAULTS["show_compass"], group_by_site_property=None, + site_get_scene_kwargs=None, ) -> tuple[Scene, dict[str, str]]: """Get the scene and legend for a given graph. @@ -1078,9 +1086,10 @@ def get_scene_and_legend( explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, group_by_site_property=group_by_site_property, legend=legend, + **(site_get_scene_kwargs or {}), ) elif isinstance(graph, MoleculeGraph): - scene = graph.get_scene(legend=legend) + scene = graph.get_scene(legend=legend, **(site_get_scene_kwargs or {})) scene.name = "StructureMoleculeComponentScene" diff --git a/crystal_toolkit/renderables/moleculegraph.py b/crystal_toolkit/renderables/moleculegraph.py index 4704b8f6..53a22e94 100644 --- a/crystal_toolkit/renderables/moleculegraph.py +++ b/crystal_toolkit/renderables/moleculegraph.py @@ -22,6 +22,7 @@ def get_molecule_graph_scene( show_bond_order=True, show_bond_length=False, visualize_bond_orders=False, + edge_weight_name_mapping: dict[str, str] | None = None, ) -> Scene: """Create a Molecule Graph scene. @@ -33,6 +34,7 @@ def get_molecule_graph_scene( show_bond_length: Defaults to False, shows the calculated length between two connected atoms visualize_bpnd_orders: Defaults False, will show the 'integral' number of bonds calculated from the OpenBabelNN strategy in the Molecule Graph + edge_weight_name_mapping: A custom mapping from the edge weight name in the MoleculeGraph, which will be shown in the tooltip if show_bond_order is True. If None, defaults to {"weight": "bond order"}. Returns: A Molecule Graph scene. @@ -47,6 +49,9 @@ def get_molecule_graph_scene( primitives: dict[str, list] = defaultdict(list) + if edge_weight_name_mapping is None: + edge_weight_name_mapping = {"weight": "bond order"} + for idx, site in enumerate(self.molecule): connected_sites = vis_mol_graph.get_connected_sites(idx) @@ -62,6 +67,9 @@ def get_molecule_graph_scene( show_bond_length=show_bond_length, visualize_bond_orders=visualize_bond_orders, draw_polyhedra=draw_polyhedra, + edge_weight_name=vis_mol_graph.edge_weight_name, + edge_weight_unit=vis_mol_graph.edge_weight_unit, + edge_weight_name_mapping=edge_weight_name_mapping, ) for scene in site_scene.contents: primitives[scene.name] += scene.contents diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index ac9579e9..4ce31bc1 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -49,6 +49,9 @@ def get_site_scene( legend: Legend | None = None, retain_atom_idx: bool = False, total_repeat_cell_cnt: int = 1, + edge_weight_name_mapping: dict[str, str] | None = None, + edge_weight_name: str = "bond order", + edge_weight_unit: str = "", ) -> Scene: """Get a Scene object for a Site. @@ -74,6 +77,9 @@ def get_site_scene( legend (Legend | None, optional): Defaults to None. retain_atom_idx (bool, optional): Defaults to False. total_repeat_cell_cnt (int, optional): Defaults to 1. + edge_weight_name_mapping (dict[str, str] | None, optional): Mapping of ConnectedSite attribute names to display names for edge weights. If None, defaults to {"weight": "bond order"}. + edge_weight_name (str, optional): Defaults to "bond order". + edge_weight_unit (str, optional): Defaults to "". Returns: Scene: The scene object containing atoms, bonds, polyhedra, magmoms. @@ -190,9 +196,17 @@ def get_site_scene( all_positions = [self.coords] name_cyl = " " + if edge_weight_name_mapping is None: + edge_weight_name_mapping = {"weight": "bond order"} + for idx, connected_site in enumerate(connected_sites): if show_bond_order and connected_site.weight is not None: - name_cyl = f"bond order:{connected_site.weight:.2f}" + edge_weight_name = edge_weight_name_mapping.get( + edge_weight_name, edge_weight_name + ) + name_cyl = f"{edge_weight_name}:{connected_site.weight:.2f}" + if edge_weight_unit: + name_cyl += f" ({edge_weight_unit})" if show_bond_length and connected_site.dist is not None: name_cyl += f"\nbond length:{connected_site.dist:.3f}"