• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19 
20 #include <gtest/gtest.h>
21 
22 template <typename T> class Concatenate3Test : public ::testing::Test {
23 protected:
Concatenate3Test()24   Concatenate3Test()
25   {
26     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27     rng = std::mt19937((*random_device)());
28     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30     f32dist = std::uniform_real_distribution<float>();
31     i8dist =
32       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33     u8dist =
34       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36 
37     input1_dims = RandomShape();
38     axis = RandomAxis(input1_dims);
39     input2_dims = RandomShape(input1_dims, axis);
40     input3_dims = RandomShape(input1_dims, axis);
41     output_dims = input1_dims;
42     output_dims[axis] = input1_dims[axis] + input2_dims[axis] + input3_dims[axis];
43 
44     input1 = std::vector<T>(NumElements(input1_dims));
45     input2 = std::vector<T>(NumElements(input2_dims));
46     input3 = std::vector<T>(NumElements(input3_dims));
47     operator_output = std::vector<T>(NumElements(output_dims));
48     subgraph_output = std::vector<T>(NumElements(output_dims));
49 
50     signed_zero_point = i8dist(rng);
51     unsigned_zero_point = u8dist(rng);
52     scale = scale_dist(rng);
53 
54     batch_size = 1;
55     channels_1 = 1;
56     channels_2 = 1;
57     channels_3 = 1;
58     for (size_t i = 0; i < axis; i++) {
59       batch_size *= output_dims[i];
60     }
61 
62     for (size_t i = axis; i < input1_dims.size(); i++) {
63       channels_1 *= input1_dims[i];
64       channels_2 *= input2_dims[i];
65       channels_3 *= input3_dims[i];
66     }
67     output_stride = channels_1 + channels_2 + channels_3;
68   }
69 
RandomShape()70   std::vector<size_t> RandomShape()
71   {
72     std::vector<size_t> dims(shape_dist(rng));
73     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
74     return dims;
75   }
76 
RandomShape(const std::vector<size_t> base_dims,size_t axis)77   std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
78   {
79     auto dims = base_dims;
80     dims[axis] = dim_dist(rng);
81     return dims;
82   }
83 
RandomAxis(const std::vector<size_t> & dims)84   size_t RandomAxis(const std::vector<size_t>& dims)
85   {
86     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
87   }
88 
NumElements(const std::vector<size_t> & dims)89   size_t NumElements(const std::vector<size_t>& dims)
90   {
91     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
92   }
93 
94   std::unique_ptr<std::random_device> random_device;
95   std::mt19937 rng;
96   std::uniform_int_distribution<size_t> shape_dist;
97   std::uniform_int_distribution<size_t> dim_dist;
98   std::uniform_real_distribution<float> f32dist;
99   std::uniform_int_distribution<int32_t> i8dist;
100   std::uniform_int_distribution<int32_t> u8dist;
101   std::uniform_real_distribution<float> scale_dist;
102 
103   uint32_t input1_id;
104   uint32_t input2_id;
105   uint32_t input3_id;
106   uint32_t output_id;
107 
108   std::vector<size_t> input1_dims;
109   std::vector<size_t> input2_dims;
110   std::vector<size_t> input3_dims;
111   std::vector<size_t> output_dims;
112 
113   size_t axis;
114   size_t batch_size;
115   size_t channels_1;
116   size_t channels_2;
117   size_t channels_3;
118   size_t output_stride;
119 
120   int32_t signed_zero_point;
121   int32_t unsigned_zero_point;
122   float scale;
123 
124   std::vector<T> input1;
125   std::vector<T> input2;
126   std::vector<T> input3;
127   std::vector<T> operator_output;
128   std::vector<T> subgraph_output;
129 };
130 
131 using Concatenate3TestQS8 = Concatenate3Test<int8_t>;
132 using Concatenate3TestQU8 = Concatenate3Test<uint8_t>;
133 using Concatenate3TestF32 = Concatenate3Test<float>;
134 
TEST_F(Concatenate3TestQS8,define)135 TEST_F(Concatenate3TestQS8, define)
136 {
137   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
138 
139   xnn_subgraph_t subgraph = nullptr;
140   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
141   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
142 
143   input1_id = XNN_INVALID_NODE_ID;
144   ASSERT_EQ(
145     xnn_status_success,
146     xnn_define_quantized_tensor_value(
147       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
148       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
149   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
150 
151   input2_id = XNN_INVALID_NODE_ID;
152   ASSERT_EQ(
153     xnn_status_success,
154     xnn_define_quantized_tensor_value(
155       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
156       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
157   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
158 
159   input3_id = XNN_INVALID_NODE_ID;
160   ASSERT_EQ(
161     xnn_status_success,
162     xnn_define_quantized_tensor_value(
163       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
164       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
165   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
166 
167   output_id = XNN_INVALID_NODE_ID;
168   ASSERT_EQ(
169     xnn_status_success,
170     xnn_define_quantized_tensor_value(
171       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
172       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
173   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
174 
175   ASSERT_EQ(
176     xnn_status_success,
177     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
178 
179   ASSERT_EQ(subgraph->num_nodes, 1);
180   const struct xnn_node* node = &subgraph->nodes[0];
181   ASSERT_EQ(node->type, xnn_node_type_concatenate3);
182   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
183   ASSERT_EQ(node->params.concatenate.axis, axis);
184   ASSERT_EQ(node->num_inputs, 3);
185   ASSERT_EQ(node->inputs[0], input1_id);
186   ASSERT_EQ(node->inputs[1], input2_id);
187   ASSERT_EQ(node->inputs[2], input3_id);
188   ASSERT_EQ(node->num_outputs, 1);
189   ASSERT_EQ(node->outputs[0], output_id);
190   ASSERT_EQ(node->flags, 0);
191 }
192 
TEST_F(Concatenate3TestQU8,define)193 TEST_F(Concatenate3TestQU8, define)
194 {
195   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
196 
197   xnn_subgraph_t subgraph = nullptr;
198   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
199   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
200 
201   input1_id = XNN_INVALID_NODE_ID;
202   ASSERT_EQ(
203     xnn_status_success,
204     xnn_define_quantized_tensor_value(
205       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
206       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
207   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
208 
209   input2_id = XNN_INVALID_NODE_ID;
210   ASSERT_EQ(
211     xnn_status_success,
212     xnn_define_quantized_tensor_value(
213       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
214       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
215   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
216 
217   input3_id = XNN_INVALID_NODE_ID;
218   ASSERT_EQ(
219     xnn_status_success,
220     xnn_define_quantized_tensor_value(
221       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
222       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
223   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
224 
225   output_id = XNN_INVALID_NODE_ID;
226   ASSERT_EQ(
227     xnn_status_success,
228     xnn_define_quantized_tensor_value(
229       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
230       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
231   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
232 
233   ASSERT_EQ(
234     xnn_status_success,
235     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
236 
237   ASSERT_EQ(subgraph->num_nodes, 1);
238   const struct xnn_node* node = &subgraph->nodes[0];
239   ASSERT_EQ(node->type, xnn_node_type_concatenate3);
240   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
241   ASSERT_EQ(node->params.concatenate.axis, axis);
242   ASSERT_EQ(node->num_inputs, 3);
243   ASSERT_EQ(node->inputs[0], input1_id);
244   ASSERT_EQ(node->inputs[1], input2_id);
245   ASSERT_EQ(node->inputs[2], input3_id);
246   ASSERT_EQ(node->num_outputs, 1);
247   ASSERT_EQ(node->outputs[0], output_id);
248   ASSERT_EQ(node->flags, 0);
249 }
250 
TEST_F(Concatenate3TestF32,define)251 TEST_F(Concatenate3TestF32, define)
252 {
253   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
254 
255   xnn_subgraph_t subgraph = nullptr;
256   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
257   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
258 
259   input1_id = XNN_INVALID_NODE_ID;
260   ASSERT_EQ(
261     xnn_status_success, xnn_define_tensor_value(
262                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
263                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
264   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
265 
266   input2_id = XNN_INVALID_NODE_ID;
267   ASSERT_EQ(
268     xnn_status_success, xnn_define_tensor_value(
269                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
270                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
271   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
272 
273   input3_id = XNN_INVALID_NODE_ID;
274   ASSERT_EQ(
275     xnn_status_success, xnn_define_tensor_value(
276                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
277                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
278   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
279 
280   output_id = XNN_INVALID_NODE_ID;
281   ASSERT_EQ(
282     xnn_status_success, xnn_define_tensor_value(
283                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 3,
284                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
285   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
286 
287   ASSERT_EQ(
288     xnn_status_success,
289     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
290 
291   ASSERT_EQ(subgraph->num_nodes, 1);
292   const struct xnn_node* node = &subgraph->nodes[0];
293   ASSERT_EQ(node->type, xnn_node_type_concatenate3);
294   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
295   ASSERT_EQ(node->params.concatenate.axis, axis);
296   ASSERT_EQ(node->num_inputs, 3);
297   ASSERT_EQ(node->inputs[0], input1_id);
298   ASSERT_EQ(node->inputs[1], input2_id);
299   ASSERT_EQ(node->inputs[2], input3_id);
300   ASSERT_EQ(node->num_outputs, 1);
301   ASSERT_EQ(node->outputs[0], output_id);
302   ASSERT_EQ(node->flags, 0);
303 }
304 
TEST_F(Concatenate3TestQS8,matches_operator_api)305 TEST_F(Concatenate3TestQS8, matches_operator_api)
306 {
307   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
308   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
309   std::generate(input3.begin(), input3.end(), [&]() { return i8dist(rng); });
310   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
311   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
312 
313   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
314 
315   xnn_operator_t op1 = nullptr;
316   xnn_operator_t op2 = nullptr;
317   xnn_operator_t op3 = nullptr;
318 
319   // Call operator API.
320   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
321   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
322   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
323   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
324   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
325   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
326 
327   ASSERT_EQ(
328     xnn_status_success,
329     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
330   ASSERT_EQ(
331     xnn_status_success,
332     xnn_setup_copy_nc_x8(
333       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
334   ASSERT_EQ(
335     xnn_status_success,
336     xnn_setup_copy_nc_x8(
337       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
338       nullptr /* thread pool */));
339 
340   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
341   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
342   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
343 
344   // Call subgraph API.
345   xnn_subgraph_t subgraph = nullptr;
346   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
347   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
348 
349   input1_id = XNN_INVALID_NODE_ID;
350   ASSERT_EQ(
351     xnn_status_success,
352     xnn_define_quantized_tensor_value(
353       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
354       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
355   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
356 
357   input2_id = XNN_INVALID_NODE_ID;
358   ASSERT_EQ(
359     xnn_status_success,
360     xnn_define_quantized_tensor_value(
361       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
362       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
363   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
364 
365   input3_id = XNN_INVALID_NODE_ID;
366   ASSERT_EQ(
367     xnn_status_success,
368     xnn_define_quantized_tensor_value(
369       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
370       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
371   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
372 
373   output_id = XNN_INVALID_NODE_ID;
374   ASSERT_EQ(
375     xnn_status_success,
376     xnn_define_quantized_tensor_value(
377       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
378       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
379   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
380 
381   ASSERT_EQ(
382     xnn_status_success,
383     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
384 
385   xnn_runtime_t runtime = nullptr;
386   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
387   ASSERT_NE(nullptr, runtime);
388   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
389   std::array<xnn_external_value, 4> external = {
390     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
391     xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
392   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
393   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
394 
395   // Check outputs match.
396   ASSERT_EQ(subgraph_output, operator_output);
397 }
398 
TEST_F(Concatenate3TestQU8,matches_operator_api)399 TEST_F(Concatenate3TestQU8, matches_operator_api)
400 {
401   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
402   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
403   std::generate(input3.begin(), input3.end(), [&]() { return u8dist(rng); });
404   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
405   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
406 
407   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
408 
409   xnn_operator_t op1 = nullptr;
410   xnn_operator_t op2 = nullptr;
411   xnn_operator_t op3 = nullptr;
412 
413   // Call operator API.
414   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
415   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
416   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
417   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
418   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
419   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
420 
421   ASSERT_EQ(
422     xnn_status_success,
423     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
424   ASSERT_EQ(
425     xnn_status_success,
426     xnn_setup_copy_nc_x8(
427       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
428   ASSERT_EQ(
429     xnn_status_success,
430     xnn_setup_copy_nc_x8(
431       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
432       nullptr /* thread pool */));
433 
434   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
435   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
436   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
437 
438   // Call subgraph API.
439   xnn_subgraph_t subgraph = nullptr;
440   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
441   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
442 
443   input1_id = XNN_INVALID_NODE_ID;
444   ASSERT_EQ(
445     xnn_status_success,
446     xnn_define_quantized_tensor_value(
447       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
448       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
449   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
450 
451   input2_id = XNN_INVALID_NODE_ID;
452   ASSERT_EQ(
453     xnn_status_success,
454     xnn_define_quantized_tensor_value(
455       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
456       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
457   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
458 
459   input3_id = XNN_INVALID_NODE_ID;
460   ASSERT_EQ(
461     xnn_status_success,
462     xnn_define_quantized_tensor_value(
463       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
464       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
465   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
466 
467   output_id = XNN_INVALID_NODE_ID;
468   ASSERT_EQ(
469     xnn_status_success,
470     xnn_define_quantized_tensor_value(
471       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 3,
472       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
473   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
474 
475   ASSERT_EQ(
476     xnn_status_success,
477     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
478 
479   xnn_runtime_t runtime = nullptr;
480   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
481   ASSERT_NE(nullptr, runtime);
482   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
483   std::array<xnn_external_value, 4> external = {
484     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
485     xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
486   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
487   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
488 
489   // Check outputs match.
490   ASSERT_EQ(subgraph_output, operator_output);
491 }
492 
TEST_F(Concatenate3TestF32,matches_operator_api)493 TEST_F(Concatenate3TestF32, matches_operator_api)
494 {
495   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
496   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
497   std::generate(input3.begin(), input3.end(), [&]() { return f32dist(rng); });
498   std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
499   std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
500 
501   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
502 
503   xnn_operator_t op1 = nullptr;
504   xnn_operator_t op2 = nullptr;
505   xnn_operator_t op3 = nullptr;
506 
507   // Call operator API.
508   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
509   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
510   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
511   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
512   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
513   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
514 
515   ASSERT_EQ(
516     xnn_status_success,
517     xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
518   ASSERT_EQ(
519     xnn_status_success,
520     xnn_setup_copy_nc_x32(
521       op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
522   ASSERT_EQ(
523     xnn_status_success, xnn_setup_copy_nc_x32(
524                           op3, batch_size, input3.data(),
525                           (float*) operator_output.data() + op1->channels + op2->channels, nullptr /* thread pool */));
526 
527   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
528   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
529   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
530 
531   // Call subgraph API.
532   xnn_subgraph_t subgraph = nullptr;
533   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
534   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
535 
536   input1_id = XNN_INVALID_NODE_ID;
537   ASSERT_EQ(
538     xnn_status_success, xnn_define_tensor_value(
539                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
540                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
541   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
542 
543   input2_id = XNN_INVALID_NODE_ID;
544   ASSERT_EQ(
545     xnn_status_success, xnn_define_tensor_value(
546                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
547                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
548   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
549 
550   input3_id = XNN_INVALID_NODE_ID;
551   ASSERT_EQ(
552     xnn_status_success, xnn_define_tensor_value(
553                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
554                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
555   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
556 
557   output_id = XNN_INVALID_NODE_ID;
558   ASSERT_EQ(
559     xnn_status_success, xnn_define_tensor_value(
560                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 3,
561                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
562   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
563 
564   ASSERT_EQ(
565     xnn_status_success,
566     xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0));
567 
568   xnn_runtime_t runtime = nullptr;
569   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
570   ASSERT_NE(nullptr, runtime);
571   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
572   std::array<xnn_external_value, 4> external = {
573     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
574     xnn_external_value{input3_id, input3.data()}, xnn_external_value{output_id, subgraph_output.data()}};
575   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
576   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
577 
578   // Check outputs match.
579   ASSERT_EQ(subgraph_output, operator_output);
580 }
581