1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include <unordered_set>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "pipeline/jit/parse/parse_dynamic.h"
25 #include "mindspore/core/ir/cell.h"
26
27 namespace mindspore::parse {
28 static std::unordered_set<std::string> cell_input_args_ = {};
29 static const std::set<std::string> ignore_judge_dynamic_cell = {
30 "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
31 "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
32 static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
33 parse::NAMED_PRIMITIVE_NAMECONSTANT,
34 parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
35
ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node,parse::AstMainType type)36 std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
37 parse::AstMainType type) {
38 MS_EXCEPTION_IF_NULL(ast);
39 if (py::isinstance<py::none>(node)) {
40 MS_LOG(DEBUG) << "Get none type node!";
41 return "";
42 }
43 auto node_type = ast->GetNodeType(node);
44 MS_EXCEPTION_IF_NULL(node_type);
45 // Check node type
46 parse::AstMainType node_main_type = node_type->main_type();
47 if (node_main_type != type) {
48 MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
49 return "";
50 }
51 std::string node_name = node_type->node_name();
52 MS_LOG(DEBUG) << "Ast node is " << node_name;
53 return node_name;
54 }
55
ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & fn_node)56 void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node) {
57 MS_EXCEPTION_IF_NULL(ast);
58 py::list args = ast->GetArgs(fn_node);
59 for (size_t i = 1; i < args.size(); i++) {
60 std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
61 MS_LOG(DEBUG) << "Input arg name: " << arg_name;
62 (void)cell_input_args_.emplace(arg_name);
63 }
64 }
65
ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)66 bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
67 MS_LOG(DEBUG) << "Parse if/while expr";
68 py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
69 const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
70 if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
71 py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
72 py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
73 if (comparators_node.empty()) {
74 MS_LOG(DEBUG) << "Get comparators node failed!";
75 return false;
76 }
77 auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
78 auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
79 // while self.a > self.b and changed self.a or self.b
80 if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
81 auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
82 std::string left_variable;
83 if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
84 left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
85 }
86 auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
87 std::string right_variable;
88 if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
89 right_variable =
90 py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
91 }
92 return ParseBodyContext(ast, node, {left_variable, right_variable});
93 }
94 // if a[0]
95 if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
96 py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
97 left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
98 }
99 MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
100 if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
101 unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
102 return true;
103 }
104 }
105 // if flag:
106 if (node_name == parse::NAMED_PRIMITIVE_NAME) {
107 std::string id = py::cast<std::string>(test_node.attr("id"));
108 if (cell_input_args_.find(id) != cell_input_args_.end()) {
109 return true;
110 }
111 }
112 return false;
113 }
114
ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)115 bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
116 MS_LOG(DEBUG) << "Parse assign expr";
117 py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
118 const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
119 if (node_name == parse::NAMED_PRIMITIVE_CALL) {
120 py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
121 const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
122 if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
123 py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
124 py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
125 if (py::isinstance<py::none>(value_in_slice_node)) {
126 MS_LOG(DEBUG) << "Parse value node is none!";
127 return false;
128 }
129 const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
130 std::string id;
131 if (py::hasattr(value_in_slice_node, "id")) {
132 id = py::cast<std::string>(value_in_slice_node.attr("id"));
133 }
134 if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() ||
135 (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) {
136 return true;
137 }
138 }
139 }
140 return false;
141 }
142
ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &,const py::object & node,const std::vector<std::string> & compare_prim)143 bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &, const py::object &node,
144 const std::vector<std::string> &compare_prim) {
145 MS_LOG(DEBUG) << "Parse augassign expr";
146 bool ret = false;
147 if (compare_prim.empty()) {
148 return ret;
149 }
150 py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
151 if (py::isinstance<py::none>(target_node)) {
152 MS_LOG(DEBUG) << "Parse target node is none!";
153 return ret;
154 }
155 py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
156 if (py::isinstance<py::none>(value_node)) {
157 MS_LOG(DEBUG) << "Parse value node is none!";
158 return ret;
159 }
160 std::string assign_prim;
161 if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
162 assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
163 }
164 auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
165 if (iter != compare_prim.end()) {
166 ret = true;
167 }
168 return ret;
169 }
170
ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)171 bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
172 MS_LOG(DEBUG) << "Parse for expr";
173 py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
174 if (py::isinstance<py::none>(body_node)) {
175 MS_LOG(DEBUG) << "Parse body of for expression is none!";
176 return false;
177 }
178 py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
179 size_t count = LongToSize(pcount);
180 MS_LOG(DEBUG) << "The for nodes count in body is " << count;
181 for (size_t i = 0; i < count; ++i) {
182 auto it = py::cast<py::list>(body_node)[i];
183 const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
184 if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
185 return true;
186 }
187 }
188 return false;
189 }
190
ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & fn_node,const std::vector<std::string> & compare_prim)191 bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
192 const std::vector<std::string> &compare_prim) {
193 MS_EXCEPTION_IF_NULL(ast);
194 py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
195 if (py::isinstance<py::none>(func_obj)) {
196 MS_LOG(DEBUG) << "Parse body of cell is none!";
197 return false;
198 }
199 py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
200 size_t count = IntToSize(pcount);
201 MS_LOG(DEBUG) << "The nodes count in body is " << count;
202 bool ret = false;
203 for (size_t i = 0; i < count; ++i) {
204 auto node = py::cast<py::list>(func_obj)[i];
205 const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
206 if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
207 ret = ParseAssignExprNode(ast, node);
208 } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
209 ret = ParseAugAssignExprNode(ast, node, compare_prim);
210 } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
211 ret = ParseForExprNode(ast, node);
212 } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
213 ret = ParseIfWhileExprNode(ast, node);
214 }
215 if (ret) {
216 MS_LOG(INFO) << "Current cell is dynamic!";
217 break;
218 }
219 }
220 return ret;
221 }
222
GetCellInfo(const py::object & cell)223 std::string DynamicParser::GetCellInfo(const py::object &cell) {
224 if (py::isinstance<Cell>(cell)) {
225 auto c_cell = py::cast<CellPtr>(cell);
226 MS_EXCEPTION_IF_NULL(c_cell);
227 auto cell_info = c_cell->ToString();
228 return cell_info;
229 }
230 return "";
231 }
232
IsDynamicCell(const py::object & cell)233 bool DynamicParser::IsDynamicCell(const py::object &cell) {
234 std::string cell_info = GetCellInfo(cell);
235 if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
236 return false;
237 }
238 // Using ast parse to check whether the construct of cell will be changed
239 auto ast = std::make_shared<parse::ParseFunctionAst>(cell);
240 bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
241 if (!success) {
242 MS_LOG(ERROR) << "Parse code to ast tree failed";
243 return false;
244 }
245 py::object fn_node = ast->GetAstNode();
246 // get the name of input args as the initialize of dynamic_variables
247 ParseInputArgs(ast, fn_node);
248 // parse body context
249 bool ret = false;
250 ret = ParseBodyContext(ast, fn_node);
251 cell_input_args_.clear();
252 return ret;
253 }
254 } // namespace mindspore::parse
255