Skip to content

Add fused MatMul operator#4806

Open
ahsan-ca wants to merge 1 commit intodevelopfrom
fusedMatMul-op
Open

Add fused MatMul operator#4806
ahsan-ca wants to merge 1 commit intodevelopfrom
fusedMatMul-op

Conversation

@ahsan-ca
Copy link
Copy Markdown
Contributor

@ahsan-ca ahsan-ca commented Apr 20, 2026

Motivation

Add fused MatMul operator

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@ahsan-ca ahsan-ca self-assigned this Apr 20, 2026
@ahsan-ca ahsan-ca requested a review from causten as a code owner April 20, 2026 22:13
Copilot AI review requested due to automatic review settings April 20, 2026 22:13
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for the com.microsoft::FusedMatMul ONNX contrib operator in MIGraphX, along with parser/verification tests and generated ONNX fixtures to validate transpose/batch-transpose/alpha behaviors.

Changes:

  • Implemented a new ONNX parser for FusedMatMul that lowers to MIGraphX ops (transpose, dot, optional mul) with 1-D promotion/squeeze behavior.
  • Added ONNX parse tests covering transA, transB, transBatchA, transBatchB, alpha, and an invalid-rank error case.
  • Added a numerical verify test and updated gen_onnx.py + embedded .onnx models for the new operator.

Reviewed changes

Copilot reviewed 22 out of 25 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/onnx/parse_fused_matmul.cpp Adds the FusedMatMul parser/lowering logic.
test/onnx/gen_onnx.py Generates ONNX models for FusedMatMul test cases.
test/onnx/parse/fused_matmul_*.cpp Adds parser equivalence tests for various attribute combinations and an error case.
test/onnx/verify/fused_matmul_verify_test.cpp Adds numeric verification of FusedMatMul output vs gold values.
test/onnx/fused_matmul_*.onnx Adds embedded ONNX fixtures used by the new tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +119 to +123
if(s0.dynamic() or s1.dynamic())
{
MIGRAPHX_THROW("PARSE_FUSEDMATMUL: dynamic inputs not supported");
}

Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FusedMatMul currently throws on any dynamic input shapes, even though the underlying ops used here (transpose, unsqueeze, and the op-builder dot which already handles dynamic broadcasting via broadcast_for_dot) support dynamic shapes. This makes FusedMatMul unnecessarily less capable than the existing MatMul parser. Consider removing this restriction and letting dot/transpose handle dynamic shapes (or only rejecting specific unsupported dynamic cases, if any).

Suggested change
if(s0.dynamic() or s1.dynamic())
{
MIGRAPHX_THROW("PARSE_FUSEDMATMUL: dynamic inputs not supported");
}

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Is there an issue you run into if you just apply the same transposes, etc. for the dynamic case?

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 20, 2026

Codecov Report

❌ Patch coverage is 86.36364% with 9 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/onnx/parse_fused_matmul.cpp 86.36% 9 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4806      +/-   ##
===========================================
- Coverage    92.49%   92.48%   -0.01%     
===========================================
  Files          583      584       +1     
  Lines        29562    29628      +66     
===========================================
+ Hits         27343    27400      +57     
- Misses        2219     2228       +9     
Files with missing lines Coverage Δ
src/onnx/parse_fused_matmul.cpp 86.36% <86.36%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 21, 2026

