1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19
20 #include <gtest/gtest.h>
21
22 template <typename T> class Concatenate3Test : public ::testing::Test {
23 protected:
Concatenate3Test()24 Concatenate3Test()
25 {
26 random_device = std::unique_ptr<std::random_device>(new std::random_device());
27 rng = std::mt19937((*random_device)());
28 shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29 dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30 f32dist = std::uniform_real_distribution<float>();
31 i8dist =
32 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33 u8dist =
34 std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35 scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36
37 input1_dims = RandomShape();
38 axis = RandomAxis(input1_dims);
39 input2_dims = RandomShape(input1_dims, axis);
40 input3_dims = RandomShape(input1_dims, axis);
41 output_dims = input1_dims;
42 output_dims[axis] = input1_dims[axis] + input2_dims[axis] + input3_dims[axis];
43
44 input1 = std::vector<T>(NumElements(input1_dims));
45 input2 = std::vector<T>(NumElements(input2_dims));
46 input3 = std::vector<T>(NumElements(input3_dims));
47 operator_output = std::vector<T>(NumElements(output_dims));
48 subgraph_output = std::vector<T>(NumElements(output_dims));
49
50 signed_zero_point = i8dist(rng);
51 unsigned_zero_point = u8dist(rng);
52 scale = scale_dist(rng);
53
54 batch_size = 1;
55 channels_1 = 1;
56 channels_2 = 1;
57 channels_3 = 1;
58 for (size_t i = 0; i < axis; i++) {
59 batch_size *= output_dims[i];
60 }
61
62 for (size_t i = axis; i < input1_dims.size(); i++) {
63 channels_1 *= input1_dims[i];
64 channels_2 *= input2_dims[i];
65 channels_3 *= input3_dims[i];
66 }
67 output_stride = channels_1 + channels_2 + channels_3;
68 }
69
RandomShape()70 std::vector<size_t> RandomShape()
71 {
72 std::vector<size_t> dims(shape_dist(rng));
73 std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
74 return dims;
75 }
76
RandomShape(const std::vector<size_t> base_dims,size_t axis)77 std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
78 {
79 auto dims = base_dims;
80 dims[axis] = dim_dist(rng);
81 return dims;
82 }
83
RandomAxis(const std::vector<size_t> & dims)84 size_t RandomAxis(const std::vector<size_t>& dims)
85 {
86 return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
87 }
88
NumElements(const std::vector<size_t> & dims)89 size_t NumElements(const std::vector<size_t>& dims)
90 {
91 return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
92 }
93
94 std::unique_ptr<std::random_device> random_device;
95 std::mt19937 rng;
96 std::uniform_int_distribution<size_t> shape_dist;
97 std::uniform_int_distribution<size_t> dim_dist;
98 std::uniform_real_distribution<float> f32dist;
99 std::uniform_int_distribution<int32_t> i8dist;
100 std::uniform_int_distribution<int32_t> u8dist;
101 std::uniform_real_distribution<float> scale_dist;
102
103 uint32_t input1_id;
104 uint32_t input2_id;
105 uint32_t input3_id;
106 uint32_t output_id;
107
108 std::vector<size_t> input1_dims;
109 std::vector<size_t> input2_dims;
110 std::vector<size_t> input3_dims;
111 std::vector<size_t> output_dims;
112
113 size_t axis;
114 size_t batch_size;
115 size_t channels_1;
116 size_t channels_2;
117 size_t channels_3;
118 size_t output_stride;
119
120 int32_t signed_zero_point;
121 int32_t unsigned_zero_point;
122 float scale;
123
124 std::vector<T> input1;
125 std::vector<T> input2;
126 std::vector<T> input3;
127 std::vector<T> operator_output;
128 std::vector<T> subgraph_output;
129 };
130
131 using Concatenate3TestQS8 = Concatenate3Test<int8_t>;
132 using Concatenate3TestQU8 = Concatenate3Test<uint8_t>;
133 using Concatenate3TestF32 = Concatenate3Test<float>;
134
TEST_F(Concatenate3TestQS8,define)135 TEST_F(Concatenate3TestQS8, define)
136 {
137 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
138
139 xnn_subgraph_t subgraph = nullptr;
140 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
141 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
142
143 input1_id = XNN_INVALID_NODE_ID;
144 ASSERT_EQ(
145 xnn_status_success,
146 xnn_define_quantized_tensor_value(
147 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
148 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
149 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
150
151 input2_id = XNN_INVALID_NODE_ID;
152 ASSERT_EQ(
153 xnn_status_success,
154 xnn_define_quantized_tensor_value(
155 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
156 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
157 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
158
159 input3_id = XNN_INVALID_NODE_ID;
160 ASSERT_EQ(
161 xnn_status_success,
162 xnn_define_quantized_tensor_value(
163 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
164 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
165 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
166
167 output_id = XNN_INVALID_NODE_ID;
168 ASSERT_EQ(
169 xnn_status_success,
170 xnn_define_quantized_tensor_value(
171 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
172 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
173 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
174
175 ASSERT_EQ(
176 xnn_status_success,
177 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
178
179 ASSERT_EQ(subgraph->num_nodes, 1);
180 const struct xnn_node* node = &subgraph->nodes[0];
181 ASSERT_EQ(node->type, xnn_node_type_concatenate3);
182 ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
183 ASSERT_EQ(node->params.concatenate.axis, axis);
184 ASSERT_EQ(node->num_inputs, 3);
185 ASSERT_EQ(node->inputs[0], input1_id);
186 ASSERT_EQ(node->inputs[1], input2_id);
187 ASSERT_EQ(node->inputs[2], input3_id);
188 ASSERT_EQ(node->num_outputs, 1);
189 ASSERT_EQ(node->outputs[0], output_id);
190 ASSERT_EQ(node->flags, 0);
191 }
192
TEST_F(Concatenate3TestQU8,define)193 TEST_F(Concatenate3TestQU8, define)
194 {
195 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
196
197 xnn_subgraph_t subgraph = nullptr;
198 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
199 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
200
201 input1_id = XNN_INVALID_NODE_ID;
202 ASSERT_EQ(
203 xnn_status_success,
204 xnn_define_quantized_tensor_value(
205 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
206 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
207 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
208
209 input2_id = XNN_INVALID_NODE_ID;
210 ASSERT_EQ(
211 xnn_status_success,
212 xnn_define_quantized_tensor_value(
213 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
214 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
215 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
216
217 input3_id = XNN_INVALID_NODE_ID;
218 ASSERT_EQ(
219 xnn_status_success,
220 xnn_define_quantized_tensor_value(
221 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
222 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
223 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
224
225 output_id = XNN_INVALID_NODE_ID;
226 ASSERT_EQ(
227 xnn_status_success,
228 xnn_define_quantized_tensor_value(
229 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
230 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
231 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
232
233 ASSERT_EQ(
234 xnn_status_success,
235 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
236
237 ASSERT_EQ(subgraph->num_nodes, 1);
238 const struct xnn_node* node = &subgraph->nodes[0];
239 ASSERT_EQ(node->type, xnn_node_type_concatenate3);
240 ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
241 ASSERT_EQ(node->params.concatenate.axis, axis);
242 ASSERT_EQ(node->num_inputs, 3);
243 ASSERT_EQ(node->inputs[0], input1_id);
244 ASSERT_EQ(node->inputs[1], input2_id);
245 ASSERT_EQ(node->inputs[2], input3_id);
246 ASSERT_EQ(node->num_outputs, 1);
247 ASSERT_EQ(node->outputs[0], output_id);
248 ASSERT_EQ(node->flags, 0);
249 }
250
TEST_F(Concatenate3TestF32,define)251 TEST_F(Concatenate3TestF32, define)
252 {
253 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
254
255 xnn_subgraph_t subgraph = nullptr;
256 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
257 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
258
259 input1_id = XNN_INVALID_NODE_ID;
260 ASSERT_EQ(
261 xnn_status_success, xnn_define_tensor_value(
262 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
263 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
264 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
265
266 input2_id = XNN_INVALID_NODE_ID;
267 ASSERT_EQ(
268 xnn_status_success, xnn_define_tensor_value(
269 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
270 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
271 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
272
273 input3_id = XNN_INVALID_NODE_ID;
274 ASSERT_EQ(
275 xnn_status_success, xnn_define_tensor_value(
276 subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
277 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
278 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
279
280 output_id = XNN_INVALID_NODE_ID;
281 ASSERT_EQ(
282 xnn_status_success, xnn_define_tensor_value(
283 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 3,
284 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
285 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
286
287 ASSERT_EQ(
288 xnn_status_success,
289 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
290
291 ASSERT_EQ(subgraph->num_nodes, 1);
292 const struct xnn_node* node = &subgraph->nodes[0];
293 ASSERT_EQ(node->type, xnn_node_type_concatenate3);
294 ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
295 ASSERT_EQ(node->params.concatenate.axis, axis);
296 ASSERT_EQ(node->num_inputs, 3);
297 ASSERT_EQ(node->inputs[0], input1_id);
298 ASSERT_EQ(node->inputs[1], input2_id);
299 ASSERT_EQ(node->inputs[2], input3_id);
300 ASSERT_EQ(node->num_outputs, 1);
301 ASSERT_EQ(node->outputs[0], output_id);
302 ASSERT_EQ(node->flags, 0);
303 }
304
TEST_F(Concatenate3TestQS8,matches_operator_api)305 TEST_F(Concatenate3TestQS8, matches_operator_api)
306 {
307 std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
308 std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
309 std::generate(input3.begin(), input3.end(), [&]() { return i8dist(rng); });
310 std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
311 std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
312
313 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
314
315 xnn_operator_t op1 = nullptr;
316 xnn_operator_t op2 = nullptr;
317 xnn_operator_t op3 = nullptr;
318
319 // Call operator API.
320 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
321 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
322 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
323 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
324 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
325 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
326
327 ASSERT_EQ(
328 xnn_status_success,
329 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
330 ASSERT_EQ(
331 xnn_status_success,
332 xnn_setup_copy_nc_x8(
333 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
334 ASSERT_EQ(
335 xnn_status_success,
336 xnn_setup_copy_nc_x8(
337 op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
338 nullptr /* thread pool */));
339
340 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
341 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
342 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
343
344 // Call subgraph API.
345 xnn_subgraph_t subgraph = nullptr;
346 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
347 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
348
349 input1_id = XNN_INVALID_NODE_ID;
350 ASSERT_EQ(
351 xnn_status_success,
352 xnn_define_quantized_tensor_value(
353 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
354 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
355 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
356
357 input2_id = XNN_INVALID_NODE_ID;
358 ASSERT_EQ(
359 xnn_status_success,
360 xnn_define_quantized_tensor_value(
361 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
362 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
363 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
364
365 input3_id = XNN_INVALID_NODE_ID;
366 ASSERT_EQ(
367 xnn_status_success,
368 xnn_define_quantized_tensor_value(
369 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
370 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
371 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
372
373 output_id = XNN_INVALID_NODE_ID;
374 ASSERT_EQ(
375 xnn_status_success,
376 xnn_define_quantized_tensor_value(
377 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
378 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
379 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
380
381 ASSERT_EQ(
382 xnn_status_success,
383 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
384
385 xnn_runtime_t runtime = nullptr;
386 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
387 ASSERT_NE(nullptr, runtime);
388 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
389 std::array<xnn_external_value, 4> external = {
390 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
391 xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
392 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
393 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
394
395 // Check outputs match.
396 ASSERT_EQ(subgraph_output, operator_output);
397 }
398
TEST_F(Concatenate3TestQU8,matches_operator_api)399 TEST_F(Concatenate3TestQU8, matches_operator_api)
400 {
401 std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
402 std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
403 std::generate(input3.begin(), input3.end(), [&]() { return u8dist(rng); });
404 std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
405 std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
406
407 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
408
409 xnn_operator_t op1 = nullptr;
410 xnn_operator_t op2 = nullptr;
411 xnn_operator_t op3 = nullptr;
412
413 // Call operator API.
414 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
415 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
416 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
417 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
418 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
419 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
420
421 ASSERT_EQ(
422 xnn_status_success,
423 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
424 ASSERT_EQ(
425 xnn_status_success,
426 xnn_setup_copy_nc_x8(
427 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
428 ASSERT_EQ(
429 xnn_status_success,
430 xnn_setup_copy_nc_x8(
431 op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
432 nullptr /* thread pool */));
433
434 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
435 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
436 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
437
438 // Call subgraph API.
439 xnn_subgraph_t subgraph = nullptr;
440 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
441 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
442
443 input1_id = XNN_INVALID_NODE_ID;
444 ASSERT_EQ(
445 xnn_status_success,
446 xnn_define_quantized_tensor_value(
447 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
448 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
449 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
450
451 input2_id = XNN_INVALID_NODE_ID;
452 ASSERT_EQ(
453 xnn_status_success,
454 xnn_define_quantized_tensor_value(
455 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
456 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
457 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
458
459 input3_id = XNN_INVALID_NODE_ID;
460 ASSERT_EQ(
461 xnn_status_success,
462 xnn_define_quantized_tensor_value(
463 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
464 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
465 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
466
467 output_id = XNN_INVALID_NODE_ID;
468 ASSERT_EQ(
469 xnn_status_success,
470 xnn_define_quantized_tensor_value(
471 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
472 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
473 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
474
475 ASSERT_EQ(
476 xnn_status_success,
477 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
478
479 xnn_runtime_t runtime = nullptr;
480 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
481 ASSERT_NE(nullptr, runtime);
482 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
483 std::array<xnn_external_value, 4> external = {
484 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
485 xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
486 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
487 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
488
489 // Check outputs match.
490 ASSERT_EQ(subgraph_output, operator_output);
491 }
492
TEST_F(Concatenate3TestF32,matches_operator_api)493 TEST_F(Concatenate3TestF32, matches_operator_api)
494 {
495 std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
496 std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
497 std::generate(input3.begin(), input3.end(), [&]() { return f32dist(rng); });
498 std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
499 std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
500
501 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
502
503 xnn_operator_t op1 = nullptr;
504 xnn_operator_t op2 = nullptr;
505 xnn_operator_t op3 = nullptr;
506
507 // Call operator API.
508 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
509 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
510 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
511 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
512 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
513 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
514
515 ASSERT_EQ(
516 xnn_status_success,
517 xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
518 ASSERT_EQ(
519 xnn_status_success,
520 xnn_setup_copy_nc_x32(
521 op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
522 ASSERT_EQ(
523 xnn_status_success, xnn_setup_copy_nc_x32(
524 op3, batch_size, input3.data(),
525 (float*) operator_output.data() + op1->channels + op2->channels, nullptr /* thread pool */));
526
527 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
528 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
529 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
530
531 // Call subgraph API.
532 xnn_subgraph_t subgraph = nullptr;
533 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
534 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
535
536 input1_id = XNN_INVALID_NODE_ID;
537 ASSERT_EQ(
538 xnn_status_success, xnn_define_tensor_value(
539 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
540 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
541 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
542
543 input2_id = XNN_INVALID_NODE_ID;
544 ASSERT_EQ(
545 xnn_status_success, xnn_define_tensor_value(
546 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
547 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
548 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
549
550 input3_id = XNN_INVALID_NODE_ID;
551 ASSERT_EQ(
552 xnn_status_success, xnn_define_tensor_value(
553 subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
554 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
555 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
556
557 output_id = XNN_INVALID_NODE_ID;
558 ASSERT_EQ(
559 xnn_status_success, xnn_define_tensor_value(
560 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 3,
561 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
562 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
563
564 ASSERT_EQ(
565 xnn_status_success,
566 xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
567
568 xnn_runtime_t runtime = nullptr;
569 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
570 ASSERT_NE(nullptr, runtime);
571 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
572 std::array<xnn_external_value, 4> external = {
573 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
574 xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
575 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
576 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
577
578 // Check outputs match.
579 ASSERT_EQ(subgraph_output, operator_output);
580 }
581