From 70cdeec952a56c6ce8d67ea9aeaeaab2ce7d7338 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Wed, 25 Feb 2026 07:19:59 -0800 Subject: [PATCH] Improvements to inference using int8 compressed kv's Multiplication is done using int16*int16 multiplication instructions avoid expensive conversion to f32/bf16 PiperOrigin-RevId: 875150774 --- compression/compress-inl.h | 14 +- gemma/flash_attention.cc | 317 ++++++++++++++++++++++++++++++++-- gemma/flash_attention_test.cc | 182 ++++++++++++++++++- ops/ops-inl.h | 171 ++++++++++++++++++ 4 files changed, 667 insertions(+), 17 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index a6aa5e36..67720810 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -511,6 +511,17 @@ struct CompressTraits { raw1 = hn::ConvertTo(df, vec_i32_1); } + template + static HWY_INLINE void Load2(DI16 di16, + const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition di8; + const auto packed_vec = hn::LoadU(di8, packed.ptr + packed_ofs); + raw0 = hn::PromoteLowerTo(di16, packed_vec); + raw1 = hn::PromoteUpperTo(di16, packed_vec); + } + template static HWY_INLINE void Load2(DBF dbf, const PackedSpan& packed, const size_t packed_ofs, hn::Vec& raw0, @@ -745,7 +756,8 @@ HWY_INLINE void VerifyRawAndPackedForDecompress() { using TRaw = hn::TFromD; // We can decompress any Packed to f32 or BF16, or f32 to f64. static_assert(hwy::IsSameEither() || - (IsF32() && hwy::IsSame())); + (IsF32() && hwy::IsSame()) || + (IsInt8() && hwy::IsSame())); } } // namespace detail diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 5922d8cc..a1f0c5a9 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -608,7 +608,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1, VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d, - float* HWY_RESTRICT scales) { + float* HWY_RESTRICT scales, float s_quantization_scale = 1.0f) { using DF4 = hn::CappedTag; const DF4 df4; using VF4 = hn::Vec; @@ -679,8 +679,9 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( const VF4 zero4 = hn::Zero(df4); const VF4 one_over_d = hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf); + const VF4 q_scale = hn::Mul(one_over_d, hn::Set(df4, s_quantization_scale)); HWY_ALIGN float tmp_one_over_d[4]; - hn::Store(one_over_d, df4, tmp_one_over_d); + hn::Store(q_scale, df4, tmp_one_over_d); hn::BlendedStore(old_d_vf, changed_max, df4, old_d); scale = hn::Mul(scale, one_over_d); hn::BlendedStore(scale, changed_max, df4, scales); @@ -713,7 +714,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1, VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max, - float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales) { + float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, + float s_quantization_scale = 1.0f) { using DF8 = hn::CappedTag; const DF8 df8; using VF8 = hn::Vec; @@ -817,8 +819,9 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( const VF8 zero8 = hn::Zero(df8); const VF8 one_over_d = hn::MaskedDivOr(zero8, non_zero_mask, hn::Set(df8, 1.0f), old_d_vf); + const VF8 q_scale = hn::Mul(one_over_d, hn::Set(df8, s_quantization_scale)); HWY_ALIGN float tmp_one_over_d[8]; - hn::Store(one_over_d, df8, tmp_one_over_d); + hn::Store(q_scale, df8, tmp_one_over_d); hn::BlendedStore(old_d_vf, changed_max, df8, old_d); scale = hn::Mul(scale, one_over_d); hn::BlendedStore(scale, changed_max, df8, scales); @@ -862,7 +865,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1, VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx, - size_t kNumQueriesPerGroup) { + size_t kNumQueriesPerGroup, float s_quantization_scale = 1.0f) { constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); [[maybe_unused]] constexpr int kSecondHalfAmountOfQueries = kNumQueries - kFirstHalfAmountOfQueries; @@ -870,25 +873,28 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( FlashAttentionTileStepAndApplySoftCap4( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales); + old_d + (q_group_idx)*kNumQueriesPerGroup, scales, + s_quantization_scale); } else { #if HWY_MAX_BYTES <= 16 FlashAttentionTileStepAndApplySoftCap4<4>( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales); + old_d + (q_group_idx)*kNumQueriesPerGroup, scales, + s_quantization_scale); FlashAttentionTileStepAndApplySoftCap4( df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1, x_7_p0, x_7_p1, old_max + (q_group_idx + 1) * kNumQueriesPerGroup, old_d + (q_group_idx + 1) * kNumQueriesPerGroup, - scales + kNumQueriesPerGroup); + scales + kNumQueriesPerGroup, s_quantization_scale); #else FlashAttentionTileStepAndApplySoftCap8( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1, x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales); + old_d + (q_group_idx)*kNumQueriesPerGroup, scales, + s_quantization_scale); #endif } } @@ -998,6 +1004,165 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( } } +template , typename T> +static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16( + DF df, const int16_t* HWY_RESTRICT q, const int16_t* HWY_RESTRICT q2, + const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0, + VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1, + VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0, + VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) { + using DI16 = hn::ScalableTag; + const DI16 di16; + using VI16 = hn::Vec; + [[maybe_unused]] const PackedSpan k_transposed_span = + MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim); + using DI32 = hn::Repartition; + const DI32 di32; + using VI32 = hn::Vec; + HWY_DASSERT(hn::Lanes(di16) <= gcpp::KVCache::kTileSize); + HWY_DASSERT(kNumQueries <= 8); + HWY_DASSERT(gcpp::KVCache::kTileSize >= hn::Lanes(df) * 2); + + VI32 isum0_p0 = hn::Zero(di32); + VI32 isum0_p1 = hn::Zero(di32); + VI32 isum1_p0 = hn::Zero(di32), isum1_p1 = hn::Zero(di32); + VI32 isum2_p0 = hn::Zero(di32), isum2_p1 = hn::Zero(di32); + VI32 isum3_p0 = hn::Zero(di32), isum3_p1 = hn::Zero(di32); + VI32 isum4_p0 = hn::Zero(di32), isum4_p1 = hn::Zero(di32); + VI32 isum5_p0 = hn::Zero(di32), isum5_p1 = hn::Zero(di32); + VI32 isum6_p0 = hn::Zero(di32), isum6_p1 = hn::Zero(di32); + VI32 isum7_p0 = hn::Zero(di32), isum7_p1 = hn::Zero(di32); + VI32 isum0_odd_p0 = hn::Zero(di32), isum0_odd_p1 = hn::Zero(di32); + VI32 isum1_odd_p0 = hn::Zero(di32), isum1_odd_p1 = hn::Zero(di32); + VI32 isum2_odd_p0 = hn::Zero(di32), isum2_odd_p1 = hn::Zero(di32); + VI32 isum3_odd_p0 = hn::Zero(di32), isum3_odd_p1 = hn::Zero(di32); + VI32 isum4_odd_p0 = hn::Zero(di32), isum4_odd_p1 = hn::Zero(di32); + VI32 isum5_odd_p0 = hn::Zero(di32), isum5_odd_p1 = hn::Zero(di32); + VI32 isum6_odd_p0 = hn::Zero(di32), isum6_odd_p1 = hn::Zero(di32); + VI32 isum7_odd_p0 = hn::Zero(di32), isum7_odd_p1 = hn::Zero(di32); + + const int32_t* q_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q); + const int32_t* q2_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q2); + constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); + constexpr int kSecondHalfAmountOfQueries = + kNumQueries - kFirstHalfAmountOfQueries; + + const hn::Repartition di8; + const hn::Half di8_half; + const int8_t* k_ptr_base = reinterpret_cast(k_transposed_tile); + for (size_t i = 0; i < qkv_dim / 2; i++) { + auto k_dim0 = + hn::LoadU(di8_half, k_ptr_base + (i * 2) * gcpp::KVCache::kTileSize); + auto k_dim1 = + hn::LoadU(di8_half, k_ptr_base + (i * 2) * gcpp::KVCache::kTileSize + + hn::Lanes(di8_half)); + auto k_vec1 = hn::PromoteTo(di16, k_dim0); + auto k_vec2 = hn::PromoteTo(di16, k_dim1); + + auto q_0_as_int32_vec = + hn::Set(di32, q_int32_ptr[i * kFirstHalfAmountOfQueries]); + VI16 q_0 = hn::BitCast(di16, q_0_as_int32_vec); + isum0_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_0, isum0_p0, + isum0_odd_p0); + isum0_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_0, isum0_p1, + isum0_odd_p1); + if constexpr (kNumQueries >= 2) { + auto q_1_as_int32_vec = + hn::Set(di32, q_int32_ptr[i * kFirstHalfAmountOfQueries + 1]); + VI16 q_1 = hn::BitCast(di16, q_1_as_int32_vec); + isum1_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_1, isum1_p0, + isum1_odd_p0); + isum1_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_1, isum1_p1, + isum1_odd_p1); + } + if constexpr (kNumQueries >= 3) { + auto q_2_as_int32_vec = + hn::Set(di32, q_int32_ptr[i * kFirstHalfAmountOfQueries + 2]); + VI16 q_2 = hn::BitCast(di16, q_2_as_int32_vec); + isum2_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_2, isum2_p0, + isum2_odd_p0); + isum2_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_2, isum2_p1, + isum2_odd_p1); + } + if constexpr (kNumQueries >= 4) { + auto q_3_as_int32_vec = + hn::Set(di32, q_int32_ptr[i * kFirstHalfAmountOfQueries + 3]); + VI16 q_3 = hn::BitCast(di16, q_3_as_int32_vec); + isum3_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_3, isum3_p0, + isum3_odd_p0); + isum3_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_3, isum3_p1, + isum3_odd_p1); + } + if constexpr (kNumQueries >= 5) { + auto q_4_as_int32_vec = + hn::Set(di32, q2_int32_ptr[i * kSecondHalfAmountOfQueries + 0]); + VI16 q_4 = hn::BitCast(di16, q_4_as_int32_vec); + isum4_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_4, isum4_p0, + isum4_odd_p0); + isum4_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_4, isum4_p1, + isum4_odd_p1); + } + if constexpr (kNumQueries >= 6) { + auto q_5_as_int32_vec = + hn::Set(di32, q2_int32_ptr[i * kSecondHalfAmountOfQueries + 1]); + VI16 q_5 = hn::BitCast(di16, q_5_as_int32_vec); + isum5_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_5, isum5_p0, + isum5_odd_p0); + isum5_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_5, isum5_p1, + isum5_odd_p1); + } + if constexpr (kNumQueries >= 7) { + auto q_6_as_int32_vec = + hn::Set(di32, q2_int32_ptr[i * kSecondHalfAmountOfQueries + 2]); + VI16 q_6 = hn::BitCast(di16, q_6_as_int32_vec); + isum6_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_6, isum6_p0, + isum6_odd_p0); + isum6_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_6, isum6_p1, + isum6_odd_p1); + } + if constexpr (kNumQueries >= 8) { + auto q_7_as_int32_vec = + hn::Set(di32, q2_int32_ptr[i * kSecondHalfAmountOfQueries + 3]); + VI16 q_7 = hn::BitCast(di16, q_7_as_int32_vec); + isum7_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_7, isum7_p0, + isum7_odd_p0); + isum7_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_7, isum7_p1, + isum7_odd_p1); + } + } + + sum0_p0 = hn::ConvertTo(df, hn::Add(isum0_p0, isum0_odd_p0)); + sum0_p1 = hn::ConvertTo(df, hn::Add(isum0_p1, isum0_odd_p1)); + if constexpr (kNumQueries >= 2) { + sum1_p0 = hn::ConvertTo(df, hn::Add(isum1_p0, isum1_odd_p0)); + sum1_p1 = hn::ConvertTo(df, hn::Add(isum1_p1, isum1_odd_p1)); + } + if constexpr (kNumQueries >= 3) { + sum2_p0 = hn::ConvertTo(df, hn::Add(isum2_p0, isum2_odd_p0)); + sum2_p1 = hn::ConvertTo(df, hn::Add(isum2_p1, isum2_odd_p1)); + } + if constexpr (kNumQueries >= 4) { + sum3_p0 = hn::ConvertTo(df, hn::Add(isum3_p0, isum3_odd_p0)); + sum3_p1 = hn::ConvertTo(df, hn::Add(isum3_p1, isum3_odd_p1)); + } + if constexpr (kNumQueries >= 5) { + sum4_p0 = hn::ConvertTo(df, hn::Add(isum4_p0, isum4_odd_p0)); + sum4_p1 = hn::ConvertTo(df, hn::Add(isum4_p1, isum4_odd_p1)); + } + if constexpr (kNumQueries >= 6) { + sum5_p0 = hn::ConvertTo(df, hn::Add(isum5_p0, isum5_odd_p0)); + sum5_p1 = hn::ConvertTo(df, hn::Add(isum5_p1, isum5_odd_p1)); + } + if constexpr (kNumQueries >= 7) { + sum6_p0 = hn::ConvertTo(df, hn::Add(isum6_p0, isum6_odd_p0)); + sum6_p1 = hn::ConvertTo(df, hn::Add(isum6_p1, isum6_odd_p1)); + } + if constexpr (kNumQueries >= 8) { + sum7_p0 = hn::ConvertTo(df, hn::Add(isum7_p0, isum7_odd_p0)); + sum7_p1 = hn::ConvertTo(df, hn::Add(isum7_p1, isum7_odd_p1)); + } +} + template , typename T> static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16( DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2, @@ -1406,6 +1571,63 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( size_t current_kv_start_offset = 0; size_t current_kv_idx = 0; + constexpr bool kQuantizeQ = IsInt8() && IsBF16(); + std::vector> q_int16; + std::vector q_T_int16_groups; + std::vector q_scales; + + if constexpr (kQuantizeQ) { + size_t num_groups = q_T_in_groups_up_to_4.size(); + q_int16.resize(num_groups * kNumQueriesPerGroup * qkv_dim); + q_scales.resize(num_groups * kNumQueriesPerGroup); + q_T_int16_groups.resize(num_groups); + + for (size_t g = 0; g < num_groups; ++g) { + size_t queries_in_group = + (g == num_groups - 1 && q_count % kNumQueriesPerGroup != 0) + ? (q_count % kNumQueriesPerGroup) + : kNumQueriesPerGroup; + const Q_T* q_bf16 = q_T_in_groups_up_to_4[g]; + int16_t* q_out = q_int16.data() + g * kNumQueriesPerGroup * qkv_dim; + q_T_int16_groups[g] = q_out; + + for (size_t j = 0; j < queries_in_group; ++j) { + float max_abs = 0.0f; + for (size_t i = 0; i < qkv_dim; i += 2) { + size_t src_idx0 = (i / 2) * (queries_in_group * 2) + j * 2 + 0; + size_t src_idx1 = (i / 2) * (queries_in_group * 2) + j * 2 + 1; + float val0 = std::abs(hwy::ConvertScalarTo(q_bf16[src_idx0])); + float val1 = std::abs(hwy::ConvertScalarTo(q_bf16[src_idx1])); + max_abs = std::max({max_abs, val0, val1}); + } + float scale = max_abs == 0.0f ? 1.0f : 32767.0f / max_abs; + q_scales[g * kNumQueriesPerGroup + j] = 1.0f / scale; + for (size_t i = 0; i < qkv_dim; i += 2) { + size_t src_idx0 = (i / 2) * (queries_in_group * 2) + j * 2 + 0; + size_t src_idx1 = (i / 2) * (queries_in_group * 2) + j * 2 + 1; + // Write to interleaved destination + size_t idx0 = (i / 2) * (queries_in_group * 2) + j * 2 + 0; + size_t idx1 = (i / 2) * (queries_in_group * 2) + j * 2 + 1; + + float val0 = hwy::ConvertScalarTo(q_bf16[src_idx0]) * scale; + int quantized0 = std::round(val0); + quantized0 = std::max(-32768, std::min(32767, quantized0)); + q_out[idx0] = static_cast(quantized0); + + float val1 = hwy::ConvertScalarTo(q_bf16[src_idx1]) * scale; + int quantized1 = std::round(val1); + quantized1 = std::max(-32768, std::min(32767, quantized1)); + q_out[idx1] = static_cast(quantized1); + } + } + } + } + + float s_quantization_scale = 1.0f; + if constexpr (kQuantizeQ) { + s_quantization_scale = 1.0f; + } + auto inner_loop = [&](int q_group_idx) HWY_ATTR { int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup); if (position + step_size <= min_start_pos_per_group[loop_idx] || @@ -1424,17 +1646,35 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( const KV_T* v_tile = tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim; - const Q_T* q_group = q_T_in_groups_up_to_4[q_group_idx]; + const Q_T* q_group = nullptr; const Q_T* q2_group = nullptr; - if (kNumQueries > 4) { - q2_group = q_T_in_groups_up_to_4[q_group_idx + 1]; + const int16_t* q_group_int16 = nullptr; + const int16_t* q2_group_int16 = nullptr; + if constexpr (kQuantizeQ) { + q_group_int16 = q_T_int16_groups[q_group_idx]; + if (kNumQueries > 4) { + q2_group_int16 = q_T_int16_groups[q_group_idx + 1]; + } + } else { + q_group = q_T_in_groups_up_to_4[q_group_idx]; + if (kNumQueries > 4) { + q2_group = q_T_in_groups_up_to_4[q_group_idx + 1]; + } } + if constexpr (IsF32()) { const KV_T* k_transposed_tile = tile_base + pos_in_tile; QDotKTilexUpTo8TransposedKDoubleWidth( df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } else if constexpr (kQuantizeQ) { + const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2; + QDotKTilexUpTo8TransposedKDoubleWidthInt16( + df, q_group_int16, q2_group_int16, k_transposed_tile, qkv_dim, + x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, + x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, + x_7_p_0, x_7_p_1); } else if constexpr (IsBF16()) { const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2; QDotKTilexUpTo8TransposedKDoubleWidthBF16( @@ -1461,6 +1701,50 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } + if constexpr (kQuantizeQ) { + VF s0 = hn::Set(df, q_scales[q_group_idx * kNumQueriesPerGroup + 0]); + x_0_p_0 = hn::Mul(x_0_p_0, s0); + x_0_p_1 = hn::Mul(x_0_p_1, s0); + if constexpr (kNumQueries >= 2) { + VF s1 = hn::Set(df, q_scales[q_group_idx * kNumQueriesPerGroup + 1]); + x_1_p_0 = hn::Mul(x_1_p_0, s1); + x_1_p_1 = hn::Mul(x_1_p_1, s1); + } + if constexpr (kNumQueries >= 3) { + VF s2 = hn::Set(df, q_scales[q_group_idx * kNumQueriesPerGroup + 2]); + x_2_p_0 = hn::Mul(x_2_p_0, s2); + x_2_p_1 = hn::Mul(x_2_p_1, s2); + } + if constexpr (kNumQueries >= 4) { + VF s3 = hn::Set(df, q_scales[q_group_idx * kNumQueriesPerGroup + 3]); + x_3_p_0 = hn::Mul(x_3_p_0, s3); + x_3_p_1 = hn::Mul(x_3_p_1, s3); + } + if constexpr (kNumQueries >= 5) { + VF s4 = + hn::Set(df, q_scales[(q_group_idx + 1) * kNumQueriesPerGroup + 0]); + x_4_p_0 = hn::Mul(x_4_p_0, s4); + x_4_p_1 = hn::Mul(x_4_p_1, s4); + } + if constexpr (kNumQueries >= 6) { + VF s5 = + hn::Set(df, q_scales[(q_group_idx + 1) * kNumQueriesPerGroup + 1]); + x_5_p_0 = hn::Mul(x_5_p_0, s5); + x_5_p_1 = hn::Mul(x_5_p_1, s5); + } + if constexpr (kNumQueries >= 7) { + VF s6 = + hn::Set(df, q_scales[(q_group_idx + 1) * kNumQueriesPerGroup + 2]); + x_6_p_0 = hn::Mul(x_6_p_0, s6); + x_6_p_1 = hn::Mul(x_6_p_1, s6); + } + if constexpr (kNumQueries >= 8) { + VF s7 = + hn::Set(df, q_scales[(q_group_idx + 1) * kNumQueriesPerGroup + 3]); + x_7_p_0 = hn::Mul(x_7_p_0, s7); + x_7_p_1 = hn::Mul(x_7_p_1, s7); + } + } constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); constexpr int kSecondHalfAmountOfQueries = @@ -1493,7 +1777,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx, - kNumQueriesPerGroup); + kNumQueriesPerGroup, s_quantization_scale); if constexpr (kUseMicroScaling) { const BF16* microscaling_scales_v = reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + @@ -1508,7 +1792,12 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); - } else if constexpr (IsBF16()) { + } else if (kQuantizeQ) { + MulByConstAndAddTileUpTo8_BF16_Int16( + df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, + x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); + } else { MulByConstAndAddTileUpTo8_BF16( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 64f3c0d2..745c52da 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -486,8 +486,11 @@ void TestTiledFlashAttentionBF16() { EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f); + for (int j = 0; j < std::min(8, (int)qkv_dim); ++j) { + if (i == 0) { + std::cerr << "att_out[0][" << j << "]=" << att_out.Row(i)[j] << " gold=" << att_out_gold[i * qkv_dim + j] << "\n"; + } + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 0.3f); } } } @@ -625,6 +628,180 @@ void TestTiledFlashAttentionInt8() { } } +void TestTiledFlashAttentionInt8BF16() { + int qkv_dim = 64; + int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by + // tiles size to test the padding logic. + int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 10.0f; + int num_queries = 8; + int num_queries_per_timestep = 4; + int num_tokens = num_queries / num_queries_per_timestep; + int kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + + int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + + 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; + + MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), + ctx.allocator, MatPadding::kPacked); + + // fill in kvs with predictable, synthetic data matching BF16 paired layout + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + int8_t* tile_ptr = kv.Row(tile_idx); + BF16* scales_ptr = reinterpret_cast( + tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize); + + // K values (interleaved by in_tile_offset pairs across qkv_dim) + for (int in_tile_offset = 0; in_tile_offset < gcpp::KVCache::kTileSize; + ++in_tile_offset) { + int i = tile_idx * gcpp::KVCache::kTileSize + in_tile_offset; + + std::vector k_vals(qkv_dim); + float max_abs_k = 0.0f; + for (int j = 0; j < qkv_dim; ++j) { + k_vals[j] = 0.01f * (i + 1) / (j + 1); + max_abs_k = std::max(max_abs_k, std::abs(k_vals[j])); + } + + // Quantize K + float scale_k = max_abs_k / 127.0f; + if (scale_k == 0.0f) scale_k = 1.0f; + scales_ptr[in_tile_offset] = hwy::ConvertScalarTo(scale_k); + + for (int j = 0; j < qkv_dim; j += 2) { + int val0 = std::round(k_vals[j] / scale_k); + val0 = std::max(-127, std::min(127, val0)); + int val1 = std::round(k_vals[j + 1] / scale_k); + val1 = std::max(-127, std::min(127, val1)); + + size_t offset0 = j * gcpp::KVCache::kTileSize + in_tile_offset * 2; + size_t offset1 = j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1; + tile_ptr[offset0] = static_cast(val0); + tile_ptr[offset1] = static_cast(val1); + } + } + + size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; + + // V values (paired by sequence length elements - i.e., in_tile_offset & + // in_tile_offset + 1) + for (int in_tile_offset = 0; in_tile_offset < gcpp::KVCache::kTileSize; + in_tile_offset += 2) { + int i1 = tile_idx * gcpp::KVCache::kTileSize + in_tile_offset; + int i2 = i1 + 1; + + std::vector v_vals1(qkv_dim); + std::vector v_vals2(qkv_dim); + float max_abs_v1 = 0.0f; + float max_abs_v2 = 0.0f; + + for (int j = 0; j < qkv_dim; ++j) { + v_vals1[j] = 0.02f * (i1 + 1) / (j + 1); + v_vals2[j] = 0.02f * (i2 + 1) / (j + 1); + max_abs_v1 = std::max(max_abs_v1, std::abs(v_vals1[j])); + max_abs_v2 = std::max(max_abs_v2, std::abs(v_vals2[j])); + } + + // Quantize V + float scale_v1 = max_abs_v1 / 127.0f; + if (scale_v1 == 0.0f) scale_v1 = 1.0f; + scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] = + hwy::ConvertScalarTo(scale_v1); + + float scale_v2 = max_abs_v2 / 127.0f; + if (scale_v2 == 0.0f) scale_v2 = 1.0f; + scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset + 1] = + hwy::ConvertScalarTo(scale_v2); + + for (int j = 0; j < qkv_dim; ++j) { + int val1 = std::round(v_vals1[j] / scale_v1); + val1 = std::max(-127, std::min(127, val1)); + int val2 = std::round(v_vals2[j] / scale_v2); + val2 = std::max(-127, std::min(127, val2)); + + tile_ptr[v_offset + in_tile_offset * qkv_dim + j * 2] = + static_cast(val1); + tile_ptr[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] = + static_cast(val2); + } + } + } + + std::vector q_bf16(num_queries_per_timestep * qkv_dim); + std::vector q_bf16_2(num_queries_per_timestep * qkv_dim); + + // fill in qs with predictable, synthetic data matching BF16 layout (paired + // adjacently) + for (int i = 0; i < num_queries_per_timestep; ++i) { + for (int j = 0; j < qkv_dim; j += 2) { + q_bf16[j * num_queries_per_timestep + i * 2] = + hwy::ConvertScalarTo(0.01f * (i + 1) / (j + 1)); + q_bf16[j * num_queries_per_timestep + i * 2 + 1] = + hwy::ConvertScalarTo(0.01f * (i + 1) / (j + 2)); + + q_bf16_2[j * num_queries_per_timestep + i * 2] = + hwy::ConvertScalarTo( + 0.01f * (i + num_queries_per_timestep + 1) / (j + 1)); + q_bf16_2[j * num_queries_per_timestep + i * 2 + 1] = + hwy::ConvertScalarTo( + 0.01f * (i + num_queries_per_timestep + 1) / (j + 2)); + } + } + const BF16* q_T[2] = {q_bf16.data(), q_bf16_2.data()}; + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + using DF = hn::ScalableTag; + const DF df; + HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); + size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_laness); + std::vector max_logits(num_queries_rounded_to_laness); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kvs, num_queries, hwy::Span(q_T, 2), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + PrintMatPtr(att_out); + for (int i = 0; i < num_queries; ++i) { + std::cerr << "exp_d: " << exp_denominator_sums[i] + << " max_logit: " << max_logits[i] << std::endl; + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-1f) + << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-2f) << "i=" << i; + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-1f); + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -638,6 +815,7 @@ HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8BF16); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/ops/ops-inl.h b/ops/ops-inl.h index c678d13f..1a458ef7 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -1078,6 +1078,177 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8( HWY_DASSERT(qkv_dim == i); } + +// Specialized version for BF16 models that uses int16 quantization for V. +template , typename VType> +HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16_Int16( + DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1, + const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1, + const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1, + const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1, + const VF& c7_p0, const VF& c7_p1, const VType* HWY_RESTRICT v_tile, + MatPtrT& out, float s_quantization_scale = 2047.0f) { + static_assert(N <= 8); + namespace hn = hwy::HWY_NAMESPACE; + const size_t qkv_dim = out.Cols(); + constexpr size_t kMaxLanes = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + using DI16 = hn::Repartition; + const DI16 di16; + const auto di16_half = hn::Half(); + using DI32 = hn::Repartition; + const DI32 di32; + using VI16 = hn::Vec; + using VI32 = hn::Vec; + + float q_scales_s[8]; + HWY_ALIGN int16_t cs_i16[N * kMaxLanes * 2]; + // float q_scale_v = 1.0f; + VF inv_v_vec = hn::Set(df, 1.0f); + + if constexpr (sizeof(VType) == 1) { + // OPTIMIZED PATH FOR INT8: + // `S` vectors (c0_p0 etc.) are already multiplied by `V` micro-scales + // upstream in `FlashAttentionTileStepAndApplySoftCap`. We just max & + // quantize S. + auto quantize_s_and_store = [&](int j, const VF& p0, const VF& p1) { + VF max_v = hn::Max(p0, p1); + float max_s = hn::GetLane(hn::MaxOfLanes(df, max_v)); + + float q_scale = max_s > 0.0f ? 32767.0f / max_s : 1.0f; + q_scales_s[j] = max_s > 0.0f ? max_s / 32767.0f : 1.0f; + + VF qs0 = hn::Mul(p0, hn::Set(df, q_scale)); + VF qs1 = hn::Mul(p1, hn::Set(df, q_scale)); + + auto i0 = hn::DemoteTo(di16_half, hn::NearestInt(qs0)); + auto i1 = hn::DemoteTo(di16_half, hn::NearestInt(qs1)); + hn::Store(hn::Combine(di16, i1, i0), di16, cs_i16 + j * kMaxLanes * 2); + }; + + quantize_s_and_store(0, c0_p0, c0_p1); + if constexpr (N >= 2) quantize_s_and_store(1, c1_p0, c1_p1); + if constexpr (N >= 3) quantize_s_and_store(2, c2_p0, c2_p1); + if constexpr (N >= 4) quantize_s_and_store(3, c3_p0, c3_p1); + if constexpr (N >= 5) quantize_s_and_store(4, c4_p0, c4_p1); + if constexpr (N >= 6) quantize_s_and_store(5, c5_p0, c5_p1); + if constexpr (N >= 7) quantize_s_and_store(6, c6_p0, c6_p1); + if constexpr (N >= 8) quantize_s_and_store(7, c7_p0, c7_p1); + } else { + // + } + + using DI8 = hn::Repartition; + const hn::Half di8_half; + PackedSpan v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF); + + size_t i = 0; + HWY_DASSERT(qkv_dim % (NF * 2) == 0); + while (i + 2 * NF <= qkv_dim) { + VI32 acc0_0 = hn::Zero(di32), acc0_1 = hn::Zero(di32); + VI32 acc1_0 = hn::Zero(di32), acc1_1 = hn::Zero(di32); + VI32 acc2_0 = hn::Zero(di32), acc2_1 = hn::Zero(di32); + VI32 acc3_0 = hn::Zero(di32), acc3_1 = hn::Zero(di32); + VI32 acc4_0 = hn::Zero(di32), acc4_1 = hn::Zero(di32); + VI32 acc5_0 = hn::Zero(di32), acc5_1 = hn::Zero(di32); + VI32 acc6_0 = hn::Zero(di32), acc6_1 = hn::Zero(di32); + VI32 acc7_0 = hn::Zero(di32), acc7_1 = hn::Zero(di32); + + VI32 acc0_o_0 = hn::Zero(di32), acc0_o_1 = hn::Zero(di32); + VI32 acc1_o_0 = hn::Zero(di32), acc1_o_1 = hn::Zero(di32); + VI32 acc2_o_0 = hn::Zero(di32), acc2_o_1 = hn::Zero(di32); + VI32 acc3_o_0 = hn::Zero(di32), acc3_o_1 = hn::Zero(di32); + VI32 acc4_o_0 = hn::Zero(di32), acc4_o_1 = hn::Zero(di32); + VI32 acc5_o_0 = hn::Zero(di32), acc5_o_1 = hn::Zero(di32); + VI32 acc6_o_0 = hn::Zero(di32), acc6_o_1 = hn::Zero(di32); + VI32 acc7_o_0 = hn::Zero(di32), acc7_o_1 = hn::Zero(di32); + + for (int lane = 0; lane < NF; ++lane) { + VI16 vi_first8, vi_next8; + + if constexpr (sizeof(VType) == 1) { // int8_t (ZERO-COPY SCALING) + const int8_t* v_ptr = reinterpret_cast(v_tile) + + 2 * qkv_dim * lane + i * 2; + + auto v8_t0 = hn::LoadU(di8_half, v_ptr); + auto v16_t0 = hn::PromoteTo(di16, v8_t0); + + auto v8_t1 = hn::LoadU(di8_half, v_ptr + hn::Lanes(di8_half)); + auto v16_t1 = hn::PromoteTo(di16, v8_t1); + + vi_first8 = v16_t0; + vi_next8 = v16_t1; + } else { // BF16 (Fallback logic dynamically tracking V) + // unreachable + } + + auto mul_acc = [&](int j, VI32& a0, VI32& a_o0, VI32& a1, VI32& a_o1) { + // --- FIX 1: Fetch 2*lane and 2*lane+1 --- + int16_t s0 = cs_i16[2 * lane + j * kMaxLanes * 2]; + int16_t s1 = cs_i16[2 * lane + 1 + j * kMaxLanes * 2]; + // ---------------------------------------- + int32_t s01; + hwy::CopySameSize(&s0, reinterpret_cast(&s01)); + hwy::CopySameSize(&s1, reinterpret_cast(&s01) + 1); + VI16 sj = hn::BitCast(di16, hn::Set(di32, s01)); + + a0 = hn::ReorderWidenMulAccumulate(di32, vi_first8, sj, a0, a_o0); + a1 = hn::ReorderWidenMulAccumulate(di32, vi_next8, sj, a1, a_o1); + }; + + mul_acc(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1); + if constexpr (N >= 2) mul_acc(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1); + if constexpr (N >= 3) mul_acc(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1); + if constexpr (N >= 4) mul_acc(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1); + if constexpr (N >= 5) mul_acc(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1); + if constexpr (N >= 6) mul_acc(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1); + if constexpr (N >= 7) mul_acc(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1); + if constexpr (N >= 8) mul_acc(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1); + } + + auto convert_and_add = [&](int j, VI32& a0, VI32& a_o0, VI32& a1, + VI32& a_o1) { + VF f0 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a0, a_o0)); + VF f1 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a1, a_o1)); + + VF o0 = hn::Load(df, out.Row(j) + i); + VF o1 = hn::Load(df, out.Row(j) + i + NF); + + // --- FIX 2: Correctly decay the old accumulator and add the new --- + VF scale_old = hn::Set(df, scales[j]); + o0 = hn::Mul(o0, scale_old); + o1 = hn::Mul(o1, scale_old); + + VF scale_new = hn::Set(df, q_scales_s[j]); + o0 = hn::MulAdd(f0, scale_new, o0); + o1 = hn::MulAdd(f1, scale_new, o1); + // ---------------------------------------------------------------- + + hn::Store(o0, df, out.Row(j) + i); + hn::Store(o1, df, out.Row(j) + i + NF); + }; + + convert_and_add(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1); + if constexpr (N >= 2) + convert_and_add(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1); + if constexpr (N >= 3) + convert_and_add(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1); + if constexpr (N >= 4) + convert_and_add(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1); + if constexpr (N >= 5) + convert_and_add(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1); + if constexpr (N >= 6) + convert_and_add(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1); + if constexpr (N >= 7) + convert_and_add(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1); + if constexpr (N >= 8) + convert_and_add(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1); + + i += 2 * NF; + } +} + template , typename VType> HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16( DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0,