1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/optimizer/pattern.h"
17 #include "pybind_api/api_register.h"
18
19 namespace mindspore {
20 namespace opt {
21 namespace python_pass {
22 int64_t Pattern::g_id_ = 0;
23
match(const AnfNodePtr & node)24 MatchResultPtr Prim::match(const AnfNodePtr &node) {
25 if (!IsValueNode<Primitive>(node)) {
26 return nullptr;
27 }
28 MatchResultPtr res = std::make_shared<MatchResult>();
29 // iterate over all primitives
30 for (auto &iter : primitives_) {
31 if (IsPrimitive(node, iter) || iter->name() == "*") {
32 matched_prim_ = iter;
33 res->add_entry(shared_from_base<Prim>(), node);
34 return res;
35 }
36 }
37 return nullptr;
38 }
39
match(const AnfNodePtr & node)40 MatchResultPtr Call::match(const AnfNodePtr &node) {
41 if (!IsPrimitiveCNode(node)) {
42 return nullptr;
43 }
44 MatchResultPtr res = std::make_shared<MatchResult>();
45 // IsPrimitiveCNode
46 auto cnode = node->cast<CNodePtr>();
47 MS_EXCEPTION_IF_NULL(cnode);
48 // Check Primitive ValueNode
49 if (prim_pattern_ != nullptr) {
50 // Passed in prim_pattern
51 auto prim_value_res = prim_pattern_->match(cnode->input(0));
52 if (prim_value_res == nullptr) {
53 return nullptr;
54 }
55 res->merge(prim_value_res);
56 } else if (prim_ != nullptr) {
57 // Passed in primitive/primitive str
58 if (!IsPrimitive(cnode->input(0), prim_)) {
59 return nullptr;
60 }
61 } else {
62 MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
63 }
64 // Check inputs
65 auto p_inputs_size = inputs_.size();
66 auto node_inputs_size = cnode->size() - 1;
67 if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
68 return nullptr;
69 }
70 // If inputs is not specified, add node without looking into its inputs
71 if (p_inputs_size == 0) {
72 res->add_entry(shared_from_base<Call>(), cnode->input(0));
73 return res;
74 }
75 bool failed = false;
76 for (std::size_t i = 0; i < node_inputs_size; i++) {
77 auto pattern = inputs_[i];
78 auto input = cnode->input(i + 1);
79 auto input_match_result = pattern->match(input);
80 if (input_match_result == nullptr) {
81 failed = true;
82 break;
83 }
84 res->merge(input_match_result);
85 }
86 if (!failed) {
87 res->add_entry(shared_from_base<Call>(), cnode->input(0));
88 return res;
89 }
90 return nullptr;
91 }
92
match(const AnfNodePtr & node)93 MatchResultPtr OneOf::match(const AnfNodePtr &node) {
94 for (auto &iter : patterns_) {
95 auto res = iter->match(node);
96 if (res != nullptr) {
97 res->add_entry(shared_from_base<OneOf>(), node);
98 return res;
99 }
100 }
101 return nullptr;
102 }
103
match(const AnfNodePtr & node)104 MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
105 for (auto &iter : patterns_) {
106 auto match_res = iter->match(node);
107 if (match_res != nullptr) {
108 return nullptr;
109 }
110 }
111 auto res = std::make_shared<MatchResult>();
112 res->add_entry(shared_from_base<NoneOf>(), node);
113 return res;
114 }
115
match(const AnfNodePtr & node)116 MatchResultPtr Any::match(const AnfNodePtr &node) {
117 MatchResultPtr res = std::make_shared<MatchResult>();
118 res->add_entry(shared_from_base<Any>(), node);
119 return res;
120 }
121
match(const AnfNodePtr & node)122 MatchResultPtr Imm::match(const AnfNodePtr &node) {
123 if (!IsValueNode<Int32Imm>(node)) {
124 return nullptr;
125 }
126 // Check value
127 auto value_node = node->cast<ValueNodePtr>();
128 MS_EXCEPTION_IF_NULL(value_node);
129 auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
130 MS_EXCEPTION_IF_NULL(value_ptr);
131 if ((int32_t)value_ptr->value() == value_) {
132 MatchResultPtr res = std::make_shared<MatchResult>();
133 res->add_entry(shared_from_base<Imm>(), node);
134 return res;
135 }
136 return nullptr;
137 }
138
get_node(const PatternPtr & pattern)139 AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
140 auto entry = match_result_.find(pattern);
141 if (entry == match_result_.end()) {
142 return nullptr;
143 }
144 return entry->second;
145 }
146
merge(const MatchResultPtr & other_result)147 void MatchResult::merge(const MatchResultPtr &other_result) {
148 auto other_result_map = other_result->result();
149 // add/update entries in other_result
150 for (auto &iter : other_result_map) {
151 match_result_[iter.first] = iter.second;
152 }
153 }
154
155 REGISTER_PYBIND_DEFINE(
__anonaef73f340102(const py::module *m) 156 Pattern, ([](const py::module *m) {
157 (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
158 (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
159 (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
160 .def(py::init<vector<py::object>, string>());
161 (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
162 .def(py::init<PatternPtr, vector<PatternPtr>>())
163 .def(py::init<py::object, vector<PatternPtr>>());
164 (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
165 (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
166 (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
167 .def(py::init<tensor::TensorPtr>());
168 (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
169 .def(py::init<string, tensor::TensorPtr, bool, bool>());
170 (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int64_t>());
171 }));
172 } // namespace python_pass
173 } // namespace opt
174 } // namespace mindspore
175