• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "mindspore/core/ops/op_enum.h"
17 #include "mindspore/core/ops/op_def.h"
18 #include "mindspore/core/mindapi/base/format.h"
19 #include "include/common/pybind_api/api_register.h"
20 #include "mindapi/base/types.h"
21 
22 namespace mindspore::ops {
23 
24 namespace {
25 // The pybind11 registers the enum globally. Therefore, if different modules register the same enum type,
26 // conflicts occur.
27 // To avoid conflicts with the mindspore_lite, the FormatEnum added for use in the pybind.
28 enum FormatEnum : int64_t {
29   DEFAULT_FORMAT = Format::DEFAULT_FORMAT,
30   NCHW = Format::NCHW,
31   NHWC = Format::NHWC,
32   NHWC4 = Format::NHWC4,
33   HWKC = Format::HWKC,
34   HWCK = Format::HWCK,
35   KCHW = Format::KCHW,
36   CKHW = Format::CKHW,
37   KHWC = Format::KHWC,
38   CHWK = Format::CHWK,
39   HW = Format::HW,
40   HW4 = Format::HW4,
41   NC = Format::NC,
42   NC4 = Format::NC4,
43   NC4HW4 = Format::NC4HW4,
44   NCDHW = Format::NCDHW,
45   NWC = Format::NWC,
46   NCW = Format::NCW,
47   NDHWC = Format::NDHWC,
48   NC8HW8 = Format::NC8HW8
49 };
50 }  // namespace
51 
RegOpEnum(py::module * m)52 void RegOpEnum(py::module *m) {
53   auto m_sub = m->def_submodule("op_enum", "submodule for op enum");
54   (void)m_sub.def("str_to_enum", &StringToEnumImpl, "string to enum value");
55   (void)py::enum_<OP_DTYPE>(*m, "OpDtype", py::arithmetic())
56     .value("DT_BEGIN", OP_DTYPE::DT_BEGIN)
57     .value("DT_BOOL", OP_DTYPE::DT_BOOL)
58     .value("DT_INT", OP_DTYPE::DT_INT)
59     .value("DT_FLOAT", OP_DTYPE::DT_FLOAT)
60     .value("DT_NUMBER", OP_DTYPE::DT_NUMBER)
61     .value("DT_TENSOR", OP_DTYPE::DT_TENSOR)
62     .value("DT_STR", OP_DTYPE::DT_STR)
63     .value("DT_ANY", OP_DTYPE::DT_ANY)
64     .value("DT_TUPLE_BOOL", OP_DTYPE::DT_TUPLE_BOOL)
65     .value("DT_TUPLE_INT", OP_DTYPE::DT_TUPLE_INT)
66     .value("DT_TUPLE_FLOAT", OP_DTYPE::DT_TUPLE_FLOAT)
67     .value("DT_TUPLE_NUMBER", OP_DTYPE::DT_TUPLE_NUMBER)
68     .value("DT_TUPLE_TENSOR", OP_DTYPE::DT_TUPLE_TENSOR)
69     .value("DT_TUPLE_STR", OP_DTYPE::DT_TUPLE_STR)
70     .value("DT_TUPLE_ANY", OP_DTYPE::DT_TUPLE_ANY)
71     .value("DT_LIST_BOOL", OP_DTYPE::DT_LIST_BOOL)
72     .value("DT_LIST_INT", OP_DTYPE::DT_LIST_INT)
73     .value("DT_LIST_FLOAT", OP_DTYPE::DT_LIST_FLOAT)
74     .value("DT_LIST_NUMBER", OP_DTYPE::DT_LIST_NUMBER)
75     .value("DT_LIST_TENSOR", OP_DTYPE::DT_LIST_TENSOR)
76     .value("DT_LIST_STR", OP_DTYPE::DT_LIST_STR)
77     .value("DT_LIST_ANY", OP_DTYPE::DT_LIST_ANY)
78     .value("DT_TYPE", OP_DTYPE::DT_TYPE)
79     .value("DT_END", OP_DTYPE::DT_END);
80   // There are currently some deficiencies in format, which will be filled in later.
81   (void)py::enum_<FormatEnum>(*m, "FormatEnum", py::arithmetic())
82     .value("DEFAULT_FORMAT", FormatEnum::DEFAULT_FORMAT)
83     .value("NCHW", FormatEnum::NCHW)
84     .value("NHWC", FormatEnum::NHWC)
85     .value("NHWC4", FormatEnum::NHWC4)
86     .value("HWKC", FormatEnum::HWKC)
87     .value("HWCK", FormatEnum::HWCK)
88     .value("KCHW", FormatEnum::KCHW)
89     .value("CKHW", FormatEnum::CKHW)
90     .value("KHWC", FormatEnum::KHWC)
91     .value("CHWK", FormatEnum::CHWK)
92     .value("HW", FormatEnum::HW)
93     .value("HW4", FormatEnum::HW4)
94     .value("NC", FormatEnum::NC)
95     .value("NC4", FormatEnum::NC4)
96     .value("NC4HW4", FormatEnum::NC4HW4)
97     .value("NCDHW", FormatEnum::NCDHW)
98     .value("NWC", FormatEnum::NWC)
99     .value("NCW", FormatEnum::NCW)
100     .value("NDHWC", FormatEnum::NDHWC)
101     .value("NC8HW8", FormatEnum::NC8HW8);
102   (void)py::enum_<Reduction>(*m, "ReductionEnum", py::arithmetic())
103     .value("SUM", Reduction::REDUCTION_SUM)
104     .value("MEAN", Reduction::MEAN)
105     .value("NONE", Reduction::NONE);
106 }
107 }  // namespace mindspore::ops
108