• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_BASE_OPERATOR_
18 #define MINDSPORE_CORE_OPS_BASE_OPERATOR_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "mindapi/ir/primitive.h"
26 
27 namespace mindspore {
28 namespace abstract {
29 class AnalysisEngine;
30 using AnalysisEnginePtr = std::shared_ptr<AnalysisEngine>;
31 
32 class AbstractBase;
33 using AbstractBasePtr = std::shared_ptr<AbstractBase>;
34 }  // namespace abstract
35 }  // namespace mindspore
36 
37 namespace mindspore {
38 class Primitive;
39 using PrimitivePtr = std::shared_ptr<Primitive>;
40 }  // namespace mindspore
41 
42 namespace mindspore {
43 namespace ops {
44 using PrimitiveCPtr = PrimitivePtr;
45 class MIND_API BaseOperator : public api::Primitive {
46  public:
47   MIND_API_BASE_MEMBER(BaseOperator);
48   explicit BaseOperator(const std::string &name);
49   PrimitiveCPtr GetPrim();
50 
51   void set_batch_rank(int64_t batch_rank);
52   int64_t get_batch_rank() const;
53 
54  protected:
55   void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
56 };
57 
58 using OperatorDefineFunc = std::function<std::shared_ptr<BaseOperator>(const std::shared_ptr<mindspore::Base> &)>;
59 class MIND_API OperatorRegister {
60  public:
~OperatorRegister()61   ~OperatorRegister() {}
62 
63   static OperatorRegister &GetInstance();
64 
65   const std::map<std::string, OperatorDefineFunc> &GetOperatorMap() const;
66 
67   void SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn);
68 
69  private:
OperatorRegister()70   OperatorRegister() {}
71   std::map<std::string, OperatorDefineFunc> operator_fns_;
72 };
73 
74 class MIND_API OperatorRegisterHelper {
75  public:
OperatorRegisterHelper(const std::string & kname,const OperatorDefineFunc & fn)76   OperatorRegisterHelper(const std::string &kname, const OperatorDefineFunc &fn) {
77     OperatorRegister::GetInstance().SetOperatorMap(kname, fn);
78     // (void)id_;  // make compiler happy on macos
79   }
80 
81   ~OperatorRegisterHelper() = default;
82 
83 //  private:
84 //   int id_{0};
85 };
86 
87 #define OPERATOR_CREATOR_REG(K_NAME, OP_CLASS)                                                                   \
88   std::shared_ptr<BaseOperator> GetDefaultBaseOperator##OP_CLASS(const std::shared_ptr<mindspore::Base> &impl) { \
89     return std::make_shared<OP_CLASS>(impl);                                                                     \
90   }                                                                                                              \
91   OperatorRegisterHelper operator_gen_##OP_CLASS(K_NAME, GetDefaultBaseOperator##OP_CLASS)
92 
93 #define MIND_API_OPERATOR_IMPL(ClassName, ParentClassName)    \
94   MIND_API_BASE_IMPL(ClassName, PrimitiveC, ParentClassName); \
95   OPERATOR_CREATOR_REG(#ClassName, ClassName)
96 
97 // This macro is for operator whose name is not same as its class name.
98 #define MIND_API_OPERATOR_NAME_IMPL(ClassName, OpName, ParentClassName) \
99   MIND_API_BASE_IMPL(ClassName, PrimitiveC, ParentClassName);           \
100   OPERATOR_CREATOR_REG(OpName, ClassName)
101 }  // namespace ops
102 }  // namespace mindspore
103 
104 #endif  // MINDSPORE_CORE_OPS_BASE_OPERATOR_
105