• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_
18 #define MINDSPORE_LITE_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_
19 
20 #include <map>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include "src/lite_kernel.h"
25 #include "include/model.h"
26 #include "coder/config.h"
27 namespace mindspore::lite::micro {
28 class OperatorCoder;
29 using CoderCreatorFunc = std::function<std::unique_ptr<OperatorCoder>(
30   const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, const Model::Node *node,
31   size_t node_index, Target target, int schema_version)>;
32 
33 class CoderKey {
34  public:
35   CoderKey() = delete;
36 
CoderKey(Target target,TypeId data_type,int op_type)37   CoderKey(Target target, TypeId data_type, int op_type) : target_(target), data_type_(data_type), op_type_(op_type) {}
38 
AllKey()39   CoderKey AllKey() const {
40     CoderKey key(kAllTargets, data_type_, op_type_);
41     return key;
42   }
43 
44   bool operator<(CoderKey rhs) const;
45   std::string ToString() const;
46 
47   ~CoderKey() = default;
48 
49  private:
50   Target target_ = kTargetUnknown;
51   TypeId data_type_ = kTypeUnknown;
52   int op_type_ = schema::PrimitiveType_NONE;
53 };
54 
55 class OpCoderFactory {
56  public:
57   OpCoderFactory() = default;
58 
59   static OpCoderFactory *GetInstance();
60 
61   int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type,
62                     const CoderCreatorFunc &creator_func);
63 
64   CoderCreatorFunc FindOpCoder(const CoderKey &key);
65 
~OpCoderFactory()66   ~OpCoderFactory() { opcoder_sets_.clear(); }
67 
68  private:
69   // target || data type || primitive type
70   std::map<CoderKey, CoderCreatorFunc> opcoder_sets_;
71 };
72 
73 class OpCoderRegister {
74  public:
75   OpCoderRegister() = delete;
76 
77   OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
78                   const CoderCreatorFunc &creator_func);
79 
80   ~OpCoderRegister() = default;
81 };
82 #define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \
83   static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, creator_func);
84 }  // namespace mindspore::lite::micro
85 
86 #endif  // MINDSPORE_LITE_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_
87