/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_INCLUDE_API_CELL_H #define MINDSPORE_INCLUDE_API_CELL_H #include #include #include #include #include "include/api/status.h" #include "include/api/types.h" #include "include/api/graph.h" namespace mindspore { class InputAndOutput; class Context; using Input = InputAndOutput; using Output = InputAndOutput; class MS_API CellBase { public: CellBase() = default; virtual ~CellBase() = default; virtual std::vector Construct(const std::vector &inputs) { return {}; } virtual std::shared_ptr Clone() const = 0; virtual Status Run(const std::vector &inputs, std::vector *outputs) { return kSuccess; } std::vector operator()(const std::vector &inputs) const; }; template class MS_API Cell : public CellBase { public: virtual ~Cell() = default; std::shared_ptr Clone() const override { return std::make_shared(static_cast(*this)); } }; class MS_API GraphCell final : public Cell { public: class GraphImpl; GraphCell() = default; ~GraphCell() override = default; explicit GraphCell(const Graph &graph); explicit GraphCell(Graph &&graph); explicit GraphCell(const std::shared_ptr &graph); void SetContext(const std::shared_ptr &context); const std::shared_ptr &GetGraph() const { return graph_; } Status Run(const std::vector &inputs, std::vector *outputs) override; std::vector GetInputs(); std::vector GetOutputs(); Status Load(uint32_t device_id); private: friend class Model; std::shared_ptr graph_; std::shared_ptr executor_; }; class MS_API InputAndOutput { public: InputAndOutput(); ~InputAndOutput() = default; InputAndOutput(const std::shared_ptr &cell, const std::vector &prev, int32_t index); int32_t GetIndex() const { return index_; } void SetIndex(int32_t index) { index_ = index; } private: std::shared_ptr cell_; std::vector prev_; int32_t index_ = 0; }; } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CELL_H