Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from executorch.exir._serialize.generated.executorch_flatbuffer.Chain import Chain
from executorch.exir._serialize.generated.executorch_flatbuffer.ContainerMetadata import ContainerMetadata
from executorch.exir._serialize.generated.executorch_flatbuffer.EValue import EValue
from executorch.exir._serialize.generated.executorch_flatbuffer.NonConstBufferDevice import NonConstBufferDevice
from executorch.exir._serialize.generated.executorch_flatbuffer.Operator import Operator
from typing import Optional
np = import_numpy()
Expand Down Expand Up @@ -230,8 +231,32 @@ def NonConstBufferSizesIsNone(self) -> bool:
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
return o == 0

# ExecutionPlan
def NonConstBufferDevice(self, j: int) -> Optional[NonConstBufferDevice]:
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = NonConstBufferDevice()
obj.Init(self._tab.Bytes, x)
return obj
return None

# ExecutionPlan
def NonConstBufferDeviceLength(self) -> int:
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ExecutionPlan
def NonConstBufferDeviceIsNone(self) -> bool:
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
return o == 0

def ExecutionPlanStart(builder: flatbuffers.Builder):
builder.StartObject(9)
builder.StartObject(10)

def Start(builder: flatbuffers.Builder):
ExecutionPlanStart(builder)
Expand Down Expand Up @@ -332,6 +357,18 @@ def ExecutionPlanStartNonConstBufferSizesVector(builder, numElems: int) -> int:
def StartNonConstBufferSizesVector(builder, numElems: int) -> int:
return ExecutionPlanStartNonConstBufferSizesVector(builder, numElems)

def ExecutionPlanAddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(nonConstBufferDevice), 0)

def AddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int):
ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice)

def ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems: int) -> int:
return builder.StartVector(4, numElems, 4)

def StartNonConstBufferDeviceVector(builder, numElems: int) -> int:
return ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems)

def ExecutionPlanEnd(builder: flatbuffers.Builder) -> int:
return builder.EndObject()

Expand All @@ -342,6 +379,7 @@ def End(builder: flatbuffers.Builder) -> int:
from executorch.exir._serialize.generated.executorch_flatbuffer import Chain
from executorch.exir._serialize.generated.executorch_flatbuffer import ContainerMetadata
from executorch.exir._serialize.generated.executorch_flatbuffer import EValue
from executorch.exir._serialize.generated.executorch_flatbuffer import NonConstBufferDevice
from executorch.exir._serialize.generated.executorch_flatbuffer import Operator
try:
from typing import List, Optional
Expand All @@ -361,6 +399,7 @@ def __init__(self):
self.operators = None # type: List[executorch_flatbuffer.Operator.OperatorT]
self.delegates = None # type: List[executorch_flatbuffer.BackendDelegate.BackendDelegateT]
self.nonConstBufferSizes = None # type: List[int]
self.nonConstBufferDevice = None # type: List[executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -389,7 +428,8 @@ def __eq__(self, other):
self.chains == other.chains and \
self.operators == other.operators and \
self.delegates == other.delegates and \
self.nonConstBufferSizes == other.nonConstBufferSizes
self.nonConstBufferSizes == other.nonConstBufferSizes and \
self.nonConstBufferDevice == other.nonConstBufferDevice

# ExecutionPlanT
def _UnPack(self, executionPlan):
Expand Down Expand Up @@ -451,6 +491,14 @@ def _UnPack(self, executionPlan):
self.nonConstBufferSizes.append(executionPlan.NonConstBufferSizes(i))
else:
self.nonConstBufferSizes = executionPlan.NonConstBufferSizesAsNumpy()
if not executionPlan.NonConstBufferDeviceIsNone():
self.nonConstBufferDevice = []
for i in range(executionPlan.NonConstBufferDeviceLength()):
if executionPlan.NonConstBufferDevice(i) is None:
self.nonConstBufferDevice.append(None)
else:
nonConstBufferDevice_ = executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT.InitFromObj(executionPlan.NonConstBufferDevice(i))
self.nonConstBufferDevice.append(nonConstBufferDevice_)

