Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion runtime/executor/tensor_parser_portable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ Result<Tensor> parseTensor(
Internal,
"dim_order_to_stride returned invalid status");

// Extract device info from serialized tensor metadata.
// Defaults to CPU/0 for backward compatibility when extra_tensor_info is
// absent (e.g., older PTE files without device annotations).
auto device_type = executorch::runtime::etensor::DeviceType::CPU;
executorch::runtime::etensor::DeviceIndex device_index = 0;
if (s_tensor->extra_tensor_info() != nullptr) {
device_type = static_cast<executorch::runtime::etensor::DeviceType>(
s_tensor->extra_tensor_info()->device_type());
device_index = static_cast<executorch::runtime::etensor::DeviceIndex>(
s_tensor->extra_tensor_info()->device_index());
}

auto* tensor_impl = method_allocator->allocateInstance<TensorImpl>();
if (tensor_impl == nullptr) {
return Error::MemoryAllocationFailed;
Expand All @@ -161,7 +173,9 @@ Result<Tensor> parseTensor(
/*data=*/nullptr,
dim_order,
strides,
dynamism);
dynamism,
device_type,
device_index);

// Now that we know how big the tensor is, find and assign its memory.
Result<void*> data_ptr = getTensorDataPtr(
Expand Down
16 changes: 16 additions & 0 deletions runtime/executor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,19 @@ def define_common_targets(is_fbcode = False):
],
env = modules_env,
)

runtime.cxx_test(
name = "tensor_parser_device_test",
srcs = [
"tensor_parser_device_test.cpp",
],
deps = [
":managed_memory_manager",
"//executorch/runtime/executor:program",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/schema:program",
],
env = {
"ET_MODULE_ADD_WITH_DEVICE_PATH": "$(location fbcode//executorch/test/models:exported_program_with_device_info[ModuleAddWithDevice.pte])",
},
)
171 changes: 171 additions & 0 deletions runtime/executor/test/tensor_parser_device_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/**
* Tests that device info (device_type) is correctly parsed from serialized
* tensors in .pte files into TensorImpl at runtime.
*
* Uses a .pte exported with DeviceAwarePartitioner (CUDA device annotation)
* so that delegate output tensors carry device_type=CUDA in ExtraTensorInfo.
*/

#include <executorch/runtime/executor/tensor_parser.h>

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/executor/test/managed_memory_manager.h>
#include <executorch/schema/program_generated.h>

#include <gtest/gtest.h>

using executorch::aten::Tensor;
using executorch::runtime::Error;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::deserialization::parseTensor;
using executorch::runtime::testing::ManagedMemoryManager;
using torch::executor::util::FileDataLoader;

constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;

namespace executorch {
namespace runtime {
namespace testing {
class ProgramTestFriend final {
public:
const static executorch_flatbuffer::Program* GetInternalProgram(
const Program* program) {
return program->internal_program_;
}
};
} // namespace testing
} // namespace runtime
} // namespace executorch

using executorch::runtime::testing::ProgramTestFriend;

class TensorParserDeviceTest : public ::testing::Test {
protected:
void SetUp() override {
const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH");
ASSERT_NE(path, nullptr)
<< "ET_MODULE_ADD_WITH_DEVICE_PATH env var not set";
Result<FileDataLoader> loader = FileDataLoader::from(path);
ASSERT_EQ(loader.error(), Error::Ok);
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
}

std::unique_ptr<FileDataLoader> loader_;
};

TEST_F(TensorParserDeviceTest, CUDADeviceParsedFromPteFile) {
Result<Program> program =
Program::load(loader_.get(), Program::Verification::Minimal);
ASSERT_EQ(program.error(), Error::Ok);

ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);

const executorch_flatbuffer::Program* internal_program =
ProgramTestFriend::GetInternalProgram(&program.get());
auto* execution_plan =
internal_program->execution_plan()->GetMutableObject(0);
auto* flatbuffer_values = execution_plan->values();

int cuda_tensor_count = 0;
int cpu_tensor_count = 0;
int total_tensor_count = 0;

for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
auto* serialization_value = flatbuffer_values->Get(i);
if (serialization_value->val_type() !=
executorch_flatbuffer::KernelTypes::Tensor) {
continue;
}
total_tensor_count++;

auto* s_tensor = serialization_value->val_as_Tensor();

Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
if (!tensor.ok()) {
bool has_cuda = s_tensor->extra_tensor_info() != nullptr &&
s_tensor->extra_tensor_info()->device_type() ==
executorch_flatbuffer::DeviceType::CUDA;
if (has_cuda) {
cuda_tensor_count++;
}
continue;
}

