1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_ 18 #define MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_ 19 #include <memory> 20 #include <atomic> 21 #include <vector> 22 #include <utility> 23 #include <exception> 24 #include <condition_variable> 25 #include <mutex> 26 27 #include "pybind11/pytypes.h" 28 #include "pybind11/pybind11.h" 29 #include "base/base.h" 30 #include "ir/value.h" 31 #include "ir/tensor.h" 32 #include "mindapi/base/shape_vector.h" 33 #include "abstract/abstract_value.h" 34 #include "mindspore/core/utils/simple_info.h" 35 36 namespace mindspore { 37 namespace stub { 38 constexpr auto PY_ATTR_STUB = "stub"; 39 constexpr auto PY_ATTR_TENSOR = "tensor"; 40 constexpr auto PY_ATTR_SYNC = "stub_sync"; 41 42 namespace py = pybind11; 43 class StubNode; 44 using StubNodePtr = std::shared_ptr<StubNode>; 45 using abstract::AbstractBasePtr; 46 47 class COMMON_EXPORT StubNode : public Value { 48 public: 49 StubNode() = default; 50 virtual ~StubNode() = default; 51 MS_DECLARE_PARENT(StubNode, Value); 52 53 virtual bool SetAbstract(const AbstractBasePtr &abs); 54 virtual void SetValue(const ValuePtr &val); 55 virtual void SetException(const std::exception_ptr &e_ptr); 56 57 ValuePtr WaitValue(); 58 virtual bool SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info); 59 void WaitPipeline(); 60 61 AbstractBasePtr ToAbstract() override; 62 bool operator==(const Value &other) const override { return other.isa<StubNode>() && &other == this; } 63 64 protected: 65 AbstractBasePtr abstract_; 66 ValueSimpleInfoPtr output_value_simple_info_; 67 ValuePtr value_; 68 std::condition_variable cond_var_; 69 std::mutex mutex_; 70 std::exception_ptr e_ptr_{}; 71 }; 72 73 class TensorNode : public StubNode { 74 public: 75 TensorNode() = default; 76 MS_DECLARE_PARENT(TensorNode, StubNode); 77 bool SetAbstract(const AbstractBasePtr &abs) override; 78 79 py::object GetValue(); 80 py::object GetShape(); 81 py::object GetDtype(); 82 }; 83 84 class SequenceNode : public StubNode { 85 public: elements_(size)86 explicit SequenceNode(size_t size = 0) : elements_(size), is_elements_build_(size > 0) {} 87 MS_DECLARE_PARENT(SequenceNode, StubNode); 88 89 py::object GetElements(); 90 91 bool SetAbstract(const AbstractBasePtr &abs) override; 92 bool SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info) override; 93 void SetValue(const ValuePtr &val) override; 94 void SetException(const std::exception_ptr &e_ptr) override; 95 SetElement(size_t i,const StubNodePtr & node)96 void SetElement(size_t i, const StubNodePtr &node) { elements_[i] = node; } Elements()97 const std::vector<StubNodePtr> &Elements() const { return elements_; } 98 99 private: 100 std::vector<StubNodePtr> elements_; 101 std::atomic<bool> is_elements_build_{false}; 102 }; 103 using SequenceNodePtr = std::shared_ptr<SequenceNode>; 104 105 class AnyTypeNode : public StubNode { 106 public: 107 AnyTypeNode() = default; 108 MS_DECLARE_PARENT(AnyTypeNode, StubNode); 109 bool SetAbstract(const AbstractBasePtr &abs) override; 110 void SetValue(const ValuePtr &val) override; 111 void SetException(const std::exception_ptr &e_ptr) override; 112 py::object GetRealNode(); 113 114 private: 115 StubNodePtr real_node_; 116 }; 117 118 class NoneTypeNode : public StubNode { 119 public: 120 NoneTypeNode() = default; 121 MS_DECLARE_PARENT(NoneTypeNode, StubNode); 122 py::object GetRealValue(); 123 }; 124 125 COMMON_EXPORT std::pair<py::object, StubNodePtr> MakeTopNode(const TypePtr &type); 126 COMMON_EXPORT void RegStubNodes(const py::module *m); 127 } // namespace stub 128 } // namespace mindspore 129 #endif // MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_ 130