• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/optimizer/pattern.h"
17 #include "pybind_api/api_register.h"
18 
19 namespace mindspore {
20 namespace opt {
21 namespace python_pass {
22 int64_t Pattern::g_id_ = 0;
23 
match(const AnfNodePtr & node)24 MatchResultPtr Prim::match(const AnfNodePtr &node) {
25   if (!IsValueNode<Primitive>(node)) {
26     return nullptr;
27   }
28   MatchResultPtr res = std::make_shared<MatchResult>();
29   // iterate over all primitives
30   for (auto &iter : primitives_) {
31     if (IsPrimitive(node, iter) || iter->name() == "*") {
32       matched_prim_ = iter;
33       res->add_entry(shared_from_base<Prim>(), node);
34       return res;
35     }
36   }
37   return nullptr;
38 }
39 
match(const AnfNodePtr & node)40 MatchResultPtr Call::match(const AnfNodePtr &node) {
41   if (!IsPrimitiveCNode(node)) {
42     return nullptr;
43   }
44   MatchResultPtr res = std::make_shared<MatchResult>();
45   // IsPrimitiveCNode
46   auto cnode = node->cast<CNodePtr>();
47   MS_EXCEPTION_IF_NULL(cnode);
48   // Check Primitive ValueNode
49   if (prim_pattern_ != nullptr) {
50     // Passed in prim_pattern
51     auto prim_value_res = prim_pattern_->match(cnode->input(0));
52     if (prim_value_res == nullptr) {
53       return nullptr;
54     }
55     res->merge(prim_value_res);
56   } else if (prim_ != nullptr) {
57     // Passed in primitive/primitive str
58     if (!IsPrimitive(cnode->input(0), prim_)) {
59       return nullptr;
60     }
61   } else {
62     MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
63   }
64   // Check inputs
65   auto p_inputs_size = inputs_.size();
66   auto node_inputs_size = cnode->size() - 1;
67   if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
68     return nullptr;
69   }
70   // If inputs is not specified, add node without looking into its inputs
71   if (p_inputs_size == 0) {
72     res->add_entry(shared_from_base<Call>(), cnode->input(0));
73     return res;
74   }
75   bool failed = false;
76   for (std::size_t i = 0; i < node_inputs_size; i++) {
77     auto pattern = inputs_[i];
78     auto input = cnode->input(i + 1);
79     auto input_match_result = pattern->match(input);
80     if (input_match_result == nullptr) {
81       failed = true;
82       break;
83     }
84     res->merge(input_match_result);
85   }
86   if (!failed) {
87     res->add_entry(shared_from_base<Call>(), cnode->input(0));
88     return res;
89   }
90   return nullptr;
91 }
92 
match(const AnfNodePtr & node)93 MatchResultPtr OneOf::match(const AnfNodePtr &node) {
94   for (auto &iter : patterns_) {
95     auto res = iter->match(node);
96     if (res != nullptr) {
97       res->add_entry(shared_from_base<OneOf>(), node);
98       return res;
99     }
100   }
101   return nullptr;
102 }
103 
match(const AnfNodePtr & node)104 MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
105   for (auto &iter : patterns_) {
106     auto match_res = iter->match(node);
107     if (match_res != nullptr) {
108       return nullptr;
109     }
110   }
111   auto res = std::make_shared<MatchResult>();
112   res->add_entry(shared_from_base<NoneOf>(), node);
113   return res;
114 }
115 
match(const AnfNodePtr & node)116 MatchResultPtr Any::match(const AnfNodePtr &node) {
117   MatchResultPtr res = std::make_shared<MatchResult>();
118   res->add_entry(shared_from_base<Any>(), node);
119   return res;
120 }
121 
match(const AnfNodePtr & node)122 MatchResultPtr Imm::match(const AnfNodePtr &node) {
123   if (!IsValueNode<Int32Imm>(node)) {
124     return nullptr;
125   }
126   // Check value
127   auto value_node = node->cast<ValueNodePtr>();
128   MS_EXCEPTION_IF_NULL(value_node);
129   auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
130   MS_EXCEPTION_IF_NULL(value_ptr);
131   if ((int32_t)value_ptr->value() == value_) {
132     MatchResultPtr res = std::make_shared<MatchResult>();
133     res->add_entry(shared_from_base<Imm>(), node);
134     return res;
135   }
136   return nullptr;
137 }
138 
get_node(const PatternPtr & pattern)139 AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
140   auto entry = match_result_.find(pattern);
141   if (entry == match_result_.end()) {
142     return nullptr;
143   }
144   return entry->second;
145 }
146 
merge(const MatchResultPtr & other_result)147 void MatchResult::merge(const MatchResultPtr &other_result) {
148   auto other_result_map = other_result->result();
149   // add/update entries in other_result
150   for (auto &iter : other_result_map) {
151     match_result_[iter.first] = iter.second;
152   }
153 }
154 
155 REGISTER_PYBIND_DEFINE(
__anonaef73f340102(const py::module *m) 156   Pattern, ([](const py::module *m) {
157     (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
158     (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
159     (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
160       .def(py::init<vector<py::object>, string>());
161     (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
162       .def(py::init<PatternPtr, vector<PatternPtr>>())
163       .def(py::init<py::object, vector<PatternPtr>>());
164     (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
165     (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
166     (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
167       .def(py::init<tensor::TensorPtr>());
168     (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
169       .def(py::init<string, tensor::TensorPtr, bool, bool>());
170     (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int64_t>());
171   }));
172 }  // namespace python_pass
173 }  // namespace opt
174 }  // namespace mindspore
175