• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_
21 
22 #include <vector>
23 #include <string>
24 #include <map>
25 #include <set>
26 #include <memory>
27 #include <utility>
28 #include <tuple>
29 
30 #include "utils/hash_map.h"
31 #include "ir/meta_func_graph.h"
32 #include "pipeline/jit/ps/parse/parse_base.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "utils/log_adapter.h"
35 #include "utils/ordered_set.h"
36 
37 namespace mindspore {
38 namespace parse {
39 class Parser;
40 class NameSpace;
41 class Symbol;
42 class Script;
43 class FunctionBlock;
44 using FunctionBlockPtr = std::shared_ptr<FunctionBlock>;
45 
46 // A function block is a straight-line code sequence with no branches, every block has one one exit point
47 // which is return. When parsing function, loop or branch , we use function block to track the structure of
48 // the original source code.
49 class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
50  public:
51   explicit FunctionBlock(const Parser &parser);
52   virtual ~FunctionBlock() = default;
53 
func_graph()54   FuncGraphPtr func_graph() { return func_graph_; }
ToString()55   std::string ToString() const { return func_graph_->ToString(); }
56   void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
57   AnfNodePtr ReadVariable(const std::string &var_name);
58   AnfNodePtr ReadLocalVariable(const std::string &var_name);
59   bool CheckHasVariable(const std::string &var_name);
60   void AddPrevBlock(const FunctionBlockPtr &block);
61   void SetPhiArgument(const ParameterPtr &phi);
62   void CollectRemovablePhi(const ParameterPtr &phi);
63   // A block is matured if all its predecessors is generated
64   void Mature();
65   CNodePtr ForceToCondNode(const AnfNodePtr &cond, bool is_while_cond = false);
66   void Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args);
67   std::set<AnfNodePtr> SearchAllArgsOfPhiNode(const std::string &var, const ParameterPtr &phi);
68   CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
69                            const AnfNodePtr &false_block_call);
70   CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
71                            const FunctionBlockPtr &false_block);
72   // Create cnode for the assign statement like self.target = source.
73   void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source);
AddGlobalVar(const std::string & var_name)74   void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); }
IsGlobalVar(const std::string & var_name)75   bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); }
76 
77   py::tuple GetAstOpNameSpace(const py::object &op);
78   AnfNodePtr MakeResolveAstOpNameSpace(const py::tuple &namespace_var);
79   AnfNodePtr MakeResolveClassObject();
80   AnfNodePtr MakeResolveClassMember(const std::string &attr_or_self);
81   AnfNodePtr MakeResolveSymbol(const std::string &value);
82   AnfNodePtr MakeResolveOperation(const std::string &value);
83   AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
84   AnfNodePtr DoResolve(const AnfNodePtr &node, const std::shared_ptr<NameSpace> &name_space,
85                        const std::shared_ptr<Symbol> &resolve_symbol);
86   AnfNodePtr HandleNamespaceSymbol(const std::string &var_name);
87   AnfNodePtr MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node,
88                            const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node);
phi_args()89   const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_args() const { return phi_args_; }
90   void FindIsolatedNodes();
91   void ConvertUnusedNodesToIsolated(const std::pair<std::string, std::pair<AnfNodePtr, bool>> var);
92   void AddIsolatedNode(const AnfNodePtr &target);
93   void AttachIsolatedNodesBeforeReturn();
prev_blocks()94   const std::vector<FunctionBlock *> &prev_blocks() const { return prev_blocks_; }
is_dead_block()95   bool is_dead_block() const { return is_dead_block_; }
96   void SetAsDeadBlock();
97   CNodePtr GetJumpNode(FunctionBlock *target_block);
98 
is_return_statement_inside()99   bool is_return_statement_inside() const { return is_return_statement_inside_; }
100   void set_is_return_statement_inside();
is_break_continue_statement_inside()101   bool is_break_continue_statement_inside() const { return is_break_continue_statement_inside_; }
102   void set_break_continue_statement_inside();
103 
block_name()104   const std::string block_name() const { return block_name_; }
set_block_name(const std::string & block_name)105   void set_block_name(const std::string &block_name) { block_name_ = block_name; }
106   void CheckUndefinedSymbol(const std::string &var, const AnfNodePtr &node) const;
107   void CheckVariableNotDefined(const std::pair<std::string, AnfNodePtr> &not_defined_branch, const std::string &var);
global_py_params()108   const py::dict &global_py_params() const { return global_py_params_; }
set_global_py_params(const py::dict & symbols)109   void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
HasGlobalPyParam(const std::string & name)110   bool HasGlobalPyParam(const std::string &name) const { return global_py_params_.contains(py::str(name)); }
AddGlobalPyParam(const std::string & name,const py::object & obj)111   void AddGlobalPyParam(const std::string &name, const py::object &obj) {
112     MS_LOG(DEBUG) << "Add global param '" << name << "', " << py::str(obj) << " for the block:" << ToString();
113     global_py_params_[py::str(name)] = obj;
114   }
UpdateGlobalPyParam(const py::dict & symbols)115   void UpdateGlobalPyParam(const py::dict &symbols) {
116     for (auto &param : symbols) {
117       if (!global_py_params_.contains(param.first)) {
118         MS_LOG(DEBUG) << "Update global param '" << param.first << "', " << py::str(param.second)
119                       << " for the block:" << ToString();
120         global_py_params_[param.first] = param.second;
121       }
122     }
123   }
124 
local_py_params()125   std::tuple<std::map<std::string, AnfNodePtr>, std::map<std::string, AnfNodePtr>> local_py_params() {
126     return {local_py_params_keys_, local_py_params_values_};
127   }
128 
129   // Call this method to update or add a variable.
UpdateLocalPyParam(const std::string & name,const AnfNodePtr & node)130   void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
131     MS_EXCEPTION_IF_NULL(node);
132     const auto key_iter = local_py_params_keys_.find(name);
133     if (key_iter == local_py_params_keys_.end()) {
134       MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString();
135       (void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name)));
136       (void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(name, node));
137     } else {
138       // Find the same position in 'values', and update the node.
139       MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> "
140                     << node->DebugString();
141       local_py_params_values_[name] = node;
142     }
143   }
144 
145   // Update local parameters from previous block.
UpdateLocalPyParam(const std::map<std::string,AnfNodePtr> & keys,std::map<std::string,AnfNodePtr> values)146   void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) {
147     if (keys.size() != values.size()) {
148       MS_LOG(INTERNAL_EXCEPTION) << "keys size should be equal to values size.";
149     }
150     for (auto iter = keys.begin(); iter != keys.end(); ++iter) {
151       const std::string &cur_key_name = iter->first;
152       const auto key_iter = local_py_params_keys_.find(cur_key_name);
153       if (key_iter == local_py_params_keys_.end()) {
154         (void)local_py_params_keys_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second));
155         (void)local_py_params_values_.emplace(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name]));
156         MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
157       } else {
158         // The local variable is already in the current block. This means the current block has multiples previous
159         // blocks. If this local variable is used in the current block, it should be converted to phi node. So we erase
160         // it from local_py_params.
161         (void)local_py_params_keys_.erase(key_iter);
162         (void)local_py_params_values_.erase(cur_key_name);
163         MS_LOG(DEBUG) << "Erase '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString();
164       }
165     }
166     if (local_py_params_keys_.size() != local_py_params_values_.size()) {
167       MS_LOG(INTERNAL_EXCEPTION)
168         << "The size of local_py_params_keys_ should be equal to local_py_params_values_ size.";
169     }
170   }
171 
172   // Isolated nodes.
isolated_nodes()173   const OrderedSet<AnfNodePtr> isolated_nodes() const { return isolated_nodes_; }
174 
175  private:
176   // Block graph
177   FuncGraphPtr func_graph_;
178 
179   // Block parser
180   const Parser &parser_;
181 
182   // A block is matured if all its prev_blocks is processed
183   bool matured_;
184 
185   // Store the nest-level block.
186   // Refer to comments in Parser::func_block_list_;
187   std::vector<FunctionBlock *> prev_blocks_;
188 
189   // Store args and variable's node, use a bool flag to indicate if the variable is used.
190   std::map<std::string, std::pair<AnfNodePtr, bool>> assigned_vars_;
191 
192   // Store the attribute that has been changed.
193   std::map<std::string, std::pair<AnfNodePtr, bool>> changed_non_param_attrs_;
194 
195   // Map the parameter node to variable, it can be resolved if the block's predecessors are processed
196   std::map<ParameterPtr, std::string> phi_nodes_;
197 
198   // Jumps map the successor block and the function call that perform jump
199   // Refer to comments in Parser::func_block_list_ that how to break the cyclic reference
200   std::map<FunctionBlock *, CNodePtr> jumps_;
201 
202   // Keep all removable phis which will be removed in one pass.
203   std::map<ParameterPtr, std::set<AnfNodePtr>> phi_args_;
204 
205   // Hold declared global variables in function
206   std::set<std::string> global_vars_;
207 
208   // Keep new made resolve symbol for the variable not found in vars_.
209   mindspore::HashMap<std::string, AnfNodePtr> var_to_resolve_;
210 
211   // Collect all python symbols in the block.
212   // We treat both global symbols and local symbols declared previously as global symbols.
213   py::dict global_py_params_;
214   std::map<std::string, AnfNodePtr> local_py_params_keys_;
215   std::map<std::string, AnfNodePtr> local_py_params_values_;
216 
217   // Isolated nodes.
218   OrderedSet<AnfNodePtr> isolated_nodes_;
219 
220   // If a block can never be executed, it's prev blocks will be empty, so this block is a dead block.
221   // while x > 5:
222   //    x = x - 2
223   //    if x > 7 :
224   //         break
225   //    else :
226   //         break
227   //    x = x - 1   #This after block is a dead block
228   bool is_dead_block_{false};
229 
230   std::pair<AnfNodePtr, bool> FindPredInterpretNode(const std::string &var_name);
231   // Flags help for determine if parallel-if transformation can be performed or not.
232   // If inside this block include all inner block there is a return statement.
233   // This flag will propagate beyond outer if/else or while/for loop, but not if-by-if;
234   bool is_return_statement_inside_{false};
235   // If inside this block there is a break/continue statement.
236   // This flag will propagate beyond outer if/else but not while/for loop, if-by-if;
237   bool is_break_continue_statement_inside_{false};
238   // Set block name for control flow block.
239   std::string block_name_{""};
240 };
241 }  // namespace parse
242 }  // namespace mindspore
243 
244 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_FUNCTION_BLOCK_H_
245