Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,11 @@ class ExecutorchBackendConfig:
# If set to true, we run quant fusion and constant propagation passes
do_quant_fusion_and_const_prop: bool = False

# Experimental: If set to true, we run a pass to reinplace ops in the graph.
# If set to true, we run a pass to reinplace ops in the graph.
run_reinplace_pass: bool = False

# When True, memory planning partitions specs by device and runs the
# algorithm independently per device, producing separate buffers for CPU
# vs. accelerator memory. Default False preserves the legacy behavior
# where all tensors are planned into CPU memory regardless of device.
enable_non_cpu_memory_planning: bool = False
116 changes: 88 additions & 28 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -28,6 +28,7 @@
import torch
from executorch.exir import memory
from executorch.exir.control_flow import while_loop as exir_while
from executorch.exir.schema import DeviceType, NonConstBufferDevice
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.error import internal_assert, InternalError
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
Expand Down Expand Up @@ -1211,10 +1212,19 @@
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
enable_non_cpu_memory_planning: bool = False,
) -> list[int]:
"""
Recursively apply algo to graph_module and its submodules for control flow.

Partitions specs by device type and device idx, and runs the memory planning
algorithm independently per device, then merges results into separate buffers.
This ensures device memory and CPU memory are never mixed.

When enable_non_cpu_memory_planning is False (default), all specs are planned
into a single CPU memory pool regardless of their device attribute. This
preserves the legacy behavior. Set to True to enable per-device partitioning.

Algo implementation should handle one of two meta entries for submodules:
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
`algo` should start at the offset specified by this list;
Expand All @@ -1229,18 +1239,19 @@
`operand` arg. The memory for operands is unused.
"""
# Extract the nodes and their lifespans from the graph_module
# Difficult to just filter the list of specs returned by this due to
# how we flag trainable weights.
_ = update_all_tensors_lifetime(graph_module, graph_signature)

# Filter specs based on alloc_graph_input and alloc_graph_output
specs = collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
do_assertion=False,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
ignore_mutable_buffers=not alloc_mutable_buffers,
# Collect and materialize specs into a set so we can iterate multiple
# times and partition by device.
all_specs: set[TensorSpec] = set(
collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
do_assertion=False,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
ignore_mutable_buffers=not alloc_mutable_buffers,
)
)

# Get temporary specs for submodules to set aside space during execution
Expand All @@ -1249,29 +1260,78 @@
algo, graph_module, alignment, graph_signature
)

# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
# algos to work using `input_mem_buffer_sizes` or use
# `non_const_buffer_sizes` directly.
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
graph_module.input_mem_buffer_sizes = submodule_bufsizes

# Get extra padding for XNNPACK if needed
extra_padding = 0
if _contains_xnnpack_delegate(graph_module):
extra_padding = 64

