-
-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathggml-topk-collapse-vsx.h
More file actions
304 lines (260 loc) · 10.2 KB
/
ggml-topk-collapse-vsx.h
File metadata and controls
304 lines (260 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
/*
* ggml-topk-collapse-vsx.h - Top-K Attention Collapse for POWER8
*
* Scott's Vision: "Quantum-like collapse - throwing lossy extras away"
*
* MATHEMATICAL BASIS:
* - Sparse attention is well-studied (Longformer, BigBird, etc.)
* - Top-K attention keeps K highest scores, zeros the rest
* - This IS mathematically valid - approximates full attention
* - vec_perm used for fast partitioning/selection
*
* Key insight: Don't randomly prune. Prune the WEAKEST signals.
* The strongest attention weights dominate anyway (~80-90% of output).
*/
#ifndef GGML_TOPK_COLLAPSE_VSX_H
#define GGML_TOPK_COLLAPSE_VSX_H
#include <altivec.h>
#include <stdint.h>
#include <math.h>
#include <string.h>
/*===========================================================================
* Configuration
*===========================================================================*/
/* Keep top K attention weights, zero the rest */
#ifndef TOPK_ATTENTION_K
#define TOPK_ATTENTION_K 64 /* Keep top 64 per position */
#endif
/* Minimum score to consider (absolute threshold) */
#ifndef TOPK_MIN_SCORE
#define TOPK_MIN_SCORE 0.01f
#endif
/* Enable entropy mixing for tie-breaking */
#ifndef TOPK_ENTROPY_ENABLED
#define TOPK_ENTROPY_ENABLED 1
#endif
/*===========================================================================
* Hardware Entropy
*===========================================================================*/
static inline uint64_t topk_read_timebase(void) {
#if defined(__powerpc64__) || defined(__powerpc__)
uint64_t tb;
__asm__ __volatile__("mftb %0" : "=r"(tb));
return tb;
#else
return 0;
#endif
}
/*===========================================================================
* Fast Partial Sort using vec_perm (Bitonic-like)
*
* Instead of full sort, use vec_perm to quickly identify
* approximate top-K elements. Not exact, but fast.
*===========================================================================*/
/* Compare-swap patterns for vec_perm based sorting network */
static const unsigned char COMPARE_LO_PATTERN[16] __attribute__((aligned(16))) = {
0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23
};
static const unsigned char COMPARE_HI_PATTERN[16] __attribute__((aligned(16))) = {
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
};
/*
* vec_perm_compare_swap: Compare two float4 vectors, output min/max
* This is a building block for sorting networks.
*/
static inline void vec_perm_compare_swap(
vector float a, vector float b,
vector float* out_min, vector float* out_max
) {
/* Compare: which elements of a are less than b? */
vector bool int mask = vec_cmpgt(b, a);
/* Select min and max */
*out_min = vec_sel(b, a, mask); /* min(a,b) */
*out_max = vec_sel(a, b, mask); /* max(a,b) */
}
/*
* Fast approximate top-4 from 8 floats using vec_perm
* Returns the 4 largest values (not necessarily sorted)
*/
static inline vector float vec_perm_top4_of_8(
vector float v0, vector float v1
) {
vector float min_vals, max_vals;
vec_perm_compare_swap(v0, v1, &min_vals, &max_vals);
return max_vals; /* Top 4 are in max_vals */
}
/*===========================================================================
* Top-K Attention Score Selection
*
* Given attention scores, keep only top-K, zero the rest.
* Uses hybrid approach:
* 1. Quick scan to find approximate K-th value (threshold)
* 2. vec_perm to mask below threshold
*===========================================================================*/
/*
* QuickSelect partition - O(n) average case
* Based on Hoare's selection algorithm (same as std::nth_element)
*/
static inline int partition_descending(float* arr, int lo, int hi) {
float pivot = arr[hi];
int i = lo - 1;
for (int j = lo; j < hi; j++) {
if (arr[j] >= pivot) { /* Descending: >= for largest first */
i++;
float tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp;
}
}
float tmp = arr[i + 1]; arr[i + 1] = arr[hi]; arr[hi] = tmp;
return i + 1;
}
/*
* QuickSelect: Find k-th largest in O(n) average case
* Much faster than O(20n) binary search - ng's suggestion
*/
static inline float find_kth_largest(
const float* scores, int n, int k
) {
if (k >= n) return -INFINITY;
if (k <= 0) return INFINITY;
/* Work on a copy to avoid modifying original */
float* arr = (float*)__builtin_alloca(n * sizeof(float));
memcpy(arr, scores, n * sizeof(float));
int lo = 0, hi = n - 1;
int target = k - 1; /* 0-indexed: k-th largest is at index k-1 */
while (lo < hi) {
int pivot_idx = partition_descending(arr, lo, hi);
if (pivot_idx == target) {
return arr[pivot_idx];
} else if (pivot_idx < target) {
lo = pivot_idx + 1;
} else {
hi = pivot_idx - 1;
}
}
return arr[lo];
}
/*
* Apply top-K mask to attention scores using vec_perm
* Scores below threshold are zeroed.
*/
static inline void apply_topk_mask_vsx(
float* scores, int n, float threshold
) {
vector float thresh_vec = vec_splats(threshold);
vector float zero_vec = vec_splats(0.0f);
int i = 0;
for (; i + 3 < n; i += 4) {
vector float v = vec_ld(0, &scores[i]);
/* Mask: keep if >= threshold */
vector bool int mask = vec_cmpge(v, thresh_vec);
v = vec_sel(zero_vec, v, mask);
vec_st(v, 0, &scores[i]);
}
/* Scalar remainder */
for (; i < n; i++) {
if (scores[i] < threshold) scores[i] = 0.0f;
}
}
/*===========================================================================
* Top-K Collapsed Attention
*
* Full attention computation with top-K sparsification:
* 1. Compute Q·K scores (standard)
* 2. Find top-K threshold
* 3. Zero below threshold (vec_perm accelerated)
* 4. Softmax over surviving scores
* 5. Weighted sum of V
*===========================================================================*/
static inline void attention_topk_collapsed(
float* output, /* Output: [seq_len, head_dim] */
const float* Q, /* Query: [seq_len, head_dim] */
const float* K, /* Key: [seq_len, head_dim] */
const float* V, /* Value: [seq_len, head_dim] */
int seq_len,
int head_dim,
int layer_id,
int head_id,
int top_k
) {
/* Temporary storage for attention scores */
float* scores = (float*)__builtin_alloca(seq_len * sizeof(float));
for (int pos = 0; pos < seq_len; pos++) {
const float* q_row = Q + pos * head_dim;
/* Compute Q·K scores */
for (int t = 0; t <= pos; t++) {
const float* k_row = K + t * head_dim;
/* Standard dot product */
vector float sum = vec_splats(0.0f);
for (int d = 0; d + 3 < head_dim; d += 4) {
vector float qv = vec_ld(0, &q_row[d]);
vector float kv = vec_ld(0, &k_row[d]);
sum = vec_madd(qv, kv, sum);
}
/* Horizontal sum */
vector float s1 = vec_add(sum, vec_sld(sum, sum, 8));
vector float s2 = vec_add(s1, vec_sld(s1, s1, 4));
vec_ste(s2, 0, &scores[t]);
}
/* TOP-K COLLAPSE: Keep only strongest signals */
int actual_k = (top_k < pos + 1) ? top_k : (pos + 1);
float threshold = find_kth_largest(scores, pos + 1, actual_k);
apply_topk_mask_vsx(scores, pos + 1, threshold);
/* Softmax over surviving scores */
float max_score = -INFINITY;
for (int t = 0; t <= pos; t++) {
if (scores[t] > max_score) max_score = scores[t];
}
float sum_exp = 0.0f;
for (int t = 0; t <= pos; t++) {
if (scores[t] != 0.0f) { /* Only non-zero */
scores[t] = expf(scores[t] - max_score);
sum_exp += scores[t];
}
}
if (sum_exp > 0.0f) {
for (int t = 0; t <= pos; t++) {
scores[t] /= sum_exp;
}
}
/* Weighted sum of V (sparse - skip zeros) */
float* out_row = output + pos * head_dim;
memset(out_row, 0, head_dim * sizeof(float));
for (int t = 0; t <= pos; t++) {
float weight = scores[t];
if (weight < TOPK_MIN_SCORE) continue; /* Skip negligible */
const float* v_row = V + t * head_dim;
for (int d = 0; d + 3 < head_dim; d += 4) {
vector float v_vec = vec_ld(0, &v_row[d]);
vector float o_vec = vec_ld(0, &out_row[d]);
vector float w_vec = vec_splats(weight);
o_vec = vec_madd(v_vec, w_vec, o_vec);
vec_st(o_vec, 0, &out_row[d]);
}
}
}
}
/*===========================================================================
* Statistics
*===========================================================================*/
typedef struct {
uint64_t total_scores;
uint64_t scores_kept;
uint64_t scores_pruned;
} topk_stats_t;
static topk_stats_t g_topk_stats = {0};
static inline void topk_report_stats(void) {
fprintf(stderr, "\n");
fprintf(stderr, "╔═══════════════════════════════════════════════════════╗\n");
fprintf(stderr, "║ Top-K Attention Collapse Statistics ║\n");
fprintf(stderr, "╠═══════════════════════════════════════════════════════╣\n");
fprintf(stderr, "║ Total scores: %12lu ║\n",
(unsigned long)g_topk_stats.total_scores);
fprintf(stderr, "║ Scores kept: %12lu (%.1f%%) ║\n",
(unsigned long)g_topk_stats.scores_kept,
g_topk_stats.total_scores > 0 ?
100.0 * g_topk_stats.scores_kept / g_topk_stats.total_scores : 0);
fprintf(stderr, "║ Scores pruned: %12lu ║\n",
(unsigned long)g_topk_stats.scores_pruned);
fprintf(stderr, "╚═══════════════════════════════════════════════════════╝\n");
}
#endif /* GGML_TOPK_COLLAPSE_VSX_H */