• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_INCLUDE_API_DELEGATE_API_H
18 #define MINDSPORE_INCLUDE_API_DELEGATE_API_H
19 
20 #include <map>
21 #include <vector>
22 #include <memory>
23 #include "include/api/status.h"
24 #include "include/api/types.h"
25 namespace mindspore {
26 class AbstractDelegate {
27  public:
28   AbstractDelegate() = default;
AbstractDelegate(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs)29   AbstractDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
30       : inputs_(inputs), outputs_(outputs) {}
31   virtual ~AbstractDelegate() = default;
32   /// \brief Get the input tensors of DelegateModel.
33   ///
34   /// \return The input tensor vector of DelegateModel.
inputs()35   const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
36 
37   /// \brief Get the output tensors of DelegateModel.
38   ///
39   /// \return The ioutput tensor vector of DelegateModel.
outputs()40   const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
41 
42  protected:
43   std::vector<mindspore::MSTensor> inputs_;
44   std::vector<mindspore::MSTensor> outputs_;
45 };
46 
47 template <typename Graph, typename Node, typename Kernel>
48 class IDelegate : public AbstractDelegate {
49  public:
50   IDelegate() = default;
IDelegate(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs)51   IDelegate(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs)
52       : AbstractDelegate(inputs, outputs) {}
53   virtual ~IDelegate() = default;
54 
55   /// \brief Replace the nodes in model with delegate nodes, delegate will create kernels by its delegate nodes.
56   ///
57   /// \param[in] graph The graph to be built.
58   virtual void ReplaceNodes(const std::shared_ptr<Graph> &graph) = 0;
59 
60   /// \brief Check if this node is belong to this delegate.
61   ///
62   /// \param[in] node The node need to be checked.
63   ///
64   /// \return True if the node is belong to this delegate, otherwise return false.
65   virtual bool IsDelegateNode(const std::shared_ptr<Node> &node) = 0;
66 
67   /// \brief Create a delegate kernel if the node is a delegate node.
68   ///
69   /// \param[in] node Define the delegate model to be built.
70   ///
71   /// \return The delegate kernel, if the node is not a delegate node, return nullptr.
72   virtual std::shared_ptr<Kernel> CreateKernel(const std::shared_ptr<Node> &node) = 0;
73 };
74 }  // namespace mindspore
75 #endif  // MINDSPORE_INCLUDE_API_DELEGATE_API_H
76