1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 19 20 #include <ir/value.h> 21 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include <utility> 26 27 #include "utils/hash_map.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 #include "frontend/parallel/strategy.h" 30 31 namespace mindspore { 32 namespace parallel { 33 /* 34 * parallel class for Reshape Primitive 35 */ 36 class ReshapeInfo : public OperatorInfo { 37 public: ReshapeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)38 ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 39 const PrimitiveAttrs &attrs) 40 : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), 41 dev_num_(0), 42 pre_operator_index_(0), 43 next_operator_index_(0), 44 input_layout_set_flag_(false), 45 output_layout_set_flag_(false) {} 46 ~ReshapeInfo() override = default; 47 Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy, 48 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {}, 49 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override; SetInputLayout(const TensorLayout & input_layout)50 void SetInputLayout(const TensorLayout &input_layout) { 51 input_layout_ = input_layout; 52 input_layout_set_flag_ = true; 53 } SetOutputLayout(const TensorLayout & output_layout)54 void SetOutputLayout(const TensorLayout &output_layout) { 55 output_layout_ = output_layout; 56 output_layout_set_flag_ = true; 57 } 58 void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); 59 void SetCostForReshapeWithParameter(); set_pre_operator_name(const std::string & pre_name)60 void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } set_next_operator_name(const std::string & next_name)61 void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } set_pre_operator_index(int64_t pre_index)62 void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; } set_next_operator_index(int64_t next_index)63 void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; } 64 StrategyPtr get_input_shard_strategy(); 65 Status GenerateStrategyCosts( 66 const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, 67 std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index, int64_t out_index, 68 bool is_prev_param, bool is_next_reshape); 69 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 70 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; pre_operator_name()71 std::string pre_operator_name() const { return pre_operator_name_; } next_operator_name()72 std::string next_operator_name() const { return next_operator_name_; } pre_operator_index()73 int64_t pre_operator_index() const { return pre_operator_index_; } next_operator_index()74 int64_t next_operator_index() const { return next_operator_index_; } 75 76 int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout); 77 int64_t GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout); 78 int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout); 79 int64_t GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout); 80 bool CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) const; 81 bool CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) const; 82 83 TensorLayout GetInputLayoutBySWCIndex(int64_t swc_index) const; 84 TensorLayout GetOutputLayoutBySWCIndex(int64_t swc_index) const; 85 InterleavedParallel()86 bool InterleavedParallel() const { return interleaved_parallel_; } 87 TensorRedistributionPtr ReshapeRedistribution(); 88 89 protected: 90 Status CheckStrategy(const StrategyPtr &strategy) override; 91 Status InferMirrorOps() override; 92 Status InferForwardCommunication() override; 93 Status InferTensorMap() override; 94 Status InferTensorInfo() override; 95 Status InferDevMatrixShape() override; 96 Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); GetAttrs()97 Status GetAttrs() override { return SUCCESS; } 98 99 private: 100 Status ComputeReplaceOp(); 101 Status ComputeReplaceOpForDynamicShape(); 102 void InferTensorInfoByLayout(); 103 void device_number(); 104 Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); 105 std::vector<int64_t> GetInputShape(const AnfNodePtr &shape_input_node); 106 void ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr &shape_input_node); 107 void ChangeDstShape(); 108 109 int64_t dev_num_; 110 int64_t pre_operator_index_; 111 int64_t next_operator_index_; 112 std::vector<int64_t> parameter_input_v_; 113 std::vector<StrategyPtr> sp_vector_; 114 Dimensions input_strategy_; 115 TensorLayout input_layout_; 116 TensorLayout output_layout_; 117 bool input_layout_set_flag_; 118 bool output_layout_set_flag_; 119 bool is_generating_costs_ = false; 120 bool is_skip_ = false; 121 bool interleaved_parallel_ = false; 122 std::string pre_operator_name_; 123 std::string next_operator_name_; 124 }; 125 } // namespace parallel 126 } // namespace mindspore 127 128 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 129