From ff787c17314762cf467b2b11e147bd65996ea08a Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 24 Mar 2026 12:10:29 +0000 Subject: [PATCH] Fix CUDA context stack leak on crash by tracking outstanding push count. --- src/loch/_platforms/_cuda.py | 15 ++++-- tests/test_platform.py | 93 ++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 tests/test_platform.py diff --git a/src/loch/_platforms/_cuda.py b/src/loch/_platforms/_cuda.py index fbe298c..ee220f2 100644 --- a/src/loch/_platforms/_cuda.py +++ b/src/loch/_platforms/_cuda.py @@ -114,6 +114,7 @@ def __init__( # Use the primary context (shared with OpenMM and other CUDA users). self._pycuda_context = self._cuda_device.retain_primary_context() self._pycuda_context.push() + self._push_count = 1 self._device = self._pycuda_context.get_device() @@ -256,22 +257,26 @@ def push_context(self): Push the primary context onto the calling thread's context stack. """ self._pycuda_context.push() + self._push_count += 1 def pop_context(self): """ Pop the primary context from the calling thread's context stack. """ self._pycuda_context.pop() + self._push_count -= 1 def cleanup(self): """ - Clean up CUDA resources and pop the context pushed during __init__. + Clean up CUDA resources and pop all outstanding context pushes. """ if self._pycuda_context is not None: - try: - self._pycuda_context.pop() - except Exception: - pass + for _ in range(self._push_count): + try: + self._pycuda_context.pop() + except Exception: + pass + self._push_count = 0 self._pycuda_context = None @property diff --git a/tests/test_platform.py b/tests/test_platform.py new file mode 100644 index 0000000..81649b4 --- /dev/null +++ b/tests/test_platform.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock, patch + +import pytest + +# Skip the entire module if PyCUDA is not installed. +pytest.importorskip("pycuda") + + +def _make_backend(mock_driver): + """Instantiate CUDAPlatform with a mocked PyCUDA driver.""" + from loch._platforms._cuda import CUDAPlatform + + mock_driver.Device.count.return_value = 1 + mock_context = MagicMock() + mock_driver.Device.return_value.retain_primary_context.return_value = mock_context + + backend = CUDAPlatform( + device=0, + num_points=3, + num_batch=10, + num_waters=5, + num_atoms=100, + num_threads=32, + ) + return backend, mock_context + + +class TestCUDAPushCount: + """Tests for CUDAPlatform context push-count tracking.""" + + def test_initial_push_count(self): + """Push count starts at 1 after __init__ (one push for the lifetime context).""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, mock_context = _make_backend(mock_driver) + assert backend._push_count == 1 + mock_context.push.assert_called_once() + + def test_push_increments_count(self): + """push_context() increments _push_count.""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, _ = _make_backend(mock_driver) + backend.push_context() + assert backend._push_count == 2 + backend.push_context() + assert backend._push_count == 3 + + def test_pop_decrements_count(self): + """pop_context() decrements _push_count.""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, _ = _make_backend(mock_driver) + backend.push_context() + backend.pop_context() + assert backend._push_count == 1 + + def test_cleanup_pops_once_normally(self): + """cleanup() pops exactly once when no extra pushes are outstanding.""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, mock_context = _make_backend(mock_driver) + backend.cleanup() + assert mock_context.pop.call_count == 1 + assert backend._push_count == 0 + assert backend._pycuda_context is None + + def test_cleanup_pops_all_outstanding(self): + """cleanup() pops all outstanding pushes, simulating a crash mid-move.""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, mock_context = _make_backend(mock_driver) + # Simulate two push_context() calls that were never popped (e.g. + # two GCMC moves crashed before their paired pop_context()). + backend.push_context() + backend.push_context() + assert backend._push_count == 3 + backend.cleanup() + assert mock_context.pop.call_count == 3 + assert backend._push_count == 0 + assert backend._pycuda_context is None + + def test_cleanup_handles_pop_exception(self): + """cleanup() continues safely if pop() raises (e.g. stack already empty).""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, mock_context = _make_backend(mock_driver) + mock_context.pop.side_effect = Exception("context stack is empty") + backend.cleanup() + assert backend._pycuda_context is None + + def test_cleanup_idempotent(self): + """Calling cleanup() a second time is a no-op.""" + with patch("loch._platforms._cuda._cuda") as mock_driver: + backend, mock_context = _make_backend(mock_driver) + backend.cleanup() + pop_count_after_first = mock_context.pop.call_count + backend.cleanup() + assert mock_context.pop.call_count == pop_count_after_first