1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/params.h>
14 #include <xnnpack/subgraph.h>
15
16
create_copy_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_copy_operator(
18 const struct xnn_node* node,
19 const struct xnn_value* values,
20 size_t num_values,
21 struct xnn_operator_data* opdata)
22 {
23 assert(node->num_inputs == 1);
24 const uint32_t input_id = node->inputs[0];
25 assert(input_id != XNN_INVALID_VALUE_ID);
26 assert(input_id < num_values);
27
28 assert(node->num_outputs == 1);
29 const uint32_t output_id = node->outputs[0];
30 assert(output_id != XNN_INVALID_VALUE_ID);
31 assert(output_id < num_values);
32
33 enum xnn_status status;
34 switch (node->compute_type) {
35 #ifndef XNN_NO_F16_OPERATORS
36 case xnn_compute_type_fp16:
37 status = xnn_create_copy_nc_x16(
38 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
39 node->flags,
40 &opdata->operator_object);
41 break;
42 #endif // !defined(XNN_NO_F16_OPERATORS)
43 case xnn_compute_type_fp32:
44 status = xnn_create_copy_nc_x32(
45 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
46 node->flags,
47 &opdata->operator_object);
48 break;
49 #ifndef XNN_NO_QS8_OPERATORS
50 case xnn_compute_type_qs8:
51 #endif // !defined(XNN_NO_QS8_OPERATORS)
52 #ifndef XNN_NO_QU8_OPERATORS
53 case xnn_compute_type_qu8:
54 #endif // !defined(XNN_NO_QU8_OPERATORS)
55 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
56 status = xnn_create_copy_nc_x8(
57 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
58 node->flags,
59 &opdata->operator_object);
60 break;
61 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
62 default:
63 XNN_UNREACHABLE;
64 }
65 if (status == xnn_status_success) {
66 opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape);
67 opdata->inputs[0] = input_id;
68 opdata->outputs[0] = output_id;
69 }
70 return status;
71 }
72
setup_copy_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)73 static enum xnn_status setup_copy_operator(
74 const struct xnn_operator_data* opdata,
75 const struct xnn_blob* blobs,
76 size_t num_blobs,
77 pthreadpool_t threadpool)
78 {
79 const uint32_t input_id = opdata->inputs[0];
80 assert(input_id != XNN_INVALID_VALUE_ID);
81 assert(input_id < num_blobs);
82
83 const uint32_t output_id = opdata->outputs[0];
84 assert(output_id != XNN_INVALID_VALUE_ID);
85 assert(output_id < num_blobs);
86
87 const struct xnn_blob* input_blob = blobs + input_id;
88 const void* input_data = input_blob->data;
89 assert(input_data != NULL);
90
91 const struct xnn_blob* output_blob = blobs + output_id;
92 void* output_data = output_blob->data;
93 assert(output_data != NULL);
94
95 switch (opdata->operator_object->type) {
96 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
97 case xnn_operator_type_copy_nc_x8:
98 return xnn_setup_copy_nc_x8(
99 opdata->operator_object,
100 opdata->batch_size,
101 input_data,
102 output_data,
103 threadpool);
104 break;
105 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
106 #ifndef XNN_NO_F16_OPERATORS
107 case xnn_operator_type_copy_nc_x16:
108 return xnn_setup_copy_nc_x16(
109 opdata->operator_object,
110 opdata->batch_size,
111 input_data,
112 output_data,
113 threadpool);
114 break;
115 #endif // !defined(XNN_NO_F16_OPERATORS)
116 case xnn_operator_type_copy_nc_x32:
117 return xnn_setup_copy_nc_x32(
118 opdata->operator_object,
119 opdata->batch_size,
120 input_data,
121 output_data,
122 threadpool);
123 break;
124 default:
125 XNN_UNREACHABLE;
126 }
127 }
128
xnn_define_static_reshape(xnn_subgraph_t subgraph,size_t num_dims,const size_t * new_shape,uint32_t input_id,uint32_t output_id,uint32_t flags)129 enum xnn_status xnn_define_static_reshape(
130 xnn_subgraph_t subgraph,
131 size_t num_dims,
132 const size_t* new_shape,
133 uint32_t input_id,
134 uint32_t output_id,
135 uint32_t flags)
136 {
137 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
138 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
139 xnn_node_type_to_string(xnn_node_type_static_reshape));
140 return xnn_status_uninitialized;
141 }
142
143 if (input_id >= subgraph->num_values) {
144 xnn_log_error(
145 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
146 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id);
147 return xnn_status_invalid_parameter;
148 }
149
150 const struct xnn_value* input_value = &subgraph->values[input_id];
151 if (input_value->type != xnn_value_type_dense_tensor) {
152 xnn_log_error(
153 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
154 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, input_value->type);
155 return xnn_status_invalid_parameter;
156 }
157
158 switch (input_value->datatype) {
159 case xnn_datatype_fp32:
160 #ifndef XNN_NO_QS8_OPERATORS
161 case xnn_datatype_qint8:
162 #endif // !defined(XNN_NO_QS8_OPERATORS)
163 #ifndef XNN_NO_QU8_OPERATORS
164 case xnn_datatype_quint8:
165 #endif // !defined(XNN_NO_QU8_OPERATORS)
166 break;
167 default:
168 xnn_log_error(
169 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
170 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id,
171 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
172 return xnn_status_invalid_parameter;
173 }
174
175 if (output_id >= subgraph->num_values) {
176 xnn_log_error(
177 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
178 xnn_node_type_to_string(xnn_node_type_static_reshape), output_id);
179 return xnn_status_invalid_parameter;
180 }
181
182 const struct xnn_value* output_value = &subgraph->values[output_id];
183 if (output_value->type != xnn_value_type_dense_tensor) {
184 xnn_log_error(
185 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
186 xnn_node_type_to_string(xnn_node_type_static_reshape), output_id, output_value->type);
187 return xnn_status_invalid_parameter;
188 }
189
190 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
191 switch (output_value->datatype) {
192 case xnn_datatype_fp32:
193 compute_type = xnn_compute_type_fp32;
194 break;
195 #ifndef XNN_NO_QS8_OPERATORS
196 case xnn_datatype_qint8:
197 compute_type = xnn_compute_type_qs8;
198 break;
199 #endif // !defined(XNN_NO_QS8_OPERATORS)
200 #ifndef XNN_NO_QU8_OPERATORS
201 case xnn_datatype_quint8:
202 compute_type = xnn_compute_type_qu8;
203 break;
204 #endif // !defined(XNN_NO_QU8_OPERATORS)
205 default:
206 xnn_log_error(
207 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
208 xnn_node_type_to_string(xnn_node_type_static_reshape), output_id,
209 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
210 return xnn_status_invalid_parameter;
211 }
212
213 if (input_value->datatype != output_value->datatype) {
214 xnn_log_error(
215 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
216 ": mismatching datatypes across input (%s) and output (%s)",
217 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
218 xnn_datatype_to_string(input_value->datatype),
219 xnn_datatype_to_string(output_value->datatype));
220 return xnn_status_invalid_parameter;
221 }
222
223 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
224 if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
225 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
226 xnn_log_error(
227 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
228 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
229 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
230 input_value->quantization.zero_point, output_value->quantization.zero_point);
231 return xnn_status_invalid_parameter;
232 }
233 if (input_value->quantization.scale != output_value->quantization.scale) {
234 xnn_log_error(
235 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
236 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
237 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
238 input_value->quantization.scale, output_value->quantization.scale);
239 return xnn_status_invalid_parameter;
240 }
241 }
242 #endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
243
244 if (num_dims > XNN_MAX_TENSOR_DIMS) {
245 xnn_log_error(
246 "failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported",
247 xnn_node_type_to_string(xnn_node_type_static_reshape), num_dims, (size_t) XNN_MAX_TENSOR_DIMS);
248 return xnn_status_unsupported_parameter;
249 }
250
251 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
252 if (node == NULL) {
253 return xnn_status_out_of_memory;
254 }
255
256 node->params.static_reshape.new_shape.num_dims = num_dims;
257 memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t));
258
259 node->type = xnn_node_type_static_reshape;
260 node->compute_type = compute_type;
261 node->num_inputs = 1;
262 node->inputs[0] = input_id;
263 node->num_outputs = 1;
264 node->outputs[0] = output_id;
265 node->flags = flags;
266
267 node->create = create_copy_operator;
268 node->setup = setup_copy_operator;
269
270 return xnn_status_success;
271 }
272