Skip to content

Rewrite skinny gemms to mul+reduce_sum#4811

Open
pfultz2 wants to merge 13 commits intodevelopfrom
skinny-dot-2-reduce
Open

Rewrite skinny gemms to mul+reduce_sum#4811
pfultz2 wants to merge 13 commits intodevelopfrom
skinny-dot-2-reduce

Conversation

@pfultz2
Copy link
Copy Markdown
Collaborator

@pfultz2 pfultz2 commented Apr 22, 2026

Motivation

This algorithm is usually much faster than mfmas.

Technical Details

Since decoding attention also uses skinny gemms, it does skip the gemms used in attention.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@pfultz2 pfultz2 requested a review from causten as a code owner April 22, 2026 00:25
Copilot AI review requested due to automatic review settings April 22, 2026 00:25
@pfultz2
Copy link
Copy Markdown
Collaborator Author

pfultz2 commented Apr 22, 2026

Here is the perf comparison on navi48 against mlir:

  ================================================================================
  mlir-base  (skinny=M, type=float_type)
  ================================================================================
       M        N        K    Baseline (ms)      Time (ms)     TFLOPS     Change
  ------ -------- --------   --------------   ------------   --------   --------
       1       10      256           0.0090         0.0074       0.00     +21.6%
       1     1000     1024           0.0139         0.0079       0.26     +75.9%
       1     4096     9216           0.2391         0.2421       0.31      -1.2%
       1     4096     4096           0.0870         0.0406       0.83    +114.3%
       1      200     4096           0.0459         0.0076       0.21    +503.9%
       1     1000     4096           0.0463         0.0142       0.58    +226.1%
       1     4096    18432           0.4747         0.4785       0.32      -0.8%
       1     1024     4096           0.0482         0.0144       0.58    +234.7%
       1      512    25088           0.2700         0.0390       0.66    +592.3%
       1     4096    25088           0.6535         0.6598       0.31      -1.0%
       1    16000     8192           0.8334         0.8334       0.31      +0.0%
       1    28672     8192           1.4788         1.4821       0.32      -0.2%
       1     5120     8192           0.2763         0.2682       0.31      +3.0%
       1     8192    14336           0.7446         0.7482       0.31      -0.5%
       1     8192     4096           0.2126         0.2160       0.31      -1.6%

  
================================================================================
  mlir-base  (skinny=M, type=half)
  ================================================================================
       M        N        K    Baseline (ms)      Time (ms)     TFLOPS     Change
  ------ -------- --------   --------------   ------------   --------   --------
       1       10      256           0.0050         0.0079       0.00     -36.7%
       1     1000     1024           0.0053         0.0076       0.27     -30.3%
       1     4096     9216           0.1199         0.1233       0.61      -2.8%
       1     4096     4096           0.0204         0.0230       1.46     -11.3%
       1      200     4096           0.0120         0.0076       0.22     +57.9%
       1     1000     4096           0.0127         0.0101       0.81     +25.7%
       1     4096    18432           0.2389         0.2421       0.62      -1.3%
       1     1024     4096           0.0129         0.0102       0.82     +26.5%
       1      512    25088           0.0550         0.0194       1.32    +183.5%
       1     4096    25088           0.3239         0.3273       0.63      -1.0%
       1    16000     8192           0.4127         0.4156       0.63      -0.7%
       1    28672     8192           0.7447         0.7486       0.63      -0.5%
       1     5120     8192           0.1259         0.1368       0.61      -8.0%
       1     8192    14336           0.3696         0.3728       0.63      -0.9%
       1     8192     4096           0.0382         0.0407       1.65      -6.1%

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the rewrite_reduce optimization pass to rewrite “skinny” dot ops (small M dimension) into an equivalent mul + reduce_sum formulation, while detecting and exempting attention-related dots (QK^T and softmaxV) to preserve attention fusion opportunities.

Changes:

  • Add attention-pattern detection (via decomposed softmax matching) to identify dots that should not be rewritten.
  • Rewrite dot ops with M (rows) ≤ 2 into unsqueeze + transpose + mul + reduce_sum + squeeze.
  • Expand test/rewrite_reduce.cpp with new coverage for skinny-dot rewrites and attention/non-attention cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
src/rewrite_reduce.cpp Adds attention dot collection + a new find_dot matcher that rewrites small-M dots into mul/reduce, skipping attention dots.
test/rewrite_reduce.cpp Adds/updates unit tests to validate skinny-dot rewrites and to ensure attention patterns are not rewritten.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/rewrite_reduce.cpp
Comment on lines +136 to +142
// If the b matrix is const foldable then make sure its a transposed layout unless its
// broadcasting
if(b_mat->can_eval() and not b_shape.transposed())
{
b_mat =
m.insert_instruction(ins, make_op("layout", {{"permutation", permutation}}), b_mat);
}
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layout insertion for const-foldable b_mat is guarded only by !b_shape.transposed(), but the preceding comment says it should be skipped when b_mat is broadcasting. If b_shape.broadcasted() is true (e.g., a multibroadcasted literal), inserting layout can materialize the broadcasted tensor and significantly increase compile-time/memory usage. Consider adding and not b_shape.broadcasted() (or an equivalent check) to the condition so broadcasted constants keep their cheap broadcasted representation.

Copilot uses AI. Check for mistakes.
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