# ExecutionPlanT
def Pack(self, builder):
Expand Down Expand Up @@ -514,6 +562,14 @@ def Pack(self, builder):
for i in reversed(range(len(self.nonConstBufferSizes))):
builder.PrependInt64(self.nonConstBufferSizes[i])
nonConstBufferSizes = builder.EndVector()
if self.nonConstBufferDevice is not None:
nonConstBufferDevicelist = []
for i in range(len(self.nonConstBufferDevice)):
nonConstBufferDevicelist.append(self.nonConstBufferDevice[i].Pack(builder))
ExecutionPlanStartNonConstBufferDeviceVector(builder, len(self.nonConstBufferDevice))
for i in reversed(range(len(self.nonConstBufferDevice))):
builder.PrependUOffsetTRelative(nonConstBufferDevicelist[i])
nonConstBufferDevice = builder.EndVector()
ExecutionPlanStart(builder)
if self.name is not None:
ExecutionPlanAddName(builder, name)
Expand All @@ -533,5 +589,7 @@ def Pack(self, builder):
ExecutionPlanAddDelegates(builder, delegates)
if self.nonConstBufferSizes is not None:
ExecutionPlanAddNonConstBufferSizes(builder, nonConstBufferSizes)
if self.nonConstBufferDevice is not None:
ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice)
executionPlan = ExecutionPlanEnd(builder)
return executionPlan
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: executorch_flatbuffer

import flatbuffers
from flatbuffers.compat import import_numpy
from typing import Any
np = import_numpy()

class NonConstBufferDevice(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset: int = 0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = NonConstBufferDevice()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsNonConstBufferDevice(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def NonConstBufferDeviceBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x45\x54\x31\x32", size_prefixed=size_prefixed)

# NonConstBufferDevice
def Init(self, buf: bytes, pos: int):
self._tab = flatbuffers.table.Table(buf, pos)

# NonConstBufferDevice
def BufferIdx(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0

# NonConstBufferDevice
def DeviceType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

# NonConstBufferDevice
def DeviceIndex(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

def NonConstBufferDeviceStart(builder: flatbuffers.Builder):
builder.StartObject(3)

def Start(builder: flatbuffers.Builder):
NonConstBufferDeviceStart(builder)

def NonConstBufferDeviceAddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int):
builder.PrependInt32Slot(0, bufferIdx, 0)

def AddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int):
NonConstBufferDeviceAddBufferIdx(builder, bufferIdx)

def NonConstBufferDeviceAddDeviceType(builder: flatbuffers.Builder, deviceType: int):
builder.PrependInt8Slot(1, deviceType, 0)

def AddDeviceType(builder: flatbuffers.Builder, deviceType: int):
NonConstBufferDeviceAddDeviceType(builder, deviceType)

def NonConstBufferDeviceAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
builder.PrependInt8Slot(2, deviceIndex, 0)

def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
NonConstBufferDeviceAddDeviceIndex(builder, deviceIndex)

def NonConstBufferDeviceEnd(builder: flatbuffers.Builder) -> int:
return builder.EndObject()

def End(builder: flatbuffers.Builder) -> int:
return NonConstBufferDeviceEnd(builder)


class NonConstBufferDeviceT(object):

# NonConstBufferDeviceT
def __init__(self):
self.bufferIdx = 0 # type: int
self.deviceType = 0 # type: int
self.deviceIndex = 0 # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
nonConstBufferDevice = NonConstBufferDevice()
nonConstBufferDevice.Init(buf, pos)
return cls.InitFromObj(nonConstBufferDevice)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, nonConstBufferDevice):
x = NonConstBufferDeviceT()
x._UnPack(nonConstBufferDevice)
return x

def __eq__(self, other):
return type(self) == type(other) and \
self.bufferIdx == other.bufferIdx and \
self.deviceType == other.deviceType and \
self.deviceIndex == other.deviceIndex

# NonConstBufferDeviceT
def _UnPack(self, nonConstBufferDevice):
if nonConstBufferDevice is None:
return
self.bufferIdx = nonConstBufferDevice.BufferIdx()
self.deviceType = nonConstBufferDevice.DeviceType()
self.deviceIndex = nonConstBufferDevice.DeviceIndex()

# NonConstBufferDeviceT
def Pack(self, builder):
NonConstBufferDeviceStart(builder)
NonConstBufferDeviceAddBufferIdx(builder, self.bufferIdx)
NonConstBufferDeviceAddDeviceType(builder, self.deviceType)
NonConstBufferDeviceAddDeviceIndex(builder, self.deviceIndex)
nonConstBufferDevice = NonConstBufferDeviceEnd(builder)
return nonConstBufferDevice
2 changes: 2 additions & 0 deletions exir/_serialize/generated/executorch_flatbuffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from . import KernelTypes
from . import MoveCall
from . import NamedData
from . import NonConstBufferDevice
from . import Null
from . import Operator
from . import OptionalTensorList
Expand Down Expand Up @@ -75,6 +76,7 @@
"KernelTypes",
"MoveCall",
"NamedData",
"NonConstBufferDevice",
"Null",
"Operator",
"OptionalTensorList",
Expand Down
28 changes: 28 additions & 0 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
Expand Down Expand Up @@ -38,7 +38,9 @@
ContainerMetadata,
DataLocation,
DataSegment,
DeviceType,
ExecutionPlan,
NonConstBufferDevice,
Program,
SubsegmentOffsets,
)
Expand Down Expand Up @@ -477,6 +479,32 @@
program, deserialize_pte_binary(flatbuffer_from_py).program
)

