• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <tuple>
24 #include <utility>
25 
26 #include "utils/hash_map.h"
27 #include "utils/ms_utils.h"
28 #include "ir/value.h"
29 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
30 #include "frontend/parallel/graph_util/generate_graph.h"
31 #include "frontend/parallel/ops_info/operator_info.h"
32 #include "frontend/parallel/strategy.h"
33 
34 namespace mindspore {
35 namespace parallel {
36 class FlashAttentionScoreInfo : public OperatorInfo {
37  public:
FlashAttentionScoreInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)38   FlashAttentionScoreInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
39                           const PrimitiveAttrs &attrs)
40       : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {}
41   ~FlashAttentionScoreInfo() override = default;
42   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
43 
SetCostUnderStrategy(const StrategyPtr & strategy)44   Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
45   void ReplaceNodeInputOrAttrs() override;
46   void ReComputeBatchSplitFlagList() override;
47   ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
48 
input_layout()49   int64_t input_layout() { return input_layout_; }
s1_split_num()50   int64_t s1_split_num() { return s1_split_num_; }
kv_split()51   bool kv_split() { return kv_split_; }
head_num()52   int64_t head_num() { return head_num_; }
real_shift_have_s1_dim()53   bool real_shift_have_s1_dim() { return real_shift_have_s1_dim_; }
real_shift_have_batch_dim()54   bool real_shift_have_batch_dim() { return real_shift_have_batch_dim_; }
is_attn_mask_compressed()55   bool is_attn_mask_compressed() { return is_attn_mask_compressed_; }
attn_mask_have_n1_dim()56   bool attn_mask_have_n1_dim() { return attn_mask_have_n1_dim_; }
attn_mask_have_batch_dim()57   bool attn_mask_have_batch_dim() { return attn_mask_have_batch_dim_; }
is_input_passed()58   std::vector<bool> is_input_passed() { return is_input_passed_; }
59   size_t GetStrategyRealIndex(size_t index);
60   Status InitAttrs();
61   RankList GetSPRankList();
62 
63  protected:
InferForwardCommunication()64   Status InferForwardCommunication() override { return SUCCESS; }
65   Status InferDevMatrixShape() override;
66   Status InferTensorMap() override;
67   Status GetAttrs() override;
68   Status InferAsLossDivisor() override;
69   Status CheckStrategy(const StrategyPtr &strategy) override;
70   Status InferMirrorOps() override;
71   Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) override;
72   Status InferOutputTensorInfo() override;
73   Status CheckInputLayout() override;
74   Status CheckOutputLayout() override;
75   Status InferOutputLayout();
76   Status InferAsLossDivisorByLayout() override;
77   Status InferMirrorOpsByLayout() override;
78   Status InferSplitNumAndDevMatrixShapeByLayout();
79 
80  private:
81   void UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr &reshape_cnode);
82   void InitIsInputPassed();
83   Status InitQKVTensorMap();
84   Status InitInputsTensorMap();
85   Status InitSplittableInputs();
86   Status InitAttnMaskSplittableInputs();
87   Status InitExpectedStrategies();
88   Status InitAttnMaskStrategies();
89   Status InitQKVHeadAndSeqDimFromInputLayout();
90   std::vector<int64_t> GetSplitIdAndRank();
91   std::tuple<int64_t, int64_t> GetAttentionMaskAttrs(const int64_t split_id, const int64_t split_num);
92   void LoadBalanceSplitAlongSeqDim(size_t input_index, GenerateGraph *gen_g, AnfNodePtr *split_node,
93                                    AnfNodePtr *keep_node, AnfNodePtr *exchange_node);
94   void LoadBalanceExchange(const int64_t all_gather_idx, const Group &group, const AnfNodePtr &input_node,
95                            AnfNodePtr *exchange_node, GenerateGraph *gen_g);
96   void GetFlashAttentionScoreOpNode(int64_t split_id, int64_t split_num, const AnfNodePtr &q,
97                                     const AnfNodePtr &real_shift, const AnfNodePtr &drop_mask,
98                                     const AnfNodePtr &attn_mask, AnfNodePtr *fa_op, GenerateGraph *gen_g);
99   std::vector<std::pair<AnfNodePtr, int64_t>> ReplaceGraphGetInputNodes(const AnfNodePtr &q_split,
100                                                                         const AnfNodePtr &real_shift_split,
101                                                                         const AnfNodePtr &drop_mask_split,
102                                                                         const AnfNodePtr &attn_mask_split,
103                                                                         const AnfNodePtr &flash_attention_score_keep,
104                                                                         const AnfNodePtr &flash_attention_score_target);
105   Status ComputeReplaceGraphForLoadBalance(const CNodePtr &cnode);
106   Status ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr &cnode);
107   int64_t head_num_ = 1;
108   float keep_prob_ = 1.0;
109   float scale_value_ = 1.0;
110   size_t qkv_batch_dim_;
111   size_t qkv_head_dim_;
112   size_t qkv_seq_dim_;
113   int64_t pre_tokens_;
114   int64_t next_tokens_;
115   int64_t batch_split_num_;
116   int64_t n1_split_num_;
117   int64_t n2_split_num_;
118   int64_t s1_split_num_;
119   int64_t s2_split_num_;
120   int64_t t1_split_num_;  // The split num of query's T-dim under 'TND'
121   int64_t t2_split_num_;  // The split num of key and value's T=dim under 'TND'
122   int64_t dev_matrix_batch_dim_;
123   int64_t dev_matrix_n1_dim_;
124   int64_t dev_matrix_s1_dim_;
125   bool real_shift_have_s1_dim_ = false;     // true if real_shift and have s1 dim.
126   bool real_shift_have_batch_dim_ = false;  // true if real_shift have batch dim
127   bool attn_mask_have_batch_dim_ = false;   // true if attn_mask have batch dim.
128   bool attn_mask_have_n1_dim_ = false;      // true if attn_mask have n1 dim.
129   bool enable_load_balance_ = false;
130   bool enable_ring_attention_ = false;
131   int64_t input_layout_;  // "BSH": 0; "BNSD": 1;
132   int64_t sparse_mode_;
133   bool kv_split_ = false;
134   bool is_attn_mask_compressed_ = false;
135   bool need_update_op_attrs_mode_ = false;
136   std::vector<bool> is_input_passed_;
137   size_t real_input_size_ = 0;
138   std::vector<Shape> splittable_inputs_;
139   Strategies expect_strategies_;
140   TensorLayout softmax_max_tensor_layout_;
141   TensorLayout softmax_sum_tensor_layout_;
142   TensorLayout softmax_out_tensor_layout_;
143   TensorLayout attention_out_tensor_layout_;
144 };
145 }  // namespace parallel
146 }  // namespace mindspore
147 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_
148