Skip to content
13 changes: 11 additions & 2 deletions crystal_toolkit/components/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import Composition, Molecule, Species, Structure
from pymatgen.core.periodic_table import DummySpecie
from pymatgen.io.lobster.lobsterenv import LobsterNeighbors
from pymatgen.io.vasp.sets import MPRelaxSet
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

Expand Down Expand Up @@ -55,7 +56,10 @@ class StructureMoleculeComponent(MPComponent):
"""

available_bonding_strategies = frozendict(
{subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()}
{
**{subcls.__name__: subcls for subcls in NearNeighbors.__subclasses__()},
"LobsterNeighbors": LobsterNeighbors,
}
)

default_scene_settings = frozendict(
Expand Down Expand Up @@ -112,6 +116,7 @@ def __init__(
show_export_button: bool = DEFAULTS["show_export_button"],
show_position_button: bool = DEFAULTS["show_position_button"],
scene_kwargs: dict | None = None,
site_get_scene_kwargs: dict | None = None,
**kwargs,
) -> None:
"""Create a StructureMoleculeComponent from a structure or molecule.
Expand Down Expand Up @@ -218,6 +223,7 @@ def __init__(
graph,
scene_additions=self.initial_data["scene_additions"],
**self.initial_data["display_options"],
site_get_scene_kwargs=site_get_scene_kwargs,
)
if hasattr(struct_or_mol, "lattice"):
self._lattice = struct_or_mol.lattice
Expand Down Expand Up @@ -968,6 +974,7 @@ def _preprocess_input_to_graph(
valid_bond_strategies = (
StructureMoleculeComponent.available_bonding_strategies
)

if bonding_strategy not in valid_bond_strategies:
raise ValueError(
"Bonding strategy not supported. Please supply a name of a NearNeighbor "
Expand Down Expand Up @@ -1032,6 +1039,7 @@ def get_scene_and_legend(
scene_additions=None,
show_compass=DEFAULTS["show_compass"],
group_by_site_property=None,
site_get_scene_kwargs=None,
) -> tuple[Scene, dict[str, str]]:
"""Get the scene and legend for a given graph.

Expand Down Expand Up @@ -1078,9 +1086,10 @@ def get_scene_and_legend(
explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull,
group_by_site_property=group_by_site_property,
legend=legend,
**(site_get_scene_kwargs or {}),
)
elif isinstance(graph, MoleculeGraph):
scene = graph.get_scene(legend=legend)
scene = graph.get_scene(legend=legend, **(site_get_scene_kwargs or {}))

scene.name = "StructureMoleculeComponentScene"

Expand Down
8 changes: 8 additions & 0 deletions crystal_toolkit/renderables/moleculegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_molecule_graph_scene(
show_bond_order=True,
show_bond_length=False,
visualize_bond_orders=False,
edge_weight_name_mapping: dict[str, str] | None = None,
) -> Scene:
"""Create a Molecule Graph scene.

Expand All @@ -33,6 +34,7 @@ def get_molecule_graph_scene(
show_bond_length: Defaults to False, shows the calculated length between two connected atoms
visualize_bpnd_orders: Defaults False, will show the 'integral' number of bonds calculated
from the OpenBabelNN strategy in the Molecule Graph
edge_weight_name_mapping: A custom mapping from the edge weight name in the MoleculeGraph, which will be shown in the tooltip if show_bond_order is True. If None, defaults to {"weight": "bond order"}.

Returns:
A Molecule Graph scene.
Expand All @@ -47,6 +49,9 @@ def get_molecule_graph_scene(

primitives: dict[str, list] = defaultdict(list)

if edge_weight_name_mapping is None:
edge_weight_name_mapping = {"weight": "bond order"}

for idx, site in enumerate(self.molecule):
connected_sites = vis_mol_graph.get_connected_sites(idx)

Expand All @@ -62,6 +67,9 @@ def get_molecule_graph_scene(
show_bond_length=show_bond_length,
visualize_bond_orders=visualize_bond_orders,
draw_polyhedra=draw_polyhedra,
edge_weight_name=vis_mol_graph.edge_weight_name,
edge_weight_unit=vis_mol_graph.edge_weight_unit,
edge_weight_name_mapping=edge_weight_name_mapping,
)
for scene in site_scene.contents:
primitives[scene.name] += scene.contents
Expand Down
16 changes: 15 additions & 1 deletion crystal_toolkit/renderables/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def get_site_scene(
legend: Legend | None = None,
retain_atom_idx: bool = False,
total_repeat_cell_cnt: int = 1,
edge_weight_name_mapping: dict[str, str] | None = None,
edge_weight_name: str = "bond order",
edge_weight_unit: str = "",
) -> Scene:
"""Get a Scene object for a Site.

Expand All @@ -74,6 +77,9 @@ def get_site_scene(
legend (Legend | None, optional): Defaults to None.
retain_atom_idx (bool, optional): Defaults to False.
total_repeat_cell_cnt (int, optional): Defaults to 1.
edge_weight_name_mapping (dict[str, str] | None, optional): Mapping of ConnectedSite attribute names to display names for edge weights. If None, defaults to {"weight": "bond order"}.
edge_weight_name (str, optional): Defaults to "bond order".
edge_weight_unit (str, optional): Defaults to "".

Returns:
Scene: The scene object containing atoms, bonds, polyhedra, magmoms.
Expand Down Expand Up @@ -190,9 +196,17 @@ def get_site_scene(
all_positions = [self.coords]
name_cyl = " "

if edge_weight_name_mapping is None:
edge_weight_name_mapping = {"weight": "bond order"}

for idx, connected_site in enumerate(connected_sites):
if show_bond_order and connected_site.weight is not None:
name_cyl = f"bond order:{connected_site.weight:.2f}"
edge_weight_name = edge_weight_name_mapping.get(
edge_weight_name, edge_weight_name
)
name_cyl = f"{edge_weight_name}:{connected_site.weight:.2f}"
if edge_weight_unit:
name_cyl += f" ({edge_weight_unit})"

if show_bond_length and connected_site.dist is not None:
name_cyl += f"\nbond length:{connected_site.dist:.3f}"
Expand Down
Loading