• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
19 
20 #include <algorithm>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "utils/hash_map.h"
26 #include "frontend/parallel/tensor_layout/construct_operator.h"
27 #include "frontend/parallel/auto_parallel/costmodel.h"
28 #include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h"
29 #include "include/common/utils/convert_utils.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 using DeviceArrangement = Shape;
34 using TensorMap = Shape;
35 using TensorShape = Shape;
36 using RedistributionOperatorMap = mindspore::HashMap<uint64_t, int64_t>;
37 using OperatorR = std::pair<OperatorName, Args>;
38 using OperatorC = std::pair<OperatorR, Shape>;
39 using OperatorList = std::vector<OperatorC>;
40 
41 class RedistributionOperatorInfer {
42  public:
43   const int64_t NONE = -1;
44   explicit RedistributionOperatorInfer(bool construct_op_flag = true)
construct_op_flag_(construct_op_flag)45       : construct_op_flag_(construct_op_flag), is_cost_model_(false) {}
46   Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list,
47               bool is_cost_model = false, bool is_dynamic_shape = false);
48   ~RedistributionOperatorInfer() = default;
operator_list()49   OperatorList operator_list() const { return operator_list_; }
operator_vector()50   OperatorVector operator_vector() const { return operator_vector_; }
output_info_vector()51   OutPutInfoVector output_info_vector() const { return output_info_vector_; }
52   Status InferRedistributionOperator();
SetVirtualRank(const int64_t virtual_rank)53   void SetVirtualRank(const int64_t virtual_rank) { virtual_rank_ = virtual_rank; }
54   Status MergePartialToFullForReshapeHasMultiDynamicAxis();
55   Status SegmentFullShapeToPartial();
56 
57  private:
58   Status InferSplitByAxis();
59   Status InferPermuteByAxis();
60   Status InferConcatByAxis();
61   Status TransferSplitByAxis(const Args &args);
62   Status TransferPermuteByAxis(const Args &args);
63   Status TransferConcatByAxis(const Args &args);
64   Status InsertOperator(const OperatorName &name, const Args &args);
65 
66   OperatorList operator_list_;
67   OperatorVector operator_vector_;
68   OutPutInfoVector output_info_vector_;
69   Arrangement dev_mat_;
70   RedistributionOperatorMap map_;
71   Map in_tensor_map_;
72   Map out_tensor_map_;
73   TensorLayout cur_tensor_layout_;
74   ConstructOperator constructor_;
75   RankList dev_list_;
76   bool construct_op_flag_;
77   bool is_cost_model_;
78   int64_t virtual_rank_ = -1;
79 };
80 }  // namespace parallel
81 }  // namespace mindspore
82 
83 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_
84