Skip to content

Arm VGF quantizer/PT2E prepare fails on annotated aten.full.default with kwargs #18322

@Rob-Hughes-Arm

Description

@Rob-Hughes-Arm

🐛 Describe the bug

Arm VGF quantizer/PT2E prepare fails on annotated aten.full.default with kwargs

Summary

VgfQuantizer can fail during PT2E prepare before lowering/runtime when the exported graph contains aten.full.default(...) and the Arm quantizer annotates that node directly.

The failure is:

AssertionError: expecting kwargs for aten op IR to be empty

This happens because the exported aten.full.default node carries kwargs such as:

{'dtype': torch.float32, 'device': device(type='cpu'), 'pin_memory': False}

and torchao.quantization.pt2e.prepare does not special-case aten.full.default.


Self-Contained Minimal Repro

import traceback

import torch
import torch.nn as nn
from executorch.backends.arm.quantizer import (
    VgfQuantizer,
    get_symmetric_quantization_config,
)
from executorch.backends.arm.vgf import VgfCompileSpec


class MinimalFullModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.predict = nn.Conv2d(4, 2, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pad = torch.full((1, 2, 8, 8), 0.0, device=x.device, dtype=x.dtype)
        fused = torch.cat([pad, x[:, :2]], dim=1)
        return self.predict(fused)


def main() -> None:
    model = MinimalFullModel().eval()
    example_input = torch.randn(1, 4, 8, 8)

    exported = torch.export.export(model, (example_input,), strict=True)
    graph_module = exported.module(check_guards=False)

    print("Nodes with kwargs in exported graph:")
    for node in graph_module.graph.nodes:
        if node.kwargs:
            print(node.name, node.target, node.kwargs)

    quantizer = VgfQuantizer(VgfCompileSpec("TOSA-1.0+INT"))
    symmetric_int8_config = get_symmetric_quantization_config(
        is_per_channel=True,
        is_qat=False,
        is_dynamic=False,
        act_qmin=-127,
        act_qmax=127,
        weight_qmin=-127,
        weight_qmax=127,
    )
    quantizer.set_global(symmetric_int8_config)

    quantizer.quantize_with_submodules(
        graph_module,
        calibration_samples=[(example_input,)],
        is_qat=False,
    )


if __name__ == "__main__":
    try:
        main()
    except Exception:
        traceback.print_exc()
        raise

Observed exported node

aten.full.default(..., dtype=torch.float32, device=cpu, pin_memory=False)

Observed failure

Traceback (most recent call last):
  ...
  File "...torchao\\quantization\\pt2e\\prepare.py", line 547, in _maybe_insert_input_observers_for_node
    assert (
AssertionError: expecting kwargs for aten op IR to be empty

Why This Looks VGF-Specific

The same explicit torch.full(...) pattern succeeds with XNNPACKQuantizer in this environment.

The key behavior difference is:

  • XNNPACKQuantizer does not annotate aten.full.default
  • VgfQuantizer does annotate aten.full.default

Once full is annotated, PT2E prepare processes it as the current annotated node and hits the kwargs assertion.

When full is left unannotated, PT2E can still insert observers at annotated consumer boundaries such as cat, and the graph prepares successfully.


Root Cause

Arm quantization annotation appears to be over-annotating constant/factory ops.

In executorch/backends/arm/quantizer/quantization_annotator.py, the Arm annotator currently treats these as quantizable outputs:

torch.ops.aten.full.default
torch.ops.aten.full
torch.ops.aten.zeros.default
torch.ops.aten.ones.default
torch.ops.aten.fill_.Scalar
torch.ops.aten.scalar_tensor.default

That is enough to trigger PT2E prepare on nodes whose exported ATen form still contains kwargs.


Evidence For A Safe Fix

I tested skipping annotation for:

torch.ops.aten.full.default
torch.ops.aten.full

in the Arm quantization annotator.

Result:

  • the minimal torch.full -> cat -> conv repro passes
  • quantization still propagates through annotated consumers like cat and conv
  • a model that returns torch.full(...) directly still works
    • with IO annotation enabled, PT2E quantizes at the graph output boundary
    • without IO annotation, the output remains float

This matches the more conservative XNNPACK behavior.


Suggested Fix

Do not directly annotate constant/factory producer ops such as:

  • aten.full.default
  • aten.full
  • likely also aten.zeros.default
  • aten.ones.default
  • aten.fill_.Scalar
  • aten.scalar_tensor.default

Instead, rely on annotated consumers (cat, conv, etc.) and graph IO annotation to place observers where needed.

That avoids the PT2E kwargs assertion while preserving quantization at real dataflow boundaries.


Environment

  • executorch==1.2.0.dev20260305+cpu
  • torch==2.10.0
  • torchao==0.15.0

Versions

Environment
executorch==1.2.0.dev20260305+cpu
torch==2.10.0
torchao==0.15.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions