Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: executorch_flatbuffer

class DeviceType(object):
CPU = 0
CUDA = 1
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,22 @@ def Location(self):
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

# ExtraTensorInfo
def DeviceType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

# ExtraTensorInfo
def DeviceIndex(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

def ExtraTensorInfoStart(builder: flatbuffers.Builder):
builder.StartObject(3)
builder.StartObject(5)

def Start(builder: flatbuffers.Builder):
ExtraTensorInfoStart(builder)
Expand All @@ -75,6 +89,18 @@ def ExtraTensorInfoAddLocation(builder: flatbuffers.Builder, location: int):
def AddLocation(builder: flatbuffers.Builder, location: int):
ExtraTensorInfoAddLocation(builder, location)

def ExtraTensorInfoAddDeviceType(builder: flatbuffers.Builder, deviceType: int):
builder.PrependInt8Slot(3, deviceType, 0)

def AddDeviceType(builder: flatbuffers.Builder, deviceType: int):
ExtraTensorInfoAddDeviceType(builder, deviceType)

def ExtraTensorInfoAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
builder.PrependInt8Slot(4, deviceIndex, 0)

def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
ExtraTensorInfoAddDeviceIndex(builder, deviceIndex)

def ExtraTensorInfoEnd(builder: flatbuffers.Builder) -> int:
return builder.EndObject()

Expand All @@ -89,6 +115,8 @@ def __init__(self):
self.mutableDataSegmentsIdx = 0 # type: int
self.fullyQualifiedName = None # type: str
self.location = 0 # type: int
self.deviceType = 0 # type: int
self.deviceIndex = 0 # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
Expand All @@ -111,7 +139,9 @@ def __eq__(self, other):
return type(self) == type(other) and \
self.mutableDataSegmentsIdx == other.mutableDataSegmentsIdx and \
self.fullyQualifiedName == other.fullyQualifiedName and \
self.location == other.location
self.location == other.location and \
self.deviceType == other.deviceType and \
self.deviceIndex == other.deviceIndex

# ExtraTensorInfoT
def _UnPack(self, extraTensorInfo):
Expand All @@ -120,6 +150,8 @@ def _UnPack(self, extraTensorInfo):
self.mutableDataSegmentsIdx = extraTensorInfo.MutableDataSegmentsIdx()
self.fullyQualifiedName = extraTensorInfo.FullyQualifiedName()
self.location = extraTensorInfo.Location()
self.deviceType = extraTensorInfo.DeviceType()
self.deviceIndex = extraTensorInfo.DeviceIndex()

# ExtraTensorInfoT
def Pack(self, builder):
Expand All @@ -130,5 +162,7 @@ def Pack(self, builder):
if self.fullyQualifiedName is not None:
ExtraTensorInfoAddFullyQualifiedName(builder, fullyQualifiedName)
ExtraTensorInfoAddLocation(builder, self.location)
ExtraTensorInfoAddDeviceType(builder, self.deviceType)
ExtraTensorInfoAddDeviceIndex(builder, self.deviceIndex)
extraTensorInfo = ExtraTensorInfoEnd(builder)
return extraTensorInfo
2 changes: 2 additions & 0 deletions exir/_serialize/generated/executorch_flatbuffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import DataLocation
from . import DataSegment
from . import DelegateCall
from . import DeviceType
from . import Double
from . import DoubleList
from . import EValue
Expand Down Expand Up @@ -56,6 +57,7 @@
"DataLocation",
"DataSegment",
"DelegateCall",
"DeviceType",
"Double",
"DoubleList",
"EValue",
Expand Down
7 changes: 7 additions & 0 deletions exir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class TensorDataLocation(IntEnum):
EXTERNAL = 1


class DeviceType(IntEnum):
CPU = 0
CUDA = 1


@dataclass
class ExtraTensorInfo:
"""
Expand All @@ -57,6 +62,8 @@ class ExtraTensorInfo:
mutable_data_segments_idx: int = 0
fully_qualified_name: Optional[str] = None
location: TensorDataLocation = TensorDataLocation.SEGMENT
device_type: DeviceType = DeviceType.CPU
device_index: int = 0


@dataclass
Expand Down
15 changes: 15 additions & 0 deletions schema/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum TensorDataLocation : byte {
EXTERNAL = 1,
}

// Device type enum indicating where a tensor resides or should be allocated.
// Please keep this in sync with executorch/runtime/core/portable_type/device.h
enum DeviceType : byte {
CPU = 0,
CUDA = 1,
}

// Table to put additional information about tensors in that is not applicable
// to the vast majority of tensors in the vast majority of programs.
table ExtraTensorInfo {
Expand All @@ -80,6 +87,14 @@ table ExtraTensorInfo {
// must be non-empty, and is used as a key to find the tensor's external
// data. Tensor.data_buffer_idx is ignored.
location: TensorDataLocation;

// [Optional] The device type where this tensor resides or should be allocated.
// Defaults to CPU for backward compatibility with existing PTE files.
device_type: DeviceType = CPU;

// [Optional] The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
// Defaults to 0 (the first device of the given type).
device_index: byte = 0;
}

table Tensor {
Expand Down
Loading