diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 44284c26..d259e2ec 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -22,7 +22,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py index 88423ad7..f7977c6b 100644 --- a/tests/models/test_fairchem_legacy.py +++ b/tests/models/test_fairchem_legacy.py @@ -22,7 +22,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_graphpes_framework.py b/tests/models/test_graphpes_framework.py index 7422e84b..7487914a 100644 --- a/tests/models/test_graphpes_framework.py +++ b/tests/models/test_graphpes_framework.py @@ -21,7 +21,7 @@ from torch_sim.models.graphpes_framework import GraphPESWrapper except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"graph-pes not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 322f3d12..8642f67b 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -20,7 +20,7 @@ from torch_sim.models.mace import MaceModel, MaceUrls except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # mace_omol is optional (added in newer MACE versions) try: diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index ee495aa7..b8ed7809 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -19,7 +19,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"mattersim not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"mattersim not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index 1519425f..c42fa845 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -18,7 +18,7 @@ from torch_sim.models.metatomic import MetatomicModel except ImportError: pytest.skip( - f"metatomic not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"metatomic not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 4d238ee6..51f73200 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -19,7 +19,7 @@ from torch_sim.models.nequip_framework import NequIPFrameworkModel except (ImportError, ModuleNotFoundError): pytest.skip( - f"nequip not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"nequip not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 98311e72..6bdf1376 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -17,7 +17,7 @@ from torch_sim.models.orb import OrbModel except ImportError: - pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) @pytest.fixture diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 5a751bb6..b5e759e4 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -20,7 +20,7 @@ except ImportError: pytest.skip( - f"sevenn not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"sevenn not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index c7753bc1..6ad6af76 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -22,7 +22,7 @@ from torch_sim.models.mace import MaceModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None: diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 00000000..a6c42381 --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,165 @@ +import pytest +import torch + +import torch_sim as ts + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class TestExtras: + def test_system_extras_construction(self): + """Extras can be passed at construction time.""" + field = torch.randn(1, 3) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": field}, + ) + assert torch.equal(state.external_E_field, field) + + def test_atom_extras_construction(self): + """Per-atom extras work at construction time.""" + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + assert torch.equal(state.tags, tags) + + def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState): + with pytest.raises(AttributeError, match="nonexistent_key"): + _ = cu_sim_state.nonexistent_key + + def test_post_init_validation_rejects_bad_shape(self): + with pytest.raises(ValueError, match="leading dim must be n_systems"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"bad": torch.randn(5, 3)}, + ) + + def test_from_state_preserves_extras(self, cu_sim_state: ts.SimState): + field = torch.randn(cu_sim_state.n_systems, 3, device=cu_sim_state.device) + cu_sim_state.set_extras("E", field, scope="per-system") + new = ts.SimState.from_state(cu_sim_state) + assert torch.equal(new.E, field) + + def test_extras_cannot_shadow_declared_fields(self, cu_sim_state: ts.SimState): + # set_extras should raise if attempting to shadow + with pytest.raises(ValueError, match="shadows an existing attribute"): + cu_sim_state.set_extras( + "cell", torch.zeros(cu_sim_state.n_systems, 3), scope="per-system" + ) + + def test_construction_extras_cannot_shadow(self): + # Post-init validation should also catch shadowing during construction + with pytest.raises(ValueError, match="shadows an existing attribute"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"cell": torch.zeros(1, 3)}, + ) + + # store_model_extras + def test_store_model_extras_canonical_keys_not_stored( + self, si_double_sim_state: ts.SimState + ): + """Canonical keys (energy, forces, stress) must not land in extras.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "forces": torch.randn(state.n_atoms, 3), + "stress": torch.randn(state.n_systems, 3, 3), + } + ) + assert not state._system_extras # noqa: SLF001 + assert not state._atom_extras # noqa: SLF001 + + def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_systems go into system_extras.""" + state = si_double_sim_state.clone() + dipole = torch.randn(state.n_systems, 3) + state.store_model_extras( + {"energy": torch.randn(state.n_systems), "dipole": dipole} + ) + assert torch.equal(state.dipole, dipole) + + def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_atoms go into atom_extras.""" + state = si_double_sim_state.clone() + charges = torch.randn(state.n_atoms) + density = torch.randn(state.n_atoms, 8) + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "charges": charges, + "density_coefficients": density, + } + ) + assert torch.equal(state.charges, charges) + assert state.density_coefficients.shape == (state.n_atoms, 8) + + def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState): + """0-d tensors and non-Tensor values are silently ignored.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "scalar": torch.tensor(3.14), + "string": "not a tensor", + } + ) + assert not state.has_extras("scalar") + assert not state.has_extras("string") + + +def test_system_extras_atoms_roundtrip(): + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])}, + ) + atoms_list = state.to_atoms() + assert "external_E_field" in atoms_list[0].info + restored = ts.io.atoms_to_state( + atoms_list, + system_extras_keys=["external_E_field"], + ) + assert torch.allclose(restored.external_E_field, state.external_E_field) + + +def test_atom_extras_atoms_roundtrip(): + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + atoms_list = state.to_atoms() + assert "tags" in atoms_list[0].arrays + restored = ts.io.atoms_to_state( + atoms_list, + atom_extras_keys=["tags"], + ) + assert torch.allclose(restored.tags, state.tags) diff --git a/tests/test_io.py b/tests/test_io.py index 2bb4f017..8e1de1ac 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import torch from ase import Atoms +from ase.build import molecule from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -91,6 +92,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: ) +@pytest.mark.parametrize( + ("charge", "spin", "expected_charge", "expected_spin"), + [ + (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin + (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin + (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + ], +) +def test_atoms_to_state_with_charge_spin( + charge: float | None, + spin: float | None, + expected_charge: float, + expected_spin: float, +) -> None: + """Test conversion from ASE Atoms with charge and spin to state tensors.""" + mol = molecule("H2O") + if charge is not None: + mol.info["charge"] = charge + if spin is not None: + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (1,) + assert state.spin.shape == (1,) + assert state.charge[0].item() == expected_charge + assert state.spin[0].item() == expected_spin + + +def test_multiple_atoms_to_state_with_charge_spin() -> None: + """Test conversion from multiple ASE Atoms with different charge/spin values.""" + mol1 = molecule("H2O") + mol1.info["charge"] = 1.0 + mol1.info["spin"] = 1.0 + + mol2 = molecule("CH4") + mol2.info["charge"] = -1.0 + mol2.info["spin"] = 0.0 + + mol3 = molecule("NH3") + mol3.info["charge"] = 0.0 + mol3.info["spin"] = 2.0 + + state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (3,) + assert state.spin.shape == (3,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + assert state.charge[2].item() == 0.0 + assert state.spin[0].item() == 1.0 + assert state.spin[1].item() == 0.0 + assert state.spin[2].item() == 2.0 + + def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) @@ -117,6 +181,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: assert len(atoms[0]) == 32 +def test_state_to_atoms_with_charge_spin() -> None: + """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" + mol = molecule("H2O") + mol.info["charge"] = 1.0 + mol.info["spin"] = 1.0 + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + atoms = ts.io.state_to_atoms(state) + + assert len(atoms) == 1 + assert isinstance(atoms[0], Atoms) + assert "charge" in atoms[0].info + assert "spin" in atoms[0].info + assert atoms[0].info["charge"] == 1 + assert atoms[0].info["spin"] == 1 + + def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) @@ -259,6 +340,9 @@ def test_state_round_trip( # since both use their own isotope masses based on species, # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) + # Check charge/spin round trip + assert torch.allclose(sim_state.charge, round_trip_state.charge) + assert torch.allclose(sim_state.spin, round_trip_state.spin) def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 328507f4..d37ea25e 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -18,7 +18,7 @@ from torch_sim.models.mace import MaceModel, MaceUrls except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) if TYPE_CHECKING: diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 76ff69c1..49d9b62a 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -233,6 +233,7 @@ def velocity_verlet_step[T: MDState]( state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_2) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 094a7deb..63cf56cd 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -616,7 +616,7 @@ def npt_langevin_init( logger.warning(msg) # Create the initial state - return NPTLangevinState.from_state( + npt_state = NPTLangevinState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -630,6 +630,8 @@ def npt_langevin_init( cell_masses=cell_masses, cell_alpha=cell_alpha, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1063/1.4901303") @@ -691,6 +693,7 @@ def npt_langevin_step( model_output = model(state) state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Store initial values for integration forces = state.forces @@ -730,6 +733,7 @@ def npt_langevin_step( state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Compute updated pressure force F_p_n_new = _compute_cell_force( @@ -1271,6 +1275,7 @@ def _npt_nose_hoover_inner_step( state.set_constrained_momenta(momenta) state.forces = model_output["forces"] state.energy = model_output["energy"] + state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass @@ -1423,7 +1428,7 @@ def npt_nose_hoover_init( logger.warning(msg) # Create initial state - return NPTNoseHooverState.from_state( + npt_state = NPTNoseHooverState.from_state( state, momenta=momenta, energy=energy, @@ -1438,6 +1443,8 @@ def npt_nose_hoover_init( barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1080/00268979600100761") @@ -2062,6 +2069,7 @@ def npt_crescale_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt_tensor / 2) @@ -2137,6 +2145,7 @@ def npt_crescale_independent_lengths_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2213,6 +2222,7 @@ def npt_crescale_average_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2290,6 +2300,7 @@ def npt_crescale_isotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2363,7 +2374,7 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState.from_state( + npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -2372,3 +2383,5 @@ def npt_crescale_init( tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) + npt_state.store_model_extras(model_output) + return npt_state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 07f3064b..316ef78c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,12 +57,14 @@ def nve_init( state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state def nve_step( @@ -100,5 +102,6 @@ def nve_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8e74bf85..4cb6bb51 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -132,6 +132,8 @@ def nvt_langevin_init( energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state @dcite("10.1098/rspa.2016.0138") @@ -191,6 +193,7 @@ def nvt_langevin_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_tensor / 2) @@ -321,7 +324,7 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState.from_state( + nh_state = NVTNoseHooverState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -330,6 +333,8 @@ def nvt_nose_hoover_init( chain=chain_fns.initialize(dof_per_system, KE, kT_tensor), _chain_fns=chain_fns, ) + nh_state.store_model_extras(model_output) + return nh_state @dcite("10.1080/00268979600100761") @@ -609,12 +614,14 @@ def nvt_vrescale_init( state.rng, ) - return NVTVRescaleState.from_state( + vr_state = NVTVRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + vr_state.store_model_extras(model_output) + return vr_state @dcite("10.1063/1.2408420") diff --git a/torch_sim/io.py b/torch_sim/io.py index edce091e..9df153d7 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -97,6 +97,14 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: if spin is not None: atoms.info["spin"] = int(spin[sys_idx].item()) + # Write system extras to atoms.info + for key, val in state.system_extras.items(): + atoms.info[key] = val[sys_idx].detach().cpu().numpy() + + # Write atom extras to atoms.arrays + for key, val in state.atom_extras.items(): + atoms.arrays[key] = val[mask].detach().cpu().numpy() + atoms_list.append(atoms) return atoms_list @@ -244,6 +252,8 @@ def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None, + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -252,6 +262,10 @@ def atoms_to_state( device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) + system_extras_keys (list[str]): Optional list of keys to read from atoms.info + into _system_extras + atom_extras_keys (list[str]): Optional list of keys to read from atoms.arrays + into _atom_extras Returns: SimState: TorchSim SimState object. @@ -305,6 +319,24 @@ def atoms_to_state( [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device ) + _system_extras: dict[str, torch.Tensor] = {} + if system_extras_keys: + for key in system_extras_keys: + vals = [at.info.get(key) for at in atoms_list] + if all(v is not None for v in vals): + _system_extras[key] = torch.tensor( + np.stack(vals), dtype=dtype, device=device + ) + + _atom_extras: dict[str, torch.Tensor] = {} + if atom_extras_keys: + for key in atom_extras_keys: + arrays = [at.arrays.get(key) for at in atoms_list] + if all(a is not None for a in arrays): + _atom_extras[key] = torch.tensor( + np.concatenate(arrays), dtype=dtype, device=device + ) + return ts.SimState( positions=positions, masses=masses, @@ -314,6 +346,8 @@ def atoms_to_state( system_idx=system_idx, charge=charge, spin=spin, + _system_extras=_system_extras, + _atom_extras=_atom_extras, ) diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index 05e452f1..432c0445 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -423,6 +423,9 @@ def forward( # noqa: C901 for idx, (n, c) in enumerate( zip(natoms, torch.cumsum(natoms, dim=0), strict=False) ): + # NOTE: Legacy FairChem models (v1) do not support charge/spin, + # so we don't pass these fields to the Data object. + # The model will simply ignore charge/spin and treat all systems as neutral. data_list.append( Data( pos=sim_state.positions[c - n : c].detach().clone(), diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 63681c8e..c395a64c 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -239,7 +239,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 This validator creates small test systems (silicon and iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -269,6 +269,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 system_idx = sim_state.system_idx og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + og_charge = sim_state.charge.clone() + og_spin = sim_state.spin.clone() if check_detached and hasattr(model, "retain_graph"): model.__dict__["retain_graph"] = True @@ -289,6 +291,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_charge, sim_state.charge): + raise ValueError(f"{og_charge=} != {sim_state.charge=}") + if not torch.allclose(og_spin, sim_state.spin): + raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -348,3 +354,43 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") + + # Test that models can handle non-zero charge and spin + benzene_atoms = molecule("C6H6") + benzene_atoms.info["charge"] = 1.0 + benzene_atoms.info["spin"] = 1.0 + charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + + # Ensure state has charge/spin before testing model + if charged_state.charge is None or charged_state.spin is None: + raise ValueError( + "atoms_to_state did not extract charge/spin. " + "Cannot test model charge/spin handling." + ) + + # Test that model can handle charge/spin without crashing + og_charged_charge = charged_state.charge.clone() + og_charged_spin = charged_state.spin.clone() + try: + charged_output = model.forward(charged_state) + except Exception as e: + raise ValueError( + "Model failed to handle non-zero charge/spin. " + "Models must be able to process states with charge and spin values. " + ) from e + + # Verify model didn't mutate charge/spin + if not torch.allclose(og_charged_charge, charged_state.charge): + raise ValueError( + f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" + ) + if not torch.allclose(og_charged_spin, charged_state.spin): + raise ValueError( + f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + ) + # Verify output shape is still correct + if charged_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect with charge/spin: " + f"{charged_output['energy'].shape=} != (1,)" + ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 7dfbb765..97d478a3 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -347,6 +347,13 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Propagate additional model outputs (e.g. dipole, charges, etc.) + for key, val in out.items(): + if key not in ("energy", "forces", "stress") and isinstance( + val, torch.Tensor + ): + results[key] = val.detach() + return results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 04dfde31..8a4a0d37 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -223,7 +223,7 @@ def swap_mc_init( """ model_output = model(state) - return SwapMCState( + mc_state = SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, @@ -233,6 +233,8 @@ def swap_mc_init( energy=model_output["energy"], _constraints=state.constraints, ) + mc_state.store_model_extras(model_output) + return mc_state def swap_mc_step( @@ -292,5 +294,6 @@ def swap_mc_step( state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() + state.store_model_extras(model_output) return state diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 5cccf96c..5c617c86 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -139,6 +139,19 @@ def bfgs_init( n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + bfgs_attrs = { + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # [S] + } + if cell_filter is not None: # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) @@ -149,59 +162,31 @@ def bfgs_init( cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - # Note (AG): At initialization, deform_grad is identity, so we have: - # fractional = Cartesian / cell and scaled forces = forces @ I = forces - # For ASE compatibility, we need to store prev_positions as fractional coords - # and prev_forces as scaled forces - - # Get initial deform_grad (identity at start since reference_cell = current_cell) + # At initialization, deform_grad is identity, so fractional = Cartesian + # and scaled forces = forces. For ASE compatibility, store prev_positions + # as fractional coords and prev_forces as scaled forces. reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = cell_filters.deform_grad( reference_cell.mT, state.cell.mT ) # [S, 3, 3] - # Initial fractional positions = solve(deform_grad, positions) = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D_ext, D_ext] - # Note (AG): Store fractional positions and scaled forces - # for ASE compatibility - "prev_forces": scaled_forces, # [N, 3] (scaled) - "prev_positions": frac_positions, # [N, 3] (fractional) - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "reference_cell": reference_cell, # [S, 3, 3] - "cell_filter": cell_filter_funcs, - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - cell_state = CellBFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_attrs["hessian"] = hessian # [S, D_ext, D_ext] + bfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) + bfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + bfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + bfgs_attrs["cell_filter"] = cell_filter_funcs + + cell_state = CellBFGSState.from_state(state, **bfgs_attrs) # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -211,6 +196,7 @@ def bfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms @@ -218,31 +204,11 @@ def bfgs_init( hessian = torch.eye(dim, device=device, dtype=dtype).unsqueeze(0).repeat( n_systems, 1, 1 ) * alpha_t.view(n_systems, 1, 1) # [S, D, D] + bfgs_attrs["hessian"] = hessian # [S, D, D] - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D, D] - "prev_forces": forces.clone(), # [N, 3] - "prev_positions": state.positions.clone(), # [N, 3] - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - return BFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_state = BFGSState.from_state(state, **bfgs_attrs) + bfgs_state.store_model_extras(model_output) + return bfgs_state def bfgs_step( # noqa: C901, PLR0915 @@ -542,6 +508,7 @@ def bfgs_step( # noqa: C901, PLR0915 state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] + state.store_model_extras(model_output) # Update cell forces for next step # Update cell forces for cell state: [S, 3, 3] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index e35ec1c1..836485bf 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -105,9 +105,12 @@ def fire_init( cell_state.cell_forces.shape, torch.nan, device=device, dtype=dtype ) + cell_state.store_model_extras(model_output) return cell_state # Create regular FireState without cell optimization - return FireState.from_state(state, **fire_attrs) + fire_state = FireState.from_state(state, **fire_attrs) + fire_state.store_model_extras(model_output) + return fire_state def fire_step( @@ -214,6 +217,7 @@ def _vv_fire_step[T: "FireState | CellFireState"]( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): @@ -460,6 +464,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 6f940ff0..7356ffe8 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -53,6 +53,8 @@ def gradient_descent_init( "stress": stress, } + state.store_model_extras(model_output) + if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) optim_attrs["reference_cell"] = state.cell.clone() @@ -112,6 +114,7 @@ def gradient_descent_step( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellOptimState): diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 0221f92d..e4791157 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -191,22 +191,10 @@ def lbfgs_init( if step_size_tensor.ndim == 0: step_size_tensor = step_size_tensor.expand(n_systems) - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - # Optimization state + lbfgs_attrs = { "forces": forces, # [N, 3] "energy": energy, # [S] "stress": stress, # [S, 3, 3] or None - # L-BFGS specific state "prev_forces": forces.clone(), # [N, 3] "prev_positions": state.positions.clone(), # [N, 3] "s_history": s_history, # [S, 0, M, 3] @@ -226,41 +214,35 @@ def lbfgs_init( reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = deform_grad(reference_cell.mT, state.cell.mT) # [S, 3, 3] - # Initial fractional positions = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) # [N, 3] - common_args["reference_cell"] = reference_cell # [S, 3, 3] - common_args["cell_filter"] = cell_filter_funcs - # Store fractional positions and scaled forces for ASE compatibility - common_args["prev_positions"] = frac_positions # [N, 3] - common_args["prev_forces"] = scaled_forces # [N, 3] + lbfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + lbfgs_attrs["cell_filter"] = cell_filter_funcs + lbfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + lbfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) - # History shape: [S, H, M+3, 3] where M = global_max_atoms extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 - common_args["s_history"] = torch.zeros( + lbfgs_attrs["s_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - common_args["y_history"] = torch.zeros( + lbfgs_attrs["y_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - cell_state = CellLBFGSState(**common_args) # ty: ignore[invalid-argument-type] + cell_state = CellLBFGSState.from_state(state, **lbfgs_attrs) # Initialize cell-specific attributes # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -270,9 +252,12 @@ def lbfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state - return LBFGSState(**common_args) # ty: ignore[invalid-argument-type] + lbfgs_state = LBFGSState.from_state(state, **lbfgs_attrs) + lbfgs_state.store_model_extras(model_output) + return lbfgs_state def lbfgs_step( # noqa: PLR0915, C901 @@ -531,6 +516,7 @@ def lbfgs_step( # noqa: PLR0915, C901 new_forces = model_output["forces"] # [N, 3] new_energy = model_output["energy"] # [S] new_stress = model_output.get("stress") # [S, 3, 3] or None + state.store_model_extras(model_output) # Update cell forces for next step: [S, 3, 3] if isinstance(state, CellLBFGSState): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 4f60d3cd..273adfd8 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState -from torch_sim.state import SimState +from torch_sim.state import _CANONICAL_MODEL_KEYS, SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem @@ -731,7 +731,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -835,8 +835,25 @@ class StaticState(SimState): else torch.full_like(sub_state.cell, fill_value=float("nan")) ), ) + static_state.store_model_extras(model_outputs) props = trajectory_reporter.report(static_state, 0, model=model) + + # Merge extra model outputs into per-system property dicts + # TODO: this should be cleaner? + extra_keys = {k for k in model_outputs if k not in _CANONICAL_MODEL_KEYS} + if extra_keys: + for sys_idx, sys_props in enumerate(props): + for key in extra_keys: + val = model_outputs[key] + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + if val.shape[0] == static_state.n_atoms: + mask = static_state.system_idx == sys_idx + sys_props[key] = val[mask] + elif val.shape[0] == static_state.n_systems: + sys_props[key] = val[sys_idx : sys_idx + 1] + all_props.extend(props) if tqdm_pbar: diff --git a/torch_sim/state.py b/torch_sim/state.py index b3cb8f07..7c9fe110 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -32,6 +32,10 @@ ) +# Canonical model output keys that are handled explicitly by integrators/runners +_CANONICAL_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + + def coerce_prng(rng: PRNGLike, device: DeviceLikeType | None) -> torch.Generator: """Coerce an int seed or existing Generator into a ``torch.Generator``. @@ -132,7 +136,9 @@ class SimState: charge: torch.Tensor | None = field(default=None) spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ - _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _constraints: list["Constraint"] = field(default_factory=list) + _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) + _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) _rng: PRNGLike = field(default=None, repr=False) if TYPE_CHECKING: @@ -246,6 +252,29 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + # Validate extras shapes and prevent shadowing + all_attrs = self._get_all_attributes() + for key, val in self._system_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"System extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"System extra '{key}' must be a torch.Tensor") + if val.shape[0] != n_systems: + raise ValueError( + f"System extra '{key}' leading dim must be " + f"n_systems={n_systems}, got {val.shape[0]}" + ) + for key, val in self._atom_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"Atom extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"Atom extra '{key}' must be a torch.Tensor") + if val.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extra '{key}' leading dim must be " + f"n_atoms={self.n_atoms}, got {val.shape[0]}" + ) + @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" @@ -253,9 +282,72 @@ def _get_all_attributes(cls) -> set[str]: cls._atom_attributes | cls._system_attributes | cls._global_attributes - | {"_constraints"} + | {"_constraints", "_system_extras", "_atom_extras"} + ) + + def __getattr__(self, name: str) -> Any: + """Allow attribute-style access to extras dict entries.""" + # Guard: don't look up private attrs in extras (avoids recursion during init) + if name.startswith("_"): + raise AttributeError(name) + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + return extras[name] + + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. + raise AttributeError( + f"'{type(self).__name__}' has no attribute or extra '{name}'" ) + @property + def system_extras(self) -> dict[str, torch.Tensor]: + """Get the system extras.""" + return self._system_extras + + @property + def atom_extras(self) -> dict[str, torch.Tensor]: + """Get the atom extras.""" + return self._atom_extras + + def has_extras(self, key: str) -> bool: + """Check if an extras key exists.""" + return key in self._system_extras or key in self._atom_extras + + def store_model_extras(self, model_output: dict[str, torch.Tensor]) -> None: + """Store non-canonical model outputs into state extras (in-place). + + Any key in *model_output* that is not in ``{"energy", "forces", "stress"}`` + is classified by its leading dimension: + + * ``n_atoms`` → stored in ``_atom_extras`` + * ``n_systems`` → stored in ``_system_extras`` + * otherwise → skipped (ambiguity or scalar) + + When ``n_atoms == n_systems`` (single-atom system), the tensor is stored as + per-atom by convention. + + Args: + model_output: Full dict returned by ``model.forward()``. + """ + n_atoms = self.n_atoms + n_systems = self.n_systems + + for key, val in model_output.items(): + if key in _CANONICAL_MODEL_KEYS: + continue + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + leading = val.shape[0] + if leading == n_atoms: + self._atom_extras[key] = val + elif leading == n_systems: + self._system_extras[key] = val + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -797,11 +889,25 @@ def _state_to_device[T: SimState]( elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(device=device) for k, v in attrs[extras_key].items() + } + if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + # Update floating point extras to new dtype + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(dtype=dtype) if v.is_floating_point() else v + for k, v in attrs[extras_key].items() + } return type(state)(**attrs) @@ -888,6 +994,13 @@ def _filter_attrs_by_index( val[system_indices] if isinstance(val, torch.Tensor) else val ) + filtered_attrs["_system_extras"] = { + key: val[system_indices] for key, val in state.system_extras.items() + } + filtered_attrs["_atom_extras"] = { + key: val[atom_indices] for key, val in state.atom_extras.items() + } + return filtered_attrs @@ -920,6 +1033,14 @@ def _split_state[T: SimState](state: T) -> list[T]: global_attrs = dict(get_attrs_for_scope(state, "global")) + split_system_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._system_extras.items(): # noqa: SLF001 + split_system_extras[key] = list(torch.split(val, 1, dim=0)) + + split_atom_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._atom_extras.items(): # noqa: SLF001 + split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) @@ -946,6 +1067,12 @@ def _split_state[T: SimState](state: T) -> list[T]: **per_system_dict, # Add the global attributes **global_attrs, + "_system_extras": { + key: split_system_extras[key][sys_idx] for key in split_system_extras + }, + "_atom_extras": { + key: split_atom_extras[key][sys_idx] for key in split_atom_extras + }, } start_idx = int(cumsum_atoms[sys_idx].item()) @@ -1092,6 +1219,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) + atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] @@ -1113,6 +1242,12 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) + # Collect extras + for key, val in state.system_extras.items(): + system_extras_tensors[key].append(val) + for key, val in state.atom_extras.items(): + atom_extras_tensors[key].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -1183,6 +1318,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Concatenate extras + concatenated["_system_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() + } + concatenated["_atom_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in atom_extras_tensors.items() + } + # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states]