• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/generate_graph.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 
23 #include "pipeline/jit/parse/python_adapter.h"
24 #include "utils/convert_utils_py.h"
25 #include "frontend/parallel/graph_util/node_info.h"
26 
27 using mindspore::tensor::Tensor;
28 
29 namespace mindspore {
30 namespace parallel {
GetOpPythonPath(const OperatorName & op_name)31 std::string GetOpPythonPath(const OperatorName &op_name) {
32   // almost all ops are defined in two main paths
33   const std::string ops_module = OP_PATH;
34   const std::string inner_ops_module = INNER_OP_PATH;
35   const std::string functional_op_module = FUNCTIONAL_OP_PATH;
36   py::module mod = py::module::import(common::SafeCStr(ops_module));
37   py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
38   py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
39 
40   if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
41     return inner_ops_module;
42   }
43   if (py::hasattr(mod, common::SafeCStr(op_name))) {
44     return ops_module;
45   }
46   if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
47     MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module
48                       << " don't have op:" << op_name;
49   }
50   return functional_op_module;
51 }
52 
CreatOpInstance(const OperatorAttrs & attrs,const OperatorName & op_name,const std::string & instance_name)53 ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
54   std::string op_path = GetOpPythonPath(op_name);
55   py::module mod = py::module::import(common::SafeCStr(op_path));
56   if (!py::hasattr(mod, common::SafeCStr(op_name))) {
57     MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
58     return nullptr;
59   }
60   std::vector<py::object> arg_list;
61   (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
62                        [](const Attr &attr) { return ValueToPyData(attr.second); });
63   py::object obj =
64     parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
65   ValuePtr op_instance = nullptr;
66   bool succ = parse::ConvertData(obj, &op_instance);
67   if (!succ) {
68     MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
69     return nullptr;
70   }
71   return op_instance;
72 }
73 
ValuePtrToAnfNodePtr(const ValuePtr & value_ptr)74 AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
75   auto value_node = NewValueNode(value_ptr);
76   MS_EXCEPTION_IF_NULL(value_node);
77   return value_node->cast<AnfNodePtr>();
78 }
79 
80 static std::unordered_map<int64_t, AnfNodePtr> int_tensor_map = {};
CreateInt32Tensor(int64_t value)81 AnfNodePtr CreateInt32Tensor(int64_t value) {
82   auto it = int_tensor_map.find(value);
83   if (it != int_tensor_map.end()) {
84     return it->second;
85   }
86   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt32);
87   ValuePtr value_ptr = MakeValue(tensor_ptr);
88   auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
89   int_tensor_map[value] = anf_node_ptr;
90   return anf_node_ptr;
91 }
92 
CreatTypeInt(int64_t value)93 AnfNodePtr CreatTypeInt(int64_t value) {
94   ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
95   return ValuePtrToAnfNodePtr(value_ptr);
96 }
97 
CreatInt64Imm(int64_t value)98 AnfNodePtr CreatInt64Imm(int64_t value) {
99   ValuePtr value_ptr = MakeValue(std::make_shared<Int64Imm>(value));
100   return ValuePtrToAnfNodePtr(value_ptr);
101 }
102 
CreateTuple(const std::vector<int64_t> & tuple)103 AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple) {
104   std::vector<ValuePtr> value_list;
105   (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
106                        [](const int64_t value) { return MakeValue(value); });
107   ValueTuplePtr value_tuple_ptr = std::make_shared<ValueTuple>(value_list);
108   return ValuePtrToAnfNodePtr(value_tuple_ptr);
109 }
110 
GetInstanceNameByCNode(const CNodePtr & cnode)111 std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
112   PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
113   if (!prim) {
114     MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
115   }
116   std::string instance_name = prim->instance_name();
117   return HashInstanceName(instance_name);
118 }
119 
HashInstanceName(const std::string & name)120 std::string HashInstanceName(const std::string &name) {
121   auto using_hash_name = common::GetEnv(USING_HASH_NAME);
122   std::string instance_name;
123   if ((using_hash_name.empty()) || (using_hash_name == "on")) {
124     instance_name = HashName(name);
125   } else {
126     instance_name = name;
127   }
128   return instance_name;
129 }
130 
Init(const CNodePtr & cnode)131 Status GenerateGraph::Init(const CNodePtr &cnode) {
132   if (!cnode) {
133     MS_LOG(ERROR) << "Init:cnode is nullptr";
134     return FAILED;
135   }
136   cnode_ = cnode;
137   func_graph_ = cnode->func_graph();
138   if (!func_graph_) {
139     MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
140     return FAILED;
141   }
142   manager_ = func_graph_->manager();
143   if (!manager_) {
144     MS_LOG(ERROR) << "Init:manager_ is nullptr";
145     return FAILED;
146   }
147   scope_ = cnode_->scope();
148   if (!scope_) {
149     MS_LOG(ERROR) << "Init:scope_ is nullptr";
150     return FAILED;
151   }
152   virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
153   virtual_input_node_->set_scope(scope_);
154   instance_name_base_ = GetInstanceNameByCNode(cnode_);
155   name_idx_ = 0;
156   return SUCCESS;
157 }
158 
PushBack(const std::vector<AnfNodePtr> & inputs)159 AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
160   CNodePtr cnode = func_graph_->NewCNode(inputs);  // using NewCNode to create anfnode
161   MS_EXCEPTION_IF_NULL(cnode);
162   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
163   SetUserAttrs(origin_attrs_, prim);
164   cnode->set_scope(scope_);
165   if (inputs.size() < 2) {
166     MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
167   }
168   (void)manager_->Replace(inputs.at(1), cnode);  // using Replace function to insert cnode after inputs[1]
169   auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
170   MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
171   return new_anf_node_ptr;
172 }
173 
NewOpInst(const OperatorName & op_name,const OperatorAttrs & attrs)174 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
175   name_idx_++;
176   ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
177   if (pyop_instance == nullptr) {
178     MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
179   }
180   auto value_node = NewValueNode(pyop_instance);
181   return value_node->cast<AnfNodePtr>();
182 }
183 
NewOpInst(const OperatorName & op_name)184 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
185   name_idx_++;
186   OperatorAttrs attrs;
187   ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
188   if (pyop_instance == nullptr) {
189     MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
190   }
191   auto value_node = NewValueNode(pyop_instance);
192   return value_node->cast<AnfNodePtr>();
193 }
194 }  // namespace parallel
195 }  // namespace mindspore
196