diff --git a/exir/_serialize/generated/executorch_flatbuffer/ExecutionPlan.py b/exir/_serialize/generated/executorch_flatbuffer/ExecutionPlan.py index b8ed496b8a8..340a0ad69aa 100644 --- a/exir/_serialize/generated/executorch_flatbuffer/ExecutionPlan.py +++ b/exir/_serialize/generated/executorch_flatbuffer/ExecutionPlan.py @@ -10,6 +10,7 @@ from executorch.exir._serialize.generated.executorch_flatbuffer.Chain import Chain from executorch.exir._serialize.generated.executorch_flatbuffer.ContainerMetadata import ContainerMetadata from executorch.exir._serialize.generated.executorch_flatbuffer.EValue import EValue +from executorch.exir._serialize.generated.executorch_flatbuffer.NonConstBufferDevice import NonConstBufferDevice from executorch.exir._serialize.generated.executorch_flatbuffer.Operator import Operator from typing import Optional np = import_numpy() @@ -230,8 +231,32 @@ def NonConstBufferSizesIsNone(self) -> bool: o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20)) return o == 0 + # ExecutionPlan + def NonConstBufferDevice(self, j: int) -> Optional[NonConstBufferDevice]: + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + obj = NonConstBufferDevice() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # ExecutionPlan + def NonConstBufferDeviceLength(self) -> int: + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # ExecutionPlan + def NonConstBufferDeviceIsNone(self) -> bool: + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) + return o == 0 + def ExecutionPlanStart(builder: flatbuffers.Builder): - builder.StartObject(9) + builder.StartObject(10) def Start(builder: flatbuffers.Builder): ExecutionPlanStart(builder) @@ -332,6 +357,18 @@ def ExecutionPlanStartNonConstBufferSizesVector(builder, numElems: int) -> int: def StartNonConstBufferSizesVector(builder, numElems: int) -> int: return ExecutionPlanStartNonConstBufferSizesVector(builder, numElems) +def ExecutionPlanAddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int): + builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(nonConstBufferDevice), 0) + +def AddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int): + ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice) + +def ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems: int) -> int: + return builder.StartVector(4, numElems, 4) + +def StartNonConstBufferDeviceVector(builder, numElems: int) -> int: + return ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems) + def ExecutionPlanEnd(builder: flatbuffers.Builder) -> int: return builder.EndObject() @@ -342,6 +379,7 @@ def End(builder: flatbuffers.Builder) -> int: from executorch.exir._serialize.generated.executorch_flatbuffer import Chain from executorch.exir._serialize.generated.executorch_flatbuffer import ContainerMetadata from executorch.exir._serialize.generated.executorch_flatbuffer import EValue +from executorch.exir._serialize.generated.executorch_flatbuffer import NonConstBufferDevice from executorch.exir._serialize.generated.executorch_flatbuffer import Operator try: from typing import List, Optional @@ -361,6 +399,7 @@ def __init__(self): self.operators = None # type: List[executorch_flatbuffer.Operator.OperatorT] self.delegates = None # type: List[executorch_flatbuffer.BackendDelegate.BackendDelegateT] self.nonConstBufferSizes = None # type: List[int] + self.nonConstBufferDevice = None # type: List[executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT] @classmethod def InitFromBuf(cls, buf, pos): @@ -389,7 +428,8 @@ def __eq__(self, other): self.chains == other.chains and \ self.operators == other.operators and \ self.delegates == other.delegates and \ - self.nonConstBufferSizes == other.nonConstBufferSizes + self.nonConstBufferSizes == other.nonConstBufferSizes and \ + self.nonConstBufferDevice == other.nonConstBufferDevice # ExecutionPlanT def _UnPack(self, executionPlan): @@ -451,6 +491,14 @@ def _UnPack(self, executionPlan): self.nonConstBufferSizes.append(executionPlan.NonConstBufferSizes(i)) else: self.nonConstBufferSizes = executionPlan.NonConstBufferSizesAsNumpy() + if not executionPlan.NonConstBufferDeviceIsNone(): + self.nonConstBufferDevice = [] + for i in range(executionPlan.NonConstBufferDeviceLength()): + if executionPlan.NonConstBufferDevice(i) is None: + self.nonConstBufferDevice.append(None) + else: + nonConstBufferDevice_ = executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT.InitFromObj(executionPlan.NonConstBufferDevice(i)) + self.nonConstBufferDevice.append(nonConstBufferDevice_) # ExecutionPlanT def Pack(self, builder): @@ -514,6 +562,14 @@ def Pack(self, builder): for i in reversed(range(len(self.nonConstBufferSizes))): builder.PrependInt64(self.nonConstBufferSizes[i]) nonConstBufferSizes = builder.EndVector() + if self.nonConstBufferDevice is not None: + nonConstBufferDevicelist = [] + for i in range(len(self.nonConstBufferDevice)): + nonConstBufferDevicelist.append(self.nonConstBufferDevice[i].Pack(builder)) + ExecutionPlanStartNonConstBufferDeviceVector(builder, len(self.nonConstBufferDevice)) + for i in reversed(range(len(self.nonConstBufferDevice))): + builder.PrependUOffsetTRelative(nonConstBufferDevicelist[i]) + nonConstBufferDevice = builder.EndVector() ExecutionPlanStart(builder) if self.name is not None: ExecutionPlanAddName(builder, name) @@ -533,5 +589,7 @@ def Pack(self, builder): ExecutionPlanAddDelegates(builder, delegates) if self.nonConstBufferSizes is not None: ExecutionPlanAddNonConstBufferSizes(builder, nonConstBufferSizes) + if self.nonConstBufferDevice is not None: + ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice) executionPlan = ExecutionPlanEnd(builder) return executionPlan diff --git a/exir/_serialize/generated/executorch_flatbuffer/NonConstBufferDevice.py b/exir/_serialize/generated/executorch_flatbuffer/NonConstBufferDevice.py new file mode 100644 index 00000000000..d82df37d29b --- /dev/null +++ b/exir/_serialize/generated/executorch_flatbuffer/NonConstBufferDevice.py @@ -0,0 +1,130 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: executorch_flatbuffer + +import flatbuffers +from flatbuffers.compat import import_numpy +from typing import Any +np = import_numpy() + +class NonConstBufferDevice(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset: int = 0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = NonConstBufferDevice() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsNonConstBufferDevice(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def NonConstBufferDeviceBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x45\x54\x31\x32", size_prefixed=size_prefixed) + + # NonConstBufferDevice + def Init(self, buf: bytes, pos: int): + self._tab = flatbuffers.table.Table(buf, pos) + + # NonConstBufferDevice + def BufferIdx(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # NonConstBufferDevice + def DeviceType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # NonConstBufferDevice + def DeviceIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + +def NonConstBufferDeviceStart(builder: flatbuffers.Builder): + builder.StartObject(3) + +def Start(builder: flatbuffers.Builder): + NonConstBufferDeviceStart(builder) + +def NonConstBufferDeviceAddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int): + builder.PrependInt32Slot(0, bufferIdx, 0) + +def AddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int): + NonConstBufferDeviceAddBufferIdx(builder, bufferIdx) + +def NonConstBufferDeviceAddDeviceType(builder: flatbuffers.Builder, deviceType: int): + builder.PrependInt8Slot(1, deviceType, 0) + +def AddDeviceType(builder: flatbuffers.Builder, deviceType: int): + NonConstBufferDeviceAddDeviceType(builder, deviceType) + +def NonConstBufferDeviceAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int): + builder.PrependInt8Slot(2, deviceIndex, 0) + +def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int): + NonConstBufferDeviceAddDeviceIndex(builder, deviceIndex) + +def NonConstBufferDeviceEnd(builder: flatbuffers.Builder) -> int: + return builder.EndObject() + +def End(builder: flatbuffers.Builder) -> int: + return NonConstBufferDeviceEnd(builder) + + +class NonConstBufferDeviceT(object): + + # NonConstBufferDeviceT + def __init__(self): + self.bufferIdx = 0 # type: int + self.deviceType = 0 # type: int + self.deviceIndex = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + nonConstBufferDevice = NonConstBufferDevice() + nonConstBufferDevice.Init(buf, pos) + return cls.InitFromObj(nonConstBufferDevice) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, nonConstBufferDevice): + x = NonConstBufferDeviceT() + x._UnPack(nonConstBufferDevice) + return x + + def __eq__(self, other): + return type(self) == type(other) and \ + self.bufferIdx == other.bufferIdx and \ + self.deviceType == other.deviceType and \ + self.deviceIndex == other.deviceIndex + + # NonConstBufferDeviceT + def _UnPack(self, nonConstBufferDevice): + if nonConstBufferDevice is None: + return + self.bufferIdx = nonConstBufferDevice.BufferIdx() + self.deviceType = nonConstBufferDevice.DeviceType() + self.deviceIndex = nonConstBufferDevice.DeviceIndex() + + # NonConstBufferDeviceT + def Pack(self, builder): + NonConstBufferDeviceStart(builder) + NonConstBufferDeviceAddBufferIdx(builder, self.bufferIdx) + NonConstBufferDeviceAddDeviceType(builder, self.deviceType) + NonConstBufferDeviceAddDeviceIndex(builder, self.deviceIndex) + nonConstBufferDevice = NonConstBufferDeviceEnd(builder) + return nonConstBufferDevice diff --git a/exir/_serialize/generated/executorch_flatbuffer/__init__.py b/exir/_serialize/generated/executorch_flatbuffer/__init__.py index df59751e724..7cc3b482376 100644 --- a/exir/_serialize/generated/executorch_flatbuffer/__init__.py +++ b/exir/_serialize/generated/executorch_flatbuffer/__init__.py @@ -31,6 +31,7 @@ from . import KernelTypes from . import MoveCall from . import NamedData +from . import NonConstBufferDevice from . import Null from . import Operator from . import OptionalTensorList @@ -75,6 +76,7 @@ "KernelTypes", "MoveCall", "NamedData", + "NonConstBufferDevice", "Null", "Operator", "OptionalTensorList", diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 46e8f020a0b..1b6aab94af3 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -38,7 +38,9 @@ ContainerMetadata, DataLocation, DataSegment, + DeviceType, ExecutionPlan, + NonConstBufferDevice, Program, SubsegmentOffsets, ) @@ -477,6 +479,32 @@ def test_round_trip_large_buffer_sizes(self) -> None: program, deserialize_pte_binary(flatbuffer_from_py).program ) + def test_round_trip_with_non_const_buffer_device(self) -> None: + """Tests that non_const_buffer_device survives round-trip + serialization/deserialization. This verifies the schema extension + for per-buffer device mapping works correctly. + """ + program = get_test_program() + program.execution_plan[0].non_const_buffer_device = [ + NonConstBufferDevice(buffer_idx=0, device_type=DeviceType.CPU, device_index=0), + NonConstBufferDevice(buffer_idx=1, device_type=DeviceType.CUDA, device_index=0), + ] + flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program))) + self.assert_programs_equal( + program, deserialize_pte_binary(flatbuffer_from_py).program + ) + + def test_round_trip_without_non_const_buffer_device(self) -> None: + """Tests backward compatibility: a program without non_const_buffer_device + (the default) round-trips correctly and the field remains None. + """ + program = get_test_program() + self.assertIsNone(program.execution_plan[0].non_const_buffer_device) + flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program))) + deserialized = deserialize_pte_binary(flatbuffer_from_py).program + self.assert_programs_equal(program, deserialized) + self.assertIsNone(deserialized.execution_plan[0].non_const_buffer_device) + def test_round_trip_no_segments_and_no_header(self) -> None: """Tests that a Program serialized with extract_delegate_segments=True when there are no segments does not contain an extended header, diff --git a/exir/schema.py b/exir/schema.py index 993a473dabb..add90dec45c 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -268,6 +268,18 @@ class Operator: overload: str +@dataclass +class NonConstBufferDevice: + """Maps a non-constant buffer to the device where it should be allocated.""" + + # Index into the non_const_buffer_sizes list. + buffer_idx: int = 0 + # The device type for this buffer (CPU, CUDA, etc.). + device_type: DeviceType = DeviceType.CPU + # The device index for multi-device scenarios (e.g., cuda:0, cuda:1). + device_index: int = 0 + + @dataclass class ExecutionPlan: name: str @@ -283,6 +295,9 @@ class ExecutionPlan: # Runtime should use the len(constant_buffer) as the ground truch of # constant memory buffer size, and ignore non_const_buffer_sizes[0]. non_const_buffer_sizes: List[int] + # Per-buffer device mapping. Each entry maps a non-constant buffer to the + # device where it should be allocated. For CPU-only programs, this is empty. + non_const_buffer_device: Optional[List[NonConstBufferDevice]] = None @dataclass diff --git a/schema/program.fbs b/schema/program.fbs index f5872633ac8..c6e6edc790f 100644 --- a/schema/program.fbs +++ b/schema/program.fbs @@ -401,6 +401,27 @@ table ExecutionPlan { // constants memory buffer size, and ignore non_const_buffer_sizes[0]. non_const_buffer_sizes: [int64]; + // [Optional] Per-buffer device mapping, parallel to non_const_buffer_sizes. + // Each entry maps a non-constant buffer to the device where it should be + // allocated. For CPU-only programs, this field is absent and all buffers + // default to CPU, ensuring zero regression. + non_const_buffer_device: [NonConstBufferDevice]; + +} + +// Maps a non-constant buffer to the device where it should be allocated. +// When present as part of ExecutionPlan.non_const_buffer_device, each entry +// describes the device placement for the corresponding planned memory buffer. +// For CPU-only programs, this table is absent (all buffers default to CPU). +table NonConstBufferDevice { + // Index into the non_const_buffer_sizes list. + buffer_idx: int; + + // The device type for this buffer (CPU, CUDA, etc.). + device_type: DeviceType = CPU; + + // The device index for multi-device scenarios (e.g., cuda:0, cuda:1). + device_index: byte = 0; } // Constant tensor data stored directly in the flatbuffer.