• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
19 
20 #include <ir/value.h>
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <utility>
26 
27 #include "utils/hash_map.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 #include "frontend/parallel/strategy.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 /*
34  * parallel class for Reshape Primitive
35  */
36 class ReshapeInfo : public OperatorInfo {
37  public:
ReshapeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)38   ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
39               const PrimitiveAttrs &attrs)
40       : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()),
41         dev_num_(0),
42         pre_operator_index_(0),
43         next_operator_index_(0),
44         input_layout_set_flag_(false),
45         output_layout_set_flag_(false) {}
46   ~ReshapeInfo() override = default;
47   Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
48               const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
49               const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override;
SetInputLayout(const TensorLayout & input_layout)50   void SetInputLayout(const TensorLayout &input_layout) {
51     input_layout_ = input_layout;
52     input_layout_set_flag_ = true;
53   }
SetOutputLayout(const TensorLayout & output_layout)54   void SetOutputLayout(const TensorLayout &output_layout) {
55     output_layout_ = output_layout;
56     output_layout_set_flag_ = true;
57   }
58   void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy);
59   void SetCostForReshapeWithParameter();
set_pre_operator_name(const std::string & pre_name)60   void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; }
set_next_operator_name(const std::string & next_name)61   void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
set_pre_operator_index(int64_t pre_index)62   void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; }
set_next_operator_index(int64_t next_index)63   void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
64   StrategyPtr get_input_shard_strategy();
65   Status GenerateStrategyCosts(
66     const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
67     std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index, int64_t out_index,
68     bool is_prev_param, bool is_next_reshape);
69   std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
70   Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
pre_operator_name()71   std::string pre_operator_name() const { return pre_operator_name_; }
next_operator_name()72   std::string next_operator_name() const { return next_operator_name_; }
pre_operator_index()73   int64_t pre_operator_index() const { return pre_operator_index_; }
next_operator_index()74   int64_t next_operator_index() const { return next_operator_index_; }
75 
76   int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout);
77   int64_t GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout);
78   int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout);
79   int64_t GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout);
80   bool CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) const;
81   bool CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) const;
82 
83   TensorLayout GetInputLayoutBySWCIndex(int64_t swc_index) const;
84   TensorLayout GetOutputLayoutBySWCIndex(int64_t swc_index) const;
85 
InterleavedParallel()86   bool InterleavedParallel() const { return interleaved_parallel_; }
87   TensorRedistributionPtr ReshapeRedistribution();
88 
89  protected:
90   Status CheckStrategy(const StrategyPtr &strategy) override;
91   Status InferMirrorOps() override;
92   Status InferForwardCommunication() override;
93   Status InferTensorMap() override;
94   Status InferTensorInfo() override;
95   Status InferDevMatrixShape() override;
96   Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
GetAttrs()97   Status GetAttrs() override { return SUCCESS; }
98 
99  private:
100   Status ComputeReplaceOp();
101   Status ComputeReplaceOpForDynamicShape();
102   void InferTensorInfoByLayout();
103   void device_number();
104   Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout);
105   std::vector<int64_t> GetInputShape(const AnfNodePtr &shape_input_node);
106   void ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr &shape_input_node);
107   void ChangeDstShape();
108 
109   int64_t dev_num_;
110   int64_t pre_operator_index_;
111   int64_t next_operator_index_;
112   std::vector<int64_t> parameter_input_v_;
113   std::vector<StrategyPtr> sp_vector_;
114   Dimensions input_strategy_;
115   TensorLayout input_layout_;
116   TensorLayout output_layout_;
117   bool input_layout_set_flag_;
118   bool output_layout_set_flag_;
119   bool is_generating_costs_ = false;
120   bool is_skip_ = false;
121   bool interleaved_parallel_ = false;
122   std::string pre_operator_name_;
123   std::string next_operator_name_;
124 };
125 }  // namespace parallel
126 }  // namespace mindspore
127 
128 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RESHAPE_INFO_H_
129