1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
19
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "schema/model_generated.h"
29 #include "include/errorcode.h"
30
31 namespace mindspore {
32 namespace opt {
33 /**
34 * Do following steps to make a operator support parallel:
35 *
36 * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
37 * 2.Add a pair of type and string name to ParallelPass::type_string;
38 * 3.Implement a class XXXInfo whose parent is OperatorInfo;
39 * 3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
40 * 4.include header file of XXXInfo in ops_info_head_files.h
41 * 5.REGISTER XXXInfo in dynamic_creator.cc
42 */
43 using schema::ReduceMode;
44 class OperatorInfo;
45 using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
46 class OperatorInfo {
47 public:
OperatorInfo(const std::string & name,const SplitStrategy & strategy)48 OperatorInfo(const std::string &name, const SplitStrategy &strategy)
49 : name_(std::move(name)),
50 strategy_(std::move(strategy)),
51 replace_op_(nullptr),
52 func_graph_(nullptr),
53 cnode_(nullptr) {}
54 virtual ~OperatorInfo() = default;
name()55 const std::string name() const { return name_; }
set_name(const std::string & name)56 void set_name(const std::string &name) { name_ = name; }
57 void Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type);
58 int DoSplit();
replace_op()59 AnfNodePtr replace_op() const { return replace_op_; }
60
61 protected:
62 int CheckSplitResult(const AnfNodePtr &anf_node, const std::vector<AnfNodePtr> &split_results, int target_output_num);
63
64 int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
65
66 AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
67 int32_t concat_dim, size_t input_nodes_num);
68 AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
69 size_t input_nodes_num);
70
71 std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
72
73 virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
74 std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
75 const std::vector<int64_t> &splits) = 0;
76 virtual int InferReplaceOp() = 0;
77 virtual int InferParallelCNodes() = 0;
78 virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
79
80 protected:
81 std::string name_;
82 SplitStrategy strategy_;
83 AnfNodePtr replace_op_{nullptr};
84 std::vector<AnfNodePtr> parallel_output_nodes_;
85 FuncGraphPtr func_graph_{nullptr};
86 CNodePtr cnode_{nullptr};
87 int32_t fmk_type_{};
88 TypeId operator_type_id_ = kNumberTypeFloat32;
89
90 private:
91 int SetCNodeBackend();
92 int CheckStrategyValue();
93 };
94
95 // a template func for normal op_coder creator
96 template <typename T>
OperatorInfoCreator(const std::string & name,const SplitStrategy & strategy)97 std::unique_ptr<OperatorInfo> OperatorInfoCreator(const std::string &name, const SplitStrategy &strategy) {
98 std::unique_ptr<T> coder = std::make_unique<T>(name, strategy);
99 return coder;
100 }
101
102 bool is_any_none(const std::vector<int64_t> &split);
103 bool is_any_not_none(const std::vector<int64_t> &split);
104
105 } // namespace opt
106 } // namespace mindspore
107
108 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
109