Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"FairChem not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_fairchem_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"FairChem not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_graphpes_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch_sim.models.graphpes_framework import GraphPESWrapper
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"graph-pes not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"graph-pes not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch_sim.models.mace import MaceModel, MaceUrls

except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments]
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)

# mace_omol is optional (added in newer MACE versions)
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_mattersim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"mattersim not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"mattersim not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_metatomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch_sim.models.metatomic import MetatomicModel
except ImportError:
pytest.skip(
f"metatomic not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"metatomic not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_nequip_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch_sim.models.nequip_framework import NequIPFrameworkModel
except (ImportError, ModuleNotFoundError):
pytest.skip(
f"nequip not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"nequip not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch_sim.models.orb import OrbModel

except ImportError:
pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments]
pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_sevennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

except ImportError:
pytest.skip(
f"sevenn not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
f"sevenn not installed: {traceback.format_exc()}",
allow_module_level=True,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from torch_sim.models.mace import MaceModel
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments]
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)


def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None:
Expand Down
165 changes: 165 additions & 0 deletions tests/test_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import pytest
import torch

import torch_sim as ts


DEVICE = torch.device("cpu")
DTYPE = torch.float64


class TestExtras:
def test_system_extras_construction(self):
"""Extras can be passed at construction time."""
field = torch.randn(1, 3)
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"external_E_field": field},
)
assert torch.equal(state.external_E_field, field)

def test_atom_extras_construction(self):
"""Per-atom extras work at construction time."""
tags = torch.tensor([1.0, 2.0])
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_atom_extras={"tags": tags},
)
assert torch.equal(state.tags, tags)

def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState):
with pytest.raises(AttributeError, match="nonexistent_key"):
_ = cu_sim_state.nonexistent_key

def test_post_init_validation_rejects_bad_shape(self):
with pytest.raises(ValueError, match="leading dim must be n_systems"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"bad": torch.randn(5, 3)},
)

def test_from_state_preserves_extras(self, cu_sim_state: ts.SimState):
field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device)
cu_sim_state.set_extras("E", field, scope="per-system")
new = ts.SimState.from_state(cu_sim_state)
assert torch.equal(new.E, field)

def test_extras_cannot_shadow_declared_fields(self, cu_sim_state: ts.SimState):
# set_extras should raise if attempting to shadow
with pytest.raises(ValueError, match="shadows an existing attribute"):
cu_sim_state.set_extras(
"cell", torch.zeros(cu_sim_state.n_systems, 3), scope="per-system"
)

def test_construction_extras_cannot_shadow(self):
# Post-init validation should also catch shadowing during construction
with pytest.raises(ValueError, match="shadows an existing attribute"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"cell": torch.zeros(1, 3)},
)

# store_model_extras
def test_store_model_extras_canonical_keys_not_stored(
self, si_double_sim_state: ts.SimState
):
"""Canonical keys (energy, forces, stress) must not land in extras."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"forces": torch.randn(state.n_atoms, 3),
"stress": torch.randn(state.n_systems, 3, 3),
}
)
assert not state._system_extras # noqa: SLF001
assert not state._atom_extras # noqa: SLF001

def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_systems go into system_extras."""
state = si_double_sim_state.clone()
dipole = torch.randn(state.n_systems, 3)
state.store_model_extras(
{"energy": torch.randn(state.n_systems), "dipole": dipole}
)
assert torch.equal(state.dipole, dipole)

def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_atoms go into atom_extras."""
state = si_double_sim_state.clone()
charges = torch.randn(state.n_atoms)
density = torch.randn(state.n_atoms, 8)
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"charges": charges,
"density_coefficients": density,
}
)
assert torch.equal(state.charges, charges)
assert state.density_coefficients.shape == (state.n_atoms, 8)

def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState):
"""0-d tensors and non-Tensor values are silently ignored."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"scalar": torch.tensor(3.14),
"string": "not a tensor",
}
)
assert not state.has_extras("scalar")
assert not state.has_extras("string")


def test_system_extras_atoms_roundtrip():
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])},
)
atoms_list = state.to_atoms()
assert "external_E_field" in atoms_list[0].info
restored = ts.io.atoms_to_state(
atoms_list,
system_extras_keys=["external_E_field"],
)
assert torch.allclose(restored.external_E_field, state.external_E_field)


