• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/flash_attention_score_info.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <tuple>
23 #include <map>
24 #include <algorithm>
25 
26 #include "ir/value.h"
27 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
28 #include "frontend/parallel/device_matrix.h"
29 #include "frontend/parallel/dynamic_creator.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31 #include "frontend/parallel/graph_util/generate_graph.h"
32 #include "frontend/parallel/graph_util/graph_utils.h"
33 #include "mindspore/core/ops/nn_ops.h"
34 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
35 #include "ops/op_enum.h"
36 
37 namespace mindspore {
38 using mindspore::ops::FASInputLayoutMode;
39 namespace parallel {
40 namespace {
41 constexpr size_t kInputRealShiftSeqDim = 2;
42 constexpr size_t kInputDropMaskSeqDim = 2;
43 constexpr size_t kOutputSoftmaxSeqDim = 2;
44 constexpr int64_t kLoadBalanceSplitNum = 2;
45 enum OpAttrUpdateMode : int64_t {
46   kLeftUpToLeftUp = 0,
47   kLeftUpToRightDown = 1,
48   kRightDownToRightDown = 2,
49 };
50 const std::vector<int64_t> needCompressAttnMask = {ops::kSparseLeftUpCausal, ops::kSparseRightDownCausal,
51                                                    ops::kSparseBand, ops::kSparseBlockLocal};
52 const std::map<int64_t, int64_t> opAttrUpdateMap = {{ops::kSparseDefaultMask, kLeftUpToLeftUp},
53                                                     {ops::kSparseLeftUpCausal, kLeftUpToRightDown},
54                                                     {ops::kSparseRightDownCausal, kRightDownToRightDown},
55                                                     {ops::kSparseBand, kRightDownToRightDown},
56                                                     {ops::kSparseBlockLocal, kLeftUpToRightDown}};
57 
GetNonMonadInputSize(const CNodePtr & cnode)58 size_t GetNonMonadInputSize(const CNodePtr &cnode) {
59   size_t cnode_non_monad_size = cnode->size();
60   for (auto &input : cnode->inputs()) {
61     if (HasAbstractMonad(input)) {
62       cnode_non_monad_size--;
63     }
64   }
65   return cnode_non_monad_size;
66 }
67 
NewSeedGeneration()68 int64_t NewSeedGeneration() {
69   static int64_t seed_generation = 0;
70   ++seed_generation;
71   return seed_generation;
72 }
73 
LongAdd(int64_t base,int64_t shift)74 int64_t LongAdd(int64_t base, int64_t shift) {
75   int64_t result;
76   if (shift > 0) {
77     if (base > INT_MAX - shift) {
78       result = INT_MAX;
79     } else {
80       result = base + shift;
81     }
82   } else {
83     if (base < INT_MIN - shift) {
84       result = INT_MIN;
85     } else {
86       result = base + shift;
87     }
88   }
89   return result;
90 }
91 
GetSplitNumByMapId(const Shape & dev_matrix,int64_t map_id)92 int64_t GetSplitNumByMapId(const Shape &dev_matrix, int64_t map_id) {
93   if (map_id == MAP_NONE) {
94     return NO_SPLIT_STRATEGY;
95   }
96   auto axis = dev_matrix.size() - 1 - LongToSize(map_id);
97   if (axis >= dev_matrix.size()) {
98     MS_LOG(EXCEPTION) << "The tensor map id (" << map_id
99                       << ") is out of device matrix's range. device_matrix: " << dev_matrix;
100   }
101   return dev_matrix[axis];
102 }
103 
GetSplitNumByTensorMap(const Shape & dev_matrix,const Shape & tensor_map)104 int64_t GetSplitNumByTensorMap(const Shape &dev_matrix, const Shape &tensor_map) {
105   auto split_num = std::accumulate(tensor_map.begin(), tensor_map.end(), 1, [&dev_matrix](int64_t a, int64_t map_id) {
106     return a * GetSplitNumByMapId(dev_matrix, map_id);
107   });
108   return split_num;
109 }
110 }  // namespace
111 
UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr & dropout_gen_mask_cnode)112 void FlashAttentionScoreInfo::UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr &dropout_gen_mask_cnode) {
113   if (!IsPrimitiveCNode(dropout_gen_mask_cnode, prim::kPrimDropoutGenMask)) {
114     return;
115   }
116 
117   // Update seed according rank_id for DropoutGenMask
118   PrimitivePtr prim = GetCNodePrimitive(dropout_gen_mask_cnode);
119   auto seed_0 = GetValue<int64_t>(prim->GetAttr(SEED0));
120   auto seed_1 = GetValue<int64_t>(prim->GetAttr(SEED1));
121   int64_t rank_id = g_device_manager->rank_index_in_stage();
122   int64_t seed_bias = 0;
123   // When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
124   if (seed_0 == 0 && seed_1 == 0) {
125     seed_bias = NewSeedGeneration();
126   }
127   MS_EXCEPTION_IF_ZERO("repeated_calc_num_", repeated_calc_num_);
128   if (repeated_num_in_dev_matrix_right_) {
129     seed_bias += rank_id / repeated_calc_num_;
130   } else {
131     int64_t device_num = stage_device_size_;
132     MS_EXCEPTION_IF_ZERO("device_num", device_num);
133     seed_bias += rank_id % (device_num / repeated_calc_num_);
134   }
135   auto clone_prim = prim->Clone();
136   clone_prim->set_attr(SEED0, MakeValue<int64_t>(seed_0 + seed_bias));
137   clone_prim->set_attr(SEED1, MakeValue<int64_t>(seed_1 + seed_bias));
138   auto func_graph = dropout_gen_mask_cnode->func_graph();
139   MS_EXCEPTION_IF_NULL(func_graph);
140   auto manager = func_graph->manager();
141   MS_EXCEPTION_IF_NULL(manager);
142   manager->SetEdge(dropout_gen_mask_cnode, 0, NewValueNode(clone_prim)->cast<AnfNodePtr>());
143 
144   // Update slice shape for DropoutGenMask and Reshape
145   Shape input_slice_shape = inputs_tensor_info_.at(ops::kFlashAttentionScoreInputDropMaskIndex).slice_shape();
146   constexpr int64_t BITS_NUM_PER_BYTE = 8;
147   input_slice_shape[input_slice_shape.size() - 1] *= BITS_NUM_PER_BYTE;  // Restores the shape of DropoutGenMask input
148   size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
149   if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
150     MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
151   }
152   if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(kIndex1))) {
153     MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
154   }
155   ValuePtr new_shape = MakeValue(input_slice_shape);
156   AnfNodePtr val = NewValueNode(new_shape);
157   manager->SetEdge(dropout_gen_mask_cnode, kIndex1, val);
158   MS_LOG(DEBUG) << "The input slice shape dropout is " << ShapeToString(input_slice_shape);
159 }
160 
InitIsInputPassed()161 void FlashAttentionScoreInfo::InitIsInputPassed() {
162   is_input_passed_.resize(input_value_.size());
163   for (size_t i = 0; i < input_value_.size(); ++i) {
164     is_input_passed_[i] = (input_value_[i] == nullptr || !input_value_[i]->isa<None>());
165   }
166 }
167 
GetStrategyRealIndex(size_t index)168 size_t FlashAttentionScoreInfo::GetStrategyRealIndex(size_t index) {
169   if (index >= is_input_passed_.size() || !is_input_passed_[index]) {
170     MS_LOG(INTERNAL_EXCEPTION) << name_ << ": GetStrategyRealIndex failed, index is " << index;
171   }
172   auto real_index = -1;
173   for (size_t i = 0; i <= index; ++i) {
174     if (is_input_passed_[i]) {
175       ++real_index;
176     }
177   }
178   return real_index;
179 }
180 
GetSPRankList()181 RankList FlashAttentionScoreInfo::GetSPRankList() {
182   CheckGlobalDeviceManager();
183   int64_t rank = g_device_manager->global_rank();
184   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
185   RankList group_devices;
186   int64_t seq_dim = SizeToLong(dev_matrix_shape_.size()) - dev_matrix_s1_dim_ - 1;
187   if (dev_matrix.GetDevicesAlongDim(seq_dim, &group_devices) != SUCCESS) {
188     MS_LOG(ERROR) << name_ << " get group devices along dim " << seq_dim << " failed.";
189   }
190   return group_devices;
191 }
192 
InitAttnMaskStrategies()193 Status FlashAttentionScoreInfo::InitAttnMaskStrategies() {
194   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
195     auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
196     int64_t s1_split_num_attn_mask = is_attn_mask_compressed_ ? 1 : s1_split_num_;
197     int64_t s2_split_num_attn_mask = enable_ring_attention_ ? s1_split_num_attn_mask : 1;
198     if (attn_mask_shape.size() == kSizeTwo) {
199       // attn_mask_shape: (S1, S2)
200       expect_strategies_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {s1_split_num_attn_mask,
201                                                                          s2_split_num_attn_mask};
202     } else if (attn_mask_shape.size() == kSizeFour) {
203       // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
204       auto attn_mask_n1_split_num = attn_mask_have_n1_dim_ ? n1_split_num_ : 1;
205       auto attn_batch_split_num = attn_mask_have_batch_dim_ ? batch_split_num_ : 1;
206       expect_strategies_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_batch_split_num, attn_mask_n1_split_num,
207                                                                          s1_split_num_attn_mask, 1};
208     }
209   }
210   return SUCCESS;
211 }
212 
InitExpectedStrategies()213 Status FlashAttentionScoreInfo::InitExpectedStrategies() {
214   expect_strategies_ = Strategies(ops::kFlashAttentionScoreInputsNum);
215   switch (input_layout_) {
216     case FASInputLayoutMode::BSH:
217       expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, s1_split_num_, n1_split_num_};
218       expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, s2_split_num_, n2_split_num_};
219       expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, s2_split_num_, n2_split_num_};
220       break;
221     case FASInputLayoutMode::SBH:
222       expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {s1_split_num_, batch_split_num_, n1_split_num_};
223       expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {s2_split_num_, batch_split_num_, n2_split_num_};
224       expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {s2_split_num_, batch_split_num_, n2_split_num_};
225       break;
226     case FASInputLayoutMode::BNSD:
227       expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, n1_split_num_, s1_split_num_,
228                                                                       1};
229       expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, n2_split_num_, s2_split_num_, 1};
230       expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, n2_split_num_, s2_split_num_,
231                                                                       1};
232       break;
233     case FASInputLayoutMode::BSND:
234       expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, s1_split_num_, n1_split_num_,
235                                                                       1};
236       expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, s2_split_num_, n2_split_num_, 1};
237       expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, s2_split_num_, n2_split_num_,
238                                                                       1};
239       break;
240     case FASInputLayoutMode::TND:
241       expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_ * s1_split_num_, n1_split_num_,
242                                                                       1};
243       expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, n2_split_num_, 1};
244       expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, n2_split_num_, 1};
245       break;
246     default:
247       MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
248       return FAILED;
249   }
250 
251   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
252     int64_t real_shift_s1_split_num = real_shift_have_s1_dim_ ? s1_split_num_ : 1;
253     auto real_shift_batch_split_num = real_shift_have_batch_dim_ ? batch_split_num_ : 1;
254     expect_strategies_[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_split_num, n1_split_num_,
255                                                                         real_shift_s1_split_num, 1};
256   }
257   if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
258     expect_strategies_[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_split_num_, n1_split_num_, s1_split_num_,
259                                                                        1};
260   }
261   if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
262     expect_strategies_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
263   }
264   InitAttnMaskStrategies();
265 
266   // padding_mask is not support yet, skip it.
267 
268   if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
269     expect_strategies_[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_split_num_};
270   }
271   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
272     expect_strategies_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {batch_split_num_};
273   }
274   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
275     expect_strategies_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {batch_split_num_};
276   }
277   expect_strategies_.erase(std::remove(expect_strategies_.begin(), expect_strategies_.end(), Shape{}),
278                            expect_strategies_.end());
279   return SUCCESS;
280 }
281 
InitQKVTensorMap()282 Status FlashAttentionScoreInfo::InitQKVTensorMap() {
283   int64_t kv_head_num_map = kv_split_ ? dev_matrix_n1_dim_ : -1;
284   auto dev_matrix_s2_dim = enable_ring_attention_ ? dev_matrix_s1_dim_ : -1;
285   switch (input_layout_) {
286     case FASInputLayoutMode::BSH:
287       inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_s1_dim_,
288                                                                       dev_matrix_n1_dim_};
289       inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
290                                                                     kv_head_num_map};
291       inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
292                                                                       kv_head_num_map};
293       break;
294     case FASInputLayoutMode::SBH:
295       inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_s1_dim_, dev_matrix_batch_dim_,
296                                                                       dev_matrix_n1_dim_};
297       inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_s2_dim, dev_matrix_batch_dim_,
298                                                                     kv_head_num_map};
299       inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_s2_dim, dev_matrix_batch_dim_,
300                                                                       kv_head_num_map};
301       break;
302     case FASInputLayoutMode::BNSD:
303       inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_,
304                                                                       dev_matrix_s1_dim_, -1};
305       inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, kv_head_num_map,
306                                                                     dev_matrix_s2_dim, -1};
307       inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, kv_head_num_map,
308                                                                       dev_matrix_s2_dim, -1};
309       break;
310     case FASInputLayoutMode::BSND:
311       inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_s1_dim_,
312                                                                       dev_matrix_n1_dim_, -1};
313       inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
314                                                                     kv_head_num_map, -1};
315       inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
316                                                                       kv_head_num_map, -1};
317       break;
318     case FASInputLayoutMode::TND:
319       inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_, -1};
320       inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, kv_head_num_map, -1};
321       inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, kv_head_num_map, -1};
322       break;
323     default:
324       MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
325       return FAILED;
326   }
327   return SUCCESS;
328 }
329 
InitInputsTensorMap()330 Status FlashAttentionScoreInfo::InitInputsTensorMap() {
331   inputs_tensor_map_ = std::vector<Shape>(ops::kFlashAttentionScoreInputsNum);
332   if (InitQKVTensorMap() != SUCCESS) {
333     return FAILED;
334   }
335 
336   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
337     auto real_shift_s1_map = real_shift_have_s1_dim_ ? dev_matrix_s1_dim_ : -1;
338     auto real_shift_batch_map = real_shift_have_batch_dim_ ? dev_matrix_batch_dim_ : -1;
339     inputs_tensor_map_[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_map, dev_matrix_n1_dim_,
340                                                                         real_shift_s1_map, -1};
341   }
342   if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
343     inputs_tensor_map_[ops::kFlashAttentionScoreInputDropMaskIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_,
344                                                                        dev_matrix_s1_dim_, -1};
345   }
346   if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
347     inputs_tensor_map_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
348   }
349   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
350     auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
351     int64_t dev_matrix_s1_dim_attn_mask = is_attn_mask_compressed_ ? -1 : dev_matrix_s1_dim_;
352     if (attn_mask_shape.size() == kSizeTwo) {
353       // attn_mask_shape: (S1, S2)
354       inputs_tensor_map_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {dev_matrix_s1_dim_attn_mask, -1};
355     } else if (attn_mask_shape.size() == kSizeFour) {
356       // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
357       auto attn_mask_batch_map = attn_mask_have_batch_dim_ ? dev_matrix_batch_dim_ : -1;
358       auto attn_mask_n1_map = attn_mask_have_n1_dim_ ? dev_matrix_n1_dim_ : -1;
359       inputs_tensor_map_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_mask_batch_map, attn_mask_n1_map,
360                                                                          dev_matrix_s1_dim_attn_mask, -1};
361     }
362   }
363   if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
364     inputs_tensor_map_[ops::kFlashAttentionScoreInputPrefixIndex] = {dev_matrix_batch_dim_};
365   }
366   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
367     inputs_tensor_map_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {dev_matrix_batch_dim_};
368   }
369   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
370     inputs_tensor_map_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {dev_matrix_batch_dim_};
371   }
372   inputs_tensor_map_.erase(std::remove(inputs_tensor_map_.begin(), inputs_tensor_map_.end(), Shape{}),
373                            inputs_tensor_map_.end());
374   return SUCCESS;
375 }
376 
InitAttnMaskSplittableInputs()377 Status FlashAttentionScoreInfo::InitAttnMaskSplittableInputs() {
378   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
379     int64_t s1_group = 2;
380     auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
381     int64_t attn_s1_group = is_attn_mask_compressed_ ? 0 : s1_group;
382     int64_t attn_s2_group = enable_ring_attention_ ? attn_s1_group : 0;
383     if (attn_mask_shape.size() == kSizeTwo) {
384       // attn_mask_shape: (S1, S2)
385       splittable_inputs_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_s1_group, attn_s2_group};
386     } else if (attn_mask_shape.size() == kSizeFour) {
387       int64_t n1_group = 1;
388       int64_t batch_group = 3;
389       // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
390       auto attn_mask_n1_group = attn_mask_shape[kIndex1] == 1 ? 0 : n1_group;
391       splittable_inputs_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {batch_group, attn_mask_n1_group, attn_s1_group,
392                                                                          0};
393     }
394   }
395   return SUCCESS;
396 }
397 
InitSplittableInputs()398 Status FlashAttentionScoreInfo::InitSplittableInputs() {
399   splittable_inputs_ = std::vector<Shape>(ops::kFlashAttentionScoreInputsNum);
400   int64_t batch_group = 3;
401   int64_t s1_group = 2;
402   int64_t n1_group = 1;
403   int64_t n2_group = kv_split_ ? n1_group : 0;
404   int64_t s2_group = enable_ring_attention_ ? s1_group : 0;
405   switch (input_layout_) {
406     case FASInputLayoutMode::BSH:
407       splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, s1_group, n1_group};
408       splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, s2_group, n2_group};
409       splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, s2_group, n2_group};
410       break;
411     case FASInputLayoutMode::SBH:
412       splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {s1_group, batch_group, n1_group};
413       splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {s2_group, batch_group, n2_group};
414       splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {s2_group, batch_group, n2_group};
415       break;
416     case FASInputLayoutMode::BNSD:
417       splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, n1_group, s1_group, 0};
418       splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, n2_group, s2_group, 0};
419       splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, n2_group, s2_group, 0};
420       break;
421     case FASInputLayoutMode::BSND:
422       splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, s1_group, n1_group, 0};
423       splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, s2_group, n2_group, 0};
424       splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, s2_group, n2_group, 0};
425       break;
426     case FASInputLayoutMode::TND:
427       splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, n1_group, 0};
428       splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, n2_group, 0};
429       splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, n2_group, 0};
430       break;
431     default:
432       MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
433       return FAILED;
434   }
435 
436   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
437     auto real_shift_s1_group = real_shift_have_s1_dim_ ? s1_group : 0;
438     splittable_inputs_[ops::kFlashAttentionScoreInputRealShiftIndex] = {batch_group, n1_group, real_shift_s1_group, 0};
439   }
440   if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
441     splittable_inputs_[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_group, n1_group, s1_group, 0};
442   }
443   if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
444     splittable_inputs_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
445   }
446   InitAttnMaskSplittableInputs();
447   if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
448     splittable_inputs_[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_group};
449   }
450   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
451     splittable_inputs_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {batch_group};
452   }
453   if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
454     splittable_inputs_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {batch_group};
455   }
456   splittable_inputs_.erase(std::remove(splittable_inputs_.begin(), splittable_inputs_.end(), Shape{}),
457                            splittable_inputs_.end());
458   return SUCCESS;
459 }
460 
InitQKVHeadAndSeqDimFromInputLayout()461 Status FlashAttentionScoreInfo::InitQKVHeadAndSeqDimFromInputLayout() {
462   switch (input_layout_) {
463     case FASInputLayoutMode::BSH:
464       qkv_batch_dim_ = kSizeZero;
465       qkv_seq_dim_ = kSizeOne;
466       qkv_head_dim_ = kSizeTwo;
467       break;
468     case FASInputLayoutMode::SBH:
469       qkv_seq_dim_ = kSizeZero;
470       qkv_batch_dim_ = kSizeOne;
471       qkv_head_dim_ = kSizeTwo;
472       break;
473     case FASInputLayoutMode::BNSD:
474       qkv_batch_dim_ = kSizeZero;
475       qkv_head_dim_ = kSizeOne;
476       qkv_seq_dim_ = kSizeTwo;
477       break;
478     case FASInputLayoutMode::BSND:
479       qkv_batch_dim_ = kSizeZero;
480       qkv_seq_dim_ = kSizeOne;
481       qkv_head_dim_ = kSizeTwo;
482       break;
483     case FASInputLayoutMode::TND:
484       qkv_batch_dim_ = kSizeZero;
485       qkv_seq_dim_ = kSizeZero;
486       qkv_head_dim_ = kSizeOne;
487       break;
488     default:
489       MS_LOG(ERROR) << name_ << ": Not support layout in parallel currently.";
490       return FAILED;
491   }
492   return SUCCESS;
493 }
494 
InitAttrs()495 Status FlashAttentionScoreInfo::InitAttrs() { return GetAttrs(); }
496 
CheckInputLayout()497 Status FlashAttentionScoreInfo::CheckInputLayout() {
498   if (InferSplitNumAndDevMatrixShapeByLayout() != SUCCESS) {
499     MS_LOG(ERROR) << name_ << ": Infer device matrix shape by layout failed.";
500     return FAILED;
501   }
502 
503   auto query_shape = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex];
504   auto key_shape = inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex];
505   if (s1_split_num_ > 1 && input_layout_ == FASInputLayoutMode::TND &&
506       (sparse_mode_ != ops::kSparseRightDownCausal || query_shape[0] != key_shape[0])) {
507     MS_LOG(ERROR)
508       << name_
509       << ": When input_layout is TND, sparse_mode is 3, and the T-dimension of query and key are the same, the "
510          "T-dimension of query can be sliced. query_shape: "
511       << query_shape << ", key_shape: " << key_shape << ", sparse_mode: " << sparse_mode_;
512     return FAILED;
513   }
514 
515   // Check all device matrix should be the same
516   if (ops::kFlashAttentionScoreInputQueryIndex >= inputs_tensor_info_.size()) {
517     return FAILED;
518   }
519   auto query_tensor_info = inputs_tensor_info_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputQueryIndex)];
520   dev_matrix_shape_ = query_tensor_info.tensor_layout().device_arrangement_origin().array();
521   return SUCCESS;
522 }
523 
CheckOutputLayout()524 Status FlashAttentionScoreInfo::CheckOutputLayout() { return SUCCESS; }
525 
InferOutputLayout()526 Status FlashAttentionScoreInfo::InferOutputLayout() {
527   auto query_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout();
528 
529   // Construct layout for softmax_max and softmax_sum
530   std::vector<Shape> softmax_max_sum_tensor_map;
531   Shape softmax_max_sum_tensor_shape;
532   if (input_layout_ == FASInputLayoutMode::TND) {
533     softmax_max_tensor_layout_ = query_layout;
534     softmax_sum_tensor_layout_ = query_layout;
535   } else {
536     softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_batch_dim_]);              // B
537     softmax_max_sum_tensor_shape.push_back(query_layout.tensor_shape_before().array()[qkv_batch_dim_]);  // B
538     softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_head_dim_]);               // N
539     softmax_max_sum_tensor_shape.push_back(head_num_);                                                   // N
540     softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_seq_dim_]);                // S
541     softmax_max_sum_tensor_shape.push_back(query_layout.tensor_shape_before().array()[qkv_seq_dim_]);    // S
542     softmax_max_sum_tensor_map.push_back({MAP_NONE});                                                    // 8
543     softmax_max_sum_tensor_shape.push_back(8);                                                           // 8
544     softmax_max_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
545                                                     softmax_max_sum_tensor_map,
546                                                     outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxMaxIndex]);
547     softmax_sum_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
548                                                     softmax_max_sum_tensor_map,
549                                                     outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxSumIndex]);
550   }
551 
552   // Construct layout for softmax_out
553   softmax_out_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
554                                                   std::vector<Shape>{{MAP_NONE}},
555                                                   outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxOutIndex]);
556   attention_out_tensor_layout_ = query_layout;
557   return SUCCESS;
558 }
559 
InferOutputTensorInfo()560 Status FlashAttentionScoreInfo::InferOutputTensorInfo() {
561   auto status = InferOutputLayout();
562   if (status != SUCCESS) {
563     return status;
564   }
565   (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_max_tensor_layout_));
566   (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_sum_tensor_layout_));
567   (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_out_tensor_layout_));
568   (void)outputs_tensor_info_.emplace_back(TensorInfo(attention_out_tensor_layout_));
569   return SUCCESS;
570 }
571 
InferAsLossDivisorByLayout()572 Status FlashAttentionScoreInfo::InferAsLossDivisorByLayout() {
573   if (outputs_tensor_info_.size() != ops::kFlashAttentionScoreOutputsNum) {
574     MS_LOG(ERROR)
575       << name_
576       << ": The size of outputs tensor info must be equal to the size of FlashAttentionScore's output size, but got  "
577       << outputs_tensor_info_.size() << " and " << ops::kFlashAttentionScoreOutputsNum;
578     return FAILED;
579   }
580 
581   auto attention_out_tensor_info = outputs_tensor_info_[ops::kFlashAttentionScoreOutputAttentionOutIndex];
582   TensorMaps attention_out_tensor_map = attention_out_tensor_info.tensor_layout().tensor_map_before();
583   if (attention_out_tensor_map.empty()) {
584     as_loss_divisor_ = stage_device_size_;
585     MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
586     return SUCCESS;
587   }
588 
589   auto out_dev_matrix_shape = attention_out_tensor_info.tensor_layout().device_arrangement_origin().array();
590   if (out_dev_matrix_shape.empty()) {
591     MS_LOG(INFO) << name_ << ": out_dev_matrix_shape is empty";
592     out_dev_matrix_shape = dev_matrix_shape_;
593   }
594   Shape squashed_tensor_map;
595   for (const auto &tensor_map : attention_out_tensor_map) {
596     std::copy(tensor_map.begin(), tensor_map.end(), std::back_inserter(squashed_tensor_map));
597   }
598 
599   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape, squashed_tensor_map);
600   MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape)
601                << ", the output tensor map is " << ShapeToString(squashed_tensor_map) << ", loss divisor is "
602                << as_loss_divisor_;
603   return SUCCESS;
604 }
605 
InferMirrorOpsByLayout()606 Status FlashAttentionScoreInfo::InferMirrorOpsByLayout() {
607   mirror_ops_.clear();
608   if (inputs_shape_.empty()) {
609     MS_LOG(INFO) << name_ << ": The inputs size is empty";
610     return SUCCESS;
611   }
612 
613   bool group_is_empty = true;
614   for (size_t i = 0; i < inputs_tensor_info_.size(); ++i) {
615     if (inputs_tensor_info_[i] == TensorInfo()) {
616       (void)mirror_ops_.emplace_back(OperatorVector());
617       continue;
618     }
619     auto input_tensor_layout = inputs_tensor_info_[i].tensor_layout();
620     auto repeated_rank_list = input_tensor_layout.InferRepeatedGroup();
621 
622     OperatorVector mirror_op;
623     if (repeated_rank_list.size() == 1) {
624       MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
625       mirror_ops_.push_back(mirror_op);
626       continue;
627     }
628     if (is_auto_parallel_) {
629       if (g_device_manager->CheckDeviceList(repeated_rank_list) != SUCCESS) {
630         MS_LOG(INFO) << name_ << ": Try to create communication group : " << repeated_rank_list
631                      << " failed in auto parallel mode, "
632                         "this error can be ignored in parallel strategies searching step";
633         return FAILED;
634       }
635       return SUCCESS;
636     }
637 
638     Group mirror_group;
639     if (g_device_manager->CreateGroup(repeated_rank_list, &mirror_group) != SUCCESS) {
640       MS_LOG(ERROR) << name_
641                     << ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
642                     << ", the full_name of node is: " << cnode_->fullname_with_scope();
643       return FAILED;
644     }
645     group_is_empty = false;
646     mirror_op = CreateMirrorOps(mirror_group.name(), mirror_group.GetDevNum());
647     mirror_ops_.push_back(mirror_op);
648   }
649 
650   if (group_is_empty) {
651     mirror_ops_.clear();
652     MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
653   }
654   return SUCCESS;
655 }
656 
GetAttrs()657 Status FlashAttentionScoreInfo::GetAttrs() {
658   InitIsInputPassed();
659   head_num_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputHeadNumIndex + 1);
660   keep_prob_ = GetInputValueFromCNode<float>(cnode_, ops::kFlashAttentionScoreInputKeepProbIndex + 1);
661   scale_value_ = GetInputValueFromCNode<float>(cnode_, ops::kFlashAttentionScoreInputScaleValueIndex + 1);
662   pre_tokens_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputPreTokensIndex + 1);
663   next_tokens_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputNextTokensIndex + 1);
664   input_layout_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputLayoutIndex + 1);
665   sparse_mode_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputSparseModeIndex + 1);
666   auto ms_context = MsContext::GetInstance();
667   MS_EXCEPTION_IF_NULL(ms_context);
668   enable_load_balance_ = ms_context->get_param<bool>(MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE);
669 
670   if (input_layout_ == FASInputLayoutMode::TND && enable_load_balance_) {
671     MS_LOG(WARNING) << name_ << ": Load balancing is not supported in the layout 'TND' and will be disabled.";
672     enable_load_balance_ = false;
673   }
674 
675   auto enable_ring_attention_iter = attrs_.find(ENABLE_RING_ATTENTION);
676   if (enable_ring_attention_iter != attrs_.end()) {
677     MS_EXCEPTION_IF_NULL(enable_ring_attention_iter->second);
678     if (enable_ring_attention_iter->second->isa<BoolImm>()) {
679       enable_ring_attention_ = enable_ring_attention_iter->second->cast<BoolImmPtr>()->value();
680       enable_load_balance_ = false;
681       MS_LOG(DEBUG) << "enable_ring_attention_: " << enable_ring_attention_;
682     } else {
683       MS_LOG(ERROR) << "enable_ring_attention should be bool";
684     }
685   }
686   if (enable_ring_attention_) {
687     if (input_layout_ != FASInputLayoutMode::BSH && input_layout_ != FASInputLayoutMode::BNSD) {
688       MS_LOG(ERROR) << "Ring attention currently only supports BSH and BNSD layout";
689     }
690     if (sparse_mode_ != 0) {
691       MS_LOG(ERROR) << "Ring attention currently only supports sparse mode 0";
692     }
693     if (keep_prob_ != 1.0) {
694       MS_LOG(ERROR) << "Ring attention currently only supports keep prob 1.0";
695     }
696     if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
697       MS_LOG(ERROR) << "Ring attention do not need input attn mask";
698     }
699   }
700 
701   is_attn_mask_compressed_ =
702     std::find(needCompressAttnMask.begin(), needCompressAttnMask.end(), sparse_mode_) != needCompressAttnMask.end();
703   need_update_op_attrs_mode_ = sparse_mode_ != ops::kSparseAllMask;
704   if (InitQKVHeadAndSeqDimFromInputLayout() != Status::SUCCESS) {
705     return FAILED;
706   }
707 
708   kv_split_ = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex][qkv_head_dim_] !=
709               inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex][qkv_head_dim_] * head_num_;
710 
711   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
712     auto real_shift_s1_dim =
713       inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputRealShiftIndex)).at(kIndex3);
714     real_shift_have_s1_dim_ = real_shift_s1_dim > 1;
715     auto real_shift_batch_dim =
716       inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputRealShiftIndex)).at(kIndex0);
717     real_shift_have_batch_dim_ = real_shift_batch_dim > 1;
718   }
719 
720   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
721     auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
722     if (attn_mask_shape.size() == kSizeFour) {
723       attn_mask_have_batch_dim_ = attn_mask_shape.at(kIndex0) > 1;
724       attn_mask_have_n1_dim_ = attn_mask_shape.at(kIndex1) > 1;
725     }
726   }
727   return SUCCESS;
728 }
729 
CheckStrategy(const StrategyPtr & strategy)730 Status FlashAttentionScoreInfo::CheckStrategy(const StrategyPtr &strategy) {
731   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
732     return FAILED;
733   }
734   auto strategies = strategy->GetInputDim();
735   auto query_strategy = strategies[ops::kFlashAttentionScoreInputQueryIndex];
736   auto key_strategy = strategies[ops::kFlashAttentionScoreInputKeyIndex];
737   auto value_strategy = strategies[ops::kFlashAttentionScoreInputValueIndex];
738   if (key_strategy != value_strategy) {
739     MS_LOG(ERROR) << name_ << ": The in_strategy both of 'key'( " << key_strategy << ") and 'value'" << value_strategy
740                   << ") must be same.";
741     return FAILED;
742   }
743   if (head_num_ % query_strategy[qkv_head_dim_] != 0) {
744     MS_LOG(ERROR) << name_ << ": head_num % query_strategy[" << qkv_head_dim_ << "] must be 0, but got " << head_num_
745                   << "(head_num) and " << query_strategy[qkv_head_dim_] << "(query_strategy[" << qkv_head_dim_ << "])";
746     return FAILED;
747   }
748   if (!kv_split_ && key_strategy[qkv_head_dim_] != 1) {
749     MS_LOG(ERROR) << name_ << ": Under the MQA,the hidden-dim of input 'key' cannot be split.";
750     return FAILED;
751   }
752 
753   if (input_layout_ == FASInputLayoutMode::TND) {
754     if (query_strategy[qkv_seq_dim_] != key_strategy[qkv_seq_dim_]) {
755       MS_LOG(ERROR)
756         << name_ << ": The split num of seq-dim between query and key must be the same when layout is 'TND'. But got "
757         << query_strategy[qkv_seq_dim_] << " and " << key_strategy[qkv_seq_dim_];
758       return FAILED;
759     }
760   } else {
761     auto s2_split_num = key_strategy[qkv_seq_dim_];
762     if (s2_split_num != 1 && !enable_ring_attention_) {
763       MS_LOG(ERROR) << name_ << ": The S-Dimension of input 'key' cannot be split, but got the strategy of key is "
764                     << key_strategy;
765       return FAILED;
766     }
767   }
768 
769   if (input_layout_ == FASInputLayoutMode::TND) {
770     batch_split_num_ = key_strategy[qkv_batch_dim_];
771     s1_split_num_ = query_strategy[qkv_batch_dim_] / batch_split_num_;
772   } else {
773     batch_split_num_ = query_strategy[qkv_batch_dim_];
774     s1_split_num_ = query_strategy[qkv_seq_dim_];
775   }
776   n1_split_num_ = query_strategy[qkv_head_dim_];
777 
778   s2_split_num_ = enable_ring_attention_ ? s1_split_num_ : 1;
779 
780   n2_split_num_ = key_strategy[qkv_head_dim_];
781 
782   if (kv_split_ && n1_split_num_ != n2_split_num_) {
783     MS_LOG(ERROR) << name_ << ": The split num of N1-dim and N2-dim must be equal if N2 > 1, but got " << n1_split_num_
784                   << " and " << n2_split_num_;
785     return FAILED;
786   }
787 
788   if (s1_split_num_ > 1 && input_layout_ == FASInputLayoutMode::TND) {
789     MS_LOG(ERROR)
790       << name_
791       << ": Currently, input_layout is TND, and the seq dimension of query is segmented. Please use Layout to "
792          "set the strategy.";
793     return FAILED;
794   }
795 
796   if (InitExpectedStrategies() != SUCCESS) {
797     return FAILED;
798   }
799   if (strategies != expect_strategies_) {
800     MS_LOG(ERROR) << name_ << ": The input strategy must be " << expect_strategies_ << ", but got " << strategies;
801     return FAILED;
802   }
803 
804   return SUCCESS;
805 }
806 
CheckStrategyForDynamicShape(const StrategyPtr &)807 Status FlashAttentionScoreInfo::CheckStrategyForDynamicShape(const StrategyPtr &) {
808   for (auto &cnode : cnodes_) {
809     // If DropoutGenMask -> Reshape -> FlashAttentionScore
810     auto reshape_node = cnode->input(ops::kFlashAttentionScoreInputDropMaskIndex + 1);
811     MS_EXCEPTION_IF_NULL(reshape_node);
812     if (!IsPrimitiveCNode(reshape_node, prim::kPrimReshape)) {
813       continue;
814     }
815 
816     MS_LOG(ERROR)
817       << name_ << ": it does not support dynamic shape if it need to replace dst-shape for reshape, the inputs' shape: "
818       << ShapesToString(inputs_shape_);
819     return FAILED;
820   }
821   return SUCCESS;
822 }
823 
InferDevMatrixShape()824 Status FlashAttentionScoreInfo::InferDevMatrixShape() {
825   switch (input_layout_) {
826     case FASInputLayoutMode::BSH:
827     case FASInputLayoutMode::BSND:
828     case FASInputLayoutMode::TND:
829       dev_matrix_shape_ = {batch_split_num_, s1_split_num_, n1_split_num_};
830       dev_matrix_batch_dim_ = kIndex2;
831       dev_matrix_s1_dim_ = kIndex1;
832       dev_matrix_n1_dim_ = kIndex0;
833       break;
834     case FASInputLayoutMode::SBH:
835       dev_matrix_shape_ = {s1_split_num_, batch_split_num_, n1_split_num_};
836       dev_matrix_s1_dim_ = kIndex2;
837       dev_matrix_batch_dim_ = kIndex1;
838       dev_matrix_n1_dim_ = kIndex0;
839       break;
840     case FASInputLayoutMode::BNSD:
841       dev_matrix_shape_ = {batch_split_num_, n1_split_num_, s1_split_num_};
842       dev_matrix_batch_dim_ = kIndex2;
843       dev_matrix_n1_dim_ = kIndex1;
844       dev_matrix_s1_dim_ = kIndex0;
845       break;
846     default:
847       MS_LOG(ERROR) << name_ << ": Not support layout: " << input_layout_;
848       return FAILED;
849   }
850   return SUCCESS;
851 }
852 
InferSplitNumAndDevMatrixShapeByLayout()853 Status FlashAttentionScoreInfo::InferSplitNumAndDevMatrixShapeByLayout() {
854   dev_matrix_shape_ =
855     inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout().device_arrangement_origin().array();
856   auto query_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout();
857   auto key_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputKeyIndex].tensor_layout();
858   auto query_tensor_map = query_layout.tensor_map_before();
859   auto query_batch_map = query_tensor_map.at(qkv_batch_dim_);
860   auto query_seq_map = query_tensor_map.at(qkv_seq_dim_);
861   auto query_head_map = query_tensor_map.at(qkv_head_dim_);
862   auto key_seq_map = key_layout.tensor_map_before().at(qkv_seq_dim_);
863 
864   auto dev_matrix_shape = dev_matrix_shape_;
865   if (input_layout_ == FASInputLayoutMode::TND) {
866     if (query_batch_map.size() == kSizeOne) {
867       dev_matrix_batch_dim_ = query_batch_map[0];
868       dev_matrix_s1_dim_ = MAP_NONE;
869     } else if (query_batch_map.size() == kSizeTwo) {
870       dev_matrix_batch_dim_ = query_batch_map[0];
871       dev_matrix_s1_dim_ = query_batch_map[1];
872     } else {
873       MS_LOG(ERROR) << name_
874                     << ": The seq-dimension of query can only be mapped upto 2 device matrix dimension, but got "
875                     << query_batch_map;
876       return FAILED;
877     }
878     n1_split_num_ = 1;
879     for (auto map_id : query_head_map) {
880       n1_split_num_ *= GetSplitNumByMapId(dev_matrix_shape, map_id);
881     }
882   } else {
883     if (query_batch_map.size() != 1 || query_seq_map.size() != 1 || query_head_map.size() != 1) {
884       MS_LOG(ERROR) << name_
885                     << ": Each dimension of query can only be mapped to one device matrix dimension, but got the "
886                        "tensor info of query is "
887                     << query_layout.ToString();
888       return FAILED;
889     }
890     dev_matrix_batch_dim_ = query_batch_map[0];
891     dev_matrix_s1_dim_ = query_seq_map[0];
892     dev_matrix_n1_dim_ = query_head_map[0];
893     n1_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_n1_dim_);
894   }
895   batch_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_batch_dim_);
896   s1_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_s1_dim_);
897   if (s1_split_num_ > 1 && GetSplitNumByTensorMap(dev_matrix_shape, query_seq_map) !=
898                              GetSplitNumByTensorMap(dev_matrix_shape, key_seq_map) * s1_split_num_) {
899     MS_LOG(EXCEPTION) << name_ << ": Cannot split the seq-dimension of key. query_seq_slice: "
900                       << GetSplitNumByTensorMap(dev_matrix_shape, query_seq_map)
901                       << ", key_seq_slice: " << GetSplitNumByTensorMap(dev_matrix_shape, key_seq_map)
902                       << ", s1_split_num: " << s1_split_num_;
903   }
904   return SUCCESS;
905 }
906 
InferTensorMap()907 Status FlashAttentionScoreInfo::InferTensorMap() {
908   if (InitInputsTensorMap() != SUCCESS) {
909     return FAILED;
910   }
911   if (input_layout_ == FASInputLayoutMode::TND) {
912     outputs_tensor_map_.push_back({inputs_tensor_map_[0]});  // softmax_max
913     outputs_tensor_map_.push_back({inputs_tensor_map_[0]});  // softmax_sum
914   } else {
915     outputs_tensor_map_.push_back({dev_matrix_batch_dim_, dev_matrix_n1_dim_, dev_matrix_s1_dim_, -1});  // softmax_max
916     outputs_tensor_map_.push_back({dev_matrix_batch_dim_, dev_matrix_n1_dim_, dev_matrix_s1_dim_, -1});  // softmax_sum
917   }
918   outputs_tensor_map_.push_back({-1});                   // softmax_out
919   outputs_tensor_map_.push_back(inputs_tensor_map_[0]);  // attention_out
920   return SUCCESS;
921 }
922 
GetSplitIdAndRank()923 std::vector<int64_t> FlashAttentionScoreInfo::GetSplitIdAndRank() {
924   CheckGlobalDeviceManager();
925   int64_t rank = g_device_manager->global_rank();
926   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
927   RankList group_devices;
928   int64_t seq_dim = SizeToLong(dev_matrix_shape_.size()) - dev_matrix_s1_dim_ - 1;
929   if (dev_matrix.GetDevicesAlongDim(seq_dim, &group_devices) != SUCCESS) {
930     MS_LOG(ERROR) << name_ << " get group devices along dim " << seq_dim << " failed.";
931   }
932   auto iter = std::find(group_devices.begin(), group_devices.end(), rank);
933   if (iter == group_devices.end()) {
934     MS_LOG(EXCEPTION) << "FlashAttentionScore S1 sequence parallel get split id failed. "
935                       << "rank " << rank << " not in group " << group_devices;
936   }
937   int64_t split_id = iter - group_devices.begin();
938   int64_t target_split_id = s1_split_num_ - split_id - 1;
939   int64_t target_rank_id = group_devices[target_split_id];
940   return std::vector<int64_t>({rank, target_rank_id, split_id, target_split_id});
941 }
942 
GetAttentionMaskAttrs(const int64_t split_id,const int64_t split_num)943 std::tuple<int64_t, int64_t> FlashAttentionScoreInfo::GetAttentionMaskAttrs(const int64_t split_id,
944                                                                             const int64_t split_num) {
945   int64_t kv_seq_length;
946   int64_t q_seq_length;
947   kv_seq_length = inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex][qkv_seq_dim_];
948   q_seq_length = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex][qkv_seq_dim_];
949   int64_t q_len_each_split = q_seq_length / split_num;
950   int64_t new_pre_tokens =
951     (sparse_mode_ == ops::kSparseDefaultMask || sparse_mode_ == ops::kSparseBand) ? pre_tokens_ : kv_seq_length;
952   int64_t new_next_tokens =
953     (sparse_mode_ == ops::kSparseDefaultMask || sparse_mode_ == ops::kSparseBand) ? next_tokens_ : 0;
954   switch (opAttrUpdateMap.at(sparse_mode_)) {
955     case kLeftUpToLeftUp:
956       new_pre_tokens = LongAdd(new_pre_tokens, -split_id * q_len_each_split);
957       new_next_tokens = LongAdd(new_next_tokens, split_id * q_len_each_split);
958       break;
959     case kLeftUpToRightDown:
960       new_pre_tokens = LongAdd(new_pre_tokens, (kv_seq_length - (split_id + 1) * q_len_each_split));
961       new_next_tokens = LongAdd(new_next_tokens, -(kv_seq_length - (split_id + 1) * q_len_each_split));
962       break;
963     case kRightDownToRightDown:
964       new_pre_tokens = LongAdd(new_pre_tokens, (split_num - split_id - 1) * (q_seq_length / split_num));
965       new_next_tokens = LongAdd(new_next_tokens, -(split_num - split_id - 1) * (q_seq_length / split_num));
966       break;
967     default:
968       MS_LOG(EXCEPTION) << "Invalid sparse mode " << sparse_mode_ << ", sparse mode should be one of [0, 2, 3, 4].";
969   }
970   return std::make_tuple(new_pre_tokens, new_next_tokens);
971 }
972 
ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr & cnode)973 Status FlashAttentionScoreInfo::ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr &cnode) {
974   std::vector<int64_t> split_info = GetSplitIdAndRank();
975   int64_t tq = inputs_shape_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputQueryIndex)][qkv_batch_dim_];
976   int64_t tk = inputs_shape_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputKeyIndex)][qkv_batch_dim_];
977   int64_t slice_tq = tq / s1_split_num_ / batch_split_num_;
978   int64_t slice_tk = tk / batch_split_num_;
979   int64_t split_id = split_info[kIndex2];
980   int64_t offset = slice_tq * split_id;
981   if (!is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] ||
982       !is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
983     MS_LOG(ERROR) << name_ << ": The input 'actual_seq_qlen' and 'actual_seq_kvlen' cannot be None under 'TND'.";
984     return FAILED;
985   }
986   auto actual_seq_qlen_input_index = ops::kFlashAttentionScoreInputActualSeqQlenIndex + 1;
987   auto actual_seq_kvlen_input_index = ops::kFlashAttentionScoreInputActualSeqKVlenIndex + 1;
988   auto actual_seq_qlen_node = cnode->input(actual_seq_qlen_input_index);
989   auto actual_seq_kvlen_node = cnode->input(actual_seq_kvlen_input_index);
990 
991   auto func_graph = cnode->func_graph();
992   MS_EXCEPTION_IF_NULL(func_graph);
993   auto manager = func_graph->manager();
994   MS_EXCEPTION_IF_NULL(manager);
995 
996   // new_actual_seq_qlen = clip(actual_seq_qlen - offset, 0, slice_tq)
997   auto qlen_offset_sub_cnode =
998     func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_qlen_node, CreateInt32Tensor(offset, true)});
999   auto new_actual_seq_qlen_cnode =
1000     func_graph->NewCNode({NewValueNode(prim::kPrimClipByValue), qlen_offset_sub_cnode, CreateInt32Tensor(0, true),
1001                           CreateInt32Tensor(slice_tq, true)});
1002   manager->SetEdge(cnode, actual_seq_qlen_input_index, new_actual_seq_qlen_cnode);
1003 
1004   // new_actual_seq_kvlen = actual_seq_kvlen - (ReLU(actual_seq_qlen - offset) - new_actual_seq_qlen)
1005   auto relu_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimReLU), qlen_offset_sub_cnode});
1006   auto kvlen_offset_sub_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_qlen_node, relu_cnode});
1007   auto tmp_new_actual_seq_kvlen_cnode =
1008     func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_kvlen_node, kvlen_offset_sub_cnode});
1009 
1010   // new_actual_seq_kvlen[actual_seq_kvlen == slice_tk] = slice_tk
1011   auto equal =
1012     func_graph->NewCNode({NewValueNode(prim::kPrimEqual), actual_seq_kvlen_node, CreateInt32Tensor(slice_tk, true)});
1013   auto new_actual_seq_kvlen_cnode = func_graph->NewCNode(
1014     {NewValueNode(prim::kPrimSelect), equal, actual_seq_kvlen_node, tmp_new_actual_seq_kvlen_cnode});
1015   manager->SetEdge(cnode, actual_seq_kvlen_input_index, new_actual_seq_kvlen_cnode);
1016 
1017   return SUCCESS;
1018 }
1019 
ReplaceNodeInputOrAttrs()1020 void FlashAttentionScoreInfo::ReplaceNodeInputOrAttrs() {
1021   for (auto &cnode : cnodes_) {
1022     SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputHeadNumIndex + 1, head_num_ / n1_split_num_);
1023     if (s1_split_num_ > 1 && !enable_load_balance_ && need_update_op_attrs_mode_) {
1024       if (input_layout_ == FASInputLayoutMode::TND) {
1025         if (ReplaceActualSeqLenForSplitSeqInTnd(cnode) != SUCCESS) {
1026           MS_LOG(EXCEPTION) << name_ << ": Replace actual_seq_qlen and actual_seq_kvlen failed.";
1027         }
1028       } else {
1029         int64_t new_pre_tokens, new_next_tokens;
1030         std::vector<int64_t> split_info = GetSplitIdAndRank();
1031         int64_t split_id = split_info[kIndex2];
1032         std::tie(new_pre_tokens, new_next_tokens) = GetAttentionMaskAttrs(split_id, s1_split_num_);
1033         int64_t new_sparse_mode = is_attn_mask_compressed_ ? ops::kSparseBand : sparse_mode_;
1034         SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputSparseModeIndex + 1, new_sparse_mode);
1035         SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputPreTokensIndex + 1, new_pre_tokens);
1036         SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputNextTokensIndex + 1, new_next_tokens);
1037       }
1038     }
1039     // If DropoutGenMask -> Reshape -> FlashAttentionScore, replace its.
1040     auto reshape_node = cnode->input(ops::kFlashAttentionScoreInputDropMaskIndex + 1);
1041     MS_EXCEPTION_IF_NULL(reshape_node);
1042     if (!IsPrimitiveCNode(reshape_node, prim::kPrimReshape)) {
1043       continue;
1044     }
1045     auto reshape_cnode = reshape_node->cast<CNodePtr>();
1046     if (!IsPrimitiveCNode(reshape_cnode->input(kIndex1), prim::kPrimDropoutGenMask)) {
1047       continue;
1048     }
1049     auto dropout_gen_mask_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
1050     // Update slice_shape for ReShape
1051     Shape input_slice_shape = inputs_tensor_info_.at(ops::kFlashAttentionScoreInputDropMaskIndex).slice_shape();
1052     ValuePtr new_shape = MakeValue(input_slice_shape);
1053     AnfNodePtr val = NewValueNode(new_shape);
1054     auto manager = cnode->func_graph()->manager();
1055     MS_EXCEPTION_IF_NULL(manager);
1056     manager->SetEdge(reshape_cnode, kIndex2, val);
1057     // Update slice shape and seed for DropoutGenMask
1058     UpdateDropoutGenMaskSliceShapeAndSeed(dropout_gen_mask_cnode);
1059   }
1060 }
1061 
LoadBalanceSplitAlongSeqDim(size_t input_index,GenerateGraph * gen_g,AnfNodePtr * split_node,AnfNodePtr * keep_node,AnfNodePtr * exchange_node)1062 void FlashAttentionScoreInfo::LoadBalanceSplitAlongSeqDim(size_t input_index, GenerateGraph *gen_g,
1063                                                           AnfNodePtr *split_node, AnfNodePtr *keep_node,
1064                                                           AnfNodePtr *exchange_node) {
1065   OperatorAttrs split_attrs;
1066   int64_t q_split_axis;
1067   switch (input_index) {
1068     case ops::kFlashAttentionScoreInputQueryIndex:
1069       q_split_axis = SizeToLong(qkv_seq_dim_);
1070       split_attrs = {std::make_pair(AXIS, MakeValue(q_split_axis)),
1071                      std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1072       *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1073       *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1074       *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1075       break;
1076     case ops::kFlashAttentionScoreInputRealShiftIndex:
1077       if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1078         split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(kInputRealShiftSeqDim)),
1079                        std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1080         *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1081         *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1082         *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1083       } else {
1084         *keep_node = gen_g->virtual_input_node();
1085         *exchange_node = gen_g->virtual_input_node();
1086       }
1087       break;
1088     case ops::kFlashAttentionScoreInputDropMaskIndex:
1089       if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1090         split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(kInputDropMaskSeqDim)),
1091                        std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1092         *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1093         *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1094         *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1095       } else {
1096         *keep_node = gen_g->virtual_input_node();
1097         *exchange_node = gen_g->virtual_input_node();
1098       }
1099       break;
1100     case ops::kFlashAttentionScoreInputAttnMaskIndex:
1101       if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1102         auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
1103         if (attn_mask_shape.size() == kSizeTwo) {
1104           split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(0)),
1105                          std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1106         } else {
1107           split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(2)),
1108                          std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1109         }
1110         *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1111         *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1112         *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1113       } else {
1114         *keep_node = gen_g->virtual_input_node();
1115         *exchange_node = gen_g->virtual_input_node();
1116       }
1117       break;
1118     default:
1119       MS_LOG(EXCEPTION) << "Invalid input index. Only 0(query), 3(real_shift), 4(drop_mask) and 6(attn_mask)"
1120                         << "support sequence dim parallel, but got " << input_index;
1121   }
1122 }
1123 
LoadBalanceExchange(const int64_t all_gather_idx,const Group & group,const AnfNodePtr & input_node,AnfNodePtr * exchange_node,GenerateGraph * gen_g)1124 void FlashAttentionScoreInfo::LoadBalanceExchange(const int64_t all_gather_idx, const Group &group,
1125                                                   const AnfNodePtr &input_node, AnfNodePtr *exchange_node,
1126                                                   GenerateGraph *gen_g) {
1127   OperatorAttrs all_gather_attrs = {std::make_pair(GROUP, MakeValue(group.name()))};
1128   OperatorAttrs all_gather_split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(0)),
1129                                           std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1130   auto all_gather_node = gen_g->PushBack({gen_g->NewOpInst(ALL_GATHER, all_gather_attrs), input_node});
1131   auto split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, all_gather_split_attrs), all_gather_node});
1132   *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), split_node, CreatInt64Imm(all_gather_idx)});
1133 }
1134 
GetFlashAttentionScoreOpNode(int64_t split_id,int64_t split_num,const AnfNodePtr & q,const AnfNodePtr & real_shift,const AnfNodePtr & drop_mask,const AnfNodePtr & attn_mask,AnfNodePtr * fa_op,GenerateGraph * gen_g)1135 void FlashAttentionScoreInfo::GetFlashAttentionScoreOpNode(int64_t split_id, int64_t split_num, const AnfNodePtr &q,
1136                                                            const AnfNodePtr &real_shift, const AnfNodePtr &drop_mask,
1137                                                            const AnfNodePtr &attn_mask, AnfNodePtr *fa_op,
1138                                                            GenerateGraph *gen_g) {
1139   int64_t new_sparse_mode = is_attn_mask_compressed_ ? ops::kSparseBand : sparse_mode_;
1140   int64_t new_pre_tokens, new_next_tokens;
1141   if (!need_update_op_attrs_mode_) {
1142     new_pre_tokens = pre_tokens_;
1143     new_next_tokens = next_tokens_;
1144   } else {
1145     std::tie(new_pre_tokens, new_next_tokens) = GetAttentionMaskAttrs(split_id, split_num);
1146   }
1147   OperatorAttrs fa_attrs = {std::make_pair(HEAD_NUM, MakeValue(head_num_ / n1_split_num_)),
1148                             std::make_pair(KEEP_PROB, MakeValue(keep_prob_)),
1149                             std::make_pair(SCALE_VALUE, MakeValue(scale_value_)),
1150                             std::make_pair(PRE_TOKENS, MakeValue(new_pre_tokens)),
1151                             std::make_pair(NEXT_TOKENS, MakeValue(new_next_tokens)),
1152                             std::make_pair(INNER_PRECISE, MakeValue<int64_t>(0)),
1153                             std::make_pair(INPUT_LAYOUT, MakeValue(input_layout_)),
1154                             std::make_pair(SPARSE_MODE, MakeValue<int64_t>(new_sparse_mode))};
1155   *fa_op = gen_g->PushBack({gen_g->NewOpInst(FLASH_ATTENTION_SCORE, fa_attrs), q, gen_g->virtual_input_node(),
1156                             gen_g->virtual_input_node(), real_shift, drop_mask, gen_g->virtual_input_node(), attn_mask,
1157                             gen_g->virtual_input_node(), gen_g->virtual_input_node(), gen_g->virtual_input_node()});
1158 }
1159 
ReplaceGraphGetInputNodes(const AnfNodePtr & q_split,const AnfNodePtr & real_shift_split,const AnfNodePtr & drop_mask_split,const AnfNodePtr & attn_mask_split,const AnfNodePtr & flash_attention_score_keep,const AnfNodePtr & flash_attention_score_target)1160 std::vector<std::pair<AnfNodePtr, int64_t>> FlashAttentionScoreInfo::ReplaceGraphGetInputNodes(
1161   const AnfNodePtr &q_split, const AnfNodePtr &real_shift_split, const AnfNodePtr &drop_mask_split,
1162   const AnfNodePtr &attn_mask_split, const AnfNodePtr &flash_attention_score_keep,
1163   const AnfNodePtr &flash_attention_score_target) {
1164   std::pair<AnfNodePtr, int64_t> real_shift_input;
1165   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1166     real_shift_input = std::make_pair(real_shift_split, kIndex4);
1167   } else {
1168     real_shift_input = std::make_pair(flash_attention_score_keep, kIndex4);
1169   }
1170   std::pair<AnfNodePtr, int64_t> drop_mask_input;
1171   if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1172     drop_mask_input = std::make_pair(drop_mask_split, kIndex5);
1173   } else {
1174     drop_mask_input = std::make_pair(flash_attention_score_keep, kIndex5);
1175   }
1176   std::pair<AnfNodePtr, int64_t> attn_mask_input;
1177   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1178     attn_mask_input = std::make_pair(attn_mask_split, kIndex7);
1179   } else {
1180     attn_mask_input = std::make_pair(flash_attention_score_keep, kIndex7);
1181   }
1182 
1183   std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = {std::make_pair(q_split, kIndex1),
1184                                                               std::make_pair(flash_attention_score_keep, kIndex2),
1185                                                               std::make_pair(flash_attention_score_keep, kIndex3),
1186                                                               real_shift_input,
1187                                                               drop_mask_input,
1188                                                               std::make_pair(flash_attention_score_keep, kIndex6),
1189                                                               attn_mask_input,
1190                                                               std::make_pair(flash_attention_score_keep, kIndex8),
1191                                                               std::make_pair(flash_attention_score_keep, kIndex9),
1192                                                               std::make_pair(flash_attention_score_keep, kIndex10),
1193                                                               std::make_pair(flash_attention_score_target, kIndex2),
1194                                                               std::make_pair(flash_attention_score_target, kIndex3)};
1195   if (!is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1196     (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex4));
1197   }
1198   if (!is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1199     (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex5));
1200   }
1201   (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex6));
1202   if (!is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] || is_attn_mask_compressed_) {
1203     (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex7));
1204   }
1205   inputs_nodes.insert(inputs_nodes.end(), {std::make_pair(flash_attention_score_target, kIndex8),
1206                                            std::make_pair(flash_attention_score_target, kIndex9),
1207                                            std::make_pair(flash_attention_score_target, kIndex10)});
1208   return inputs_nodes;
1209 }
1210 
ComputeReplaceGraphForLoadBalance(const CNodePtr & cnode)1211 Status FlashAttentionScoreInfo::ComputeReplaceGraphForLoadBalance(const CNodePtr &cnode) {
1212   GenerateGraph gen_g = GenerateGraph(attrs_);
1213   if (gen_g.Init(cnode) != SUCCESS) {
1214     return FAILED;
1215   }
1216   CheckGlobalDeviceManager();
1217   std::vector<int64_t> split_info = GetSplitIdAndRank();
1218   int64_t rank_id = split_info[kIndex0];
1219   int64_t target_rank_id = split_info[kIndex1];
1220   int64_t split_id = split_info[kIndex2];
1221   int64_t target_split_id = split_info[kIndex3];
1222   Group group;
1223   RankList swap_group_devices = {rank_id, target_rank_id};
1224   if (g_device_manager->CreateGroup(swap_group_devices, &group) != SUCCESS) {
1225     MS_LOG(ERROR) << "Create communication group for " << swap_group_devices << " failed";
1226     return FAILED;
1227   }
1228 
1229   AnfNodePtr q_split, q_keep, q_exchange;
1230   LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputQueryIndex, &gen_g, &q_split, &q_keep, &q_exchange);
1231   AnfNodePtr real_shift_split, real_shift_keep, real_shift_exchange;
1232   LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputRealShiftIndex, &gen_g, &real_shift_split, &real_shift_keep,
1233                               &real_shift_exchange);
1234   AnfNodePtr drop_mask_split, drop_mask_keep, drop_mask_exchange;
1235   LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputDropMaskIndex, &gen_g, &drop_mask_split, &drop_mask_keep,
1236                               &drop_mask_exchange);
1237   AnfNodePtr attn_mask_split, attn_mask_keep, attn_mask_exchange;
1238   LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputAttnMaskIndex, &gen_g, &attn_mask_split, &attn_mask_keep,
1239                               &attn_mask_exchange);
1240 
1241   AnfNodePtr flash_attention_score_keep;
1242   GetFlashAttentionScoreOpNode(split_id * kLoadBalanceSplitNum, s1_split_num_ * kLoadBalanceSplitNum, q_keep,
1243                                real_shift_keep, drop_mask_keep, attn_mask_keep, &flash_attention_score_keep, &gen_g);
1244   auto softmax_max_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1245                                           CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxMaxIndex)});
1246   auto softmax_sum_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1247                                           CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxSumIndex)});
1248   auto softmax_out_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1249                                           CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxOutIndex)});
1250   auto attention_out_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1251                                             CreatInt64Imm(ops::kFlashAttentionScoreOutputAttentionOutIndex)});
1252 
1253   const int64_t all_gather_idx = (split_id < target_split_id) ? 1 : 0;
1254   AnfNodePtr q_target;
1255   LoadBalanceExchange(all_gather_idx, group, q_exchange, &q_target, &gen_g);
1256   AnfNodePtr real_shift_target;
1257   if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1258     LoadBalanceExchange(all_gather_idx, group, real_shift_exchange, &real_shift_target, &gen_g);
1259   } else {
1260     real_shift_target = gen_g.virtual_input_node();
1261   }
1262   AnfNodePtr drop_mask_target;
1263   if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1264     LoadBalanceExchange(all_gather_idx, group, drop_mask_exchange, &drop_mask_target, &gen_g);
1265   } else {
1266     drop_mask_target = gen_g.virtual_input_node();
1267   }
1268   AnfNodePtr attn_mask_target;
1269   if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1270     LoadBalanceExchange(all_gather_idx, group, attn_mask_exchange, &attn_mask_target, &gen_g);
1271   } else {
1272     attn_mask_target = gen_g.virtual_input_node();
1273   }
1274 
1275   AnfNodePtr flash_attention_score_target;
1276   GetFlashAttentionScoreOpNode(target_split_id * kLoadBalanceSplitNum + 1, s1_split_num_ * kLoadBalanceSplitNum,
1277                                q_target, real_shift_target, drop_mask_target, attn_mask_target,
1278                                &flash_attention_score_target, &gen_g);
1279   auto softmax_max_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1280                                             CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxMaxIndex)});
1281   auto softmax_sum_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1282                                             CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxSumIndex)});
1283   auto attention_out_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1284                                               CreatInt64Imm(ops::kFlashAttentionScoreOutputAttentionOutIndex)});
1285 
1286   AnfNodePtr attention_out_exchange;
1287   LoadBalanceExchange(all_gather_idx, group, attention_out_target, &attention_out_exchange, &gen_g);
1288 
1289   int64_t softmax_concat_axis = kOutputSoftmaxSeqDim;
1290   auto softmax_max_maketuple =
1291     gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_max_keep, softmax_max_target});
1292   auto softmax_max =
1293     gen_g.PushBack({gen_g.NewOpInst(CONCAT), softmax_max_maketuple, CreatInt64Imm(softmax_concat_axis)});
1294   auto softmax_sum_maketuple =
1295     gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_sum_keep, softmax_sum_target});
1296   auto softmax_sum =
1297     gen_g.PushBack({gen_g.NewOpInst(CONCAT), softmax_sum_maketuple, CreatInt64Imm(softmax_concat_axis)});
1298   int64_t attention_out_concat_axis = SizeToLong(qkv_seq_dim_);
1299   auto attention_out_maketuple =
1300     gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), attention_out_keep, attention_out_exchange});
1301   auto attention_out =
1302     gen_g.PushBack({gen_g.NewOpInst(CONCAT), attention_out_maketuple, CreatInt64Imm(attention_out_concat_axis)});
1303   auto output_maketuple =
1304     gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_max, softmax_sum, softmax_out_keep, attention_out});
1305 
1306   std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes =
1307     ReplaceGraphGetInputNodes(q_split, real_shift_split, drop_mask_split, attn_mask_split, flash_attention_score_keep,
1308                               flash_attention_score_target);
1309 
1310   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
1311     std::make_pair(inputs_nodes, output_maketuple));
1312   return SUCCESS;
1313 }
1314 
replace_graph(const CNodePtr & cnode)1315 ReplaceGraphPtr FlashAttentionScoreInfo::replace_graph(const CNodePtr &cnode) {
1316   if (s1_split_num_ > 1 && enable_load_balance_) {
1317     if (ComputeReplaceGraphForLoadBalance(cnode) != SUCCESS) {
1318       MS_LOG(EXCEPTION) << name_
1319                         << ": FlashAttentionScore S1 sequence parallel with load balance get replace graph failed";
1320     }
1321   }
1322   return replace_graph_;
1323 }
1324 
InferAsLossDivisor()1325 Status FlashAttentionScoreInfo::InferAsLossDivisor() {
1326   if (outputs_tensor_map_.empty()) {
1327     MS_LOG(ERROR) << name_ << ": The size of outputs tensor map is empty";
1328     return FAILED;
1329   }
1330   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
1331   MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
1332                << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
1333                << ", as_loss_divisor_ is " << as_loss_divisor_;
1334   return SUCCESS;
1335 }
1336 
GenerateOpStrategies(int64_t stage_id)1337 std::vector<StrategyPtr> FlashAttentionScoreInfo::GenerateOpStrategies(int64_t stage_id) {
1338   InitSplittableInputs();
1339   std::vector<StrategyPtr> sp_vector;
1340   if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs_, &sp_vector) != SUCCESS) {
1341     MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
1342   }
1343   if (sp_vector.empty()) {
1344     MS_LOG(EXCEPTION) << name_ << ": No valid strategy.";
1345   }
1346   return sp_vector;
1347 }
1348 
ReComputeBatchSplitFlagList()1349 void FlashAttentionScoreInfo::ReComputeBatchSplitFlagList() {
1350   split_flag_list_ = std::vector<bool>(inputs_shape_.size(), true);
1351 }
1352 
InferMirrorOps()1353 Status FlashAttentionScoreInfo::InferMirrorOps() {
1354   if (OperatorInfo::InferMirrorOps() != SUCCESS) {
1355     return FAILED;
1356   }
1357   // No need to insert mirror ops
1358   if (mirror_ops_.empty()) {
1359     return SUCCESS;
1360   }
1361   // Insert empty OperatorInfo for optional input
1362   size_t cur_index = 0;
1363   std::vector<OperatorVector> real_mirror_ops(input_value_.size(), OperatorVector());
1364   for (size_t i = 0; i < input_value_.size(); ++i) {
1365     if (is_input_passed_[i]) {
1366       real_mirror_ops[i] = mirror_ops_[cur_index++];
1367     }
1368     mirror_ops_ = real_mirror_ops;
1369   }
1370   return SUCCESS;
1371 }
1372 
1373 REGISTER(FlashAttentionScoreInfo);
1374 }  // namespace parallel
1375 }  // namespace mindspore
1376