1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 19 20 #include <ir/value.h> 21 22 #include <memory> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "frontend/parallel/ops_info/operator_info.h" 28 #include "frontend/parallel/strategy.h" 29 30 namespace mindspore { 31 namespace parallel { 32 /* 33 * parallel class for Reshape Primitive 34 */ 35 class ReshapeInfo : public OperatorInfo { 36 public: ReshapeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)37 ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 38 const PrimitiveAttrs &attrs) 39 : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), 40 dev_num_(0), 41 pre_operator_index_(0), 42 next_operator_index_(0), 43 input_layout_set_flag_(false), 44 output_layout_set_flag_(false) {} 45 ~ReshapeInfo() override = default; 46 Status Init(const StrategyPtr &strategy) override; SetInputLayout(const TensorLayout & input_layout)47 void SetInputLayout(const TensorLayout &input_layout) { 48 input_layout_ = input_layout; 49 input_layout_set_flag_ = true; 50 } SetOutputLayout(const TensorLayout & output_layout)51 void SetOutputLayout(const TensorLayout &output_layout) { 52 output_layout_ = output_layout; 53 output_layout_set_flag_ = true; 54 } 55 void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); 56 void SetCostForReshapeWithParameter(); set_pre_operator_name(const std::string & pre_name)57 void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } set_next_operator_name(const std::string & next_name)58 void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } set_pre_operator_index(int64_t pre_index)59 void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; } set_next_operator_index(int64_t next_index)60 void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; } 61 Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, 62 const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index, 63 int64_t in_index, bool is_prev_param, bool is_next_reshape); 64 Status InitForCostModel(const StrategyPtr &strategy) override; 65 Status GenerateStrategies(int64_t stage_id) override; 66 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 67 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; pre_operator_name()68 std::string pre_operator_name() const { return pre_operator_name_; } next_operator_name()69 std::string next_operator_name() const { return next_operator_name_; } pre_operator_index()70 int64_t pre_operator_index() const { return pre_operator_index_; } next_operator_index()71 int64_t next_operator_index() const { return next_operator_index_; } 72 73 protected: 74 Status CheckStrategy(const StrategyPtr &strategy) override; 75 Status InferMirrorOps() override; 76 Status InferForwardCommunication() override; 77 Status InferTensorMap() override; 78 Status InferTensorInfo() override; 79 Status InferDevMatrixShape() override; 80 Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); 81 Status GetAttrs() override; 82 Strategys GetOutputsStrategy(); 83 84 private: 85 Status GetParameterInput(); 86 Status ComputeReplaceOp(); 87 void InferTensorInfoByLayout(); 88 void device_number(); 89 Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); 90 91 int64_t dev_num_; 92 int64_t pre_operator_index_; 93 int64_t next_operator_index_; 94 std::vector<int64_t> parameter_input_v_; 95 std::vector<StrategyPtr> sp_vector_; 96 Dimensions input_strategy_; 97 TensorLayout input_layout_; 98 TensorLayout output_layout_; 99 bool input_layout_set_flag_; 100 bool output_layout_set_flag_; 101 bool is_generating_costs_ = false; 102 bool is_skip_ = false; 103 std::string pre_operator_name_; 104 std::string next_operator_name_; 105 }; 106 } // namespace parallel 107 } // namespace mindspore 108 109 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ 110