def test_atom_extras_atoms_roundtrip():
tags = torch.tensor([1.0, 2.0])
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_atom_extras={"tags": tags},
)
atoms_list = state.to_atoms()
assert "tags" in atoms_list[0].arrays
restored = ts.io.atoms_to_state(
atoms_list,
atom_extras_keys=["tags"],
)
assert torch.allclose(restored.tags, state.tags)
84 changes: 84 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from ase import Atoms
from ase.build import molecule
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

Expand Down Expand Up @@ -91,6 +92,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None:
)


@pytest.mark.parametrize(
("charge", "spin", "expected_charge", "expected_spin"),
[
(1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin
(0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin
(None, None, 0.0, 0.0), # No charge/spin set, defaults to zero
],
)
def test_atoms_to_state_with_charge_spin(
charge: float | None,
spin: float | None,
expected_charge: float,
expected_spin: float,
) -> None:
"""Test conversion from ASE Atoms with charge and spin to state tensors."""
mol = molecule("H2O")
if charge is not None:
mol.info["charge"] = charge
if spin is not None:
mol.info["spin"] = spin

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (1,)
assert state.spin.shape == (1,)
assert state.charge[0].item() == expected_charge
assert state.spin[0].item() == expected_spin


def test_multiple_atoms_to_state_with_charge_spin() -> None:
"""Test conversion from multiple ASE Atoms with different charge/spin values."""
mol1 = molecule("H2O")
mol1.info["charge"] = 1.0
mol1.info["spin"] = 1.0

mol2 = molecule("CH4")
mol2.info["charge"] = -1.0
mol2.info["spin"] = 0.0

mol3 = molecule("NH3")
mol3.info["charge"] = 0.0
mol3.info["spin"] = 2.0

state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE)

# Check basic properties
assert isinstance(state, SimState)
assert state.charge is not None
assert state.spin is not None
assert state.charge.shape == (3,)
assert state.spin.shape == (3,)
assert state.charge[0].item() == 1.0
assert state.charge[1].item() == -1.0
assert state.charge[2].item() == 0.0
assert state.spin[0].item() == 1.0
assert state.spin[1].item() == 0.0
assert state.spin[2].item() == 2.0


def test_state_to_structure(ar_supercell_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of pymatgen Structure."""
structures = ts.io.state_to_structures(ar_supercell_sim_state)
Expand All @@ -117,6 +181,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None:
assert len(atoms[0]) == 32


def test_state_to_atoms_with_charge_spin() -> None:
"""Test conversion from state with charge/spin to ASE Atoms preserves charge/spin."""
mol = molecule("H2O")
mol.info["charge"] = 1.0
mol.info["spin"] = 1.0

state = ts.io.atoms_to_state([mol], DEVICE, DTYPE)
atoms = ts.io.state_to_atoms(state)

assert len(atoms) == 1
assert isinstance(atoms[0], Atoms)
assert "charge" in atoms[0].info
assert "spin" in atoms[0].info
assert atoms[0].info["charge"] == 1
assert atoms[0].info["spin"] == 1


def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None:
"""Test conversion from state tensors to list of ASE Atoms."""
atoms = ts.io.state_to_atoms(ar_double_sim_state)
Expand Down Expand Up @@ -259,6 +340,9 @@ def test_state_round_trip(
# since both use their own isotope masses based on species,
# not the ones in the state
assert torch.allclose(sim_state.masses, round_trip_state.masses)
# Check charge/spin round trip
assert torch.allclose(sim_state.charge, round_trip_state.charge)
assert torch.allclose(sim_state.spin, round_trip_state.spin)


def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizers_vs_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torch_sim.models.mace import MaceModel, MaceUrls
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments]
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)


if TYPE_CHECKING:
Expand Down
1 change: 1 addition & 0 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def velocity_verlet_step[T: MDState](

state.energy = model_output["energy"]
state.forces = model_output["forces"]
state.store_model_extras(model_output)
return momentum_step(state, dt_2)


Expand Down
Loading
Loading