• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cmath>
18 #include <memory>
19 #include <queue>
20 #include <utility>
21 #include <list>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25 
26 #include "ops/other_ops.h"
27 #include "ops/array_ops.h"
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/other_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "mindspore/core/ops/nn_ops.h"
33 #include "mindspore/core/ops/make_tuple.h"
34 #include "utils/anf_utils.h"
35 #include "ir/tensor.h"
36 #include "utils/trace_base.h"
37 #include "ir/anf.h"
38 #include "ir/func_graph.h"
39 #include "include/common/debug/anf_ir_dump.h"
40 #include "include/common/utils/utils.h"
41 #include "include/common/utils/anfalgo.h"
42 #include "include/common/utils/parallel_context.h"
43 #include "include/common/utils/comm_manager.h"
44 #include "include/backend/optimizer/helper.h"
45 #include "include/backend/anf_runtime_algorithm.h"
46 #include "frontend/parallel/ops_info/ops_utils.h"
47 #include "frontend/parallel/ops_info/operator_info.h"
48 #include "frontend/parallel/tensor_layout/tensor_info.h"
49 #include "frontend/parallel/device_matrix.h"
50 #include "pipeline/jit/ps/action.h"
51 #include "mindspore/ccsrc/include/backend/optimizer/helper.h"
52 #include "mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h"
53 #include "mindspore/core/ops/op_enum.h"
54 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
55 #include "frontend/parallel/step_parallel_utils.h"
56 #include "mindspore/ccsrc/frontend/parallel/ops_info/flash_attention_score_info.h"
57 #include "frontend/optimizer/flash_sp.h"
58 #include "frontend/parallel/graph_util/graph_info.h"
59 
60 namespace mindspore {
61 using mindspore::ops::FASInputLayoutMode;
62 namespace parallel {
FlashSPInfo(CNodePtr fa_score_node)63 FlashSPInfo::FlashSPInfo(CNodePtr fa_score_node) {
64   MS_EXCEPTION_IF_NULL(fa_score_node);
65   std::shared_ptr<OperatorInfo> operator_info = fa_score_node->user_data<parallel::OperatorInfo>();
66   MS_EXCEPTION_IF_NULL(operator_info);
67   auto flash_score_info_ptr = std::dynamic_pointer_cast<FlashAttentionScoreInfo>(operator_info);
68   MS_EXCEPTION_IF_NULL(flash_score_info_ptr);
69 
70   flashsp_num_ = flash_score_info_ptr->s1_split_num();
71   dev_rank_id_ = g_device_manager->global_rank();
72 
73   auto rankList = flash_score_info_ptr->GetSPRankList();
74   size_t pos = -1;
75   for (size_t i = 0; i < rankList.size(); ++i) {
76     if (dev_rank_id_ == rankList[i]) {
77       pos = i;
78     }
79   }
80   send_rank_id_ = rankList[(pos + 1) % rankList.size()];
81   recv_rank_id_ = rankList[(pos + rankList.size() - 1) % rankList.size()];
82 }
83 namespace {
84 using CNodePtrPair = std::pair<CNodePtr, CNodePtr>;
85 using FSPInfo = FlashSPInfo;
86 
FindFWFlashAttentionScore(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & origin_nodes_topological)87 std::vector<CNodePtr> FindFWFlashAttentionScore(const FuncGraphManagerPtr &manager,
88                                                 const std::vector<AnfNodePtr> &origin_nodes_topological) {
89   std::vector<CNodePtr> result;
90   for (size_t i = 0; i < origin_nodes_topological.size(); ++i) {
91     auto node = origin_nodes_topological[i];
92     if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
93       result.push_back(node->cast<CNodePtr>());
94     }
95   }
96   return result;
97 }
98 
NewReshapeNode(const AnfNodePtr & input_node,const ShapeVector & output_shape,const TypeId & output_type)99 CNodePtr NewReshapeNode(const AnfNodePtr &input_node, const ShapeVector &output_shape, const TypeId &output_type) {
100   MS_EXCEPTION_IF_NULL(input_node);
101   std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
102                                             input_node, NewValueNode(MakeValue(output_shape))};
103   auto reshape = input_node->func_graph()->NewCNode(reshape_inputs);
104   MS_EXCEPTION_IF_NULL(reshape);
105 
106   common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(output_shape), reshape);
107   reshape->set_scope(input_node->scope());
108   return reshape;
109 }
110 
NewConcatNode(const AnfNodePtr & input_node,size_t concat_dim)111 CNodePtr NewConcatNode(const AnfNodePtr &input_node, size_t concat_dim) {
112   MS_EXCEPTION_IF_NULL(input_node);
113   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
114                                            input_node, NewValueNode(MakeValue(static_cast<int64_t>(concat_dim)))};
115   auto concat = input_node->func_graph()->NewCNode(concat_inputs);
116   MS_EXCEPTION_IF_NULL(concat);
117   concat->set_scope(input_node->scope());
118   return concat;
119 }
120 
NewMakeTupleNode(const std::vector<AnfNodePtr> & input_nodes)121 CNodePtr NewMakeTupleNode(const std::vector<AnfNodePtr> &input_nodes) {
122   // input_nodes are getitem nodes
123   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
124   for (size_t i = 0; i < input_nodes.size(); ++i) {
125     make_tuple_inputs.push_back(input_nodes[i]);
126   }
127   auto make_tuple = input_nodes[0]->func_graph()->NewCNode(make_tuple_inputs);
128   MS_EXCEPTION_IF_NULL(make_tuple);
129   make_tuple->set_scope(input_nodes[0]->scope());
130   return make_tuple;
131 }
132 
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num)133 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num) {
134   if (split_num == 0) {
135     MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
136   }
137   MS_EXCEPTION_IF_NULL(input_node);
138   std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
139                                           input_node, NewValueNode<int64_t>(split_dim),
140                                           NewValueNode<int64_t>(split_num)};
141   auto split = input_node->func_graph()->NewCNode(split_inputs);
142   MS_EXCEPTION_IF_NULL(split);
143   split->set_scope(input_node->scope());
144   return split;
145 }
146 
NewTupleGetItemNode(const AnfNodePtr & input_node,size_t output_index)147 CNodePtr NewTupleGetItemNode(const AnfNodePtr &input_node, size_t output_index) {
148   MS_EXCEPTION_IF_NULL(input_node);
149   auto idx = NewValueNode(SizeToLong(output_index));
150   MS_EXCEPTION_IF_NULL(idx);
151   auto getitem = input_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input_node, idx});
152   MS_EXCEPTION_IF_NULL(getitem);
153   getitem->set_scope(input_node->scope());
154   return getitem;
155 }
156 
NewNeighborExchangeNode(const AnfNodePtr & input_node,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids,int fa_index,int ne_index,parallel::Shape neigh_shape)157 CNodePtr NewNeighborExchangeNode(const AnfNodePtr &input_node, const std::vector<int64_t> &send_rank_ids,
158                                  const std::vector<int64_t> &recv_rank_ids, int fa_index, int ne_index,
159                                  parallel::Shape neigh_shape) {
160   MS_EXCEPTION_IF_NULL(input_node);
161   // input_node is maketuple node
162   std::vector<AnfNodePtr> ne_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimNeighborExchange->name())),
163                                        input_node};
164   auto neighbor_exchange = input_node->func_graph()->NewCNode(ne_inputs);
165   MS_EXCEPTION_IF_NULL(neighbor_exchange);
166 
167   // RECV_TYPE
168   auto dtype = TypeId::kNumberTypeFloat16;
169   common::AnfAlgo::SetNodeAttr(parallel::RECV_TYPE, TypeIdToType(dtype), neighbor_exchange);
170 
171   std::stringstream ss;
172   ss << fa_index << "_" << ne_index;
173   std::string ss_result = ss.str();
174   common::AnfAlgo::SetNodeAttr("FLASH_INDEX", MakeValue<std::string>(ss_result), neighbor_exchange);
175 
176   // GROUP
177   std::string group = g_device_manager->world_group();
178   common::AnfAlgo::SetNodeAttr(parallel::GROUP, MakeValue<std::string>(group), neighbor_exchange);
179 
180   // SEND_RANK_IDS, RECV_RANK_IDS
181   common::AnfAlgo::SetNodeAttr(parallel::SEND_RANK_IDS, parallel::MakeListValue(send_rank_ids), neighbor_exchange);
182   common::AnfAlgo::SetNodeAttr(parallel::RECV_RANK_IDS, parallel::MakeListValue(recv_rank_ids), neighbor_exchange);
183 
184   // SEND_SHAPES, RECV_SHAPES
185   parallel::Shape shape = neigh_shape;
186   parallel::Shapes send_shapes;
187   parallel::Shapes recv_shapes;
188   for (size_t i = 0; i < send_rank_ids.size(); ++i) {
189     send_shapes.push_back(shape);
190     recv_shapes.push_back(shape);
191   }
192   common::AnfAlgo::SetNodeAttr(parallel::SEND_SHAPES, parallel::MakeTupleListValue(send_shapes), neighbor_exchange);
193   common::AnfAlgo::SetNodeAttr(parallel::RECV_SHAPES, parallel::MakeTupleListValue(recv_shapes), neighbor_exchange);
194 
195   common::AnfAlgo::SetNodeAttr(parallel::COMM_REUSE, MakeValue(true), neighbor_exchange);
196 
197   neighbor_exchange->set_scope(input_node->scope());
198   return neighbor_exchange;
199 }
200 
NewFlashAttentionScoreNode(const std::vector<AnfNodePtr> & input_nodes,int fa_index,int ne_index)201 CNodePtr NewFlashAttentionScoreNode(const std::vector<AnfNodePtr> &input_nodes, int fa_index, int ne_index) {
202   std::vector<AnfNodePtr> fa_inputs = {
203     NewValueNode(std::make_shared<Primitive>(prim::kPrimFlashAttentionScore->name()))};
204 
205   for (size_t i = 0; i < input_nodes.size(); ++i) {
206     fa_inputs.push_back(input_nodes[i]);
207   }
208   auto fa_score = input_nodes[0]->func_graph()->NewCNode(fa_inputs);
209   MS_EXCEPTION_IF_NULL(fa_score);
210 
211   std::stringstream ss;
212   ss << fa_index << "_" << ne_index;
213   std::string ss_result = ss.str();
214   common::AnfAlgo::SetNodeAttr(FLASH_INDEX, MakeValue<std::string>(ss_result), fa_score);
215   fa_score->set_scope(input_nodes[0]->scope());
216   return fa_score;
217 }
218 
NewAddNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)219 CNodePtr NewAddNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
220   std::vector<AnfNodePtr> add_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAdd->name())), left_node,
221                                         right_node};
222   auto add_node = left_node->func_graph()->NewCNode(add_inputs);
223   MS_EXCEPTION_IF_NULL(add_node);
224   add_node->set_scope(left_node->scope());
225   return add_node;
226 }
227 
NewSubNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)228 CNodePtr NewSubNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
229   MS_EXCEPTION_IF_NULL(left_node);
230   MS_EXCEPTION_IF_NULL(right_node);
231   std::vector<AnfNodePtr> sub_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSub->name())), left_node,
232                                         right_node};
233   auto sub_node = left_node->func_graph()->NewCNode(sub_inputs);
234   MS_EXCEPTION_IF_NULL(sub_node);
235   sub_node->set_scope(left_node->scope());
236   return sub_node;
237 }
238 
NewMulNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)239 CNodePtr NewMulNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
240   MS_EXCEPTION_IF_NULL(left_node);
241   MS_EXCEPTION_IF_NULL(right_node);
242   std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMul->name())), left_node,
243                                         right_node};
244   auto mul_node = left_node->func_graph()->NewCNode(mul_inputs);
245   MS_EXCEPTION_IF_NULL(mul_node);
246   mul_node->set_scope(left_node->scope());
247   return mul_node;
248 }
249 
NewDivNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)250 CNodePtr NewDivNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
251   MS_EXCEPTION_IF_NULL(left_node);
252   MS_EXCEPTION_IF_NULL(right_node);
253   std::vector<AnfNodePtr> div_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimRealDiv->name())),
254                                         left_node, right_node};
255   auto div_node = left_node->func_graph()->NewCNode(div_inputs);
256   MS_EXCEPTION_IF_NULL(div_node);
257   div_node->set_scope(left_node->scope());
258   return div_node;
259 }
260 
NewExpNode(const AnfNodePtr & left_node)261 CNodePtr NewExpNode(const AnfNodePtr &left_node) {
262   MS_EXCEPTION_IF_NULL(left_node);
263   std::vector<AnfNodePtr> exp_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimExp->name())), left_node};
264   auto exp_node = left_node->func_graph()->NewCNode(exp_inputs);
265   MS_EXCEPTION_IF_NULL(exp_node);
266   exp_node->set_scope(left_node->scope());
267   return exp_node;
268 }
269 
NewMaxNode(const AnfNodePtr & left_node,const AnfNodePtr & right_node)270 CNodePtr NewMaxNode(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
271   MS_EXCEPTION_IF_NULL(left_node);
272   MS_EXCEPTION_IF_NULL(right_node);
273   std::vector<AnfNodePtr> max_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMaximum->name())),
274                                         left_node, right_node};
275   auto max_node = left_node->func_graph()->NewCNode(max_inputs);
276   MS_EXCEPTION_IF_NULL(max_node);
277   max_node->set_scope(left_node->scope());
278   return max_node;
279 }
280 
NewCastNode(const AnfNodePtr & tensor_node,const TypeId & dtype)281 CNodePtr NewCastNode(const AnfNodePtr &tensor_node, const TypeId &dtype) {
282   MS_EXCEPTION_IF_NULL(tensor_node);
283   auto type_node = NewValueNode(static_cast<int64_t>(dtype));
284   std::vector<AnfNodePtr> cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
285                                          tensor_node, type_node};
286   auto cast_node = tensor_node->func_graph()->NewCNode(cast_inputs);
287 
288   MS_EXCEPTION_IF_NULL(cast_node);
289   common::AnfAlgo::SetNodeAttrSafely(kAttrDstType, TypeIdToType(dtype), cast_node);
290   cast_node->set_scope(tensor_node->scope());
291   return cast_node;
292 }
293 
NewTransposeNode(const AnfNodePtr & tensor_node,const AnfNodePtr & tuple,ShapeVector output_shape)294 CNodePtr NewTransposeNode(const AnfNodePtr &tensor_node, const AnfNodePtr &tuple, ShapeVector output_shape) {
295   MS_EXCEPTION_IF_NULL(tensor_node);
296   MS_EXCEPTION_IF_NULL(tuple);
297   std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTranspose->name())),
298                                               tensor_node, tuple};
299   auto transpose_node = tensor_node->func_graph()->NewCNode(transpose_inputs);
300   MS_EXCEPTION_IF_NULL(transpose_node);
301   transpose_node->set_scope(tensor_node->scope());
302   return transpose_node;
303 }
304 
NewTileNode(const AnfNodePtr & tensor_node,const AnfNodePtr & tuple)305 CNodePtr NewTileNode(const AnfNodePtr &tensor_node, const AnfNodePtr &tuple) {
306   MS_EXCEPTION_IF_NULL(tensor_node);
307   MS_EXCEPTION_IF_NULL(tuple);
308   std::vector<AnfNodePtr> tile_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimTile->name())),
309                                          tensor_node, tuple};
310   auto tile_node = tensor_node->func_graph()->NewCNode(tile_inputs);
311   MS_EXCEPTION_IF_NULL(tile_node);
312   tile_node->set_scope(tensor_node->scope());
313   return tile_node;
314 }
315 
make_mask_tensor(TypeId type_id,ShapeVector shape,uint8_t value,bool is_causle)316 tensor::TensorPtr make_mask_tensor(TypeId type_id, ShapeVector shape, uint8_t value, bool is_causle) {
317   tensor::TensorPtr mask_tensor = std::make_shared<mindspore::tensor::Tensor>(type_id, shape);
318   int tensor_size = SizeToInt(mask_tensor->data().size());
319   uint8_t *uint8_data = reinterpret_cast<uint8_t *>(mask_tensor->data_c());
320   if (!is_causle) {
321     for (int i = 0; i < tensor_size; ++i) {
322       uint8_data[i] = value;
323     }
324   } else {
325     for (int i = 0; i < shape[kIndex0]; ++i) {
326       for (int j = 0; j < shape[kIndex1]; ++j) {
327         if (i >= j) {
328           uint8_data[i * shape[kIndex0] + j] = 0;
329         } else {
330           uint8_data[i * shape[kIndex0] + j] = 1;
331         }
332       }
333     }
334   }
335   return mask_tensor;
336 }
337 
GetActualMask(int index,int64_t rank_id,TypeId mask_dtype,ShapeVector mask_shape)338 AnfNodePtr GetActualMask(int index, int64_t rank_id, TypeId mask_dtype, ShapeVector mask_shape) {
339   AnfNodePtr actual_mask;
340   if (index == 0) {
341     auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 0, true);
342     actual_mask = NewValueNode(MakeValue(mask_tensor));
343   } else if (index <= rank_id) {
344     auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 0, false);
345     actual_mask = NewValueNode(MakeValue(mask_tensor));
346   } else {
347     auto mask_tensor = make_mask_tensor(mask_dtype, mask_shape, 1, false);
348     actual_mask = NewValueNode(MakeValue(mask_tensor));
349   }
350   return actual_mask;
351 }
352 
GetPosInSpDevice(std::shared_ptr<FlashAttentionScoreInfo> flash_score_info_ptr,int64_t rank_id)353 int64_t GetPosInSpDevice(std::shared_ptr<FlashAttentionScoreInfo> flash_score_info_ptr, int64_t rank_id) {
354   auto rankList = flash_score_info_ptr->GetSPRankList();
355   int64_t pos = -1;
356   for (size_t rank_list_idx = 0; rank_list_idx < rankList.size(); ++rank_list_idx) {
357     if (rank_id == rankList[rank_list_idx]) {
358       pos = rank_list_idx;
359     }
360   }
361   return pos;
362 }
363 
GetBSHFromShape(int64_t input_layout,Shape q_shape,Shape kv_shape,int64_t * fa_b,int64_t * fa_s1,int64_t * fa_h1,int64_t * fa_s2)364 void GetBSHFromShape(int64_t input_layout, Shape q_shape, Shape kv_shape, int64_t *fa_b, int64_t *fa_s1, int64_t *fa_h1,
365                      int64_t *fa_s2) {
366   if (input_layout == FASInputLayoutMode::BSH) {
367     *fa_b = q_shape[kIndex0];
368     *fa_s1 = q_shape[kIndex1];
369     *fa_h1 = q_shape[kIndex2];
370     *fa_s2 = kv_shape[kIndex1];
371   } else if (input_layout == FASInputLayoutMode::BNSD) {
372     *fa_b = q_shape[kIndex0];
373     *fa_s1 = q_shape[kIndex2];
374     *fa_h1 = q_shape[kIndex1] * q_shape[kIndex3];
375     *fa_s2 = kv_shape[kIndex2];
376   }
377 }
378 
GetFlashIndexString(int fa_index,int index)379 ValuePtr GetFlashIndexString(int fa_index, int index) {
380   std::stringstream ss;
381   ss << fa_index << "_" << index;
382   std::string ss_result = ss.str();
383   return MakeValue<std::string>(ss_result);
384 }
385 
UpdateAttentionOutput(CNodePtr * history_max,CNodePtr * history_sum,CNodePtr * acc_attention,const CNodePtr & softmax_max,const CNodePtr & softmax_sum,CNodePtr attention_output,int64_t fa_b,int64_t fa_s1,int64_t fa_n1,int64_t fa_h1,int64_t input_layout,int fa_index,int index)386 void UpdateAttentionOutput(CNodePtr *history_max, CNodePtr *history_sum, CNodePtr *acc_attention,
387                            const CNodePtr &softmax_max, const CNodePtr &softmax_sum, CNodePtr attention_output,
388                            int64_t fa_b, int64_t fa_s1, int64_t fa_n1, int64_t fa_h1, int64_t input_layout,
389                            int fa_index, int index) {
390   auto temp_max = NewMaxNode(*history_max, softmax_max);
391   auto m_h_sub_temp = NewSubNode(*history_max, temp_max);
392   auto m_i_sub_temp = NewSubNode(softmax_max, temp_max);
393   auto e_m_h_temp = NewExpNode(m_h_sub_temp);
394   auto e_m_i_temp = NewExpNode(m_i_sub_temp);
395   auto e_l_h = NewMulNode(e_m_h_temp, *history_sum);
396   auto e_l_i = NewMulNode(e_m_i_temp, softmax_sum);
397   auto l = NewAddNode(e_l_h, e_l_i);
398   auto e_m_h_div = NewDivNode(e_l_h, l);
399   auto e_m_i_div = NewDivNode(e_l_i, l);
400   auto e_m_h_div_split = NewSplitNode(e_m_h_div, 3, 8);
401   auto e_m_h_div_item = NewTupleGetItemNode(e_m_h_div_split, 0);
402   auto e_m_h_div_concat = NewTileNode(e_m_h_div_item, parallel::CreateTuple({1, 1, 1, fa_h1 / fa_n1}));
403   auto e_m_i_div_split = NewSplitNode(e_m_i_div, 3, 8);
404   auto e_m_i_div_item = NewTupleGetItemNode(e_m_i_div_split, 0);
405   auto e_m_i_div_concat = NewTileNode(e_m_i_div_item, parallel::CreateTuple({1, 1, 1, fa_h1 / fa_n1}));
406   if (input_layout == FASInputLayoutMode::BSH) {
407     (*acc_attention) = NewReshapeNode(*acc_attention, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1}, TypeId::kNumberTypeFloat16);
408     attention_output =
409       NewReshapeNode(attention_output, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1}, TypeId::kNumberTypeFloat16);
410     AnfNodePtr tmp_tup = parallel::CreateTuple({0, 2, 1, 3});
411     (*acc_attention) = NewTransposeNode(*acc_attention, tmp_tup, {fa_b, fa_n1, fa_s1, fa_h1 / fa_n1});
412     attention_output = NewTransposeNode(attention_output, tmp_tup, {fa_b, fa_n1, fa_s1, fa_h1 / fa_n1});
413   }
414   (*acc_attention) = NewCastNode(*acc_attention, TypeId::kNumberTypeFloat32);
415   attention_output = NewCastNode(attention_output, TypeId::kNumberTypeFloat32);
416   auto weighted_history = NewMulNode(e_m_h_div_concat, *acc_attention);
417   auto weighted_attention = NewMulNode(e_m_i_div_concat, attention_output);
418   (*acc_attention) = NewAddNode(weighted_history, weighted_attention);
419   common::AnfAlgo::SetNodeAttr(kAttrAccumulatedAttention, MakeValue(1), *acc_attention);
420   common::AnfAlgo::SetNodeAttr("FLASH_INDEX", GetFlashIndexString(fa_index, index), *acc_attention);
421   if (input_layout == FASInputLayoutMode::BSH) {
422     auto tmp_tup1 = parallel::CreateTuple({0, 2, 1, 3});
423     (*acc_attention) = NewTransposeNode(*acc_attention, tmp_tup1, {fa_b, fa_s1, fa_n1, fa_h1 / fa_n1});
424     (*acc_attention) = NewReshapeNode(*acc_attention, {fa_b, fa_s1, fa_h1}, TypeId::kNumberTypeFloat32);
425   }
426   (*history_max) = temp_max;
427   (*history_sum) = l;
428 }
429 
CreateReplaceFSPGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & fa_score_node,FSPInfo * fsp_info,int fa_index)430 CNodePtr CreateReplaceFSPGraph(const FuncGraphManagerPtr &manager,
431                                const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtr &fa_score_node,
432                                FSPInfo *fsp_info, int fa_index) {
433   std::vector<AnfNodePtr> fa_inputs;
434   for (size_t i = 0; i < ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputsNum; ++i) {
435     fa_inputs.push_back(fa_score_node->input(i + 1));
436   }
437 
438   auto key_node = fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputKeyIndex + 1);
439   auto value_node = fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputValueIndex + 1);
440 
441   int64_t sp_num = fsp_info->GetSPNum(), rank_id = fsp_info->GetRankId();
442   int64_t send_rank_id = fsp_info->GetSendRankId(), recv_rank_id = fsp_info->GetRecvRankId();
443 
444   std::shared_ptr<OperatorInfo> operator_info = fa_score_node->user_data<parallel::OperatorInfo>();
445   auto flash_score_info_ptr = std::dynamic_pointer_cast<FlashAttentionScoreInfo>(operator_info);
446   auto q_shape = operator_info->inputs_tensor_info()[kIndex0].tensor_layout().base_slice_shape().array();
447   auto kv_shape = operator_info->inputs_tensor_info()[kIndex1].tensor_layout().base_slice_shape().array();
448   auto input_layout = flash_score_info_ptr->input_layout();
449   if (input_layout != FASInputLayoutMode::BSH && input_layout != FASInputLayoutMode::BNSD) {
450     return nullptr;
451   }
452 
453   int64_t fa_n1 = GetValue<int64_t>(
454     fa_score_node->input(ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputHeadNumIndex + 1)
455       ->cast<ValueNodePtr>()
456       ->value());
457   int64_t fa_b, fa_s1, fa_h1, fa_s2;
458   GetBSHFromShape(input_layout, q_shape, kv_shape, &fa_b, &fa_s1, &fa_h1, &fa_s2);
459 
460   CNodePtr local_fa_node, kv_received_tuple, softmax_max, softmax_sum, softmax_out, attention_output;
461   CNodePtr history_max, history_sum, acc_attention;
462   AnfNodePtr actual_mask;
463   for (int i = 0; i < sp_num; ++i) {
464     std::vector<AnfNodePtr> kv_nodes = {key_node, value_node};
465     auto kv_tuple = NewMakeTupleNode(kv_nodes);
466     auto kv_concat = NewConcatNode(kv_tuple, 0);
467     std::vector<AnfNodePtr> concat_tuple = {kv_concat};
468     auto kv_concat_tuple = NewMakeTupleNode(concat_tuple);
469     if (i != sp_num - 1) {
470       auto neigh_shape = kv_shape;
471       neigh_shape[0] = neigh_shape[0] * kIndex2;
472       kv_received_tuple =
473         NewNeighborExchangeNode(kv_concat_tuple, {send_rank_id}, {recv_rank_id}, fa_index, i, neigh_shape);
474     }
475 
476     auto pos = GetPosInSpDevice(flash_score_info_ptr, rank_id);
477     actual_mask = GetActualMask(i, pos, TypeId::kNumberTypeUInt8, Shape{fa_s1, fa_s2});
478     fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputKeyIndex] = key_node;
479     fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputValueIndex] = value_node;
480     fa_inputs[ops::FlashAttentionScoreInputIndex::kFlashAttentionScoreInputAttnMaskIndex] = actual_mask;
481     local_fa_node = NewFlashAttentionScoreNode(fa_inputs, fa_index, i);
482     common::AnfAlgo::CopyNodeAttrs(fa_score_node, local_fa_node);
483 
484     if (i != sp_num - 1) {
485       auto kv_exchanged_item = NewTupleGetItemNode(kv_received_tuple, kIndex0);
486       auto kv_split = NewSplitNode(kv_exchanged_item, kIndex0, kIndex2);
487       key_node = NewTupleGetItemNode(kv_split, kIndex0);
488       value_node = NewTupleGetItemNode(kv_split, kIndex1);
489     }
490 
491     softmax_max = NewTupleGetItemNode(local_fa_node, kIndex0);
492     softmax_sum = NewTupleGetItemNode(local_fa_node, kIndex1);
493     attention_output = NewTupleGetItemNode(local_fa_node, kIndex3);
494 
495     if (i == 0) {
496       acc_attention = attention_output->cast<CNodePtr>();
497       history_max = softmax_max->cast<CNodePtr>();
498       history_sum = softmax_sum->cast<CNodePtr>();
499     } else {
500       UpdateAttentionOutput(&history_max, &history_sum, &acc_attention, softmax_max, softmax_sum, attention_output,
501                             fa_b, fa_s1, fa_n1, fa_h1, input_layout, fa_index, i);
502     }
503   }
504   acc_attention = NewCastNode(acc_attention, TypeId::kNumberTypeFloat16);
505   softmax_out = NewTupleGetItemNode(local_fa_node, kIndex2);
506   std::vector<AnfNodePtr> output_tuple = {history_max, history_sum, softmax_out, acc_attention};
507   auto attention_results = NewMakeTupleNode(output_tuple);
508   return attention_results;
509 }
510 
CreateAndReplaceFAScore(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & fa_score_node,FSPInfo * fsp_info,int i)511 void CreateAndReplaceFAScore(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
512                              const CNodePtr &fa_score_node, FSPInfo *fsp_info, int i) {
513   auto cnode = CreateReplaceFSPGraph(manager, origin_nodes_topological, fa_score_node, fsp_info, i);
514   MS_EXCEPTION_IF_NULL(cnode);
515   (void)manager->Replace(fa_score_node, cnode);
516 }
517 
CheckUserSettings(const FuncGraphPtr & fg,FSPInfo * fsp_info)518 bool CheckUserSettings(const FuncGraphPtr &fg, FSPInfo *fsp_info) {
519   fsp_info->DisplayInfo();
520 
521   int64_t sp_num = fsp_info->GetSPNum();
522   if (sp_num <= 1) {
523     MS_LOG(WARNING) << "FSP: To activate the pass, sp num " << sp_num << " should between larger than 1";
524     return false;
525   }
526   return true;
527 }
528 }  // namespace
529 
SetFlashSP(const FuncGraphPtr & func_graph)530 bool SetFlashSP(const FuncGraphPtr &func_graph) {
531   auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
532   if (parallel_mode != kSemiAutoParallel) {
533     return false;
534   }
535 
536   MS_EXCEPTION_IF_NULL(func_graph);
537   auto manager = func_graph->manager();
538   MS_EXCEPTION_IF_NULL(manager);
539   auto ret = func_graph->get_return();
540   auto origin_nodes_topological = DeepScopedGraphSearch(ret);
541 
542   std::vector<CNodePtr> fa_score_nodes = FindFWFlashAttentionScore(manager, origin_nodes_topological);
543   if (fa_score_nodes.size() == 0) {
544     return false;
545   }
546 
547   for (size_t i = 0; i < fa_score_nodes.size(); ++i) {
548     auto fa_score_node = fa_score_nodes[i];
549     auto fa_score_node_prim = GetCNodePrimitive(fa_score_node);
550     MS_EXCEPTION_IF_NULL(fa_score_node_prim);
551     if (!fa_score_node_prim->HasAttr(parallel::ENABLE_RING_ATTENTION) ||
552         !GetValue<bool>((fa_score_node_prim->GetAttr(parallel::ENABLE_RING_ATTENTION)))) {
553       continue;
554     }
555 
556     auto fsp_info = FSPInfo(fa_score_node);
557     if (!CheckUserSettings(func_graph, &fsp_info)) {
558       return false;
559     }
560 
561     manager = func_graph->manager();
562     MS_EXCEPTION_IF_NULL(manager);
563     auto orders = func_graph->GetOrderedCnodes();
564     std::vector<CNodePtr> nodes_topological(orders.cbegin(), orders.cend());
565     CreateAndReplaceFAScore(manager, nodes_topological, fa_score_node, &fsp_info, i);
566   }
567   return true;
568 }
569 }  // namespace parallel
570 }  // namespace mindspore
571