1 #pragma once
2 #include <ATen/Context.h>
3 #include <ATen/NestedTensorImpl.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <c10/core/ScalarType.h>
10 
11 #include <c10/util/Exception.h>
12 #include <c10/util/env.h>
13 #include <c10/util/irange.h>
14 
15 #include <c10/core/SymInt.h>
16 #include <c10/core/SymFloat.h>
17 #include <c10/util/string_view.h>
18 #include <c10/util/Array.h>
19 #include <cmath>
20 #include <cstdint>
21 #include <functional>
22 
23 namespace sdp {
24 
25 constexpr int32_t num_backends = 5;
26 enum class SDPBackend {
27   error = -1,
28   math = 0,
29   flash_attention = 1,
30   efficient_attention = 2,
31   cudnn_attention = 3,
32   overrideable = 4
33 };
34 
35 // Note that if this changed make sure to update
36 // the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
37 enum class CustomMaskType {
38   NoCustomMask = 0,
39   CausalFromTopLeft = 1,
40   CausalFromBottomRight = 2,
41   NumCustomMaskTypes,
42 };
43 
44 struct sdp_params {
45   at::Tensor query;
46   at::Tensor key;
47   at::Tensor value;
48   std::optional<at::Tensor> attn_mask;
49   double dropout;
50   bool is_causal;
51   bool enable_gqa;
52 };
53 
54 SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
55 
calculate_scale(const at::Tensor & query,std::optional<double> scale)56 inline c10::SymFloat calculate_scale(
57     const at::Tensor& query,
58     std::optional<double> scale) {
59   const auto softmax_scale = scale.has_value()
60       ? scale.value()
61       : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
62   return c10::SymFloat(softmax_scale);
63 }
64 
65 using c10::array_of;
66 
input_requires_grad(sdp_params const & params)67 inline bool input_requires_grad(sdp_params const& params) {
68   const bool any_inputs_require_grad = params.query.requires_grad() ||
69       params.key.requires_grad() || params.value.requires_grad();
70   const bool gradmode_enabled = at::GradMode::is_enabled();
71   return any_inputs_require_grad && gradmode_enabled;
72 }
73 
has_for_nested_inputs(sdp_params const & params)74 inline bool has_for_nested_inputs(sdp_params const& params) {
75   return
76       (params.query.is_nested() && params.query.layout() == c10::kStrided) ||
77       (params.key.is_nested() && params.key.layout() == c10::kStrided) ||
78       (params.value.is_nested() && params.value.layout() == c10::kStrided);
79 }
80 
has_for_dense_inputs(sdp_params const & params)81 inline bool has_for_dense_inputs(sdp_params const& params) {
82   return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
83 }
84 
has_only_dense_inputs(sdp_params const & params)85 inline bool has_only_dense_inputs(sdp_params const& params) {
86   return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
87 }
88 
89 template <typename dtype_vector>
check_tensor_dtype(sdp_params const & params,dtype_vector allowed_dtypes,bool debug)90 inline bool check_tensor_dtype(
91     sdp_params const& params,
92     dtype_vector allowed_dtypes,
93     bool debug) {
94   auto query_dtype = params.query.dtype();
95   if (!(query_dtype == params.key.dtype() &&
96         query_dtype == params.value.dtype() &&
97         (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
98          allowed_dtypes.end()))) {
99     if (debug) {
100       TORCH_WARN(
101           "Expected query, key and value to all be of dtype: {",
102           c10::Join(", ", allowed_dtypes),
103           "}. Got ",
104           "Query dtype: ",
105           params.query.dtype(),
106           ", Key dtype: ",
107           params.key.dtype(),
108           ", and Value dtype: ",
109           params.value.dtype(),
110           " instead.");
111     }
112     return false;
113   }
114   return true;
115 }
116 
117 
try_broadcast_param_size(const c10::SymInt q_size,const c10::SymInt k_size,const c10::SymInt v_size,c10::string_view param_name,bool debug)118 inline bool try_broadcast_param_size(
119     const c10::SymInt q_size,
120     const c10::SymInt k_size,
121     const c10::SymInt v_size,
122     c10::string_view param_name,
123     bool debug) {
124   auto max_size = std::max({q_size, k_size, v_size});
125   if ((q_size != max_size && q_size != 1) ||
126       (k_size != max_size && k_size != 1) ||
127       (v_size != max_size && v_size != 1)) {
128     if (debug) {
129       TORCH_WARN(
130           "Both fused kernels require query, key and value to have broadcastable ",
131           param_name,
132           "got Query ",
133           param_name,
134           q_size,
135           ", Key ",
136           param_name,
137           k_size,
138           ", Value ",
139           param_name,
140           v_size,
141           " instead.");
142     }
143     return false;
144   }
145   return true;
146 }
147 
check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(at::Tensor const & param,c10::string_view param_name,bool debug)148 inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
149     at::Tensor const& param,
150     c10::string_view param_name,
151     bool debug) {
152   const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
153   const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
154   auto num_head_dims = nt_tensor_impl->opt_size(1);
155   if (!num_head_dims.has_value()) {
156     // num_head_dims is ragged
157     if (debug) {
158       TORCH_WARN(
159           "Fused kernels do not support ragged num_head_dims, ",
160           param_name,
161           "has a ragged num_heads.");
162     }
163     return false;
164   }
165 
166   auto* sizes_ptr = sizes.data_ptr<int64_t>();
167   const int64_t n_tensors = param.size(0);
168   const int64_t size_tensor_stride = sizes.stride(0);
169 
170   // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
171   for (const auto i : c10::irange(n_tensors)) {
172     if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
173       if (debug) {
174         TORCH_WARN(
175             "Fused kernels do not support seq_len == 0, ",
176             param_name,
177             "has a seq len of 0.");
178       }
179       return false;
180     }
181   }
182   return true;
183 }
184 
check_for_seq_len_0_nested_tensor(sdp_params const & params,bool debug)185 inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
186   // When this function is called we are assured that the nt is dim==4
187   bool q_is_safe = params.query.is_nested()
188       ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
189             params.query, "query ", debug)
190       : true;
191   // short circuit if any is unsafe
192   if (!q_is_safe) {
193     return false;
194   }
195 
196   bool k_is_safe = params.key.is_nested()
197       ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
198             params.key, "key ", debug)
199       : true;
200   if (!k_is_safe) {
201     return false;
202   }
203 
204   bool v_is_safe = params.value.is_nested()
205       ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
206             params.value, "value ", debug)
207       : true;
208   if (!v_is_safe) {
209     return false;
210   }
211 
212   // We now know none of the inputs have ragged num_heads, so we can safely
213   // access .size(1)
214   auto q_num_heads = params.query.size(1);
215   auto k_num_heads = params.key.size(1);
216   auto v_num_heads = params.value.size(1);
217   bool same_num_heads =
218       q_num_heads == k_num_heads && q_num_heads == v_num_heads;
219 
220   if (!same_num_heads) {
221     if (input_requires_grad(params)){
222       if (debug) {
223         TORCH_WARN(
224               "Both fused kernels do not support training with broadcasted NT inputs.");
225       }
226       return false;
227     }
228     return try_broadcast_param_size(
229         q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
230   }
231 
232   return true;
233 }
234 
check_nested_tensor(sdp_params const & params,bool debug)235 inline bool check_nested_tensor(sdp_params const& params, bool debug) {
236   // Return false if have nested tensor
237   if (!has_only_dense_inputs(params)) {
238     if (debug) {
239       TORCH_WARN(
240           "Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
241     }
242     return false;
243   }
244   return true;
245 }
246 
check_for_dropout(sdp_params const & params,bool debug)247 inline bool check_for_dropout(sdp_params const& params, bool debug) {
248   if (params.dropout > 0.0) {
249     if (debug) {
250       TORCH_WARN("Both fused kernels do not support non-zero dropout.");
251     }
252     return false;
253   }
254   return true;
255 }
256 
check_requires_grad_and_nested(sdp_params const & params,bool debug)257 inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
258   if (input_requires_grad(params)) {
259     if (debug) {
260       TORCH_WARN(
261           "Memory efficient attention currently doesn't support training with NT inputs.");
262     }
263     return false;
264   }
265   return true;
266 }
267 
check_for_attn_mask(sdp_params const & params,bool debug)268 inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
269   if (params.attn_mask.has_value()) {
270     if (debug) {
271       TORCH_WARN("Flash Attention does not support non-null attn_mask.");
272     }
273     return false;
274   }
275   return true;
276 }
277 
check_attn_mask_shape(sdp_params const & params,bool debug)278 inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
279   auto attn_mask = params.attn_mask;
280   if (!attn_mask.has_value()) {
281     return true;
282   }
283   if (attn_mask.value().requires_grad()) {
284     return false;
285   }
286   auto batchSize = params.query.sym_size(0);
287   auto qSize = params.query.sym_size(2);
288   auto kvSize = params.key.sym_size(2);
289   auto num_head = params.query.sym_size(1);
290   if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
291     return false;
292   }
293   if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
294     return false;
295   }
296   if (attn_mask.value().dim() == 2) {
297     return true;
298   } else if (attn_mask.value().dim() == 4) {
299     if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
300         && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
301       return true;
302     }
303   }
304   if (debug) {
305     TORCH_WARN("Please use the following attn mask shapes: ",
306         "2d - ({Q_seq_len, 1}  x {KV_seq_len, 1}); ",
307         "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1}  x {KV_seq_len, 1})");
308   }
309   return false;
310 }
311 
check_tensor_shapes(sdp_params const & params,bool debug)312 inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
313   auto query_dim = params.query.dim();
314   if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
315         (query_dim == 4))) {
316     if (debug) {
317       TORCH_WARN(
318           "Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
319           query_dim,
320           ", Key dim: ",
321           params.key.dim(),
322           ", Value dim: ",
323           params.value.dim(),
324           " instead.");
325     }
326     return false;
327   }
328   return true;
329 }
330 
check_safe_kv_broadcast(at::Tensor const & param,bool debug)331 inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
332   const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
333   auto seq_len = nt_tensor_impl->opt_size(2);
334   if (!seq_len.has_value()) {
335     if (debug) {
336       TORCH_WARN(
337           "For both fused kernels, if one of key/value batch_size requires "
338           "broadcasting and the other does not, then the other must have a ",
339           "consistent seq_len dim.")
340     }
341     return false;
342   }
343   return true;
344 }
345 
check_grouped_query_attention(sdp_params const & params,bool debug)346 inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
347   const auto q_num_heads = params.query.sym_size(-3);
348   const auto k_num_heads = params.key.sym_size(-3);
349   const auto v_num_heads = params.value.sym_size(-3);
350   const bool same_kv_heads = k_num_heads == v_num_heads;
351 
352   if (!(same_kv_heads)){
353     if (debug) {
354       TORCH_WARN(
355           "Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
356           "Key sizes: ",
357           params.key.sizes(),
358           ", Value sizes: ",
359           params.value.sizes(),
360           ", Query sizes: ",
361           params.query.sizes(),
362           " instead.");
363     }
364     return false;
365   }
366   // Check if grouped query attention is supported and validate the number of
367   // heads
368   if (q_num_heads % k_num_heads != 0) {
369     if (debug) {
370       TORCH_WARN(
371           "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
372           "Got input Key sizes(): ",
373           params.key.sym_size(-3),
374           ", Value sizes(): ",
375           params.value.sym_size(-3),
376           ", Query sizes(): ",
377           params.query.sym_size(-3),
378           " instead.");
379     }
380     return false;
381   }
382   return true;
383 }
384 
385 template <bool supports_gqa>
check_batch_size_and_num_heads_dense(sdp_params const & params,bool debug)386 inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
387   // This is expected to be called after check_tensor_shapes ensuring that the
388   // size() calls won't error since the inputs are all 4 dimensional
389 
390   auto q_batch_size = params.query.sym_size(0);
391   auto k_batch_size = params.key.sym_size(0);
392   auto v_batch_size = params.value.sym_size(0);
393 
394   bool same_batch_size =
395       q_batch_size == k_batch_size && q_batch_size == v_batch_size;
396 
397   auto q_num_heads = params.query.sym_size(-3);
398   auto k_num_heads = params.key.sym_size(-3);
399   auto v_num_heads = params.value.sym_size(-3);
400 
401   bool same_num_heads =
402       q_num_heads == k_num_heads && q_num_heads == v_num_heads;
403 
404   if (!same_batch_size){
405     if(debug) {
406       TORCH_WARN(
407           "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
408           "Query.sizes(): ",
409           params.query.sizes(),
410           ", Key.sizes(): ",
411           params.key.sizes(),
412           ", Value.sizes(): ",
413           params.value.sizes(),
414           " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
415     }
416     return false;
417   }
418 
419   if(params.enable_gqa && supports_gqa){
420     return check_grouped_query_attention(params, debug);
421   }
422 
423   if (!same_num_heads){
424     if (debug) {
425       TORCH_WARN(
426           "For dense input, both fused kernels require query, key and value to have the same num_heads. ",
427           "Query.sizes(): ",
428           params.query.sizes(),
429           ", Key sizes(): ",
430           params.key.sizes(),
431           ", Value sizes(): ",
432           params.value.sizes(),
433           " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
434     }
435     return false;
436   }
437   // If all checks pass, return true
438   return true;
439 }
440 
check_batch_size_nested(sdp_params const & params,bool debug)441 inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
442   // This is expected to be called after check_tensor_shapes ensuring that the
443   // size() calls won't error since the inputs are all 4 dimensional
444   auto q_batch_size = params.query.sym_size(0);
445   auto k_batch_size = params.key.sym_size(0);
446   auto v_batch_size = params.value.sym_size(0);
447 
448   bool same_batch_size =
449       q_batch_size == k_batch_size && q_batch_size == v_batch_size;
450 
451   // num_heads logic for nested input is checked in
452   // check_for_seq_len_0_nested_tensor as there is handling there to make sure
453   // num_heads is not ragged
454   bool broadcastable_batch_size = true;
455   if (!same_batch_size) {
456     if (input_requires_grad(params)){
457       if (debug) {
458         TORCH_WARN(
459             "Both fused kernels do not support training with broadcasted NT inputs.");
460       }
461       return false;
462     }
463     // try to broadcast batchsize
464     broadcastable_batch_size = try_broadcast_param_size(
465         q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
466 
467     // if only one of k or v require broadcasting of batch size, the other
468     // must have a consistent seq_len dim
469     if (broadcastable_batch_size) {
470       if (k_batch_size == 1 && v_batch_size != 1 &&
471           !check_safe_kv_broadcast(params.value, debug)) {
472         return false;
473       }
474       if (v_batch_size == 1 && k_batch_size != 1 &&
475           !check_safe_kv_broadcast(params.key, debug)) {
476         return false;
477       }
478     }
479   }
480   return broadcastable_batch_size;
481 }
482 
check_nonzero_sequence_lengths_dense(sdp_params const & params,bool debug)483 inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
484   // In some cases people will pass in 0 sized tensors, this will
485   // cause the fused path to error with unaligned mask
486   bool zero_seq_len_q = params.query.sym_size(-2) == 0;
487   bool zero_seq_len_k = params.key.sym_size(-2) == 0;
488   if (zero_seq_len_q || zero_seq_len_k) {
489     if (debug) {
490       TORCH_WARN(
491           "Both fused kernels do not support zero seq_len_q or seq_len_kv.");
492     }
493     return false;
494   }
495   return true;
496 }
497 
498 template<bool ignore_singleton_dim>
check_last_dim_stride_equals_1_dense(sdp_params const & params,bool debug)499 inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
500   // The stride checking for NestedTensors is done within the kernel
501   // And .contiguous will be called if needed
502 
503   // This function checks that the last dimension of the inputs to
504   // fused_attention have stride 1
505   bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
506       params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
507 
508   // https://github.com/pytorch/pytorch/issues/116333
509   // If the head_dim is size 1 the stride won't matter, but we
510   // check this condition before padding the head_dim to 1
511   if (ignore_singleton_dim){
512     qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
513   }
514   bool mask_stride_equal_1 = params.attn_mask.has_value()
515       ? params.attn_mask.value().sym_stride(-1) == 1
516       : true;
517   if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
518     if (debug) {
519       std::ostringstream epilogue_message;
520       if (params.attn_mask.has_value()) {
521         epilogue_message << ", Attn_mask.stride(-1): "
522                          << params.attn_mask.value().sym_stride(-1);
523       }
524       epilogue_message << " instead.";
525       TORCH_WARN(
526           "Both fused kernels require the last dimension of the input to have stride 1. ",
527           "Got Query.stride(-1): ",
528           params.query.sym_stride(-1),
529           ", Key.stride(-1): ",
530           params.key.sym_stride(-1),
531           ", Value.stride(-1): ",
532           params.value.sym_stride(-1),
533           epilogue_message.str());
534     }
535 
536     return false;
537   }
538   return true;
539 }
540 
check_runtime_disabled_flash(sdp_params const & params,bool debug)541 inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
542   // We check the global context to see if user has explicitly turned of flash
543   // sdp kernels
544   if (!at::globalContext().userEnabledFlashSDP()) {
545     if (debug) {
546       TORCH_WARN("Flash attention has been runtime disabled.");
547     }
548     return false;
549   }
550   return true;
551 }
552 
check_runtime_disabled_mem_efficient(sdp_params const & params,bool debug)553 inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
554   // We check the global context to see if user has explicitly turned of
555   // mem_efficient sdp kernels
556   if (!at::globalContext().userEnabledMemEfficientSDP()) {
557     if (debug) {
558       TORCH_WARN("Memory Efficient attention has been runtime disabled.");
559     }
560     return false;
561   }
562   return true;
563 }
564 
565 
566 } // namespace sdp
567