• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_GRAPH_INODE_H
25 #define ARM_COMPUTE_GRAPH_INODE_H
26 
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/graph/LayerDescriptors.h"
29 #include "arm_compute/graph/TensorDescriptor.h"
30 #include "arm_compute/graph/Types.h"
31 
32 #include <set>
33 
34 namespace arm_compute
35 {
36 namespace graph
37 {
38 // Forward declarations
39 class Graph;
40 class Edge;
41 class INodeVisitor;
42 class Tensor;
43 
44 /** Node interface */
45 class INode
46 {
47 public:
48     /** Constructor */
49     INode();
50     /** Destructor **/
51     virtual ~INode() = default;
52     /** Prevent instances of this class from being copied (As this class contains pointers) */
53     INode(const INode &) = delete;
54     /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
55     INode &operator=(const INode &) = delete;
56     /** Allow instances of this class to be moved */
57     INode(INode &&) = default;
58     /** Allow instances of this class to be move assigned */
59     INode &operator=(INode &&) = default;
60     /** Validate node
61      *
62      * @return Status containing any errors
63      */
64     virtual Status validate() const;
65     /** Returns node's type
66      *
67      * @return Node's type
68      */
69     virtual NodeType type() const = 0;
70     /** Accepts a node visitor
71      *
72      * @param[in] v Visitor to accept
73      */
74     virtual void accept(INodeVisitor &v) = 0;
75     /** Forwards descriptor information to outputs if possible
76      *
77      * @return True if descriptor information could be forwarded otherwise false
78      */
79     virtual bool forward_descriptors() = 0;
80     /** Calculates output configuration
81      *
82      * @param[in] idx Output index to configure
83      *
84      * @return Output descriptor configuration
85      */
86     virtual TensorDescriptor configure_output(size_t idx) const = 0;
87     /** Returns node's name
88      *
89      * @return Node name
90      */
91     std::string name() const;
92     /** Returns node's ID
93      *
94      * @return Node's ID
95      */
96     NodeID id() const;
97     /** Returns node's Graph
98      *
99      * @return Node's graph
100      */
101     const Graph *graph() const;
102     /** Returns node's Graph
103      *
104      * @return Node's graph
105      */
106     Graph *graph();
107     /** Sets the graph that this node is registered to
108      *
109      * @param[in] g Back reference to graph
110      */
111     void set_graph(Graph *g);
112     /** Sets the node id
113      *
114      * @param[in] id Node id
115      */
116     void set_id(NodeID id);
117     /** Sets common node parameters
118      *
119      * @param[in] common_params Common node parameters to set
120      */
121     void set_common_node_parameters(NodeParams common_params);
122     /** Sets target preference
123      *
124      * @note This is not the target that the graph executor might choose, its just an indication
125      *
126      * @param[in] target Target preference
127      */
128     void set_requested_target(Target target);
129     /** Sets the final execution target
130      *
131      * @note GraphManager might change this target
132      *
133      * @param[in] target Final execution target
134      */
135     void set_assigned_target(Target target);
136     /** Sets the output tensor of at a given index
137      *
138      * @note All edges will get updated
139      *
140      * @param[in] tid Tensor ID
141      * @param[in] idx Output index
142      */
143     void set_output_tensor(TensorID tid, size_t idx);
144     /** Returns inputs of the node
145      *
146      * @return Inputs of the node
147      */
148     const std::vector<TensorID> &inputs() const;
149     /** Returns outputs of the node
150      *
151      * @return Outputs of the node
152      */
153     const std::vector<TensorID> &outputs() const;
154     /** Returns input edge set
155      *
156      * @return Set of input edges
157      */
158     const std::vector<EdgeID> &input_edges() const;
159     /** Returns output edge set
160      *
161      * @return Set of output edges
162      */
163     const std::set<EdgeID> &output_edges() const;
164     /** Returns the tensor ID of a given input of the node
165      *
166      * @note Precondition : idx should be a valid input index
167      *
168      * @param[in] idx Index of the node input
169      *
170      * @return TensorID of the requested input
171      */
172     TensorID input_id(size_t idx) const;
173     /** Returns the tensor ID of a given output of the node
174      *
175      * @note Precondition : idx should be a valid output index
176      *
177      * @param[in] idx Index of the node output
178      *
179      * @return TensorID of the requested output
180      */
181     TensorID output_id(size_t idx) const;
182     /** Returns the tensor of a given input of the node
183      *
184      * @note Precondition : idx should be a valid input index
185      *
186      * @param[in] idx Index of the node input
187      *
188      * @return Tensor of the requested input
189      */
190     Tensor *input(size_t idx) const;
191     /** Returns the tensor of a given output of the node
192      *
193      * @note Precondition : idx should be a valid output index
194      *
195      * @param[in] idx Index of the node output
196      *
197      * @return Tensor of the requested output
198      */
199     Tensor *output(size_t idx) const;
200     /** Returns the edge ID of a given input of the node
201      *
202      * @note Precondition : idx should be a valid input index
203      *
204      * @param[in] idx Index of the node input
205      *
206      * @return EdgeID of the requested input
207      */
208     EdgeID input_edge_id(size_t idx) const;
209     /** Returns the edge of a given input of the node
210      *
211      * @note Precondition : idx should be a valid input index
212      *
213      * @param[in] idx Index of the node input
214      *
215      * @return Edge of the requested input
216      */
217     Edge *input_edge(size_t idx) const;
218     /** Returns number of inputs of the node
219      *
220      * @return Number of inputs
221      */
222     size_t num_inputs() const;
223     /** Returns number of outputs of the node
224      *
225      * @return Number of outputs
226      */
227     size_t num_outputs() const;
228     /** Returns common node parameters
229      *
230      * @return Common node parameters
231      */
232     NodeParams common_node_params() const;
233     /** Returns requested target for this node
234      *
235      * @return Requested execution target
236      */
237     Target requested_target() const;
238     /** Returns assigned target for this node
239      *
240      * @return Assigned target of this node
241      */
242     Target assigned_target() const;
243 
244 protected:
245     friend class Graph;
246 
247 protected:
248     Graph                *_graph;           /**< Backward reference to graph owning the node */
249     NodeID                _id;              /**< Node ID */
250     NodeParams            _common_params;   /**< Node common params */
251     std::vector<TensorID> _outputs;         /**< Output of the node */
252     std::vector<EdgeID>   _input_edges;     /**< Inputs edge set */
253     std::set<EdgeID>      _output_edges;    /**< Output edge set */
254     Target                _assigned_target; /**< Assigned target by the Graph executor */
255 };
256 } // namespace graph
257 } // namespace arm_compute
258 #endif /* ARM_COMPUTE_GRAPH_INODE_H */
259