1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 19 20 #include <algorithm> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "utils/hash_map.h" 26 #include "frontend/parallel/tensor_layout/construct_operator.h" 27 #include "frontend/parallel/auto_parallel/costmodel.h" 28 #include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" 29 #include "include/common/utils/convert_utils.h" 30 31 namespace mindspore { 32 namespace parallel { 33 using DeviceArrangement = Shape; 34 using TensorMap = Shape; 35 using TensorShape = Shape; 36 using RedistributionOperatorMap = mindspore::HashMap<uint64_t, int64_t>; 37 using OperatorR = std::pair<OperatorName, Args>; 38 using OperatorC = std::pair<OperatorR, Shape>; 39 using OperatorList = std::vector<OperatorC>; 40 41 class RedistributionOperatorInfer { 42 public: 43 const int64_t NONE = -1; 44 explicit RedistributionOperatorInfer(bool construct_op_flag = true) construct_op_flag_(construct_op_flag)45 : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} 46 Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, 47 bool is_cost_model = false, bool is_dynamic_shape = false); 48 ~RedistributionOperatorInfer() = default; operator_list()49 OperatorList operator_list() const { return operator_list_; } operator_vector()50 OperatorVector operator_vector() const { return operator_vector_; } output_info_vector()51 OutPutInfoVector output_info_vector() const { return output_info_vector_; } 52 Status InferRedistributionOperator(); SetVirtualRank(const int64_t virtual_rank)53 void SetVirtualRank(const int64_t virtual_rank) { virtual_rank_ = virtual_rank; } 54 Status MergePartialToFullForReshapeHasMultiDynamicAxis(); 55 Status SegmentFullShapeToPartial(); 56 57 private: 58 Status InferSplitByAxis(); 59 Status InferPermuteByAxis(); 60 Status InferConcatByAxis(); 61 Status TransferSplitByAxis(const Args &args); 62 Status TransferPermuteByAxis(const Args &args); 63 Status TransferConcatByAxis(const Args &args); 64 Status InsertOperator(const OperatorName &name, const Args &args); 65 66 OperatorList operator_list_; 67 OperatorVector operator_vector_; 68 OutPutInfoVector output_info_vector_; 69 Arrangement dev_mat_; 70 RedistributionOperatorMap map_; 71 Map in_tensor_map_; 72 Map out_tensor_map_; 73 TensorLayout cur_tensor_layout_; 74 ConstructOperator constructor_; 75 RankList dev_list_; 76 bool construct_op_flag_; 77 bool is_cost_model_; 78 int64_t virtual_rank_ = -1; 79 }; 80 } // namespace parallel 81 } // namespace mindspore 82 83 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ 84