1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 19 20 #include <algorithm> 21 #include <string> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "frontend/parallel/tensor_layout/construct_operator.h" 27 #include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" 28 #include "utils/convert_utils.h" 29 namespace mindspore { 30 namespace parallel { 31 using DeviceArrangement = Shape; 32 using TensorMap = Shape; 33 using TensorShape = Shape; 34 using RedistributionOperatorMap = std::unordered_map<uint64_t, int64_t>; 35 using OperatorR = std::pair<OperatorName, Args>; 36 using OperatorC = std::pair<OperatorR, Shape>; 37 using OperatorList = std::vector<OperatorC>; 38 39 class RedistributionOperatorInfer { 40 public: 41 const int64_t NONE = -1; 42 explicit RedistributionOperatorInfer(bool construct_op_flag = true) construct_op_flag_(construct_op_flag)43 : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} 44 Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, 45 bool is_cost_model = false); 46 ~RedistributionOperatorInfer() = default; operator_list()47 OperatorList operator_list() const { return operator_list_; } operator_vector()48 OperatorVector operator_vector() const { return operator_vector_; } output_info_vector()49 OutPutInfoVector output_info_vector() const { return output_info_vector_; } 50 Status InferRedistributionOperator(); 51 52 private: 53 Status InferSplitByAxis(); 54 Status InferPermuteByAxis(); 55 Status InferConcatByAxis(); 56 Status TransferSplitByAxis(const Args &args); 57 Status TransferPermuteByAxis(const Args &args); 58 Status TransferConcatByAxis(const Args &args); 59 Status InsertOperator(const OperatorName &name, const Args &args); 60 61 OperatorList operator_list_; 62 OperatorVector operator_vector_; 63 OutPutInfoVector output_info_vector_; 64 Arrangement dev_mat_; 65 RedistributionOperatorMap map_; 66 Map in_tensor_map_; 67 Map out_tensor_map_; 68 TensorLayout cur_tensor_layout_; 69 ConstructOperator constructor_; 70 RankList dev_list_; 71 bool construct_op_flag_; 72 bool is_cost_model_; 73 }; 74 } // namespace parallel 75 } // namespace mindspore 76 77 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 78