1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14
15
16
create_convert_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_convert_operator(
18 const struct xnn_node* node,
19 const struct xnn_value* values,
20 size_t num_values,
21 struct xnn_operator_data* opdata)
22 {
23 assert(node->num_inputs == 1);
24 const uint32_t input_id = node->inputs[0];
25 assert(input_id != XNN_INVALID_VALUE_ID);
26 assert(input_id < num_values);
27
28 assert(node->num_outputs == 1);
29 const uint32_t output_id = node->outputs[0];
30 assert(output_id != XNN_INVALID_VALUE_ID);
31 assert(output_id < num_values);
32
33 const size_t num_input_dims = values[input_id].shape.num_dims;
34 const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
35
36 enum xnn_status status = xnn_status_uninitialized;
37 switch (node->compute_type) {
38 case xnn_compute_type_fp32_to_fp16:
39 status = xnn_create_convert_nc_f32_f16(
40 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
41 node->flags,
42 &opdata->operator_object);
43 break;
44 case xnn_compute_type_fp32_to_qs8:
45 status = xnn_create_convert_nc_f32_qs8(
46 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
47 values[output_id].quantization.scale,
48 (int8_t) values[output_id].quantization.zero_point,
49 INT8_MIN, INT8_MAX,
50 node->flags,
51 &opdata->operator_object);
52 break;
53 case xnn_compute_type_fp32_to_qu8:
54 status = xnn_create_convert_nc_f32_qu8(
55 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
56 values[output_id].quantization.scale,
57 (uint8_t) values[output_id].quantization.zero_point,
58 0, UINT8_MAX,
59 node->flags,
60 &opdata->operator_object);
61 break;
62 case xnn_compute_type_fp16_to_fp32:
63 status = xnn_create_convert_nc_f16_f32(
64 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
65 node->flags,
66 &opdata->operator_object);
67 break;
68 case xnn_compute_type_qs8_to_fp32:
69 status = xnn_create_convert_nc_qs8_f32(
70 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
71 values[input_id].quantization.scale,
72 (int8_t) values[input_id].quantization.zero_point,
73 node->flags,
74 &opdata->operator_object);
75 break;
76 case xnn_compute_type_qu8_to_fp32:
77 status = xnn_create_convert_nc_qu8_f32(
78 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
79 values[input_id].quantization.scale,
80 (uint8_t) values[input_id].quantization.zero_point,
81 node->flags,
82 &opdata->operator_object);
83 break;
84 default:
85 XNN_UNREACHABLE;
86 }
87 if (status == xnn_status_success) {
88 opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
89 opdata->inputs[0] = input_id;
90 opdata->outputs[0] = output_id;
91 }
92 return status;
93 }
94
setup_convert_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)95 static enum xnn_status setup_convert_operator(
96 const struct xnn_operator_data* opdata,
97 const struct xnn_blob* blobs,
98 size_t num_blobs,
99 pthreadpool_t threadpool)
100 {
101 const uint32_t input_id = opdata->inputs[0];
102 assert(input_id != XNN_INVALID_VALUE_ID);
103 assert(input_id < num_blobs);
104
105 const uint32_t output_id = opdata->outputs[0];
106 assert(output_id != XNN_INVALID_VALUE_ID);
107 assert(output_id < num_blobs);
108
109 const struct xnn_blob* input_blob = blobs + input_id;
110 const void* input_data = input_blob->data;
111 assert(input_data != NULL);
112
113 const struct xnn_blob* output_blob = blobs + output_id;
114 void* output_data = output_blob->data;
115 assert(output_data != NULL);
116
117 switch (opdata->operator_object->type) {
118 case xnn_operator_type_convert_nc_f32_f16:
119 return xnn_setup_convert_nc_f32_f16(
120 opdata->operator_object,
121 opdata->batch_size,
122 input_data,
123 output_data,
124 threadpool);
125 case xnn_operator_type_convert_nc_f32_qs8:
126 return xnn_setup_convert_nc_f32_qs8(
127 opdata->operator_object,
128 opdata->batch_size,
129 input_data,
130 output_data,
131 threadpool);
132 case xnn_operator_type_convert_nc_f32_qu8:
133 return xnn_setup_convert_nc_f32_qu8(
134 opdata->operator_object,
135 opdata->batch_size,
136 input_data,
137 output_data,
138 threadpool);
139 case xnn_operator_type_convert_nc_f16_f32:
140 return xnn_setup_convert_nc_f16_f32(
141 opdata->operator_object,
142 opdata->batch_size,
143 input_data,
144 output_data,
145 threadpool);
146 case xnn_operator_type_convert_nc_qs8_f32:
147 return xnn_setup_convert_nc_qs8_f32(
148 opdata->operator_object,
149 opdata->batch_size,
150 input_data,
151 output_data,
152 threadpool);
153 case xnn_operator_type_convert_nc_qu8_f32:
154 return xnn_setup_convert_nc_qu8_f32(
155 opdata->operator_object,
156 opdata->batch_size,
157 input_data,
158 output_data,
159 threadpool);
160 default:
161 XNN_UNREACHABLE;
162 }
163 }
164
validate_datatypes(enum xnn_datatype input_datatype,enum xnn_datatype output_datatype)165 static inline enum xnn_compute_type validate_datatypes(
166 enum xnn_datatype input_datatype,
167 enum xnn_datatype output_datatype)
168 {
169 switch (input_datatype) {
170 case xnn_datatype_fp32:
171 switch (output_datatype) {
172 case xnn_datatype_fp16:
173 return xnn_compute_type_fp32_to_fp16;
174 case xnn_datatype_qint8:
175 return xnn_compute_type_fp32_to_qs8;
176 case xnn_datatype_quint8:
177 return xnn_compute_type_fp32_to_qu8;
178 default:
179 break;
180 }
181 break;
182 case xnn_datatype_fp16:
183 if (output_datatype == xnn_datatype_fp32) {
184 return xnn_compute_type_fp16_to_fp32;
185 }
186 break;
187 case xnn_datatype_qint8:
188 if (output_datatype == xnn_datatype_fp32) {
189 return xnn_compute_type_qs8_to_fp32;
190 }
191 break;
192 case xnn_datatype_quint8:
193 if (output_datatype == xnn_datatype_fp32) {
194 return xnn_compute_type_qu8_to_fp32;
195 }
196 break;
197 default:
198 XNN_UNREACHABLE;
199 }
200 return xnn_compute_type_invalid;
201 }
202
xnn_init_convert_node(struct xnn_node * node,enum xnn_compute_type compute_type,uint32_t input_id,uint32_t output_id,uint32_t flags)203 void xnn_init_convert_node(
204 struct xnn_node* node,
205 enum xnn_compute_type compute_type,
206 uint32_t input_id,
207 uint32_t output_id,
208 uint32_t flags)
209 {
210 node->type = xnn_node_type_convert;
211 node->compute_type = compute_type;
212 node->num_inputs = 1;
213 node->inputs[0] = input_id;
214 node->num_outputs = 1;
215 node->outputs[0] = output_id;
216 node->flags = flags;
217
218 node->create = create_convert_operator;
219 node->setup = setup_convert_operator;
220 }
221
xnn_define_convert(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)222 enum xnn_status xnn_define_convert(
223 xnn_subgraph_t subgraph,
224 uint32_t input_id,
225 uint32_t output_id,
226 uint32_t flags)
227 {
228 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
229 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
230 xnn_node_type_to_string(xnn_node_type_convert));
231 return xnn_status_uninitialized;
232 }
233
234 if (input_id >= subgraph->num_values) {
235 xnn_log_error(
236 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
237 xnn_node_type_to_string(xnn_node_type_convert), input_id);
238 return xnn_status_invalid_parameter;
239 }
240
241 const struct xnn_value* input_value = &subgraph->values[input_id];
242 if (input_value->type != xnn_value_type_dense_tensor) {
243 xnn_log_error(
244 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
245 xnn_node_type_to_string(xnn_node_type_convert), input_id, input_value->type);
246 return xnn_status_invalid_parameter;
247 }
248
249 switch (input_value->datatype) {
250 case xnn_datatype_fp16:
251 case xnn_datatype_fp32:
252 case xnn_datatype_qint8:
253 case xnn_datatype_quint8:
254 break;
255 default:
256 xnn_log_error(
257 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
258 xnn_node_type_to_string(xnn_node_type_convert), input_id,
259 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
260 return xnn_status_invalid_parameter;
261 }
262
263 if (output_id >= subgraph->num_values) {
264 xnn_log_error(
265 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
266 xnn_node_type_to_string(xnn_node_type_convert), output_id);
267 return xnn_status_invalid_parameter;
268 }
269
270 const struct xnn_value* output_value = &subgraph->values[output_id];
271 if (output_value->type != xnn_value_type_dense_tensor) {
272 xnn_log_error(
273 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
274 xnn_node_type_to_string(xnn_node_type_convert), output_id, output_value->type);
275 return xnn_status_invalid_parameter;
276 }
277
278 switch (output_value->datatype) {
279 case xnn_datatype_fp16:
280 case xnn_datatype_fp32:
281 case xnn_datatype_qint8:
282 case xnn_datatype_quint8:
283 break;
284 default:
285 xnn_log_error(
286 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
287 xnn_node_type_to_string(xnn_node_type_convert), output_id,
288 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
289 return xnn_status_invalid_parameter;
290 }
291
292 enum xnn_compute_type compute_type = validate_datatypes(input_value->datatype, output_value->datatype);
293 if (compute_type == xnn_compute_type_invalid) {
294 xnn_log_error(
295 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
296 ": mismatching datatypes across input (%s) and output (%s)",
297 xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
298 xnn_datatype_to_string(input_value->datatype),
299 xnn_datatype_to_string(output_value->datatype));
300 return xnn_status_invalid_parameter;
301 }
302
303 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
304 if (node == NULL) {
305 return xnn_status_out_of_memory;
306 }
307
308 xnn_init_convert_node(node, compute_type, input_id, output_id, flags);
309 return xnn_status_success;
310 }
311