1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include <tuple> 24 #include <utility> 25 26 #include "utils/hash_map.h" 27 #include "utils/ms_utils.h" 28 #include "ir/value.h" 29 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 30 #include "frontend/parallel/graph_util/generate_graph.h" 31 #include "frontend/parallel/ops_info/operator_info.h" 32 #include "frontend/parallel/strategy.h" 33 34 namespace mindspore { 35 namespace parallel { 36 class FlashAttentionScoreInfo : public OperatorInfo { 37 public: FlashAttentionScoreInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)38 FlashAttentionScoreInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 39 const PrimitiveAttrs &attrs) 40 : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} 41 ~FlashAttentionScoreInfo() override = default; 42 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 43 SetCostUnderStrategy(const StrategyPtr & strategy)44 Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } 45 void ReplaceNodeInputOrAttrs() override; 46 void ReComputeBatchSplitFlagList() override; 47 ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; 48 input_layout()49 int64_t input_layout() { return input_layout_; } s1_split_num()50 int64_t s1_split_num() { return s1_split_num_; } kv_split()51 bool kv_split() { return kv_split_; } head_num()52 int64_t head_num() { return head_num_; } real_shift_have_s1_dim()53 bool real_shift_have_s1_dim() { return real_shift_have_s1_dim_; } real_shift_have_batch_dim()54 bool real_shift_have_batch_dim() { return real_shift_have_batch_dim_; } is_attn_mask_compressed()55 bool is_attn_mask_compressed() { return is_attn_mask_compressed_; } attn_mask_have_n1_dim()56 bool attn_mask_have_n1_dim() { return attn_mask_have_n1_dim_; } attn_mask_have_batch_dim()57 bool attn_mask_have_batch_dim() { return attn_mask_have_batch_dim_; } is_input_passed()58 std::vector<bool> is_input_passed() { return is_input_passed_; } 59 size_t GetStrategyRealIndex(size_t index); 60 Status InitAttrs(); 61 RankList GetSPRankList(); 62 63 protected: InferForwardCommunication()64 Status InferForwardCommunication() override { return SUCCESS; } 65 Status InferDevMatrixShape() override; 66 Status InferTensorMap() override; 67 Status GetAttrs() override; 68 Status InferAsLossDivisor() override; 69 Status CheckStrategy(const StrategyPtr &strategy) override; 70 Status InferMirrorOps() override; 71 Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) override; 72 Status InferOutputTensorInfo() override; 73 Status CheckInputLayout() override; 74 Status CheckOutputLayout() override; 75 Status InferOutputLayout(); 76 Status InferAsLossDivisorByLayout() override; 77 Status InferMirrorOpsByLayout() override; 78 Status InferSplitNumAndDevMatrixShapeByLayout(); 79 80 private: 81 void UpdateDropoutGenMaskSliceShapeAndSeed(const CNodePtr &reshape_cnode); 82 void InitIsInputPassed(); 83 Status InitQKVTensorMap(); 84 Status InitInputsTensorMap(); 85 Status InitSplittableInputs(); 86 Status InitAttnMaskSplittableInputs(); 87 Status InitExpectedStrategies(); 88 Status InitAttnMaskStrategies(); 89 Status InitQKVHeadAndSeqDimFromInputLayout(); 90 std::vector<int64_t> GetSplitIdAndRank(); 91 std::tuple<int64_t, int64_t> GetAttentionMaskAttrs(const int64_t split_id, const int64_t split_num); 92 void LoadBalanceSplitAlongSeqDim(size_t input_index, GenerateGraph *gen_g, AnfNodePtr *split_node, 93 AnfNodePtr *keep_node, AnfNodePtr *exchange_node); 94 void LoadBalanceExchange(const int64_t all_gather_idx, const Group &group, const AnfNodePtr &input_node, 95 AnfNodePtr *exchange_node, GenerateGraph *gen_g); 96 void GetFlashAttentionScoreOpNode(int64_t split_id, int64_t split_num, const AnfNodePtr &q, 97 const AnfNodePtr &real_shift, const AnfNodePtr &drop_mask, 98 const AnfNodePtr &attn_mask, AnfNodePtr *fa_op, GenerateGraph *gen_g); 99 std::vector<std::pair<AnfNodePtr, int64_t>> ReplaceGraphGetInputNodes(const AnfNodePtr &q_split, 100 const AnfNodePtr &real_shift_split, 101 const AnfNodePtr &drop_mask_split, 102 const AnfNodePtr &attn_mask_split, 103 const AnfNodePtr &flash_attention_score_keep, 104 const AnfNodePtr &flash_attention_score_target); 105 Status ComputeReplaceGraphForLoadBalance(const CNodePtr &cnode); 106 Status ReplaceActualSeqLenForSplitSeqInTnd(const CNodePtr &cnode); 107 int64_t head_num_ = 1; 108 float keep_prob_ = 1.0; 109 float scale_value_ = 1.0; 110 size_t qkv_batch_dim_; 111 size_t qkv_head_dim_; 112 size_t qkv_seq_dim_; 113 int64_t pre_tokens_; 114 int64_t next_tokens_; 115 int64_t batch_split_num_; 116 int64_t n1_split_num_; 117 int64_t n2_split_num_; 118 int64_t s1_split_num_; 119 int64_t s2_split_num_; 120 int64_t t1_split_num_; // The split num of query's T-dim under 'TND' 121 int64_t t2_split_num_; // The split num of key and value's T=dim under 'TND' 122 int64_t dev_matrix_batch_dim_; 123 int64_t dev_matrix_n1_dim_; 124 int64_t dev_matrix_s1_dim_; 125 bool real_shift_have_s1_dim_ = false; // true if real_shift and have s1 dim. 126 bool real_shift_have_batch_dim_ = false; // true if real_shift have batch dim 127 bool attn_mask_have_batch_dim_ = false; // true if attn_mask have batch dim. 128 bool attn_mask_have_n1_dim_ = false; // true if attn_mask have n1 dim. 129 bool enable_load_balance_ = false; 130 bool enable_ring_attention_ = false; 131 int64_t input_layout_; // "BSH": 0; "BNSD": 1; 132 int64_t sparse_mode_; 133 bool kv_split_ = false; 134 bool is_attn_mask_compressed_ = false; 135 bool need_update_op_attrs_mode_ = false; 136 std::vector<bool> is_input_passed_; 137 size_t real_input_size_ = 0; 138 std::vector<Shape> splittable_inputs_; 139 Strategies expect_strategies_; 140 TensorLayout softmax_max_tensor_layout_; 141 TensorLayout softmax_sum_tensor_layout_; 142 TensorLayout softmax_out_tensor_layout_; 143 TensorLayout attention_out_tensor_layout_; 144 }; 145 } // namespace parallel 146 } // namespace mindspore 147 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FLASH_ATTENTION_SCORE_INFO_H_ 148