1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2023 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ 21 22 #include <vector> 23 #include <string> 24 #include <map> 25 #include <set> 26 #include <memory> 27 #include <utility> 28 #include <tuple> 29 30 #include "utils/hash_map.h" 31 #include "ir/meta_func_graph.h" 32 #include "pipeline/jit/ps/parse/parse_base.h" 33 #include "pipeline/jit/ps/parse/resolve.h" 34 #include "utils/log_adapter.h" 35 #include "utils/ordered_set.h" 36 37 namespace mindspore { 38 namespace parse { 39 class Parser; 40 class NameSpace; 41 class Symbol; 42 class Script; 43 class FunctionBlock; 44 using FunctionBlockPtr = std::shared_ptr<FunctionBlock>; 45 46 // A function block is a straight-line code sequence with no branches, every block has one one exit point 47 // which is return. When parsing function, loop or branch , we use function block to track the structure of 48 // the original source code. 49 class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { 50 public: 51 explicit FunctionBlock(const Parser &parser); 52 virtual ~FunctionBlock() = default; 53 func_graph()54 FuncGraphPtr func_graph() { return func_graph_; } ToString()55 std::string ToString() const { return func_graph_->ToString(); } 56 void WriteVariable(const std::string &var_name, const AnfNodePtr &node); 57 AnfNodePtr ReadVariable(const std::string &var_name); 58 AnfNodePtr ReadLocalVariable(const std::string &var_name); 59 bool CheckHasVariable(const std::string &var_name); 60 void AddPrevBlock(const FunctionBlockPtr &block); 61 void SetPhiArgument(const ParameterPtr &phi); 62 void CollectRemovablePhi(const ParameterPtr &phi); 63 // A block is matured if all its predecessors is generated 64 void Mature(); 65 CNodePtr ForceToCondNode(const AnfNodePtr &cond, bool is_while_cond = false); 66 void Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args); 67 std::set<AnfNodePtr> SearchAllArgsOfPhiNode(const std::string &var, const ParameterPtr &phi); 68 CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call, 69 const AnfNodePtr &false_block_call); 70 CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block, 71 const FunctionBlockPtr &false_block); 72 // Create cnode for the assign statement like self.target = source. 73 void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source); AddGlobalVar(const std::string & var_name)74 void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } IsGlobalVar(const std::string & var_name)75 bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } 76 77 py::tuple GetAstOpNameSpace(const py::object &op); 78 AnfNodePtr MakeResolveAstOpNameSpace(const py::tuple &namespace_var); 79 AnfNodePtr MakeResolveClassObject(); 80 AnfNodePtr MakeResolveClassMember(const std::string &attr_or_self); 81 AnfNodePtr MakeResolveSymbol(const std::string &value); 82 AnfNodePtr MakeResolveOperation(const std::string &value); 83 AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol); 84 AnfNodePtr DoResolve(const AnfNodePtr &node, const std::shared_ptr<NameSpace> &name_space, 85 const std::shared_ptr<Symbol> &resolve_symbol); 86 AnfNodePtr HandleNamespaceSymbol(const std::string &var_name); 87 AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node, 88 const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node); phi_args()89 const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_args() const { return phi_args_; } 90 void FindIsolatedNodes(); 91 void ConvertUnusedNodesToIsolated(const std::pair<std::string, std::pair<AnfNodePtr, bool>> var); 92 void AddIsolatedNode(const AnfNodePtr &target); 93 void AttachIsolatedNodesBeforeReturn(); prev_blocks()94 const std::vector<FunctionBlock *> &prev_blocks() const { return prev_blocks_; } is_dead_block()95 bool is_dead_block() const { return is_dead_block_; } 96 void SetAsDeadBlock(); 97 CNodePtr GetJumpNode(FunctionBlock *target_block); 98 is_return_statement_inside()99 bool is_return_statement_inside() const { return is_return_statement_inside_; } 100 void set_is_return_statement_inside(); is_break_continue_statement_inside()101 bool is_break_continue_statement_inside() const { return is_break_continue_statement_inside_; } 102 void set_break_continue_statement_inside(); 103 block_name()104 const std::string block_name() const { return block_name_; } set_block_name(const std::string & block_name)105 void set_block_name(const std::string &block_name) { block_name_ = block_name; } 106 void CheckUndefinedSymbol(const std::string &var, const AnfNodePtr &node) const; 107 void CheckVariableNotDefined(const std::pair<std::string, AnfNodePtr> ¬_defined_branch, const std::string &var); global_py_params()108 const py::dict &global_py_params() const { return global_py_params_; } set_global_py_params(const py::dict & symbols)109 void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; } HasGlobalPyParam(const std::string & name)110 bool HasGlobalPyParam(const std::string &name) const { return global_py_params_.contains(py::str(name)); } AddGlobalPyParam(const std::string & name,const py::object & obj)111 void AddGlobalPyParam(const std::string &name, const py::object &obj) { 112 MS_LOG(DEBUG) << "Add global param '" << name << "', " << py::str(obj) << " for the block:" << ToString(); 113 global_py_params_[py::str(name)] = obj; 114 } UpdateGlobalPyParam(const py::dict & symbols)115 void UpdateGlobalPyParam(const py::dict &symbols) { 116 for (auto ¶m : symbols) { 117 if (!global_py_params_.contains(param.first)) { 118 MS_LOG(DEBUG) << "Update global param '" << param.first << "', " << py::str(param.second) 119 << " for the block:" << ToString(); 120 global_py_params_[param.first] = param.second; 121 } 122 } 123 } 124 local_py_params()125 std::tuple<std::map<std::string, AnfNodePtr>, std::map<std::string, AnfNodePtr>> local_py_params() { 126 return {local_py_params_keys_, local_py_params_values_}; 127 } 128 129 // Call this method to update or add a variable. UpdateLocalPyParam(const std::string & name,const AnfNodePtr & node)130 void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) { 131 MS_EXCEPTION_IF_NULL(node); 132 const auto key_iter = local_py_params_keys_.find(name); 133 if (key_iter == local_py_params_keys_.end()) { 134 MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString(); 135 (void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name))); 136 (void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(name, node)); 137 } else { 138 // Find the same position in 'values', and update the node. 139 MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> " 140 << node->DebugString(); 141 local_py_params_values_[name] = node; 142 } 143 } 144 145 // Update local parameters from previous block. UpdateLocalPyParam(const std::map<std::string,AnfNodePtr> & keys,std::map<std::string,AnfNodePtr> values)146 void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) { 147 if (keys.size() != values.size()) { 148 MS_LOG(INTERNAL_EXCEPTION) << "keys size should be equal to values size."; 149 } 150 for (auto iter = keys.begin(); iter != keys.end(); ++iter) { 151 const std::string &cur_key_name = iter->first; 152 const auto key_iter = local_py_params_keys_.find(cur_key_name); 153 if (key_iter == local_py_params_keys_.end()) { 154 (void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second)); 155 (void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name])); 156 MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString(); 157 } else { 158 // The local variable is already in the current block. This means the current block has multiples previous 159 // blocks. If this local variable is used in the current block, it should be converted to phi node. So we erase 160 // it from local_py_params. 161 (void)local_py_params_keys_.erase(key_iter); 162 (void)local_py_params_values_.erase(cur_key_name); 163 MS_LOG(DEBUG) << "Erase '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString(); 164 } 165 } 166 if (local_py_params_keys_.size() != local_py_params_values_.size()) { 167 MS_LOG(INTERNAL_EXCEPTION) 168 << "The size of local_py_params_keys_ should be equal to local_py_params_values_ size."; 169 } 170 } 171 172 // Isolated nodes. isolated_nodes()173 const OrderedSet<AnfNodePtr> isolated_nodes() const { return isolated_nodes_; } 174 175 private: 176 // Block graph 177 FuncGraphPtr func_graph_; 178 179 // Block parser 180 const Parser &parser_; 181 182 // A block is matured if all its prev_blocks is processed 183 bool matured_; 184 185 // Store the nest-level block. 186 // Refer to comments in Parser::func_block_list_; 187 std::vector<FunctionBlock *> prev_blocks_; 188 189 // Store args and variable's node, use a bool flag to indicate if the variable is used. 190 std::map<std::string, std::pair<AnfNodePtr, bool>> assigned_vars_; 191 192 // Store the attribute that has been changed. 193 std::map<std::string, std::pair<AnfNodePtr, bool>> changed_non_param_attrs_; 194 195 // Map the parameter node to variable, it can be resolved if the block's predecessors are processed 196 std::map<ParameterPtr, std::string> phi_nodes_; 197 198 // Jumps map the successor block and the function call that perform jump 199 // Refer to comments in Parser::func_block_list_ that how to break the cyclic reference 200 std::map<FunctionBlock *, CNodePtr> jumps_; 201 202 // Keep all removable phis which will be removed in one pass. 203 std::map<ParameterPtr, std::set<AnfNodePtr>> phi_args_; 204 205 // Hold declared global variables in function 206 std::set<std::string> global_vars_; 207 208 // Keep new made resolve symbol for the variable not found in vars_. 209 mindspore::HashMap<std::string, AnfNodePtr> var_to_resolve_; 210 211 // Collect all python symbols in the block. 212 // We treat both global symbols and local symbols declared previously as global symbols. 213 py::dict global_py_params_; 214 std::map<std::string, AnfNodePtr> local_py_params_keys_; 215 std::map<std::string, AnfNodePtr> local_py_params_values_; 216 217 // Isolated nodes. 218 OrderedSet<AnfNodePtr> isolated_nodes_; 219 220 // If a block can never be executed, it's prev blocks will be empty, so this block is a dead block. 221 // while x > 5: 222 // x = x - 2 223 // if x > 7 : 224 // break 225 // else : 226 // break 227 // x = x - 1 #This after block is a dead block 228 bool is_dead_block_{false}; 229 230 std::pair<AnfNodePtr, bool> FindPredInterpretNode(const std::string &var_name); 231 // Flags help for determine if parallel-if transformation can be performed or not. 232 // If inside this block include all inner block there is a return statement. 233 // This flag will propagate beyond outer if/else or while/for loop, but not if-by-if; 234 bool is_return_statement_inside_{false}; 235 // If inside this block there is a break/continue statement. 236 // This flag will propagate beyond outer if/else but not while/for loop, if-by-if; 237 bool is_break_continue_statement_inside_{false}; 238 // Set block name for control flow block. 239 std::string block_name_{""}; 240 }; 241 } // namespace parse 242 } // namespace mindspore 243 244 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_ 245