def test_round_trip_with_non_const_buffer_device(self) -> None:
"""Tests that non_const_buffer_device survives round-trip
serialization/deserialization. This verifies the schema extension
for per-buffer device mapping works correctly.
"""
program = get_test_program()
program.execution_plan[0].non_const_buffer_device = [
NonConstBufferDevice(buffer_idx=0, device_type=DeviceType.CPU, device_index=0),
NonConstBufferDevice(buffer_idx=1, device_type=DeviceType.CUDA, device_index=0),
]
flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
self.assert_programs_equal(
program, deserialize_pte_binary(flatbuffer_from_py).program
)

def test_round_trip_without_non_const_buffer_device(self) -> None:
"""Tests backward compatibility: a program without non_const_buffer_device
(the default) round-trips correctly and the field remains None.
"""
program = get_test_program()
self.assertIsNone(program.execution_plan[0].non_const_buffer_device)
flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
deserialized = deserialize_pte_binary(flatbuffer_from_py).program
self.assert_programs_equal(program, deserialized)
self.assertIsNone(deserialized.execution_plan[0].non_const_buffer_device)

def test_round_trip_no_segments_and_no_header(self) -> None:
"""Tests that a Program serialized with extract_delegate_segments=True
when there are no segments does not contain an extended header,
Expand Down
15 changes: 15 additions & 0 deletions exir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,18 @@ class Operator:
overload: str


@dataclass
class NonConstBufferDevice:
"""Maps a non-constant buffer to the device where it should be allocated."""

# Index into the non_const_buffer_sizes list.
buffer_idx: int = 0
# The device type for this buffer (CPU, CUDA, etc.).
device_type: DeviceType = DeviceType.CPU
# The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
device_index: int = 0


@dataclass
class ExecutionPlan:
name: str
Expand All @@ -283,6 +295,9 @@ class ExecutionPlan:
# Runtime should use the len(constant_buffer) as the ground truch of
# constant memory buffer size, and ignore non_const_buffer_sizes[0].
non_const_buffer_sizes: List[int]
# Per-buffer device mapping. Each entry maps a non-constant buffer to the
# device where it should be allocated. For CPU-only programs, this is empty.
non_const_buffer_device: Optional[List[NonConstBufferDevice]] = None


@dataclass
Expand Down
21 changes: 21 additions & 0 deletions schema/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,27 @@ table ExecutionPlan {
// constants memory buffer size, and ignore non_const_buffer_sizes[0].
non_const_buffer_sizes: [int64];

// [Optional] Per-buffer device mapping, parallel to non_const_buffer_sizes.
// Each entry maps a non-constant buffer to the device where it should be
// allocated. For CPU-only programs, this field is absent and all buffers
// default to CPU, ensuring zero regression.
non_const_buffer_device: [NonConstBufferDevice];

}

// Maps a non-constant buffer to the device where it should be allocated.
// When present as part of ExecutionPlan.non_const_buffer_device, each entry
// describes the device placement for the corresponding planned memory buffer.
// For CPU-only programs, this table is absent (all buffers default to CPU).
table NonConstBufferDevice {
// Index into the non_const_buffer_sizes list.
buffer_idx: int;

// The device type for this buffer (CPU, CUDA, etc.).
device_type: DeviceType = CPU;

// The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
device_index: byte = 0;
}

// Constant tensor data stored directly in the flatbuffer.
Expand Down
Loading