# Pass the filtered specs to the algorithm
bufsizes: list[int] = algo(
alignment,
specs,
graph_module,
graph_signature,
extra_padding,
# 1. Partition specs by device
specs_by_device: dict[DeviceType, set[TensorSpec]] = defaultdict(set)
if enable_non_cpu_memory_planning:
for spec in all_specs:
specs_by_device[spec.device].add(spec)
else:
# Legacy behavior: all specs planned into CPU memory regardless of device
specs_by_device[DeviceType.CPU] = all_specs

# 2. Plan each device independently
global_bufsizes: list[int] = [0] # index 0 reserved for constants
buffer_device_types: list[DeviceType] = [DeviceType.CPU]

# Process CPU first (if present), then other devices sorted by enum value
device_order = sorted(
specs_by_device.keys(),
key=lambda d: (d != DeviceType.CPU, d.value),
)

# pyre-ignore[6]: Incompatible parameter type [6]
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
insert_calls_to_free(graph_module, specs)
for device_type in device_order:
device_specs = specs_by_device[device_type]

graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
return bufsizes
# Only apply submodule pre-allocation for CPU specs; device buffers
# do not share memory space with CPU submodule arenas.
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
graph_module.input_mem_buffer_sizes = (
submodule_bufsizes if device_type == DeviceType.CPU else []
)

# Run algorithm independently on this device's specs
device_bufsizes = algo(
alignment, device_specs, graph_module, graph_signature, extra_padding
)

# Calculate base mem_id in global space
base_mem_id = len(global_bufsizes)

# Append buffer sizes (skip index 0 which is constants placeholder)
global_bufsizes.extend(device_bufsizes[1:])

# Track device type for each new buffer slot
for _ in device_bufsizes[1:]:
buffer_device_types.append(device_type)

# Remap spec mem_ids from algo-local to global.
# The algorithm assigns mem_id starting from 1; remap to global position.
for spec in device_specs:
if spec.mem_id is not None:
spec.mem_id = (spec.mem_id - 1) + base_mem_id

# Ensure backward compatibility: at least [0, 0] when no specs exist
if len(global_bufsizes) < 2:
global_bufsizes.append(0)
buffer_device_types.append(DeviceType.CPU)

# 3. Insert free calls and build device buffer mapping
insert_calls_to_free(graph_module, all_specs)

has_device_buffers = any(dt != DeviceType.CPU for dt in buffer_device_types)
non_const_buffer_device: Optional[list[NonConstBufferDevice]] = None
if has_device_buffers:
non_const_buffer_device = [
NonConstBufferDevice(buffer_idx=i, device_type=dt, device_index=0)
for i, dt in enumerate(buffer_device_types)
]

graph_module.meta["non_const_buffer_sizes"] = global_bufsizes
if non_const_buffer_device is not None:
graph_module.meta["non_const_buffer_device"] = non_const_buffer_device
return global_bufsizes
3 changes: 3 additions & 0 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
alloc_mutable_buffers: bool = True,
share_mutable_buffers: bool = False,
alignment: int = ALIGNMENT,
enable_non_cpu_memory_planning: bool = False,
) -> None:
r"""
alloc_graph_input/alloc_graph_output will have 4 different combinations
Expand All @@ -173,6 +174,7 @@ def __init__(
self.alloc_mutable_buffers = alloc_mutable_buffers
self.share_mutable_buffers = share_mutable_buffers
self.alignment = alignment
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
self.state = _MemoryPlanningState()

def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
Expand Down Expand Up @@ -250,6 +252,7 @@ def run(
# If mutable buffers are shared, then do not allocate them in the
# main memory planning algo; they are allocated in run_multimethod.
self.alloc_mutable_buffers and not self.share_mutable_buffers,
self.enable_non_cpu_memory_planning,
)

if self.share_mutable_buffers and graph_signature is not None:
Expand Down
6 changes: 6 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,12 @@ def to_executorch( # noqa (FLAKE8) C901
else:
memory_planning_pass = config.memory_planning_pass
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
# Propagate enable_non_cpu_memory_planning from the top-level config
# to the pass instance so that device-aware partitioning is applied.
if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"):
memory_planning_pass.enable_non_cpu_memory_planning = (
config.enable_non_cpu_memory_planning
)
if hasattr(memory_planning_pass, "run"):
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
else:
Expand Down
169 changes: 169 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -29,6 +29,8 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import (
_do_user_inputs_exist,
apply_algo,
collect_specs_from_nodes,
filter_nodes,
get_node_tensor_specs,
greedy,
Expand All @@ -45,6 +47,7 @@
ToOutVarPass,
)
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.schema import DeviceType
from executorch.exir.tensor import TensorSpec
from functorch.experimental.control_flow import map as torch_map
from parameterized import parameterized
Expand Down Expand Up @@ -1259,3 +1262,169 @@
self.assertEqual(v_cache[0].val.allocation_info.memory_id, 2)
self.assertEqual(v_cache[0].val.allocation_info.memory_offset_low, 256)
self.assertEqual(v_cache[0].val.allocation_info.memory_offset_high, 0)


class TestDeviceAwareMemoryPlanning(unittest.TestCase):
"""Tests for per-device memory planning (separate buffers per device type)."""

def _prepare_model(
self,
) -> Tuple[GraphModule, ExportGraphSignature]:
"""Prepare ToyModelForMemPlanning through SpecPropPass + ToOutVarPass."""
model = ToyModelForMemPlanning()
inputs = model.get_random_inputs()
edge = to_edge(export(model, inputs, strict=True))
gm = edge.exported_program().graph_module
gs = edge.exported_program().graph_signature
gm = PassManager(passes=[SpecPropPass(), ToOutVarPass()])(gm).graph_module
return gm, gs

def _get_planned_specs(
self,
gm: GraphModule,
gs: ExportGraphSignature,
) -> list[TensorSpec]:
"""Get the unique set of specs that apply_algo would plan."""
return list(
collect_specs_from_nodes(
gm.graph.nodes,
gs,
do_assertion=False,
ignore_graph_input=False,
ignore_graph_output=False,
ignore_mutable_buffers=False,
)
)

def test_cpu_only_unchanged(self) -> None:
"""CPU-only specs produce bufsizes = [0, X] with no device metadata."""
gm, gs = self._prepare_model()

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
bufsizes = apply_algo(
algo, gm, 16, gs, enable_non_cpu_memory_planning=True
)

# The CUDA spec is the only tensor in its buffer
self.assertEqual(bufsizes[0], 0) # constants
self.assertGreater(bufsizes[1], 0) # CPU activations
self.assertNotIn("non_const_buffer_device", gm.meta)

def test_all_cuda_no_wasted_slots(self) -> None:
"""CUDA-only specs produce [0, X] with CUDA at buffer index 1."""
gm, gs = self._prepare_model()
specs = self._get_planned_specs(gm, gs)
for spec in specs:
spec.device = DeviceType.CUDA

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)

# [0, cuda_size] — no wasted CPU buffer slot
self.assertEqual(len(bufsizes), 2)
self.assertEqual(bufsizes[0], 0)
self.assertGreater(bufsizes[1], 0)
# Device mapping should be present
self.assertIn("non_const_buffer_device", gm.meta)
device_map = gm.meta["non_const_buffer_device"]
self.assertEqual(len(device_map), 2)
self.assertEqual(device_map[0].device_type, DeviceType.CPU) # constants
self.assertEqual(device_map[1].device_type, DeviceType.CUDA)

def test_mixed_cpu_cuda_separate_buffers(self) -> None:
"""CPU specs at mem_id=1, CUDA specs at mem_id=2, separate sizes."""
gm, gs = self._prepare_model()
specs = self._get_planned_specs(gm, gs)

# Set second half of specs to CUDA
mid = len(specs) // 2
self.assertGreater(mid, 0)
cpu_specs = specs[:mid]
cuda_specs = specs[mid:]
for spec in cuda_specs:
spec.device = DeviceType.CUDA

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)

# [constants, cpu_activations, cuda_activations]
self.assertEqual(len(bufsizes), 3)
self.assertEqual(bufsizes[0], 0)
self.assertGreater(bufsizes[1], 0)
self.assertGreater(bufsizes[2], 0)

# CPU specs should have mem_id=1, CUDA specs should have mem_id=2
for spec in cpu_specs:
self.assertEqual(spec.mem_id, 1, f"CPU spec has wrong mem_id: {spec.mem_id}")
for spec in cuda_specs:
self.assertEqual(spec.mem_id, 2, f"CUDA spec has wrong mem_id: {spec.mem_id}")

def test_mem_offset_correct_after_remap(self) -> None:
"""After remapping, mem_offset is relative to its own buffer."""
gm, gs = self._prepare_model()
specs = self._get_planned_specs(gm, gs)

# Set the last spec to CUDA (sole CUDA tensor)
cuda_spec = specs[-1]
cuda_spec.device = DeviceType.CUDA

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
bufsizes = apply_algo(
algo, gm, 16, gs, enable_non_cpu_memory_planning=True
)

# The CUDA spec is the only tensor in its buffer, so offset should be 0
self.assertEqual(cuda_spec.mem_offset, 0)
# The CUDA buffer should fit exactly this tensor
cuda_mem_id = cuda_spec.mem_id
self.assertIsNotNone(cuda_mem_id)
assert cuda_mem_id is not None
self.assertGreaterEqual(bufsizes[cuda_mem_id], cuda_spec.allocated_memory)

def test_no_cross_device_memory_sharing(self) -> None:
"""Specs on different devices never share buffers, regardless of lifetime."""
gm, gs = self._prepare_model()
specs = self._get_planned_specs(gm, gs)
self.assertGreaterEqual(len(specs), 2)

# Assign alternating specs to CUDA to ensure some pairs have
# non-overlapping lifetimes (which greedy would normally share).
for i, spec in enumerate(specs):
if i % 2 == 0:
spec.device = DeviceType.CUDA

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)

# Verify CPU and CUDA specs have disjoint mem_ids
cpu_mem_ids: set[int] = set()
cuda_mem_ids: set[int] = set()
for i, spec in enumerate(specs):
if spec.mem_id is not None:
if i % 2 == 0:
cuda_mem_ids.add(spec.mem_id)
else:
cpu_mem_ids.add(spec.mem_id)

self.assertTrue(
cpu_mem_ids.isdisjoint(cuda_mem_ids),
f"CPU {cpu_mem_ids} and CUDA {cuda_mem_ids} should not share buffers",
)

def test_disabled_falls_back_to_cpu(self) -> None:
"""With enable_non_cpu_memory_planning=False (default), CUDA specs are
planned into CPU memory — no device-specific buffers are created."""
gm, gs = self._prepare_model()
specs = self._get_planned_specs(gm, gs)
for spec in specs:
spec.device = DeviceType.CUDA

algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
# Default: enable_non_cpu_memory_planning=False
bufsizes = apply_algo(algo, gm, 16, gs)

# All specs planned into a single CPU pool — same as CPU-only
self.assertEqual(len(bufsizes), 2)
self.assertEqual(bufsizes[0], 0)
self.assertGreater(bufsizes[1], 0)
self.assertNotIn("non_const_buffer_device", gm.meta)
Loading