1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cmath>
18 #include <memory>
19 #include <queue>
20 #include <utility>
21 #include <list>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25
26 #include "ops/other_ops.h"
27 #include "ops/array_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/other_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "mindspore/core/ops/nn_ops.h"
33 #include "mindspore/core/ops/make_tuple.h"
34 #include "utils/anf_utils.h"
35 #include "ir/tensor.h"
36 #include "utils/trace_base.h"
37 #include "ir/anf.h"
38 #include "ir/func_graph.h"
39 #include "include/common/debug/anf_ir_dump.h"
40 #include "include/common/utils/utils.h"
41 #include "include/common/utils/anfalgo.h"
42 #include "include/common/utils/parallel_context.h"
43 #include "include/common/utils/comm_manager.h"
44 #include "include/backend/optimizer/helper.h"
45 #include "include/backend/anf_runtime_algorithm.h"
46 #include "frontend/parallel/ops_info/ops_utils.h"
47 #include "frontend/parallel/ops_info/operator_info.h"
48 #include "frontend/parallel/tensor_layout/tensor_info.h"
49 #include "frontend/parallel/device_matrix.h"
50 #include "pipeline/jit/ps/action.h"
51 #include "mindspore/ccsrc/include/backend/optimizer/helper.h"
52 #include "mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h"
53 #include "mindspore/core/ops/op_enum.h"
54 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
55 #include "frontend/parallel/step_parallel_utils.h"
56 #include "mindspore/ccsrc/frontend/parallel/ops_info/flash_attention_score_info.h"
57 #include "frontend/optimizer/flash_sp.h"
58 #include "frontend/parallel/graph_util/graph_info.h"
59
60 namespace mindspore {
61 using mindspore::ops::FASInputLayoutMode;
62 namespace parallel {
FlashSPInfo(CNodePtr fa_score_node)63 FlashSPInfo::FlashSPInfo(CNodePtr fa_score_node) {
64 MS_EXCEPTION_IF_NULL(fa_score_node);
65 std::shared_ptr<OperatorInfo> operator_info = fa_score_node->user_data<parallel::OperatorInfo>();
66 MS_EXCEPTION_IF_NULL(operator_info);
67 auto flash_score_info_ptr = std::dynamic_pointer_cast<FlashAttentionScoreInfo>(operator_info);
68 MS_EXCEPTION_IF_NULL(flash_score_info_ptr);
69
70 flashsp_num_ = flash_score_info_ptr->s1_split_num();
71 dev_rank_id_ = g_device_manager->global_rank();
72
73 auto rankList = flash_score_info_ptr->GetSPRankList();
74 size_t pos = -1;
75 for (size_t i = 0; i < rankList.size(); ++i) {
76 if (dev_rank_id_ == rankList[i]) {
77 pos = i;
78 }
79 }
80 send_rank_id_ = rankList[(pos + 1) % rankList.size()];
81 recv_rank_id_ = rankList[(pos + rankList.size() - 1) % rankList.size()];
82 }
83 namespace {
84 using CNodePtrPair = std::pair<CNodePtr, CNodePtr>;
85 using FSPInfo = FlashSPInfo;
86
FindFWFlashAttentionScore(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & origin_nodes_topological)87 std::vector<CNodePtr> FindFWFlashAttentionScore(const FuncGraphManagerPtr &manager,
88 const std::vector<AnfNodePtr> &origin_nodes_topological) {
89 std::vector<CNodePtr> result;
90 for (size_t i = 0; i < origin_nodes_topological.size(); ++i) {
91 auto node = origin_nodes_topological[i];
92 if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
93 result.push_back(node->cast<CNodePtr>());
94 }
95 }
96 return result;
97 }
98
NewReshapeNode(const AnfNodePtr & input_node,const ShapeVector & output_shape,const TypeId & output_type)99 CNodePtr NewReshapeNode(const AnfNodePtr &input_node, const ShapeVector &output_shape, const TypeId &output_type) {
100 MS_EXCEPTION_IF_NULL(input_node);
101 std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
102 input_node, NewValueNode(MakeValue(output_shape))};
103 auto reshape = input_node->func_graph()->NewCNode(reshape_inputs);
104 MS_EXCEPTION_IF_NULL(reshape);
105
106 common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(output_shape), reshape);
107 reshape->set_scope(input_node->scope());
108 return reshape;
109 }
110
NewConcatNode(const AnfNodePtr & input_node,size_t concat_dim)111 CNodePtr NewConcatNode(const AnfNodePtr &input_node, size_t concat_dim) {
112 MS_EXCEPTION_IF_NULL(input_node);
113 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
114 input_node, NewValueNode(MakeValue(static_cast<int64_t>(concat_dim)))};
115 auto concat = input_node->func_graph()->NewCNode(concat_inputs);
116 MS_EXCEPTION_IF_NULL(concat);
117 concat->set_scope(input_node->scope());
118 return concat;
119 }
120
NewMakeTupleNode(const std::vector<AnfNodePtr> & input_nodes)121 CNodePtr NewMakeTupleNode(const std::vector<AnfNodePtr> &input_nodes) {
122 // input_nodes are getitem nodes
123 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
124 for (size_t i = 0; i < input_nodes.size(); ++i) {
125 make_tuple_inputs.push_back(input_nodes[i]);
126 }
127 auto make_tuple = input_nodes[0]->func_graph()->NewCNode(make_tuple_inputs);
128 MS_EXCEPTION_IF_NULL(make_tuple);
129 make_tuple->set_scope(input_nodes[0]->scope());
130 return make_tuple;
131 }
132
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num)133 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num) {
134 if (split_num == 0) {
135 MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
136 }
137 MS_EXCEPTION_IF_NULL(input_node);
138 std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
139 input_node, NewValueNode<int64_t>(split_dim),
140 NewValueNode<int64_t>(split_num)};
141 auto split = input_node->func_graph()->NewCNode(split_inputs);
142 MS_EXCEPTION_IF_NULL(split);
143 split->set_scope(input_node->scope());
144 return split;
145 }
146
NewTupleGetItemNode(const AnfNodePtr & input_node,size_t output_index)147 CNodePtr NewTupleGetItemNode(const AnfNodePtr &input_node, size_t output_index) {
148 MS_EXCEPTION_IF_NULL(input_node);
149 auto idx = NewValueNode(SizeToLong(output_index));
150 MS_EXCEPTION_IF_NULL(idx);
151 auto getitem = input_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input_node, idx});
152 MS_EXCEPTION_IF_NULL(getitem);
153 getitem->set_scope(input_node->scope());
154 return getitem;
155 }
156
NewNeighborExchangeNode(const AnfNodePtr & input_node,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids,int fa_index,int ne_index,parallel::Shape neigh_shape)157 CNodePtr NewNeighborExchangeNode(const AnfNodePtr &input_node, const std::vector<int64_t> &send_rank_ids,
158 const std::vector<int64_t> &recv_rank_ids, int fa_index, int ne_index,
159 parallel::Shape neigh_shape) {
160 MS_EXCEPTION_IF_NULL(input_node);
161 // input_node is maketuple node
162 std::vector<AnfNodePtr> ne_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimNeighborExchange->name())),
163 input_node};
164 auto neighbor_exchange = input_node->func_graph()->NewCNode(ne_inputs);
165 MS_EXCEPTION_IF_NULL(neighbor_exchange);
166
167 // RECV_TYPE
168 auto dtype = TypeId::kNumberTypeFloat16;
169 common::AnfAlgo::SetNodeAttr(parallel::RECV_TYPE, TypeIdToType(dtype), neighbor_exchange);
170
171 std::stringstream ss;
172 ss << fa_index << "_" << ne_index;
173 std::string ss_result = ss.str();
174 common::AnfAlgo::SetNodeAttr("FLASH_INDEX", MakeValue<std::string>(ss_result), neighbor_exchange);
175
176 // GROUP
177 std::string group = g_device_manager->world_group();
178 common::AnfAlgo::SetNodeAttr(parallel::GROUP, MakeValue<std::string>(group), neighbor_exchange);
179
180 // SEND_RANK_IDS, RECV_RANK_IDS
181 common::AnfAlgo::SetNodeAttr(parallel::SEND_RANK_IDS, parallel::MakeListValue(send_rank_ids), neighbor_exchange);
182 common::AnfAlgo::SetNodeAttr(parallel::RECV_RANK_IDS, parallel::MakeListValue(recv_rank_ids), neighbor_exchange);
183
184 // SEND_SHAPES, RECV_SHAPES
185 parallel::Shape shape = neigh_shape;
186 parallel::Shapes send_shapes;
187 parallel::Shapes recv_shapes;
188 for (size_t i = 0; i < send_rank_ids.size(); ++i) {
189 send_shapes.push_back(shape);
190 recv_shapes.push_back(shape);
191 }
192 common::AnfAlgo::SetNodeAttr(parallel::SEND_SHAPES, parallel::MakeTupleListValue(send_shapes), neighbor_exchange);
193 common::AnfAlgo::SetNodeAttr(parallel::RECV_SHAPES, parallel::MakeTupleListValue(recv_shapes), neighbor_exchange);
194
195 common::AnfAlgo::SetNodeAttr(parallel::COMM_REUSE, MakeValue(true), neighbor_exchange);
196
197 neighbor_exchange->set_scope(input_node->scope());
198 return neighbor_exchange;
199 }
200
NewFlashAttentionScoreNode(const std::vector<AnfNodePtr> & input_nodes,int fa_index,int ne_index)201 CNodePtr NewFlashAttentionScoreNode(const std::vector<AnfNodePtr> &input_nodes, int fa_index, int ne_index) {
202 std::vector<AnfNodePtr> fa_inputs = {
203 NewValueNode(std::make_shared<Primitive>(prim::kPrimFlashAttentionScore->name()))};
204
205 for (size_t i = 0; i < input_nodes.size(); ++i) {
206 fa_inputs.push_back(input_nodes[i]);
207 }
208 auto fa_score = input_nodes[0]->func_graph()->NewCNode(fa_inputs);
209 MS_EXCEPTION_IF_NULL(fa_score);
210
211 std::stringstream ss;
212 ss << fa_index << "_" << ne_index;
213 std::string ss_result = ss.str();
214 common::AnfAlgo::SetNodeAttr(FLASH_INDEX, MakeValue<std::string>(ss_result), fa_score);
215 fa_score->set_scope(input_nodes[0]->scope());
216 return fa_score;
217 }
218
NewAddNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)219 CNodePtr NewAddNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
220 std::vector<AnfNodePtr> add_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAdd->name())), left_node,
221 right_node};
222 auto add_node = left_node->func_graph()->NewCNode(add_inputs);
223 MS_EXCEPTION_IF_NULL(add_node);
224 add_node->set_scope(left_node->scope());
225 return add_node;
226 }
227
NewSubNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)228 CNodePtr NewSubNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
229 MS_EXCEPTION_IF_NULL(left_node);
230 MS_EXCEPTION_IF_NULL(right_node);
231 std::vector<AnfNodePtr> sub_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSub->name())), left_node,
232 right_node};
233 auto sub_node = left_node->func_graph()->NewCNode(sub_inputs);
234 MS_EXCEPTION_IF_NULL(sub_node);
235 sub_node->set_scope(left_node->scope());
236 return sub_node;
237 }
238
NewMulNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)239 CNodePtr NewMulNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
240 MS_EXCEPTION_IF_NULL(left_node);
241 MS_EXCEPTION_IF_NULL(right_node);
242 std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMul->name())), left_node,
243 right_node};
244 auto mul_node = left_node->func_graph()->NewCNode(mul_inputs);
245 MS_EXCEPTION_IF_NULL(mul_node);
246 mul_node->set_scope(left_node->scope());
247 return mul_node;
248 }
249
NewDivNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)250 CNodePtr NewDivNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
251 MS_EXCEPTION_IF_NULL(left_node);
252 MS_EXCEPTION_IF_NULL(right_node);
253 std::vector<AnfNodePtr> div_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimRealDiv->name())),
254 left_node, right_node};
255 auto div_node = left_node->func_graph()->NewCNode(div_inputs);
256 MS_EXCEPTION_IF_NULL(div_node);
257 div_node->set_scope(left_node->scope());
258 return div_node;
259 }
260
NewExpNode(const AnfNodePtr & left_node)261 CNodePtr NewExpNode(const AnfNodePtr &left_node) {
262 MS_EXCEPTION_IF_NULL(left_node);
263 std::vector<AnfNodePtr> exp_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimExp->name())), left_node};
264 auto exp_node = left_node->func_graph()->NewCNode(exp_inputs);
265 MS_EXCEPTION_IF_NULL(exp_node);
266 exp_node->set_scope(left_node->scope());
267 return exp_node;
268 }
269
NewMaxNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)270 CNodePtr NewMaxNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
271 MS_EXCEPTION_IF_NULL(left_node);
272 MS_EXCEPTION_IF_NULL(right_node);
273 std::vector<AnfNodePtr> max_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMaximum->name())),
274 left_node, right_node};
275 auto max_node = left_node->func_graph()->NewCNode(max_inputs);
276 MS_EXCEPTION_IF_NULL(max_node);
277 max_node->set_scope(left_node->scope());
278 return max_node;
279 }
280
NewCastNode(const AnfNodePtr & tensor_node,const TypeId & dtype)281 CNodePtr NewCastNode(const AnfNodePtr &tensor_node, const TypeId &dtype) {
282 MS_EXCEPTION_IF_NULL(tensor_node);
283 auto type_node = NewValueNode(static_cast<int64_t>(dtype));
284 std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
285 tensor_node, type_node};
286 auto cast_node = tensor_node->func_graph()->NewCNode(cast_inputs);
287
288 MS_EXCEPTION_IF_NULL(cast_node);
289 common::AnfAlgo::SetNodeAttrSafely(kAttrDstType, TypeIdToType(dtype), cast_node);
290 cast_node->set_scope(tensor_node->scope());
291 return cast_node;
292 }
293
NewTransposeNode(const AnfNodePtr & tensor_node,const AnfNodePtr & tuple,ShapeVector output_shape)294 CNodePtr NewTransposeNode(const AnfNodePtr &tensor_node, const AnfNodePtr &tuple, ShapeVector output_shape) {
295 MS_EXCEPTION_IF_NULL(tensor_node);
296 MS_EXCEPTION_IF_NULL(tuple);
297 std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTranspose->name())),
298 tensor_node, tuple};
299 auto transpose_node = tensor_node->func_graph()->NewCNode(transpose_inputs);
300 MS_EXCEPTION_IF_NULL(transpose_node);
301 transpose_node->set_scope(tensor_node->scope());
302 return transpose_node;
303 }
304
NewTileNode(const AnfNodePtr & tensor_node,const AnfNodePtr & tuple)305 CNodePtr NewTileNode(const AnfNodePtr &tensor_node, const AnfNodePtr &tuple) {
306 MS_EXCEPTION_IF_NULL(tensor_node);
307 MS_EXCEPTION_IF_NULL(tuple);
308 std::vector<AnfNodePtr> tile_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTile->name())),
309 tensor_node, tuple};
310 auto tile_node = tensor_node->func_graph()->NewCNode(tile_inputs);
311 MS_EXCEPTION_IF_NULL(tile_node);
312 tile_node->set_scope(tensor_node->scope());
313 return tile_node;
314 }
315
make_mask_tensor(TypeId type_id,ShapeVector shape,uint8_t value,bool is_causle)316 tensor::TensorPtr make_mask_tensor(TypeId type_id, ShapeVector shape, uint8_t value, bool is_causle) {
317 tensor::TensorPtr mask_tensor = std::make_shared<mindspore::tensor::Tensor>(type_id, shape);
318 int tensor_size = SizeToInt(mask_tensor->data().size());
319 uint8_t *uint8_data = reinterpret_cast<uint8_t *>(mask_tensor->data_c());
320 if (!is_causle) {
321 for (int i = 0; i < tensor_size; ++i) {
322 uint8_data[i] = value;
323 }
324 } else {
325 for (int i = 0; i < shape[kIndex0]; ++i) {
326 for (int j = 0; j < shape[kIndex1]; ++j) {
327 if (i >= j) {
328 uint8_data[i * shape[kIndex0] + j] = 0;
329 } else {
330 uint8_data[i * shape[kIndex0] + j] = 1;
331 }
332 }
333 }
334 }
335 return mask_tensor;
336 }
337
GetActualMask(int index,int64_t rank_id,TypeId mask_dtype,ShapeVector mask_shape)338 AnfNodePtr GetActualMask(int index, int64_t rank_id, TypeId mask_dtype, ShapeVector mask_shape) {
339 AnfNodePtr actual_mask;
340 if (index == 0) {
341 auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 0, true);
342 actual_mask = NewValueNode(MakeValue(mask_tensor));
343 } else if (index <= rank_id) {
344 auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 0, false);
345 actual_mask = NewValueNode(MakeValue(mask_tensor));
346 } else {
347 auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 1, false);
348 actual_mask = NewValueNode(MakeValue(mask_tensor));
349 }
350 return actual_mask;
351 }
352
GetPosInSpDevice(std::shared_ptr<FlashAttentionScoreInfo> flash_score_info_ptr,int64_t rank_id)353 int64_t GetPosInSpDevice(std::shared_ptr<FlashAttentionScoreInfo> flash_score_info_ptr, int64_t rank_id) {
354 auto rankList = flash_score_info_ptr->GetSPRankList();
355 int64_t pos = -1;
356 for (size_t rank_list_idx = 0; rank_list_idx < rankList.size(); ++rank_list_idx) {
357 if (rank_id == rankList[rank_list_idx]) {
358 pos = rank_list_idx;
359 }
360 }
361 return pos;
362 }
363
GetBSHFromShape(int64_t input_layout,Shape q_shape,Shape kv_shape,int64_t * fa_b,int64_t * fa_s1,int64_t * fa_h1,int64_t * fa_s2)364 void GetBSHFromShape(int64_t input_layout, Shape q_shape, Shape kv_shape, int64_t *fa_b, int64_t *fa_s1, int64_t *fa_h1,
365 int64_t *fa_s2) {
366 if (input_layout == FASInputLayoutMode::BSH) {
367 *fa_b = q_shape[kIndex0];
368 *fa_s1 = q_shape[kIndex1];
369 *fa_h1 = q_shape[kIndex2];
370 *fa_s2 = kv_shape[kIndex1];
371 } else if (input_layout == FASInputLayoutMode::BNSD) {
372 *fa_b = q_shape[kIndex0];
373 *fa_s1 = q_shape[kIndex2];
374 *fa_h1 = q_shape[kIndex1] * q_shape[kIndex3];
375 *fa_s2 = kv_shape[kIndex2];
376 }
377 }
378
GetFlashIndexString(int fa_index,int index)379 ValuePtr GetFlashIndexString(int fa_index, int index) {
380 std::stringstream ss;
381 ss << fa_index << "_" << index;
382 std::string ss_result = ss.str();
383 return MakeValue<std::string>(ss_result);
384 }
385
UpdateAttentionOutput(CNodePtr * history_max,CNodePtr * history_sum,CNodePtr * acc_attention,const CNodePtr & softmax_max,const CNodePtr & softmax_sum,CNodePtr attention_output,int64_t fa_b,int64_t fa_s1,int64_t fa_n1,int64_t fa_h1,int64_t input_layout,int fa_index,int index)386 void UpdateAttentionOutput(CNodePtr *history_max, CNodePtr *history_sum, CNodePtr *acc_attention,
387 const CNodePtr &softmax_max, const CNodePtr &softmax_sum, CNodePtr attention_output,
388 int64_t fa_b, int64_t fa_s1, int64_t fa_n1, int64_t fa_h1, int64_t input_layout,
389 int fa_index, int index) {
390 auto temp_max = NewMaxNode(*history_max, softmax_max);
391 auto m_h_sub_temp = NewSubNode(*history_max, temp_max);
392 auto m_i_sub_temp = NewSubNode(softmax_max, temp_max);
393 auto e_m_h_temp = NewExpNode(m_h_sub_temp);
394 auto e_m_i_temp = NewExpNode(m_i_sub_temp);
395 auto e_l_h = NewMulNode(e_m_h_temp, *history_sum);
396 auto e_l_i = NewMulNode(e_m_i_temp, softmax_sum);
397 auto l = NewAddNode(e_l_h, e_l_i);
398 auto e_m_h_div = NewDivNode(e_l_h, l);
399 auto e_m_i_div = NewDivNode(e_l_i, l);
400 auto e_m_h_div_split = NewSplitNode(e_m_h_div, 3, 8);
401 auto e_m_h_div_item = NewTupleGetItemNode(e_m_h_div_split, 0);
402 auto e_m_h_div_concat = NewTileNode(e_m_h_div_item, parallel::CreateTuple({1, 1, 1, fa_h1 / fa_n1}));
403 auto e_m_i_div_split = NewSplitNode(e_m_i_div, 3, 8);
404 auto e_m_i_div_item = NewTupleGetItemNode(e_m_i_div_split, 0);
405 auto e_m_i_div_concat = NewTileNode(e_m_i_div_item, parallel::CreateTuple({1, 1, 1, fa_h1 / fa_n1}));
406 if (input_layout == FASInputLayoutMode::BSH) {
407 (*acc_attention) = NewReshapeNode(*acc_attention, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1}, TypeId::kNumberTypeFloat16);
408 attention_output =
409 NewReshapeNode(attention_output, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1}, TypeId::kNumberTypeFloat16);
410 AnfNodePtr tmp_tup = parallel::CreateTuple({0, 2, 1, 3});
411 (*acc_attention) = NewTransposeNode(*acc_attention, tmp_tup, {fa_b, fa_n1, fa_s1, fa_h1 / fa_n1});
412 attention_output = NewTransposeNode(attention_output, tmp_tup, {fa_b, fa_n1, fa_s1, fa_h1 / fa_n1});
413 }
414 (*acc_attention) = NewCastNode(*acc_attention, TypeId::kNumberTypeFloat32);
415 attention_output = NewCastNode(attention_output, TypeId::kNumberTypeFloat32);
416 auto weighted_history = NewMulNode(e_m_h_div_concat, *acc_attention);
417 auto weighted_attention = NewMulNode(e_m_i_div_concat, attention_output);
418 (*acc_attention) = NewAddNode(weighted_history, weighted_attention);
419 common::AnfAlgo::SetNodeAttr(kAttrAccumulatedAttention, MakeValue(1), *acc_attention);
420 common::AnfAlgo::SetNodeAttr("FLASH_INDEX", GetFlashIndexString(fa_index, index), *acc_attention);
421 if (input_layout == FASInputLayoutMode::BSH) {
422 auto tmp_tup1 = parallel::CreateTuple({0, 2, 1, 3});
423 (*acc_attention) = NewTransposeNode(*acc_attention, tmp_tup1, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1});
424 (*acc_attention) = NewReshapeNode(*acc_attention, {fa_b, fa_s1, fa_h1}, TypeId::kNumberTypeFloat32);
425 }
426 (*history_max) = temp_max;
427 (*history_sum) = l;
428 }
429
CreateReplaceFSPGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & fa_score_node,FSPInfo * fsp_info,int fa_index)430 CNodePtr CreateReplaceFSPGraph(const FuncGraphManagerPtr &manager,
431 const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtr &fa_score_node,
432 FSPInfo *fsp_info, int fa_index) {
433 std::vector<AnfNodePtr> fa_inputs;
434 for (size_t i = 0; i < ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputsNum; ++i) {
435 fa_inputs.push_back(fa_score_node->input(i + 1));
436 }
437
438 auto key_node = fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputKeyIndex + 1);
439 auto value_node = fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputValueIndex + 1);
440
441 int64_t sp_num = fsp_info->GetSPNum(), rank_id = fsp_info->GetRankId();
442 int64_t send_rank_id = fsp_info->GetSendRankId(), recv_rank_id = fsp_info->GetRecvRankId();
443
444 std::shared_ptr<OperatorInfo> operator_info = fa_score_node->user_data<parallel::OperatorInfo>();
445 auto flash_score_info_ptr = std::dynamic_pointer_cast<FlashAttentionScoreInfo>(operator_info);
446 auto q_shape = operator_info->inputs_tensor_info()[kIndex0].tensor_layout().base_slice_shape().array();
447 auto kv_shape = operator_info->inputs_tensor_info()[kIndex1].tensor_layout().base_slice_shape().array();
448 auto input_layout = flash_score_info_ptr->input_layout();
449 if (input_layout != FASInputLayoutMode::BSH && input_layout != FASInputLayoutMode::BNSD) {
450 return nullptr;
451 }
452
453 int64_t fa_n1 = GetValue<int64_t>(
454 fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputHeadNumIndex + 1)
455 ->cast<ValueNodePtr>()
456 ->value());
457 int64_t fa_b, fa_s1, fa_h1, fa_s2;
458 GetBSHFromShape(input_layout, q_shape, kv_shape, &fa_b, &fa_s1, &fa_h1, &fa_s2);
459
460 CNodePtr local_fa_node, kv_received_tuple, softmax_max, softmax_sum, softmax_out, attention_output;
461 CNodePtr history_max, history_sum, acc_attention;
462 AnfNodePtr actual_mask;
463 for (int i = 0; i < sp_num; ++i) {
464 std::vector<AnfNodePtr> kv_nodes = {key_node, value_node};
465 auto kv_tuple = NewMakeTupleNode(kv_nodes);
466 auto kv_concat = NewConcatNode(kv_tuple, 0);
467 std::vector<AnfNodePtr> concat_tuple = {kv_concat};
468 auto kv_concat_tuple = NewMakeTupleNode(concat_tuple);
469 if (i != sp_num - 1) {
470 auto neigh_shape = kv_shape;
471 neigh_shape[0] = neigh_shape[0] * kIndex2;
472 kv_received_tuple =
473 NewNeighborExchangeNode(kv_concat_tuple, {send_rank_id}, {recv_rank_id}, fa_index, i, neigh_shape);
474 }
475
476 auto pos = GetPosInSpDevice(flash_score_info_ptr, rank_id);
477 actual_mask = GetActualMask(i, pos, TypeId::kNumberTypeUInt8, Shape{fa_s1, fa_s2});
478 fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputKeyIndex] = key_node;
479 fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputValueIndex] = value_node;
480 fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputAttnMaskIndex] = actual_mask;
481 local_fa_node = NewFlashAttentionScoreNode(fa_inputs, fa_index, i);
482 common::AnfAlgo::CopyNodeAttrs(fa_score_node, local_fa_node);
483
484 if (i != sp_num - 1) {
485 auto kv_exchanged_item = NewTupleGetItemNode(kv_received_tuple, kIndex0);
486 auto kv_split = NewSplitNode(kv_exchanged_item, kIndex0, kIndex2);
487 key_node = NewTupleGetItemNode(kv_split, kIndex0);
488 value_node = NewTupleGetItemNode(kv_split, kIndex1);
489 }
490
491 softmax_max = NewTupleGetItemNode(local_fa_node, kIndex0);
492 softmax_sum = NewTupleGetItemNode(local_fa_node, kIndex1);
493 attention_output = NewTupleGetItemNode(local_fa_node, kIndex3);
494
495 if (i == 0) {
496 acc_attention = attention_output->cast<CNodePtr>();
497 history_max = softmax_max->cast<CNodePtr>();
498 history_sum = softmax_sum->cast<CNodePtr>();
499 } else {
500 UpdateAttentionOutput(&history_max, &history_sum, &acc_attention, softmax_max, softmax_sum, attention_output,
501 fa_b, fa_s1, fa_n1, fa_h1, input_layout, fa_index, i);
502 }
503 }
504 acc_attention = NewCastNode(acc_attention, TypeId::kNumberTypeFloat16);
505 softmax_out = NewTupleGetItemNode(local_fa_node, kIndex2);
506 std::vector<AnfNodePtr> output_tuple = {history_max, history_sum, softmax_out, acc_attention};
507 auto attention_results = NewMakeTupleNode(output_tuple);
508 return attention_results;
509 }
510
CreateAndReplaceFAScore(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & fa_score_node,FSPInfo * fsp_info,int i)511 void CreateAndReplaceFAScore(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
512 const CNodePtr &fa_score_node, FSPInfo *fsp_info, int i) {
513 auto cnode = CreateReplaceFSPGraph(manager, origin_nodes_topological, fa_score_node, fsp_info, i);
514 MS_EXCEPTION_IF_NULL(cnode);
515 (void)manager->Replace(fa_score_node, cnode);
516 }
517
CheckUserSettings(const FuncGraphPtr & fg,FSPInfo * fsp_info)518 bool CheckUserSettings(const FuncGraphPtr &fg, FSPInfo *fsp_info) {
519 fsp_info->DisplayInfo();
520
521 int64_t sp_num = fsp_info->GetSPNum();
522 if (sp_num <= 1) {
523 MS_LOG(WARNING) << "FSP: To activate the pass, sp num " << sp_num << " should between larger than 1";
524 return false;
525 }
526 return true;
527 }
528 } // namespace
529
SetFlashSP(const FuncGraphPtr & func_graph)530 bool SetFlashSP(const FuncGraphPtr &func_graph) {
531 auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
532 if (parallel_mode != kSemiAutoParallel) {
533 return false;
534 }
535
536 MS_EXCEPTION_IF_NULL(func_graph);
537 auto manager = func_graph->manager();
538 MS_EXCEPTION_IF_NULL(manager);
539 auto ret = func_graph->get_return();
540 auto origin_nodes_topological = DeepScopedGraphSearch(ret);
541
542 std::vector<CNodePtr> fa_score_nodes = FindFWFlashAttentionScore(manager, origin_nodes_topological);
543 if (fa_score_nodes.size() == 0) {
544 return false;
545 }
546
547 for (size_t i = 0; i < fa_score_nodes.size(); ++i) {
548 auto fa_score_node = fa_score_nodes[i];
549 auto fa_score_node_prim = GetCNodePrimitive(fa_score_node);
550 MS_EXCEPTION_IF_NULL(fa_score_node_prim);
551 if (!fa_score_node_prim->HasAttr(parallel::ENABLE_RING_ATTENTION) ||
552 !GetValue<bool>((fa_score_node_prim->GetAttr(parallel::ENABLE_RING_ATTENTION)))) {
553 continue;
554 }
555
556 auto fsp_info = FSPInfo(fa_score_node);
557 if (!CheckUserSettings(func_graph, &fsp_info)) {
558 return false;
559 }
560
561 manager = func_graph->manager();
562 MS_EXCEPTION_IF_NULL(manager);
563 auto orders = func_graph->GetOrderedCnodes();
564 std::vector<CNodePtr> nodes_topological(orders.cbegin(), orders.cend());
565 CreateAndReplaceFAScore(manager, nodes_topological, fa_score_node, &fsp_info, i);
566 }
567 return true;
568 }
569 } // namespace parallel
570 } // namespace mindspore
571