1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
10 
11 #include <executorch/kernels/optimized/blas/CPUBlas.h>
12 #include <executorch/kernels/optimized/vec/functional.h>
13 #include <executorch/kernels/optimized/vec/vec.h>
14 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
15 // @lint-ignore CLANGTIDY facebook-unused-include-check
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17 
18 #include <array>
19 #include <vector>
20 
21 #ifdef ET_USE_THREADPOOL
22 #include <executorch/extension/parallel/thread_parallel.h>
23 #include <executorch/extension/threadpool/threadpool.h>
24 #endif
25 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
26 
27 namespace torch {
28 namespace executor {
29 
30 namespace native {
31 
32 namespace util {
33 
34 constexpr size_t kKVDim = 4;
35 
36 template <typename T>
_store(T * dst,::executorch::vec::Vectorized<T> src)37 inline void _store(T* dst, ::executorch::vec::Vectorized<T> src) {
38   src.store(dst);
39 }
40 
41 /*
42 inline void _store(::Half* dst, at::vec::Vectorized<float> src) {
43   //fp16_ieee_to_fp32_value
44   auto res = at::vec::convert_float_half(src, src);
45   res.store(dst, at::vec::Vectorized<float>::size());
46 }
47 */
48 
49 template <typename T>
data_index_init(T offset)50 inline T data_index_init(T offset) {
51   return offset;
52 }
53 
54 template <typename T, typename... Args>
data_index_init(T offset,T & x,const T & X,Args &&...args)55 inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
56   offset = data_index_init(offset, std::forward<Args>(args)...);
57   x = offset % X;
58   return offset / X;
59 }
60 
data_index_step()61 inline bool data_index_step() {
62   return true;
63 }
64 
65 template <typename T, typename... Args>
data_index_step(T & x,const T & X,Args &&...args)66 inline bool data_index_step(T& x, const T& X, Args&&... args) {
67   if (data_index_step(std::forward<Args>(args)...)) {
68     x = ((x + 1) == X) ? 0 : (x + 1);
69     return x == 0;
70   }
71   return false;
72 }
73 
calculate_scale(const Tensor & query,optional<double> scale)74 inline double calculate_scale(const Tensor& query, optional<double> scale) {
75   const auto softmax_scale =
76       scale.has_value() ? scale.value() : 1.0 / std::sqrt(query.size(3));
77   return softmax_scale;
78 }
79 
80 } // namespace util
81 namespace vec = ::executorch::vec;
82 using Tensor = exec_aten::Tensor;
83 
84 namespace {
85 
86 // 1) out = exp(a - val)
87 // 2) val = sum(out)
88 template <typename T1, typename T2>
89 inline void
_exp_reduce_sum_fusion_kernel(T1 * a,const int & size,T2 * out,T1 & val)90 _exp_reduce_sum_fusion_kernel(T1* a, const int& size, T2* out, T1& val) {
91   auto vec_size = vec::Vectorized<T1>::size();
92   auto vec_max = vec::Vectorized<T1>(val);
93   T1 tmp_sum = 0;
94   auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
95   for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
96     auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
97     auto tmp1 = tmp0 - vec_max;
98     // Replace with exp_u20 later
99     // auto tmp2 = tmp1.exp_u20();
100     auto tmp2 = tmp1.exp();
101     vec_tmp_sum += tmp2;
102     util::_store(out + i, tmp2);
103   }
104   tmp_sum = vec::vec_reduce_all<T1>(
105       [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; },
106       vec_tmp_sum);
107   for (int i = vec_size * (size / vec_size); i < size; i++) {
108     auto tmp0 = a[i];
109     auto tmp1 = tmp0 - val;
110     auto tmp2 = exp(tmp1);
111     tmp_sum += tmp2;
112     out[i] = tmp2;
113   }
114   val = tmp_sum;
115 }
116 
117 // 1) out = a * scale
118 // 2) max = max(out)
119 template <typename scalar_t>
_mul_reduce_max_fusion_kernel(const scalar_t * a,const scalar_t & scale,const int & size,scalar_t * out,scalar_t & max)120 inline void _mul_reduce_max_fusion_kernel(
121     const scalar_t* a,
122     const scalar_t& scale,
123     const int& size,
124     scalar_t* out,
125     scalar_t& max) {
126   auto vec_size = vec::Vectorized<scalar_t>::size();
127   auto vec_scale = vec::Vectorized<scalar_t>(scale);
128   scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
129   auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
130   for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
131     auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
132     auto tmp1 = tmp0 * vec_scale;
133     vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
134     util::_store(out + i, tmp1);
135   }
136   for (int i = vec_size * (size / vec_size); i < size; i++) {
137     auto tmp0 = a[i];
138     auto tmp1 = tmp0 * scale;
139     tmp_max = std::max(tmp_max, tmp1);
140     out[i] = tmp1;
141   }
142   max = std::max(
143       tmp_max,
144       vec::vec_reduce_all<scalar_t>(
145           [](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
146             return vec::maximum(x, y);
147           },
148           vec_tmp_max));
149 }
150 
151 template <typename scalar_t>
conditional_data_ptr(scalar_t * ptr,scalar_t * ptr2)152 static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) {
153   ET_CHECK(ptr2 == nullptr);
154   return ptr;
155 }
156 
157 template <
158     typename scalar_t,
159     typename std::enable_if_t<
160         ::executorch::runtime::is_reduced_floating_point_v<scalar_t>,
161         int> = 0>
conditional_data_ptr(float * ptr,scalar_t * ptr2)162 static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) {
163   (void)ptr;
164   return ptr2;
165 }
166 
167 template <typename scalar_t>
fill_stub(scalar_t * data,scalar_t val,int64_t size)168 inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
169   using Vec = vec::Vectorized<scalar_t>;
170   Vec data_vec = Vec(val);
171   int64_t d = 0;
172   for (; d < size - (size % Vec::size()); d += Vec::size()) {
173     data_vec.store(data + d);
174   }
175   for (; d < size; d++) {
176     data[d] = val;
177   }
178 }
179 
180 /*
181 Note on start_pos as a parameter:
182 What is start_pos?
183 - start_pos is the position of the first element of the current query. That is,
184 in LLMs during generate phase, when we generate one token a time, the query
185 will correspond to monotonically increasing start_pos. e.g. the first token
186 is at start_pos = 0, the second token is at start_pos = 1, and so on.
187 If we do prefill with prompt which has 4 tokens, then during the decode phase,
188 start_pos = 4.
189 
190 Why is start_pos neded?
191 - Attention should not need to know start_pos. However, to apply causal mask,
192 we can use is_causal parameter (aten API for SDPA is thinking of getting rid
193 of it). However, the current handling of is_causal assumes that start_pos = 0.
194 Meaning when we have a query during decode at start_pos = 4, it will be a
195 single vector of [1, head_dim] for a given head. Key param, derived from kv
196 cache, will be of size [start_pos + 1, head_dim]. That is all the past tokens
197 contained in kv cache. If we apply causal mask naively, then the query is
198 assumed to be at start_pos = 0, and thus all the future tokens (indices 1...4)
199 in q @ k.T = [1, start_pos], will be masked out for attention calculation.
200 However, that is not right. Since query is at pos 4, that is 4th token, it
201 should attend to all previous tokens in the cache. That is 0...start_pos. Thus
202 we need to pass start_pos.
203 
204 Can we use attn_mask?
205 - Yes. Attention mask can be used for the same, however, at the moment attention
206 mask for our llama model is a boolean mask which requires conversion to -inf for
207 masked out section. This requires change that may have perf implication, however
208 we havent really validated this. It is possible that there is no perf
209 implication. If the mask was float mask, thing will work out-of-the-box. In our
210 llama definition each layer is storying mask and if we move to float mask, that
211 can increase memory footprint, which is right now optimized away since
212 sdpa_with_kv_cache does not use attn_mask.
213 
214 TODO: Just handle conversion of bool mask to float
215 */
216 template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
cpu_flash_attention(Tensor & output,const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,const optional<Tensor> & attn_mask,const optional<double> & scale,bool is_seq_at_dim_1=false,const int64_t start_pos=0)217 void cpu_flash_attention(
218     Tensor& output,
219     const Tensor& query,
220     const Tensor& key,
221     const Tensor& value,
222     double dropout_p,
223     bool is_causal,
224     const optional<Tensor>& attn_mask,
225     const optional<double>& scale,
226     bool is_seq_at_dim_1 = false,
227     const int64_t start_pos = 0) {
228   (void)dropout_p;
229   // Query (Batch x Num_heads  x Q_seq_len  x Dim_per_head)
230   // Key   (Batch x Num_heads  x KV_seq_len x Dim_per_head)
231   // Value (Batch x Num_heads  x KV_seq_len x Dim_per_head)
232 
233   /*
234   //    -> (Batch x Q_seq_len  x Num_heads  x Dim_per_head)
235   at::Tensor query = q.transpose(1, 2);
236   //    -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
237   at::Tensor key = k.transpose(1, 2);
238   //    -> (Batch x KV_seq_len x Num_heads  x Dim_per_head)
239   at::Tensor value = v.transpose(1, 2);
240   */
241 
242   // Without this we have out-of-bounds writes for
243   // causal masking
244   static_assert(
245       kv_split_size > q_split_size,
246       "KV_split_size must be greater than q_split_size");
247 
248   constexpr bool is_reduced_type =
249       ::executorch::runtime::is_reduced_floating_point_v<scalar_t>;
250 
251   ET_CHECK_MSG(
252       !is_reduced_type, "FlashAttention does not support reduced types.");
253   // Figure out mixed precision a little later
254   // using accum_t = at::opmath_type<scalar_t>;
255   using accum_t = scalar_t;
256   using Vec = vec::Vectorized<accum_t>;
257   accum_t scaling_factor =
258       static_cast<accum_t>(util::calculate_scale(query, scale));
259 
260   int64_t batchSize = query.size(0);
261   int64_t num_head = query.size(1);
262   int64_t qSize = query.size(2);
263   int64_t headSize = query.size(3);
264   int64_t kvSize = value.size(2);
265   int64_t num_heads_kv = key.size(1);
266 
267   if (is_seq_at_dim_1) {
268     num_head = query.size(2);
269     num_heads_kv = key.size(2);
270     qSize = query.size(1);
271     kvSize = value.size(1);
272   }
273 
274   ET_CHECK_MSG(
275       num_heads_kv <= num_head,
276       "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64
277       " num key heads:%" PRId64,
278       num_head,
279       num_heads_kv);
280   ET_CHECK_MSG(
281       num_head % num_heads_kv == 0,
282       "FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64
283       " and num kv heads=%" PRId64,
284       num_head,
285       num_heads_kv);
286   int64_t num_reps = num_head / num_heads_kv;
287 
288   bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
289   if (has_attn_mask) {
290     /*
291     TODO: fix this for upcasting attn mask
292     if (is_reduced_type) {
293       // SHould not come here for now.
294       attn_mask.value() = attn_mask.value().to(at::kFloat);
295     }
296     */
297     ET_CHECK_MSG(attn_mask.value().dim() == 2, "attn_mask must be 2D");
298     ET_CHECK_MSG(
299         attn_mask.value().size(0) == qSize, "attn_mask shape mismatch");
300     ET_CHECK_MSG(
301         attn_mask.value().size(1) == kvSize,
302         "attn_mask shape mismatch"
303         "attn_mask.size(1)=%zd kvSize=%" PRId64,
304         attn_mask.value().size(1),
305         kvSize);
306   }
307 
308   auto strides = query.strides();
309   int64_t qStrideB = strides[0];
310   int64_t qStrideH = strides[1];
311   int64_t qStrideM = strides[2];
312 
313   if (is_seq_at_dim_1) {
314     qStrideH = strides[2];
315     qStrideM = strides[1];
316   }
317 
318   strides = key.strides();
319   int64_t kStrideB = strides[0];
320   int64_t kStrideH = strides[1];
321   int64_t kStrideN = strides[2];
322 
323   if (is_seq_at_dim_1) {
324     kStrideH = strides[2];
325     kStrideN = strides[1];
326   }
327 
328   strides = value.strides();
329   int64_t vStrideB = strides[0];
330   int64_t vStrideH = strides[1];
331   int64_t vStrideN = strides[2];
332 
333   if (is_seq_at_dim_1) {
334     vStrideH = strides[2];
335     vStrideN = strides[1];
336   }
337 
338   strides = output.strides();
339   int64_t oStrideB = strides[0];
340   int64_t oStrideH = strides[1];
341   int64_t oStrideM = strides[2];
342 
343   if (is_seq_at_dim_1) {
344     oStrideH = strides[2];
345     oStrideM = strides[1];
346   }
347 
348   int64_t mStrideB = 0;
349   int64_t mStrideH = 0;
350   int64_t mStrideM = 0;
351   if (has_attn_mask) {
352     // int64_t mStrideB = 0;
353     //(has_attn_mask && attn_mask.value().size(0) > 1)
354     //    ? attn_mask.value().stride(0)
355     //    : 0;
356     // int64_t mStrideH = 0;
357     //(has_attn_mask && attn_mask.value().size(1) > 1)
358     //    ? attn_mask.value().stride(1)
359     //    : 0;
360     strides = attn_mask.value().strides();
361     mStrideM = strides[0];
362   }
363 
364   int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
365   int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
366   int64_t qSlice = (qSize - 1) / qSplitSize + 1;
367 #ifdef ET_USE_THREADPOOL
368   int64_t num_thread =
369       ::executorch::extension::threadpool::get_threadpool()->get_thread_count();
370 #else
371   int64_t num_thread = 1;
372 #endif
373 
374   // const auto dtype = query.scalar_type();
375   // Following will be revisited in the future
376   // const auto accumulate_dtype = dtype; // toOpMathType(dtype);
377 
378   // allocate per thread temp buf (accumulate type)
379   int64_t size_per_thread =
380       /* qk     */ qSplitSize * kvSplitSize +
381       /* qk_max */ qSplitSize +
382       /* qk_sum */ qSplitSize +
383       /* dst    */ qSplitSize * headSize;
384 
385   int64_t size_bytes = size_per_thread * num_thread * query.element_size();
386   std::vector<char> buf_vec(size_bytes);
387   void* buf = reinterpret_cast<void*>(buf_vec.data());
388   // Need to double check the following
389   size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
390   std::vector<char> buf_reduced_vec(size_bytes);
391   void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
392   // at::Tensor buf_reduced = at::empty(
393   //    {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
394   //    query.options());
395 
396   // Data ptrs
397   const scalar_t* q_data = query.const_data_ptr<scalar_t>();
398   const scalar_t* k_data = key.const_data_ptr<scalar_t>();
399   const scalar_t* v_data = value.const_data_ptr<scalar_t>();
400   const accum_t* mask_data =
401       has_attn_mask ? attn_mask.value().const_data_ptr<accum_t>() : nullptr;
402   scalar_t* out_data = output.mutable_data_ptr<scalar_t>();
403   accum_t* buf_data = reinterpret_cast<accum_t*>(buf);
404   scalar_t* buf_reduced_data =
405       is_reduced_type ? reinterpret_cast<scalar_t*>(buf_reduced) : nullptr;
406 
407   auto compute_lambda = [&](int64_t begin, int64_t end) {
408     int64_t i = 0, j = 0, k = 0;
409     util::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
410     int ompIdx = torch::executor::get_thread_num();
411     accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
412     accum_t* qk_data = buf_ptr;
413     accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
414     accum_t* qk_sum_data = qk_max_data + qSplitSize;
415     accum_t* dst_data = qk_sum_data + qSplitSize;
416     scalar_t* qk_reduced_data = is_reduced_type
417         ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
418         : nullptr;
419 
420     for (int64_t z = begin; z < end; z++) {
421       int64_t m = k * qSplitSize;
422       int64_t qBlockSize = std::min(qSplitSize, qSize - m);
423       // Initialize max and sum
424       fill_stub(
425           qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
426       // Original flash sdpa wasnt really meant to be used
427       // for decode the way we are using via start_pos here.
428       // Thus when num_keys is 1 during decode phase, we
429       // still need to iterate through all the kv_splits
430       // Take start_pos = 130 and k_split_size = 128
431       // Here we have to produce [1x130] of q @ k.T
432       // when seq_len = 1
433       // But if num_keys = 1 then we dont really loop over
434       // all kv_splits.
435       // When k_split_size > 130, this is not an issue because
436       // there is only one iteration of the following loop anyway.
437       // Outside of determining how many loop iterations are needed
438       // num_keys participates only in causal attention.
439       // Rest of the calculation of q @ k.T and @ v.T is same.
440       // We dont run into this bug when k_split_size < start_pos + seqlen
441       // since there is only one iteration and that applies
442       // causal attention correctly.
443       // Howeve when k_split_size > start_pos + seqlen, we have
444       // more than one iteration, however if we dont adjust num_keys
445       // we dont get more than one iteration
446       // This is unique to this deployment of flash attention since
447       // original implementation wasnt deployed on this way.
448 
449       // Some of these bugs can be resolved by relying on attention mask
450       // but that requires storing attention mask in float as the current
451       // code doesnt support bool attention mask.
452       // However, lets just fix that as well.
453       int64_t num_keys =
454           is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
455       auto j_kv = j / num_reps;
456       for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
457         int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
458         // Calculate scale * q @ k.T
459         fill_stub(qk_data, static_cast<accum_t>(0), qSplitSize * kvSplitSize);
460         ::executorch::cpublas::gemm(
461             ::executorch::cpublas::TransposeType::Transpose,
462             ::executorch::cpublas::TransposeType::NoTranspose,
463             kvBlockSize,
464             qBlockSize,
465             headSize,
466             static_cast<accum_t>(1),
467             k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN,
468             kStrideN,
469             q_data + i * qStrideB + j * qStrideH + m * qStrideM,
470             qStrideM,
471             static_cast<accum_t>(0),
472             qk_data,
473             kvBlockSize);
474         // Apply causal mask, fill unused, i.e. future values, with -inf
475         // Say you have q @ k.T size = [16, 32]
476         // With qblock size = 4, say you are processing
477         // q seq len dim = 8:11.
478         // Say kvSplitSize = 4
479         // Then for causal mask, the entries that needs to be
480         // ignored are
481         // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
482         // Following condition says that num_keys = 8 + 4 =12
483         // (num_keys - n) <= kvSplitSize
484         // num_keys <= n + kvSplitSize
485         // If n + kvSplitSize is larger than 12, then some
486         // entries need masked out. In our example n = 4
487         // will qualify for that
488         if (is_causal && num_keys - n <= kvSplitSize) {
489           // For this fn to work k_split_size > q_split_size
490           for (int32_t row = 0; row < qBlockSize; ++row) {
491             int64_t last_col = m + (row + start_pos) - n;
492             accum_t* row_ptr = qk_data + row * kvBlockSize;
493             fill_stub(
494                 row_ptr + last_col + 1,
495                 -std::numeric_limits<accum_t>::infinity(),
496                 kvBlockSize - last_col - 1);
497           }
498         }
499         // Update attention weights with attention mask
500         // And apply scaling factor
501         // qk <- qk * scaling + attn_mask
502         if (has_attn_mask) {
503           for (int64_t row = 0; row < qBlockSize; ++row) {
504             vec::map2<accum_t>(
505                 [scaling_factor](Vec x, Vec y) {
506                   return x * Vec(scaling_factor) + y;
507                 },
508                 qk_data + row * kvBlockSize,
509                 qk_data + row * kvBlockSize,
510                 mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM +
511                     n,
512                 kvBlockSize);
513           }
514         }
515         // Update coefficients with Softmax
516         accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
517         for (int64_t row = 0; row < qBlockSize; ++row) {
518           if (has_attn_mask) {
519             // max per row
520             tmp_max = vec::reduce_all<accum_t>(
521                 [](Vec& x, Vec& y) { return vec::maximum(x, y); },
522                 qk_data + row * kvBlockSize,
523                 kvBlockSize);
524           } else {
525             // apply scaling factor and max per row in fusion
526             _mul_reduce_max_fusion_kernel(
527                 qk_data + row * kvBlockSize,
528                 scaling_factor,
529                 kvBlockSize,
530                 qk_data + row * kvBlockSize,
531                 tmp_max);
532           }
533           tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
534           // qk <- exp(qk - max) and sum per row
535           tmp_sum = tmp_max;
536           _exp_reduce_sum_fusion_kernel(
537               qk_data + row * kvBlockSize,
538               kvBlockSize,
539               conditional_data_ptr(qk_data, qk_reduced_data) +
540                   row * kvBlockSize,
541               tmp_sum);
542           // exp_tmp <- exp(max[row] - max)
543           exp_tmp = std::exp(qk_max_data[row] - tmp_max);
544           // sum[row] <- sum + exp_tmp * sum[row]
545           qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
546           // max[row] <- max
547           qk_max_data[row] = tmp_max;
548           // dst <- dst * exp_tmp
549           if (n > 0) {
550             vec::map<accum_t>(
551                 [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
552                 dst_data + row * headSize,
553                 dst_data + row * headSize,
554                 headSize);
555           }
556         }
557         // Calculate Softmax(q @ k.T) @ v
558         ::executorch::cpublas::gemm(
559             ::executorch::cpublas::TransposeType::NoTranspose,
560             ::executorch::cpublas::TransposeType::NoTranspose,
561             headSize,
562             qBlockSize,
563             kvBlockSize,
564             static_cast<accum_t>(1),
565             v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN,
566             vStrideN,
567             conditional_data_ptr(qk_data, qk_reduced_data),
568             kvBlockSize,
569             n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
570             dst_data,
571             headSize);
572       }
573       // dst <- dst / sum[row]
574       // reorder MHA output with strides
575       for (int64_t row = 0; row < qBlockSize; ++row) {
576         accum_t sum_reciprocal = 1 / qk_sum_data[row];
577         vec::map<scalar_t>(
578             [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
579             out_data + i * oStrideB + j * oStrideH + m * oStrideM +
580                 row * oStrideM,
581             dst_data + row * headSize,
582             headSize);
583       }
584       // Move to the next query
585       util::data_index_step(i, batchSize, j, num_head, k, qSlice);
586     }
587   };
588   torch::executor::parallel_for(
589       0, batchSize * num_head * qSlice, 1, compute_lambda);
590 }
591 
validate_flash_attention_args(const Tensor & query,const Tensor & key,const Tensor & value,const optional<Tensor> & attn_mask)592 bool validate_flash_attention_args(
593     const Tensor& query,
594     const Tensor& key,
595     const Tensor& value,
596     const optional<Tensor>& attn_mask) {
597   ET_LOG_MSG_AND_RETURN_IF_FALSE(query.dim() == 4, "query must be a 4D tensor");
598   ET_LOG_MSG_AND_RETURN_IF_FALSE(key.dim() == 4, "key must be a 4D tensor");
599   ET_LOG_MSG_AND_RETURN_IF_FALSE(value.dim() == 4, "value must be a 4D tensor");
600 
601   // Sizes
602   ET_LOG_MSG_AND_RETURN_IF_FALSE(
603       (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
604       "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
605 
606   ET_LOG_MSG_AND_RETURN_IF_FALSE(
607       (query.scalar_type() == ScalarType::Float), "Query must be Float type");
608 
609   ET_LOG_MSG_AND_RETURN_IF_FALSE(
610       (query.scalar_type() == key.scalar_type()) &&
611           (query.scalar_type() == value.scalar_type()),
612       "Key and Value must have the same data type as Query");
613 
614   ET_LOG_MSG_AND_RETURN_IF_FALSE(
615       !attn_mask.has_value() || attn_mask.value().dim() == 2,
616       "Attention mask must be a 2D tensor");
617 
618   ET_LOG_MSG_AND_RETURN_IF_FALSE(
619       !attn_mask.has_value() ||
620           attn_mask.value().scalar_type() == query.scalar_type(),
621       "Attention mask must be a 2D tensor");
622 
623   ET_LOG_MSG_AND_RETURN_IF_FALSE(
624       is_contiguous_dim_order(query.dim_order().data(), query.dim()),
625       "key cache must be in contiguous dim order");
626 
627   ET_LOG_MSG_AND_RETURN_IF_FALSE(
628       is_contiguous_dim_order(key.dim_order().data(), key.dim()),
629       "value cache must be in contiguous dim order");
630 
631   ET_LOG_MSG_AND_RETURN_IF_FALSE(
632       is_contiguous_dim_order(value.dim_order().data(), value.dim()),
633       "value cache must be in contiguous dim order");
634 
635   if (attn_mask.has_value()) {
636     ET_LOG_MSG_AND_RETURN_IF_FALSE(
637         is_contiguous_dim_order(
638             attn_mask.value().dim_order().data(), attn_mask.value().dim()),
639         "value cache must be in contiguous dim order");
640   }
641 
642   return true;
643 }
644 
validate_cache_params(const Tensor & k_cache,const Tensor & v_cache,int64_t start_pos,int64_t seq_length)645 bool validate_cache_params(
646     const Tensor& k_cache,
647     const Tensor& v_cache,
648     int64_t start_pos,
649     int64_t seq_length) {
650   ET_LOG_MSG_AND_RETURN_IF_FALSE(
651       k_cache.dim() == 4, "kcache must be a 4D tensor");
652 
653   ET_LOG_MSG_AND_RETURN_IF_FALSE(
654       v_cache.dim() == 4, "v_cache must be a 4D tensor");
655 
656   ET_LOG_MSG_AND_RETURN_IF_FALSE(
657       start_pos < k_cache.size(1),
658       "start_pos must be less than key cache at dim 1");
659 
660   ET_LOG_MSG_AND_RETURN_IF_FALSE(
661       start_pos < v_cache.size(1),
662       "start_pos must be less than value cache at dim 1");
663 
664   ET_LOG_MSG_AND_RETURN_IF_FALSE(
665       (start_pos + seq_length) <= k_cache.size(1),
666       "start_post + seq_length must be less than max seq length supported by key cache."
667       "start pos: %" PRId64 ", seq_length: %" PRId64
668       "."
669       "key cache size: %zd",
670       start_pos,
671       seq_length,
672       k_cache.size(1));
673 
674   ET_LOG_MSG_AND_RETURN_IF_FALSE(
675       (start_pos + seq_length) <= v_cache.size(1),
676       "start_post + seq_length must be less than max seq length supported by key cache."
677       "start pos: %" PRId64 ", seq_length: %" PRId64
678       "."
679       "value cache size: %zd",
680       start_pos,
681       seq_length,
682       v_cache.size(1));
683 
684   // Make sure they are in contiguous dim order
685   ET_LOG_MSG_AND_RETURN_IF_FALSE(
686       is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
687       "key cache must be in contiguous dim order");
688 
689   ET_LOG_MSG_AND_RETURN_IF_FALSE(
690       is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
691       "value cache must be in contiguous dim order");
692 
693   return true;
694 }
695 
696 // TODO: seq_length is not yet used for copy
update_cache(const Tensor & projected_value,const Tensor & cache,int64_t start_pos,int64_t seq_length)697 void update_cache(
698     const Tensor& projected_value,
699     const Tensor& cache,
700     int64_t start_pos,
701     int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
702   // 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
703   // 2) projected_value shape should be [bs, seq_len, num heads, head dim]
704   // 3) We're updating the cache with projected_value, at position start_pos
705 
706   ET_CHECK_MSG(
707       projected_value.size(0) == cache.size(0),
708       "projected_value batch size should be equal to the cache batch size.");
709   ET_CHECK_MSG(
710       projected_value.size(2) == cache.size(2),
711       "projected_value number of heads should be equal to the cache number of heads.");
712   ET_CHECK_MSG(
713       projected_value.size(3) == cache.size(3),
714       "projected_value embedding dimension should be equal to the cache embedding dimension.");
715   ET_CHECK_MSG(
716       projected_value.element_size() == cache.element_size(),
717       "projected_value data type size should be equal to the cache data type size.");
718 
719   ET_CHECK_MSG(
720       is_contiguous_dim_order(
721           projected_value.dim_order().data(), projected_value.dim()),
722       "projected value must be in contiguous dim order");
723   const void* projected_value_data = projected_value.const_data_ptr();
724   void* cache_data = cache.mutable_data_ptr();
725 
726   ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
727   ET_CHECK_MSG(cache_data, "cache data is null");
728 
729   auto cache_strides = cache.strides();
730   exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
731   exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
732 
733   auto value_strides = projected_value.strides();
734   exec_aten::StridesType value_batch_dim_stride = value_strides[0];
735 
736   exec_aten::SizesType num_bytes_to_copy =
737       (projected_value.numel() / projected_value.size(0)) *
738       projected_value.element_size();
739 
740   for (int64_t batch_line = 0; batch_line < projected_value.size(0);
741        ++batch_line) {
742     exec_aten::SizesType cache_pos_offset =
743         (batch_line * cache_batch_dim_stride +
744          start_pos * cache_seq_dim_stride) *
745         cache.element_size();
746     exec_aten::SizesType value_pos_offset =
747         (batch_line * value_batch_dim_stride) * cache.element_size();
748 
749     std::memcpy(
750         (uint8_t*)cache_data + cache_pos_offset,
751         (uint8_t*)projected_value_data + value_pos_offset,
752         num_bytes_to_copy);
753   }
754 }
755 
756 } // anonymous namespace
757 
flash_attention_kernel_out(RuntimeContext & ctx,const Tensor & query,const Tensor & key,const Tensor & value,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)758 Tensor& flash_attention_kernel_out(
759     RuntimeContext& ctx,
760     const Tensor& query,
761     const Tensor& key,
762     const Tensor& value,
763     const optional<Tensor>& attn_mask,
764     const double dropout_p,
765     const bool is_causal,
766     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
767     const optional<double> scale,
768     Tensor& output) {
769   (void)ctx;
770   ET_KERNEL_CHECK(
771       ctx,
772       validate_flash_attention_args(query, key, value, attn_mask),
773       InvalidArgument,
774       output);
775 
776   ET_KERNEL_CHECK(
777       ctx,
778       resize_tensor(output, query.sizes()) == Error::Ok,
779       InvalidArgument,
780       output);
781 
782   auto q_seq_len = query.size(2);
783 
784   ET_SWITCH_FLOAT_TYPES(
785       query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
786         // TODO we need to re-evaluate this for ARM CPUs
787         // And there can be many so instead of templatizing
788         // we might consider another appraoch
789         if (q_seq_len >= 768) {
790           cpu_flash_attention<CTYPE, 256, 512>(
791               output,
792               query,
793               key,
794               value,
795               dropout_p,
796               is_causal,
797               attn_mask,
798               scale);
799         } else if (q_seq_len >= 192) {
800           cpu_flash_attention<CTYPE, 64, 512>(
801               output,
802               query,
803               key,
804               value,
805               dropout_p,
806               is_causal,
807               attn_mask,
808               scale);
809         } else {
810           cpu_flash_attention<CTYPE, 32, 512>(
811               output,
812               query,
813               key,
814               value,
815               dropout_p,
816               is_causal,
817               attn_mask,
818               scale);
819         }
820       });
821   return output;
822 }
823 
824 /*
825   Input params
826   @param[in] q_projected Projected query with query weights.
827   Format [n_layers, batch size, seq_len, num heads, head dim]
828   @param[in] k_projected Projected query with key weights.
829   Format [n_layers, batch size, seq_len, num heads, head dim]
830   @param[in] v_projected Projected query with value weights.
831   Format [n_layers, batch size, seq_len, num heads, head dim]
832   @param[in] key_cache Cache of previous k_projected.
833   Format [n_layers, batch size, max_seq_len, num heads, head dim]
834   @param[in] key_cache Cache of previous v_projected.
835   Format [n_layers, batch size, max_seq_len, num heads, head dim]
836   ....
837   @param[in] start_pos: sequence position
838   @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
839 */
custom_sdpa_out(RuntimeContext & ctx,const Tensor & q,const Tensor & k,const Tensor & v,const int64_t start_pos,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)840 Tensor& custom_sdpa_out(
841     RuntimeContext& ctx,
842     const Tensor& q,
843     const Tensor& k,
844     const Tensor& v,
845     const int64_t start_pos,
846     const optional<Tensor>& attn_mask,
847     const double dropout_p,
848     const bool is_causal,
849     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
850     const optional<double> scale,
851     Tensor& output) {
852   ET_KERNEL_CHECK_MSG(
853       ctx,
854       !attn_mask.has_value() || !is_causal,
855       InvalidArgument,
856       output,
857       "attn_mask and is_causal cannot be set at the same time");
858 
859   ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
860 
861   const int64_t seq_len = q.size(1);
862   auto q_seq_len = q.size(1);
863 
864   // Refactor the following into create_view util perhaps using
865   // TensorPtr
866   std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
867       0, 1, 2, 3};
868   std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
869   sliced_key_sizes[0] = k.size(0);
870   sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
871   sliced_key_sizes[2] = k.size(2);
872   sliced_key_sizes[3] = k.size(3);
873   std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
874   dim_order_to_stride_nocheck(
875       sliced_key_sizes.data(),
876       sliced_key_dim_order.data(),
877       util::kKVDim,
878       sliced_key_strides.data());
879   // since the cache is sliced, the batch stride needs to stay the same.
880   sliced_key_strides[0] = k.strides()[0];
881   void* key_cache_data = k.mutable_data_ptr();
882   TensorImpl k_impl = TensorImpl(
883       k.scalar_type(),
884       util::kKVDim,
885       sliced_key_sizes.data(),
886       key_cache_data,
887       sliced_key_dim_order.data(),
888       sliced_key_strides.data(),
889       TensorShapeDynamism::STATIC);
890   Tensor sliced_key_cache(&k_impl);
891 
892   std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
893       0, 1, 2, 3};
894   std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
895   sliced_value_sizes[0] = v.size(0);
896   sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
897   sliced_value_sizes[2] = v.size(2);
898   sliced_value_sizes[3] = v.size(3);
899   std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
900   dim_order_to_stride_nocheck(
901       sliced_value_sizes.data(),
902       sliced_value_dim_order.data(),
903       util::kKVDim,
904       sliced_value_strides.data());
905   // since the cache is sliced, the batch stride needs to stay the same.
906   sliced_value_strides[0] = v.strides()[0];
907   void* value_cache_data = v.mutable_data_ptr();
908   TensorImpl value_impl = TensorImpl(
909       v.scalar_type(),
910       util::kKVDim,
911       sliced_value_sizes.data(),
912       value_cache_data,
913       sliced_value_dim_order.data(),
914       sliced_value_strides.data(),
915       TensorShapeDynamism::STATIC);
916   Tensor sliced_value_cache(&value_impl);
917 
918   ET_KERNEL_CHECK(
919       ctx,
920       resize_tensor(output, q.sizes()) == Error::Ok,
921       InvalidArgument,
922       output);
923 
924   // TODO(task): replace the template param selection logic
925   // with whatever apprpriately makes more sense for
926   ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
927     // TODO we need to re-evaluate this for ARM CPUs
928     // And there can be many so instead of templatizing
929     // we might consider another appraoch
930     if (q_seq_len >= 768) {
931       cpu_flash_attention<CTYPE, 256, 512>(
932           output,
933           q,
934           sliced_key_cache,
935           sliced_value_cache,
936           dropout_p,
937           is_causal,
938           attn_mask,
939           scale,
940           true, /* is_seq_at_dim_1 */
941           start_pos);
942     } else if (q_seq_len >= 192) {
943       cpu_flash_attention<CTYPE, 64, 512>(
944           output,
945           q,
946           sliced_key_cache,
947           sliced_value_cache,
948           dropout_p,
949           is_causal,
950           attn_mask,
951           scale,
952           true, /* is_seq_at_dim_1 */
953           start_pos);
954     } else {
955       cpu_flash_attention<CTYPE, 32, 512>(
956           output,
957           q,
958           sliced_key_cache,
959           sliced_value_cache,
960           dropout_p,
961           is_causal,
962           attn_mask,
963           scale,
964           true, /* is_seq_at_dim_1 */
965           start_pos);
966     }
967   });
968   return output;
969 }
970 /*
971   Input params
972   @param[in] q_projected Projected query with query weights.
973   Format [n_layers, batch size, seq_len, num heads, head dim]
974   @param[in] k_projected Projected query with key weights.
975   Format [n_layers, batch size, seq_len, num heads, head dim]
976   @param[in] v_projected Projected query with value weights.
977   Format [n_layers, batch size, seq_len, num heads, head dim]
978   @param[in] key_cache Cache of previous k_projected.
979   Format [n_layers, batch size, max_seq_len, num heads, head dim]
980   @param[in] key_cache Cache of previous v_projected.
981   Format [n_layers, batch size, max_seq_len, num heads, head dim]
982   ....
983   @param[in] start_pos: sequence position
984   @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
985 */
sdpa_with_kv_cache_out(KernelRuntimeContext & ctx,const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> & attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)986 Tensor& sdpa_with_kv_cache_out(
987     KernelRuntimeContext& ctx,
988     const Tensor& q_projected,
989     const Tensor& k_projected,
990     const Tensor& v_projected,
991     Tensor& key_cache,
992     Tensor& value_cache,
993     const int64_t start_pos,
994     const int64_t seq_len,
995     const optional<Tensor>& attn_mask,
996     const double dropout_p,
997     const bool is_causal,
998     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
999     const optional<double> scale,
1000     Tensor& output) {
1001   (void)ctx;
1002   ET_KERNEL_CHECK(
1003       ctx,
1004       validate_cache_params(key_cache, value_cache, start_pos, seq_len),
1005       InvalidArgument,
1006       output);
1007 
1008   ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
1009 
1010   update_cache(k_projected, key_cache, start_pos, seq_len);
1011   update_cache(v_projected, value_cache, start_pos, seq_len);
1012 
1013   custom_sdpa_out(
1014       ctx,
1015       q_projected,
1016       key_cache,
1017       value_cache,
1018       start_pos,
1019       attn_mask,
1020       dropout_p,
1021       is_causal,
1022       scale,
1023       output);
1024 
1025   return output;
1026 }
1027 } // namespace native
1028 } // namespace executor
1029 } // namespace torch
1030 
1031 EXECUTORCH_LIBRARY(
1032     llama,
1033     "sdpa_with_kv_cache.out",
1034     torch::executor::native::sdpa_with_kv_cache_out);
1035 
1036 EXECUTORCH_LIBRARY(
1037     llama,
1038     "custom_sdpa.out",
1039     torch::executor::native::custom_sdpa_out);
1040