• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "tools/optimizer/parallel/split_strategy.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "schema/model_generated.h"
29 #include "include/errorcode.h"
30 
31 namespace mindspore {
32 namespace opt {
33 /**
34  * Do following steps to make a operator support parallel:
35  *
36  * 1.Add the schema::PrimitiveType_XXX to ParallelPass::PARALLEL_LIST;
37  * 2.Add a pair of type and string name to ParallelPass::type_string;
38  * 3.Implement a class XXXInfo whose parent is OperatorInfo;
39  *    3.1.Override CheckStrategy(), InferParallelCNodes() and InferReplaceOp()
40  * 4.include header file of XXXInfo in ops_info_head_files.h
41  * 5.REGISTER XXXInfo in dynamic_creator.cc
42  */
43 using schema::ReduceMode;
44 class OperatorInfo;
45 using OperatorInfoPtr = std::shared_ptr<OperatorInfo>;
46 class OperatorInfo {
47  public:
OperatorInfo(const std::string & name,const SplitStrategy & strategy)48   OperatorInfo(const std::string &name, const SplitStrategy &strategy)
49       : name_(std::move(name)),
50         strategy_(std::move(strategy)),
51         replace_op_(nullptr),
52         func_graph_(nullptr),
53         cnode_(nullptr) {}
54   virtual ~OperatorInfo() = default;
name()55   const std::string name() const { return name_; }
set_name(const std::string & name)56   void set_name(const std::string &name) { name_ = name; }
57   void Init(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int32_t fmk_type);
58   int DoSplit();
replace_op()59   AnfNodePtr replace_op() const { return replace_op_; }
60 
61  protected:
62   int CheckSplitResult(const AnfNodePtr &anf_node, const std::vector<AnfNodePtr> &split_results, int target_output_num);
63 
64   int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
65 
66   AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
67                                int32_t concat_dim, size_t input_nodes_num);
68   AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
69                               size_t input_nodes_num);
70 
71   std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
72 
73   virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
74                                           std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
75                                           const std::vector<int64_t> &splits) = 0;
76   virtual int InferReplaceOp() = 0;
77   virtual int InferParallelCNodes() = 0;
78   virtual int CheckStrategy(const SplitStrategy &strategy) = 0;
79 
80  protected:
81   std::string name_;
82   SplitStrategy strategy_;
83   AnfNodePtr replace_op_{nullptr};
84   std::vector<AnfNodePtr> parallel_output_nodes_;
85   FuncGraphPtr func_graph_{nullptr};
86   CNodePtr cnode_{nullptr};
87   int32_t fmk_type_{};
88   TypeId operator_type_id_ = kNumberTypeFloat32;
89 
90  private:
91   int SetCNodeBackend();
92   int CheckStrategyValue();
93 };
94 
95 // a template func for normal op_coder creator
96 template <typename T>
OperatorInfoCreator(const std::string & name,const SplitStrategy & strategy)97 std::unique_ptr<OperatorInfo> OperatorInfoCreator(const std::string &name, const SplitStrategy &strategy) {
98   std::unique_ptr<T> coder = std::make_unique<T>(name, strategy);
99   return coder;
100 }
101 
102 bool is_any_none(const std::vector<int64_t> &split);
103 bool is_any_not_none(const std::vector<int64_t> &split);
104 
105 }  // namespace opt
106 }  // namespace mindspore
107 
108 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_PARALLEL_OPERATOR_INFO_H_
109