1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
18 #define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
19
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "schema/model_generated.h"
29 #include "include/context.h"
30 #include "include/errorcode.h"
31
32 namespace mindspore {
33 namespace opt {
34 /**
35 * Do following steps to make a operator support parallel:
36 *
37 * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
38 * 2.Add a pair of type and string name to ParallelPass::type_string;
39 * 3.Implement a class XXXInfo whose parent is OperatorInfo;
40 * 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
41 * 4.include header file of XXXInfo in ops_info_head_files.h
42 * 5.REGISTER XXXInfo in dynamic_creator.cc
43 */
44 using schema::ReduceMode;
45 class OperatorInfo;
46 using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
47 class OperatorInfo {
48 public:
OperatorInfo(const std::string & name,const SplitStrategy & strategy)49 OperatorInfo(const std::string &name, const SplitStrategy &strategy)
50 : name_(std::move(name)),
51 strategy_(std::move(strategy)),
52 replace_op_(nullptr),
53 func_graph_(nullptr),
54 cnode_(nullptr) {}
55 virtual ~OperatorInfo() = default;
name()56 const std::string name() const { return name_; }
set_name(const std::string & name)57 void set_name(const std::string &name) { name_ = name; }
58 void Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type);
59 int DoSplit();
replace_op()60 AnfNodePtr replace_op() const { return replace_op_; }
61
62 protected:
63 int CheckSplitResult(const AnfNodePtr &anf_node, const std::vector<AnfNodePtr> &split_results, int target_output_num);
64
65 int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
66
67 AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
68 int32_t concat_dim, size_t input_nodes_num);
69 AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
70 size_t input_nodes_num);
71
72 std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
73
74 virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
75 std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
76 const std::vector<int64_t> &splits) = 0;
77 virtual int InferReplaceOp() = 0;
78 virtual int InferParallelCNodes() = 0;
79 virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
80
81 protected:
82 std::string name_;
83 SplitStrategy strategy_;
84 AnfNodePtr replace_op_{nullptr};
85 std::vector<AnfNodePtr> parallel_output_nodes_;
86 FuncGraphPtr func_graph_{nullptr};
87 CNodePtr cnode_{nullptr};
88 int32_t fmk_type_{};
89 TypeId operator_type_id_ = kNumberTypeFloat32;
90
91 private:
92 int SetCNodeBackend();
93 int CheckStrategyValue();
94 };
95
96 // a template func for normal op_coder creator
97 template <typename T>
OperatorInfoCreator(const std::string & name,const SplitStrategy & strategy)98 std::unique_ptr<OperatorInfo> OperatorInfoCreator(const std::string &name, const SplitStrategy &strategy) {
99 std::unique_ptr<T> coder = std::make_unique<T>(name, strategy);
100 return coder;
101 }
102
103 bool is_any_none(const std::vector<int64_t> &split);
104 bool is_any_not_none(const std::vector<int64_t> &split);
105
106 } // namespace opt
107 } // namespace mindspore
108
109 #endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
110