• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
19 
20 #include <algorithm>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/parallel/tensor_layout/construct_operator.h"
27 #include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h"
28 #include "utils/convert_utils.h"
29 namespace mindspore {
30 namespace parallel {
31 using DeviceArrangement = Shape;
32 using TensorMap = Shape;
33 using TensorShape = Shape;
34 using RedistributionOperatorMap = std::unordered_map<uint64_t, int64_t>;
35 using OperatorR = std::pair<OperatorName, Args>;
36 using OperatorC = std::pair<OperatorR, Shape>;
37 using OperatorList = std::vector<OperatorC>;
38 
39 class RedistributionOperatorInfer {
40  public:
41   const int64_t NONE = -1;
42   explicit RedistributionOperatorInfer(bool construct_op_flag = true)
construct_op_flag_(construct_op_flag)43       : construct_op_flag_(construct_op_flag), is_cost_model_(false) {}
44   Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list,
45               bool is_cost_model = false);
46   ~RedistributionOperatorInfer() = default;
operator_list()47   OperatorList operator_list() const { return operator_list_; }
operator_vector()48   OperatorVector operator_vector() const { return operator_vector_; }
output_info_vector()49   OutPutInfoVector output_info_vector() const { return output_info_vector_; }
50   Status InferRedistributionOperator();
51 
52  private:
53   Status InferSplitByAxis();
54   Status InferPermuteByAxis();
55   Status InferConcatByAxis();
56   Status TransferSplitByAxis(const Args &args);
57   Status TransferPermuteByAxis(const Args &args);
58   Status TransferConcatByAxis(const Args &args);
59   Status InsertOperator(const OperatorName &name, const Args &args);
60 
61   OperatorList operator_list_;
62   OperatorVector operator_vector_;
63   OutPutInfoVector output_info_vector_;
64   Arrangement dev_mat_;
65   RedistributionOperatorMap map_;
66   Map in_tensor_map_;
67   Map out_tensor_map_;
68   TensorLayout cur_tensor_layout_;
69   ConstructOperator constructor_;
70   RankList dev_list_;
71   bool construct_op_flag_;
72   bool is_cost_model_;
73 };
74 }  // namespace parallel
75 }  // namespace mindspore
76 
77 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
78