• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_OP_DEF_H_
18 #define MINDSPORE_CORE_OPS_OP_DEF_H_
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <unordered_map>
23 #include "ir/dtype/type_id.h"
24 #include "ops_func_impl/op_func_impl.h"
25 namespace mindspore::ops {
26 
27 enum OP_DTYPE : int64_t {
28   DT_BEGIN = 0,
29   DT_BOOL,
30   DT_INT,
31   DT_FLOAT,
32   DT_NUMBER,
33   DT_TENSOR,
34   DT_STR,
35   DT_ANY,
36   DT_TUPLE_BOOL,
37   DT_TUPLE_INT,
38   DT_TUPLE_FLOAT,
39   DT_TUPLE_NUMBER,
40   DT_TUPLE_TENSOR,
41   DT_TUPLE_STR,
42   DT_TUPLE_ANY,
43   DT_LIST_BOOL,
44   DT_LIST_INT,
45   DT_LIST_FLOAT,
46   DT_LIST_NUMBER,
47   DT_LIST_TENSOR,
48   DT_LIST_STR,
49   DT_LIST_ANY,
50   DT_TYPE,
51   DT_END,
52 };
53 
54 struct OpInputArg {
55   std::string arg_name_;
56   OP_DTYPE arg_dtype_;
57   bool as_init_arg_;  // true if this is a primitive init arg.
58   std::string arg_handler_;
59   std::vector<OP_DTYPE> cast_dtype_;
60   bool is_optional_;
61 };
62 
63 struct OpOutputArg {
64   std::string arg_name_;
65   OP_DTYPE arg_dtype_;
66   int64_t inplace_input_index_;
67 };
68 
69 struct OpDef {
70   std::string name_;
71   std::vector<OpInputArg> args_;
72   std::vector<OpOutputArg> returns_;
73   std::vector<Signature> signatures_;
74   std::unordered_map<std::string, size_t> indexes_;
75   OpFuncImpl &func_impl_;
76   bool enable_dispatch_;
77   bool is_view_;
78 };
79 
80 using OpDefPtr = OpDef *;
81 
82 MS_CORE_API OpDefPtr GetOpDef(const std::string &op_name);
83 MS_CORE_API void AddOpDef(const std::string &op_name, const OpDefPtr op_def);
84 MS_CORE_API bool IsPrimitiveFunction(const std::string &op_name);
85 
86 class OpDefRegHelper {
87  public:
OpDefRegHelper(const std::string & op_name,const OpDefPtr op_def)88   OpDefRegHelper(const std::string &op_name, const OpDefPtr op_def) { AddOpDef(op_name, op_def); }
89   ~OpDefRegHelper() = default;
90 };
91 
92 #define REGISTER_PRIMITIVE_OP_DEF(op_name, op_def) \
93   static auto op_def_helper_##op_name = OpDefRegHelper(op_name, op_def);
94 }  // namespace mindspore::ops
95 #endif
96