diff --git a/dpdata/lammps/lmp.py b/dpdata/lammps/lmp.py index e259aa5c..c9d60ec5 100644 --- a/dpdata/lammps/lmp.py +++ b/dpdata/lammps/lmp.py @@ -3,6 +3,8 @@ import numpy as np +from dpdata.periodic_table import ELEMENTS, Element + ptr_float_fmt = "%15.10f" ptr_int_fmt = "%6d" ptr_key_fmt = "%15s" @@ -484,6 +486,47 @@ def rotate_to_lower_triangle( return cell, coord +def _get_lammps_masses(system) -> np.ndarray | None: + """Get masses for the LAMMPS ``Masses`` section. + + Prefer explicitly stored masses when available. Otherwise, infer masses from + ``atom_names`` when all names are valid chemical element symbols. + + Parameters + ---------- + system : dict + System data dictionary + + Returns + ------- + np.ndarray or None + Per-type masses aligned with ``atom_names``. Returns ``None`` when the + masses cannot be determined safely. + + Raises + ------ + ValueError + If explicit ``system["masses"]`` is present but does not match the + length of ``atom_names``. + """ + atom_names = system["atom_names"] + masses = system.get("masses") + if masses is not None: + masses = np.asarray(masses, dtype=float) + if masses.ndim != 1 or len(masses) != len(atom_names): + raise ValueError( + 'Explicit system["masses"] must be a 1D array with the same ' + 'length as system["atom_names"] to write the LAMMPS Masses ' + "section." + ) + return masses + + if not all(name in ELEMENTS for name in atom_names): + return None + + return np.array([Element(name).mass for name in atom_names], dtype=float) + + def from_system_data(system, f_idx=0): ret = "" ret += "\n" @@ -514,6 +557,16 @@ def from_system_data(system, f_idx=0): cell[2][1], ) # noqa: UP031 ret += "\n" + + masses = _get_lammps_masses(system) + if masses is not None: + ret += "Masses\n" + ret += "\n" + mass_fmt = ptr_int_fmt + " " + ptr_float_fmt + " # %s\n" # noqa: UP031 + for ii, (mass, atom_name) in enumerate(zip(masses, system["atom_names"])): + ret += mass_fmt % (ii + 1, mass, atom_name) + ret += "\n" + ret += "Atoms # atomic\n" ret += "\n" coord_fmt = ( diff --git a/tests/test_lammps_lmp_dump.py b/tests/test_lammps_lmp_dump.py index c2c2f811..a717c6cf 100644 --- a/tests/test_lammps_lmp_dump.py +++ b/tests/test_lammps_lmp_dump.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import tempfile import unittest import numpy as np @@ -9,6 +10,9 @@ from dpdata.lammps.lmp import rotate_to_lower_triangle +TEST_DIR = os.path.dirname(__file__) +POSCAR_CONF_LMP = os.path.join(TEST_DIR, "poscars", "conf.lmp") + class TestLmpDump(unittest.TestCase, TestPOSCARoh): def setUp(self): @@ -100,5 +104,39 @@ def test_negative_diagonal(self): ) +class TestLmpDumpMasses(unittest.TestCase): + def test_dump_known_elements_writes_masses(self): + system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"]) + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tmp_masses.lmp") + system.to_lammps_lmp(output) + with open(output) as f: + content = f.read() + + self.assertIn("Masses\n", content) + self.assertIn(" 1 15.9994000000 # O", content) + self.assertIn(" 2 1.0079400000 # H", content) + self.assertLess(content.index("Masses\n"), content.index("Atoms # atomic\n")) + + def test_dump_unknown_types_skips_masses(self): + system = dpdata.System(POSCAR_CONF_LMP) + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tmp_unknown_types.lmp") + system.to_lammps_lmp(output) + with open(output) as f: + content = f.read() + + self.assertNotIn("Masses\n", content) + + def test_dump_rejects_mismatched_explicit_masses(self): + system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"]) + system.data["masses"] = np.array([15.9994, 1.00794, 99.0]) + + with tempfile.TemporaryDirectory() as tmpdir: + output = os.path.join(tmpdir, "tmp_bad_masses.lmp") + with self.assertRaisesRegex(ValueError, r'system\["masses"\]'): + system.to_lammps_lmp(output) + + if __name__ == "__main__": unittest.main()