Conversation
|
Here is the perf comparison on navi48 against mlir: |
There was a problem hiding this comment.
Pull request overview
This PR extends the rewrite_reduce optimization pass to rewrite “skinny” dot ops (small M dimension) into an equivalent mul + reduce_sum formulation, while detecting and exempting attention-related dots (QK^T and softmaxV) to preserve attention fusion opportunities.
Changes:
- Add attention-pattern detection (via decomposed softmax matching) to identify dots that should not be rewritten.
- Rewrite
dotops with M (rows) ≤ 2 intounsqueeze + transpose + mul + reduce_sum + squeeze. - Expand
test/rewrite_reduce.cppwith new coverage for skinny-dot rewrites and attention/non-attention cases.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/rewrite_reduce.cpp |
Adds attention dot collection + a new find_dot matcher that rewrites small-M dots into mul/reduce, skipping attention dots. |
test/rewrite_reduce.cpp |
Adds/updates unit tests to validate skinny-dot rewrites and to ensure attention patterns are not rewritten. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // If the b matrix is const foldable then make sure its a transposed layout unless its | ||
| // broadcasting | ||
| if(b_mat->can_eval() and not b_shape.transposed()) | ||
| { | ||
| b_mat = | ||
| m.insert_instruction(ins, make_op("layout", {{"permutation", permutation}}), b_mat); | ||
| } |
There was a problem hiding this comment.
The layout insertion for const-foldable b_mat is guarded only by !b_shape.transposed(), but the preceding comment says it should be skipped when b_mat is broadcasting. If b_shape.broadcasted() is true (e.g., a multibroadcasted literal), inserting layout can materialize the broadcasted tensor and significantly increase compile-time/memory usage. Consider adding and not b_shape.broadcasted() (or an equivalent check) to the condition so broadcasted constants keep their cheap broadcasted representation.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4811 +/- ##
===========================================
+ Coverage 92.49% 92.50% +0.01%
===========================================
Files 583 583
Lines 29561 29621 +60
===========================================
+ Hits 27342 27399 +57
- Misses 2219 2222 +3
🚀 New features to boost your workflow:
|
Regressions detected 🔴 |
|
Motivation
This algorithm is usually much faster than mfmas.
Technical Details
Since decoding attention also uses skinny gemms, it does skip the gemms used in attention.
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable