Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions dpdata/lammps/lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import numpy as np

from dpdata.periodic_table import ELEMENTS, Element

ptr_float_fmt = "%15.10f"
ptr_int_fmt = "%6d"
ptr_key_fmt = "%15s"
Expand Down Expand Up @@ -484,6 +486,47 @@ def rotate_to_lower_triangle(
return cell, coord


def _get_lammps_masses(system) -> np.ndarray | None:
"""Get masses for the LAMMPS ``Masses`` section.

Prefer explicitly stored masses when available. Otherwise, infer masses from
``atom_names`` when all names are valid chemical element symbols.

Parameters
----------
system : dict
System data dictionary

Returns
-------
np.ndarray or None
Per-type masses aligned with ``atom_names``. Returns ``None`` when the
masses cannot be determined safely.

Raises
------
ValueError
If explicit ``system["masses"]`` is present but does not match the
length of ``atom_names``.
"""
atom_names = system["atom_names"]
masses = system.get("masses")
if masses is not None:
masses = np.asarray(masses, dtype=float)
if masses.ndim != 1 or len(masses) != len(atom_names):
raise ValueError(
'Explicit system["masses"] must be a 1D array with the same '
'length as system["atom_names"] to write the LAMMPS Masses '
"section."
)
return masses

if not all(name in ELEMENTS for name in atom_names):
return None

return np.array([Element(name).mass for name in atom_names], dtype=float)


def from_system_data(system, f_idx=0):
ret = ""
ret += "\n"
Expand Down Expand Up @@ -514,6 +557,16 @@ def from_system_data(system, f_idx=0):
cell[2][1],
) # noqa: UP031
ret += "\n"

masses = _get_lammps_masses(system)
if masses is not None:
ret += "Masses\n"
ret += "\n"
mass_fmt = ptr_int_fmt + " " + ptr_float_fmt + " # %s\n" # noqa: UP031
for ii, (mass, atom_name) in enumerate(zip(masses, system["atom_names"])):
ret += mass_fmt % (ii + 1, mass, atom_name)
ret += "\n"

ret += "Atoms # atomic\n"
ret += "\n"
coord_fmt = (
Expand Down
38 changes: 38 additions & 0 deletions tests/test_lammps_lmp_dump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import tempfile
import unittest

import numpy as np
Expand All @@ -9,6 +10,9 @@

from dpdata.lammps.lmp import rotate_to_lower_triangle

TEST_DIR = os.path.dirname(__file__)
POSCAR_CONF_LMP = os.path.join(TEST_DIR, "poscars", "conf.lmp")


class TestLmpDump(unittest.TestCase, TestPOSCARoh):
def setUp(self):
Expand Down Expand Up @@ -100,5 +104,39 @@ def test_negative_diagonal(self):
)


class TestLmpDumpMasses(unittest.TestCase):
def test_dump_known_elements_writes_masses(self):
system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"])
with tempfile.TemporaryDirectory() as tmpdir:
output = os.path.join(tmpdir, "tmp_masses.lmp")
system.to_lammps_lmp(output)
with open(output) as f:
content = f.read()

self.assertIn("Masses\n", content)
self.assertIn(" 1 15.9994000000 # O", content)
self.assertIn(" 2 1.0079400000 # H", content)
self.assertLess(content.index("Masses\n"), content.index("Atoms # atomic\n"))

def test_dump_unknown_types_skips_masses(self):
system = dpdata.System(POSCAR_CONF_LMP)
with tempfile.TemporaryDirectory() as tmpdir:
output = os.path.join(tmpdir, "tmp_unknown_types.lmp")
system.to_lammps_lmp(output)
with open(output) as f:
content = f.read()

self.assertNotIn("Masses\n", content)

def test_dump_rejects_mismatched_explicit_masses(self):
system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"])
system.data["masses"] = np.array([15.9994, 1.00794, 99.0])

with tempfile.TemporaryDirectory() as tmpdir:
output = os.path.join(tmpdir, "tmp_bad_masses.lmp")
with self.assertRaisesRegex(ValueError, r'system\["masses"\]'):
system.to_lammps_lmp(output)


if __name__ == "__main__":
unittest.main()
Loading