1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_INCLUDE_API_CELL_H 17 #define MINDSPORE_INCLUDE_API_CELL_H 18 #include <string> 19 #include <vector> 20 #include <map> 21 #include <memory> 22 #include "include/api/status.h" 23 #include "include/api/types.h" 24 #include "include/api/graph.h" 25 26 namespace mindspore { 27 class InputAndOutput; 28 class Context; 29 using Input = InputAndOutput; 30 using Output = InputAndOutput; 31 32 class MS_API CellBase { 33 public: 34 CellBase() = default; 35 virtual ~CellBase() = default; 36 virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; } 37 virtual std::shared_ptr<CellBase> Clone() const = 0; 38 virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; } 39 std::vector<Output> operator()(const std::vector<Input> &inputs) const; 40 }; 41 42 template <class T> 43 class MS_API Cell : public CellBase { 44 public: 45 virtual ~Cell() = default; 46 std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); } 47 }; 48 49 class MS_API ParameterCell final : public Cell<ParameterCell> { 50 public: 51 ParameterCell() = default; 52 ~ParameterCell() override = default; 53 54 ParameterCell(const ParameterCell &); 55 ParameterCell &operator=(const ParameterCell &); 56 57 ParameterCell(ParameterCell &&); 58 ParameterCell &operator=(ParameterCell &&); 59 60 explicit ParameterCell(const MSTensor &); 61 ParameterCell &operator=(const MSTensor &); 62 63 explicit ParameterCell(MSTensor &&); 64 ParameterCell &operator=(MSTensor &&); 65 66 MSTensor GetTensor() const { return tensor_; } 67 68 private: 69 MSTensor tensor_; 70 }; 71 72 class MS_API OpCellBase : public CellBase { 73 public: 74 explicit OpCellBase(const std::string &name) : name_(name) {} 75 ~OpCellBase() override = default; 76 const std::string &GetOpType() const { return name_; } 77 78 protected: 79 std::string name_; 80 }; 81 82 template <class T> 83 class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> { 84 public: 85 explicit OpCell(const std::string &name) : OpCellBase(name) {} 86 ~OpCell() override = default; 87 std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); } 88 }; 89 90 class MS_API GraphCell final : public Cell<GraphCell> { 91 public: 92 class GraphImpl; 93 94 GraphCell() = default; 95 ~GraphCell() override = default; 96 97 explicit GraphCell(const Graph &); 98 explicit GraphCell(Graph &&); 99 explicit GraphCell(const std::shared_ptr<Graph> &); 100 101 void SetContext(const std::shared_ptr<Context> &context); 102 const std::shared_ptr<Graph> &GetGraph() const { return graph_; } 103 Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; 104 std::vector<MSTensor> GetInputs(); 105 std::vector<MSTensor> GetOutputs(); 106 Status Load(uint32_t device_id); 107 108 private: 109 friend class Model; 110 111 std::shared_ptr<Graph> graph_; 112 std::shared_ptr<GraphImpl> executor_; 113 }; 114 115 class MS_API InputAndOutput { 116 public: 117 InputAndOutput(); 118 ~InputAndOutput() = default; 119 120 // no explicit 121 InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit) 122 InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit) 123 124 InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index); 125 126 int32_t GetIndex() const { return index_; } 127 void SetIndex(int32_t index) { index_ = index; } 128 129 private: 130 std::shared_ptr<CellBase> cell_; 131 std::vector<InputAndOutput> prev_; 132 int32_t index_; 133 }; 134 } // namespace mindspore 135 #endif // MINDSPORE_INCLUDE_API_CELL_H 136