• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_INCLUDE_API_CELL_H
17 #define MINDSPORE_INCLUDE_API_CELL_H
18 #include <string>
19 #include <vector>
20 #include <map>
21 #include <memory>
22 #include "include/api/status.h"
23 #include "include/api/types.h"
24 #include "include/api/graph.h"
25 
26 namespace mindspore {
27 class InputAndOutput;
28 class Context;
29 using Input = InputAndOutput;
30 using Output = InputAndOutput;
31 
32 class MS_API CellBase {
33  public:
34   CellBase() = default;
35   virtual ~CellBase() = default;
36   virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
37   virtual std::shared_ptr<CellBase> Clone() const = 0;
38   virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
39   std::vector<Output> operator()(const std::vector<Input> &inputs) const;
40 };
41 
42 template <class T>
43 class MS_API Cell : public CellBase {
44  public:
45   virtual ~Cell() = default;
46   std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
47 };
48 
49 class MS_API ParameterCell final : public Cell<ParameterCell> {
50  public:
51   ParameterCell() = default;
52   ~ParameterCell() override = default;
53 
54   ParameterCell(const ParameterCell &);
55   ParameterCell &operator=(const ParameterCell &);
56 
57   ParameterCell(ParameterCell &&);
58   ParameterCell &operator=(ParameterCell &&);
59 
60   explicit ParameterCell(const MSTensor &);
61   ParameterCell &operator=(const MSTensor &);
62 
63   explicit ParameterCell(MSTensor &&);
64   ParameterCell &operator=(MSTensor &&);
65 
66   MSTensor GetTensor() const { return tensor_; }
67 
68  private:
69   MSTensor tensor_;
70 };
71 
72 class MS_API OpCellBase : public CellBase {
73  public:
74   explicit OpCellBase(const std::string &name) : name_(name) {}
75   ~OpCellBase() override = default;
76   const std::string &GetOpType() const { return name_; }
77 
78  protected:
79   std::string name_;
80 };
81 
82 template <class T>
83 class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> {
84  public:
85   explicit OpCell(const std::string &name) : OpCellBase(name) {}
86   ~OpCell() override = default;
87   std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
88 };
89 
90 class MS_API GraphCell final : public Cell<GraphCell> {
91  public:
92   class GraphImpl;
93 
94   GraphCell() = default;
95   ~GraphCell() override = default;
96 
97   explicit GraphCell(const Graph &);
98   explicit GraphCell(Graph &&);
99   explicit GraphCell(const std::shared_ptr<Graph> &);
100 
101   void SetContext(const std::shared_ptr<Context> &context);
102   const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
103   Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
104   std::vector<MSTensor> GetInputs();
105   std::vector<MSTensor> GetOutputs();
106   Status Load(uint32_t device_id);
107 
108  private:
109   friend class Model;
110 
111   std::shared_ptr<Graph> graph_;
112   std::shared_ptr<GraphImpl> executor_;
113 };
114 
115 class MS_API InputAndOutput {
116  public:
117   InputAndOutput();
118   ~InputAndOutput() = default;
119 
120   // no explicit
121   InputAndOutput(const MSTensor &);  // NOLINT(runtime/explicit)
122   InputAndOutput(MSTensor &&);       // NOLINT(runtime/explicit)
123 
124   InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
125 
126   int32_t GetIndex() const { return index_; }
127   void SetIndex(int32_t index) { index_ = index; }
128 
129  private:
130   std::shared_ptr<CellBase> cell_;
131   std::vector<InputAndOutput> prev_;
132   int32_t index_;
133 };
134 }  // namespace mindspore
135 #endif  // MINDSPORE_INCLUDE_API_CELL_H
136