1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/math.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17
18
xnn_create_subgraph(uint32_t external_value_ids,uint32_t flags,xnn_subgraph_t * subgraph_out)19 enum xnn_status xnn_create_subgraph(
20 uint32_t external_value_ids,
21 uint32_t flags,
22 xnn_subgraph_t* subgraph_out)
23 {
24 struct xnn_subgraph* subgraph = NULL;
25 enum xnn_status status = xnn_status_uninitialized;
26
27 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
28 xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
29 goto error;
30 }
31
32 status = xnn_status_out_of_memory;
33
34 subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
35 if (subgraph == NULL) {
36 xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
37 goto error;
38 }
39
40 subgraph->external_value_ids = external_value_ids;
41
42 subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
43 if (subgraph->values == NULL) {
44 xnn_log_error("failed to allocate %zu bytes for subgraph values", external_value_ids * sizeof(struct xnn_value));
45 goto error;
46 }
47 for (size_t i = 0; i < external_value_ids; i++) {
48 subgraph->values[i].id = i;
49 }
50 subgraph->num_values = external_value_ids;
51 subgraph->num_reserved_values = external_value_ids;
52
53 *subgraph_out = subgraph;
54 return xnn_status_success;
55
56 error:
57 xnn_delete_subgraph(subgraph);
58 return status;
59 }
60
61
xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)62 struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
63 {
64 struct xnn_value* values = subgraph->values;
65 const size_t size = subgraph->num_values;
66 const size_t capacity = subgraph->num_reserved_values;
67 if (capacity < size + 1) {
68 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
69 assert(new_capacity >= size + 1);
70 values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
71 if (values == NULL) {
72 xnn_log_error("failed to allocate %zu bytes for subgraph values",
73 capacity * sizeof(struct xnn_value));
74 return values;
75 }
76
77 memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
78 subgraph->num_reserved_values = new_capacity;
79 subgraph->values = values;
80 }
81 subgraph->num_values = size + 1;
82 struct xnn_value* new_value = values + size;
83 new_value->id = size;
84 return new_value;
85 }
86
xnn_node_clear(struct xnn_node * node)87 void xnn_node_clear(struct xnn_node* node) {
88 assert(node != NULL);
89 assert(node->type != xnn_node_type_invalid);
90 memset(node, 0, sizeof(struct xnn_node));
91 }
92
xnn_value_clear(struct xnn_value * value)93 void xnn_value_clear(struct xnn_value* value) {
94 assert(value != NULL);
95 assert(value->type != xnn_value_type_invalid);
96 memset(value, 0, sizeof(struct xnn_value));
97 }
98
xnn_subgraph_new_node(xnn_subgraph_t subgraph)99 struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
100 {
101 struct xnn_node* nodes = subgraph->nodes;
102 const size_t size = subgraph->num_nodes;
103 const size_t capacity = subgraph->num_reserved_nodes;
104
105 if (capacity < size + 1) {
106 const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
107 assert(new_capacity >= size + 1);
108 nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
109 if (nodes == NULL) {
110 xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
111 capacity * sizeof(struct xnn_node));
112 return nodes;
113 }
114
115 memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
116 subgraph->num_reserved_nodes = new_capacity;
117 subgraph->nodes = nodes;
118 }
119 subgraph->num_nodes = size + 1;
120 struct xnn_node* new_node = nodes + size;
121 new_node->id = size;
122 return new_node;
123 }
124
125 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
126 #define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
127 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
128 #define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
129
xnn_check_nchw_compatibility(xnn_subgraph_t subgraph,struct xnn_node * node)130 uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
131 switch (node->type) {
132 case xnn_node_type_convolution_2d:
133 // Supported cases:
134 // - 1x1 convolution (no stride, no dilation, no padding, no groups)
135 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
136 if (node->params.convolution_2d.groups != 1) {
137 return 0;
138 }
139 if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
140 return 0;
141 }
142 if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
143 if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
144 node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0)
145 {
146 return 0;
147 }
148 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
149 return 0;
150 }
151 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
152 } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
153 if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
154 node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1)
155 {
156 return 0;
157 }
158 if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
159 return 0;
160 }
161 if (node->params.convolution_2d.group_input_channels != 3) {
162 return 0;
163 }
164 return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
165 }
166 return 0;
167 case xnn_node_type_depthwise_convolution_2d:
168 // Supported cases:
169 // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
170 // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
171 // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
172 // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
173 if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
174 return 0;
175 }
176 if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
177 return 0;
178 }
179 if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
180 return 0;
181 }
182 if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
183 return 0;
184 }
185 switch (node->params.depthwise_convolution_2d.subsampling_height) {
186 case 1:
187 case 2:
188 break;
189 default:
190 return 0;
191 }
192 if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
193 return 0;
194 }
195 switch (node->params.depthwise_convolution_2d.kernel_height) {
196 case 3:
197 return node->params.depthwise_convolution_2d.input_padding_top == 1 &&
198 node->params.depthwise_convolution_2d.input_padding_right == 1 &&
199 node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
200 node->params.depthwise_convolution_2d.input_padding_left == 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
201 case 5:
202 return node->params.depthwise_convolution_2d.input_padding_top == 2 &&
203 node->params.depthwise_convolution_2d.input_padding_right == 2 &&
204 node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
205 node->params.depthwise_convolution_2d.input_padding_left == 2 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
206 default:
207 return 0;
208 }
209 case xnn_node_type_depth_to_space:
210 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
211 case xnn_node_type_global_average_pooling_2d:
212 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
213 case xnn_node_type_add2:
214 case xnn_node_type_multiply2:
215 assert(node->num_inputs == 2);
216 assert(node->num_outputs == 1);
217 if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
218 subgraph->values[node->inputs[1]].shape.num_dims != 4)
219 {
220 return 0;
221 }
222
223 if (subgraph->values[node->inputs[0]].data != NULL) {
224 // Check that the first input is representable as either a scalar, or a vector
225 size_t num_nonunit_dims = 0;
226 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
227 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
228 num_nonunit_dims += 1;
229 }
230 }
231 if (num_nonunit_dims > 1) {
232 return 0;
233 }
234 }
235
236 if (subgraph->values[node->inputs[1]].data != NULL) {
237 // Check that the second input is representable as either a scalar, or a vector
238 size_t num_nonunit_dims = 0;
239 for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
240 if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
241 num_nonunit_dims += 1;
242 }
243 }
244 if (num_nonunit_dims > 1) {
245 return 0;
246 }
247 }
248
249 return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
250 case xnn_node_type_static_resize_bilinear_2d:
251 return subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
252 subgraph->values[node->inputs[0]].shape.dim[2] > 1 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
253 case xnn_node_type_abs:
254 case xnn_node_type_bankers_rounding:
255 case xnn_node_type_ceiling:
256 case xnn_node_type_clamp:
257 case xnn_node_type_elu:
258 case xnn_node_type_floor:
259 case xnn_node_type_hardswish:
260 case xnn_node_type_leaky_relu:
261 case xnn_node_type_negate:
262 case xnn_node_type_sigmoid:
263 case xnn_node_type_square:
264 assert(node->num_inputs == 1);
265 assert(node->num_outputs == 1);
266 return subgraph->values[node->inputs[0]].shape.num_dims == 4 ? XNN_LAYOUT_FLAG_COMPATIBLE_NCHW : 0;
267 default:
268 return false;
269 }
270 }
271
xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)272 void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
273 {
274 // Convert parts of the subgraph to NCHW for sparse inference
275 // Step 1: detect NCHW-compatible Nodes
276 // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
277 // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
278 // Step 4: switch Values' layout to NCHW
279 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
280 struct xnn_node* node = &subgraph->nodes[n];
281 node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
282 xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
283 n, xnn_node_type_to_string(node->type),
284 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
285 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
286 node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
287 }
288
289 // Run Shiloach-Vishkin connected components algorithm i.e. find all
290 // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
291 // to all the producer nodes
292 bool update = false;
293 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
294 struct xnn_node* node = &subgraph->nodes[n];
295 node->cluster_leader = n;
296 if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
297 for (uint32_t i = 0; i < node->num_inputs; i++) {
298 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
299 if (value->data != NULL) {
300 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
301 // during the initial NCHW compatibility check for the Node.
302 continue;
303 }
304 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
305 // External value, invalid cluster
306 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
307 continue;
308 }
309 const uint32_t producer_id = value->producer;
310 assert(producer_id != XNN_INVALID_NODE_ID);
311 assert(producer_id < n);
312 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
313 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
314 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
315 {
316 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
317 if (producer_node->cluster_leader != node->cluster_leader) {
318 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
319 update = true;
320 }
321 } else {
322 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
323 }
324 }
325 }
326 }
327 // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
328 // pratically cannot happen.
329 if (!update) {
330 return;
331 }
332 // Propagate the cluster leader to other nodes in the graph untill all the
333 // nodes in the cluster is not updated
334 while (update) {
335 update = false;
336 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
337 struct xnn_node* node = &subgraph->nodes[n];
338 if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
339 continue;
340 }
341
342 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
343 continue;
344 }
345
346 for (uint32_t i = 0; i < node->num_inputs; i++) {
347 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
348 if (value->data != NULL) {
349 // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
350 // during the initial NCHW compatibility check for the Node.
351 continue;
352 }
353 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) != 0) {
354 // External value, invalid cluster
355 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
356 continue;
357 }
358 const uint32_t producer_id = value->producer;
359 assert(producer_id != XNN_INVALID_NODE_ID);
360 assert(producer_id < n);
361 struct xnn_node* producer_node = &subgraph->nodes[producer_id];
362 if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
363 (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
364 {
365 producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
366 if (producer_node->cluster_leader != node->cluster_leader) {
367 producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
368 update = true;
369 }
370 } else {
371 node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
372 }
373 }
374 }
375 }
376 // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
377 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
378 struct xnn_node* node = &subgraph->nodes[n];
379 subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
380 }
381 // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
382 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
383 struct xnn_node* node = &subgraph->nodes[n];
384 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
385 continue;
386 }
387
388 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
389 continue;
390 }
391
392 for (uint32_t i = 0; i < node->num_inputs; i++) {
393 struct xnn_value* value = &subgraph->values[node->inputs[i]];
394 if (value->data != NULL) {
395 // Static data, skip this input value because it doesn't have a producer Node.
396 continue;
397 }
398 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
399 value->num_nchw_compatible_consumers += 1;
400 }
401 }
402 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
403 struct xnn_node* node = &subgraph->nodes[n];
404 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
405 continue;
406 }
407
408 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
409 continue;
410 }
411
412 for (uint32_t i = 0; i < node->num_inputs; i++) {
413 const struct xnn_value* value = &subgraph->values[node->inputs[i]];
414 if (value->data != NULL) {
415 // Static data, skip this input value because it doesn't have a producer Node.
416 continue;
417 }
418 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
419 assert(value->num_nchw_compatible_consumers > 0);
420 if (value->num_nchw_compatible_consumers != value->num_consumers) {
421 subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
422 }
423 }
424 }
425 // Evaluate if it is profitable to run the model as sparse:
426 // - Compute the number of parameters and zeroes in 1x1 Convolution weights
427 // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
428 // or with less than 2/3rd of zeroes in 1x1 Convolution filters
429 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
430 struct xnn_node* node = &subgraph->nodes[n];
431 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
432 continue;
433 }
434
435 if (node->type == xnn_node_type_convolution_2d &&
436 max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
437 {
438 assert(node->num_inputs >= 2);
439
440 const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
441 assert(filter->data != NULL);
442 assert(filter->shape.num_dims == 4);
443
444 const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
445 subgraph->nodes[node->cluster_leader].num_params += num_params;
446
447 const float* data = (const float*) filter->data;
448 size_t num_zeroes = 0;
449 for (size_t i = 0; i < num_params; i++) {
450 num_zeroes += (size_t) (data[i] == 0.0f);
451 }
452 xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
453 subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
454 }
455 }
456 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
457 struct xnn_node* node = &subgraph->nodes[n];
458 if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
459 continue;
460 }
461
462 if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
463 continue;
464 }
465
466 if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
467 xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
468 n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
469 continue;
470 }
471
472 for (uint32_t i = 0; i < node->num_inputs; i++) {
473 struct xnn_value* value = &subgraph->values[node->inputs[i]];
474 if (value->data != NULL) {
475 // Static data, skip this input value because it doesn't have a producer Node.
476 continue;
477 }
478 assert((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0);
479 assert(value->num_nchw_compatible_consumers > 0);
480 assert(value->num_nchw_compatible_consumers == value->num_consumers);
481 if (value->layout != xnn_layout_type_nchw) {
482 value->layout = xnn_layout_type_nchw;
483 xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
484 }
485 }
486 }
487 }
488
xnn_subgraph_optimize(xnn_subgraph_t subgraph,uint32_t flags)489 enum xnn_status xnn_subgraph_optimize(
490 xnn_subgraph_t subgraph,
491 uint32_t flags)
492 {
493 // Initialize producer/consumer fields to safe defaults.
494 for (uint32_t i = 0; i < subgraph->num_values; i++) {
495 struct xnn_value* value = &subgraph->values[i];
496 value->producer = XNN_INVALID_NODE_ID;
497 value->first_consumer = XNN_INVALID_NODE_ID;
498 value->num_consumers = 0;
499 }
500
501 // Analyse Nodes' inputs and output and update Values' producer/consumer fields
502 for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
503 struct xnn_node* node = &subgraph->nodes[n];
504
505 for (uint32_t i = 0; i < node->num_inputs; i++) {
506 const uint32_t input_id = node->inputs[i];
507 assert(input_id < subgraph->num_values);
508
509 if (subgraph->values[input_id].num_consumers++ == 0) {
510 assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
511 subgraph->values[input_id].first_consumer = n;
512 }
513 }
514
515 for (uint32_t o = 0; o < node->num_outputs; o++) {
516 const uint32_t output_id = node->outputs[o];
517 assert(output_id < subgraph->num_values);
518
519 assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
520 subgraph->values[output_id].producer = n;
521 }
522 }
523
524 // Count extra consumer for Values which are external outputs.
525 // Remove unreferenced values.
526 for (uint32_t i = 0; i < subgraph->num_values; i++) {
527 struct xnn_value* value = &subgraph->values[i];
528 if (value->type == xnn_value_type_invalid) {
529 continue;
530 }
531
532 if (value->flags & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
533 value->num_consumers += 1;
534 }
535 if ((value->flags & XNN_VALUE_FLAG_EXTERNAL_INPUT) == 0 && value->num_consumers == 0) {
536 xnn_value_clear(value);
537 }
538 }
539
540 // Fuse Nodes where possible
541 for (uint32_t i = 0; i < subgraph->num_values; i++) {
542 struct xnn_value* value = &subgraph->values[i];
543 if (value->num_consumers == 1) {
544 const uint32_t producer_id = value->producer;
545 if (producer_id == XNN_INVALID_NODE_ID) {
546 continue;
547 }
548 assert(producer_id < subgraph->num_nodes);
549
550 const uint32_t consumer_id = value->first_consumer;
551 if (consumer_id == XNN_INVALID_NODE_ID) {
552 continue;
553 }
554 assert(consumer_id < subgraph->num_nodes);
555
556 struct xnn_node* producer = &subgraph->nodes[producer_id];
557 assert(producer->type != xnn_node_type_invalid);
558 struct xnn_node* consumer = &subgraph->nodes[consumer_id];
559 assert(consumer->type != xnn_node_type_invalid);
560
561 // Try to fuse Clamp Node upstream into producer Node
562 if (consumer->type == xnn_node_type_clamp) {
563 switch (producer->type) {
564 case xnn_node_type_add2:
565 case xnn_node_type_average_pooling_2d:
566 case xnn_node_type_clamp:
567 case xnn_node_type_convolution_2d:
568 case xnn_node_type_divide:
569 case xnn_node_type_deconvolution_2d:
570 case xnn_node_type_depthwise_convolution_2d:
571 case xnn_node_type_fully_connected:
572 case xnn_node_type_multiply2:
573 case xnn_node_type_max_pooling_2d:
574 case xnn_node_type_subtract:
575 xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
576 assert(producer->num_outputs == 1);
577 assert(consumer->num_inputs == 1);
578 assert(consumer->num_outputs == 1);
579
580 const uint32_t fused_output_id = consumer->outputs[0];
581 assert(fused_output_id < subgraph->num_values);
582 subgraph->values[fused_output_id].producer = producer_id;
583 producer->outputs[0] = fused_output_id;
584
585 producer->activation.output_min =
586 math_max_f32(producer->activation.output_min, consumer->activation.output_min);
587 producer->activation.output_max =
588 math_min_f32(producer->activation.output_max, consumer->activation.output_max);
589
590 xnn_node_clear(consumer);
591 xnn_value_clear(value);
592 break;
593 default:
594 break;
595 }
596 }
597 // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
598 if (producer->type == xnn_node_type_static_constant_pad) {
599 assert(producer->num_inputs == 1);
600 assert(producer->num_outputs == 1);
601 const bool is_spatial_2d_zero_padding = value->shape.num_dims == 4 &&
602 (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
603 producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0 &&
604 producer->params.static_pad.padding_value == 0;
605 switch (consumer->type) {
606 case xnn_node_type_convolution_2d:
607 if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
608 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
609 consumer_id, producer_id);
610 assert(consumer->num_inputs >= 1);
611 assert(consumer->inputs[0] == producer->outputs[0]);
612
613 consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
614 consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
615 consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
616 consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
617
618 consumer->inputs[0] = producer->inputs[0];
619
620 const uint32_t fused_input_id = producer->inputs[0];
621 assert(fused_input_id < subgraph->num_values);
622 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
623 subgraph->values[fused_input_id].first_consumer = consumer_id;
624 }
625
626 xnn_node_clear(producer);
627 xnn_value_clear(value);
628 }
629 break;
630 case xnn_node_type_depthwise_convolution_2d:
631 if (is_spatial_2d_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
632 xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
633 consumer_id, producer_id);
634 assert(consumer->num_inputs >= 1);
635 assert(consumer->inputs[0] == producer->outputs[0]);
636
637 consumer->params.depthwise_convolution_2d.input_padding_top +=
638 producer->params.static_pad.pre_paddings[1];
639 consumer->params.depthwise_convolution_2d.input_padding_right +=
640 producer->params.static_pad.post_paddings[2];
641 consumer->params.depthwise_convolution_2d.input_padding_bottom +=
642 producer->params.static_pad.post_paddings[1];
643 consumer->params.depthwise_convolution_2d.input_padding_left +=
644 producer->params.static_pad.pre_paddings[2];
645
646 consumer->inputs[0] = producer->inputs[0];
647
648 const uint32_t fused_input_id = producer->inputs[0];
649 assert(fused_input_id < subgraph->num_values);
650 if (subgraph->values[fused_input_id].first_consumer == producer_id) {
651 subgraph->values[fused_input_id].first_consumer = consumer_id;
652 }
653
654 xnn_node_clear(producer);
655 xnn_value_clear(value);
656 }
657 break;
658 default:
659 break;
660 }
661 }
662 }
663 }
664
665 #if XNN_ENABLE_SPARSE
666 if ((flags & XNN_FLAG_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
667 xnn_subgraph_rewrite_for_nchw(subgraph);
668 }
669 #endif
670
671 return xnn_status_success;
672 }
673
xnn_delete_subgraph(xnn_subgraph_t subgraph)674 enum xnn_status xnn_delete_subgraph(
675 xnn_subgraph_t subgraph)
676 {
677 if (subgraph != NULL) {
678 memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
679 xnn_release_memory(subgraph->nodes);
680
681 memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
682 xnn_release_memory(subgraph->values);
683
684 memset(subgraph, 0, sizeof(struct xnn_subgraph));
685 xnn_release_memory(subgraph);
686 }
687 return xnn_status_success;
688 }
689