• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H
18 #define MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "tools/optimizer/parallel/operator_info.h"
25 
26 namespace mindspore {
27 namespace opt {
28 using OperatorInfoCreatorFunc =
29   std::function<std::unique_ptr<opt::OperatorInfo>(const std::string &name, const SplitStrategy &strategy)>;
30 
31 class SplitOpKey {
32  public:
33   SplitOpKey() = delete;
34 
SplitOpKey(int op_type,TypeId data_type,bool is_depth_wise)35   SplitOpKey(int op_type, TypeId data_type, bool is_depth_wise)
36       : op_type_(op_type), data_type_(data_type), is_depth_wise_(is_depth_wise) {}
37 
38   bool operator<(const SplitOpKey &key) const;
39 
40   std::string ToString() const;
41 
42   ~SplitOpKey() = default;
43 
44  private:
45   int op_type_{schema::PrimitiveType_NONE};
46   TypeId data_type_{kTypeUnknown};
47   // Conv && DepthwiseCon has same schema_id, so need this flags
48   bool is_depth_wise_{false};
49 };
50 
51 class OperatorInfoFactory {
52  public:
53   static OperatorInfoFactory *GeInstance();
54 
55   OperatorInfoFactory(const OperatorInfoFactory &) = delete;
56 
57   OperatorInfoFactory &operator=(const OperatorInfoFactory &) = delete;
58 
59   void RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
60                             const OperatorInfoCreatorFunc &creator_func);
61 
62   OperatorInfoCreatorFunc FindOperatorInfo(const SplitOpKey &split_op_key);
63 
64  private:
65   OperatorInfoFactory() = default;
66 
67   virtual ~OperatorInfoFactory() = default;
68 
69  private:
70   // key: op_type -->data_type-->-->is_depth_wise-->name
71   std::map<SplitOpKey, OperatorInfoCreatorFunc> operator_info_map_;
72 };
73 
74 class OperatorInfoRegister {
75  public:
76   OperatorInfoRegister() = delete;
77 
78   OperatorInfoRegister(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
79                        const OperatorInfoCreatorFunc &creator_func);
80 
81   ~OperatorInfoRegister() = default;
82 };
83 
84 #define OPERATOR_INFO_REGISTER(operator_type, type_id, is_depth_wise, creator_func)                          \
85   static OperatorInfoRegister g_name##operator_type##type_id##is_depth_wise##Creator(operator_type, type_id, \
86                                                                                      is_depth_wise, creator_func);
87 }  // namespace opt
88 }  // namespace mindspore
89 
90 #endif  // MINDSPORE_LITE_TOOLS_OPERATOR_INFO_REGISTER_H
91