Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cadet/cadet.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,38 @@ def __setstate__(self, state):
# Restore the state and cast to addict.Dict() to add __frozen attributes
state = Dict(state)
self.__dict__.update(state)

def initialize_simulation(self) -> ReturnInformation:
"""
Initialize a CADET simulation without running it.

Returns
-------
ReturnInformation
Information about the initialization result.
"""

return_information = self.cadet_runner.initialize_simulation(self)

return return_information

def perform_simulation_step(self, t_end: float) -> tuple[ReturnInformation, float]:
"""
Perform a single simulation step until time t_end.

Parameters
----------
t_end : float
Target end time for this simulation step.

Returns
-------
tuple[ReturnInformation, float]
- ReturnInformation: Information about the step result
- float: Actually reached time (may differ from t_end)
"""
return self.cadet_runner.perform_simulation_step(t_end)

def end_simulation(self) -> ReturnInformation:

return self.cadet_runner.end_simulation()
127 changes: 126 additions & 1 deletion cadet/cadet_dll.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
point_bool = ctypes.POINTER(ctypes.c_bool)
point_int = ctypes.POINTER(ctypes.c_int)
point_double = ctypes.POINTER(ctypes.c_double)
double = ctypes.c_double

# Values of cdtResult
c_cadet_result = ctypes.c_int
Expand Down Expand Up @@ -106,6 +107,11 @@ class CADET_API_V1_SIGNATURES:
signatures_1_1_0a1 = {}
signatures_1_1_0a1['timeout'] = ('return', 'drv', 'timeout')

signatures_1_1_0a2 = {}
signatures_1_1_0a2['initializeSimulation'] = ('return', 'drv', 'parameterProvider')
signatures_1_1_0a2['performSimulationStep'] = ('return', 'drv', 'tEnd','tReached')
signatures_1_1_0a2['endSimulation'] = ('return', 'drv')

# Mappings for common ctypes parameters
lookup_prototype = {
None: None,
Expand Down Expand Up @@ -136,6 +142,8 @@ class CADET_API_V1_SIGNATURES:
'keepParticleSingletonDimension': point_bool,
'timeSim': point_double,
'timeout': point_double,
'tEnd': double,
'tReached': point_double
}

lookup_output_argument_type = {
Expand Down Expand Up @@ -169,6 +177,12 @@ class CADET_API_V1_SIGNATURES:
_sigs_1_1_0a1.update(CADET_API_V1_SIGNATURES.signatures_1_1_0a1)
_VERSION_SIGNATURES[Version("1.1.0a1")] = _sigs_1_1_0a1

_sigs_1_1_0a2 = dict(CADET_API_V1_SIGNATURES.signatures_1_0_0)
_sigs_1_1_0a2.update(CADET_API_V1_SIGNATURES.signatures_1_1_0a1)
_sigs_1_1_0a2.update(CADET_API_V1_SIGNATURES.signatures_1_1_0a2)
_VERSION_SIGNATURES[Version("1.1.0a2")] = _sigs_1_1_0a2


def _get_api_signatures(api: Any) -> dict[str, tuple[str, ...]]:
return _VERSION_SIGNATURES[api._version]

Expand Down Expand Up @@ -200,6 +214,12 @@ class CADETAPI_V1_1_0a1(ctypes.Structure):
_version = Version("1.1.0a1")
_fields_ = _setup_api(_version)


class CADETAPI_V1_1_0a2(ctypes.Structure):
"""Mimic cdtAPIv1.1.0a.2 struct of CADET C-API in ctypes."""
_version = Version("1.1.0a2")
_fields_ = _setup_api(_version)


class SimulationResult:
"""
Expand Down Expand Up @@ -1722,8 +1742,14 @@ def _initialize_dll(self):
"This version of CADET-Python does not support CADET-CAPI version "
f"({self._cadet_capi_version})."
)
elif self._cadet_capi_version >= Version("1.1.0a2"):
cdtGetAPIv1_1_0a2 = self._lib.cdtGetAPIv1_1_0a2
cdtGetAPIv1_1_0a2.argtypes = [ctypes.POINTER(CADETAPI_V1_1_0a2)]
cdtGetAPIv1_1_0a2.restype = c_cadet_result
self._api = CADETAPI_V1_1_0a2()
cdtGetAPIv1_1_0a2(ctypes.byref(self._api))

elif self._cadet_capi_version >= Version("1.1.0a1"):
elif self._cadet_capi_version == Version("1.1.0a1"):
cdtGetAPIv1_1_0a1 = self._lib.cdtGetAPIv1_1_0a1
cdtGetAPIv1_1_0a1.argtypes = [ctypes.POINTER(CADETAPI_V1_1_0a1)]
cdtGetAPIv1_1_0a1.restype = c_cadet_result
Expand Down Expand Up @@ -2227,3 +2253,102 @@ def cadet_commit_hash(self) -> str:
@property
def cadet_path(self) -> os.PathLike:
return self._cadet_path

def initialize_simulation(self, sim: "Cadet") -> ReturnInformation:

"""
Initialize a CADET simulation without running it.

Parameters
----------
simulation : Cadet
Simulation object containing input data.

Returns
-------
ReturnInformation
Information about the initialization result.
"""

if self._cadet_capi_version < Version("1.1.0a2"):
raise RuntimeError(
f"Used Cadet-Core C API Version: ({self._cadet_capi_version})"
"To use initialize_simulation Cadet-Core needs to support at least Version 1.1.0a2"
)

pp = cadet_dll_parameterprovider.PARAMETERPROVIDER(sim)

log_buffer = self.setup_log_buffer()

returncode = self._api.initializeSimulation(self._driver, ctypes.byref(pp))

if returncode != 0:
log = ""
error_message = log_buffer.getvalue()
else:
log = log_buffer.getvalue()
error_message = ""


return_info = ReturnInformation(
return_code=returncode,
error_message=error_message,
log=log
)

return return_info

def perform_simulation_step(self, t_end:float) -> tuple[ReturnInformation,float]:
"""
Perform a single simulation step until time t_end.

Parameters
----------
t_end : float
Target end time for this simulation step.

Returns
-------
tuple[ReturnInformation, float]
- ReturnInformation: Information about the step result
- float: Actually reached time (may differ from t_end)
"""
if self._cadet_capi_version < Version("1.1.0a2"):
raise RuntimeError(
f"Used Cadet-Core C API Version: ({self._cadet_capi_version})"
"To use perform_simulation_step Cadet-Core needs to support at least Version 1.1.0a2"
)

log_buffer = self.setup_log_buffer()

t_reached = ctypes.c_double(0.0)

returncode = self._api.performSimulationStep( self._driver, ctypes.c_double(t_end), ctypes.byref(t_reached))

if returncode != 0:
log = ""
error_message = log_buffer.getvalue()
else:
log = log_buffer.getvalue()
error_message = ""

self.res = SimulationResult(self._api, self._driver)

return_info = ReturnInformation(
return_code=returncode,
error_message=error_message,
log=log
)

return return_info, t_reached.value

def end_simulation(self) -> ReturnInformation:
if self._cadet_capi_version < Version("1.1.0a2"):
raise RuntimeError(
f"Used Cadet-Core C API Version: ({self._cadet_capi_version})"
"To use end_simulation Cadet-Core needs to support at least Version 1.1.0a2"
)

returncode = self._api.endSimulation(self._driver)

return returncode
Loading
Loading