• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "utils/ms_exception.h"
17 #include "include/common/utils/convert_utils_py.h"
18 #include "include/common/utils/convert_utils.h"
19 #include "include/common/utils/stub_tensor.h"
20 #include "pybind_api/gil_scoped_long_running.h"
21 #include "include/common/profiler.h"
22 #include "include/common/utils/anfalgo.h"
23 
24 namespace mindspore {
25 namespace stub {
26 namespace {
MakeStubNode(const TypePtr & type)27 StubNodePtr MakeStubNode(const TypePtr &type) {
28   if (type->isa<TensorType>()) {
29     return std::make_shared<TensorNode>();
30   } else if (type->isa<Tuple>() || type->isa<List>()) {
31     TypePtrList elements;
32     if (type->isa<Tuple>()) {
33       auto tuple_type = type->cast<TuplePtr>();
34       elements = tuple_type->elements();
35     } else {
36       auto list_type = type->cast<ListPtr>();
37       elements = list_type->elements();
38     }
39     auto node = std::make_shared<SequenceNode>(elements.size());
40     for (size_t i = 0; i < elements.size(); ++i) {
41       auto elem = MakeStubNode(elements[i]);
42       node->SetElement(i, elem);
43     }
44     return node;
45   } else if (type == kTypeAny) {
46     return std::make_shared<AnyTypeNode>();
47   } else if (type == kTypeNone) {
48     return std::make_shared<NoneTypeNode>();
49   } else {
50     MS_LOG(WARNING) << "stub tensor is create for type: " << type->ToString();
51   }
52   return nullptr;
53 }
54 
MakeOutput(const StubNodePtr & node)55 py::object MakeOutput(const StubNodePtr &node) {
56   if (node->isa<TensorNode>()) {
57     auto tensor = node->cast<std::shared_ptr<TensorNode>>();
58     return py::cast(tensor);
59   } else if (node->isa<SequenceNode>()) {
60     auto seq = node->cast<std::shared_ptr<SequenceNode>>();
61     MS_EXCEPTION_IF_NULL(seq);
62     auto &elements = seq->Elements();
63     if (elements.empty()) {
64       return py::cast(seq);
65     }
66     py::tuple out(elements.size());
67     for (size_t i = 0; i < elements.size(); ++i) {
68       out[i] = MakeOutput(elements[i]);
69     }
70     return out;
71   } else if (node->isa<AnyTypeNode>()) {
72     auto tensor = node->cast<std::shared_ptr<AnyTypeNode>>();
73     return py::cast(tensor);
74   } else {
75     auto tensor = node->cast<std::shared_ptr<NoneTypeNode>>();
76     return py::cast(tensor);
77   }
78 }
79 }  // namespace
80 
SetAbstract(const AbstractBasePtr & abs)81 bool StubNode::SetAbstract(const AbstractBasePtr &abs) {
82   std::unique_lock<std::mutex> lock(mutex_);
83   abstract_ = abs;
84   cond_var_.notify_all();
85   return true;
86 }
87 
SetValueSimpleInfo(const ValueSimpleInfoPtr & output_value_simple_info)88 bool StubNode::SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info) {
89   std::unique_lock<std::mutex> lock(mutex_);
90   output_value_simple_info_ = output_value_simple_info;
91   cond_var_.notify_all();
92   return true;
93 }
SetValue(const ValuePtr & val)94 void StubNode::SetValue(const ValuePtr &val) {
95   std::unique_lock<std::mutex> lock(mutex_);
96   value_ = val;
97   cond_var_.notify_all();
98 }
99 
SetException(const std::exception_ptr & e_ptr)100 void StubNode::SetException(const std::exception_ptr &e_ptr) {
101   // cppcheck-suppress unreadVariable
102   std::unique_lock<std::mutex> lock(mutex_);
103   e_ptr_ = e_ptr;
104   cond_var_.notify_all();
105 }
106 
WaitValue()107 ValuePtr StubNode::WaitValue() {
108   runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
109   // cppcheck-suppress unreadVariable
110   GilReleaseWithCheck gil_release;
111   std::unique_lock<std::mutex> lock(mutex_);
112   cond_var_.wait(lock, [this] { return value_.get() != nullptr || e_ptr_ != nullptr; });
113   if (e_ptr_ != nullptr) {
114     // Need to clear exception in the instance.
115     MsException::Instance().CheckException();
116     std::rethrow_exception(e_ptr_);
117   }
118   return value_;
119 }
120 
WaitPipeline()121 void StubNode::WaitPipeline() {
122   runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
123   // cppcheck-suppress unreadVariable
124   GilReleaseWithCheck gil_release;
125   std::unique_lock<std::mutex> lock(mutex_);
126   cond_var_.wait(lock, [this] {
127     return abstract_.get() != nullptr || output_value_simple_info_.get() != nullptr || e_ptr_ != nullptr;
128   });
129   if (e_ptr_ != nullptr) {
130     // Need to clear exception in the instance.
131     MsException::Instance().CheckException();
132     std::rethrow_exception(e_ptr_);
133   }
134 }
135 
ToAbstract()136 AbstractBasePtr StubNode::ToAbstract() {
137   WaitPipeline();
138   if (output_value_simple_info_ == nullptr) {
139     MS_EXCEPTION_IF_NULL(abstract_);
140     return abstract_;
141   }
142   return TransformValueSimpleInfoToAbstract(*output_value_simple_info_);
143 }
144 
GetValue()145 py::object TensorNode::GetValue() {
146   auto val = WaitValue();
147   return ValueToPyData(val);
148 }
149 
GetShape()150 py::object TensorNode::GetShape() {
151   WaitPipeline();
152   abstract::ShapePtr shape{nullptr};
153   ShapeVector shape_vector;
154   if (output_value_simple_info_ == nullptr) {
155     MS_EXCEPTION_IF_NULL(abstract_);
156     auto base = abstract_->BuildShape();
157     shape = base->cast<abstract::ShapePtr>();
158     if (shape && !shape->IsDynamic()) {
159       shape_vector = shape->shape();
160     } else {
161       auto val = WaitValue();
162       auto tensor = val->cast<tensor::TensorPtr>();
163       MS_EXCEPTION_IF_NULL(tensor);
164       shape_vector = tensor->shape();
165     }
166   } else {
167     if (output_value_simple_info_->size_ != kIndex1) {
168       MS_LOG(EXCEPTION) << "Simple infer size " << output_value_simple_info_->size_ << " is not equal to 1";
169     }
170     shape_vector = output_value_simple_info_->shape_vector_[kIndex0];
171   }
172   auto ret = py::tuple(shape_vector.size());
173   for (size_t i = 0; i < shape_vector.size(); ++i) {
174     ret[i] = shape_vector[i];
175   }
176   return ret;
177 }
178 
GetDtype()179 py::object TensorNode::GetDtype() {
180   WaitPipeline();
181   TypePtr base = nullptr;
182   if (output_value_simple_info_ == nullptr) {
183     MS_EXCEPTION_IF_NULL(abstract_);
184     base = abstract_->BuildType();
185     if (base->isa<TensorType>()) {
186       base = base->cast<TensorTypePtr>()->element();
187     }
188   } else {
189     if (output_value_simple_info_->size_ != kIndex1) {
190       MS_LOG(EXCEPTION) << "Simple infer size " << output_value_simple_info_->size_ << " is not equal to 1";
191     }
192     base = output_value_simple_info_->dtype_vector_[kIndex0];
193   }
194   return py::cast(base);
195 }
196 
SetAbstract(const AbstractBasePtr & abs)197 bool TensorNode::SetAbstract(const AbstractBasePtr &abs) {
198   if (!abs->isa<abstract::AbstractTensor>() && !abs->isa<abstract::AbstractMapTensor>()) {
199     if (!abs->isa<abstract::AbstractScalar>() || abs->BuildValue() != kValueAny) {
200       return false;
201     }
202   }
203   return StubNode::SetAbstract(abs);
204 }
205 
GetElements()206 py::object SequenceNode::GetElements() {
207   if (!is_elements_build_.load()) {
208     (void)WaitPipeline();
209   }
210   py::tuple out(elements_.size());
211   for (size_t i = 0; i < elements_.size(); ++i) {
212     out[i] = MakeOutput(elements_[i]);
213   }
214   return out;
215 }
216 
SetAbstract(const AbstractBasePtr & abs)217 bool SequenceNode::SetAbstract(const AbstractBasePtr &abs) {
218   auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
219   if (seq_abs == nullptr) {
220     return false;
221   }
222   auto children = seq_abs->elements();
223   if (!is_elements_build_.load()) {
224     for (const auto &child : children) {
225       (void)elements_.emplace_back(MakeStubNode(child->BuildType()));
226     }
227   }
228   is_elements_build_ = true;
229   if (elements_.size() != children.size()) {
230     return false;
231   }
232   for (size_t i = 0; i < elements_.size(); ++i) {
233     if (!elements_[i]->SetAbstract(children[i])) {
234       return false;
235     }
236   }
237   return StubNode::SetAbstract(abs);
238 }
239 
SetValueSimpleInfo(const mindspore::ValueSimpleInfoPtr & output_value_simple_info)240 bool SequenceNode::SetValueSimpleInfo(const mindspore::ValueSimpleInfoPtr &output_value_simple_info) {
241   MS_EXCEPTION_IF_NULL(output_value_simple_info);
242   if (!is_elements_build_.load()) {
243     for (size_t i = 0; i < output_value_simple_info->size_; ++i) {
244       (void)elements_.emplace_back(
245         MakeStubNode(std::make_shared<TensorType>(output_value_simple_info->dtype_vector_[i])));
246     }
247   }
248   is_elements_build_ = true;
249   for (size_t i = 0; i < output_value_simple_info->size_; ++i) {
250     auto elem_simple_info = std::make_shared<mindspore::ValueSimpleInfo>();
251     elem_simple_info->size_ = kIndex1;
252     (void)elem_simple_info->shape_vector_.emplace_back(output_value_simple_info->shape_vector_[i]);
253     (void)elem_simple_info->dtype_vector_.emplace_back(output_value_simple_info->dtype_vector_[i]);
254     MS_EXCEPTION_IF_NULL(elements_[i]);
255     if (!elements_[i]->SetValueSimpleInfo(elem_simple_info)) {
256       return false;
257     }
258   }
259   return StubNode::SetValueSimpleInfo(output_value_simple_info);
260 }
261 
SetValue(const ValuePtr & val)262 void SequenceNode::SetValue(const ValuePtr &val) {
263   auto seq_value = val->cast<ValueSequencePtr>();
264   MS_EXCEPTION_IF_NULL(seq_value);
265   auto children = seq_value->value();
266   for (size_t i = 0; i < children.size(); ++i) {
267     elements_[i]->SetValue(children[i]);
268   }
269   StubNode::SetValue(val);
270 }
271 
SetException(const std::exception_ptr & e_ptr)272 void SequenceNode::SetException(const std::exception_ptr &e_ptr) {
273   for (auto &element : elements_) {
274     element->SetException(e_ptr);
275   }
276   StubNode::SetException(e_ptr);
277 }
278 
SetAbstract(const AbstractBasePtr & abs)279 bool AnyTypeNode::SetAbstract(const AbstractBasePtr &abs) {
280   real_node_ = MakeStubNode(abs->BuildType());
281   auto flag = real_node_->SetAbstract(abs);
282   (void)StubNode::SetAbstract(abs);
283   return flag;
284 }
285 
SetValue(const ValuePtr & val)286 void AnyTypeNode::SetValue(const ValuePtr &val) {
287   real_node_->SetValue(val);
288   StubNode::SetValue(val);
289 }
290 
GetRealNode()291 py::object AnyTypeNode::GetRealNode() {
292   (void)WaitPipeline();
293   return py::cast(real_node_);
294 }
295 
GetRealValue()296 py::object NoneTypeNode::GetRealValue() {
297   auto val = WaitValue();
298   return ValueToPyData(val);
299 }
300 
SetException(const std::exception_ptr & e_ptr)301 void AnyTypeNode::SetException(const std::exception_ptr &e_ptr) {
302   StubNode::SetException(e_ptr);
303   if (real_node_ != nullptr) {
304     real_node_->SetException(e_ptr);
305   }
306 }
307 
MakeTopNode(const TypePtr & type)308 std::pair<py::object, StubNodePtr> MakeTopNode(const TypePtr &type) {
309   auto top = MakeStubNode(type);
310   auto ret = MakeOutput(top);
311   return std::make_pair(ret, top);
312 }
313 
RegStubNodes(const py::module * m)314 void RegStubNodes(const py::module *m) {
315   (void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode");
316   (void)py::class_<TensorNode, StubNode, std::shared_ptr<TensorNode>>(*m, "TensorNode")
317     .def("get_value", &TensorNode::GetValue, "get output value of async stub.")
318     .def("get_shape", &TensorNode::GetShape, "get output shape of async stub.")
319     .def("get_dtype", &TensorNode::GetDtype, "get output dtype of async stub.");
320   (void)py::class_<SequenceNode, StubNode, std::shared_ptr<SequenceNode>>(*m, "SequenceNode")
321     .def("get_elements", &SequenceNode::GetElements, "get the elements of async stub_seq.");
322   (void)py::class_<AnyTypeNode, StubNode, std::shared_ptr<AnyTypeNode>>(*m, "AnyTypeNode")
323     .def("get_real_node", &AnyTypeNode::GetRealNode, "get the real StubNode");
324   (void)py::class_<NoneTypeNode, StubNode, std::shared_ptr<NoneTypeNode>>(*m, "NoneTypeNode")
325     .def("get_real_value", &NoneTypeNode::GetRealValue, "get the real value");
326 }
327 }  // namespace stub
328 }  // namespace mindspore
329