/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include "pipeline/jit/parse/parse_dynamic.h" #include "mindspore/core/ir/cell.h" namespace mindspore::parse { static std::unordered_set cell_input_args_ = {}; static const std::set ignore_judge_dynamic_cell = { "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; static const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, parse::NAMED_PRIMITIVE_NAMECONSTANT, parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; std::string DynamicParser::ParseNodeName(const std::shared_ptr &ast, const py::object &node, parse::AstMainType type) { MS_EXCEPTION_IF_NULL(ast); if (py::isinstance(node)) { MS_LOG(DEBUG) << "Get none type node!"; return ""; } auto node_type = ast->GetNodeType(node); MS_EXCEPTION_IF_NULL(node_type); // Check node type parse::AstMainType node_main_type = node_type->main_type(); if (node_main_type != type) { MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type; return ""; } std::string node_name = node_type->node_name(); MS_LOG(DEBUG) << "Ast node is " << node_name; return node_name; } void DynamicParser::ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node) { MS_EXCEPTION_IF_NULL(ast); py::list args = ast->GetArgs(fn_node); for (size_t i = 1; i < args.size(); i++) { std::string arg_name = py::cast(args[i].attr("arg")); MS_LOG(DEBUG) << "Input arg name: " << arg_name; (void)cell_input_args_.emplace(arg_name); } } bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse if/while expr"; py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR); if (node_name == parse::NAMED_PRIMITIVE_COMPARE) { py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT); py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS); if (comparators_node.empty()) { MS_LOG(DEBUG) << "Get comparators node failed!"; return false; } auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); // while self.a > self.b and changed self.a or self.b if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) { auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); std::string left_variable; if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) { left_variable = py::cast(left_value.attr("id")) + py::cast(left_node.attr("attr")); } auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE); std::string right_variable; if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) { right_variable = py::cast(right_value.attr("id")) + py::cast(comparators_node[0].attr("attr")); } return ParseBodyContext(ast, node, {left_variable, right_variable}); } // if a[0] if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) { py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR); } MS_LOG(DEBUG) << "Left is " << left << " Right is " << right; if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { return true; } } // if flag: if (node_name == parse::NAMED_PRIMITIVE_NAME) { std::string id = py::cast(test_node.attr("id")); if (cell_input_args_.find(id) != cell_input_args_.end()) { return true; } } return false; } bool DynamicParser::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse assign expr"; py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE); const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR); if (node_name == parse::NAMED_PRIMITIVE_CALL) { py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC); const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR); if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) { py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); if (py::isinstance(value_in_slice_node)) { MS_LOG(DEBUG) << "Parse value node is none!"; return false; } const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); std::string id; if (py::hasattr(value_in_slice_node, "id")) { id = py::cast(value_in_slice_node.attr("id")); } if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() || (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) { return true; } } } return false; } bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr &, const py::object &node, const std::vector &compare_prim) { MS_LOG(DEBUG) << "Parse augassign expr"; bool ret = false; if (compare_prim.empty()) { return ret; } py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET); if (py::isinstance(target_node)) { MS_LOG(DEBUG) << "Parse target node is none!"; return ret; } py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE); if (py::isinstance(value_node)) { MS_LOG(DEBUG) << "Parse value node is none!"; return ret; } std::string assign_prim; if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) { assign_prim = py::cast(value_node.attr("id")) + py::cast(target_node.attr("attr")); } auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim); if (iter != compare_prim.end()) { ret = true; } return ret; } bool DynamicParser::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse for expr"; py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); if (py::isinstance(body_node)) { MS_LOG(DEBUG) << "Parse body of for expression is none!"; return false; } py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); size_t count = LongToSize(pcount); MS_LOG(DEBUG) << "The for nodes count in body is " << count; for (size_t i = 0; i < count; ++i) { auto it = py::cast(body_node)[i]; const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) { return true; } } return false; } bool DynamicParser::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, const std::vector &compare_prim) { MS_EXCEPTION_IF_NULL(ast); py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); if (py::isinstance(func_obj)) { MS_LOG(DEBUG) << "Parse body of cell is none!"; return false; } py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); size_t count = IntToSize(pcount); MS_LOG(DEBUG) << "The nodes count in body is " << count; bool ret = false; for (size_t i = 0; i < count; ++i) { auto node = py::cast(func_obj)[i]; const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT); if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) { ret = ParseAssignExprNode(ast, node); } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) { ret = ParseAugAssignExprNode(ast, node, compare_prim); } else if (node_name == parse::NAMED_PRIMITIVE_FOR) { ret = ParseForExprNode(ast, node); } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) { ret = ParseIfWhileExprNode(ast, node); } if (ret) { MS_LOG(INFO) << "Current cell is dynamic!"; break; } } return ret; } std::string DynamicParser::GetCellInfo(const py::object &cell) { if (py::isinstance(cell)) { auto c_cell = py::cast(cell); MS_EXCEPTION_IF_NULL(c_cell); auto cell_info = c_cell->ToString(); return cell_info; } return ""; } bool DynamicParser::IsDynamicCell(const py::object &cell) { std::string cell_info = GetCellInfo(cell); if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { return false; } // Using ast parse to check whether the construct of cell will be changed auto ast = std::make_shared(cell); bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); if (!success) { MS_LOG(ERROR) << "Parse code to ast tree failed"; return false; } py::object fn_node = ast->GetAstNode(); // get the name of input args as the initialize of dynamic_variables ParseInputArgs(ast, fn_node); // parse body context bool ret = false; ret = ParseBodyContext(ast, fn_node); cell_input_args_.clear(); return ret; } } // namespace mindspore::parse