1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_INCLUDE_API_CELL_H 17 #define MINDSPORE_INCLUDE_API_CELL_H 18 #include <string> 19 #include <vector> 20 #include <map> 21 #include <memory> 22 #include "include/api/status.h" 23 #include "include/api/types.h" 24 #include "include/api/graph.h" 25 26 namespace mindspore { 27 class InputAndOutput; 28 class Context; 29 using Input = InputAndOutput; 30 using Output = InputAndOutput; 31 32 class MS_API CellBase { 33 public: 34 CellBase() = default; 35 virtual ~CellBase() = default; Construct(const std::vector<Input> & inputs)36 virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; } 37 virtual std::shared_ptr<CellBase> Clone() const = 0; Run(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)38 virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; } 39 std::vector<Output> operator()(const std::vector<Input> &inputs) const; 40 }; 41 42 template <class T> 43 class MS_API Cell : public CellBase { 44 public: 45 virtual ~Cell() = default; Clone()46 std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); } 47 }; 48 49 class MS_API GraphCell final : public Cell<GraphCell> { 50 public: 51 class GraphImpl; 52 53 GraphCell() = default; 54 ~GraphCell() override = default; 55 56 explicit GraphCell(const Graph &graph); 57 explicit GraphCell(Graph &&graph); 58 explicit GraphCell(const std::shared_ptr<Graph> &graph); 59 60 void SetContext(const std::shared_ptr<Context> &context); GetGraph()61 const std::shared_ptr<Graph> &GetGraph() const { return graph_; } 62 Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; 63 std::vector<MSTensor> GetInputs(); 64 std::vector<MSTensor> GetOutputs(); 65 Status Load(uint32_t device_id); 66 67 private: 68 friend class Model; 69 70 std::shared_ptr<Graph> graph_; 71 std::shared_ptr<GraphImpl> executor_; 72 }; 73 74 class MS_API InputAndOutput { 75 public: 76 InputAndOutput(); 77 ~InputAndOutput() = default; 78 79 InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev, int32_t index); 80 GetIndex()81 int32_t GetIndex() const { return index_; } SetIndex(int32_t index)82 void SetIndex(int32_t index) { index_ = index; } 83 84 private: 85 std::shared_ptr<CellBase> cell_; 86 std::vector<InputAndOutput> prev_; 87 int32_t index_ = 0; 88 }; 89 } // namespace mindspore 90 #endif // MINDSPORE_INCLUDE_API_CELL_H 91