• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "ir/anf.h"
28 #include "utils/hash_map.h"
29 #include "utils/ms_utils.h"
30 #include "base/base.h"
31 #include "frontend/parallel/auto_parallel/costmodel.h"
32 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "frontend/parallel/device_matrix.h"
35 #include "frontend/parallel/group_manager.h"
36 #include "frontend/parallel/ops_info/ops_utils.h"
37 #include "frontend/parallel/strategy.h"
38 #include "frontend/parallel/tensor_layout/tensor_info.h"
39 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
40 #include "utils/log_adapter.h"
41 #include "ops/op_utils.h"
42 
43 namespace mindspore {
44 namespace parallel {
45 using ForwardOp = OperatorVector;
46 using MirrorOps = std::vector<OperatorVector>;
47 using Ops = std::vector<OperatorVector>;
48 using VirtualDivOp = OperatorVector;
49 using TensorMaps = std::vector<Shape>;
50 using TensorLayouts = std::vector<TensorLayout>;
51 using different_type = std::vector<int64_t>::difference_type;
52 using PrimitiveAttrs = mindspore::HashMap<std::string, ValuePtr>;
53 using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>;
54 using TensorRedistributionPtr = std::shared_ptr<TensorRedistribution>;
55 
56 #define FILTER_LOG(x) (x) ? void(0) : MS_LOG(ERROR)
57 
58 enum InferStrategyMode {
59   SAME_MODE = 0,
60   BROADCAST_MODE = 1,
61   INDEPENDENT_MODE = 2,
62   INDIVIDUAL_MODE = 3,
63   INVALID_MODE = 4,
64 };
65 
66 class TensorInfoBase {
67  public:
TensorInfoBase(bool is_list)68   explicit TensorInfoBase(bool is_list) { is_list_ = is_list; }
69   virtual ~TensorInfoBase() = default;
is_list()70   bool is_list() const { return is_list_; }
71   virtual std::shared_ptr<TensorInfoBase> GetElement(int64_t idx) = 0;
72   virtual TensorInfo GetValue() = 0;
73   virtual size_t size() = 0;
74 
75  private:
76   bool is_list_;
77 };
78 
79 using TensorInfoBasePtr = std::shared_ptr<TensorInfoBase>;
80 
81 class TensorInfoValue : public TensorInfoBase {
82  public:
TensorInfoValue(TensorInfo l)83   explicit TensorInfoValue(TensorInfo l) : TensorInfoBase(false), _l(std::move(l)) {}
84   ~TensorInfoValue() override = default;
GetElement(int64_t idx)85   std::shared_ptr<TensorInfoBase> GetElement(int64_t idx) override {
86     MS_LOG(WARNING) << "Can not get element from TensorInfoValue, please use GetValue";
87     return std::make_shared<TensorInfoValue>(_l);
88   }
GetValue()89   TensorInfo GetValue() override { return _l; }
size()90   size_t size() override { return 1; }
91 
92  private:
93   TensorInfo _l;
94 };
95 
96 class TensorInfoList : public TensorInfoBase {
97  public:
TensorInfoList(std::vector<TensorInfoBasePtr> l_list)98   explicit TensorInfoList(std::vector<TensorInfoBasePtr> l_list) : TensorInfoBase(true), _l_list(std::move(l_list)) {}
99   ~TensorInfoList() override = default;
GetElement(int64_t idx)100   TensorInfoBasePtr GetElement(int64_t idx) override {
101     if (idx < 0 || static_cast<size_t>(idx) >= _l_list.size()) {
102       MS_LOG(EXCEPTION) << "Index " << idx << " is out of range";
103     }
104     return _l_list[LongToSize(idx)];
105   }
GetValue()106   TensorInfo GetValue() override { MS_LOG(EXCEPTION) << "Can not get value from TensorInfoList"; }
size()107   size_t size() override { return _l_list.size(); }
108 
109  private:
110   std::vector<TensorInfoBasePtr> _l_list;
111 };
112 
113 class Edge;
114 
115 inline std::string GetPrimNameFromInfoName(const std::string &info_name);
116 
117 class OperatorInfo {
118  public:
OperatorInfo(std::string name,Shapes inputs_shape,Shapes outputs_shape,PrimitiveAttrs attrs,const OperatorCostPtr & cost)119   OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs,
120                const OperatorCostPtr &cost)
121       : name_(std::move(name)),
122         inputs_shape_(std::move(inputs_shape)),
123         outputs_shape_(std::move(outputs_shape)),
124         attrs_(std::move(attrs)),
125         is_alive_(true),
126         operator_cost_(cost),
127         outputs_type_() {
128     std::vector<bool> not_parameteter(inputs_shape_.size(), false);
129     is_parameter_ = not_parameteter;
130     refkey_parameter_name_ = "";
131     stage_device_list_ = g_device_manager->GetDeviceListInThisStage();
132     stage_device_size_ = SizeToLong(stage_device_list_.size());
133     cnode_ = nullptr;
134     prim_name_ = GetPrimNameFromInfoName(this->name_);
135   }
136 
137   virtual ~OperatorInfo() = default;
138 
set_involved_param_name(std::string name)139   void set_involved_param_name(std::string name) { involved_param_name_ = name; }
get_involved_param_name()140   std::string get_involved_param_name() { return involved_param_name_; }
141   Status set_is_parameter(const std::vector<bool> &is_parameter);
142   Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
143                                      const std::vector<size_t> &output_lengths);
144   double GetOutputsTotalSize();
145   // Set outputs dtype.
146   // If only one output, outputs_type.size() is 1.
147   // If output is tuple, outputs_type.size() is greater than 1.
148   Status set_outputs_type(const std::vector<TypePtr> &outputs_type);
outputs_type()149   const std::vector<TypePtr> &outputs_type() const { return outputs_type_; }
150   virtual Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
151                       const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
152                       const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {});
153   // only init the necessary parts
154   virtual Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
155 
156   // Given the stage_id (which indicates the number of devices),
157   // generate all strategies for this operator
158   Status GenerateStrategies(int64_t stage_id);
159   virtual std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) = 0;
operator_cost()160   const OperatorCostPtr &operator_cost() const { return operator_cost_; }
set_cost(const OperatorCostPtr & cost)161   void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
162   virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
163 
164   virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
165   virtual void ReComputeBatchSplitFlagList();
166   std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
167   void ComputeBatchSplitFlagList();
inputs_shape()168   Shapes inputs_shape() const { return inputs_shape_; }
inputs_shape_new()169   NewShapes inputs_shape_new() const { return inputs_shape_new_; }
outputs_shape()170   Shapes outputs_shape() const { return outputs_shape_; }
outputs_shape_new()171   NewShapes outputs_shape_new() const { return outputs_shape_new_; }
set_inputs_divisor(const Shapes & in_divisor)172   void set_inputs_divisor(const Shapes &in_divisor) { inputs_divisor_ = in_divisor; }
set_outputs_divisor(const Shapes & out_divisor)173   void set_outputs_divisor(const Shapes &out_divisor) { outputs_divisor_ = out_divisor; }
set_dynamic_shape_flag(bool flag)174   void set_dynamic_shape_flag(bool flag) { dynamic_shape_flag_ = flag; }
175 
176   double GetForwardMemoryCostFromCNode();
177   // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy
178   // is checked
179   Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
180   CostPtrList GetCostByStrategyPtr(const StrategyPtr &strategy);
GetStrategyCost()181   std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
182   void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost);
183   // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
184   // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
185   // at the end of forward phase.
186   Status CalculateMemoryCost();
187   // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
188   // by the output
189   Status CalculateMemoryCostForInference();
190   virtual int64_t ComputeOpAndPrevEdgeParameterInvolved();
191 
forward_op()192   ForwardOp forward_op() const { return forward_op_; }
replace_op()193   ForwardOp replace_op() const { return replace_op_; }
replace_op_info()194   OutPutInfoVector replace_op_info() const { return replace_op_info_; }
replace_graph(const CNodePtr &)195   virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; }
mirror_ops()196   MirrorOps mirror_ops() const { return mirror_ops_; }
sub_ops()197   Ops sub_ops() const { return sub_ops_; }
virtual_div_op()198   VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
dev_matrix_shape()199   Shape dev_matrix_shape() const { return dev_matrix_shape_; }
inputs_tensor_info()200   std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
inputs_tensor_info_new()201   std::vector<TensorInfoBasePtr> inputs_tensor_info_new() const { return inputs_tensor_info_new_; }
set_inputs_tensor_info(const std::vector<TensorInfo> & tensor_info)202   void set_inputs_tensor_info(const std::vector<TensorInfo> &tensor_info) { inputs_tensor_info_ = tensor_info; }
outputs_tensor_info()203   std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
outputs_tensor_info_new()204   std::vector<TensorInfoBasePtr> outputs_tensor_info_new() const { return outputs_tensor_info_new_; }
strategy_cost()205   std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
name()206   const std::string &name() const { return name_; }
set_name(const std::string & name)207   void set_name(const std::string &name) { name_ = name; }
stage_device_list()208   RankList stage_device_list() const { return stage_device_list_; }
209 
AddSuccEdge(const std::shared_ptr<Edge> & e)210   void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); }
AddPrevEdge(const std::shared_ptr<Edge> & e)211   void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); }
succ_edges()212   std::vector<std::shared_ptr<Edge>> succ_edges() const { return succ_edges_; }
prev_edges()213   std::vector<std::shared_ptr<Edge>> prev_edges() const { return prev_edges_; }
214   std::vector<std::shared_ptr<Edge>> GetAliveSuccEdges();
215   std::vector<std::shared_ptr<Edge>> GetAlivePrevEdges();
216   void ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
217   void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
218   void ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
219   void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
GetOutputTypeLengths()220   std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
SetSelectedStrategyAndCost(const StrategyPtr & s_strategy,const CostPtr & cost)221   void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) {
222     selected_strategy_ = s_strategy;
223     selected_cost_ = cost;
224   }
225   void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth);
selected_strategy()226   StrategyPtr selected_strategy() const { return selected_strategy_; }
selected_cost()227   CostPtr selected_cost() const { return selected_cost_; }
228 
229   TensorLayout GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index);
230   TensorLayout GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index);
231   StrategyPtr GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index);
232   StrategyPtr GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index);
233   bool IsReshape() const;
234   bool IsTmpIdentity() const;
235 
236   void set_swc_index(int64_t swc, int64_t depth);
swc_index()237   int64_t swc_index() const { return swc_index_; }
238   // Approximate the list of available strategies
239   void ApproximateStrategies();
240   // Make the list of available strategies exact and re-init the related edges incident to this operator
241   void ExactStrategiesAndRelatedEdges();
is_strategy_cost_exact()242   bool is_strategy_cost_exact() const { return is_strategy_cost_exact_; }
SetIsStrategyCostExactTrue()243   void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; }
ClearStrategyCost()244   void ClearStrategyCost() { strategy_cost_.clear(); }
245   void CheckSelectedStrategy(const StrategyPtr &s_strategy);
InitSelectedStrategy(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)246   Status InitSelectedStrategy(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
247     set_auto_parallel(false);
248     return Init(in_strategy, out_strategy);
249   }
set_input_value(const std::vector<ValuePtr> & input_value)250   void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
input_value()251   const std::vector<ValuePtr> &input_value() const { return input_value_; }
set_outputs_dtype(const TypePtr & dtype)252   void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
set_cnode(const CNodePtr & cnode)253   void set_cnode(const CNodePtr &cnode) {
254     cnode_ = cnode;
255     cnodes_.push_back(cnode);
256   }
set_new_shape(const std::vector<NewShapes> & shape)257   void set_new_shape(const std::vector<NewShapes> &shape) {
258     inputs_shape_new_ = shape[0];
259     outputs_shape_new_ = shape[1];
260   }
261   std::vector<CNodePtr> cnodes();
cnode()262   CNodePtr cnode() const { return cnode_; }
is_alive()263   bool is_alive() const { return is_alive_; }
SetNotAlive()264   void SetNotAlive() { is_alive_ = false; }
split_flag_list()265   std::vector<bool> split_flag_list() const { return split_flag_list_; }
strategy()266   StrategyPtr strategy() const { return strategy_; }
out_strategy()267   StrategyPtr out_strategy() const { return out_strategy_; }
set_out_strategy(const StrategyPtr & strategy)268   void set_out_strategy(const StrategyPtr &strategy) { out_strategy_ = strategy; }
set_strategy(const StrategyPtr & strategy)269   void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; }
set_refkey_parameter_name(std::string p_name)270   void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
refkey_parameter_name()271   const std::string &refkey_parameter_name() const { return refkey_parameter_name_; }
272   // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
273   // multiple times. This method is to correct this, and makes the cost is calculated only once.
274   Status CorrectMemoryCost(size_t input_index);
is_output_parameter_involve()275   int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; }
is_output_critical()276   int64_t is_output_critical() const { return is_output_critical_; }
mark_output_critical()277   void mark_output_critical() { is_output_critical_ = 1; }
mark_output_not_critical()278   void mark_output_not_critical() { is_output_critical_ = 0; }
used_devices()279   int64_t used_devices() const { return used_devices_; }
280   // needed by rec_parser
set_type(const std::string & type)281   void set_type(const std::string &type) { type_ = type; }
type()282   const std::string &type() const { return type_; }
set_last_node_flag(const bool & is_last_node)283   void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; }
is_last_node()284   const bool &is_last_node() const { return is_last_node_; }
attrs()285   const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; }
addAttr(const std::string & name,const ValuePtr & val)286   void addAttr(const std::string &name, const ValuePtr &val) { attrs_[name] = val; }
set_stage_id(int32_t stage_id)287   void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
stage_id()288   int32_t stage_id() const { return stage_id_; }
289   Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
290   Status CreateGroupForOptShard(TensorLayout *tensor_layout, std::vector<Group> *groups);
ReplaceNodeInputOrAttrs()291   virtual void ReplaceNodeInputOrAttrs() {}
set_auto_parallel(bool is_auto_parallel)292   void set_auto_parallel(bool is_auto_parallel) { is_auto_parallel_ = is_auto_parallel; }
set_assigned_parallel(bool is_assigned_parallel)293   void set_assigned_parallel(bool is_assigned_parallel) { is_assigned_parallel_ = is_assigned_parallel; }
repeated_num_in_dev_matrix_right()294   bool repeated_num_in_dev_matrix_right() const { return repeated_num_in_dev_matrix_right_; }
set_repeated_num_in_dev_matrix_right(bool is_right)295   void set_repeated_num_in_dev_matrix_right(bool is_right) { repeated_num_in_dev_matrix_right_ = is_right; }
296 
297   TensorRedistributionPtr CreateTensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) {
298     if (this->tensor_redistribution_ != nullptr) {
299       MS_LOG(DEBUG) << "TensorRedistribution re-created.";
300     }
301     this->tensor_redistribution_ = std::make_shared<TensorRedistribution>(construct_op_flag, keep_reshape);
302     return this->tensor_redistribution_;
303   }
304 
305   TensorRedistributionPtr CreateReshapeTensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) {
306     if (this->reshape_tensor_redistribution_ != nullptr) {
307       MS_LOG(DEBUG) << "TensorRedistribution re-created.";
308     }
309     this->reshape_tensor_redistribution_ = std::make_shared<TensorRedistribution>(construct_op_flag, keep_reshape);
310     return this->reshape_tensor_redistribution_;
311   }
312 
SetTensorRedistribution(const TensorRedistributionPtr & tensor_redistribution)313   void SetTensorRedistribution(const TensorRedistributionPtr &tensor_redistribution) {
314     this->tensor_redistribution_ = tensor_redistribution;
315   }
316 
SetReshapeTensorRedistribution(const TensorRedistributionPtr & tensor_redistribution)317   void SetReshapeTensorRedistribution(const TensorRedistributionPtr &tensor_redistribution) {
318     this->reshape_tensor_redistribution_ = tensor_redistribution;
319   }
320 
tensor_redistribution()321   TensorRedistributionPtr tensor_redistribution() { return this->tensor_redistribution_; }
322 
reshape_tensor_redistribution()323   TensorRedistributionPtr reshape_tensor_redistribution() { return this->reshape_tensor_redistribution_; }
324 
325   // Key for user data.
326   constexpr static char key[] = "OpInfo";
327 
328  protected:
329   // needed by rec_parser
330   std::string type_;
331   TensorRedistributionPtr tensor_redistribution_;
332   TensorRedistributionPtr reshape_tensor_redistribution_;
333   bool is_last_node_ = false;
334   virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
335   virtual Status InferTensorMap() = 0;
InferOutputTensorMap()336   virtual Status InferOutputTensorMap() { return SUCCESS; }
InferOutputTensorInfo()337   virtual Status InferOutputTensorInfo() { return SUCCESS; }
CheckLayoutConfig()338   virtual Status CheckLayoutConfig() { return SUCCESS; }
339   virtual Status CheckInputLayout();
CheckOutputLayout()340   virtual Status CheckOutputLayout() { return SUCCESS; }
InferForwardCommunicationByLayout()341   virtual Status InferForwardCommunicationByLayout() { return SUCCESS; }
342   virtual Status InferMirrorOpsByLayout();
343   virtual Status InferForwardCommunication() = 0;
344   virtual Status GetAttrs() = 0;
345   virtual Status InferDevMatrixShape() = 0;
346   virtual Status InferMirrorOps();
347   virtual Status InferTensorInfo();
348   virtual Status InferTensorInfoNew();
349 
InferReplaceOps()350   virtual void InferReplaceOps() {}
351   virtual void UpdateOutputTensorInfoForInterleaved();
352   virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
CheckStrategyForDynamicShape(const StrategyPtr & strategy)353   virtual Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) { return SUCCESS; }
354   Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
355   Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
356   void DivisorsReplaceShapes();  // in dynamic shape, using divisors replace to shapes before CheckStrategy and so on
357   void ResumeShapes();           // in dynamic shape, resume shapes after CheckStrategy and so on
358   void DynamicShapeCheckStrategyLog();
359   void SetRepeatedCalcDevMatrix();
360   void ResetTensorMapIfRepeatedCalc();
361   void ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps *tensor_map_new);
362   void ChangeMakeTupleConstant(const CNodePtr &cnode, size_t make_tuple_index);
363   Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
364   Status CreateGroupByDimWithDevMatrix(DeviceMatrix *dev_matrix, size_t axis, std::vector<Group> *group);
365   Status InferAttrs();
366   void ResetQueueMember();
367   Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
368   Status InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
369                               const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts);
370   Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
371   Status InferRepeatedCalcInfo();
372   Status InferVirtualDivOps();
373   Status InferVirtualDivOpsByLayout();
374   bool IsDynamicShape();
375   bool IsDynamicRank();
376   bool IsSelfDefineShard();
377 
378   // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map.
379   // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
380   // is used for grad and overload the function. If the output is a scalar, need to override the function too.
381   virtual Status InferAsLossDivisor();
382   virtual Status InferAsLossDivisorByLayout();
383   void BreakingTiesForPreferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) const;
384   int64_t GetIntAttr(const std::string &attr_name);
385   bool GetBoolAttr(const std::string &attr_name);
386   float GetFloatAttr(const std::string &attr_name);
387   std::string GetStringAttr(const std::string &attr_name);
388   std::vector<int64_t> GetTupleIntAttr(const std::string &attr_name);
ReportError(const std::string & error_msg)389   void ReportError(const std::string &error_msg) const {
390     if (is_auto_parallel_) {
391       MS_LOG(DEBUG) << error_msg;
392     } else {
393       MS_LOG(ERROR) << error_msg;
394     }
395   }
396 
397   std::string name_;
398   std::string prim_name_;
399   Shapes inputs_shape_;
400   Shapes outputs_shape_;
401   NewShapes inputs_shape_new_;
402   NewShapes outputs_shape_new_;
403   Shapes inputs_divisor_;   // using for dynamic shape, the size is equal to inputs_shape_
404   Shapes outputs_divisor_;  // using for dynamic shape, the size is equal to outputs_shape_
405   Shapes inputs_shape_clone_;
406   Shapes outputs_shape_clone_;
407   bool dynamic_shape_flag_ = False;  // means this op in the dynamic shape graph
408   mindspore::HashMap<std::string, ValuePtr> attrs_;
409   std::vector<ValuePtr> input_value_;
410   TypePtr outputs_dtype_;
411 
412   int32_t stage_id_ = 0;
413   StrategyPtr strategy_;
414   StrategyPtr out_strategy_;
415   std::vector<TensorInfo> inputs_tensor_info_;
416   std::vector<TensorInfo> outputs_tensor_info_;
417   std::vector<TensorInfoBasePtr> inputs_tensor_info_new_;
418   std::vector<TensorInfoBasePtr> outputs_tensor_info_new_;
419   Shape dev_matrix_shape_;  // if repeated calculation, it contains the repeated_calc_num_
420   Shape out_dev_matrix_shape_;
421   int64_t repeated_calc_num_ = 1;
422   int64_t as_loss_divisor_ = 1;
423   TensorMaps inputs_tensor_map_;
424   TensorMaps outputs_tensor_map_;
425   NewTensorMaps inputs_tensor_map_new_;
426   NewTensorMaps outputs_tensor_map_new_;
427   ForwardOp forward_op_;
428   ForwardOp forward_op_interleaved_;
429   Ops sub_ops_;
430   ForwardOp replace_op_;
431   OutPutInfoVector replace_op_info_;
432   ReplaceGraphPtr replace_graph_;
433   MirrorOps mirror_ops_;
434   VirtualDivOp virtual_div_op_;
435   RankList stage_device_list_;  // the device list in this stage
436   int64_t stage_device_size_ = 0;
437   bool infer_attrs_completed_ = false;
438   bool is_layout_config_ = false;
439   bool is_dynamic_shape_ = false;
440   bool is_dynamic_rank_ = false;
441   Shapes strategy_from_layout_;
442 
443   bool is_auto_parallel_ = false;      // false: semi_auto_parallel; true: auto_parallel
444   bool is_assigned_parallel_ = false;  // false: origin parallel; true: dynamic_shape parallel
445   // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected.
446   std::vector<size_t> corrected_input_indices_;
447   // Given a parallelization strategy, there is a cost.
448   std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_;
449   std::string involved_param_name_;
450   // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
451   std::vector<bool> is_parameter_;
452   // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
453   // pre-operator that has parameters as input.
454   std::vector<bool> is_parameter_involve_;
455   // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
456   // peak memory cost in the training phase.
457   // -1: unset; 0: not parameter_involved; 1: parameter_involved
458   int64_t is_output_parameter_involve_ = -1;
459   // Whether this output is critical, which means that this output is included in calculating peak memory cost
460   // in the inference phase.
461   // -1 : unset; 0: not critical; 1: critical
462   int64_t is_output_critical_ = -1;
463   double outputs_total_size_ = 0.0;
464   bool is_calculated_outputs_size_ = false;
465   // for each input and output, the followings record the number of bytes of each element
466   std::vector<size_t> inputs_type_lengths_;
467   std::vector<size_t> outputs_type_lengths_;
468   std::vector<std::shared_ptr<Edge>> prev_edges_;
469   std::vector<std::shared_ptr<Edge>> succ_edges_;
470   StrategyPtr selected_strategy_;
471   int64_t selected_strategy_depth_ = -1;
472   // Used in DP algorithm
473   bool is_alive_;
474   CostPtr selected_cost_;
475   std::vector<bool> split_flag_list_;
476   std::string refkey_parameter_name_;
477   CNodePtr cnode_;
478   std::vector<CNodePtr> cnodes_;
479   int64_t used_devices_ = -1;
480   // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
481   bool repeated_num_in_dev_matrix_right_ = true;
482   // Whether the list of available strategies is exact or approximate
483   bool is_strategy_cost_exact_ = true;
484   bool self_define_shard_;
485 
486  private:
487   OperatorCostPtr operator_cost_;
488   std::vector<TypePtr> outputs_type_;
489   int64_t swc_index_ = -1;
490   Status GetLayoutConfig();
491   Status GetRepeatedNumInDevMatrixRight();
492   Status CheckLayoutConfigBase();
493 };
494 
495 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
496 Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool);
497 Operator CreateVirtualDivOp(int64_t div_num);
498 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
499 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
500 Operator CreateAllGatherOp(const std::string &group);
501 Operator CreateCastOp(TypePtr type);
502 Operator CreateDivOp(float scale);
503 Operator CreateScalarDivOp(int64_t div_num);
504 Operator CreateScalarCastOp(TypePtr type);
505 Operator CreateScalarFloorDivOp(int64_t div_num);
506 Operator CreateScalarMulOp(int64_t scalar);
507 void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val);
508 int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
509 Operator CreateMicroStepAllGatherOp(const std::string &group);
510 void AddCommOpMeanFlag(const CNodePtr &comm_node);
511 void AddCommOpParamFlag(const CNodePtr &comm_node);
512 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
513 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
514 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
515 std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
516                                                                const std::vector<bool> &split_flag_list);
517 std::string StrategyToString(const Strategies &strategy);
518 Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
519                                                   const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
520 // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d])
521 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
522                                               const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
523 // generate strategies for that inputs' dimension maybe dependent
524 Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
525                                             const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp);
526 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
527 // and the corresponding dimensions that are not broadcast are all relevant dimensions
528 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
529 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
530 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
531 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
532                                        std::vector<StrategyPtr> *sp_vector);
533 std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence);
534 ValuePtr MakeListValue(const std::vector<int64_t> &v);
535 ValuePtr MakeTupleListValue(const Shapes &v);
536 AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
537 AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
538 
539 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
540 Operator CreateDivOpWithType(float divisor, const TypePtr &dtype);
541 std::vector<int64_t> GetTensorValue(const ValuePtr &ori_value);
542 
GetPrimNameFromInfoName(const std::string & info_name)543 inline std::string GetPrimNameFromInfoName(const std::string &info_name) {
544   auto prim_name = info_name;
545   if (auto pos = info_name.rfind("Info"); pos != std::string::npos) {
546     prim_name = info_name.substr(0, pos);
547   }
548   return prim_name;
549 }
550 
551 template <typename T>
GetScalarValueFromInputs(const std::vector<ValuePtr> & input_value,size_t idx)552 std::optional<T> GetScalarValueFromInputs(const std::vector<ValuePtr> &input_value, size_t idx) {
553   if (idx == SIZE_MAX) {
554     MS_EXCEPTION(ValueError) << "Index is the size max, target value maybe wrong!";
555   }
556 
557   if (input_value.size() <= idx || input_value[idx] == nullptr) {
558     return std::nullopt;
559   }
560   return ops::GetScalarValue<T>(input_value[idx]);
561 }
562 
563 template <typename T>
GetInputValueFromCNode(const CNodePtr & cnode,size_t index)564 T GetInputValueFromCNode(const CNodePtr &cnode, size_t index) {
565   MS_EXCEPTION_IF_NULL(cnode);
566   auto inputs = cnode->inputs();
567   if (index >= inputs.size()) {
568     MS_LOG(EXCEPTION) << "The input index (" << index << ") is exceed of inputs size (" << inputs.size() << ").";
569   }
570   auto input_node = inputs[index];
571   MS_EXCEPTION_IF_NULL(input_node);
572   if (!input_node->isa<ValueNode>()) {
573     MS_LOG(EXCEPTION) << "The input index (" << index << ") is not a value node.";
574   }
575   auto value = input_node->cast<ValueNodePtr>()->value();
576   MS_EXCEPTION_IF_NULL(value);
577   return GetValue<T>(value);
578 }
579 
580 template <typename T>
SetValueInputToCNode(const CNodePtr & cnode,size_t index,T value)581 void SetValueInputToCNode(const CNodePtr &cnode, size_t index, T value) {
582   MS_EXCEPTION_IF_NULL(cnode);
583   auto inputs = cnode->inputs();
584   if (index >= inputs.size()) {
585     MS_LOG(EXCEPTION) << "The input index (" << index << ") is exceed of inputs size (" << inputs.size() << ").";
586   }
587   auto func_graph = cnode->func_graph();
588   MS_EXCEPTION_IF_NULL(func_graph);
589   auto manager = func_graph->manager();
590   auto value_node = NewValueNode(MakeValue(value));
591   MS_EXCEPTION_IF_NULL(value_node);
592   manager->SetEdge(cnode, index, value_node);
593 }
594 
595 template <typename T>
GetScalarValueFromInputs(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)596 std::optional<T> GetScalarValueFromInputs(const std::vector<ValuePtr> &input_value, const std::string &op_name,
597                                           const std::string &attr_name) {
598   auto prim_name = GetPrimNameFromInfoName(op_name);
599   auto idx = ops::GetInputIndexByName(prim_name, attr_name);
600   return GetScalarValueFromInputs<T>(input_value, idx);
601 }
602 
603 template <typename T>
GetArrayValueFromInputs(const std::vector<ValuePtr> & input_value,size_t idx)604 std::optional<std::vector<T>> GetArrayValueFromInputs(const std::vector<ValuePtr> &input_value, size_t idx) {
605   if (idx == SIZE_MAX) {
606     MS_EXCEPTION(ValueError) << "Index is the size max, target value maybe wrong!";
607   }
608 
609   if (input_value.size() <= idx || input_value[idx] == nullptr) {
610     return std::nullopt;
611   }
612   auto array_opt = ops::GetArrayValue<T>(input_value[idx]);
613   if (!array_opt.has_value() || array_opt.value().HasUnknownValue()) {
614     return std::nullopt;
615   }
616   return array_opt.value().ToVector();
617 }
618 
619 template <typename T>
GetArrayValueFromInputs(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)620 std::optional<std::vector<T>> GetArrayValueFromInputs(const std::vector<ValuePtr> &input_value,
621                                                       const std::string &op_name, const std::string &attr_name) {
622   auto prim_name = GetPrimNameFromInfoName(op_name);
623   auto idx = ops::GetInputIndexByName(prim_name, attr_name);
624   return GetArrayValueFromInputs<T>(input_value, idx);
625 }
626 
627 template <typename T>
GetArrayValueFromInputsWithCheck(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)628 std::optional<std::vector<T>> GetArrayValueFromInputsWithCheck(const std::vector<ValuePtr> &input_value,
629                                                                const std::string &op_name,
630                                                                const std::string &attr_name) {
631   auto attr_opt = GetArrayValueFromInputs<T>(input_value, op_name, attr_name);
632   if (!attr_opt.has_value()) {
633     MS_LOG(ERROR) << op_name << ": Don't have attribution " << attr_name;
634     return std::nullopt;
635   }
636   return attr_opt;
637 }
638 
639 template <typename T>
GetScalarValueFromInputsWithCheck(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)640 std::optional<T> GetScalarValueFromInputsWithCheck(const std::vector<ValuePtr> &input_value, const std::string &op_name,
641                                                    const std::string &attr_name) {
642   auto attr_opt = GetScalarValueFromInputs<T>(input_value, op_name, attr_name);
643   if (!attr_opt.has_value()) {
644     MS_LOG(ERROR) << op_name << ": Don't have attribution " << attr_name;
645     return std::nullopt;
646   }
647   return attr_opt;
648 }
649 
650 }  // namespace parallel
651 }  // namespace mindspore
652 
653 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
654