❌ Patch coverage is 94.91525% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rewrite_reduce.cpp 94.92% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4811      +/-   ##
===========================================
+ Coverage    92.49%   92.50%   +0.01%     
===========================================
  Files          583      583              
  Lines        29561    29621      +60     
===========================================
+ Hits         27342    27399      +57     
- Misses        2219     2222       +3     
Files with missing lines Coverage Δ
src/rewrite_reduce.cpp 98.16% <94.92%> (-1.21%) ⬇️

... and 30 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 23, 2026

Test Batch New Rate (f2f08f) Old Rate (0681c6) Diff Status
torchvision-resnet50 64 3,136.33 3,166.53 -0.95%
torchvision-resnet50_fp16 64 6,239.58 6,820.89 -8.52% 🔴
torchvision-densenet121 32 2,593.68 2,422.10 7.08% 🔆
torchvision-densenet121_fp16 32 4,092.38 4,086.15 0.15%
torchvision-inceptionv3 32 1,795.88 1,658.01 8.31% 🔆
torchvision-inceptionv3_fp16 32 2,706.77 2,677.22 1.10%
cadene-inceptionv4 16 448.09 789.08 -43.21% 🔴
cadene-resnext64x4 16 630.09 790.86 -20.33% 🔴
slim-mobilenet 64 6,246.83 8,521.12 -26.69% 🔴
slim-nasnetalarge 64 nan 218.37 nan
slim-resnet50v2 64 1,474.63 3,288.39 -55.16% 🔴
bert-mrpc-onnx 8 134.03 1,108.72 -87.91% 🔴
bert-mrpc-tf 1 nan 456.31 nan
pytorch-examples-wlang-gru 1 311.14 338.72 -8.14% 🔴
pytorch-examples-wlang-lstm 1 547.96 452.66 21.05% 🔆
torchvision-resnet50_1 1 nan 748.69 nan
cadene-dpn92_1 1 380.01 403.71 -5.87% 🔴
cadene-resnext101_1 1 nan 323.10 nan
onnx-taau-downsample 1 256.08 391.86 -34.65% 🔴
dlrm-criteoterabyte 1 14.46 31.56 -54.19% 🔴
dlrm-criteoterabyte_fp16 1 23.00 50.03 -54.03% 🔴
agentmodel 1 5,733.52 10,347.89 -44.59% 🔴
unet_fp16 2 15.68 55.66 -71.82% 🔴
resnet50v1_fp16 1 nan 994.78 nan
resnet50v1_int8 1 88.21 943.47 -90.65% 🔴
bert_base_cased_fp16 64 957.12 1,086.36 -11.90% 🔴
bert_large_uncased_fp16 32 93.98 338.39 -72.23% 🔴
bert_large_fp16 1 19.76 198.67 -90.06% 🔴
distilgpt2_fp16 16 776.85 2,066.44 -62.41% 🔴
yolov5s 1 387.28 549.43 -29.51% 🔴
tinyllama 1 14.23 43.80 -67.50% 🔴
vicuna-fastchat 1 9.51 42.63 -77.69% 🔴
whisper-tiny-encoder 1 91.65 406.13 -77.43% 🔴
whisper-tiny-decoder 1 94.67 395.21 -76.04% 🔴
llama2_7b 1 4.87 19.10 -74.50% 🔴
qwen1.5-7b 1 7.40 23.42 -68.41% 🔴
phi3-3.8b 1 26.69 26.63 0.25%
llama3-8b 1 20.71 21.68 -4.45%
whisper-large-encoder 1 5.62 10.12 -44.47% 🔴
whisper-large-decoder 1 18.17 101.03 -82.01% 🔴
mistral-7b 1 23.76 23.64 0.51%
FLUX.1-schnell 1 736.07 737.25 -0.16%

Regressions detected 🔴

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 23, 2026

Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
2026-04-22 20:08:26.250759 [ERROR] [/src/AMDMIGraphX/src/pass_manager.cpp:185] Error rewrite_reduce: /src/AMDMIGraphX/src/include/migraphx/check_shapes.hpp:220: same_dims: mul: Dimensions do not match
2026-04-22 20:08:26.250829 [ERROR] [/src/AMDMIGraphX/src/pass_manager.cpp:195] Dump: "/tmp/migraphx/rewrite_reduce64510854628877.mxr"
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 360, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 233, in main
model.compile(
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/check_shapes.hpp:220: same_dims: mul: Dimensions do not match
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 ERROR - check error output
traceback
2026-04-22 20:18:37.159212 [ERROR] [/src/AMDMIGraphX/src/pass_manager.cpp:185] Error simplify_reshapes: /src/AMDMIGraphX/src/include/migraphx/op/multibroadcast.hpp:87: compute_shape: MULTIBROADCAST: input shape {1, 1, 2050048} cannot be broadcasted to {1, 1001, 2048}!
2026-04-22 20:18:37.159293 [ERROR] [/src/AMDMIGraphX/src/pass_manager.cpp:195] Dump: "/tmp/migraphx/simplify_reshapes65121763086641.mxr"
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 360, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 233, in main
model.compile(
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/multibroadcast.hpp:87: compute_shape: MULTIBROADCAST: input shape {1, 1, 2050048} cannot be broadcasted to {1, 1001, 2048}!
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 360, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 211, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/include/migraphx/op/convolution.hpp:103: normalize_compute_shape: CONVOLUTION: mismatched channel numbers
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants