Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
}
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
Expand Down Expand Up @@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
float m = hn::ReduceMax(df, x_max);
m = std::max(m, old_max);
VF m_vec = hn::Set(df, m);
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
float scale = old_d * std::exp(old_max - m);
VF x_sum = hn::Add(x0, x1);
old_d = hn::ReduceSum(df, x_sum) + scale;
Expand Down Expand Up @@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
VF4 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
Expand Down Expand Up @@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
x_6_sum, x_7_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
VF8 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
const VF zero = hn::Zero(df);
Expand Down
Loading