• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <unordered_set>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "pipeline/jit/parse/parse_dynamic.h"
25 #include "mindspore/core/ir/cell.h"
26 
27 namespace mindspore::parse {
28 static std::unordered_set<std::string> cell_input_args_ = {};
29 static const std::set<std::string> ignore_judge_dynamic_cell = {
30   "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
31   "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
32 static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
33                                                                 parse::NAMED_PRIMITIVE_NAMECONSTANT,
34                                                                 parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
35 
ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node,parse::AstMainType type)36 std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
37                                          parse::AstMainType type) {
38   MS_EXCEPTION_IF_NULL(ast);
39   if (py::isinstance<py::none>(node)) {
40     MS_LOG(DEBUG) << "Get none type node!";
41     return "";
42   }
43   auto node_type = ast->GetNodeType(node);
44   MS_EXCEPTION_IF_NULL(node_type);
45   // Check node type
46   parse::AstMainType node_main_type = node_type->main_type();
47   if (node_main_type != type) {
48     MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
49     return "";
50   }
51   std::string node_name = node_type->node_name();
52   MS_LOG(DEBUG) << "Ast node is " << node_name;
53   return node_name;
54 }
55 
ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & fn_node)56 void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node) {
57   MS_EXCEPTION_IF_NULL(ast);
58   py::list args = ast->GetArgs(fn_node);
59   for (size_t i = 1; i < args.size(); i++) {
60     std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
61     MS_LOG(DEBUG) << "Input arg name: " << arg_name;
62     (void)cell_input_args_.emplace(arg_name);
63   }
64 }
65 
ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)66 bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
67   MS_LOG(DEBUG) << "Parse if/while expr";
68   py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
69   const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
70   if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
71     py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
72     py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
73     if (comparators_node.empty()) {
74       MS_LOG(DEBUG) << "Get comparators node failed!";
75       return false;
76     }
77     auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
78     auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
79     // while self.a > self.b and changed self.a or self.b
80     if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
81       auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
82       std::string left_variable;
83       if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
84         left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
85       }
86       auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
87       std::string right_variable;
88       if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
89         right_variable =
90           py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
91       }
92       return ParseBodyContext(ast, node, {left_variable, right_variable});
93     }
94     // if a[0]
95     if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
96       py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
97       left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
98     }
99     MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
100     if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
101         unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
102       return true;
103     }
104   }
105   // if flag:
106   if (node_name == parse::NAMED_PRIMITIVE_NAME) {
107     std::string id = py::cast<std::string>(test_node.attr("id"));
108     if (cell_input_args_.find(id) != cell_input_args_.end()) {
109       return true;
110     }
111   }
112   return false;
113 }
114 
ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)115 bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
116   MS_LOG(DEBUG) << "Parse assign expr";
117   py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
118   const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
119   if (node_name == parse::NAMED_PRIMITIVE_CALL) {
120     py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
121     const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
122     if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
123       py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
124       py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
125       if (py::isinstance<py::none>(value_in_slice_node)) {
126         MS_LOG(DEBUG) << "Parse value node is none!";
127         return false;
128       }
129       const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
130       std::string id;
131       if (py::hasattr(value_in_slice_node, "id")) {
132         id = py::cast<std::string>(value_in_slice_node.attr("id"));
133       }
134       if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() ||
135           (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) {
136         return true;
137       }
138     }
139   }
140   return false;
141 }
142 
ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &,const py::object & node,const std::vector<std::string> & compare_prim)143 bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &, const py::object &node,
144                                            const std::vector<std::string> &compare_prim) {
145   MS_LOG(DEBUG) << "Parse augassign expr";
146   bool ret = false;
147   if (compare_prim.empty()) {
148     return ret;
149   }
150   py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
151   if (py::isinstance<py::none>(target_node)) {
152     MS_LOG(DEBUG) << "Parse target node is none!";
153     return ret;
154   }
155   py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
156   if (py::isinstance<py::none>(value_node)) {
157     MS_LOG(DEBUG) << "Parse value node is none!";
158     return ret;
159   }
160   std::string assign_prim;
161   if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
162     assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
163   }
164   auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
165   if (iter != compare_prim.end()) {
166     ret = true;
167   }
168   return ret;
169 }
170 
ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & node)171 bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
172   MS_LOG(DEBUG) << "Parse for expr";
173   py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
174   if (py::isinstance<py::none>(body_node)) {
175     MS_LOG(DEBUG) << "Parse body of for expression is none!";
176     return false;
177   }
178   py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
179   size_t count = LongToSize(pcount);
180   MS_LOG(DEBUG) << "The for nodes count in body is " << count;
181   for (size_t i = 0; i < count; ++i) {
182     auto it = py::cast<py::list>(body_node)[i];
183     const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
184     if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
185       return true;
186     }
187   }
188   return false;
189 }
190 
ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> & ast,const py::object & fn_node,const std::vector<std::string> & compare_prim)191 bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
192                                      const std::vector<std::string> &compare_prim) {
193   MS_EXCEPTION_IF_NULL(ast);
194   py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
195   if (py::isinstance<py::none>(func_obj)) {
196     MS_LOG(DEBUG) << "Parse body of cell is none!";
197     return false;
198   }
199   py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
200   size_t count = IntToSize(pcount);
201   MS_LOG(DEBUG) << "The nodes count in body is " << count;
202   bool ret = false;
203   for (size_t i = 0; i < count; ++i) {
204     auto node = py::cast<py::list>(func_obj)[i];
205     const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
206     if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
207       ret = ParseAssignExprNode(ast, node);
208     } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
209       ret = ParseAugAssignExprNode(ast, node, compare_prim);
210     } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
211       ret = ParseForExprNode(ast, node);
212     } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
213       ret = ParseIfWhileExprNode(ast, node);
214     }
215     if (ret) {
216       MS_LOG(INFO) << "Current cell is dynamic!";
217       break;
218     }
219   }
220   return ret;
221 }
222 
GetCellInfo(const py::object & cell)223 std::string DynamicParser::GetCellInfo(const py::object &cell) {
224   if (py::isinstance<Cell>(cell)) {
225     auto c_cell = py::cast<CellPtr>(cell);
226     MS_EXCEPTION_IF_NULL(c_cell);
227     auto cell_info = c_cell->ToString();
228     return cell_info;
229   }
230   return "";
231 }
232 
IsDynamicCell(const py::object & cell)233 bool DynamicParser::IsDynamicCell(const py::object &cell) {
234   std::string cell_info = GetCellInfo(cell);
235   if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
236     return false;
237   }
238   // Using ast parse to check whether the construct of cell will be changed
239   auto ast = std::make_shared<parse::ParseFunctionAst>(cell);
240   bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
241   if (!success) {
242     MS_LOG(ERROR) << "Parse code to ast tree failed";
243     return false;
244   }
245   py::object fn_node = ast->GetAstNode();
246   // get the name of input args as the initialize of dynamic_variables
247   ParseInputArgs(ast, fn_node);
248   // parse body context
249   bool ret = false;
250   ret = ParseBodyContext(ast, fn_node);
251   cell_input_args_.clear();
252   return ret;
253 }
254 }  // namespace mindspore::parse
255