1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ 19 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "ir/value.h" 26 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 27 #include "frontend/parallel/ops_info/operator_info.h" 28 #include "frontend/parallel/strategy.h" 29 30 namespace mindspore { 31 namespace parallel { 32 class GatherPInfo : public OperatorInfo { 33 public: 34 GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 35 const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) OperatorInfo(name,inputs_shape,outputs_shape,attrs,std::make_shared<GatherV2PCost> ())36 : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), 37 axis_(0), 38 bias_(0), 39 index_offset_(0), 40 slice_size_(0), 41 replace_op_name_(replace_op_name) {} 42 ~GatherPInfo() override = default; 43 Status Init(const StrategyPtr &strategy) override; 44 Status InitForCostModel(const StrategyPtr &strategy) override; 45 46 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 47 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; 48 ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; 49 std::shared_ptr<Strategys> GenerateBatchStrategies() override; param_split_shapes()50 const std::vector<int64_t> ¶m_split_shapes() const { return param_split_shapes_; } index_offsets()51 const std::vector<int64_t> &index_offsets() const { return index_offsets_; } 52 53 protected: 54 Status CheckStrategy(const StrategyPtr &strategy) override; 55 Status InferMirrorOps() override; 56 Status InferForwardCommunication() override; 57 Status InferTensorInfo() override; 58 Status InferDevMatrixShape() override; 59 Status InferTensorMap() override; 60 void InferInputsTensorMap(); 61 void InferOutputsTensorMap(); 62 Status GetAttrs() override; 63 64 Status ComputeReplaceGraph(const CNodePtr &cnode); 65 Status CheckManualSplit(const Strategys &strategy); 66 Status CheckSplitAxisStrategy(const StrategyPtr &strategy); 67 void SetAttribute(const StrategyPtr &strategy); 68 Status GetManualSplitAttr(); 69 Status GetManualSplitWithoutOffsetAttr(); 70 Status ComputeReplaceOp(); 71 Status InferBias(); 72 Status InferOffset(); 73 Status InferGroup(); 74 bool ShardBatchAndAxis(const Strategys &strategy) const; 75 76 int64_t axis_; 77 std::string target_ = DEVICE; 78 int64_t bias_; 79 int64_t index_offset_; 80 int64_t slice_size_; 81 std::string replace_op_name_ = GATHERV2; 82 Group group_; 83 bool manual_split_ = false; 84 bool dynamic_shape_indices_ = false; 85 bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward 86 bool shard_batch_and_axis_ = false; 87 std::vector<int64_t> param_split_shapes_; 88 std::vector<int64_t> index_offsets_; 89 }; 90 91 class SparseGatherV2Info : public GatherPInfo { 92 public: 93 SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 94 const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) GatherPInfo(name,inputs_shape,outputs_shape,attrs,replace_op_name)95 : GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} 96 ~SparseGatherV2Info() override = default; 97 }; 98 99 class EmbeddingLookupInfo : public GatherPInfo { 100 public: EmbeddingLookupInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)101 EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 102 const PrimitiveAttrs &attrs) 103 : GatherPInfo(name, inputs_shape, outputs_shape, attrs) {} 104 ~EmbeddingLookupInfo() override = default; 105 }; 106 } // namespace parallel 107 } // namespace mindspore 108 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ 109