• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2019,2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_GRAPH_INODE_H
25 #define ARM_COMPUTE_GRAPH_INODE_H
26 
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/graph/LayerDescriptors.h"
29 #include "arm_compute/graph/TensorDescriptor.h"
30 #include "arm_compute/graph/Types.h"
31 
32 #include <list>
33 #include <set>
34 
35 namespace arm_compute
36 {
37 namespace graph
38 {
39 // Forward declarations
40 class Graph;
41 class Edge;
42 class INodeVisitor;
43 class Tensor;
44 
45 /** Node interface */
46 class INode
47 {
48 public:
49     /** Constructor */
50     INode();
51     /** Destructor **/
52     virtual ~INode() = default;
53     /** Prevent instances of this class from being copied (As this class contains pointers) */
54     INode(const INode &) = delete;
55     /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
56     INode &operator=(const INode &) = delete;
57     /** Allow instances of this class to be moved */
58     INode(INode &&) = default;
59     /** Allow instances of this class to be move assigned */
60     INode &operator=(INode &&) = default;
61     /** Validate node
62      *
63      * @return Status containing any errors
64      */
65     virtual Status validate() const;
66     /** Returns node's type
67      *
68      * @return Node's type
69      */
70     virtual NodeType type() const = 0;
71     /** Accepts a node visitor
72      *
73      * @param[in] v Visitor to accept
74      */
75     virtual void accept(INodeVisitor &v) = 0;
76     /** Forwards descriptor information to outputs if possible
77      *
78      * @return True if descriptor information could be forwarded otherwise false
79      */
80     virtual bool forward_descriptors() = 0;
81     /** Calculates output configuration
82      *
83      * @param[in] idx Output index to configure
84      *
85      * @return Output descriptor configuration
86      */
87     virtual TensorDescriptor configure_output(size_t idx) const = 0;
88     /** Returns node's name
89      *
90      * @return Node name
91      */
92     std::string name() const;
93     /** Returns node's ID
94      *
95      * @return Node's ID
96      */
97     NodeID id() const;
98     /** Returns node's Graph
99      *
100      * @return Node's graph
101      */
102     const Graph *graph() const;
103     /** Returns node's Graph
104      *
105      * @return Node's graph
106      */
107     Graph *graph();
108     /** Sets the graph that this node is registered to
109      *
110      * @param[in] g Back reference to graph
111      */
112     void set_graph(Graph *g);
113     /** Sets the node id
114      *
115      * @param[in] id Node id
116      */
117     void set_id(NodeID id);
118     /** Sets common node parameters
119      *
120      * @param[in] common_params Common node parameters to set
121      */
122     void set_common_node_parameters(NodeParams common_params);
123     /** Sets target preference
124      *
125      * @note This is not the target that the graph executor might choose, its just an indication
126      *
127      * @param[in] target Target preference
128      */
129     void set_requested_target(Target target);
130     /** Sets the final execution target
131      *
132      * @note GraphManager might change this target
133      *
134      * @param[in] target Final execution target
135      */
136     void set_assigned_target(Target target);
137     /** Sets the output tensor of at a given index
138      *
139      * @note All edges will get updated
140      *
141      * @param[in] tid Tensor ID
142      * @param[in] idx Output index
143      */
144     void set_output_tensor(TensorID tid, size_t idx);
145     /** Returns inputs of the node
146      *
147      * @return Inputs of the node
148      */
149     const std::vector<TensorID> &inputs() const;
150     /** Returns outputs of the node
151      *
152      * @return Outputs of the node
153      */
154     const std::vector<TensorID> &outputs() const;
155     /** Returns input edge set
156      *
157      * @return Set of input edges
158      */
159     const std::vector<EdgeID> &input_edges() const;
160     /** Returns output edge set
161      *
162      * @return Set of output edges
163      */
164     const std::set<EdgeID> &output_edges() const;
165     /** Returns the tensor ID of a given input of the node
166      *
167      * @note Precondition : idx should be a valid input index
168      *
169      * @param[in] idx Index of the node input
170      *
171      * @return TensorID of the requested input
172      */
173     TensorID input_id(size_t idx) const;
174     /** Returns the tensor ID of a given output of the node
175      *
176      * @note Precondition : idx should be a valid output index
177      *
178      * @param[in] idx Index of the node output
179      *
180      * @return TensorID of the requested output
181      */
182     TensorID output_id(size_t idx) const;
183     /** Returns the tensor of a given input of the node
184      *
185      * @note Precondition : idx should be a valid input index
186      *
187      * @param[in] idx Index of the node input
188      *
189      * @return Tensor of the requested input
190      */
191     Tensor *input(size_t idx) const;
192     /** Returns the tensor of a given output of the node
193      *
194      * @note Precondition : idx should be a valid output index
195      *
196      * @param[in] idx Index of the node output
197      *
198      * @return Tensor of the requested output
199      */
200     Tensor *output(size_t idx) const;
201     /** Returns the edge ID of a given input of the node
202      *
203      * @note Precondition : idx should be a valid input index
204      *
205      * @param[in] idx Index of the node input
206      *
207      * @return EdgeID of the requested input
208      */
209     EdgeID input_edge_id(size_t idx) const;
210     /** Returns the edge of a given input of the node
211      *
212      * @note Precondition : idx should be a valid input index
213      *
214      * @param[in] idx Index of the node input
215      *
216      * @return Edge of the requested input
217      */
218     Edge *input_edge(size_t idx) const;
219     /** Returns number of inputs of the node
220      *
221      * @return Number of inputs
222      */
223     size_t num_inputs() const;
224     /** Returns number of outputs of the node
225      *
226      * @return Number of outputs
227      */
228     size_t num_outputs() const;
229     /** Returns common node parameters
230      *
231      * @return Common node parameters
232      */
233     NodeParams common_node_params() const;
234     /** Returns requested target for this node
235      *
236      * @return Requested execution target
237      */
238     Target requested_target() const;
239     /** Returns assigned target for this node
240      *
241      * @return Assigned target of this node
242      */
243     Target assigned_target() const;
244     /** Post operator info list
245      *
246      * @return Post operator info list
247      */
248     const std::list<std::unique_ptr<ConvPostOpInfo>> &post_op_info_list() const;
249     /** Post operator info list
250      *
251      * @return Post operator info list
252      */
253     std::list<std::unique_ptr<ConvPostOpInfo>> &post_op_info_list();
254 
255 protected:
256     friend class Graph;
257 
258 protected:
259     Graph                                     *_graph;             /**< Backward reference to graph owning the node */
260     NodeID                                     _id;                /**< Node ID */
261     NodeParams                                 _common_params;     /**< Node common params */
262     std::vector<TensorID>                      _outputs;           /**< Output of the node */
263     std::vector<EdgeID>                        _input_edges;       /**< Inputs edge set */
264     std::set<EdgeID>                           _output_edges;      /**< Output edge set */
265     Target                                     _assigned_target;   /**< Assigned target by the Graph executor */
266     std::list<std::unique_ptr<ConvPostOpInfo>> _post_op_info_list; /**< Post operator info list */
267 };
268 } // namespace graph
269 } // namespace arm_compute
270 #endif /* ARM_COMPUTE_GRAPH_INODE_H */
271