1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_OP_DEF_H_ 18 #define MINDSPORE_CORE_OPS_OP_DEF_H_ 19 #include <string> 20 #include <vector> 21 #include <memory> 22 #include <unordered_map> 23 #include "ir/dtype/type_id.h" 24 #include "ops_func_impl/op_func_impl.h" 25 namespace mindspore::ops { 26 27 enum OP_DTYPE : int64_t { 28 DT_BEGIN = 0, 29 DT_BOOL, 30 DT_INT, 31 DT_FLOAT, 32 DT_NUMBER, 33 DT_TENSOR, 34 DT_STR, 35 DT_ANY, 36 DT_TUPLE_BOOL, 37 DT_TUPLE_INT, 38 DT_TUPLE_FLOAT, 39 DT_TUPLE_NUMBER, 40 DT_TUPLE_TENSOR, 41 DT_TUPLE_STR, 42 DT_TUPLE_ANY, 43 DT_LIST_BOOL, 44 DT_LIST_INT, 45 DT_LIST_FLOAT, 46 DT_LIST_NUMBER, 47 DT_LIST_TENSOR, 48 DT_LIST_STR, 49 DT_LIST_ANY, 50 DT_TYPE, 51 DT_END, 52 }; 53 54 struct OpInputArg { 55 std::string arg_name_; 56 OP_DTYPE arg_dtype_; 57 bool as_init_arg_; // true if this is a primitive init arg. 58 std::string arg_handler_; 59 std::vector<OP_DTYPE> cast_dtype_; 60 bool is_optional_; 61 }; 62 63 struct OpOutputArg { 64 std::string arg_name_; 65 OP_DTYPE arg_dtype_; 66 int64_t inplace_input_index_; 67 }; 68 69 struct OpDef { 70 std::string name_; 71 std::vector<OpInputArg> args_; 72 std::vector<OpOutputArg> returns_; 73 std::vector<Signature> signatures_; 74 std::unordered_map<std::string, size_t> indexes_; 75 OpFuncImpl &func_impl_; 76 bool enable_dispatch_; 77 bool is_view_; 78 }; 79 80 using OpDefPtr = OpDef *; 81 82 MS_CORE_API OpDefPtr GetOpDef(const std::string &op_name); 83 MS_CORE_API void AddOpDef(const std::string &op_name, const OpDefPtr op_def); 84 MS_CORE_API bool IsPrimitiveFunction(const std::string &op_name); 85 86 class OpDefRegHelper { 87 public: OpDefRegHelper(const std::string & op_name,const OpDefPtr op_def)88 OpDefRegHelper(const std::string &op_name, const OpDefPtr op_def) { AddOpDef(op_name, op_def); } 89 ~OpDefRegHelper() = default; 90 }; 91 92 #define REGISTER_PRIMITIVE_OP_DEF(op_name, op_def) \ 93 static auto op_def_helper_##op_name = OpDefRegHelper(op_name, op_def); 94 } // namespace mindspore::ops 95 #endif 96