1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
19
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "ir/anf.h"
28 #include "utils/hash_map.h"
29 #include "utils/ms_utils.h"
30 #include "base/base.h"
31 #include "frontend/parallel/auto_parallel/costmodel.h"
32 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "frontend/parallel/device_matrix.h"
35 #include "frontend/parallel/group_manager.h"
36 #include "frontend/parallel/ops_info/ops_utils.h"
37 #include "frontend/parallel/strategy.h"
38 #include "frontend/parallel/tensor_layout/tensor_info.h"
39 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
40 #include "utils/log_adapter.h"
41 #include "ops/op_utils.h"
42
43 namespace mindspore {
44 namespace parallel {
45 using ForwardOp = OperatorVector;
46 using MirrorOps = std::vector<OperatorVector>;
47 using Ops = std::vector<OperatorVector>;
48 using VirtualDivOp = OperatorVector;
49 using TensorMaps = std::vector<Shape>;
50 using TensorLayouts = std::vector<TensorLayout>;
51 using different_type = std::vector<int64_t>::difference_type;
52 using PrimitiveAttrs = mindspore::HashMap<std::string, ValuePtr>;
53 using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>;
54 using TensorRedistributionPtr = std::shared_ptr<TensorRedistribution>;
55
56 #define FILTER_LOG(x) (x) ? void(0) : MS_LOG(ERROR)
57
58 enum InferStrategyMode {
59 SAME_MODE = 0,
60 BROADCAST_MODE = 1,
61 INDEPENDENT_MODE = 2,
62 INDIVIDUAL_MODE = 3,
63 INVALID_MODE = 4,
64 };
65
66 class TensorInfoBase {
67 public:
TensorInfoBase(bool is_list)68 explicit TensorInfoBase(bool is_list) { is_list_ = is_list; }
69 virtual ~TensorInfoBase() = default;
is_list()70 bool is_list() const { return is_list_; }
71 virtual std::shared_ptr<TensorInfoBase> GetElement(int64_t idx) = 0;
72 virtual TensorInfo GetValue() = 0;
73 virtual size_t size() = 0;
74
75 private:
76 bool is_list_;
77 };
78
79 using TensorInfoBasePtr = std::shared_ptr<TensorInfoBase>;
80
81 class TensorInfoValue : public TensorInfoBase {
82 public:
TensorInfoValue(TensorInfo l)83 explicit TensorInfoValue(TensorInfo l) : TensorInfoBase(false), _l(std::move(l)) {}
84 ~TensorInfoValue() override = default;
GetElement(int64_t idx)85 std::shared_ptr<TensorInfoBase> GetElement(int64_t idx) override {
86 MS_LOG(WARNING) << "Can not get element from TensorInfoValue, please use GetValue";
87 return std::make_shared<TensorInfoValue>(_l);
88 }
GetValue()89 TensorInfo GetValue() override { return _l; }
size()90 size_t size() override { return 1; }
91
92 private:
93 TensorInfo _l;
94 };
95
96 class TensorInfoList : public TensorInfoBase {
97 public:
TensorInfoList(std::vector<TensorInfoBasePtr> l_list)98 explicit TensorInfoList(std::vector<TensorInfoBasePtr> l_list) : TensorInfoBase(true), _l_list(std::move(l_list)) {}
99 ~TensorInfoList() override = default;
GetElement(int64_t idx)100 TensorInfoBasePtr GetElement(int64_t idx) override {
101 if (idx < 0 || static_cast<size_t>(idx) >= _l_list.size()) {
102 MS_LOG(EXCEPTION) << "Index " << idx << " is out of range";
103 }
104 return _l_list[LongToSize(idx)];
105 }
GetValue()106 TensorInfo GetValue() override { MS_LOG(EXCEPTION) << "Can not get value from TensorInfoList"; }
size()107 size_t size() override { return _l_list.size(); }
108
109 private:
110 std::vector<TensorInfoBasePtr> _l_list;
111 };
112
113 class Edge;
114
115 inline std::string GetPrimNameFromInfoName(const std::string &info_name);
116
117 class OperatorInfo {
118 public:
OperatorInfo(std::string name,Shapes inputs_shape,Shapes outputs_shape,PrimitiveAttrs attrs,const OperatorCostPtr & cost)119 OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs,
120 const OperatorCostPtr &cost)
121 : name_(std::move(name)),
122 inputs_shape_(std::move(inputs_shape)),
123 outputs_shape_(std::move(outputs_shape)),
124 attrs_(std::move(attrs)),
125 is_alive_(true),
126 operator_cost_(cost),
127 outputs_type_() {
128 std::vector<bool> not_parameteter(inputs_shape_.size(), false);
129 is_parameter_ = not_parameteter;
130 refkey_parameter_name_ = "";
131 stage_device_list_ = g_device_manager->GetDeviceListInThisStage();
132 stage_device_size_ = SizeToLong(stage_device_list_.size());
133 cnode_ = nullptr;
134 prim_name_ = GetPrimNameFromInfoName(this->name_);
135 }
136
137 virtual ~OperatorInfo() = default;
138
set_involved_param_name(std::string name)139 void set_involved_param_name(std::string name) { involved_param_name_ = name; }
get_involved_param_name()140 std::string get_involved_param_name() { return involved_param_name_; }
141 Status set_is_parameter(const std::vector<bool> &is_parameter);
142 Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
143 const std::vector<size_t> &output_lengths);
144 double GetOutputsTotalSize();
145 // Set outputs dtype.
146 // If only one output, outputs_type.size() is 1.
147 // If output is tuple, outputs_type.size() is greater than 1.
148 Status set_outputs_type(const std::vector<TypePtr> &outputs_type);
outputs_type()149 const std::vector<TypePtr> &outputs_type() const { return outputs_type_; }
150 virtual Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
151 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
152 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {});
153 // only init the necessary parts
154 virtual Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
155
156 // Given the stage_id (which indicates the number of devices),
157 // generate all strategies for this operator
158 Status GenerateStrategies(int64_t stage_id);
159 virtual std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) = 0;
operator_cost()160 const OperatorCostPtr &operator_cost() const { return operator_cost_; }
set_cost(const OperatorCostPtr & cost)161 void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
162 virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
163
164 virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
165 virtual void ReComputeBatchSplitFlagList();
166 std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
167 void ComputeBatchSplitFlagList();
inputs_shape()168 Shapes inputs_shape() const { return inputs_shape_; }
inputs_shape_new()169 NewShapes inputs_shape_new() const { return inputs_shape_new_; }
outputs_shape()170 Shapes outputs_shape() const { return outputs_shape_; }
outputs_shape_new()171 NewShapes outputs_shape_new() const { return outputs_shape_new_; }
set_inputs_divisor(const Shapes & in_divisor)172 void set_inputs_divisor(const Shapes &in_divisor) { inputs_divisor_ = in_divisor; }
set_outputs_divisor(const Shapes & out_divisor)173 void set_outputs_divisor(const Shapes &out_divisor) { outputs_divisor_ = out_divisor; }
set_dynamic_shape_flag(bool flag)174 void set_dynamic_shape_flag(bool flag) { dynamic_shape_flag_ = flag; }
175
176 double GetForwardMemoryCostFromCNode();
177 // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy
178 // is checked
179 Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
180 CostPtrList GetCostByStrategyPtr(const StrategyPtr &strategy);
GetStrategyCost()181 std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
182 void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost);
183 // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
184 // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
185 // at the end of forward phase.
186 Status CalculateMemoryCost();
187 // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
188 // by the output
189 Status CalculateMemoryCostForInference();
190 virtual int64_t ComputeOpAndPrevEdgeParameterInvolved();
191
forward_op()192 ForwardOp forward_op() const { return forward_op_; }
replace_op()193 ForwardOp replace_op() const { return replace_op_; }
replace_op_info()194 OutPutInfoVector replace_op_info() const { return replace_op_info_; }
replace_graph(const CNodePtr &)195 virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; }
mirror_ops()196 MirrorOps mirror_ops() const { return mirror_ops_; }
sub_ops()197 Ops sub_ops() const { return sub_ops_; }
virtual_div_op()198 VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
dev_matrix_shape()199 Shape dev_matrix_shape() const { return dev_matrix_shape_; }
inputs_tensor_info()200 std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
inputs_tensor_info_new()201 std::vector<TensorInfoBasePtr> inputs_tensor_info_new() const { return inputs_tensor_info_new_; }
set_inputs_tensor_info(const std::vector<TensorInfo> & tensor_info)202 void set_inputs_tensor_info(const std::vector<TensorInfo> &tensor_info) { inputs_tensor_info_ = tensor_info; }
outputs_tensor_info()203 std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
outputs_tensor_info_new()204 std::vector<TensorInfoBasePtr> outputs_tensor_info_new() const { return outputs_tensor_info_new_; }
strategy_cost()205 std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
name()206 const std::string &name() const { return name_; }
set_name(const std::string & name)207 void set_name(const std::string &name) { name_ = name; }
stage_device_list()208 RankList stage_device_list() const { return stage_device_list_; }
209
AddSuccEdge(const std::shared_ptr<Edge> & e)210 void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); }
AddPrevEdge(const std::shared_ptr<Edge> & e)211 void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); }
succ_edges()212 std::vector<std::shared_ptr<Edge>> succ_edges() const { return succ_edges_; }
prev_edges()213 std::vector<std::shared_ptr<Edge>> prev_edges() const { return prev_edges_; }
214 std::vector<std::shared_ptr<Edge>> GetAliveSuccEdges();
215 std::vector<std::shared_ptr<Edge>> GetAlivePrevEdges();
216 void ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
217 void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
218 void ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
219 void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge);
GetOutputTypeLengths()220 std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); }
SetSelectedStrategyAndCost(const StrategyPtr & s_strategy,const CostPtr & cost)221 void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) {
222 selected_strategy_ = s_strategy;
223 selected_cost_ = cost;
224 }
225 void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth);
selected_strategy()226 StrategyPtr selected_strategy() const { return selected_strategy_; }
selected_cost()227 CostPtr selected_cost() const { return selected_cost_; }
228
229 TensorLayout GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index);
230 TensorLayout GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index);
231 StrategyPtr GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index);
232 StrategyPtr GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index);
233 bool IsReshape() const;
234 bool IsTmpIdentity() const;
235
236 void set_swc_index(int64_t swc, int64_t depth);
swc_index()237 int64_t swc_index() const { return swc_index_; }
238 // Approximate the list of available strategies
239 void ApproximateStrategies();
240 // Make the list of available strategies exact and re-init the related edges incident to this operator
241 void ExactStrategiesAndRelatedEdges();
is_strategy_cost_exact()242 bool is_strategy_cost_exact() const { return is_strategy_cost_exact_; }
SetIsStrategyCostExactTrue()243 void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; }
ClearStrategyCost()244 void ClearStrategyCost() { strategy_cost_.clear(); }
245 void CheckSelectedStrategy(const StrategyPtr &s_strategy);
InitSelectedStrategy(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)246 Status InitSelectedStrategy(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
247 set_auto_parallel(false);
248 return Init(in_strategy, out_strategy);
249 }
set_input_value(const std::vector<ValuePtr> & input_value)250 void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
input_value()251 const std::vector<ValuePtr> &input_value() const { return input_value_; }
set_outputs_dtype(const TypePtr & dtype)252 void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
set_cnode(const CNodePtr & cnode)253 void set_cnode(const CNodePtr &cnode) {
254 cnode_ = cnode;
255 cnodes_.push_back(cnode);
256 }
set_new_shape(const std::vector<NewShapes> & shape)257 void set_new_shape(const std::vector<NewShapes> &shape) {
258 inputs_shape_new_ = shape[0];
259 outputs_shape_new_ = shape[1];
260 }
261 std::vector<CNodePtr> cnodes();
cnode()262 CNodePtr cnode() const { return cnode_; }
is_alive()263 bool is_alive() const { return is_alive_; }
SetNotAlive()264 void SetNotAlive() { is_alive_ = false; }
split_flag_list()265 std::vector<bool> split_flag_list() const { return split_flag_list_; }
strategy()266 StrategyPtr strategy() const { return strategy_; }
out_strategy()267 StrategyPtr out_strategy() const { return out_strategy_; }
set_out_strategy(const StrategyPtr & strategy)268 void set_out_strategy(const StrategyPtr &strategy) { out_strategy_ = strategy; }
set_strategy(const StrategyPtr & strategy)269 void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; }
set_refkey_parameter_name(std::string p_name)270 void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); }
refkey_parameter_name()271 const std::string &refkey_parameter_name() const { return refkey_parameter_name_; }
272 // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated
273 // multiple times. This method is to correct this, and makes the cost is calculated only once.
274 Status CorrectMemoryCost(size_t input_index);
is_output_parameter_involve()275 int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; }
is_output_critical()276 int64_t is_output_critical() const { return is_output_critical_; }
mark_output_critical()277 void mark_output_critical() { is_output_critical_ = 1; }
mark_output_not_critical()278 void mark_output_not_critical() { is_output_critical_ = 0; }
used_devices()279 int64_t used_devices() const { return used_devices_; }
280 // needed by rec_parser
set_type(const std::string & type)281 void set_type(const std::string &type) { type_ = type; }
type()282 const std::string &type() const { return type_; }
set_last_node_flag(const bool & is_last_node)283 void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; }
is_last_node()284 const bool &is_last_node() const { return is_last_node_; }
attrs()285 const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; }
addAttr(const std::string & name,const ValuePtr & val)286 void addAttr(const std::string &name, const ValuePtr &val) { attrs_[name] = val; }
set_stage_id(int32_t stage_id)287 void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
stage_id()288 int32_t stage_id() const { return stage_id_; }
289 Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
290 Status CreateGroupForOptShard(TensorLayout *tensor_layout, std::vector<Group> *groups);
ReplaceNodeInputOrAttrs()291 virtual void ReplaceNodeInputOrAttrs() {}
set_auto_parallel(bool is_auto_parallel)292 void set_auto_parallel(bool is_auto_parallel) { is_auto_parallel_ = is_auto_parallel; }
set_assigned_parallel(bool is_assigned_parallel)293 void set_assigned_parallel(bool is_assigned_parallel) { is_assigned_parallel_ = is_assigned_parallel; }
repeated_num_in_dev_matrix_right()294 bool repeated_num_in_dev_matrix_right() const { return repeated_num_in_dev_matrix_right_; }
set_repeated_num_in_dev_matrix_right(bool is_right)295 void set_repeated_num_in_dev_matrix_right(bool is_right) { repeated_num_in_dev_matrix_right_ = is_right; }
296
297 TensorRedistributionPtr CreateTensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) {
298 if (this->tensor_redistribution_ != nullptr) {
299 MS_LOG(DEBUG) << "TensorRedistribution re-created.";
300 }
301 this->tensor_redistribution_ = std::make_shared<TensorRedistribution>(construct_op_flag, keep_reshape);
302 return this->tensor_redistribution_;
303 }
304
305 TensorRedistributionPtr CreateReshapeTensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) {
306 if (this->reshape_tensor_redistribution_ != nullptr) {
307 MS_LOG(DEBUG) << "TensorRedistribution re-created.";
308 }
309 this->reshape_tensor_redistribution_ = std::make_shared<TensorRedistribution>(construct_op_flag, keep_reshape);
310 return this->reshape_tensor_redistribution_;
311 }
312
SetTensorRedistribution(const TensorRedistributionPtr & tensor_redistribution)313 void SetTensorRedistribution(const TensorRedistributionPtr &tensor_redistribution) {
314 this->tensor_redistribution_ = tensor_redistribution;
315 }
316
SetReshapeTensorRedistribution(const TensorRedistributionPtr & tensor_redistribution)317 void SetReshapeTensorRedistribution(const TensorRedistributionPtr &tensor_redistribution) {
318 this->reshape_tensor_redistribution_ = tensor_redistribution;
319 }
320
tensor_redistribution()321 TensorRedistributionPtr tensor_redistribution() { return this->tensor_redistribution_; }
322
reshape_tensor_redistribution()323 TensorRedistributionPtr reshape_tensor_redistribution() { return this->reshape_tensor_redistribution_; }
324
325 // Key for user data.
326 constexpr static char key[] = "OpInfo";
327
328 protected:
329 // needed by rec_parser
330 std::string type_;
331 TensorRedistributionPtr tensor_redistribution_;
332 TensorRedistributionPtr reshape_tensor_redistribution_;
333 bool is_last_node_ = false;
334 virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
335 virtual Status InferTensorMap() = 0;
InferOutputTensorMap()336 virtual Status InferOutputTensorMap() { return SUCCESS; }
InferOutputTensorInfo()337 virtual Status InferOutputTensorInfo() { return SUCCESS; }
CheckLayoutConfig()338 virtual Status CheckLayoutConfig() { return SUCCESS; }
339 virtual Status CheckInputLayout();
CheckOutputLayout()340 virtual Status CheckOutputLayout() { return SUCCESS; }
InferForwardCommunicationByLayout()341 virtual Status InferForwardCommunicationByLayout() { return SUCCESS; }
342 virtual Status InferMirrorOpsByLayout();
343 virtual Status InferForwardCommunication() = 0;
344 virtual Status GetAttrs() = 0;
345 virtual Status InferDevMatrixShape() = 0;
346 virtual Status InferMirrorOps();
347 virtual Status InferTensorInfo();
348 virtual Status InferTensorInfoNew();
349
InferReplaceOps()350 virtual void InferReplaceOps() {}
351 virtual void UpdateOutputTensorInfoForInterleaved();
352 virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
CheckStrategyForDynamicShape(const StrategyPtr & strategy)353 virtual Status CheckStrategyForDynamicShape(const StrategyPtr &strategy) { return SUCCESS; }
354 Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
355 Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
356 void DivisorsReplaceShapes(); // in dynamic shape, using divisors replace to shapes before CheckStrategy and so on
357 void ResumeShapes(); // in dynamic shape, resume shapes after CheckStrategy and so on
358 void DynamicShapeCheckStrategyLog();
359 void SetRepeatedCalcDevMatrix();
360 void ResetTensorMapIfRepeatedCalc();
361 void ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps *tensor_map_new);
362 void ChangeMakeTupleConstant(const CNodePtr &cnode, size_t make_tuple_index);
363 Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
364 Status CreateGroupByDimWithDevMatrix(DeviceMatrix *dev_matrix, size_t axis, std::vector<Group> *group);
365 Status InferAttrs();
366 void ResetQueueMember();
367 Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
368 Status InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
369 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts);
370 Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
371 Status InferRepeatedCalcInfo();
372 Status InferVirtualDivOps();
373 Status InferVirtualDivOpsByLayout();
374 bool IsDynamicShape();
375 bool IsDynamicRank();
376 bool IsSelfDefineShard();
377
378 // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map.
379 // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output
380 // is used for grad and overload the function. If the output is a scalar, need to override the function too.
381 virtual Status InferAsLossDivisor();
382 virtual Status InferAsLossDivisorByLayout();
383 void BreakingTiesForPreferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) const;
384 int64_t GetIntAttr(const std::string &attr_name);
385 bool GetBoolAttr(const std::string &attr_name);
386 float GetFloatAttr(const std::string &attr_name);
387 std::string GetStringAttr(const std::string &attr_name);
388 std::vector<int64_t> GetTupleIntAttr(const std::string &attr_name);
ReportError(const std::string & error_msg)389 void ReportError(const std::string &error_msg) const {
390 if (is_auto_parallel_) {
391 MS_LOG(DEBUG) << error_msg;
392 } else {
393 MS_LOG(ERROR) << error_msg;
394 }
395 }
396
397 std::string name_;
398 std::string prim_name_;
399 Shapes inputs_shape_;
400 Shapes outputs_shape_;
401 NewShapes inputs_shape_new_;
402 NewShapes outputs_shape_new_;
403 Shapes inputs_divisor_; // using for dynamic shape, the size is equal to inputs_shape_
404 Shapes outputs_divisor_; // using for dynamic shape, the size is equal to outputs_shape_
405 Shapes inputs_shape_clone_;
406 Shapes outputs_shape_clone_;
407 bool dynamic_shape_flag_ = False; // means this op in the dynamic shape graph
408 mindspore::HashMap<std::string, ValuePtr> attrs_;
409 std::vector<ValuePtr> input_value_;
410 TypePtr outputs_dtype_;
411
412 int32_t stage_id_ = 0;
413 StrategyPtr strategy_;
414 StrategyPtr out_strategy_;
415 std::vector<TensorInfo> inputs_tensor_info_;
416 std::vector<TensorInfo> outputs_tensor_info_;
417 std::vector<TensorInfoBasePtr> inputs_tensor_info_new_;
418 std::vector<TensorInfoBasePtr> outputs_tensor_info_new_;
419 Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_
420 Shape out_dev_matrix_shape_;
421 int64_t repeated_calc_num_ = 1;
422 int64_t as_loss_divisor_ = 1;
423 TensorMaps inputs_tensor_map_;
424 TensorMaps outputs_tensor_map_;
425 NewTensorMaps inputs_tensor_map_new_;
426 NewTensorMaps outputs_tensor_map_new_;
427 ForwardOp forward_op_;
428 ForwardOp forward_op_interleaved_;
429 Ops sub_ops_;
430 ForwardOp replace_op_;
431 OutPutInfoVector replace_op_info_;
432 ReplaceGraphPtr replace_graph_;
433 MirrorOps mirror_ops_;
434 VirtualDivOp virtual_div_op_;
435 RankList stage_device_list_; // the device list in this stage
436 int64_t stage_device_size_ = 0;
437 bool infer_attrs_completed_ = false;
438 bool is_layout_config_ = false;
439 bool is_dynamic_shape_ = false;
440 bool is_dynamic_rank_ = false;
441 Shapes strategy_from_layout_;
442
443 bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel
444 bool is_assigned_parallel_ = false; // false: origin parallel; true: dynamic_shape parallel
445 // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected.
446 std::vector<size_t> corrected_input_indices_;
447 // Given a parallelization strategy, there is a cost.
448 std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_;
449 std::string involved_param_name_;
450 // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter
451 std::vector<bool> is_parameter_;
452 // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
453 // pre-operator that has parameters as input.
454 std::vector<bool> is_parameter_involve_;
455 // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
456 // peak memory cost in the training phase.
457 // -1: unset; 0: not parameter_involved; 1: parameter_involved
458 int64_t is_output_parameter_involve_ = -1;
459 // Whether this output is critical, which means that this output is included in calculating peak memory cost
460 // in the inference phase.
461 // -1 : unset; 0: not critical; 1: critical
462 int64_t is_output_critical_ = -1;
463 double outputs_total_size_ = 0.0;
464 bool is_calculated_outputs_size_ = false;
465 // for each input and output, the followings record the number of bytes of each element
466 std::vector<size_t> inputs_type_lengths_;
467 std::vector<size_t> outputs_type_lengths_;
468 std::vector<std::shared_ptr<Edge>> prev_edges_;
469 std::vector<std::shared_ptr<Edge>> succ_edges_;
470 StrategyPtr selected_strategy_;
471 int64_t selected_strategy_depth_ = -1;
472 // Used in DP algorithm
473 bool is_alive_;
474 CostPtr selected_cost_;
475 std::vector<bool> split_flag_list_;
476 std::string refkey_parameter_name_;
477 CNodePtr cnode_;
478 std::vector<CNodePtr> cnodes_;
479 int64_t used_devices_ = -1;
480 // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
481 bool repeated_num_in_dev_matrix_right_ = true;
482 // Whether the list of available strategies is exact or approximate
483 bool is_strategy_cost_exact_ = true;
484 bool self_define_shard_;
485
486 private:
487 OperatorCostPtr operator_cost_;
488 std::vector<TypePtr> outputs_type_;
489 int64_t swc_index_ = -1;
490 Status GetLayoutConfig();
491 Status GetRepeatedNumInDevMatrixRight();
492 Status CheckLayoutConfigBase();
493 };
494
495 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
496 Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool);
497 Operator CreateVirtualDivOp(int64_t div_num);
498 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
499 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
500 Operator CreateAllGatherOp(const std::string &group);
501 Operator CreateCastOp(TypePtr type);
502 Operator CreateDivOp(float scale);
503 Operator CreateScalarDivOp(int64_t div_num);
504 Operator CreateScalarCastOp(TypePtr type);
505 Operator CreateScalarFloorDivOp(int64_t div_num);
506 Operator CreateScalarMulOp(int64_t scalar);
507 void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val);
508 int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node);
509 Operator CreateMicroStepAllGatherOp(const std::string &group);
510 void AddCommOpMeanFlag(const CNodePtr &comm_node);
511 void AddCommOpParamFlag(const CNodePtr &comm_node);
512 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
513 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
514 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
515 std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
516 const std::vector<bool> &split_flag_list);
517 std::string StrategyToString(const Strategies &strategy);
518 Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
519 const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
520 // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d])
521 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
522 const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
523 // generate strategies for that inputs' dimension maybe dependent
524 Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
525 const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp);
526 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
527 // and the corresponding dimensions that are not broadcast are all relevant dimensions
528 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
529 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
530 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
531 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
532 std::vector<StrategyPtr> *sp_vector);
533 std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence);
534 ValuePtr MakeListValue(const std::vector<int64_t> &v);
535 ValuePtr MakeTupleListValue(const Shapes &v);
536 AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
537 AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
538
539 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
540 Operator CreateDivOpWithType(float divisor, const TypePtr &dtype);
541 std::vector<int64_t> GetTensorValue(const ValuePtr &ori_value);
542
GetPrimNameFromInfoName(const std::string & info_name)543 inline std::string GetPrimNameFromInfoName(const std::string &info_name) {
544 auto prim_name = info_name;
545 if (auto pos = info_name.rfind("Info"); pos != std::string::npos) {
546 prim_name = info_name.substr(0, pos);
547 }
548 return prim_name;
549 }
550
551 template <typename T>
GetScalarValueFromInputs(const std::vector<ValuePtr> & input_value,size_t idx)552 std::optional<T> GetScalarValueFromInputs(const std::vector<ValuePtr> &input_value, size_t idx) {
553 if (idx == SIZE_MAX) {
554 MS_EXCEPTION(ValueError) << "Index is the size max, target value maybe wrong!";
555 }
556
557 if (input_value.size() <= idx || input_value[idx] == nullptr) {
558 return std::nullopt;
559 }
560 return ops::GetScalarValue<T>(input_value[idx]);
561 }
562
563 template <typename T>
GetInputValueFromCNode(const CNodePtr & cnode,size_t index)564 T GetInputValueFromCNode(const CNodePtr &cnode, size_t index) {
565 MS_EXCEPTION_IF_NULL(cnode);
566 auto inputs = cnode->inputs();
567 if (index >= inputs.size()) {
568 MS_LOG(EXCEPTION) << "The input index (" << index << ") is exceed of inputs size (" << inputs.size() << ").";
569 }
570 auto input_node = inputs[index];
571 MS_EXCEPTION_IF_NULL(input_node);
572 if (!input_node->isa<ValueNode>()) {
573 MS_LOG(EXCEPTION) << "The input index (" << index << ") is not a value node.";
574 }
575 auto value = input_node->cast<ValueNodePtr>()->value();
576 MS_EXCEPTION_IF_NULL(value);
577 return GetValue<T>(value);
578 }
579
580 template <typename T>
SetValueInputToCNode(const CNodePtr & cnode,size_t index,T value)581 void SetValueInputToCNode(const CNodePtr &cnode, size_t index, T value) {
582 MS_EXCEPTION_IF_NULL(cnode);
583 auto inputs = cnode->inputs();
584 if (index >= inputs.size()) {
585 MS_LOG(EXCEPTION) << "The input index (" << index << ") is exceed of inputs size (" << inputs.size() << ").";
586 }
587 auto func_graph = cnode->func_graph();
588 MS_EXCEPTION_IF_NULL(func_graph);
589 auto manager = func_graph->manager();
590 auto value_node = NewValueNode(MakeValue(value));
591 MS_EXCEPTION_IF_NULL(value_node);
592 manager->SetEdge(cnode, index, value_node);
593 }
594
595 template <typename T>
GetScalarValueFromInputs(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)596 std::optional<T> GetScalarValueFromInputs(const std::vector<ValuePtr> &input_value, const std::string &op_name,
597 const std::string &attr_name) {
598 auto prim_name = GetPrimNameFromInfoName(op_name);
599 auto idx = ops::GetInputIndexByName(prim_name, attr_name);
600 return GetScalarValueFromInputs<T>(input_value, idx);
601 }
602
603 template <typename T>
GetArrayValueFromInputs(const std::vector<ValuePtr> & input_value,size_t idx)604 std::optional<std::vector<T>> GetArrayValueFromInputs(const std::vector<ValuePtr> &input_value, size_t idx) {
605 if (idx == SIZE_MAX) {
606 MS_EXCEPTION(ValueError) << "Index is the size max, target value maybe wrong!";
607 }
608
609 if (input_value.size() <= idx || input_value[idx] == nullptr) {
610 return std::nullopt;
611 }
612 auto array_opt = ops::GetArrayValue<T>(input_value[idx]);
613 if (!array_opt.has_value() || array_opt.value().HasUnknownValue()) {
614 return std::nullopt;
615 }
616 return array_opt.value().ToVector();
617 }
618
619 template <typename T>
GetArrayValueFromInputs(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)620 std::optional<std::vector<T>> GetArrayValueFromInputs(const std::vector<ValuePtr> &input_value,
621 const std::string &op_name, const std::string &attr_name) {
622 auto prim_name = GetPrimNameFromInfoName(op_name);
623 auto idx = ops::GetInputIndexByName(prim_name, attr_name);
624 return GetArrayValueFromInputs<T>(input_value, idx);
625 }
626
627 template <typename T>
GetArrayValueFromInputsWithCheck(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)628 std::optional<std::vector<T>> GetArrayValueFromInputsWithCheck(const std::vector<ValuePtr> &input_value,
629 const std::string &op_name,
630 const std::string &attr_name) {
631 auto attr_opt = GetArrayValueFromInputs<T>(input_value, op_name, attr_name);
632 if (!attr_opt.has_value()) {
633 MS_LOG(ERROR) << op_name << ": Don't have attribution " << attr_name;
634 return std::nullopt;
635 }
636 return attr_opt;
637 }
638
639 template <typename T>
GetScalarValueFromInputsWithCheck(const std::vector<ValuePtr> & input_value,const std::string & op_name,const std::string & attr_name)640 std::optional<T> GetScalarValueFromInputsWithCheck(const std::vector<ValuePtr> &input_value, const std::string &op_name,
641 const std::string &attr_name) {
642 auto attr_opt = GetScalarValueFromInputs<T>(input_value, op_name, attr_name);
643 if (!attr_opt.has_value()) {
644 MS_LOG(ERROR) << op_name << ": Don't have attribution " << attr_name;
645 return std::nullopt;
646 }
647 return attr_opt;
648 }
649
650 } // namespace parallel
651 } // namespace mindspore
652
653 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_OPERATOR_INFO_H_
654