Rewrite mul reduce to use fdot2 instructions#4787
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a GPU-side optimization to fuse reduce_sum(pointwise(mul(x,y))) into a specialized reduction path that can leverage vector dot-product instructions (e.g., fdot2) during codegen, and adds kernel-level tests for the new vector dot helper.
Changes:
- Add
migraphx::vec_dot(generic + HW-specialized overloads) to support dot-product style reads in GPU reductions. - Introduce
gpu::mul_reduce_sumand aprepare_reducerewrite to replacereduce_sum(pointwise(mul))with the specialized op. - Extend GPU reduce codegen to emit
vec_dotas the read function forgpu::mul_reduce_sum, plus add kernel tests forvec.hpp.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
test/gpu/kernels/vec.cpp |
Adds kernel-level tests covering vec utilities and new vec_dot behavior. |
src/targets/gpu/prepare_reduce.cpp |
Adds gpu::mul_reduce_sum op and a rewrite pass to detect/rewrite mul+reduce_sum patterns. |
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp |
Adds vec_dot (generic + half/bf16/int8 HW overloads). |
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp |
Tightens MIGRAPHX_REQUIRES macro by parenthesizing the condition. |
src/targets/gpu/compile_gen.cpp |
Updates reduce codegen to handle gpu::mul_reduce_sum by using vec_dot as the read. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // rewrite argmin/argmax to handle tuples | ||
| rewrite_arg_reduce(m); | ||
| rewrite_mul_reduce_sum(m); | ||
| fuse_reductions(m); |
There was a problem hiding this comment.
This PR introduces a new rewrite_mul_reduce_sum transformation that changes the reduce kernel IR (introducing gpu::mul_reduce_sum), but there are no corresponding unit tests validating when the rewrite should/should not apply. There is already test coverage for gpu::prepare_reduce in test/gpu/prepare_reduce.cpp; adding cases for the mul+reduce_sum pattern (including a negative case for dynamic/runtime axes) would help prevent regressions.
| if(ins->name() != "reduce_sum") | ||
| continue; | ||
| auto pw = ins->inputs().front(); | ||
| if(pw->name() != "pointwise") | ||
| continue; | ||
| if(not is_only_mul(*pw->module_inputs().front())) | ||
| continue; | ||
| auto axes = ins->get_operator().to_value()["axes"].to_vector<std::int64_t>(); | ||
| auto mrs = m.insert_instruction(ins, mul_reduce_sum{std::move(axes)}, pw->inputs()); | ||
| m.replace_instruction(ins, mrs); |
There was a problem hiding this comment.
rewrite_mul_reduce_sum rewrites every reduce_sum whose first input is a pointwise(mul) without checking whether the reduction axes are provided as a runtime input (i.e., axes attribute is empty and reduce_sum has a second input). In that dynamic-axes case this rewrite drops the axes input entirely (since it always uses pw->inputs()), changing semantics and potentially producing an incorrect output shape.
Consider guarding the rewrite to only trigger when reduce_sum has static axes (e.g., ins->inputs().size() == 1 and/or the axes attribute is non-empty), or alternatively preserving the axes input and implementing mul_reduce_sum to support runtime axes as well.
Motivation
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable