1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "mindspore/core/ops/op_enum.h"
17 #include "mindspore/core/ops/op_def.h"
18 #include "mindspore/core/mindapi/base/format.h"
19 #include "include/common/pybind_api/api_register.h"
20 #include "mindapi/base/types.h"
21
22 namespace mindspore::ops {
23
24 namespace {
25 // The pybind11 registers the enum globally. Therefore, if different modules register the same enum type,
26 // conflicts occur.
27 // To avoid conflicts with the mindspore_lite, the FormatEnum added for use in the pybind.
28 enum FormatEnum : int64_t {
29 DEFAULT_FORMAT = Format::DEFAULT_FORMAT,
30 NCHW = Format::NCHW,
31 NHWC = Format::NHWC,
32 NHWC4 = Format::NHWC4,
33 HWKC = Format::HWKC,
34 HWCK = Format::HWCK,
35 KCHW = Format::KCHW,
36 CKHW = Format::CKHW,
37 KHWC = Format::KHWC,
38 CHWK = Format::CHWK,
39 HW = Format::HW,
40 HW4 = Format::HW4,
41 NC = Format::NC,
42 NC4 = Format::NC4,
43 NC4HW4 = Format::NC4HW4,
44 NCDHW = Format::NCDHW,
45 NWC = Format::NWC,
46 NCW = Format::NCW,
47 NDHWC = Format::NDHWC,
48 NC8HW8 = Format::NC8HW8
49 };
50 } // namespace
51
RegOpEnum(py::module * m)52 void RegOpEnum(py::module *m) {
53 auto m_sub = m->def_submodule("op_enum", "submodule for op enum");
54 (void)m_sub.def("str_to_enum", &StringToEnumImpl, "string to enum value");
55 (void)py::enum_<OP_DTYPE>(*m, "OpDtype", py::arithmetic())
56 .value("DT_BEGIN", OP_DTYPE::DT_BEGIN)
57 .value("DT_BOOL", OP_DTYPE::DT_BOOL)
58 .value("DT_INT", OP_DTYPE::DT_INT)
59 .value("DT_FLOAT", OP_DTYPE::DT_FLOAT)
60 .value("DT_NUMBER", OP_DTYPE::DT_NUMBER)
61 .value("DT_TENSOR", OP_DTYPE::DT_TENSOR)
62 .value("DT_STR", OP_DTYPE::DT_STR)
63 .value("DT_ANY", OP_DTYPE::DT_ANY)
64 .value("DT_TUPLE_BOOL", OP_DTYPE::DT_TUPLE_BOOL)
65 .value("DT_TUPLE_INT", OP_DTYPE::DT_TUPLE_INT)
66 .value("DT_TUPLE_FLOAT", OP_DTYPE::DT_TUPLE_FLOAT)
67 .value("DT_TUPLE_NUMBER", OP_DTYPE::DT_TUPLE_NUMBER)
68 .value("DT_TUPLE_TENSOR", OP_DTYPE::DT_TUPLE_TENSOR)
69 .value("DT_TUPLE_STR", OP_DTYPE::DT_TUPLE_STR)
70 .value("DT_TUPLE_ANY", OP_DTYPE::DT_TUPLE_ANY)
71 .value("DT_LIST_BOOL", OP_DTYPE::DT_LIST_BOOL)
72 .value("DT_LIST_INT", OP_DTYPE::DT_LIST_INT)
73 .value("DT_LIST_FLOAT", OP_DTYPE::DT_LIST_FLOAT)
74 .value("DT_LIST_NUMBER", OP_DTYPE::DT_LIST_NUMBER)
75 .value("DT_LIST_TENSOR", OP_DTYPE::DT_LIST_TENSOR)
76 .value("DT_LIST_STR", OP_DTYPE::DT_LIST_STR)
77 .value("DT_LIST_ANY", OP_DTYPE::DT_LIST_ANY)
78 .value("DT_TYPE", OP_DTYPE::DT_TYPE)
79 .value("DT_END", OP_DTYPE::DT_END);
80 // There are currently some deficiencies in format, which will be filled in later.
81 (void)py::enum_<FormatEnum>(*m, "FormatEnum", py::arithmetic())
82 .value("DEFAULT_FORMAT", FormatEnum::DEFAULT_FORMAT)
83 .value("NCHW", FormatEnum::NCHW)
84 .value("NHWC", FormatEnum::NHWC)
85 .value("NHWC4", FormatEnum::NHWC4)
86 .value("HWKC", FormatEnum::HWKC)
87 .value("HWCK", FormatEnum::HWCK)
88 .value("KCHW", FormatEnum::KCHW)
89 .value("CKHW", FormatEnum::CKHW)
90 .value("KHWC", FormatEnum::KHWC)
91 .value("CHWK", FormatEnum::CHWK)
92 .value("HW", FormatEnum::HW)
93 .value("HW4", FormatEnum::HW4)
94 .value("NC", FormatEnum::NC)
95 .value("NC4", FormatEnum::NC4)
96 .value("NC4HW4", FormatEnum::NC4HW4)
97 .value("NCDHW", FormatEnum::NCDHW)
98 .value("NWC", FormatEnum::NWC)
99 .value("NCW", FormatEnum::NCW)
100 .value("NDHWC", FormatEnum::NDHWC)
101 .value("NC8HW8", FormatEnum::NC8HW8);
102 (void)py::enum_<Reduction>(*m, "ReductionEnum", py::arithmetic())
103 .value("SUM", Reduction::REDUCTION_SUM)
104 .value("MEAN", Reduction::MEAN)
105 .value("NONE", Reduction::NONE);
106 }
107 } // namespace mindspore::ops
108