• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
19 
20 #include <ir/value.h>
21 
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "frontend/parallel/ops_info/operator_info.h"
28 #include "frontend/parallel/strategy.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 /*
33  * parallel class for Reshape Primitive
34  */
35 class ReshapeInfo : public OperatorInfo {
36  public:
ReshapeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)37   ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
38               const PrimitiveAttrs &attrs)
39       : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
40         dev_num_(0),
41         pre_operator_index_(0),
42         next_operator_index_(0),
43         input_layout_set_flag_(false),
44         output_layout_set_flag_(false) {}
45   ~ReshapeInfo() override = default;
46   Status Init(const StrategyPtr &strategy) override;
SetInputLayout(const TensorLayout & input_layout)47   void SetInputLayout(const TensorLayout &input_layout) {
48     input_layout_ = input_layout;
49     input_layout_set_flag_ = true;
50   }
SetOutputLayout(const TensorLayout & output_layout)51   void SetOutputLayout(const TensorLayout &output_layout) {
52     output_layout_ = output_layout;
53     output_layout_set_flag_ = true;
54   }
55   void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy);
56   void SetCostForReshapeWithParameter();
set_pre_operator_name(const std::string & pre_name)57   void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; }
set_next_operator_name(const std::string & next_name)58   void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
set_pre_operator_index(int64_t pre_index)59   void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; }
set_next_operator_index(int64_t next_index)60   void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
61   Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
62                                const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
63                                int64_t in_index, bool is_prev_param, bool is_next_reshape);
64   Status InitForCostModel(const StrategyPtr &strategy) override;
65   Status GenerateStrategies(int64_t stage_id) override;
66   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
67   Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
pre_operator_name()68   std::string pre_operator_name() const { return pre_operator_name_; }
next_operator_name()69   std::string next_operator_name() const { return next_operator_name_; }
pre_operator_index()70   int64_t pre_operator_index() const { return pre_operator_index_; }
next_operator_index()71   int64_t next_operator_index() const { return next_operator_index_; }
72 
73  protected:
74   Status CheckStrategy(const StrategyPtr &strategy) override;
75   Status InferMirrorOps() override;
76   Status InferForwardCommunication() override;
77   Status InferTensorMap() override;
78   Status InferTensorInfo() override;
79   Status InferDevMatrixShape() override;
80   Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
81   Status GetAttrs() override;
82   Strategys GetOutputsStrategy();
83 
84  private:
85   Status GetParameterInput();
86   Status ComputeReplaceOp();
87   void InferTensorInfoByLayout();
88   void device_number();
89   Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout);
90 
91   int64_t dev_num_;
92   int64_t pre_operator_index_;
93   int64_t next_operator_index_;
94   std::vector<int64_t> parameter_input_v_;
95   std::vector<StrategyPtr> sp_vector_;
96   Dimensions input_strategy_;
97   TensorLayout input_layout_;
98   TensorLayout output_layout_;
99   bool input_layout_set_flag_;
100   bool output_layout_set_flag_;
101   bool is_generating_costs_ = false;
102   bool is_skip_ = false;
103   std::string pre_operator_name_;
104   std::string next_operator_name_;
105 };
106 }  // namespace parallel
107 }  // namespace mindspore
108 
109 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
110