1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/flash_attention_score_info.h"
18
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <tuple>
23 #include <map>
24 #include <algorithm>
25
26 #include "ir/value.h"
27 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
28 #include "frontend/parallel/device_matrix.h"
29 #include "frontend/parallel/dynamic_creator.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31 #include "frontend/parallel/graph_util/generate_graph.h"
32 #include "frontend/parallel/graph_util/graph_utils.h"
33 #include "mindspore/core/ops/nn_ops.h"
34 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
35 #include "ops/op_enum.h"
36
37 namespace mindspore {
38 using mindspore::ops::FASInputLayoutMode;
39 namespace parallel {
40 namespace {
41 constexpr size_t kInputRealShiftSeqDim = 2;
42 constexpr size_t kInputDropMaskSeqDim = 2;
43 constexpr size_t kOutputSoftmaxSeqDim = 2;
44 constexpr int64_t kLoadBalanceSplitNum = 2;
45 enum OpAttrUpdateMode : int64_t {
46 kLeftUpToLeftUp = 0,
47 kLeftUpToRightDown = 1,
48 kRightDownToRightDown = 2,
49 };
50 const std::vector<int64_t> needCompressAttnMask = {ops::kSparseLeftUpCausal, ops::kSparseRightDownCausal,
51 ops::kSparseBand, ops::kSparseBlockLocal};
52 const std::map<int64_t, int64_t> opAttrUpdateMap = {{ops::kSparseDefaultMask, kLeftUpToLeftUp},
53 {ops::kSparseLeftUpCausal, kLeftUpToRightDown},
54 {ops::kSparseRightDownCausal, kRightDownToRightDown},
55 {ops::kSparseBand, kRightDownToRightDown},
56 {ops::kSparseBlockLocal, kLeftUpToRightDown}};
57
GetNonMonadInputSize(const CNodePtr & cnode)58 size_t GetNonMonadInputSize(const CNodePtr &cnode) {
59 size_t cnode_non_monad_size = cnode->size();
60 for (auto &input : cnode->inputs()) {
61 if (HasAbstractMonad(input)) {
62 cnode_non_monad_size--;
63 }
64 }
65 return cnode_non_monad_size;
66 }
67
NewSeedGeneration()68 int64_t NewSeedGeneration() {
69 static int64_t seed_generation = 0;
70 ++seed_generation;
71 return seed_generation;
72 }
73
LongAdd(int64_t base,int64_t shift)74 int64_t LongAdd(int64_t base, int64_t shift) {
75 int64_t result;
76 if (shift > 0) {
77 if (base > INT_MAX - shift) {
78 result = INT_MAX;
79 } else {
80 result = base + shift;
81 }
82 } else {
83 if (base < INT_MIN - shift) {
84 result = INT_MIN;
85 } else {
86 result = base + shift;
87 }
88 }
89 return result;
90 }
91
GetSplitNumByMapId(const Shape & dev_matrix,int64_t map_id)92 int64_t GetSplitNumByMapId(const Shape &dev_matrix, int64_t map_id) {
93 if (map_id == MAP_NONE) {
94 return NO_SPLIT_STRATEGY;
95 }
96 auto axis = dev_matrix.size() - 1 - LongToSize(map_id);
97 if (axis >= dev_matrix.size()) {
98 MS_LOG(EXCEPTION) << "The tensor map id (" << map_id
99 << ") is out of device matrix's range. device_matrix: " << dev_matrix;
100 }
101 return dev_matrix[axis];
102 }
103
GetSplitNumByTensorMap(const Shape & dev_matrix,const Shape & tensor_map)104 int64_t GetSplitNumByTensorMap(const Shape &dev_matrix, const Shape &tensor_map) {
105 auto split_num = std::accumulate(tensor_map.begin(), tensor_map.end(), 1, [&dev_matrix](int64_t a, int64_t map_id) {
106 return a * GetSplitNumByMapId(dev_matrix, map_id);
107 });
108 return split_num;
109 }
110 } // namespace
111
UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr & dropout_gen_mask_cnode)112 void FlashAttentionScoreInfo::UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr &dropout_gen_mask_cnode) {
113 if (!IsPrimitiveCNode(dropout_gen_mask_cnode, prim::kPrimDropoutGenMask)) {
114 return;
115 }
116
117 // Update seed according rank_id for DropoutGenMask
118 PrimitivePtr prim = GetCNodePrimitive(dropout_gen_mask_cnode);
119 auto seed_0 = GetValue<int64_t>(prim->GetAttr(SEED0));
120 auto seed_1 = GetValue<int64_t>(prim->GetAttr(SEED1));
121 int64_t rank_id = g_device_manager->rank_index_in_stage();
122 int64_t seed_bias = 0;
123 // When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
124 if (seed_0 == 0 && seed_1 == 0) {
125 seed_bias = NewSeedGeneration();
126 }
127 MS_EXCEPTION_IF_ZERO("repeated_calc_num_", repeated_calc_num_);
128 if (repeated_num_in_dev_matrix_right_) {
129 seed_bias += rank_id / repeated_calc_num_;
130 } else {
131 int64_t device_num = stage_device_size_;
132 MS_EXCEPTION_IF_ZERO("device_num", device_num);
133 seed_bias += rank_id % (device_num / repeated_calc_num_);
134 }
135 auto clone_prim = prim->Clone();
136 clone_prim->set_attr(SEED0, MakeValue<int64_t>(seed_0 + seed_bias));
137 clone_prim->set_attr(SEED1, MakeValue<int64_t>(seed_1 + seed_bias));
138 auto func_graph = dropout_gen_mask_cnode->func_graph();
139 MS_EXCEPTION_IF_NULL(func_graph);
140 auto manager = func_graph->manager();
141 MS_EXCEPTION_IF_NULL(manager);
142 manager->SetEdge(dropout_gen_mask_cnode, 0, NewValueNode(clone_prim)->cast<AnfNodePtr>());
143
144 // Update slice shape for DropoutGenMask and Reshape
145 Shape input_slice_shape = inputs_tensor_info_.at(ops::kFlashAttentionScoreInputDropMaskIndex).slice_shape();
146 constexpr int64_t BITS_NUM_PER_BYTE = 8;
147 input_slice_shape[input_slice_shape.size() - 1] *= BITS_NUM_PER_BYTE; // Restores the shape of DropoutGenMask input
148 size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode);
149 if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
150 MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
151 }
152 if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(kIndex1))) {
153 MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
154 }
155 ValuePtr new_shape = MakeValue(input_slice_shape);
156 AnfNodePtr val = NewValueNode(new_shape);
157 manager->SetEdge(dropout_gen_mask_cnode, kIndex1, val);
158 MS_LOG(DEBUG) << "The input slice shape dropout is " << ShapeToString(input_slice_shape);
159 }
160
InitIsInputPassed()161 void FlashAttentionScoreInfo::InitIsInputPassed() {
162 is_input_passed_.resize(input_value_.size());
163 for (size_t i = 0; i < input_value_.size(); ++i) {
164 is_input_passed_[i] = (input_value_[i] == nullptr || !input_value_[i]->isa<None>());
165 }
166 }
167
GetStrategyRealIndex(size_t index)168 size_t FlashAttentionScoreInfo::GetStrategyRealIndex(size_t index) {
169 if (index >= is_input_passed_.size() || !is_input_passed_[index]) {
170 MS_LOG(INTERNAL_EXCEPTION) << name_ << ": GetStrategyRealIndex failed, index is " << index;
171 }
172 auto real_index = -1;
173 for (size_t i = 0; i <= index; ++i) {
174 if (is_input_passed_[i]) {
175 ++real_index;
176 }
177 }
178 return real_index;
179 }
180
GetSPRankList()181 RankList FlashAttentionScoreInfo::GetSPRankList() {
182 CheckGlobalDeviceManager();
183 int64_t rank = g_device_manager->global_rank();
184 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
185 RankList group_devices;
186 int64_t seq_dim = SizeToLong(dev_matrix_shape_.size()) - dev_matrix_s1_dim_ - 1;
187 if (dev_matrix.GetDevicesAlongDim(seq_dim, &group_devices) != SUCCESS) {
188 MS_LOG(ERROR) << name_ << " get group devices along dim " << seq_dim << " failed.";
189 }
190 return group_devices;
191 }
192
InitAttnMaskStrategies()193 Status FlashAttentionScoreInfo::InitAttnMaskStrategies() {
194 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
195 auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
196 int64_t s1_split_num_attn_mask = is_attn_mask_compressed_ ? 1 : s1_split_num_;
197 int64_t s2_split_num_attn_mask = enable_ring_attention_ ? s1_split_num_attn_mask : 1;
198 if (attn_mask_shape.size() == kSizeTwo) {
199 // attn_mask_shape: (S1, S2)
200 expect_strategies_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {s1_split_num_attn_mask,
201 s2_split_num_attn_mask};
202 } else if (attn_mask_shape.size() == kSizeFour) {
203 // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
204 auto attn_mask_n1_split_num = attn_mask_have_n1_dim_ ? n1_split_num_ : 1;
205 auto attn_batch_split_num = attn_mask_have_batch_dim_ ? batch_split_num_ : 1;
206 expect_strategies_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_batch_split_num, attn_mask_n1_split_num,
207 s1_split_num_attn_mask, 1};
208 }
209 }
210 return SUCCESS;
211 }
212
InitExpectedStrategies()213 Status FlashAttentionScoreInfo::InitExpectedStrategies() {
214 expect_strategies_ = Strategies(ops::kFlashAttentionScoreInputsNum);
215 switch (input_layout_) {
216 case FASInputLayoutMode::BSH:
217 expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, s1_split_num_, n1_split_num_};
218 expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, s2_split_num_, n2_split_num_};
219 expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, s2_split_num_, n2_split_num_};
220 break;
221 case FASInputLayoutMode::SBH:
222 expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {s1_split_num_, batch_split_num_, n1_split_num_};
223 expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {s2_split_num_, batch_split_num_, n2_split_num_};
224 expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {s2_split_num_, batch_split_num_, n2_split_num_};
225 break;
226 case FASInputLayoutMode::BNSD:
227 expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, n1_split_num_, s1_split_num_,
228 1};
229 expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, n2_split_num_, s2_split_num_, 1};
230 expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, n2_split_num_, s2_split_num_,
231 1};
232 break;
233 case FASInputLayoutMode::BSND:
234 expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_, s1_split_num_, n1_split_num_,
235 1};
236 expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, s2_split_num_, n2_split_num_, 1};
237 expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, s2_split_num_, n2_split_num_,
238 1};
239 break;
240 case FASInputLayoutMode::TND:
241 expect_strategies_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_split_num_ * s1_split_num_, n1_split_num_,
242 1};
243 expect_strategies_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_split_num_, n2_split_num_, 1};
244 expect_strategies_[ops::kFlashAttentionScoreInputValueIndex] = {batch_split_num_, n2_split_num_, 1};
245 break;
246 default:
247 MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
248 return FAILED;
249 }
250
251 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
252 int64_t real_shift_s1_split_num = real_shift_have_s1_dim_ ? s1_split_num_ : 1;
253 auto real_shift_batch_split_num = real_shift_have_batch_dim_ ? batch_split_num_ : 1;
254 expect_strategies_[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_split_num, n1_split_num_,
255 real_shift_s1_split_num, 1};
256 }
257 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
258 expect_strategies_[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_split_num_, n1_split_num_, s1_split_num_,
259 1};
260 }
261 if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
262 expect_strategies_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
263 }
264 InitAttnMaskStrategies();
265
266 // padding_mask is not support yet, skip it.
267
268 if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
269 expect_strategies_[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_split_num_};
270 }
271 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
272 expect_strategies_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {batch_split_num_};
273 }
274 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
275 expect_strategies_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {batch_split_num_};
276 }
277 expect_strategies_.erase(std::remove(expect_strategies_.begin(), expect_strategies_.end(), Shape{}),
278 expect_strategies_.end());
279 return SUCCESS;
280 }
281
InitQKVTensorMap()282 Status FlashAttentionScoreInfo::InitQKVTensorMap() {
283 int64_t kv_head_num_map = kv_split_ ? dev_matrix_n1_dim_ : -1;
284 auto dev_matrix_s2_dim = enable_ring_attention_ ? dev_matrix_s1_dim_ : -1;
285 switch (input_layout_) {
286 case FASInputLayoutMode::BSH:
287 inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_s1_dim_,
288 dev_matrix_n1_dim_};
289 inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
290 kv_head_num_map};
291 inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
292 kv_head_num_map};
293 break;
294 case FASInputLayoutMode::SBH:
295 inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_s1_dim_, dev_matrix_batch_dim_,
296 dev_matrix_n1_dim_};
297 inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_s2_dim, dev_matrix_batch_dim_,
298 kv_head_num_map};
299 inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_s2_dim, dev_matrix_batch_dim_,
300 kv_head_num_map};
301 break;
302 case FASInputLayoutMode::BNSD:
303 inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_,
304 dev_matrix_s1_dim_, -1};
305 inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, kv_head_num_map,
306 dev_matrix_s2_dim, -1};
307 inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, kv_head_num_map,
308 dev_matrix_s2_dim, -1};
309 break;
310 case FASInputLayoutMode::BSND:
311 inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_s1_dim_,
312 dev_matrix_n1_dim_, -1};
313 inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
314 kv_head_num_map, -1};
315 inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, dev_matrix_s2_dim,
316 kv_head_num_map, -1};
317 break;
318 case FASInputLayoutMode::TND:
319 inputs_tensor_map_[ops::kFlashAttentionScoreInputQueryIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_, -1};
320 inputs_tensor_map_[ops::kFlashAttentionScoreInputKeyIndex] = {dev_matrix_batch_dim_, kv_head_num_map, -1};
321 inputs_tensor_map_[ops::kFlashAttentionScoreInputValueIndex] = {dev_matrix_batch_dim_, kv_head_num_map, -1};
322 break;
323 default:
324 MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
325 return FAILED;
326 }
327 return SUCCESS;
328 }
329
InitInputsTensorMap()330 Status FlashAttentionScoreInfo::InitInputsTensorMap() {
331 inputs_tensor_map_ = std::vector<Shape>(ops::kFlashAttentionScoreInputsNum);
332 if (InitQKVTensorMap() != SUCCESS) {
333 return FAILED;
334 }
335
336 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
337 auto real_shift_s1_map = real_shift_have_s1_dim_ ? dev_matrix_s1_dim_ : -1;
338 auto real_shift_batch_map = real_shift_have_batch_dim_ ? dev_matrix_batch_dim_ : -1;
339 inputs_tensor_map_[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_map, dev_matrix_n1_dim_,
340 real_shift_s1_map, -1};
341 }
342 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
343 inputs_tensor_map_[ops::kFlashAttentionScoreInputDropMaskIndex] = {dev_matrix_batch_dim_, dev_matrix_n1_dim_,
344 dev_matrix_s1_dim_, -1};
345 }
346 if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
347 inputs_tensor_map_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
348 }
349 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
350 auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
351 int64_t dev_matrix_s1_dim_attn_mask = is_attn_mask_compressed_ ? -1 : dev_matrix_s1_dim_;
352 if (attn_mask_shape.size() == kSizeTwo) {
353 // attn_mask_shape: (S1, S2)
354 inputs_tensor_map_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {dev_matrix_s1_dim_attn_mask, -1};
355 } else if (attn_mask_shape.size() == kSizeFour) {
356 // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
357 auto attn_mask_batch_map = attn_mask_have_batch_dim_ ? dev_matrix_batch_dim_ : -1;
358 auto attn_mask_n1_map = attn_mask_have_n1_dim_ ? dev_matrix_n1_dim_ : -1;
359 inputs_tensor_map_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_mask_batch_map, attn_mask_n1_map,
360 dev_matrix_s1_dim_attn_mask, -1};
361 }
362 }
363 if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
364 inputs_tensor_map_[ops::kFlashAttentionScoreInputPrefixIndex] = {dev_matrix_batch_dim_};
365 }
366 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
367 inputs_tensor_map_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {dev_matrix_batch_dim_};
368 }
369 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
370 inputs_tensor_map_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {dev_matrix_batch_dim_};
371 }
372 inputs_tensor_map_.erase(std::remove(inputs_tensor_map_.begin(), inputs_tensor_map_.end(), Shape{}),
373 inputs_tensor_map_.end());
374 return SUCCESS;
375 }
376
InitAttnMaskSplittableInputs()377 Status FlashAttentionScoreInfo::InitAttnMaskSplittableInputs() {
378 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
379 int64_t s1_group = 2;
380 auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
381 int64_t attn_s1_group = is_attn_mask_compressed_ ? 0 : s1_group;
382 int64_t attn_s2_group = enable_ring_attention_ ? attn_s1_group : 0;
383 if (attn_mask_shape.size() == kSizeTwo) {
384 // attn_mask_shape: (S1, S2)
385 splittable_inputs_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_s1_group, attn_s2_group};
386 } else if (attn_mask_shape.size() == kSizeFour) {
387 int64_t n1_group = 1;
388 int64_t batch_group = 3;
389 // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
390 auto attn_mask_n1_group = attn_mask_shape[kIndex1] == 1 ? 0 : n1_group;
391 splittable_inputs_[ops::kFlashAttentionScoreInputAttnMaskIndex] = {batch_group, attn_mask_n1_group, attn_s1_group,
392 0};
393 }
394 }
395 return SUCCESS;
396 }
397
InitSplittableInputs()398 Status FlashAttentionScoreInfo::InitSplittableInputs() {
399 splittable_inputs_ = std::vector<Shape>(ops::kFlashAttentionScoreInputsNum);
400 int64_t batch_group = 3;
401 int64_t s1_group = 2;
402 int64_t n1_group = 1;
403 int64_t n2_group = kv_split_ ? n1_group : 0;
404 int64_t s2_group = enable_ring_attention_ ? s1_group : 0;
405 switch (input_layout_) {
406 case FASInputLayoutMode::BSH:
407 splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, s1_group, n1_group};
408 splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, s2_group, n2_group};
409 splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, s2_group, n2_group};
410 break;
411 case FASInputLayoutMode::SBH:
412 splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {s1_group, batch_group, n1_group};
413 splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {s2_group, batch_group, n2_group};
414 splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {s2_group, batch_group, n2_group};
415 break;
416 case FASInputLayoutMode::BNSD:
417 splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, n1_group, s1_group, 0};
418 splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, n2_group, s2_group, 0};
419 splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, n2_group, s2_group, 0};
420 break;
421 case FASInputLayoutMode::BSND:
422 splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, s1_group, n1_group, 0};
423 splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, s2_group, n2_group, 0};
424 splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, s2_group, n2_group, 0};
425 break;
426 case FASInputLayoutMode::TND:
427 splittable_inputs_[ops::kFlashAttentionScoreInputQueryIndex] = {batch_group, n1_group, 0};
428 splittable_inputs_[ops::kFlashAttentionScoreInputKeyIndex] = {batch_group, n2_group, 0};
429 splittable_inputs_[ops::kFlashAttentionScoreInputValueIndex] = {batch_group, n2_group, 0};
430 break;
431 default:
432 MS_LOG(ERROR) << name_ << "Not support layout: " << input_layout_;
433 return FAILED;
434 }
435
436 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
437 auto real_shift_s1_group = real_shift_have_s1_dim_ ? s1_group : 0;
438 splittable_inputs_[ops::kFlashAttentionScoreInputRealShiftIndex] = {batch_group, n1_group, real_shift_s1_group, 0};
439 }
440 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
441 splittable_inputs_[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_group, n1_group, s1_group, 0};
442 }
443 if (is_input_passed_[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
444 splittable_inputs_[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
445 }
446 InitAttnMaskSplittableInputs();
447 if (is_input_passed_[ops::kFlashAttentionScoreInputPrefixIndex]) {
448 splittable_inputs_[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_group};
449 }
450 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
451 splittable_inputs_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {batch_group};
452 }
453 if (is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
454 splittable_inputs_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {batch_group};
455 }
456 splittable_inputs_.erase(std::remove(splittable_inputs_.begin(), splittable_inputs_.end(), Shape{}),
457 splittable_inputs_.end());
458 return SUCCESS;
459 }
460
InitQKVHeadAndSeqDimFromInputLayout()461 Status FlashAttentionScoreInfo::InitQKVHeadAndSeqDimFromInputLayout() {
462 switch (input_layout_) {
463 case FASInputLayoutMode::BSH:
464 qkv_batch_dim_ = kSizeZero;
465 qkv_seq_dim_ = kSizeOne;
466 qkv_head_dim_ = kSizeTwo;
467 break;
468 case FASInputLayoutMode::SBH:
469 qkv_seq_dim_ = kSizeZero;
470 qkv_batch_dim_ = kSizeOne;
471 qkv_head_dim_ = kSizeTwo;
472 break;
473 case FASInputLayoutMode::BNSD:
474 qkv_batch_dim_ = kSizeZero;
475 qkv_head_dim_ = kSizeOne;
476 qkv_seq_dim_ = kSizeTwo;
477 break;
478 case FASInputLayoutMode::BSND:
479 qkv_batch_dim_ = kSizeZero;
480 qkv_seq_dim_ = kSizeOne;
481 qkv_head_dim_ = kSizeTwo;
482 break;
483 case FASInputLayoutMode::TND:
484 qkv_batch_dim_ = kSizeZero;
485 qkv_seq_dim_ = kSizeZero;
486 qkv_head_dim_ = kSizeOne;
487 break;
488 default:
489 MS_LOG(ERROR) << name_ << ": Not support layout in parallel currently.";
490 return FAILED;
491 }
492 return SUCCESS;
493 }
494
InitAttrs()495 Status FlashAttentionScoreInfo::InitAttrs() { return GetAttrs(); }
496
CheckInputLayout()497 Status FlashAttentionScoreInfo::CheckInputLayout() {
498 if (InferSplitNumAndDevMatrixShapeByLayout() != SUCCESS) {
499 MS_LOG(ERROR) << name_ << ": Infer device matrix shape by layout failed.";
500 return FAILED;
501 }
502
503 auto query_shape = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex];
504 auto key_shape = inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex];
505 if (s1_split_num_ > 1 && input_layout_ == FASInputLayoutMode::TND &&
506 (sparse_mode_ != ops::kSparseRightDownCausal || query_shape[0] != key_shape[0])) {
507 MS_LOG(ERROR)
508 << name_
509 << ": When input_layout is TND, sparse_mode is 3, and the T-dimension of query and key are the same, the "
510 "T-dimension of query can be sliced. query_shape: "
511 << query_shape << ", key_shape: " << key_shape << ", sparse_mode: " << sparse_mode_;
512 return FAILED;
513 }
514
515 // Check all device matrix should be the same
516 if (ops::kFlashAttentionScoreInputQueryIndex >= inputs_tensor_info_.size()) {
517 return FAILED;
518 }
519 auto query_tensor_info = inputs_tensor_info_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputQueryIndex)];
520 dev_matrix_shape_ = query_tensor_info.tensor_layout().device_arrangement_origin().array();
521 return SUCCESS;
522 }
523
CheckOutputLayout()524 Status FlashAttentionScoreInfo::CheckOutputLayout() { return SUCCESS; }
525
InferOutputLayout()526 Status FlashAttentionScoreInfo::InferOutputLayout() {
527 auto query_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout();
528
529 // Construct layout for softmax_max and softmax_sum
530 std::vector<Shape> softmax_max_sum_tensor_map;
531 Shape softmax_max_sum_tensor_shape;
532 if (input_layout_ == FASInputLayoutMode::TND) {
533 softmax_max_tensor_layout_ = query_layout;
534 softmax_sum_tensor_layout_ = query_layout;
535 } else {
536 softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_batch_dim_]); // B
537 softmax_max_sum_tensor_shape.push_back(query_layout.tensor_shape_before().array()[qkv_batch_dim_]); // B
538 softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_head_dim_]); // N
539 softmax_max_sum_tensor_shape.push_back(head_num_); // N
540 softmax_max_sum_tensor_map.push_back(query_layout.tensor_map_before()[qkv_seq_dim_]); // S
541 softmax_max_sum_tensor_shape.push_back(query_layout.tensor_shape_before().array()[qkv_seq_dim_]); // S
542 softmax_max_sum_tensor_map.push_back({MAP_NONE}); // 8
543 softmax_max_sum_tensor_shape.push_back(8); // 8
544 softmax_max_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
545 softmax_max_sum_tensor_map,
546 outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxMaxIndex]);
547 softmax_sum_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
548 softmax_max_sum_tensor_map,
549 outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxSumIndex]);
550 }
551
552 // Construct layout for softmax_out
553 softmax_out_tensor_layout_.InitFromExtendVector(query_layout.device_arrangement_origin().array(),
554 std::vector<Shape>{{MAP_NONE}},
555 outputs_shape()[ops::kFlashAttentionScoreOutputSoftmaxOutIndex]);
556 attention_out_tensor_layout_ = query_layout;
557 return SUCCESS;
558 }
559
InferOutputTensorInfo()560 Status FlashAttentionScoreInfo::InferOutputTensorInfo() {
561 auto status = InferOutputLayout();
562 if (status != SUCCESS) {
563 return status;
564 }
565 (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_max_tensor_layout_));
566 (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_sum_tensor_layout_));
567 (void)outputs_tensor_info_.emplace_back(TensorInfo(softmax_out_tensor_layout_));
568 (void)outputs_tensor_info_.emplace_back(TensorInfo(attention_out_tensor_layout_));
569 return SUCCESS;
570 }
571
InferAsLossDivisorByLayout()572 Status FlashAttentionScoreInfo::InferAsLossDivisorByLayout() {
573 if (outputs_tensor_info_.size() != ops::kFlashAttentionScoreOutputsNum) {
574 MS_LOG(ERROR)
575 << name_
576 << ": The size of outputs tensor info must be equal to the size of FlashAttentionScore's output size, but got "
577 << outputs_tensor_info_.size() << " and " << ops::kFlashAttentionScoreOutputsNum;
578 return FAILED;
579 }
580
581 auto attention_out_tensor_info = outputs_tensor_info_[ops::kFlashAttentionScoreOutputAttentionOutIndex];
582 TensorMaps attention_out_tensor_map = attention_out_tensor_info.tensor_layout().tensor_map_before();
583 if (attention_out_tensor_map.empty()) {
584 as_loss_divisor_ = stage_device_size_;
585 MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
586 return SUCCESS;
587 }
588
589 auto out_dev_matrix_shape = attention_out_tensor_info.tensor_layout().device_arrangement_origin().array();
590 if (out_dev_matrix_shape.empty()) {
591 MS_LOG(INFO) << name_ << ": out_dev_matrix_shape is empty";
592 out_dev_matrix_shape = dev_matrix_shape_;
593 }
594 Shape squashed_tensor_map;
595 for (const auto &tensor_map : attention_out_tensor_map) {
596 std::copy(tensor_map.begin(), tensor_map.end(), std::back_inserter(squashed_tensor_map));
597 }
598
599 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape, squashed_tensor_map);
600 MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape)
601 << ", the output tensor map is " << ShapeToString(squashed_tensor_map) << ", loss divisor is "
602 << as_loss_divisor_;
603 return SUCCESS;
604 }
605
InferMirrorOpsByLayout()606 Status FlashAttentionScoreInfo::InferMirrorOpsByLayout() {
607 mirror_ops_.clear();
608 if (inputs_shape_.empty()) {
609 MS_LOG(INFO) << name_ << ": The inputs size is empty";
610 return SUCCESS;
611 }
612
613 bool group_is_empty = true;
614 for (size_t i = 0; i < inputs_tensor_info_.size(); ++i) {
615 if (inputs_tensor_info_[i] == TensorInfo()) {
616 (void)mirror_ops_.emplace_back(OperatorVector());
617 continue;
618 }
619 auto input_tensor_layout = inputs_tensor_info_[i].tensor_layout();
620 auto repeated_rank_list = input_tensor_layout.InferRepeatedGroup();
621
622 OperatorVector mirror_op;
623 if (repeated_rank_list.size() == 1) {
624 MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
625 mirror_ops_.push_back(mirror_op);
626 continue;
627 }
628 if (is_auto_parallel_) {
629 if (g_device_manager->CheckDeviceList(repeated_rank_list) != SUCCESS) {
630 MS_LOG(INFO) << name_ << ": Try to create communication group : " << repeated_rank_list
631 << " failed in auto parallel mode, "
632 "this error can be ignored in parallel strategies searching step";
633 return FAILED;
634 }
635 return SUCCESS;
636 }
637
638 Group mirror_group;
639 if (g_device_manager->CreateGroup(repeated_rank_list, &mirror_group) != SUCCESS) {
640 MS_LOG(ERROR) << name_
641 << ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
642 << ", the full_name of node is: " << cnode_->fullname_with_scope();
643 return FAILED;
644 }
645 group_is_empty = false;
646 mirror_op = CreateMirrorOps(mirror_group.name(), mirror_group.GetDevNum());
647 mirror_ops_.push_back(mirror_op);
648 }
649
650 if (group_is_empty) {
651 mirror_ops_.clear();
652 MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
653 }
654 return SUCCESS;
655 }
656
GetAttrs()657 Status FlashAttentionScoreInfo::GetAttrs() {
658 InitIsInputPassed();
659 head_num_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputHeadNumIndex + 1);
660 keep_prob_ = GetInputValueFromCNode<float>(cnode_, ops::kFlashAttentionScoreInputKeepProbIndex + 1);
661 scale_value_ = GetInputValueFromCNode<float>(cnode_, ops::kFlashAttentionScoreInputScaleValueIndex + 1);
662 pre_tokens_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputPreTokensIndex + 1);
663 next_tokens_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputNextTokensIndex + 1);
664 input_layout_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputLayoutIndex + 1);
665 sparse_mode_ = GetInputValueFromCNode<int64_t>(cnode_, ops::kFlashAttentionScoreInputSparseModeIndex + 1);
666 auto ms_context = MsContext::GetInstance();
667 MS_EXCEPTION_IF_NULL(ms_context);
668 enable_load_balance_ = ms_context->get_param<bool>(MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE);
669
670 if (input_layout_ == FASInputLayoutMode::TND && enable_load_balance_) {
671 MS_LOG(WARNING) << name_ << ": Load balancing is not supported in the layout 'TND' and will be disabled.";
672 enable_load_balance_ = false;
673 }
674
675 auto enable_ring_attention_iter = attrs_.find(ENABLE_RING_ATTENTION);
676 if (enable_ring_attention_iter != attrs_.end()) {
677 MS_EXCEPTION_IF_NULL(enable_ring_attention_iter->second);
678 if (enable_ring_attention_iter->second->isa<BoolImm>()) {
679 enable_ring_attention_ = enable_ring_attention_iter->second->cast<BoolImmPtr>()->value();
680 enable_load_balance_ = false;
681 MS_LOG(DEBUG) << "enable_ring_attention_: " << enable_ring_attention_;
682 } else {
683 MS_LOG(ERROR) << "enable_ring_attention should be bool";
684 }
685 }
686 if (enable_ring_attention_) {
687 if (input_layout_ != FASInputLayoutMode::BSH && input_layout_ != FASInputLayoutMode::BNSD) {
688 MS_LOG(ERROR) << "Ring attention currently only supports BSH and BNSD layout";
689 }
690 if (sparse_mode_ != 0) {
691 MS_LOG(ERROR) << "Ring attention currently only supports sparse mode 0";
692 }
693 if (keep_prob_ != 1.0) {
694 MS_LOG(ERROR) << "Ring attention currently only supports keep prob 1.0";
695 }
696 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
697 MS_LOG(ERROR) << "Ring attention do not need input attn mask";
698 }
699 }
700
701 is_attn_mask_compressed_ =
702 std::find(needCompressAttnMask.begin(), needCompressAttnMask.end(), sparse_mode_) != needCompressAttnMask.end();
703 need_update_op_attrs_mode_ = sparse_mode_ != ops::kSparseAllMask;
704 if (InitQKVHeadAndSeqDimFromInputLayout() != Status::SUCCESS) {
705 return FAILED;
706 }
707
708 kv_split_ = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex][qkv_head_dim_] !=
709 inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex][qkv_head_dim_] * head_num_;
710
711 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
712 auto real_shift_s1_dim =
713 inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputRealShiftIndex)).at(kIndex3);
714 real_shift_have_s1_dim_ = real_shift_s1_dim > 1;
715 auto real_shift_batch_dim =
716 inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputRealShiftIndex)).at(kIndex0);
717 real_shift_have_batch_dim_ = real_shift_batch_dim > 1;
718 }
719
720 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
721 auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
722 if (attn_mask_shape.size() == kSizeFour) {
723 attn_mask_have_batch_dim_ = attn_mask_shape.at(kIndex0) > 1;
724 attn_mask_have_n1_dim_ = attn_mask_shape.at(kIndex1) > 1;
725 }
726 }
727 return SUCCESS;
728 }
729
CheckStrategy(const StrategyPtr & strategy)730 Status FlashAttentionScoreInfo::CheckStrategy(const StrategyPtr &strategy) {
731 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
732 return FAILED;
733 }
734 auto strategies = strategy->GetInputDim();
735 auto query_strategy = strategies[ops::kFlashAttentionScoreInputQueryIndex];
736 auto key_strategy = strategies[ops::kFlashAttentionScoreInputKeyIndex];
737 auto value_strategy = strategies[ops::kFlashAttentionScoreInputValueIndex];
738 if (key_strategy != value_strategy) {
739 MS_LOG(ERROR) << name_ << ": The in_strategy both of 'key'( " << key_strategy << ") and 'value'" << value_strategy
740 << ") must be same.";
741 return FAILED;
742 }
743 if (head_num_ % query_strategy[qkv_head_dim_] != 0) {
744 MS_LOG(ERROR) << name_ << ": head_num % query_strategy[" << qkv_head_dim_ << "] must be 0, but got " << head_num_
745 << "(head_num) and " << query_strategy[qkv_head_dim_] << "(query_strategy[" << qkv_head_dim_ << "])";
746 return FAILED;
747 }
748 if (!kv_split_ && key_strategy[qkv_head_dim_] != 1) {
749 MS_LOG(ERROR) << name_ << ": Under the MQA,the hidden-dim of input 'key' cannot be split.";
750 return FAILED;
751 }
752
753 if (input_layout_ == FASInputLayoutMode::TND) {
754 if (query_strategy[qkv_seq_dim_] != key_strategy[qkv_seq_dim_]) {
755 MS_LOG(ERROR)
756 << name_ << ": The split num of seq-dim between query and key must be the same when layout is 'TND'. But got "
757 << query_strategy[qkv_seq_dim_] << " and " << key_strategy[qkv_seq_dim_];
758 return FAILED;
759 }
760 } else {
761 auto s2_split_num = key_strategy[qkv_seq_dim_];
762 if (s2_split_num != 1 && !enable_ring_attention_) {
763 MS_LOG(ERROR) << name_ << ": The S-Dimension of input 'key' cannot be split, but got the strategy of key is "
764 << key_strategy;
765 return FAILED;
766 }
767 }
768
769 if (input_layout_ == FASInputLayoutMode::TND) {
770 batch_split_num_ = key_strategy[qkv_batch_dim_];
771 s1_split_num_ = query_strategy[qkv_batch_dim_] / batch_split_num_;
772 } else {
773 batch_split_num_ = query_strategy[qkv_batch_dim_];
774 s1_split_num_ = query_strategy[qkv_seq_dim_];
775 }
776 n1_split_num_ = query_strategy[qkv_head_dim_];
777
778 s2_split_num_ = enable_ring_attention_ ? s1_split_num_ : 1;
779
780 n2_split_num_ = key_strategy[qkv_head_dim_];
781
782 if (kv_split_ && n1_split_num_ != n2_split_num_) {
783 MS_LOG(ERROR) << name_ << ": The split num of N1-dim and N2-dim must be equal if N2 > 1, but got " << n1_split_num_
784 << " and " << n2_split_num_;
785 return FAILED;
786 }
787
788 if (s1_split_num_ > 1 && input_layout_ == FASInputLayoutMode::TND) {
789 MS_LOG(ERROR)
790 << name_
791 << ": Currently, input_layout is TND, and the seq dimension of query is segmented. Please use Layout to "
792 "set the strategy.";
793 return FAILED;
794 }
795
796 if (InitExpectedStrategies() != SUCCESS) {
797 return FAILED;
798 }
799 if (strategies != expect_strategies_) {
800 MS_LOG(ERROR) << name_ << ": The input strategy must be " << expect_strategies_ << ", but got " << strategies;
801 return FAILED;
802 }
803
804 return SUCCESS;
805 }
806
CheckStrategyForDynamicShape(const StrategyPtr &)807 Status FlashAttentionScoreInfo::CheckStrategyForDynamicShape(const StrategyPtr &) {
808 for (auto &cnode : cnodes_) {
809 // If DropoutGenMask -> Reshape -> FlashAttentionScore
810 auto reshape_node = cnode->input(ops::kFlashAttentionScoreInputDropMaskIndex + 1);
811 MS_EXCEPTION_IF_NULL(reshape_node);
812 if (!IsPrimitiveCNode(reshape_node, prim::kPrimReshape)) {
813 continue;
814 }
815
816 MS_LOG(ERROR)
817 << name_ << ": it does not support dynamic shape if it need to replace dst-shape for reshape, the inputs' shape: "
818 << ShapesToString(inputs_shape_);
819 return FAILED;
820 }
821 return SUCCESS;
822 }
823
InferDevMatrixShape()824 Status FlashAttentionScoreInfo::InferDevMatrixShape() {
825 switch (input_layout_) {
826 case FASInputLayoutMode::BSH:
827 case FASInputLayoutMode::BSND:
828 case FASInputLayoutMode::TND:
829 dev_matrix_shape_ = {batch_split_num_, s1_split_num_, n1_split_num_};
830 dev_matrix_batch_dim_ = kIndex2;
831 dev_matrix_s1_dim_ = kIndex1;
832 dev_matrix_n1_dim_ = kIndex0;
833 break;
834 case FASInputLayoutMode::SBH:
835 dev_matrix_shape_ = {s1_split_num_, batch_split_num_, n1_split_num_};
836 dev_matrix_s1_dim_ = kIndex2;
837 dev_matrix_batch_dim_ = kIndex1;
838 dev_matrix_n1_dim_ = kIndex0;
839 break;
840 case FASInputLayoutMode::BNSD:
841 dev_matrix_shape_ = {batch_split_num_, n1_split_num_, s1_split_num_};
842 dev_matrix_batch_dim_ = kIndex2;
843 dev_matrix_n1_dim_ = kIndex1;
844 dev_matrix_s1_dim_ = kIndex0;
845 break;
846 default:
847 MS_LOG(ERROR) << name_ << ": Not support layout: " << input_layout_;
848 return FAILED;
849 }
850 return SUCCESS;
851 }
852
InferSplitNumAndDevMatrixShapeByLayout()853 Status FlashAttentionScoreInfo::InferSplitNumAndDevMatrixShapeByLayout() {
854 dev_matrix_shape_ =
855 inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout().device_arrangement_origin().array();
856 auto query_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputQueryIndex].tensor_layout();
857 auto key_layout = inputs_tensor_info_[ops::kFlashAttentionScoreInputKeyIndex].tensor_layout();
858 auto query_tensor_map = query_layout.tensor_map_before();
859 auto query_batch_map = query_tensor_map.at(qkv_batch_dim_);
860 auto query_seq_map = query_tensor_map.at(qkv_seq_dim_);
861 auto query_head_map = query_tensor_map.at(qkv_head_dim_);
862 auto key_seq_map = key_layout.tensor_map_before().at(qkv_seq_dim_);
863
864 auto dev_matrix_shape = dev_matrix_shape_;
865 if (input_layout_ == FASInputLayoutMode::TND) {
866 if (query_batch_map.size() == kSizeOne) {
867 dev_matrix_batch_dim_ = query_batch_map[0];
868 dev_matrix_s1_dim_ = MAP_NONE;
869 } else if (query_batch_map.size() == kSizeTwo) {
870 dev_matrix_batch_dim_ = query_batch_map[0];
871 dev_matrix_s1_dim_ = query_batch_map[1];
872 } else {
873 MS_LOG(ERROR) << name_
874 << ": The seq-dimension of query can only be mapped upto 2 device matrix dimension, but got "
875 << query_batch_map;
876 return FAILED;
877 }
878 n1_split_num_ = 1;
879 for (auto map_id : query_head_map) {
880 n1_split_num_ *= GetSplitNumByMapId(dev_matrix_shape, map_id);
881 }
882 } else {
883 if (query_batch_map.size() != 1 || query_seq_map.size() != 1 || query_head_map.size() != 1) {
884 MS_LOG(ERROR) << name_
885 << ": Each dimension of query can only be mapped to one device matrix dimension, but got the "
886 "tensor info of query is "
887 << query_layout.ToString();
888 return FAILED;
889 }
890 dev_matrix_batch_dim_ = query_batch_map[0];
891 dev_matrix_s1_dim_ = query_seq_map[0];
892 dev_matrix_n1_dim_ = query_head_map[0];
893 n1_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_n1_dim_);
894 }
895 batch_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_batch_dim_);
896 s1_split_num_ = GetSplitNumByMapId(dev_matrix_shape, dev_matrix_s1_dim_);
897 if (s1_split_num_ > 1 && GetSplitNumByTensorMap(dev_matrix_shape, query_seq_map) !=
898 GetSplitNumByTensorMap(dev_matrix_shape, key_seq_map) * s1_split_num_) {
899 MS_LOG(EXCEPTION) << name_ << ": Cannot split the seq-dimension of key. query_seq_slice: "
900 << GetSplitNumByTensorMap(dev_matrix_shape, query_seq_map)
901 << ", key_seq_slice: " << GetSplitNumByTensorMap(dev_matrix_shape, key_seq_map)
902 << ", s1_split_num: " << s1_split_num_;
903 }
904 return SUCCESS;
905 }
906
InferTensorMap()907 Status FlashAttentionScoreInfo::InferTensorMap() {
908 if (InitInputsTensorMap() != SUCCESS) {
909 return FAILED;
910 }
911 if (input_layout_ == FASInputLayoutMode::TND) {
912 outputs_tensor_map_.push_back({inputs_tensor_map_[0]}); // softmax_max
913 outputs_tensor_map_.push_back({inputs_tensor_map_[0]}); // softmax_sum
914 } else {
915 outputs_tensor_map_.push_back({dev_matrix_batch_dim_, dev_matrix_n1_dim_, dev_matrix_s1_dim_, -1}); // softmax_max
916 outputs_tensor_map_.push_back({dev_matrix_batch_dim_, dev_matrix_n1_dim_, dev_matrix_s1_dim_, -1}); // softmax_sum
917 }
918 outputs_tensor_map_.push_back({-1}); // softmax_out
919 outputs_tensor_map_.push_back(inputs_tensor_map_[0]); // attention_out
920 return SUCCESS;
921 }
922
GetSplitIdAndRank()923 std::vector<int64_t> FlashAttentionScoreInfo::GetSplitIdAndRank() {
924 CheckGlobalDeviceManager();
925 int64_t rank = g_device_manager->global_rank();
926 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
927 RankList group_devices;
928 int64_t seq_dim = SizeToLong(dev_matrix_shape_.size()) - dev_matrix_s1_dim_ - 1;
929 if (dev_matrix.GetDevicesAlongDim(seq_dim, &group_devices) != SUCCESS) {
930 MS_LOG(ERROR) << name_ << " get group devices along dim " << seq_dim << " failed.";
931 }
932 auto iter = std::find(group_devices.begin(), group_devices.end(), rank);
933 if (iter == group_devices.end()) {
934 MS_LOG(EXCEPTION) << "FlashAttentionScore S1 sequence parallel get split id failed. "
935 << "rank " << rank << " not in group " << group_devices;
936 }
937 int64_t split_id = iter - group_devices.begin();
938 int64_t target_split_id = s1_split_num_ - split_id - 1;
939 int64_t target_rank_id = group_devices[target_split_id];
940 return std::vector<int64_t>({rank, target_rank_id, split_id, target_split_id});
941 }
942
GetAttentionMaskAttrs(const int64_t split_id,const int64_t split_num)943 std::tuple<int64_t, int64_t> FlashAttentionScoreInfo::GetAttentionMaskAttrs(const int64_t split_id,
944 const int64_t split_num) {
945 int64_t kv_seq_length;
946 int64_t q_seq_length;
947 kv_seq_length = inputs_shape_[ops::kFlashAttentionScoreInputKeyIndex][qkv_seq_dim_];
948 q_seq_length = inputs_shape_[ops::kFlashAttentionScoreInputQueryIndex][qkv_seq_dim_];
949 int64_t q_len_each_split = q_seq_length / split_num;
950 int64_t new_pre_tokens =
951 (sparse_mode_ == ops::kSparseDefaultMask || sparse_mode_ == ops::kSparseBand) ? pre_tokens_ : kv_seq_length;
952 int64_t new_next_tokens =
953 (sparse_mode_ == ops::kSparseDefaultMask || sparse_mode_ == ops::kSparseBand) ? next_tokens_ : 0;
954 switch (opAttrUpdateMap.at(sparse_mode_)) {
955 case kLeftUpToLeftUp:
956 new_pre_tokens = LongAdd(new_pre_tokens, -split_id * q_len_each_split);
957 new_next_tokens = LongAdd(new_next_tokens, split_id * q_len_each_split);
958 break;
959 case kLeftUpToRightDown:
960 new_pre_tokens = LongAdd(new_pre_tokens, (kv_seq_length - (split_id + 1) * q_len_each_split));
961 new_next_tokens = LongAdd(new_next_tokens, -(kv_seq_length - (split_id + 1) * q_len_each_split));
962 break;
963 case kRightDownToRightDown:
964 new_pre_tokens = LongAdd(new_pre_tokens, (split_num - split_id - 1) * (q_seq_length / split_num));
965 new_next_tokens = LongAdd(new_next_tokens, -(split_num - split_id - 1) * (q_seq_length / split_num));
966 break;
967 default:
968 MS_LOG(EXCEPTION) << "Invalid sparse mode " << sparse_mode_ << ", sparse mode should be one of [0, 2, 3, 4].";
969 }
970 return std::make_tuple(new_pre_tokens, new_next_tokens);
971 }
972
ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr & cnode)973 Status FlashAttentionScoreInfo::ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr &cnode) {
974 std::vector<int64_t> split_info = GetSplitIdAndRank();
975 int64_t tq = inputs_shape_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputQueryIndex)][qkv_batch_dim_];
976 int64_t tk = inputs_shape_[GetStrategyRealIndex(ops::kFlashAttentionScoreInputKeyIndex)][qkv_batch_dim_];
977 int64_t slice_tq = tq / s1_split_num_ / batch_split_num_;
978 int64_t slice_tk = tk / batch_split_num_;
979 int64_t split_id = split_info[kIndex2];
980 int64_t offset = slice_tq * split_id;
981 if (!is_input_passed_[ops::kFlashAttentionScoreInputActualSeqQlenIndex] ||
982 !is_input_passed_[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
983 MS_LOG(ERROR) << name_ << ": The input 'actual_seq_qlen' and 'actual_seq_kvlen' cannot be None under 'TND'.";
984 return FAILED;
985 }
986 auto actual_seq_qlen_input_index = ops::kFlashAttentionScoreInputActualSeqQlenIndex + 1;
987 auto actual_seq_kvlen_input_index = ops::kFlashAttentionScoreInputActualSeqKVlenIndex + 1;
988 auto actual_seq_qlen_node = cnode->input(actual_seq_qlen_input_index);
989 auto actual_seq_kvlen_node = cnode->input(actual_seq_kvlen_input_index);
990
991 auto func_graph = cnode->func_graph();
992 MS_EXCEPTION_IF_NULL(func_graph);
993 auto manager = func_graph->manager();
994 MS_EXCEPTION_IF_NULL(manager);
995
996 // new_actual_seq_qlen = clip(actual_seq_qlen - offset, 0, slice_tq)
997 auto qlen_offset_sub_cnode =
998 func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_qlen_node, CreateInt32Tensor(offset, true)});
999 auto new_actual_seq_qlen_cnode =
1000 func_graph->NewCNode({NewValueNode(prim::kPrimClipByValue), qlen_offset_sub_cnode, CreateInt32Tensor(0, true),
1001 CreateInt32Tensor(slice_tq, true)});
1002 manager->SetEdge(cnode, actual_seq_qlen_input_index, new_actual_seq_qlen_cnode);
1003
1004 // new_actual_seq_kvlen = actual_seq_kvlen - (ReLU(actual_seq_qlen - offset) - new_actual_seq_qlen)
1005 auto relu_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimReLU), qlen_offset_sub_cnode});
1006 auto kvlen_offset_sub_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_qlen_node, relu_cnode});
1007 auto tmp_new_actual_seq_kvlen_cnode =
1008 func_graph->NewCNode({NewValueNode(prim::kPrimSub), actual_seq_kvlen_node, kvlen_offset_sub_cnode});
1009
1010 // new_actual_seq_kvlen[actual_seq_kvlen == slice_tk] = slice_tk
1011 auto equal =
1012 func_graph->NewCNode({NewValueNode(prim::kPrimEqual), actual_seq_kvlen_node, CreateInt32Tensor(slice_tk, true)});
1013 auto new_actual_seq_kvlen_cnode = func_graph->NewCNode(
1014 {NewValueNode(prim::kPrimSelect), equal, actual_seq_kvlen_node, tmp_new_actual_seq_kvlen_cnode});
1015 manager->SetEdge(cnode, actual_seq_kvlen_input_index, new_actual_seq_kvlen_cnode);
1016
1017 return SUCCESS;
1018 }
1019
ReplaceNodeInputOrAttrs()1020 void FlashAttentionScoreInfo::ReplaceNodeInputOrAttrs() {
1021 for (auto &cnode : cnodes_) {
1022 SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputHeadNumIndex + 1, head_num_ / n1_split_num_);
1023 if (s1_split_num_ > 1 && !enable_load_balance_ && need_update_op_attrs_mode_) {
1024 if (input_layout_ == FASInputLayoutMode::TND) {
1025 if (ReplaceActualSeqLenForSplitSeqInTnd(cnode) != SUCCESS) {
1026 MS_LOG(EXCEPTION) << name_ << ": Replace actual_seq_qlen and actual_seq_kvlen failed.";
1027 }
1028 } else {
1029 int64_t new_pre_tokens, new_next_tokens;
1030 std::vector<int64_t> split_info = GetSplitIdAndRank();
1031 int64_t split_id = split_info[kIndex2];
1032 std::tie(new_pre_tokens, new_next_tokens) = GetAttentionMaskAttrs(split_id, s1_split_num_);
1033 int64_t new_sparse_mode = is_attn_mask_compressed_ ? ops::kSparseBand : sparse_mode_;
1034 SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputSparseModeIndex + 1, new_sparse_mode);
1035 SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputPreTokensIndex + 1, new_pre_tokens);
1036 SetValueInputToCNode<int64_t>(cnode, ops::kFlashAttentionScoreInputNextTokensIndex + 1, new_next_tokens);
1037 }
1038 }
1039 // If DropoutGenMask -> Reshape -> FlashAttentionScore, replace its.
1040 auto reshape_node = cnode->input(ops::kFlashAttentionScoreInputDropMaskIndex + 1);
1041 MS_EXCEPTION_IF_NULL(reshape_node);
1042 if (!IsPrimitiveCNode(reshape_node, prim::kPrimReshape)) {
1043 continue;
1044 }
1045 auto reshape_cnode = reshape_node->cast<CNodePtr>();
1046 if (!IsPrimitiveCNode(reshape_cnode->input(kIndex1), prim::kPrimDropoutGenMask)) {
1047 continue;
1048 }
1049 auto dropout_gen_mask_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
1050 // Update slice_shape for ReShape
1051 Shape input_slice_shape = inputs_tensor_info_.at(ops::kFlashAttentionScoreInputDropMaskIndex).slice_shape();
1052 ValuePtr new_shape = MakeValue(input_slice_shape);
1053 AnfNodePtr val = NewValueNode(new_shape);
1054 auto manager = cnode->func_graph()->manager();
1055 MS_EXCEPTION_IF_NULL(manager);
1056 manager->SetEdge(reshape_cnode, kIndex2, val);
1057 // Update slice shape and seed for DropoutGenMask
1058 UpdateDropoutGenMaskSliceShapeAndSeed(dropout_gen_mask_cnode);
1059 }
1060 }
1061
LoadBalanceSplitAlongSeqDim(size_t input_index,GenerateGraph * gen_g,AnfNodePtr * split_node,AnfNodePtr * keep_node,AnfNodePtr * exchange_node)1062 void FlashAttentionScoreInfo::LoadBalanceSplitAlongSeqDim(size_t input_index, GenerateGraph *gen_g,
1063 AnfNodePtr *split_node, AnfNodePtr *keep_node,
1064 AnfNodePtr *exchange_node) {
1065 OperatorAttrs split_attrs;
1066 int64_t q_split_axis;
1067 switch (input_index) {
1068 case ops::kFlashAttentionScoreInputQueryIndex:
1069 q_split_axis = SizeToLong(qkv_seq_dim_);
1070 split_attrs = {std::make_pair(AXIS, MakeValue(q_split_axis)),
1071 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1072 *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1073 *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1074 *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1075 break;
1076 case ops::kFlashAttentionScoreInputRealShiftIndex:
1077 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1078 split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(kInputRealShiftSeqDim)),
1079 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1080 *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1081 *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1082 *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1083 } else {
1084 *keep_node = gen_g->virtual_input_node();
1085 *exchange_node = gen_g->virtual_input_node();
1086 }
1087 break;
1088 case ops::kFlashAttentionScoreInputDropMaskIndex:
1089 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1090 split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(kInputDropMaskSeqDim)),
1091 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1092 *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1093 *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1094 *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1095 } else {
1096 *keep_node = gen_g->virtual_input_node();
1097 *exchange_node = gen_g->virtual_input_node();
1098 }
1099 break;
1100 case ops::kFlashAttentionScoreInputAttnMaskIndex:
1101 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1102 auto attn_mask_shape = inputs_shape_.at(GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
1103 if (attn_mask_shape.size() == kSizeTwo) {
1104 split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(0)),
1105 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1106 } else {
1107 split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(2)),
1108 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1109 }
1110 *split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, split_attrs), gen_g->virtual_input_node()});
1111 *keep_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(0)});
1112 *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), *split_node, CreatInt64Imm(1)});
1113 } else {
1114 *keep_node = gen_g->virtual_input_node();
1115 *exchange_node = gen_g->virtual_input_node();
1116 }
1117 break;
1118 default:
1119 MS_LOG(EXCEPTION) << "Invalid input index. Only 0(query), 3(real_shift), 4(drop_mask) and 6(attn_mask)"
1120 << "support sequence dim parallel, but got " << input_index;
1121 }
1122 }
1123
LoadBalanceExchange(const int64_t all_gather_idx,const Group & group,const AnfNodePtr & input_node,AnfNodePtr * exchange_node,GenerateGraph * gen_g)1124 void FlashAttentionScoreInfo::LoadBalanceExchange(const int64_t all_gather_idx, const Group &group,
1125 const AnfNodePtr &input_node, AnfNodePtr *exchange_node,
1126 GenerateGraph *gen_g) {
1127 OperatorAttrs all_gather_attrs = {std::make_pair(GROUP, MakeValue(group.name()))};
1128 OperatorAttrs all_gather_split_attrs = {std::make_pair(AXIS, MakeValue<int64_t>(0)),
1129 std::make_pair(OUTPUT_NUM, MakeValue(kLoadBalanceSplitNum))};
1130 auto all_gather_node = gen_g->PushBack({gen_g->NewOpInst(ALL_GATHER, all_gather_attrs), input_node});
1131 auto split_node = gen_g->PushBack({gen_g->NewOpInst(SPLIT, all_gather_split_attrs), all_gather_node});
1132 *exchange_node = gen_g->PushBack({gen_g->NewOpInst(TUPLE_GETITEM), split_node, CreatInt64Imm(all_gather_idx)});
1133 }
1134
GetFlashAttentionScoreOpNode(int64_t split_id,int64_t split_num,const AnfNodePtr & q,const AnfNodePtr & real_shift,const AnfNodePtr & drop_mask,const AnfNodePtr & attn_mask,AnfNodePtr * fa_op,GenerateGraph * gen_g)1135 void FlashAttentionScoreInfo::GetFlashAttentionScoreOpNode(int64_t split_id, int64_t split_num, const AnfNodePtr &q,
1136 const AnfNodePtr &real_shift, const AnfNodePtr &drop_mask,
1137 const AnfNodePtr &attn_mask, AnfNodePtr *fa_op,
1138 GenerateGraph *gen_g) {
1139 int64_t new_sparse_mode = is_attn_mask_compressed_ ? ops::kSparseBand : sparse_mode_;
1140 int64_t new_pre_tokens, new_next_tokens;
1141 if (!need_update_op_attrs_mode_) {
1142 new_pre_tokens = pre_tokens_;
1143 new_next_tokens = next_tokens_;
1144 } else {
1145 std::tie(new_pre_tokens, new_next_tokens) = GetAttentionMaskAttrs(split_id, split_num);
1146 }
1147 OperatorAttrs fa_attrs = {std::make_pair(HEAD_NUM, MakeValue(head_num_ / n1_split_num_)),
1148 std::make_pair(KEEP_PROB, MakeValue(keep_prob_)),
1149 std::make_pair(SCALE_VALUE, MakeValue(scale_value_)),
1150 std::make_pair(PRE_TOKENS, MakeValue(new_pre_tokens)),
1151 std::make_pair(NEXT_TOKENS, MakeValue(new_next_tokens)),
1152 std::make_pair(INNER_PRECISE, MakeValue<int64_t>(0)),
1153 std::make_pair(INPUT_LAYOUT, MakeValue(input_layout_)),
1154 std::make_pair(SPARSE_MODE, MakeValue<int64_t>(new_sparse_mode))};
1155 *fa_op = gen_g->PushBack({gen_g->NewOpInst(FLASH_ATTENTION_SCORE, fa_attrs), q, gen_g->virtual_input_node(),
1156 gen_g->virtual_input_node(), real_shift, drop_mask, gen_g->virtual_input_node(), attn_mask,
1157 gen_g->virtual_input_node(), gen_g->virtual_input_node(), gen_g->virtual_input_node()});
1158 }
1159
ReplaceGraphGetInputNodes(const AnfNodePtr & q_split,const AnfNodePtr & real_shift_split,const AnfNodePtr & drop_mask_split,const AnfNodePtr & attn_mask_split,const AnfNodePtr & flash_attention_score_keep,const AnfNodePtr & flash_attention_score_target)1160 std::vector<std::pair<AnfNodePtr, int64_t>> FlashAttentionScoreInfo::ReplaceGraphGetInputNodes(
1161 const AnfNodePtr &q_split, const AnfNodePtr &real_shift_split, const AnfNodePtr &drop_mask_split,
1162 const AnfNodePtr &attn_mask_split, const AnfNodePtr &flash_attention_score_keep,
1163 const AnfNodePtr &flash_attention_score_target) {
1164 std::pair<AnfNodePtr, int64_t> real_shift_input;
1165 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1166 real_shift_input = std::make_pair(real_shift_split, kIndex4);
1167 } else {
1168 real_shift_input = std::make_pair(flash_attention_score_keep, kIndex4);
1169 }
1170 std::pair<AnfNodePtr, int64_t> drop_mask_input;
1171 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1172 drop_mask_input = std::make_pair(drop_mask_split, kIndex5);
1173 } else {
1174 drop_mask_input = std::make_pair(flash_attention_score_keep, kIndex5);
1175 }
1176 std::pair<AnfNodePtr, int64_t> attn_mask_input;
1177 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1178 attn_mask_input = std::make_pair(attn_mask_split, kIndex7);
1179 } else {
1180 attn_mask_input = std::make_pair(flash_attention_score_keep, kIndex7);
1181 }
1182
1183 std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = {std::make_pair(q_split, kIndex1),
1184 std::make_pair(flash_attention_score_keep, kIndex2),
1185 std::make_pair(flash_attention_score_keep, kIndex3),
1186 real_shift_input,
1187 drop_mask_input,
1188 std::make_pair(flash_attention_score_keep, kIndex6),
1189 attn_mask_input,
1190 std::make_pair(flash_attention_score_keep, kIndex8),
1191 std::make_pair(flash_attention_score_keep, kIndex9),
1192 std::make_pair(flash_attention_score_keep, kIndex10),
1193 std::make_pair(flash_attention_score_target, kIndex2),
1194 std::make_pair(flash_attention_score_target, kIndex3)};
1195 if (!is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1196 (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex4));
1197 }
1198 if (!is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1199 (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex5));
1200 }
1201 (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex6));
1202 if (!is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] || is_attn_mask_compressed_) {
1203 (void)inputs_nodes.emplace_back(std::make_pair(flash_attention_score_target, kIndex7));
1204 }
1205 inputs_nodes.insert(inputs_nodes.end(), {std::make_pair(flash_attention_score_target, kIndex8),
1206 std::make_pair(flash_attention_score_target, kIndex9),
1207 std::make_pair(flash_attention_score_target, kIndex10)});
1208 return inputs_nodes;
1209 }
1210
ComputeReplaceGraphForLoadBalance(const CNodePtr & cnode)1211 Status FlashAttentionScoreInfo::ComputeReplaceGraphForLoadBalance(const CNodePtr &cnode) {
1212 GenerateGraph gen_g = GenerateGraph(attrs_);
1213 if (gen_g.Init(cnode) != SUCCESS) {
1214 return FAILED;
1215 }
1216 CheckGlobalDeviceManager();
1217 std::vector<int64_t> split_info = GetSplitIdAndRank();
1218 int64_t rank_id = split_info[kIndex0];
1219 int64_t target_rank_id = split_info[kIndex1];
1220 int64_t split_id = split_info[kIndex2];
1221 int64_t target_split_id = split_info[kIndex3];
1222 Group group;
1223 RankList swap_group_devices = {rank_id, target_rank_id};
1224 if (g_device_manager->CreateGroup(swap_group_devices, &group) != SUCCESS) {
1225 MS_LOG(ERROR) << "Create communication group for " << swap_group_devices << " failed";
1226 return FAILED;
1227 }
1228
1229 AnfNodePtr q_split, q_keep, q_exchange;
1230 LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputQueryIndex, &gen_g, &q_split, &q_keep, &q_exchange);
1231 AnfNodePtr real_shift_split, real_shift_keep, real_shift_exchange;
1232 LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputRealShiftIndex, &gen_g, &real_shift_split, &real_shift_keep,
1233 &real_shift_exchange);
1234 AnfNodePtr drop_mask_split, drop_mask_keep, drop_mask_exchange;
1235 LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputDropMaskIndex, &gen_g, &drop_mask_split, &drop_mask_keep,
1236 &drop_mask_exchange);
1237 AnfNodePtr attn_mask_split, attn_mask_keep, attn_mask_exchange;
1238 LoadBalanceSplitAlongSeqDim(ops::kFlashAttentionScoreInputAttnMaskIndex, &gen_g, &attn_mask_split, &attn_mask_keep,
1239 &attn_mask_exchange);
1240
1241 AnfNodePtr flash_attention_score_keep;
1242 GetFlashAttentionScoreOpNode(split_id * kLoadBalanceSplitNum, s1_split_num_ * kLoadBalanceSplitNum, q_keep,
1243 real_shift_keep, drop_mask_keep, attn_mask_keep, &flash_attention_score_keep, &gen_g);
1244 auto softmax_max_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1245 CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxMaxIndex)});
1246 auto softmax_sum_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1247 CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxSumIndex)});
1248 auto softmax_out_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1249 CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxOutIndex)});
1250 auto attention_out_keep = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_keep,
1251 CreatInt64Imm(ops::kFlashAttentionScoreOutputAttentionOutIndex)});
1252
1253 const int64_t all_gather_idx = (split_id < target_split_id) ? 1 : 0;
1254 AnfNodePtr q_target;
1255 LoadBalanceExchange(all_gather_idx, group, q_exchange, &q_target, &gen_g);
1256 AnfNodePtr real_shift_target;
1257 if (is_input_passed_[ops::kFlashAttentionScoreInputRealShiftIndex]) {
1258 LoadBalanceExchange(all_gather_idx, group, real_shift_exchange, &real_shift_target, &gen_g);
1259 } else {
1260 real_shift_target = gen_g.virtual_input_node();
1261 }
1262 AnfNodePtr drop_mask_target;
1263 if (is_input_passed_[ops::kFlashAttentionScoreInputDropMaskIndex]) {
1264 LoadBalanceExchange(all_gather_idx, group, drop_mask_exchange, &drop_mask_target, &gen_g);
1265 } else {
1266 drop_mask_target = gen_g.virtual_input_node();
1267 }
1268 AnfNodePtr attn_mask_target;
1269 if (is_input_passed_[ops::kFlashAttentionScoreInputAttnMaskIndex] && !is_attn_mask_compressed_) {
1270 LoadBalanceExchange(all_gather_idx, group, attn_mask_exchange, &attn_mask_target, &gen_g);
1271 } else {
1272 attn_mask_target = gen_g.virtual_input_node();
1273 }
1274
1275 AnfNodePtr flash_attention_score_target;
1276 GetFlashAttentionScoreOpNode(target_split_id * kLoadBalanceSplitNum + 1, s1_split_num_ * kLoadBalanceSplitNum,
1277 q_target, real_shift_target, drop_mask_target, attn_mask_target,
1278 &flash_attention_score_target, &gen_g);
1279 auto softmax_max_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1280 CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxMaxIndex)});
1281 auto softmax_sum_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1282 CreatInt64Imm(ops::kFlashAttentionScoreOutputSoftmaxSumIndex)});
1283 auto attention_out_target = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), flash_attention_score_target,
1284 CreatInt64Imm(ops::kFlashAttentionScoreOutputAttentionOutIndex)});
1285
1286 AnfNodePtr attention_out_exchange;
1287 LoadBalanceExchange(all_gather_idx, group, attention_out_target, &attention_out_exchange, &gen_g);
1288
1289 int64_t softmax_concat_axis = kOutputSoftmaxSeqDim;
1290 auto softmax_max_maketuple =
1291 gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_max_keep, softmax_max_target});
1292 auto softmax_max =
1293 gen_g.PushBack({gen_g.NewOpInst(CONCAT), softmax_max_maketuple, CreatInt64Imm(softmax_concat_axis)});
1294 auto softmax_sum_maketuple =
1295 gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_sum_keep, softmax_sum_target});
1296 auto softmax_sum =
1297 gen_g.PushBack({gen_g.NewOpInst(CONCAT), softmax_sum_maketuple, CreatInt64Imm(softmax_concat_axis)});
1298 int64_t attention_out_concat_axis = SizeToLong(qkv_seq_dim_);
1299 auto attention_out_maketuple =
1300 gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), attention_out_keep, attention_out_exchange});
1301 auto attention_out =
1302 gen_g.PushBack({gen_g.NewOpInst(CONCAT), attention_out_maketuple, CreatInt64Imm(attention_out_concat_axis)});
1303 auto output_maketuple =
1304 gen_g.PushBack({NewValueNode(prim::kPrimMakeTuple), softmax_max, softmax_sum, softmax_out_keep, attention_out});
1305
1306 std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes =
1307 ReplaceGraphGetInputNodes(q_split, real_shift_split, drop_mask_split, attn_mask_split, flash_attention_score_keep,
1308 flash_attention_score_target);
1309
1310 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
1311 std::make_pair(inputs_nodes, output_maketuple));
1312 return SUCCESS;
1313 }
1314
replace_graph(const CNodePtr & cnode)1315 ReplaceGraphPtr FlashAttentionScoreInfo::replace_graph(const CNodePtr &cnode) {
1316 if (s1_split_num_ > 1 && enable_load_balance_) {
1317 if (ComputeReplaceGraphForLoadBalance(cnode) != SUCCESS) {
1318 MS_LOG(EXCEPTION) << name_
1319 << ": FlashAttentionScore S1 sequence parallel with load balance get replace graph failed";
1320 }
1321 }
1322 return replace_graph_;
1323 }
1324
InferAsLossDivisor()1325 Status FlashAttentionScoreInfo::InferAsLossDivisor() {
1326 if (outputs_tensor_map_.empty()) {
1327 MS_LOG(ERROR) << name_ << ": The size of outputs tensor map is empty";
1328 return FAILED;
1329 }
1330 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
1331 MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
1332 << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
1333 << ", as_loss_divisor_ is " << as_loss_divisor_;
1334 return SUCCESS;
1335 }
1336
GenerateOpStrategies(int64_t stage_id)1337 std::vector<StrategyPtr> FlashAttentionScoreInfo::GenerateOpStrategies(int64_t stage_id) {
1338 InitSplittableInputs();
1339 std::vector<StrategyPtr> sp_vector;
1340 if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs_, &sp_vector) != SUCCESS) {
1341 MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
1342 }
1343 if (sp_vector.empty()) {
1344 MS_LOG(EXCEPTION) << name_ << ": No valid strategy.";
1345 }
1346 return sp_vector;
1347 }
1348
ReComputeBatchSplitFlagList()1349 void FlashAttentionScoreInfo::ReComputeBatchSplitFlagList() {
1350 split_flag_list_ = std::vector<bool>(inputs_shape_.size(), true);
1351 }
1352
InferMirrorOps()1353 Status FlashAttentionScoreInfo::InferMirrorOps() {
1354 if (OperatorInfo::InferMirrorOps() != SUCCESS) {
1355 return FAILED;
1356 }
1357 // No need to insert mirror ops
1358 if (mirror_ops_.empty()) {
1359 return SUCCESS;
1360 }
1361 // Insert empty OperatorInfo for optional input
1362 size_t cur_index = 0;
1363 std::vector<OperatorVector> real_mirror_ops(input_value_.size(), OperatorVector());
1364 for (size_t i = 0; i < input_value_.size(); ++i) {
1365 if (is_input_passed_[i]) {
1366 real_mirror_ops[i] = mirror_ops_[cur_index++];
1367 }
1368 mirror_ops_ = real_mirror_ops;
1369 }
1370 return SUCCESS;
1371 }
1372
1373 REGISTER(FlashAttentionScoreInfo);
1374 } // namespace parallel
1375 } // namespace mindspore
1376