• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
18 #define MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
19 #include <memory>
20 #include <atomic>
21 #include <vector>
22 #include <utility>
23 #include <exception>
24 #include <condition_variable>
25 #include <mutex>
26 
27 #include "pybind11/pytypes.h"
28 #include "pybind11/pybind11.h"
29 #include "base/base.h"
30 #include "ir/value.h"
31 #include "ir/tensor.h"
32 #include "mindapi/base/shape_vector.h"
33 #include "abstract/abstract_value.h"
34 #include "mindspore/core/utils/simple_info.h"
35 
36 namespace mindspore {
37 namespace stub {
38 constexpr auto PY_ATTR_STUB = "stub";
39 constexpr auto PY_ATTR_TENSOR = "tensor";
40 constexpr auto PY_ATTR_SYNC = "stub_sync";
41 
42 namespace py = pybind11;
43 class StubNode;
44 using StubNodePtr = std::shared_ptr<StubNode>;
45 using abstract::AbstractBasePtr;
46 
47 class COMMON_EXPORT StubNode : public Value {
48  public:
49   StubNode() = default;
50   virtual ~StubNode() = default;
51   MS_DECLARE_PARENT(StubNode, Value);
52 
53   virtual bool SetAbstract(const AbstractBasePtr &abs);
54   virtual void SetValue(const ValuePtr &val);
55   virtual void SetException(const std::exception_ptr &e_ptr);
56 
57   ValuePtr WaitValue();
58   virtual bool SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info);
59   void WaitPipeline();
60 
61   AbstractBasePtr ToAbstract() override;
62   bool operator==(const Value &other) const override { return other.isa<StubNode>() && &other == this; }
63 
64  protected:
65   AbstractBasePtr abstract_;
66   ValueSimpleInfoPtr output_value_simple_info_;
67   ValuePtr value_;
68   std::condition_variable cond_var_;
69   std::mutex mutex_;
70   std::exception_ptr e_ptr_{};
71 };
72 
73 class TensorNode : public StubNode {
74  public:
75   TensorNode() = default;
76   MS_DECLARE_PARENT(TensorNode, StubNode);
77   bool SetAbstract(const AbstractBasePtr &abs) override;
78 
79   py::object GetValue();
80   py::object GetShape();
81   py::object GetDtype();
82 };
83 
84 class SequenceNode : public StubNode {
85  public:
elements_(size)86   explicit SequenceNode(size_t size = 0) : elements_(size), is_elements_build_(size > 0) {}
87   MS_DECLARE_PARENT(SequenceNode, StubNode);
88 
89   py::object GetElements();
90 
91   bool SetAbstract(const AbstractBasePtr &abs) override;
92   bool SetValueSimpleInfo(const ValueSimpleInfoPtr &output_value_simple_info) override;
93   void SetValue(const ValuePtr &val) override;
94   void SetException(const std::exception_ptr &e_ptr) override;
95 
SetElement(size_t i,const StubNodePtr & node)96   void SetElement(size_t i, const StubNodePtr &node) { elements_[i] = node; }
Elements()97   const std::vector<StubNodePtr> &Elements() const { return elements_; }
98 
99  private:
100   std::vector<StubNodePtr> elements_;
101   std::atomic<bool> is_elements_build_{false};
102 };
103 using SequenceNodePtr = std::shared_ptr<SequenceNode>;
104 
105 class AnyTypeNode : public StubNode {
106  public:
107   AnyTypeNode() = default;
108   MS_DECLARE_PARENT(AnyTypeNode, StubNode);
109   bool SetAbstract(const AbstractBasePtr &abs) override;
110   void SetValue(const ValuePtr &val) override;
111   void SetException(const std::exception_ptr &e_ptr) override;
112   py::object GetRealNode();
113 
114  private:
115   StubNodePtr real_node_;
116 };
117 
118 class NoneTypeNode : public StubNode {
119  public:
120   NoneTypeNode() = default;
121   MS_DECLARE_PARENT(NoneTypeNode, StubNode);
122   py::object GetRealValue();
123 };
124 
125 COMMON_EXPORT std::pair<py::object, StubNodePtr> MakeTopNode(const TypePtr &type);
126 COMMON_EXPORT void RegStubNodes(const py::module *m);
127 }  // namespace stub
128 }  // namespace mindspore
129 #endif  // MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
130