Test Batch New Rate (113810) Old Rate (0681c6) Diff Status
torchvision-resnet50 64 1,495.31 3,166.53 -52.78% 🔴
torchvision-resnet50_fp16 64 3,522.06 6,820.89 -48.36% 🔴
torchvision-densenet121 32 1,563.76 2,422.10 -35.44% 🔴
torchvision-densenet121_fp16 32 4,200.31 4,086.15 2.79%
torchvision-inceptionv3 32 1,009.51 1,658.01 -39.11% 🔴
torchvision-inceptionv3_fp16 32 1,638.54 2,677.22 -38.80% 🔴
cadene-inceptionv4 16 419.62 789.08 -46.82% 🔴
cadene-resnext64x4 16 374.65 790.86 -52.63% 🔴
slim-mobilenet 64 4,261.81 8,521.12 -49.99% 🔴
slim-nasnetalarge 64 108.23 218.37 -50.43% 🔴
slim-resnet50v2 64 3,144.93 3,288.39 -4.36%
bert-mrpc-onnx 8 1,126.43 1,108.72 1.60%
bert-mrpc-tf 1 456.51 456.31 0.04%
pytorch-examples-wlang-gru 1 361.04 338.72 6.59% 🔆
pytorch-examples-wlang-lstm 1 459.18 452.66 1.44%
torchvision-resnet50_1 1 714.98 748.69 -4.50%
cadene-dpn92_1 1 442.29 403.71 9.56% 🔆
cadene-resnext101_1 1 367.53 323.10 13.75% 🔆
onnx-taau-downsample 1 396.53 391.86 1.19%
dlrm-criteoterabyte 1 16.20 31.56 -48.67% 🔴
dlrm-criteoterabyte_fp16 1 32.12 50.03 -35.78% 🔴
agentmodel 1 10,088.51 10,347.89 -2.51%
unet_fp16 2 30.44 55.66 -45.32% 🔴
resnet50v1_fp16 1 920.08 994.78 -7.51% 🔴
resnet50v1_int8 1 846.40 943.47 -10.29% 🔴
bert_base_cased_fp16 64 588.98 1,086.36 -45.78% 🔴
bert_large_uncased_fp16 32 177.98 338.39 -47.40% 🔴
bert_large_fp16 1 198.46 198.67 -0.11%
distilgpt2_fp16 16 1,303.88 2,066.44 -36.90% 🔴
yolov5s 1 442.30 549.43 -19.50% 🔴
tinyllama 1 23.79 43.80 -45.68% 🔴
vicuna-fastchat 1 22.42 42.63 -47.41% 🔴
whisper-tiny-encoder 1 251.94 406.13 -37.97% 🔴
whisper-tiny-decoder 1 279.35 395.21 -29.31% 🔴
llama2_7b 1 9.75 19.10 -48.96% 🔴
qwen1.5-7b 1 14.54 23.42 -37.93% 🔴
phi3-3.8b 1 15.68 26.63 -41.13% 🔴
llama3-8b 1 13.00 21.68 -40.01% 🔴
whisper-large-encoder 1 5.04 10.12 -50.18% 🔴
whisper-large-decoder 1 72.50 101.03 -28.24% 🔴
mistral-7b 1 14.29 23.64 -39.54% 🔴
FLUX.1-schnell 1 675.18 737.25 -8.42% 🔴

Regressions detected 🔴

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 21, 2026

Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 360, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 211, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/convolution.hpp:103: normalize_compute_shape: CONVOLUTION: mismatched channel numbers
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Comment on lines +45 to +46
// transBatchA : int, for rank-R A, permute [1, 2, ..., R-2, 0, R-1] (default 0)
// transBatchB : int, for rank-R B, permute [1, 2, ..., R-2, 0, R-1] (default 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that this the transpose that is intended? Reading the spec I would think it's
[0, 1, 2, ..., R-1] -> [R-2, 0, 1, 2, ..., R-3, R-1]. The spec is really imprecise about it however.

static instruction_ref apply_trans_last_two(const onnx_parser::node_info& info,
instruction_ref x)
{
auto r = x->get_shape().ndim();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to assert r >= 2 to help debug if this parser is modified in the future that causes a rank 1 x to sneak in here

Comment on lines +119 to +123
if(s0.dynamic() or s1.dynamic())
{
MIGRAPHX_THROW("PARSE_FUSEDMATMUL: dynamic inputs not supported");
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Is there an issue you run into if you just apply the same transposes, etc. for the dynamic case?

Comment thread test/onnx/gen_onnx.py

return ([node], [m1, m2], [y])


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a few missing scenarios for test case completeness:

  1. 1D tensors and tensors that need batch broadcasting
  2. Verify with more variations of attributes (setting transA, the transBatch attrs, etc.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants