• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
18 #define MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "schema/model_generated.h"
29 #include "include/context.h"
30 #include "include/errorcode.h"
31 
32 namespace mindspore {
33 namespace opt {
34 /**
35  * Do following steps to make a operator support parallel:
36  *
37  * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
38  * 2.Add a pair of type and string name to ParallelPass::type_string;
39  * 3.Implement a class XXXInfo whose parent is OperatorInfo;
40  *    3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
41  * 4.include header file of XXXInfo in ops_info_head_files.h
42  * 5.REGISTER XXXInfo in dynamic_creator.cc
43  */
44 using schema::ReduceMode;
45 class OperatorInfo;
46 using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
47 class OperatorInfo {
48  public:
OperatorInfo(const std::string & name,const SplitStrategy & strategy)49   OperatorInfo(const std::string &name, const SplitStrategy &strategy)
50       : name_(std::move(name)),
51         strategy_(std::move(strategy)),
52         replace_op_(nullptr),
53         func_graph_(nullptr),
54         cnode_(nullptr) {}
55   virtual ~OperatorInfo() = default;
name()56   const std::string name() const { return name_; }
set_name(const std::string & name)57   void set_name(const std::string &name) { name_ = name; }
58   void Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type);
59   int DoSplit();
replace_op()60   AnfNodePtr replace_op() const { return replace_op_; }
61 
62  protected:
63   int CheckSplitResult(const AnfNodePtr &anf_node, const std::vector<AnfNodePtr> &split_results, int target_output_num);
64 
65   int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
66 
67   AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
68                                int32_t concat_dim, size_t input_nodes_num);
69   AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
70                               size_t input_nodes_num);
71 
72   std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
73 
74   virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
75                                           std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
76                                           const std::vector<int64_t> &splits) = 0;
77   virtual int InferReplaceOp() = 0;
78   virtual int InferParallelCNodes() = 0;
79   virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
80 
81  protected:
82   std::string name_;
83   SplitStrategy strategy_;
84   AnfNodePtr replace_op_{nullptr};
85   std::vector<AnfNodePtr> parallel_output_nodes_;
86   FuncGraphPtr func_graph_{nullptr};
87   CNodePtr cnode_{nullptr};
88   int32_t fmk_type_{};
89   TypeId operator_type_id_ = kNumberTypeFloat32;
90 
91  private:
92   int SetCNodeBackend();
93   int CheckStrategyValue();
94 };
95 
96 // a template func for normal op_coder creator
97 template <typename T>
OperatorInfoCreator(const std::string & name,const SplitStrategy & strategy)98 std::unique_ptr<OperatorInfo> OperatorInfoCreator(const std::string &name, const SplitStrategy &strategy) {
99   std::unique_ptr<T> coder = std::make_unique<T>(name, strategy);
100   return coder;
101 }
102 
103 bool is_any_none(const std::vector<int64_t> &split);
104 bool is_any_not_none(const std::vector<int64_t> &split);
105 
106 }  // namespace opt
107 }  // namespace mindspore
108 
109 #endif  // MINDSPORE_LITE_SRC_PASS_PARALLEL_OPERATOR_INFO_H_
110