1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <stdio.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/math.h>
16 #include <xnnpack/memory-planner.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/params.h>
19 #include <xnnpack/subgraph.h>
20
21
xnn_create_runtime(xnn_subgraph_t subgraph,xnn_runtime_t * runtime_out)22 enum xnn_status xnn_create_runtime(
23 xnn_subgraph_t subgraph,
24 xnn_runtime_t* runtime_out)
25 {
26 return xnn_create_runtime_v2(subgraph, NULL /* threadpool */, 0 /* flags */, runtime_out);
27 }
28
29 // Product of all shape dimensions
product_all_dims(const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS (1)])30 static size_t product_all_dims(
31 const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS(1)])
32 {
33 size_t batch_size = 1;
34 for (size_t i = 0; i < shape->num_dims; i++) {
35 batch_size *= shape->dim[i];
36 }
37 return batch_size;
38 }
39
40 // Product of all shape dimensions, except for the last (channel) one
product_non_channel_dims(const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS (1)])41 static size_t product_non_channel_dims(
42 const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS(1)])
43 {
44 size_t batch_size = 1;
45 for (size_t i = 0; i + 1 < shape->num_dims; i++) {
46 batch_size *= shape->dim[i];
47 }
48 return batch_size;
49 }
50
xnn_create_runtime_v2(xnn_subgraph_t subgraph,pthreadpool_t threadpool,uint32_t flags,xnn_runtime_t * runtime_out)51 enum xnn_status xnn_create_runtime_v2(
52 xnn_subgraph_t subgraph,
53 pthreadpool_t threadpool,
54 uint32_t flags,
55 xnn_runtime_t* runtime_out)
56 {
57 struct xnn_runtime* runtime = NULL;
58 enum xnn_status status = xnn_status_uninitialized;
59
60 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
61 xnn_log_error("failed to create runtime: XNNPACK is not initialized");
62 goto error;
63 }
64
65 xnn_subgraph_optimize(subgraph, flags & XNN_FLAG_SPARSE_INFERENCE);
66
67 status = xnn_status_out_of_memory;
68
69 runtime = xnn_allocate_zero_memory(sizeof(struct xnn_runtime));
70 if (runtime == NULL) {
71 xnn_log_error("failed to allocate %zu bytes for runtime descriptor", sizeof(struct xnn_runtime));
72 goto error;
73 }
74
75 runtime->opdata = xnn_allocate_zero_memory(sizeof(struct xnn_operator_data) * subgraph->num_nodes);
76 if (runtime->opdata == NULL) {
77 xnn_log_error("failed to allocate %zu bytes for opdata descriptors",
78 sizeof(struct xnn_operator_data) * subgraph->num_nodes);
79 goto error;
80 }
81 runtime->num_ops = subgraph->num_nodes;
82
83 struct xnn_value* values = subgraph->values;
84 for (size_t i = 0; i < subgraph->num_nodes; i++) {
85 const struct xnn_node* node = subgraph->nodes + i;
86 switch (node->type) {
87 case xnn_node_type_invalid:
88 // Node was fused
89 continue;
90 case xnn_node_type_abs:
91 status = xnn_create_abs_nc_f32(
92 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
93 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
94 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
95 node->flags,
96 &runtime->opdata[i].operator_object);
97 if (status != xnn_status_success) {
98 goto error;
99 }
100 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
101 runtime->opdata[i].inputs[0] = node->inputs[0];
102 runtime->opdata[i].outputs[0] = node->outputs[0];
103 break;
104 case xnn_node_type_add2:
105 status = xnn_create_add_nd_f32(
106 node->activation.output_min,
107 node->activation.output_max,
108 node->flags,
109 &runtime->opdata[i].operator_object);
110 if (status != xnn_status_success) {
111 goto error;
112 }
113 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
114 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
115 if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
116 assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
117 assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
118 runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
119 runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
120 if (values[node->inputs[0]].shape.num_dims > 2) {
121 memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
122 }
123 runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
124 runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
125 if (values[node->inputs[0]].shape.num_dims > 2) {
126 memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
127 }
128 } else {
129 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
130 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
131 assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
132 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
133 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
134 }
135 runtime->opdata[i].inputs[0] = node->inputs[0];
136 runtime->opdata[i].inputs[1] = node->inputs[1];
137 runtime->opdata[i].outputs[0] = node->outputs[0];
138 break;
139 case xnn_node_type_argmax_pooling_2d:
140 status = xnn_create_argmax_pooling2d_nhwc_f32(
141 node->params.pooling_2d.padding_top,
142 node->params.pooling_2d.padding_right,
143 node->params.pooling_2d.padding_bottom,
144 node->params.pooling_2d.padding_left,
145 node->params.pooling_2d.pooling_height,
146 node->params.pooling_2d.pooling_width,
147 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
148 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
149 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
150 node->flags,
151 &runtime->opdata[i].operator_object);
152 if (status != xnn_status_success) {
153 goto error;
154 }
155 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
156 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
157 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
158 runtime->opdata[i].inputs[0] = node->inputs[0];
159 runtime->opdata[i].outputs[0] = node->outputs[0];
160 runtime->opdata[i].outputs[1] = node->outputs[1];
161 break;
162 case xnn_node_type_average_pooling_2d:
163 status = xnn_create_average_pooling2d_nhwc_f32(
164 node->params.pooling_2d.padding_top,
165 node->params.pooling_2d.padding_right,
166 node->params.pooling_2d.padding_bottom,
167 node->params.pooling_2d.padding_left,
168 node->params.pooling_2d.pooling_height,
169 node->params.pooling_2d.pooling_width,
170 node->params.pooling_2d.stride_height,
171 node->params.pooling_2d.stride_width,
172 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
173 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
174 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
175 node->activation.output_min,
176 node->activation.output_max,
177 node->flags,
178 &runtime->opdata[i].operator_object);
179 if (status != xnn_status_success) {
180 goto error;
181 }
182 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
183 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
184 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
185 runtime->opdata[i].inputs[0] = node->inputs[0];
186 runtime->opdata[i].outputs[0] = node->outputs[0];
187 break;
188 case xnn_node_type_bankers_rounding:
189 status = xnn_create_bankers_rounding_nc_f32(
190 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
191 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
192 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
193 node->flags,
194 &runtime->opdata[i].operator_object);
195 if (status != xnn_status_success) {
196 goto error;
197 }
198 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
199 runtime->opdata[i].inputs[0] = node->inputs[0];
200 runtime->opdata[i].outputs[0] = node->outputs[0];
201 break;
202 case xnn_node_type_ceiling:
203 status = xnn_create_ceiling_nc_f32(
204 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
205 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
206 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
207 node->flags,
208 &runtime->opdata[i].operator_object);
209 if (status != xnn_status_success) {
210 goto error;
211 }
212 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
213 runtime->opdata[i].inputs[0] = node->inputs[0];
214 runtime->opdata[i].outputs[0] = node->outputs[0];
215 break;
216 case xnn_node_type_convolution_2d:
217 assert(values[node->inputs[1]].data != NULL);
218 assert(values[node->inputs[2]].data != NULL);
219 if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
220 status = xnn_create_convolution2d_nchw_f32(
221 node->params.convolution_2d.input_padding_top,
222 node->params.convolution_2d.input_padding_right,
223 node->params.convolution_2d.input_padding_bottom,
224 node->params.convolution_2d.input_padding_left,
225 node->params.convolution_2d.kernel_height,
226 node->params.convolution_2d.kernel_width,
227 node->params.convolution_2d.subsampling_height,
228 node->params.convolution_2d.subsampling_width,
229 node->params.convolution_2d.dilation_height,
230 node->params.convolution_2d.dilation_width,
231 node->params.convolution_2d.groups,
232 node->params.convolution_2d.group_input_channels,
233 node->params.convolution_2d.group_output_channels,
234 node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
235 node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
236 values[node->inputs[1]].data,
237 values[node->inputs[2]].data,
238 node->activation.output_min,
239 node->activation.output_max,
240 node->flags | (values[node->inputs[0]].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0),
241 &runtime->opdata[i].operator_object);
242 } else {
243 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
244 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
245 status = xnn_create_convolution2d_nhwc_f32(
246 node->params.convolution_2d.input_padding_top,
247 node->params.convolution_2d.input_padding_right,
248 node->params.convolution_2d.input_padding_bottom,
249 node->params.convolution_2d.input_padding_left,
250 node->params.convolution_2d.kernel_height,
251 node->params.convolution_2d.kernel_width,
252 node->params.convolution_2d.subsampling_height,
253 node->params.convolution_2d.subsampling_width,
254 node->params.convolution_2d.dilation_height,
255 node->params.convolution_2d.dilation_width,
256 node->params.convolution_2d.groups,
257 node->params.convolution_2d.group_input_channels,
258 node->params.convolution_2d.group_output_channels,
259 node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
260 node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
261 values[node->inputs[1]].data,
262 values[node->inputs[2]].data,
263 node->activation.output_min,
264 node->activation.output_max,
265 node->flags,
266 &runtime->opdata[i].operator_object);
267 }
268 if (status != xnn_status_success) {
269 goto error;
270 }
271 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
272 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
273 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
274 runtime->opdata[i].inputs[0] = node->inputs[0];
275 runtime->opdata[i].outputs[0] = node->outputs[0];
276 break;
277 case xnn_node_type_clamp:
278 status = xnn_create_clamp_nc_f32(
279 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
280 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
281 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
282 node->activation.output_min,
283 node->activation.output_max,
284 node->flags,
285 &runtime->opdata[i].operator_object);
286 if (status != xnn_status_success) {
287 goto error;
288 }
289 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
290 runtime->opdata[i].inputs[0] = node->inputs[0];
291 runtime->opdata[i].outputs[0] = node->outputs[0];
292 break;
293 case xnn_node_type_deconvolution_2d:
294 assert(values[node->inputs[1]].data != NULL);
295 assert(values[node->inputs[2]].data != NULL);
296 status = xnn_create_deconvolution2d_nhwc_f32(
297 node->params.deconvolution_2d.padding_top,
298 node->params.deconvolution_2d.padding_right,
299 node->params.deconvolution_2d.padding_bottom,
300 node->params.deconvolution_2d.padding_left,
301 node->params.deconvolution_2d.kernel_height,
302 node->params.deconvolution_2d.kernel_width,
303 node->params.deconvolution_2d.upsampling_height,
304 node->params.deconvolution_2d.upsampling_width,
305 node->params.deconvolution_2d.dilation_height,
306 node->params.deconvolution_2d.dilation_width,
307 node->params.deconvolution_2d.groups,
308 node->params.deconvolution_2d.group_input_channels,
309 node->params.deconvolution_2d.group_output_channels,
310 node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */,
311 node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */,
312 values[node->inputs[1]].data,
313 values[node->inputs[2]].data,
314 node->activation.output_min,
315 node->activation.output_max,
316 node->flags,
317 &runtime->opdata[i].operator_object);
318 if (status != xnn_status_success) {
319 goto error;
320 }
321 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
322 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
323 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
324 runtime->opdata[i].adjustment_height = node->params.deconvolution_2d.adjustment_height;
325 runtime->opdata[i].adjustment_width = node->params.deconvolution_2d.adjustment_width;
326 runtime->opdata[i].inputs[0] = node->inputs[0];
327 runtime->opdata[i].outputs[0] = node->outputs[0];
328 break;
329 case xnn_node_type_depthwise_convolution_2d:
330 assert(values[node->inputs[1]].data != NULL);
331 assert(values[node->inputs[2]].data != NULL);
332 if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
333 assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
334 status = xnn_create_convolution2d_nchw_f32(
335 node->params.depthwise_convolution_2d.input_padding_top,
336 node->params.depthwise_convolution_2d.input_padding_right,
337 node->params.depthwise_convolution_2d.input_padding_bottom,
338 node->params.depthwise_convolution_2d.input_padding_left,
339 node->params.depthwise_convolution_2d.kernel_height,
340 node->params.depthwise_convolution_2d.kernel_width,
341 node->params.depthwise_convolution_2d.subsampling_height,
342 node->params.depthwise_convolution_2d.subsampling_width,
343 node->params.depthwise_convolution_2d.dilation_height,
344 node->params.depthwise_convolution_2d.dilation_width,
345 node->params.depthwise_convolution_2d.input_channels /* groups */,
346 1 /* group_input_channels */,
347 node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
348 node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
349 node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
350 values[node->inputs[1]].data,
351 values[node->inputs[2]].data,
352 node->activation.output_min,
353 node->activation.output_max,
354 node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
355 &runtime->opdata[i].operator_object);
356 } else {
357 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
358 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
359 status = xnn_create_convolution2d_nhwc_f32(
360 node->params.depthwise_convolution_2d.input_padding_top,
361 node->params.depthwise_convolution_2d.input_padding_right,
362 node->params.depthwise_convolution_2d.input_padding_bottom,
363 node->params.depthwise_convolution_2d.input_padding_left,
364 node->params.depthwise_convolution_2d.kernel_height,
365 node->params.depthwise_convolution_2d.kernel_width,
366 node->params.depthwise_convolution_2d.subsampling_height,
367 node->params.depthwise_convolution_2d.subsampling_width,
368 node->params.depthwise_convolution_2d.dilation_height,
369 node->params.depthwise_convolution_2d.dilation_width,
370 node->params.depthwise_convolution_2d.input_channels /* groups */,
371 1 /* group_input_channels */,
372 node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
373 node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
374 node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
375 values[node->inputs[1]].data,
376 values[node->inputs[2]].data,
377 node->activation.output_min,
378 node->activation.output_max,
379 node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
380 &runtime->opdata[i].operator_object);
381 }
382 if (status != xnn_status_success) {
383 goto error;
384 }
385 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
386 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
387 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
388
389 runtime->opdata[i].inputs[0] = node->inputs[0];
390 runtime->opdata[i].outputs[0] = node->outputs[0];
391 break;
392 case xnn_node_type_depth_to_space:
393 status = xnn_status_unsupported_parameter;
394 if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
395 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
396 status = xnn_create_depth_to_space_nchw2nhwc_x32(
397 values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output channels */,
398 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
399 values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output stride */,
400 node->params.depth_to_space.block_size,
401 node->flags,
402 &runtime->opdata[i].operator_object);
403 } else {
404 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
405 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
406 status = xnn_create_depth_to_space_nhwc_x32(
407 values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output channels */,
408 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
409 values[node->outputs[0]].shape.dim[values[node->outputs[0]].shape.num_dims - 1] /* output stride */,
410 node->params.depth_to_space.block_size,
411 node->flags,
412 &runtime->opdata[i].operator_object);
413 }
414 if (status != xnn_status_success) {
415 goto error;
416 }
417 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
418 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
419 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
420 runtime->opdata[i].output_height = values[node->outputs[0]].shape.dim[1];
421 runtime->opdata[i].output_width = values[node->outputs[0]].shape.dim[2];
422 runtime->opdata[i].inputs[0] = node->inputs[0];
423 runtime->opdata[i].outputs[0] = node->outputs[0];
424 break;
425 case xnn_node_type_divide:
426 status = xnn_create_divide_nd_f32(
427 node->activation.output_min,
428 node->activation.output_max,
429 node->flags,
430 &runtime->opdata[i].operator_object);
431 if (status != xnn_status_success) {
432 goto error;
433 }
434 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
435 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
436 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
437 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
438 runtime->opdata[i].inputs[0] = node->inputs[0];
439 runtime->opdata[i].inputs[1] = node->inputs[1];
440 runtime->opdata[i].outputs[0] = node->outputs[0];
441 break;
442 case xnn_node_type_elu:
443 status = xnn_create_elu_nc_f32(
444 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
445 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
446 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
447 node->params.elu.alpha,
448 node->flags,
449 &runtime->opdata[i].operator_object);
450 if (status != xnn_status_success) {
451 goto error;
452 }
453 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
454 runtime->opdata[i].inputs[0] = node->inputs[0];
455 runtime->opdata[i].outputs[0] = node->outputs[0];
456 break;
457 case xnn_node_type_fully_connected:
458 {
459 const size_t num_input_elements = product_all_dims(&values[node->inputs[0]].shape);
460 const size_t output_channels = values[node->inputs[1]].shape.dim[0];
461 const size_t input_channels = values[node->inputs[1]].shape.dim[1];
462 status = xnn_create_fully_connected_nc_f32(
463 input_channels,
464 output_channels,
465 input_channels /* input stride */,
466 output_channels /* output stride */,
467 values[node->inputs[1]].data,
468 values[node->inputs[2]].data,
469 node->activation.output_min,
470 node->activation.output_max,
471 0 /* flags */,
472 &runtime->opdata[i].operator_object);
473 if (status != xnn_status_success) {
474 goto error;
475 }
476 runtime->opdata[i].batch_size = num_input_elements / input_channels;
477 runtime->opdata[i].inputs[0] = node->inputs[0];
478 runtime->opdata[i].outputs[0] = node->outputs[0];
479 break;
480 }
481 case xnn_node_type_floor:
482 status = xnn_create_floor_nc_f32(
483 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
484 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
485 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
486 node->flags,
487 &runtime->opdata[i].operator_object);
488 if (status != xnn_status_success) {
489 goto error;
490 }
491 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
492 runtime->opdata[i].inputs[0] = node->inputs[0];
493 runtime->opdata[i].outputs[0] = node->outputs[0];
494 break;
495 case xnn_node_type_global_average_pooling_2d:
496 if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
497 status = xnn_create_global_average_pooling_ncw_f32(
498 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
499 node->activation.output_min,
500 node->activation.output_max,
501 node->flags,
502 &runtime->opdata[i].operator_object);
503 } else {
504 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
505 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
506 status = xnn_create_global_average_pooling_nwc_f32(
507 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
508 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
509 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
510 node->activation.output_min,
511 node->activation.output_max,
512 node->flags,
513 &runtime->opdata[i].operator_object);
514 }
515 if (status != xnn_status_success) {
516 goto error;
517 }
518 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
519 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[1] * values[node->inputs[0]].shape.dim[2];
520 runtime->opdata[i].inputs[0] = node->inputs[0];
521 runtime->opdata[i].outputs[0] = node->outputs[0];
522 break;
523 case xnn_node_type_hardswish:
524 status = xnn_create_hardswish_nc_f32(
525 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
526 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
527 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
528 node->flags,
529 &runtime->opdata[i].operator_object);
530 if (status != xnn_status_success) {
531 goto error;
532 }
533 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
534 runtime->opdata[i].inputs[0] = node->inputs[0];
535 runtime->opdata[i].outputs[0] = node->outputs[0];
536 break;
537 case xnn_node_type_leaky_relu:
538 status = xnn_create_leaky_relu_nc_f32(
539 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
540 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
541 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
542 node->params.leaky_relu.negative_slope,
543 node->flags,
544 &runtime->opdata[i].operator_object);
545 if (status != xnn_status_success) {
546 goto error;
547 }
548 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
549 runtime->opdata[i].inputs[0] = node->inputs[0];
550 runtime->opdata[i].outputs[0] = node->outputs[0];
551 break;
552 case xnn_node_type_max_pooling_2d:
553 status = xnn_create_max_pooling2d_nhwc_f32(
554 node->params.pooling_2d.padding_top,
555 node->params.pooling_2d.padding_right,
556 node->params.pooling_2d.padding_bottom,
557 node->params.pooling_2d.padding_left,
558 node->params.pooling_2d.pooling_height,
559 node->params.pooling_2d.pooling_width,
560 node->params.pooling_2d.stride_height,
561 node->params.pooling_2d.stride_width,
562 node->params.pooling_2d.dilation_height,
563 node->params.pooling_2d.dilation_width,
564 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
565 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
566 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
567 node->activation.output_min,
568 node->activation.output_max,
569 node->flags,
570 &runtime->opdata[i].operator_object);
571 if (status != xnn_status_success) {
572 goto error;
573 }
574 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
575 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
576 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
577 runtime->opdata[i].inputs[0] = node->inputs[0];
578 runtime->opdata[i].outputs[0] = node->outputs[0];
579 break;
580 case xnn_node_type_maximum2:
581 status = xnn_create_maximum_nd_f32(
582 node->flags,
583 &runtime->opdata[i].operator_object);
584 if (status != xnn_status_success) {
585 goto error;
586 }
587 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
588 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
589 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
590 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
591 runtime->opdata[i].inputs[0] = node->inputs[0];
592 runtime->opdata[i].inputs[1] = node->inputs[1];
593 runtime->opdata[i].outputs[0] = node->outputs[0];
594 break;
595 case xnn_node_type_minimum2:
596 status = xnn_create_minimum_nd_f32(
597 node->flags,
598 &runtime->opdata[i].operator_object);
599 if (status != xnn_status_success) {
600 goto error;
601 }
602 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
603 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
604 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
605 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
606 runtime->opdata[i].inputs[0] = node->inputs[0];
607 runtime->opdata[i].inputs[1] = node->inputs[1];
608 runtime->opdata[i].outputs[0] = node->outputs[0];
609 break;
610 case xnn_node_type_multiply2:
611 status = xnn_create_multiply_nd_f32(
612 node->activation.output_min,
613 node->activation.output_max,
614 node->flags,
615 &runtime->opdata[i].operator_object);
616 if (status != xnn_status_success) {
617 goto error;
618 }
619 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
620 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
621 if (values[node->outputs[0]].layout == xnn_layout_type_nchw) {
622 assert(values[node->inputs[0]].layout == xnn_layout_type_nchw);
623 assert(values[node->inputs[1]].layout == xnn_layout_type_nchw);
624 runtime->opdata[i].shape1.dim[0] = values[node->inputs[0]].shape.dim[0];
625 runtime->opdata[i].shape1.dim[1] = values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
626 if (values[node->inputs[0]].shape.num_dims > 2) {
627 memcpy(&runtime->opdata[i].shape1.dim[2], &values[node->inputs[0]].shape.dim[1], (values[node->inputs[0]].shape.num_dims - 2) * sizeof(size_t));
628 }
629 runtime->opdata[i].shape2.dim[0] = values[node->inputs[1]].shape.dim[0];
630 runtime->opdata[i].shape2.dim[1] = values[node->inputs[1]].shape.dim[values[node->inputs[0]].shape.num_dims - 1];
631 if (values[node->inputs[0]].shape.num_dims > 2) {
632 memcpy(&runtime->opdata[i].shape2.dim[2], &values[node->inputs[1]].shape.dim[1], (values[node->inputs[1]].shape.num_dims - 2) * sizeof(size_t));
633 }
634 } else {
635 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
636 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
637 assert(values[node->inputs[1]].layout == xnn_layout_type_nhwc);
638 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
639 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
640 }
641 runtime->opdata[i].inputs[0] = node->inputs[0];
642 runtime->opdata[i].inputs[1] = node->inputs[1];
643 runtime->opdata[i].outputs[0] = node->outputs[0];
644 break;
645 case xnn_node_type_negate:
646 status = xnn_create_negate_nc_f32(
647 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
648 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
649 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
650 node->flags,
651 &runtime->opdata[i].operator_object);
652 if (status != xnn_status_success) {
653 goto error;
654 }
655 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
656 runtime->opdata[i].inputs[0] = node->inputs[0];
657 runtime->opdata[i].outputs[0] = node->outputs[0];
658 break;
659 case xnn_node_type_prelu:
660 status = xnn_create_prelu_nc_f32(
661 values[node->inputs[1]].shape.dim[values[node->inputs[1]].shape.num_dims - 1] /* channels */,
662 values[node->inputs[1]].shape.dim[values[node->inputs[1]].shape.num_dims - 1] /* input stride */,
663 values[node->inputs[1]].shape.dim[values[node->inputs[1]].shape.num_dims - 1] /* output stride */,
664 values[node->inputs[1]].data /* negative slope */,
665 node->flags,
666 &runtime->opdata[i].operator_object);
667 if (status != xnn_status_success) {
668 goto error;
669 }
670 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
671 runtime->opdata[i].inputs[0] = node->inputs[0];
672 runtime->opdata[i].outputs[0] = node->outputs[0];
673 break;
674 case xnn_node_type_sigmoid:
675 status = xnn_create_sigmoid_nc_f32(
676 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
677 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
678 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
679 node->flags,
680 &runtime->opdata[i].operator_object);
681 if (status != xnn_status_success) {
682 goto error;
683 }
684 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
685 runtime->opdata[i].inputs[0] = node->inputs[0];
686 runtime->opdata[i].outputs[0] = node->outputs[0];
687 break;
688 case xnn_node_type_softmax:
689 status = xnn_create_softmax_nc_f32(
690 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
691 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
692 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
693 node->flags,
694 &runtime->opdata[i].operator_object);
695 if (status != xnn_status_success) {
696 goto error;
697 }
698 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
699 runtime->opdata[i].inputs[0] = node->inputs[0];
700 runtime->opdata[i].outputs[0] = node->outputs[0];
701 break;
702 case xnn_node_type_static_constant_pad:
703 status = xnn_create_constant_pad_nd_x32(
704 &node->params.static_pad.padding_value,
705 node->flags,
706 &runtime->opdata[i].operator_object);
707 if (status != xnn_status_success) {
708 goto error;
709 }
710 runtime->opdata[i].shape1 = values[node->inputs[0]].shape;
711 memcpy(runtime->opdata[i].pre_paddings, node->params.static_pad.pre_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
712 memcpy(runtime->opdata[i].post_paddings, node->params.static_pad.post_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
713 runtime->opdata[i].inputs[0] = node->inputs[0];
714 runtime->opdata[i].outputs[0] = node->outputs[0];
715 break;
716 case xnn_node_type_static_reshape:
717 status = xnn_create_copy_nc_x32(
718 1 /* channels */,
719 1 /* input stride */,
720 1 /* output stride */,
721 node->flags,
722 &runtime->opdata[i].operator_object);
723 if (status != xnn_status_success) {
724 goto error;
725 }
726 runtime->opdata[i].batch_size = product_all_dims(&values[node->inputs[0]].shape);
727 runtime->opdata[i].inputs[0] = node->inputs[0];
728 runtime->opdata[i].outputs[0] = node->outputs[0];
729 break;
730 case xnn_node_type_static_resize_bilinear_2d:
731 if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
732 status = xnn_create_resize_bilinear2d_nchw_f32(
733 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
734 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
735 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
736 node->flags,
737 &runtime->opdata[i].operator_object);
738 } else {
739 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
740 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
741 status = xnn_create_resize_bilinear2d_nhwc_f32(
742 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
743 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
744 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
745 node->flags,
746 &runtime->opdata[i].operator_object);
747 }
748 if (status != xnn_status_success) {
749 goto error;
750 }
751 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
752 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
753 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
754 runtime->opdata[i].output_height = values[node->outputs[0]].shape.dim[1];
755 runtime->opdata[i].output_width = values[node->outputs[0]].shape.dim[2];
756 runtime->opdata[i].inputs[0] = node->inputs[0];
757 runtime->opdata[i].outputs[0] = node->outputs[0];
758 break;
759 case xnn_node_type_square:
760 status = xnn_create_square_nc_f32(
761 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
762 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
763 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
764 node->flags,
765 &runtime->opdata[i].operator_object);
766 if (status != xnn_status_success) {
767 goto error;
768 }
769 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
770 runtime->opdata[i].inputs[0] = node->inputs[0];
771 runtime->opdata[i].outputs[0] = node->outputs[0];
772 break;
773 case xnn_node_type_square_root:
774 status = xnn_create_square_root_nc_f32(
775 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
776 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
777 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
778 node->flags,
779 &runtime->opdata[i].operator_object);
780 if (status != xnn_status_success) {
781 goto error;
782 }
783 runtime->opdata[i].batch_size = product_non_channel_dims(&values[node->inputs[0]].shape);
784 runtime->opdata[i].inputs[0] = node->inputs[0];
785 runtime->opdata[i].outputs[0] = node->outputs[0];
786 break;
787 case xnn_node_type_squared_difference:
788 status = xnn_create_squared_difference_nd_f32(
789 node->flags,
790 &runtime->opdata[i].operator_object);
791 if (status != xnn_status_success) {
792 goto error;
793 }
794 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
795 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
796 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
797 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
798 runtime->opdata[i].inputs[0] = node->inputs[0];
799 runtime->opdata[i].inputs[1] = node->inputs[1];
800 runtime->opdata[i].outputs[0] = node->outputs[0];
801 break;
802 case xnn_node_type_subtract:
803 status = xnn_create_subtract_nd_f32(
804 node->activation.output_min,
805 node->activation.output_max,
806 node->flags,
807 &runtime->opdata[i].operator_object);
808 if (status != xnn_status_success) {
809 goto error;
810 }
811 runtime->opdata[i].shape1.num_dims = values[node->inputs[0]].shape.num_dims;
812 runtime->opdata[i].shape2.num_dims = values[node->inputs[1]].shape.num_dims;
813 memcpy(runtime->opdata[i].shape1.dim, values[node->inputs[0]].shape.dim, values[node->inputs[0]].shape.num_dims * sizeof(size_t));
814 memcpy(runtime->opdata[i].shape2.dim, values[node->inputs[1]].shape.dim, values[node->inputs[1]].shape.num_dims * sizeof(size_t));
815 runtime->opdata[i].inputs[0] = node->inputs[0];
816 runtime->opdata[i].inputs[1] = node->inputs[1];
817 runtime->opdata[i].outputs[0] = node->outputs[0];
818 break;
819 case xnn_node_type_unpooling_2d:
820 status = xnn_create_unpooling2d_nhwc_x32(
821 node->params.pooling_2d.padding_top,
822 node->params.pooling_2d.padding_right,
823 node->params.pooling_2d.padding_bottom,
824 node->params.pooling_2d.padding_left,
825 node->params.pooling_2d.pooling_height,
826 node->params.pooling_2d.pooling_width,
827 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* channels */,
828 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* input stride */,
829 values[node->inputs[0]].shape.dim[values[node->inputs[0]].shape.num_dims - 1] /* output stride */,
830 node->flags,
831 &runtime->opdata[i].operator_object);
832 if (status != xnn_status_success) {
833 goto error;
834 }
835 runtime->opdata[i].batch_size = values[node->inputs[0]].shape.dim[0];
836 runtime->opdata[i].input_height = values[node->inputs[0]].shape.dim[1];
837 runtime->opdata[i].input_width = values[node->inputs[0]].shape.dim[2];
838 runtime->opdata[i].inputs[0] = node->inputs[0];
839 runtime->opdata[i].inputs[1] = node->inputs[1];
840 runtime->opdata[i].outputs[0] = node->outputs[0];
841 break;
842 }
843 }
844
845 runtime->blobs = xnn_allocate_zero_memory(sizeof(struct xnn_blob) * subgraph->num_values);
846 if (runtime->blobs == NULL) {
847 xnn_log_error("failed to allocate %zu bytes for blob descriptors",
848 sizeof(struct xnn_blob) * subgraph->num_values);
849 goto error;
850 }
851 runtime->num_blobs = subgraph->num_values;
852
853 struct xnn_value_allocation_tracker mem_alloc_tracker;
854 xnn_init_value_allocation_tracker(&mem_alloc_tracker, subgraph);
855
856 for (uint32_t i = 0; i < subgraph->num_values; i++) {
857 const struct xnn_value* value = &subgraph->values[i];
858 struct xnn_blob* blob = &runtime->blobs[i];
859 if (value->datatype != xnn_datatype_invalid && value->type == xnn_value_type_dense_tensor) {
860 blob->size = xnn_tensor_get_size(subgraph, i);
861 blob->data = (void*) value->data;
862 if (blob->data == NULL) {
863 if ((value->flags & (XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT)) == 0) {
864 // Value is purely internal to the runtime, and must be allocated in its workspace.
865 xnn_add_value_allocation_tracker(&mem_alloc_tracker, i, round_up_po2(blob->size, XNN_EXTRA_BYTES));
866 } else {
867 // Value is non-static and external to the runtime: must be specified via a call to xnn_setup_runtime.
868 blob->external = true;
869 }
870 }
871 }
872 }
873 xnn_plan_value_allocation_tracker(&mem_alloc_tracker);
874
875 if (mem_alloc_tracker.mem_arena_size != 0) {
876 // XNN_EXTRA_BYTES ensures that out-of-bound reads of intermediate values don't segfault.
877 const size_t mem_arena_size = mem_alloc_tracker.mem_arena_size + XNN_EXTRA_BYTES;
878 runtime->workspace = xnn_allocate_simd_memory(mem_arena_size);
879 if (runtime->workspace == NULL) {
880 xnn_log_error("failed to allocate %zu bytes for runtime workspace", mem_arena_size);
881 xnn_release_value_allocation_tracker(&mem_alloc_tracker);
882 goto error;
883 }
884 for (size_t i = 0; i < subgraph->num_values; i++) {
885 const struct xnn_value* value = &subgraph->values[i];
886 struct xnn_blob* blob = &runtime->blobs[i];
887 if (value->datatype != xnn_datatype_invalid && value->type == xnn_value_type_dense_tensor) {
888 if (value->data == NULL && !blob->external) {
889 // Value is purely internal to the runtime, allocate it in the workspace.
890 blob->data = (void*) ((uintptr_t) runtime->workspace + mem_alloc_tracker.usage[i].alloc_offset);
891 }
892 }
893 }
894 }
895 xnn_release_value_allocation_tracker(&mem_alloc_tracker);
896
897 runtime->threadpool = threadpool;
898
899 *runtime_out = runtime;
900 return xnn_status_success;
901
902 error:
903 xnn_delete_runtime(runtime);
904 return status;
905 }
906
xnn_setup_runtime(xnn_runtime_t runtime,size_t num_external_values,const struct xnn_external_value * external_values)907 enum xnn_status xnn_setup_runtime(
908 xnn_runtime_t runtime,
909 size_t num_external_values,
910 const struct xnn_external_value* external_values)
911 {
912 // Validate inputs without changing internal state.
913 // This ensures that runtime stays in consistent state in case validation fails midway.
914 for (size_t i = 0; i < num_external_values; i++) {
915 const struct xnn_external_value* external_value = &external_values[i];
916 const uint32_t value_id = external_value->id;
917 if (value_id >= runtime->num_blobs) {
918 xnn_log_error("failed to setup runtime: out-of-bounds ID %" PRIu32 " in external value #%zu",
919 value_id, i);
920 return xnn_status_invalid_parameter;
921 }
922
923 const struct xnn_blob* blob = &runtime->blobs[value_id];
924 if (!blob->external) {
925 xnn_log_error("failed to setup runtime: Value %" PRIu32 " is not external", value_id);
926 return xnn_status_invalid_parameter;
927 }
928 }
929
930 // Apply runtime state changes.
931 for (size_t i = 0; i < num_external_values; i++) {
932 const struct xnn_external_value* external_value = &external_values[i];
933 const uint32_t value_id = external_value->id;
934 struct xnn_blob* blob = &runtime->blobs[value_id];
935 blob->data = external_value->data;
936 }
937
938 for (size_t i = 0; i < runtime->num_ops; i++) {
939 const struct xnn_operator_data* opdata = &runtime->opdata[i];
940 if (opdata->operator_object == NULL) {
941 // Operator was removed during optimization
942 continue;
943 }
944
945 enum xnn_status status = xnn_status_success;
946 switch (opdata->operator_object->type) {
947 case xnn_operator_type_abs_nc_f32:
948 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
949 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
950 status = xnn_setup_abs_nc_f32(
951 opdata->operator_object,
952 opdata->batch_size,
953 runtime->blobs[opdata->inputs[0]].data,
954 runtime->blobs[opdata->outputs[0]].data,
955 runtime->threadpool);
956 break;
957 case xnn_operator_type_add_nd_f32:
958 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
959 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
960 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
961 status = xnn_setup_add_nd_f32(
962 opdata->operator_object,
963 opdata->shape1.num_dims,
964 opdata->shape1.dim,
965 opdata->shape2.num_dims,
966 opdata->shape2.dim,
967 runtime->blobs[opdata->inputs[0]].data,
968 runtime->blobs[opdata->inputs[1]].data,
969 runtime->blobs[opdata->outputs[0]].data,
970 runtime->threadpool);
971 break;
972 case xnn_operator_type_argmax_pooling_nhwc_f32:
973 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
974 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
975 assert(runtime->blobs[opdata->outputs[1]].data != NULL);
976 status = xnn_setup_argmax_pooling2d_nhwc_f32(
977 opdata->operator_object,
978 opdata->batch_size,
979 opdata->input_height,
980 opdata->input_width,
981 runtime->blobs[opdata->inputs[0]].data,
982 runtime->blobs[opdata->outputs[0]].data,
983 runtime->blobs[opdata->outputs[1]].data,
984 runtime->threadpool);
985 break;
986 case xnn_operator_type_average_pooling_nhwc_f32:
987 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
988 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
989 status = xnn_setup_average_pooling2d_nhwc_f32(
990 opdata->operator_object,
991 opdata->batch_size,
992 opdata->input_height,
993 opdata->input_width,
994 runtime->blobs[opdata->inputs[0]].data,
995 runtime->blobs[opdata->outputs[0]].data,
996 runtime->threadpool);
997 break;
998 case xnn_operator_type_bankers_rounding_nc_f32:
999 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1000 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1001 status = xnn_setup_bankers_rounding_nc_f32(
1002 opdata->operator_object,
1003 opdata->batch_size,
1004 runtime->blobs[opdata->inputs[0]].data,
1005 runtime->blobs[opdata->outputs[0]].data,
1006 runtime->threadpool);
1007 break;
1008 case xnn_operator_type_ceiling_nc_f32:
1009 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1010 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1011 status = xnn_setup_ceiling_nc_f32(
1012 opdata->operator_object,
1013 opdata->batch_size,
1014 runtime->blobs[opdata->inputs[0]].data,
1015 runtime->blobs[opdata->outputs[0]].data,
1016 runtime->threadpool);
1017 break;
1018 case xnn_operator_type_constant_pad_nd_x32:
1019 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1020 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1021 status = xnn_setup_constant_pad_nd_x32(
1022 opdata->operator_object,
1023 opdata->shape1.num_dims,
1024 opdata->shape1.dim,
1025 opdata->pre_paddings,
1026 opdata->post_paddings,
1027 runtime->blobs[opdata->inputs[0]].data,
1028 runtime->blobs[opdata->outputs[0]].data,
1029 runtime->threadpool);
1030 break;
1031 case xnn_operator_type_convolution_nchw_f32:
1032 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1033 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1034 status = xnn_setup_convolution2d_nchw_f32(
1035 opdata->operator_object,
1036 opdata->batch_size,
1037 opdata->input_height,
1038 opdata->input_width,
1039 runtime->blobs[opdata->inputs[0]].data,
1040 runtime->blobs[opdata->outputs[0]].data,
1041 runtime->threadpool);
1042 break;
1043 case xnn_operator_type_convolution_nhwc_f32:
1044 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1045 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1046 status = xnn_setup_convolution2d_nhwc_f32(
1047 opdata->operator_object,
1048 opdata->batch_size,
1049 opdata->input_height,
1050 opdata->input_width,
1051 runtime->blobs[opdata->inputs[0]].data,
1052 runtime->blobs[opdata->outputs[0]].data,
1053 runtime->threadpool);
1054 break;
1055 case xnn_operator_type_copy_nc_x32:
1056 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1057 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1058 status = xnn_setup_copy_nc_x32(
1059 opdata->operator_object,
1060 opdata->batch_size,
1061 runtime->blobs[opdata->inputs[0]].data,
1062 runtime->blobs[opdata->outputs[0]].data,
1063 runtime->threadpool);
1064 break;
1065 case xnn_operator_type_clamp_nc_f32:
1066 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1067 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1068 status = xnn_setup_clamp_nc_f32(
1069 opdata->operator_object,
1070 opdata->batch_size,
1071 runtime->blobs[opdata->inputs[0]].data,
1072 runtime->blobs[opdata->outputs[0]].data,
1073 runtime->threadpool);
1074 break;
1075 case xnn_operator_type_deconvolution_nhwc_f32:
1076 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1077 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1078 status = xnn_setup_deconvolution2d_nhwc_f32(
1079 opdata->operator_object,
1080 opdata->batch_size,
1081 opdata->input_height,
1082 opdata->input_width,
1083 opdata->adjustment_height,
1084 opdata->adjustment_width,
1085 runtime->blobs[opdata->inputs[0]].data,
1086 runtime->blobs[opdata->outputs[0]].data,
1087 runtime->threadpool);
1088 break;
1089 case xnn_operator_type_depth_to_space_nchw2nhwc_x32:
1090 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1091 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1092 status = xnn_setup_depth_to_space_nchw2nhwc_x32(
1093 opdata->operator_object,
1094 opdata->batch_size,
1095 opdata->input_height,
1096 opdata->input_width,
1097 runtime->blobs[opdata->inputs[0]].data,
1098 runtime->blobs[opdata->outputs[0]].data,
1099 runtime->threadpool);
1100 break;
1101 case xnn_operator_type_depth_to_space_nhwc_x32:
1102 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1103 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1104 status = xnn_setup_depth_to_space_nhwc_x32(
1105 opdata->operator_object,
1106 opdata->batch_size,
1107 opdata->input_height,
1108 opdata->input_width,
1109 runtime->blobs[opdata->inputs[0]].data,
1110 runtime->blobs[opdata->outputs[0]].data,
1111 runtime->threadpool);
1112 break;
1113 case xnn_operator_type_divide_nd_f32:
1114 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1115 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1116 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1117 status = xnn_setup_divide_nd_f32(
1118 opdata->operator_object,
1119 opdata->shape1.num_dims,
1120 opdata->shape1.dim,
1121 opdata->shape2.num_dims,
1122 opdata->shape2.dim,
1123 runtime->blobs[opdata->inputs[0]].data,
1124 runtime->blobs[opdata->inputs[1]].data,
1125 runtime->blobs[opdata->outputs[0]].data,
1126 runtime->threadpool);
1127 break;
1128 case xnn_operator_type_elu_nc_f32:
1129 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1130 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1131 status = xnn_setup_elu_nc_f32(
1132 opdata->operator_object,
1133 opdata->batch_size,
1134 runtime->blobs[opdata->inputs[0]].data,
1135 runtime->blobs[opdata->outputs[0]].data,
1136 runtime->threadpool);
1137 break;
1138 case xnn_operator_type_fully_connected_nc_f32:
1139 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1140 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1141 status = xnn_setup_fully_connected_nc_f32(
1142 opdata->operator_object,
1143 opdata->batch_size,
1144 runtime->blobs[opdata->inputs[0]].data,
1145 runtime->blobs[opdata->outputs[0]].data,
1146 runtime->threadpool);
1147 break;
1148 case xnn_operator_type_floor_nc_f32:
1149 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1150 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1151 status = xnn_setup_floor_nc_f32(
1152 opdata->operator_object,
1153 opdata->batch_size,
1154 runtime->blobs[opdata->inputs[0]].data,
1155 runtime->blobs[opdata->outputs[0]].data,
1156 runtime->threadpool);
1157 break;
1158 case xnn_operator_type_global_average_pooling_ncw_f32:
1159 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1160 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1161 status = xnn_setup_global_average_pooling_ncw_f32(
1162 opdata->operator_object,
1163 opdata->batch_size,
1164 opdata->input_width,
1165 runtime->blobs[opdata->inputs[0]].data,
1166 runtime->blobs[opdata->outputs[0]].data,
1167 runtime->threadpool);
1168 break;
1169 case xnn_operator_type_global_average_pooling_nwc_f32:
1170 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1171 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1172 status = xnn_setup_global_average_pooling_nwc_f32(
1173 opdata->operator_object,
1174 opdata->batch_size,
1175 opdata->input_width,
1176 runtime->blobs[opdata->inputs[0]].data,
1177 runtime->blobs[opdata->outputs[0]].data,
1178 runtime->threadpool);
1179 break;
1180 case xnn_operator_type_hardswish_nc_f32:
1181 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1182 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1183 status = xnn_setup_hardswish_nc_f32(
1184 opdata->operator_object,
1185 opdata->batch_size,
1186 runtime->blobs[opdata->inputs[0]].data,
1187 runtime->blobs[opdata->outputs[0]].data,
1188 runtime->threadpool);
1189 break;
1190 case xnn_operator_type_leaky_relu_nc_f32:
1191 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1192 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1193 status = xnn_setup_leaky_relu_nc_f32(
1194 opdata->operator_object,
1195 opdata->batch_size,
1196 runtime->blobs[opdata->inputs[0]].data,
1197 runtime->blobs[opdata->outputs[0]].data,
1198 runtime->threadpool);
1199 break;
1200 case xnn_operator_type_max_pooling_nhwc_f32:
1201 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1202 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1203 status = xnn_setup_max_pooling2d_nhwc_f32(
1204 opdata->operator_object,
1205 opdata->batch_size,
1206 opdata->input_height,
1207 opdata->input_width,
1208 runtime->blobs[opdata->inputs[0]].data,
1209 runtime->blobs[opdata->outputs[0]].data,
1210 runtime->threadpool);
1211 break;
1212 case xnn_operator_type_maximum_nd_f32:
1213 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1214 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1215 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1216 status = xnn_setup_maximum_nd_f32(
1217 opdata->operator_object,
1218 opdata->shape1.num_dims,
1219 opdata->shape1.dim,
1220 opdata->shape2.num_dims,
1221 opdata->shape2.dim,
1222 runtime->blobs[opdata->inputs[0]].data,
1223 runtime->blobs[opdata->inputs[1]].data,
1224 runtime->blobs[opdata->outputs[0]].data,
1225 runtime->threadpool);
1226 break;
1227 case xnn_operator_type_minimum_nd_f32:
1228 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1229 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1230 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1231 status = xnn_setup_minimum_nd_f32(
1232 opdata->operator_object,
1233 opdata->shape1.num_dims,
1234 opdata->shape1.dim,
1235 opdata->shape2.num_dims,
1236 opdata->shape2.dim,
1237 runtime->blobs[opdata->inputs[0]].data,
1238 runtime->blobs[opdata->inputs[1]].data,
1239 runtime->blobs[opdata->outputs[0]].data,
1240 runtime->threadpool);
1241 break;
1242 case xnn_operator_type_multiply_nd_f32:
1243 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1244 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1245 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1246 status = xnn_setup_multiply_nd_f32(
1247 opdata->operator_object,
1248 opdata->shape1.num_dims,
1249 opdata->shape1.dim,
1250 opdata->shape2.num_dims,
1251 opdata->shape2.dim,
1252 runtime->blobs[opdata->inputs[0]].data,
1253 runtime->blobs[opdata->inputs[1]].data,
1254 runtime->blobs[opdata->outputs[0]].data,
1255 runtime->threadpool);
1256 break;
1257 case xnn_operator_type_negate_nc_f32:
1258 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1259 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1260 status = xnn_setup_negate_nc_f32(
1261 opdata->operator_object,
1262 opdata->batch_size,
1263 runtime->blobs[opdata->inputs[0]].data,
1264 runtime->blobs[opdata->outputs[0]].data,
1265 runtime->threadpool);
1266 break;
1267 case xnn_operator_type_prelu_nc_f32:
1268 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1269 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1270 status = xnn_setup_prelu_nc_f32(
1271 opdata->operator_object,
1272 opdata->batch_size,
1273 runtime->blobs[opdata->inputs[0]].data,
1274 runtime->blobs[opdata->outputs[0]].data,
1275 runtime->threadpool);
1276 break;
1277 case xnn_operator_type_resize_bilinear_nchw_f32:
1278 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1279 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1280 status = xnn_setup_resize_bilinear2d_nchw_f32(
1281 opdata->operator_object,
1282 opdata->batch_size,
1283 opdata->input_height,
1284 opdata->input_width,
1285 opdata->output_height,
1286 opdata->output_width,
1287 runtime->blobs[opdata->inputs[0]].data,
1288 runtime->blobs[opdata->outputs[0]].data,
1289 runtime->threadpool);
1290 break;
1291 case xnn_operator_type_resize_bilinear_nhwc_f32:
1292 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1293 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1294 status = xnn_setup_resize_bilinear2d_nhwc_f32(
1295 opdata->operator_object,
1296 opdata->batch_size,
1297 opdata->input_height,
1298 opdata->input_width,
1299 opdata->output_height,
1300 opdata->output_width,
1301 runtime->blobs[opdata->inputs[0]].data,
1302 runtime->blobs[opdata->outputs[0]].data,
1303 runtime->threadpool);
1304 break;
1305 case xnn_operator_type_sigmoid_nc_f32:
1306 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1307 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1308 status = xnn_setup_sigmoid_nc_f32(
1309 opdata->operator_object,
1310 opdata->batch_size,
1311 runtime->blobs[opdata->inputs[0]].data,
1312 runtime->blobs[opdata->outputs[0]].data,
1313 runtime->threadpool);
1314 break;
1315 case xnn_operator_type_softmax_nc_f32:
1316 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1317 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1318 status = xnn_setup_softmax_nc_f32(
1319 opdata->operator_object,
1320 opdata->batch_size,
1321 runtime->blobs[opdata->inputs[0]].data,
1322 runtime->blobs[opdata->outputs[0]].data,
1323 runtime->threadpool);
1324 break;
1325 case xnn_operator_type_square_nc_f32:
1326 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1327 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1328 status = xnn_setup_square_nc_f32(
1329 opdata->operator_object,
1330 opdata->batch_size,
1331 runtime->blobs[opdata->inputs[0]].data,
1332 runtime->blobs[opdata->outputs[0]].data,
1333 runtime->threadpool);
1334 break;
1335 case xnn_operator_type_square_root_nc_f32:
1336 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1337 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1338 status = xnn_setup_square_root_nc_f32(
1339 opdata->operator_object,
1340 opdata->batch_size,
1341 runtime->blobs[opdata->inputs[0]].data,
1342 runtime->blobs[opdata->outputs[0]].data,
1343 runtime->threadpool);
1344 break;
1345 case xnn_operator_type_squared_difference_nd_f32:
1346 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1347 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1348 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1349 status = xnn_setup_squared_difference_nd_f32(
1350 opdata->operator_object,
1351 opdata->shape1.num_dims,
1352 opdata->shape1.dim,
1353 opdata->shape2.num_dims,
1354 opdata->shape2.dim,
1355 runtime->blobs[opdata->inputs[0]].data,
1356 runtime->blobs[opdata->inputs[1]].data,
1357 runtime->blobs[opdata->outputs[0]].data,
1358 runtime->threadpool);
1359 break;
1360 case xnn_operator_type_subtract_nd_f32:
1361 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1362 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1363 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1364 status = xnn_setup_subtract_nd_f32(
1365 opdata->operator_object,
1366 opdata->shape1.num_dims,
1367 opdata->shape1.dim,
1368 opdata->shape2.num_dims,
1369 opdata->shape2.dim,
1370 runtime->blobs[opdata->inputs[0]].data,
1371 runtime->blobs[opdata->inputs[1]].data,
1372 runtime->blobs[opdata->outputs[0]].data,
1373 runtime->threadpool);
1374 break;
1375 case xnn_operator_type_unpooling_nhwc_x32:
1376 assert(runtime->blobs[opdata->inputs[0]].data != NULL);
1377 assert(runtime->blobs[opdata->inputs[1]].data != NULL);
1378 assert(runtime->blobs[opdata->outputs[0]].data != NULL);
1379 status = xnn_setup_unpooling2d_nhwc_x32(
1380 opdata->operator_object,
1381 opdata->batch_size,
1382 opdata->input_height,
1383 opdata->input_width,
1384 runtime->blobs[opdata->inputs[0]].data,
1385 runtime->blobs[opdata->inputs[1]].data,
1386 runtime->blobs[opdata->outputs[0]].data,
1387 runtime->threadpool);
1388 break;
1389 default:
1390 xnn_log_fatal("unexpected operator type %s in operator #%zu",
1391 xnn_operator_type_to_string(opdata->operator_object->type), i);
1392 XNN_UNREACHABLE;
1393 }
1394 if (status != xnn_status_success) {
1395 xnn_log_error("failed to setup runtime: error in operator #%zu", i);
1396 return status;
1397 }
1398 }
1399
1400 return xnn_status_success;
1401 }
1402
xnn_invoke_runtime(xnn_runtime_t runtime)1403 enum xnn_status xnn_invoke_runtime(
1404 xnn_runtime_t runtime)
1405 {
1406 for (size_t i = 0; i < runtime->num_ops; i++) {
1407 if (runtime->opdata[i].operator_object == NULL) {
1408 // Operator was removed after fusion
1409 continue;
1410 }
1411
1412 const enum xnn_status status = xnn_run_operator(runtime->opdata[i].operator_object, runtime->threadpool);
1413 if (status != xnn_status_success) {
1414 return status;
1415 }
1416 }
1417 return xnn_status_success;
1418 }
1419
xnn_delete_runtime(xnn_runtime_t runtime)1420 enum xnn_status xnn_delete_runtime(
1421 xnn_runtime_t runtime)
1422 {
1423 if (runtime != NULL) {
1424 if (runtime->opdata != NULL) {
1425 for (size_t i = 0; i < runtime->num_ops; i++) {
1426 xnn_delete_operator(runtime->opdata[i].operator_object);
1427 }
1428 xnn_release_memory(runtime->opdata);
1429
1430 xnn_release_memory(runtime->blobs);
1431 xnn_release_simd_memory(runtime->workspace);
1432 }
1433 xnn_release_memory(runtime);
1434 }
1435 return xnn_status_success;
1436 }
1437