Skip to content
Open
5 changes: 5 additions & 0 deletions exir/emit/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ fbcode_target(_kind = runtime.python_test,
"//executorch/exir:schema",
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:constant_prop_pass",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/tests:lib",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable_lib",
Expand Down
118 changes: 118 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,3 +2518,121 @@ def forward(self):
for j in range(2):
expected_storage.append(j * 16 + i)
self.assertEqual([int(v) for v in storage_values], expected_storage)

def test_emit_device_info_propagated_to_serialized_tensor(self) -> None:
"""Verify that device info from PropagateDevicePass flows through
the emitter into ExtraTensorInfo.device_type on serialized tensors."""
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]

class DevicePartitioner(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(AddSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(DevicePartitioner())
et_prog = lowered.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertGreater(len(plan.delegates), 0)

tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
cuda_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
]
# add(a, b) has 2 delegate inputs + 1 delegate output = 3 CUDA tensors
self.assertEqual(
len(cuda_tensors),
3,
f"Expected exactly 3 CUDA tensors (2 inputs + 1 output for delegated add), got {len(cuda_tensors)}",
)

def test_emit_cpu_tensors_no_extra_device_info(self) -> None:
"""When all tensors are on CPU (default), ExtraTensorInfo should NOT be
created solely for device info — it should remain None for activation tensors.
"""

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
et_prog = edge.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
cuda_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
]
self.assertEqual(
len(cuda_tensors),
0,
"No tensor should have CUDA device when model runs entirely on CPU",
)
12 changes: 11 additions & 1 deletion exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,16 @@ def to_list(
tensor_size = to_list(spec.shape)
tensor_dim_order = to_list(spec.dim_order)

extra_tensor_info = spec.extra_tensor_info
# Propagate device from TensorSpec into ExtraTensorInfo for serialization.
if spec.device != schema.DeviceType.CPU:
if extra_tensor_info is None:
extra_tensor_info = schema.ExtraTensorInfo(
device_type=spec.device,
)
else:
extra_tensor_info.device_type = spec.device

flatbuffer_tensor = schema.Tensor(
scalar_type=scalar_type_enum(spec.scalar_type),
# The runtime currently only supports tensors with offsets of zero.
Expand All @@ -376,7 +386,7 @@ def to_list(
allocation_info=allocation_info,
layout=layout_enum(spec.layout),
shape_dynamism=spec.shape_dynamism,
extra_tensor_info=spec.extra_tensor_info,
extra_tensor_info=extra_tensor_info,
)
return flatbuffer_tensor

Expand Down
Loading