Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
This project is a Python implementation of v1.4.0 of the [MATLAB toolbox k-Wave](http://www.k-wave.org/) as well as an
interface to the pre-compiled v1.3 of k-Wave simulation binaries, which support NVIDIA sm 5.0 (Maxwell) to sm 9.0a (Hopper) GPUs.

**New in v0.6.0:** Unified `kspaceFirstOrder()` API with a pure NumPy/CuPy solver. See the [API guide](https://k-wave-python.readthedocs.io/en/latest/get_started/new_api.html).
**New in v0.6.0:** Unified `kspaceFirstOrder()` API with a pure NumPy/CuPy solver. See the [API guide](https://k-wave-python.readthedocs.io/en/latest/get_started/new_api.html). The `kspaceFirstOrder()` API is experimental and may change before v1.0.0.

## Mission

Expand Down
34 changes: 31 additions & 3 deletions kwave/kspaceFirstOrder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _expand_for_pml_outside(kgrid, medium, source, sensor, pml_size):
return expanded_kgrid, expanded_medium, expanded_source, expanded_sensor


_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms")
_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms", "_max_all", "_min_all", "_rms_all")


def _strip_pml(result, pml_size, ndim):
Expand Down Expand Up @@ -156,6 +156,10 @@ def kspaceFirstOrder(
Returns:
dict: Recorded sensor data keyed by field name (e.g.
``"p"``, ``"p_final"``, ``"ux"``, ``"uy"``).

All time-series are ``(n_sensor, Nt)`` with sensor points in
C-flattened order. Use :func:`reshape_to_grid` to recover spatial
structure for full-grid masks.
"""
if device not in ("cpu", "gpu"):
raise ValueError(f"device must be 'cpu' or 'gpu', got {device!r}")
Expand All @@ -181,7 +185,7 @@ def kspaceFirstOrder(
from kwave.utils.filters import smooth

source = copy.copy(source)
source.p0 = smooth(np.asarray(source.p0, dtype=float).reshape(tuple(int(n) for n in kgrid.N), order="F"), restore_max=True)
source.p0 = smooth(np.asarray(source.p0, dtype=float).reshape(tuple(int(n) for n in kgrid.N)), restore_max=True)

# --- Backend dispatch ---

Expand Down Expand Up @@ -225,7 +229,7 @@ def kspaceFirstOrder(
from kwave.utils.conversion import cart2grid

sensor = copy.copy(sensor)
sensor.mask, _, _ = cart2grid(kgrid, np.asarray(sensor.mask))
sensor.mask, _, _ = cart2grid(kgrid, np.asarray(sensor.mask), order="C")

cpp_sim = CppSimulation(kgrid, medium, source, sensor, pml_size=pml_size, pml_alpha=pml_alpha, use_sg=use_sg)
if save_only:
Expand All @@ -244,3 +248,27 @@ def kspaceFirstOrder(
result = _strip_pml(result, pml_size, kgrid.dim)

return result


def reshape_to_grid(data, grid_shape):
"""Reshape flat sensor data to grid shape.

Convenience helper for full-grid sensor masks where ``n_sensor``
equals the total number of grid points.

Args:
data: sensor array — ``(n_sensor, Nt)`` time-series or
``(n_sensor,)`` aggregate.
grid_shape: tuple of grid dimensions, e.g. ``(Nx, Ny)``.

Returns:
For time-series: ``(*grid_shape, Nt)``
For aggregates: ``(*grid_shape)``
"""
data = np.asarray(data)
if data.ndim == 2:
n_sensor, Nt = data.shape
return data.reshape(*grid_shape, Nt)
elif data.ndim == 1:
return data.reshape(grid_shape)
return data
50 changes: 49 additions & 1 deletion kwave/solvers/cpp_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,62 @@ def run(self, *, device="cpu", num_threads=None, device_num=None, quiet=False, d
data_dir = os.path.dirname(input_file)
try:
self._execute(input_file, output_file, device=device, num_threads=num_threads, device_num=device_num, quiet=quiet, debug=debug)
return self._parse_output(output_file)
result = self._parse_output(output_file)
result = self._fix_output_order(result)
return result
finally:
if cleanup:
try:
shutil.rmtree(data_dir)
except OSError as exc:
warnings.warn(f"Could not clean up temp directory {data_dir!r}: {exc}", RuntimeWarning, stacklevel=2)

_FULL_GRID_SUFFIXES = ("_final", "_max", "_min", "_rms", "_max_all", "_min_all", "_rms_all")

def _fix_output_order(self, result):
"""Convert C++ output from F-order to C-order.

The C++ binary writes arrays in Fortran order. HDF5/h5py reads them
with reversed dimensions. We fix full-grid fields via transpose and
reorder sensor time-series rows from F-indexed to C-indexed.
"""
ndim = self.ndim
grid_shape = tuple(int(n) for n in self.kgrid.N)

# 1. Transpose full-grid fields from reversed F-order to C-order
for key, val in result.items():
if not isinstance(val, np.ndarray):
continue
is_grid = any(key.endswith(s) for s in self._FULL_GRID_SUFFIXES)
if is_grid and val.ndim == ndim:
result[key] = val.transpose(tuple(range(ndim - 1, -1, -1)))

# 2. Reorder sensor time-series from F-indexed to C-indexed rows
if self.sensor is None or self.sensor.mask is None:
mask = np.ones(grid_shape, dtype=bool)
else:
mask = np.asarray(self.sensor.mask, dtype=bool).reshape(grid_shape)

n_sensor = int(mask.sum())
if n_sensor > 0 and ndim >= 2:
f_nz = np.where(mask.ravel(order="F"))[0]
c_nz = np.where(mask.ravel())[0]
f_equiv = np.ravel_multi_index(np.unravel_index(c_nz, grid_shape), grid_shape, order="F")
perm = np.searchsorted(f_nz, f_equiv)

for key, val in result.items():
if not isinstance(val, np.ndarray):
continue
is_grid = any(key.endswith(s) for s in self._FULL_GRID_SUFFIXES)
if is_grid:
continue
if val.ndim == 2 and val.shape[0] == n_sensor:
result[key] = val[perm]
elif val.ndim == 1 and val.shape[0] == n_sensor:
result[key] = val[perm]

return result

# -- HDF5 serialization --

def _write_hdf5(self, filepath):
Expand Down
91 changes: 69 additions & 22 deletions kwave/solvers/kspace_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def _to_cpu(x):
def _expand_to_grid(val, grid_shape, xp, name="parameter"):
if val is None:
raise ValueError(f"Missing required parameter: {name}")
arr = xp.array(val, dtype=float).flatten(order="F")
arr = xp.array(val, dtype=float).ravel()
grid_size = int(np.prod(grid_shape))
if arr.size == 1:
return xp.full(grid_shape, float(arr[0]), dtype=float)
if arr.size == grid_size:
return arr.reshape(grid_shape, order="F")
return arr.reshape(grid_shape)
raise ValueError(f"{name} size {arr.size} incompatible with grid size {grid_size}")


Expand All @@ -48,16 +48,16 @@ def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_

Returns a callable (t, field) → field that injects scaled source values.
"""
mask = xp.array(mask_raw, dtype=bool).flatten(order="F")
mask = xp.array(mask_raw, dtype=bool).ravel()
if mask.size == 1:
mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).flatten(order="F")
mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).ravel()
n_src = int(xp.sum(mask))

signal_arr = xp.array(signal_raw, dtype=float, order="F")
signal_arr = xp.array(signal_raw, dtype=float)
if signal_arr.ndim == 1:
signal = signal_arr.reshape(1, -1)
else:
signal = signal_arr.reshape(-1, signal_arr.shape[-1], order="F") if signal_arr.ndim > 2 else signal_arr
signal = signal_arr.reshape(-1, signal_arr.shape[-1]) if signal_arr.ndim > 2 else signal_arr

scaled = signal * xp.atleast_1d(xp.asarray(scale))[:, None]
signal_len = scaled.shape[1]
Expand All @@ -70,9 +70,9 @@ def get_val(t):
def dirichlet(t, field):
if t >= signal_len:
return field
flat = field.flatten(order="F") # copy — mutation is intentional
flat = field.flatten() # copy — mutation is intentional
flat[mask] = get_val(t)
return flat.reshape(grid_shape, order="F")
return flat.reshape(grid_shape)

# Pre-allocate buffer to avoid per-step allocation
_src_buf = xp.zeros(grid_size, dtype=float)
Expand All @@ -82,15 +82,15 @@ def additive_kspace(t, field):
return field
_src_buf[:] = 0
_src_buf[mask] = get_val(t)
src = _src_buf.reshape(grid_shape, order="F")
src = _src_buf.reshape(grid_shape)
return field + diff_fn(src, source_kappa)

def additive_no_correction(t, field):
if t >= signal_len:
return field
_src_buf[:] = 0
_src_buf[mask] = get_val(t)
return field + _src_buf.reshape(grid_shape, order="F")
return field + _src_buf.reshape(grid_shape)

ops = {"dirichlet": dirichlet, "additive": additive_kspace, "additive-no-correction": additive_no_correction}
if mode not in ops:
Expand Down Expand Up @@ -210,19 +210,19 @@ def _is_cartesian(arr):

if mask_raw is None:
self.n_sensor_points = grid_numel
self._extract = lambda f: f.flatten(order="F")
self._extract = lambda f: f.ravel()
else:
mask_arr = np.asarray(mask_raw, dtype=float)
# Check Cartesian first to avoid ambiguity when size == grid_numel
if _is_cartesian(mask_arr):
self._setup_cartesian_extract(mask_arr)
elif _is_binary(mask_arr):
bmask = xp.array(mask_arr, dtype=bool).flatten(order="F")
bmask = xp.array(mask_arr, dtype=bool).ravel()
if bmask.size == 1:
bmask = xp.full(grid_numel, bool(bmask[0]), dtype=bool)
self.n_sensor_points = int(xp.sum(bmask))
idx = xp.where(bmask)[0]
self._extract = lambda f, _i=idx: f.flatten(order="F")[_i]
self._extract = lambda f, _i=idx: f.ravel()[_i]
else:
raise ValueError(
f"Sensor mask shape {mask_arr.shape} is neither binary " f"(numel={grid_numel}) nor Cartesian ({self.ndim}, N_points)"
Expand Down Expand Up @@ -289,7 +289,7 @@ def _setup_cartesian_extract(self, cart_pos):
x_vec, cart_x = axis_coords[0], cart.flatten()

def _extract_1d_interp(f):
return xp.asarray(np.interp(cart_x, x_vec, _to_cpu(f).flatten(order="F")))
return xp.asarray(np.interp(cart_x, x_vec, _to_cpu(f).ravel()))

self._extract = _extract_1d_interp
else:
Expand All @@ -298,8 +298,8 @@ def _extract_1d_interp(f):
int_idx = np.clip(np.floor(frac_idx).astype(int), 0, np.array(self.grid_shape)[:, None] - 2)
local = frac_idx - int_idx

# F-order strides and 2^ndim corner enumeration
strides = np.cumprod([1] + list(self.grid_shape[:-1]))
# C-order strides and 2^ndim corner enumeration
strides = np.cumprod([1] + list(self.grid_shape[:0:-1]))[::-1]
n_corners = 2**self.ndim
corner_indices = np.zeros((self.n_sensor_points, n_corners), dtype=int)
corner_weights = np.ones((self.n_sensor_points, n_corners))
Expand All @@ -313,7 +313,7 @@ def _extract_1d_interp(f):
corner_weights = xp.array(corner_weights)

def _extract_bilinear(f):
return (f.flatten(order="F")[corner_indices] * corner_weights).sum(axis=1)
return (f.ravel()[corner_indices] * corner_weights).sum(axis=1)

self._extract = _extract_bilinear

Expand Down Expand Up @@ -460,15 +460,15 @@ def _setup_source_operators(self):
grid_size = int(np.prod(self.grid_shape))

def _expand_mask(mask_raw):
mask = xp.array(mask_raw, dtype=bool).flatten(order="F")
mask = xp.array(mask_raw, dtype=bool).ravel()
if mask.size == 1:
mask = xp.full(self.grid_shape, bool(mask[0]), dtype=bool).flatten(order="F")
mask = xp.full(self.grid_shape, bool(mask[0]), dtype=bool).ravel()
return mask

def source_scale(mask_raw, c0):
"""Get per-source-point sound speed values."""
mask = _expand_mask(mask_raw)
c0_flat = c0.flatten(order="F")
c0_flat = c0.ravel()
n_src = int(xp.sum(mask))
return c0_flat[mask] if c0_flat.size > 1 else xp.full(n_src, float(c0_flat))

Expand Down Expand Up @@ -566,7 +566,7 @@ def _setup_fields(self):
if self.smooth_p0 and self.ndim >= 2:
from kwave.utils.filters import smooth

# p0 is F-order from _expand_to_grid; smooth() is order-agnostic (uses FFT on shape)
# smooth() is order-agnostic (uses FFT on shape)
p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True))
self._p0_initial = p0
else:
Expand Down Expand Up @@ -779,6 +779,53 @@ def create_simulation(kgrid, medium, source, sensor, device="auto", smooth_p0=Fa
)


def _f_to_c_source_reorder(source, grid_shape):
"""Reorder multi-row source signals from MATLAB F-flat to C-flat mask order.

MATLAB sends source signal rows ordered by F-flattened mask indices.
The solver uses C-flat ordering internally. For single-row (uniform)
sources, no reordering is needed.
"""
ndim = len(grid_shape)
if ndim < 2:
return source
source = dict(source) # shallow copy — don't mutate caller's dict

for mask_key, signal_keys in [("p_mask", ["p"]), ("u_mask", ["ux", "uy", "uz"])]:
mask_raw = source.get(mask_key)
if mask_raw is None:
continue
mask = np.asarray(mask_raw, dtype=bool)
if mask.size <= 1:
continue
mask_grid = mask.reshape(grid_shape)
n_src = int(mask_grid.sum())
if n_src < 2:
continue

# Build F→C permutation for mask points
f_nz = np.where(mask_grid.ravel(order="F"))[0]
c_nz = np.where(mask_grid.ravel())[0]
f_equiv = np.ravel_multi_index(np.unravel_index(c_nz, grid_shape), grid_shape, order="F")
perm = np.searchsorted(f_nz, f_equiv)

for sig_key in signal_keys:
sig = source.get(sig_key)
if sig is None:
continue
sig = np.asarray(sig)
if sig.ndim >= 2 and sig.shape[0] == n_src:
source[sig_key] = sig[perm]

return source


def simulate_from_dicts(kgrid, medium, source, sensor, device="auto", smooth_p0=False):
"""MATLAB interop entry point."""
"""MATLAB interop entry point.

Reorders multi-row source signals from MATLAB's F-flat mask ordering
to the solver's C-flat ordering before running the simulation.
"""
grid_shape = tuple(kgrid[k] for k in ["Nx", "Ny", "Nz"] if k in kgrid)
source = _f_to_c_source_reorder(source, grid_shape)
return create_simulation(kgrid, medium, source, sensor, device, smooth_p0=smooth_p0).run()
Loading
Loading