1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "utils/ms_exception.h"
17 #include "include/common/utils/convert_utils_py.h"
18 #include "include/common/utils/convert_utils.h"
19 #include "include/common/utils/stub_tensor.h"
20 #include "pybind_api/gil_scoped_long_running.h"
21 #include "include/common/profiler.h"
22 #include "include/common/utils/anfalgo.h"
23
24 namespace mindspore {
25 namespace stub {
26 namespace {
MakeStubNode(const TypePtr & type)27 StubNodePtr MakeStubNode(const TypePtr &type) {
28 if (type->isa<TensorType>()) {
29 return std::make_shared<TensorNode>();
30 } else if (type->isa<Tuple>() || type->isa<List>()) {
31 TypePtrList elements;
32 if (type->isa<Tuple>()) {
33 auto tuple_type = type->cast<TuplePtr>();
34 elements = tuple_type->elements();
35 } else {
36 auto list_type = type->cast<ListPtr>();
37 elements = list_type->elements();
38 }
39 auto node = std::make_shared<SequenceNode>(elements.size());
40 for (size_t i = 0; i < elements.size(); ++i) {
41 auto elem = MakeStubNode(elements[i]);
42 node->SetElement(i, elem);
43 }
44 return node;
45 } else if (type == kTypeAny) {
46 return std::make_shared<AnyTypeNode>();
47 } else if (type == kTypeNone) {
48 return std::make_shared<NoneTypeNode>();
49 } else {
50 MS_LOG(WARNING) << "stub tensor is create for type: " << type->ToString();
51 }
52 return nullptr;
53 }
54
MakeOutput(const StubNodePtr & node)55 py::object MakeOutput(const StubNodePtr &node) {
56 if (node->isa<TensorNode>()) {
57 auto tensor = node->cast<std::shared_ptr<TensorNode>>();
58 return py::cast(tensor);
59 } else if (node->isa<SequenceNode>()) {
60 auto seq = node->cast<std::shared_ptr<SequenceNode>>();
61 MS_EXCEPTION_IF_NULL(seq);
62 auto &elements = seq->Elements();
63 if (elements.empty()) {
64 return py::cast(seq);
65 }
66 py::tuple out(elements.size());
67 for (size_t i = 0; i < elements.size(); ++i) {
68 out[i] = MakeOutput(elements[i]);
69 }
70 return out;
71 } else if (node->isa<AnyTypeNode>()) {
72 auto tensor = node->cast<std::shared_ptr<AnyTypeNode>>();
73 return py::cast(tensor);
74 } else {
75 auto tensor = node->cast<std::shared_ptr<NoneTypeNode>>();
76 return py::cast(tensor);
77 }
78 }
79 } // namespace
80
SetAbstract(const AbstractBasePtr & abs)81 bool StubNode::SetAbstract(const AbstractBasePtr &abs) {
82 std::unique_lock<std::mutex> lock(mutex_);
83 abstract_ = abs;
84 cond_var_.notify_all();
85 return true;
86 }
87
SetValueSimpleInfo(const ValueSimpleInfoPtr & output_value_simple_info)88 bool StubNode::SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info) {
89 std::unique_lock<std::mutex> lock(mutex_);
90 output_value_simple_info_ = output_value_simple_info;
91 cond_var_.notify_all();
92 return true;
93 }
SetValue(const ValuePtr & val)94 void StubNode::SetValue(const ValuePtr &val) {
95 std::unique_lock<std::mutex> lock(mutex_);
96 value_ = val;
97 cond_var_.notify_all();
98 }
99
SetException(const std::exception_ptr & e_ptr)100 void StubNode::SetException(const std::exception_ptr &e_ptr) {
101 // cppcheck-suppress unreadVariable
102 std::unique_lock<std::mutex> lock(mutex_);
103 e_ptr_ = e_ptr;
104 cond_var_.notify_all();
105 }
106
WaitValue()107 ValuePtr StubNode::WaitValue() {
108 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
109 // cppcheck-suppress unreadVariable
110 GilReleaseWithCheck gil_release;
111 std::unique_lock<std::mutex> lock(mutex_);
112 cond_var_.wait(lock, [this] { return value_.get() != nullptr || e_ptr_ != nullptr; });
113 if (e_ptr_ != nullptr) {
114 // Need to clear exception in the instance.
115 MsException::Instance().CheckException();
116 std::rethrow_exception(e_ptr_);
117 }
118 return value_;
119 }
120
WaitPipeline()121 void StubNode::WaitPipeline() {
122 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kWaitPipeline);
123 // cppcheck-suppress unreadVariable
124 GilReleaseWithCheck gil_release;
125 std::unique_lock<std::mutex> lock(mutex_);
126 cond_var_.wait(lock, [this] {
127 return abstract_.get() != nullptr || output_value_simple_info_.get() != nullptr || e_ptr_ != nullptr;
128 });
129 if (e_ptr_ != nullptr) {
130 // Need to clear exception in the instance.
131 MsException::Instance().CheckException();
132 std::rethrow_exception(e_ptr_);
133 }
134 }
135
ToAbstract()136 AbstractBasePtr StubNode::ToAbstract() {
137 WaitPipeline();
138 if (output_value_simple_info_ == nullptr) {
139 MS_EXCEPTION_IF_NULL(abstract_);
140 return abstract_;
141 }
142 return TransformValueSimpleInfoToAbstract(*output_value_simple_info_);
143 }
144
GetValue()145 py::object TensorNode::GetValue() {
146 auto val = WaitValue();
147 return ValueToPyData(val);
148 }
149
GetShape()150 py::object TensorNode::GetShape() {
151 WaitPipeline();
152 abstract::ShapePtr shape{nullptr};
153 ShapeVector shape_vector;
154 if (output_value_simple_info_ == nullptr) {
155 MS_EXCEPTION_IF_NULL(abstract_);
156 auto base = abstract_->BuildShape();
157 shape = base->cast<abstract::ShapePtr>();
158 if (shape && !shape->IsDynamic()) {
159 shape_vector = shape->shape();
160 } else {
161 auto val = WaitValue();
162 auto tensor = val->cast<tensor::TensorPtr>();
163 MS_EXCEPTION_IF_NULL(tensor);
164 shape_vector = tensor->shape();
165 }
166 } else {
167 if (output_value_simple_info_->size_ != kIndex1) {
168 MS_LOG(EXCEPTION) << "Simple infer size " << output_value_simple_info_->size_ << " is not equal to 1";
169 }
170 shape_vector = output_value_simple_info_->shape_vector_[kIndex0];
171 }
172 auto ret = py::tuple(shape_vector.size());
173 for (size_t i = 0; i < shape_vector.size(); ++i) {
174 ret[i] = shape_vector[i];
175 }
176 return ret;
177 }
178
GetDtype()179 py::object TensorNode::GetDtype() {
180 WaitPipeline();
181 TypePtr base = nullptr;
182 if (output_value_simple_info_ == nullptr) {
183 MS_EXCEPTION_IF_NULL(abstract_);
184 base = abstract_->BuildType();
185 if (base->isa<TensorType>()) {
186 base = base->cast<TensorTypePtr>()->element();
187 }
188 } else {
189 if (output_value_simple_info_->size_ != kIndex1) {
190 MS_LOG(EXCEPTION) << "Simple infer size " << output_value_simple_info_->size_ << " is not equal to 1";
191 }
192 base = output_value_simple_info_->dtype_vector_[kIndex0];
193 }
194 return py::cast(base);
195 }
196
SetAbstract(const AbstractBasePtr & abs)197 bool TensorNode::SetAbstract(const AbstractBasePtr &abs) {
198 if (!abs->isa<abstract::AbstractTensor>() && !abs->isa<abstract::AbstractMapTensor>()) {
199 if (!abs->isa<abstract::AbstractScalar>() || abs->BuildValue() != kValueAny) {
200 return false;
201 }
202 }
203 return StubNode::SetAbstract(abs);
204 }
205
GetElements()206 py::object SequenceNode::GetElements() {
207 if (!is_elements_build_.load()) {
208 (void)WaitPipeline();
209 }
210 py::tuple out(elements_.size());
211 for (size_t i = 0; i < elements_.size(); ++i) {
212 out[i] = MakeOutput(elements_[i]);
213 }
214 return out;
215 }
216
SetAbstract(const AbstractBasePtr & abs)217 bool SequenceNode::SetAbstract(const AbstractBasePtr &abs) {
218 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
219 if (seq_abs == nullptr) {
220 return false;
221 }
222 auto children = seq_abs->elements();
223 if (!is_elements_build_.load()) {
224 for (const auto &child : children) {
225 (void)elements_.emplace_back(MakeStubNode(child->BuildType()));
226 }
227 }
228 is_elements_build_ = true;
229 if (elements_.size() != children.size()) {
230 return false;
231 }
232 for (size_t i = 0; i < elements_.size(); ++i) {
233 if (!elements_[i]->SetAbstract(children[i])) {
234 return false;
235 }
236 }
237 return StubNode::SetAbstract(abs);
238 }
239
SetValueSimpleInfo(const mindspore::ValueSimpleInfoPtr & output_value_simple_info)240 bool SequenceNode::SetValueSimpleInfo(const mindspore::ValueSimpleInfoPtr &output_value_simple_info) {
241 MS_EXCEPTION_IF_NULL(output_value_simple_info);
242 if (!is_elements_build_.load()) {
243 for (size_t i = 0; i < output_value_simple_info->size_; ++i) {
244 (void)elements_.emplace_back(
245 MakeStubNode(std::make_shared<TensorType>(output_value_simple_info->dtype_vector_[i])));
246 }
247 }
248 is_elements_build_ = true;
249 for (size_t i = 0; i < output_value_simple_info->size_; ++i) {
250 auto elem_simple_info = std::make_shared<mindspore::ValueSimpleInfo>();
251 elem_simple_info->size_ = kIndex1;
252 (void)elem_simple_info->shape_vector_.emplace_back(output_value_simple_info->shape_vector_[i]);
253 (void)elem_simple_info->dtype_vector_.emplace_back(output_value_simple_info->dtype_vector_[i]);
254 MS_EXCEPTION_IF_NULL(elements_[i]);
255 if (!elements_[i]->SetValueSimpleInfo(elem_simple_info)) {
256 return false;
257 }
258 }
259 return StubNode::SetValueSimpleInfo(output_value_simple_info);
260 }
261
SetValue(const ValuePtr & val)262 void SequenceNode::SetValue(const ValuePtr &val) {
263 auto seq_value = val->cast<ValueSequencePtr>();
264 MS_EXCEPTION_IF_NULL(seq_value);
265 auto children = seq_value->value();
266 for (size_t i = 0; i < children.size(); ++i) {
267 elements_[i]->SetValue(children[i]);
268 }
269 StubNode::SetValue(val);
270 }
271
SetException(const std::exception_ptr & e_ptr)272 void SequenceNode::SetException(const std::exception_ptr &e_ptr) {
273 for (auto &element : elements_) {
274 element->SetException(e_ptr);
275 }
276 StubNode::SetException(e_ptr);
277 }
278
SetAbstract(const AbstractBasePtr & abs)279 bool AnyTypeNode::SetAbstract(const AbstractBasePtr &abs) {
280 real_node_ = MakeStubNode(abs->BuildType());
281 auto flag = real_node_->SetAbstract(abs);
282 (void)StubNode::SetAbstract(abs);
283 return flag;
284 }
285
SetValue(const ValuePtr & val)286 void AnyTypeNode::SetValue(const ValuePtr &val) {
287 real_node_->SetValue(val);
288 StubNode::SetValue(val);
289 }
290
GetRealNode()291 py::object AnyTypeNode::GetRealNode() {
292 (void)WaitPipeline();
293 return py::cast(real_node_);
294 }
295
GetRealValue()296 py::object NoneTypeNode::GetRealValue() {
297 auto val = WaitValue();
298 return ValueToPyData(val);
299 }
300
SetException(const std::exception_ptr & e_ptr)301 void AnyTypeNode::SetException(const std::exception_ptr &e_ptr) {
302 StubNode::SetException(e_ptr);
303 if (real_node_ != nullptr) {
304 real_node_->SetException(e_ptr);
305 }
306 }
307
MakeTopNode(const TypePtr & type)308 std::pair<py::object, StubNodePtr> MakeTopNode(const TypePtr &type) {
309 auto top = MakeStubNode(type);
310 auto ret = MakeOutput(top);
311 return std::make_pair(ret, top);
312 }
313
RegStubNodes(const py::module * m)314 void RegStubNodes(const py::module *m) {
315 (void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode");
316 (void)py::class_<TensorNode, StubNode, std::shared_ptr<TensorNode>>(*m, "TensorNode")
317 .def("get_value", &TensorNode::GetValue, "get output value of async stub.")
318 .def("get_shape", &TensorNode::GetShape, "get output shape of async stub.")
319 .def("get_dtype", &TensorNode::GetDtype, "get output dtype of async stub.");
320 (void)py::class_<SequenceNode, StubNode, std::shared_ptr<SequenceNode>>(*m, "SequenceNode")
321 .def("get_elements", &SequenceNode::GetElements, "get the elements of async stub_seq.");
322 (void)py::class_<AnyTypeNode, StubNode, std::shared_ptr<AnyTypeNode>>(*m, "AnyTypeNode")
323 .def("get_real_node", &AnyTypeNode::GetRealNode, "get the real StubNode");
324 (void)py::class_<NoneTypeNode, StubNode, std::shared_ptr<NoneTypeNode>>(*m, "NoneTypeNode")
325 .def("get_real_value", &NoneTypeNode::GetRealValue, "get the real value");
326 }
327 } // namespace stub
328 } // namespace mindspore
329