1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "ir/value.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 #include "frontend/parallel/status.h" 30 #include "frontend/parallel/tensor_layout/construct_operator.h" 31 #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" 32 #include "frontend/parallel/tensor_layout/tensor_layout.h" 33 34 namespace mindspore { 35 namespace parallel { 36 constexpr double ALLTOALL_SCALE_FACTOR = 2.0; 37 constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; 38 class TensorRedistribution { 39 public: 40 explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) reshape_flag_(false)41 : reshape_flag_(false), 42 comm_cost_(0.0), 43 forward_comm_cost_(0.0), 44 backward_comm_cost_(0.0), 45 computation_cost_(0.0), 46 memory_cost_(0.0), 47 construct_op_flag_(construct_op_flag), 48 keep_reshape_(keep_reshape) {} 49 Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); 50 ~TensorRedistribution() = default; 51 RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); operator_list()52 OperatorList operator_list() const { return operator_list_; } reshape_flag()53 bool reshape_flag() const { return reshape_flag_; } 54 Status ComputeCost(); comm_cost()55 double comm_cost() const { return comm_cost_; } computation_cost()56 double computation_cost() const { return computation_cost_; } forward_comm_cost()57 double forward_comm_cost() const { return forward_comm_cost_; } backward_comm_cost()58 double backward_comm_cost() const { return backward_comm_cost_; } memory_cost()59 double memory_cost() const { return memory_cost_; } 60 61 private: 62 Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, 63 OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); 64 Status InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout, 65 OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector, 66 bool is_cost_model); 67 Status ComputeConcatCost(double input_size, const Shape &attrs); 68 Status ComputePermuteCost(double input_size, const Shape &attrs); 69 RedistributionOpListPtr InferTensorRedistributionOperatorListUnExpand(bool is_cost_model = false); 70 TensorLayout from_origin_; 71 TensorLayout to_origin_; 72 TensorLayout from_; 73 TensorLayout to_; 74 RankList dev_list_; 75 OperatorList operator_list_; 76 bool reshape_flag_; 77 // communication cost, which is the sum of forward communication cost and backward communication cost 78 double comm_cost_; 79 // forward communication cost 80 double forward_comm_cost_; 81 // backward communication cost 82 double backward_comm_cost_; 83 // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the 84 // inputs. This is calculated ONLY for forward phase. 85 double computation_cost_; 86 // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is 87 // calculated by the outputs. 88 double memory_cost_; 89 bool construct_op_flag_; 90 bool keep_reshape_; 91 bool expand_able_ = true; 92 }; 93 } // namespace parallel 94 } // namespace mindspore 95 96 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ 97