1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_INCLUDE_API_DELEGATE_API_H 18 #define MINDSPORE_INCLUDE_API_DELEGATE_API_H 19 20 #include <map> 21 #include <vector> 22 #include <memory> 23 #include "include/api/status.h" 24 #include "include/api/types.h" 25 namespace mindspore { 26 class AbstractDelegate { 27 public: 28 AbstractDelegate() = default; AbstractDelegate(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs)29 AbstractDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs) 30 : inputs_(inputs), outputs_(outputs) {} 31 virtual ~AbstractDelegate() = default; 32 /// \brief Get the input tensors of DelegateModel. 33 /// 34 /// \return The input tensor vector of DelegateModel. inputs()35 const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; } 36 37 /// \brief Get the output tensors of DelegateModel. 38 /// 39 /// \return The ioutput tensor vector of DelegateModel. outputs()40 const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; } 41 42 protected: 43 std::vector<mindspore::MSTensor> inputs_; 44 std::vector<mindspore::MSTensor> outputs_; 45 }; 46 47 template <typename Graph, typename Node, typename Kernel> 48 class IDelegate : public AbstractDelegate { 49 public: 50 IDelegate() = default; IDelegate(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs)51 IDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs) 52 : AbstractDelegate(inputs, outputs) {} 53 virtual ~IDelegate() = default; 54 55 /// \brief Replace the nodes in model with delegate nodes, delegate will create kernels by its delegate nodes. 56 /// 57 /// \param[in] graph The graph to be built. 58 virtual void ReplaceNodes(const std::shared_ptr<Graph> &graph) = 0; 59 60 /// \brief Check if this node is belong to this delegate. 61 /// 62 /// \param[in] node The node need to be checked. 63 /// 64 /// \return True if the node is belong to this delegate, otherwise return false. 65 virtual bool IsDelegateNode(const std::shared_ptr<Node> &node) = 0; 66 67 /// \brief Create a delegate kernel if the node is a delegate node. 68 /// 69 /// \param[in] node Define the delegate model to be built. 70 /// 71 /// \return The delegate kernel, if the node is not a delegate node, return nullptr. 72 virtual std::shared_ptr<Kernel> CreateKernel(const std::shared_ptr<Node> &node) = 0; 73 }; 74 } // namespace mindspore 75 #endif // MINDSPORE_INCLUDE_API_DELEGATE_API_H 76