• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "utils/ms_utils.h"
29 #include "base/base.h"
30 #include "frontend/parallel/auto_parallel/costmodel.h"
31 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
32 #include "frontend/parallel/device_manager.h"
33 #include "frontend/parallel/device_matrix.h"
34 #include "frontend/parallel/group_manager.h"
35 #include "frontend/parallel/ops_info/ops_utils.h"
36 #include "frontend/parallel/strategy.h"
37 #include "frontend/parallel/tensor_layout/tensor_info.h"
38 #include "utils/log_adapter.h"
39 #include "base/core_ops.h"
40 
41 namespace mindspore {
42 namespace parallel {
43 using ForwardOp = OperatorVector;
44 using MirrorOps = std::vector<OperatorVector>;
45 using Ops = std::vector<OperatorVector>;
46 using VirtualDivOp = OperatorVector;
47 using TensorMaps = std::vector<Shape>;
48 using TensorLayouts = std::vector<TensorLayout>;
49 using different_type = std::vector<int64_t>::difference_type;
50 using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>;
51 using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>;
52 
53 class Edge;
54 
55 class OperatorInfo {
56  public:
OperatorInfo(std::string name,Shapes inputs_shape,Shapes outputs_shape,PrimitiveAttrs attrs,OperatorCostPtr cost)57   OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost)
58       : name_(std::move(name)),
59         inputs_shape_(std::move(inputs_shape)),
60         outputs_shape_(std::move(outputs_shape)),
61         attrs_(std::move(attrs)),
62         is_alive_(true),
63         operator_cost_(cost),
64         outputs_type_() {
65     std::vector<bool> not_parameteter(inputs_shape_.size(), false);
66     is_parameter_ = not_parameteter;
67     refkey_parameter_name_ = "";
68     stage_device_list_ = g_device_manager->GetDeviceListInThisStage();
69     stage_device_size_ = SizeToLong(stage_device_list_.size());
70     cnode_ = nullptr;
71   }
72 
73   virtual ~OperatorInfo() = default;
74 
75   Status set_is_parameter(const std::vector<bool> &is_parameter);
76   Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
77                                      const std::vector<size_t> &output_lengths);
78   double GetOutputsTotalSize();
79   // Set outputs dtype.
80   // If only one output, outputs_type.size() is 1.
81   // If output is tuple, outputs_type.size() is greater than 1.
82   Status set_outputs_type(const std::vector<TypePtr> &outputs_type);
outputs_type()83   const std::vector<TypePtr> &outputs_type() const { return outputs_type_; }
84   virtual Status Init(const StrategyPtr &strategy) = 0;
85   virtual Status InitForCostModel(const StrategyPtr &strategy) = 0;  // only init the necessary parts
86 
87   // Given the stage_id (which indicates the number of devices),
88   // generate all strategies for this operator
89   virtual Status GenerateStrategies(int64_t stage_id);
90   virtual std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) = 0;
operator_cost()91   const OperatorCostPtr &operator_cost() const { return operator_cost_; }
set_cost(const OperatorCostPtr & cost)92   void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
93   virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
94 
95   virtual std::shared_ptr<Strategys> GenerateBatchStrategies();
96   virtual void ReComputeBatchSplitFlagList();
97   void ComputeBatchSplitFlagList();
98 
99   double GetForwardMemoryCostFromCNode();
100   // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy
101   // is checked
102   Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
GetStrategyCost()103   std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
104   void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &);
105   // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
106   // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
107   // at the end of forward phase.
108   Status CalculateMemoryCost();
109   // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
110   // by the output
111   Status CalculateMemoryCostForInference();
112   int64_t ComputeOpAndPrevEdgeParameterInvolved();
113 
forward_op()114   ForwardOp forward_op() const { return forward_op_; }
replace_op()115   ForwardOp replace_op() const { return replace_op_; }
replace_op_info()116   OutPutInfoVector replace_op_info() const { return replace_op_info_; }
replace_graph(const CNodePtr &)117   virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; }
mirror_ops()118   MirrorOps mirror_ops() const { return mirror_ops_; }
sub_ops()119   Ops sub_ops() const { return sub_ops_; }
virtual_div_op()120   VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
dev_matrix_shape()121   Shape dev_matrix_shape() const { return dev_matrix_shape_; }
inputs_tensor_info()122   std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
outputs_tensor_info()123   std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
strategy_cost()124   std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
name()125   const std::string &name() const { return name_; }
set_name(const std::string & name)126   void set_name(const std::string &name) { name_ = name; }
stage_device_list()127   RankList stage_device_list() const { return stage_device_list_; }
128 
AddSuccEdge(const std::shared_ptr<Edge> & e)129   void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); }
AddPrevEdge(const std::shared_ptr<Edge> & e)130   void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); }
succ_edges()131   std::vector<std::shared_ptr<Edge>> succ_edges() const { return succ_edges_; }
prev_edges()132   std::vector<std::shared_ptr<Edge>> prev_edges() const { return prev_edges_; }
133   std::vector<std::shared_ptr<Edge>> GetAliveSuccEdges();
134   std::vector<std::shared_ptr<Edge>> GetAlivePrevEdges();
135   void ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
136   void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
137   void ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
138   void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
GetOutputTypeLengths()139   std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
SetSelectedStrategyAndCost(const StrategyPtr & s_strategy,const CostPtr & cost)140   void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) {
141     selected_strategy_ = s_strategy;
142     selected_cost_ = cost;
143   }
144   void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t);
selected_strategy()145   StrategyPtr selected_strategy() const { return selected_strategy_; }
selected_cost()146   CostPtr selected_cost() const { return selected_cost_; }
147   // Approximate the list of available strategies
148   void ApproximateStrategies();
149   // Make the list of available strategies exact and re-init the related edges incident to this operator
150   void ExactStrategiesAndRelatedEdges();
is_strategy_cost_exact()151   bool is_strategy_cost_exact() { return is_strategy_cost_exact_; }
SetIsStrategyCostExactTrue()152   void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; }
ClearStrategyCost()153   void ClearStrategyCost() { strategy_cost_.clear(); }
154   void CheckSelectedStrategy(const StrategyPtr &);
InitSelectedStrategy(const StrategyPtr & s_strategy)155   Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
set_input_value(const std::vector<ValuePtr> & input_value)156   void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
input_value()157   const std::vector<ValuePtr> &input_value() const { return input_value_; }
set_outputs_dtype(const TypePtr & dtype)158   void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
set_cnode(const CNodePtr & cnode)159   void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
160   CNodePtr cnode();
is_alive()161   bool is_alive() const { return is_alive_; }
SetNotAlive()162   void SetNotAlive() { is_alive_ = false; }
strategy()163   StrategyPtr strategy() const { return strategy_; }
set_strategy(const StrategyPtr & strategy)164   void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; }
set_refkey_parameter_name(std::string p_name)165   void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
refkey_parameter_name()166   const std::string &refkey_parameter_name() const { return refkey_parameter_name_; }
167   // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
168   // multiple times. This method is to correct this, and makes the cost is calculated only once.
169   Status CorrectMemoryCost(size_t input_index);
is_output_parameter_involve()170   int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; }
is_output_critical()171   int64_t is_output_critical() const { return is_output_critical_; }
mark_output_critical()172   void mark_output_critical() { is_output_critical_ = 1; }
mark_output_not_critical()173   void mark_output_not_critical() { is_output_critical_ = 0; }
used_devices()174   int64_t used_devices() const { return used_devices_; }
175   // needed by rec_parser
set_type(const std::string & type)176   void set_type(const std::string &type) { type_ = type; }
type()177   const std::string &type() const { return type_; }
set_last_node_flag(const bool & is_last_node)178   void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; }
is_last_node()179   const bool &is_last_node() const { return is_last_node_; }
attrs()180   const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
set_stage_id(int32_t stage_id)181   void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
stage_id()182   int32_t stage_id() const { return stage_id_; }
183   Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
184   Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);
ReplaceNodeInputOrAttrs()185   virtual void ReplaceNodeInputOrAttrs() {}
186 
187   // Key for user data.
188   constexpr static char key[] = "OpInfo";
189 
190  protected:
191   // needed by rec_parser
192   std::string type_;
193   bool is_last_node_ = false;
194   virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
195   virtual Status InferTensorMap() = 0;
196   virtual Status InferForwardCommunication() = 0;
197   virtual Status GetAttrs() = 0;
198   virtual Status InferDevMatrixShape() = 0;
199   virtual Status InferMirrorOps();
200   virtual Status InferTensorInfo();
201   Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
202   void SetRepeatedCalcDevMatrix();
203   void ResetTensorMapIfRepeatedCalc();
204   Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
205   Status InferAttrs();
206   void ResetQueueMember();
207   Status InitWithAutoRepeatCalc(const StrategyPtr &strategy);
208   Status InitWithManualRepeatCalc(const StrategyPtr &strategy);
209   Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy);
210   Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy);
211   Status InferRepeatedCalcInfo();
212   Status InferVirtualDivOps();
213 
214   // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map.
215   // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
216   // is used for grad and overload the function. If the output is a scalar, need to override the function too.
217   virtual Status InferAsLossDivisor();
218   Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
219                          Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
220   void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
221   int64_t GetIntAttr(const std::string &attr_name);
222   bool GetBoolAttr(const std::string &attr_name);
223   float GetFloatAttr(const std::string &attr_name);
224   std::string GetStringAttr(const std::string &attr_name);
225   std::vector<int64_t> GetTupleIntAttr(const std::string &attr_name);
226 
227   std::string name_;
228   Shapes inputs_shape_;
229   Shapes outputs_shape_;
230   std::unordered_map<std::string, ValuePtr> attrs_;
231   std::vector<ValuePtr> input_value_;
232   TypePtr outputs_dtype_;
233 
234   int32_t stage_id_ = 0;
235   StrategyPtr strategy_;
236   std::vector<TensorInfo> inputs_tensor_info_;
237   std::vector<TensorInfo> outputs_tensor_info_;
238   Shape dev_matrix_shape_;  // if repeated calculation, it contains the repeated_calc_num_
239   Shape out_dev_matrix_shape_;
240   int64_t repeated_calc_num_ = 1;
241   int64_t as_loss_divisor_ = 1;
242   TensorMaps inputs_tensor_map_;
243   TensorMaps outputs_tensor_map_;
244   ForwardOp forward_op_;
245   Ops sub_ops_;
246   ForwardOp replace_op_;
247   OutPutInfoVector replace_op_info_;
248   ReplaceGraphPtr replace_graph_;
249   MirrorOps mirror_ops_;
250   VirtualDivOp virtual_div_op_;
251   RankList stage_device_list_;  // the device list in this stage
252   int64_t stage_device_size_ = 0;
253   bool infer_attrs_completed_ = false;
254 
255   bool is_auto_parallel_ = false;  // false: semi_auto_parallel; true: auto_parallel
256   // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected.
257   std::vector<size_t> corrected_input_indices_;
258   // Given a parallelization strategy, there is a cost.
259   std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_;
260   // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
261   std::vector<bool> is_parameter_;
262   // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
263   // pre-operator that has parameters as input.
264   std::vector<bool> is_parameter_involve_;
265   // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
266   // peak memory cost in the training phase.
267   // -1: unset; 0: not parameter_involved; 1: parameter_involved
268   int64_t is_output_parameter_involve_ = -1;
269   // Whether this output is critical, which means that this output is included in calculating peak memory cost
270   // in the inference phase.
271   // -1 : unset; 0: not critical; 1: critical
272   int64_t is_output_critical_ = -1;
273   double outputs_total_size_ = 0.0;
274   bool is_calculated_outputs_size_ = false;
275   // for each input and output, the followings record the number of bytes of each element
276   std::vector<size_t> inputs_type_lengths_;
277   std::vector<size_t> outputs_type_lengths_;
278   std::vector<std::shared_ptr<Edge>> prev_edges_;
279   std::vector<std::shared_ptr<Edge>> succ_edges_;
280   StrategyPtr selected_strategy_;
281   int64_t selected_strategy_depth_ = -1;
282   // Used in DP algorithm
283   bool is_alive_;
284   CostPtr selected_cost_;
285   std::vector<bool> split_flag_list_;
286   std::string refkey_parameter_name_;
287   CNodePtr cnode_;
288   int64_t used_devices_ = -1;
289   // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
290   bool repeated_num_in_dev_matrix_right_ = true;
291   // Whether the list of available strategies is exact or approximate
292   bool is_strategy_cost_exact_ = true;
293 
294  private:
295   OperatorCostPtr operator_cost_;
296   std::vector<TypePtr> outputs_type_;
297 };
298 
299 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
300 Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool);
301 Operator CreateVirtualDivOp(int64_t div_num);
302 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
303 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
304 Operator CreateAllGatherOp(const std::string &group);
305 Operator CreateMiniStepAllGatherOp(const std::string &group);
306 void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
307 Operator CreateMicroStepAllGatherOp(const std::string &group);
308 void AddCommOpMeanFlag(const CNodePtr &comm_node);
309 void AddCommOpParamFlag(const CNodePtr &comm_node);
310 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
311 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
312 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
313 std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
314                                                               const std::vector<bool> &split_flag_list);
315 std::string StrategyToString(const Strategys &strategy);
316 void PrintStrategy(const StrategyPtr &strategy);
317 // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d])
318 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
319                                               const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
320 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
321 // and the corresponding dimensions that are not broadcast are all relevant dimensions
322 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
323 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
324 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
325 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
326                                        std::vector<StrategyPtr> *sp_vector);
327 
328 Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
329 std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue);
330 }  // namespace parallel
331 }  // namespace mindspore
332 
333 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
334