• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_TRANSFORM_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_TRANSFORM_H_
19 
20 #include <vector>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "ir/value.h"
26 #include "frontend/parallel/tensor_layout/tensor_layout.h"
27 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 using TransformFunc = std::function<std::pair<std::string, std::vector<int64_t>>(const Operator &)>;
32 using InferShapeFunc = std::function<Shape(const Shape &, const std::vector<int64_t> &)>;
33 class TensorTransform {
34  public:
35   static std::shared_ptr<TensorTransform> GetInstance();
36   ~TensorTransform() = default;
37   TensorTransform(const TensorTransform &) = delete;
38   TensorTransform &operator=(const TensorTransform &) = delete;
39   void InitTransforOperator();
40   std::vector<std::pair<std::string, std::vector<int64_t>>> TransformOperators(const Shapes &from, const Shapes &to,
41                                                                                const RankList &dev_list,
42                                                                                int64_t rank_id);
43   RedistributionOpListPtr OptimizeTensorRedistributionOperatorList(
44     const RedistributionOpListPtr &redistribution_op_list, const Shape &input_shape);
45 
46  private:
47   TensorTransform();
48   std::unordered_map<string, TransformFunc> transform_operator_;
49   std::unordered_map<string, InferShapeFunc> infer_shape_operator_;
50   bool inited_function_ = false;
51   std::pair<std::string, std::vector<int64_t>> ExtractReshapeOp(const Operator &reshape_op_pair) const;
52   std::pair<std::string, std::vector<int64_t>> ExtractAllGatherOp(const Operator &allgather_op_pair) const;
53   std::pair<std::string, std::vector<int64_t>> ExtractSplitOp(const Operator &split_op_pair) const;
54   std::pair<std::string, std::vector<int64_t>> ExtractConcatOp(const Operator &concat_op_pair) const;
55   std::pair<std::string, std::vector<int64_t>> ExtractStridedSliceOp(const Operator &slice_op_pair) const;
56   Shape InferReshapeOp(const Shape &ori_shape, const std::vector<int64_t> &op) const;
57   Shape InferAllGatherOp(const Shape &ori_shape, const std::vector<int64_t> &op) const;
58   Shape InferStridedSliceOp(const Shape &ori_shape, const std::vector<int64_t> &op) const;
59   std::vector<Shape> GetRedistributionOpShape(
60     const Shape &ori_shape, const std::vector<std::pair<std::string, std::vector<int64_t>>> &transform_op_list);
61   void OptimizeAllConcat(std::vector<std::pair<std::string, std::vector<int64_t>>> *transform_op_list);
62   TensorRedistribution tensor_redistribution_;
63 };
64 }  // namespace parallel
65 }  // namespace mindspore
66 
67 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_TRANSFORM_H_
68