Tensor t = tensor.get();
auto device_type = t.unsafeGetTensorImpl()->device_type();

if (device_type == executorch::runtime::etensor::DeviceType::CUDA) {
cuda_tensor_count++;
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
<< "CUDA tensor should have device_index=0";
} else {
EXPECT_EQ(device_type, executorch::runtime::etensor::DeviceType::CPU);
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
<< "CPU tensor should have device_index=0";
cpu_tensor_count++;
}
}

EXPECT_GT(total_tensor_count, 0) << "Should have at least one tensor";
// The model has add(a, b) delegated to CUDA — 2 inputs + 1 output = 3 CUDA
EXPECT_EQ(cuda_tensor_count, 3)
<< "Expected 3 CUDA tensors (2 delegate inputs + 1 delegate output)";
}

TEST_F(TensorParserDeviceTest, NonDelegatedTensorsDefaultToCPU) {
Result<Program> program =
Program::load(loader_.get(), Program::Verification::Minimal);
ASSERT_EQ(program.error(), Error::Ok);

ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);

const executorch_flatbuffer::Program* internal_program =
ProgramTestFriend::GetInternalProgram(&program.get());
auto* execution_plan =
internal_program->execution_plan()->GetMutableObject(0);
auto* flatbuffer_values = execution_plan->values();

for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
auto* serialization_value = flatbuffer_values->Get(i);
if (serialization_value->val_type() !=
executorch_flatbuffer::KernelTypes::Tensor) {
continue;
}

auto* s_tensor = serialization_value->val_as_Tensor();
bool has_cuda_device = s_tensor->extra_tensor_info() != nullptr &&
s_tensor->extra_tensor_info()->device_type() ==
executorch_flatbuffer::DeviceType::CUDA;

// Only check tensors that are NOT annotated as CUDA
if (has_cuda_device) {
continue;
}

Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
if (!tensor.ok()) {
continue;
}

Tensor t = tensor.get();
EXPECT_EQ(
t.unsafeGetTensorImpl()->device_type(),
executorch::runtime::etensor::DeviceType::CPU)
<< "Tensor at index " << i
<< " without CUDA annotation should default to CPU";
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
<< "Tensor at index " << i
<< " without device annotation should have device_index=0";
}
}
142 changes: 142 additions & 0 deletions test/models/export_program_with_device_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

"""Exports a simple model with device-annotated tensors for C++ testing.

Uses DeviceAwarePartitioner (BackendWithCompilerDemo + target_device=cuda:0)
so that delegate output tensors are annotated with CUDA device in the .pte.
"""

import argparse
import os
from typing import Dict, final

import torch
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY
from torch import nn
from torch.export import export
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase


class _AddOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]


@final
class _DeviceAwarePartitioner(Partitioner):
"""Partitioner that tags add ops for delegation with target_device=cuda:0."""

def __init__(self) -> None:
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(_AddOperatorSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)


class ModuleAddWithDevice(nn.Module):
"""Simple add model — the add op will be delegated with CUDA device annotation."""

def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return torch.add(a, b)

def get_random_inputs(self):
return (torch.randn(2, 2), torch.randn(2, 2))


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, required=True)
args = parser.parse_args()

torch.manual_seed(0)
model = ModuleAddWithDevice()
inputs = model.get_random_inputs()

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(_DeviceAwarePartitioner())
et_prog = lowered.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False))

os.makedirs(args.outdir, exist_ok=True)
outfile = os.path.join(args.outdir, "ModuleAddWithDevice.pte")

# Verify device annotations are present in the serialized program
from executorch.exir.schema import DeviceType, Tensor as SchemaTensor

program = et_prog._emitter_output.program
plan = program.execution_plan[0]
print(f"Delegates: {len(plan.delegates)}")
cuda_count = 0
for i, v in enumerate(plan.values):
if isinstance(v.val, SchemaTensor):
t = v.val
eti = t.extra_tensor_info
dev = eti.device_type if eti else "no_eti"
print(f" Tensor[{i}]: sizes={list(t.sizes)}, device={dev}")
if eti and eti.device_type == DeviceType.CUDA:
cuda_count += 1
print(f"CUDA tensors: {cuda_count}")

# Also check graph module specs
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.tensor import TensorSpec

gm = et_prog.exported_program().graph_module
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
specs = node.meta.get("spec")
print(
f" Delegate node '{node.name}' spec.device = {specs.device if isinstance(specs, TensorSpec) else [s.device for s in specs if isinstance(s, TensorSpec)]}"
)

with open(outfile, "wb") as fp:
fp.write(et_prog.buffer)
print(f"Exported ModuleAddWithDevice to {outfile}")


if __name__ == "__main__":
main()
Loading
Loading