• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
16 
create_convert_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_convert_operator(
18   const struct xnn_node* node,
19   const struct xnn_value* values,
20   size_t num_values,
21   struct xnn_operator_data* opdata)
22 {
23   assert(node->num_inputs == 1);
24   const uint32_t input_id = node->inputs[0];
25   assert(input_id != XNN_INVALID_VALUE_ID);
26   assert(input_id < num_values);
27 
28   assert(node->num_outputs == 1);
29   const uint32_t output_id = node->outputs[0];
30   assert(output_id != XNN_INVALID_VALUE_ID);
31   assert(output_id < num_values);
32 
33   const size_t num_input_dims = values[input_id].shape.num_dims;
34   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
35 
36   enum xnn_status status = xnn_status_uninitialized;
37   switch (node->compute_type) {
38     case xnn_compute_type_fp32_to_fp16:
39       status = xnn_create_convert_nc_f32_f16(
40         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
41         node->flags,
42         &opdata->operator_object);
43       break;
44     case xnn_compute_type_fp32_to_qs8:
45       status = xnn_create_convert_nc_f32_qs8(
46         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
47         values[output_id].quantization.scale,
48         (int8_t) values[output_id].quantization.zero_point,
49         INT8_MIN, INT8_MAX,
50         node->flags,
51         &opdata->operator_object);
52       break;
53     case xnn_compute_type_fp32_to_qu8:
54       status = xnn_create_convert_nc_f32_qu8(
55         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
56         values[output_id].quantization.scale,
57         (uint8_t) values[output_id].quantization.zero_point,
58         0, UINT8_MAX,
59         node->flags,
60         &opdata->operator_object);
61       break;
62     case xnn_compute_type_fp16_to_fp32:
63       status = xnn_create_convert_nc_f16_f32(
64         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
65         node->flags,
66         &opdata->operator_object);
67       break;
68     case xnn_compute_type_qs8_to_fp32:
69       status = xnn_create_convert_nc_qs8_f32(
70         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
71         values[input_id].quantization.scale,
72         (int8_t) values[input_id].quantization.zero_point,
73         node->flags,
74         &opdata->operator_object);
75       break;
76     case xnn_compute_type_qu8_to_fp32:
77       status = xnn_create_convert_nc_qu8_f32(
78         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
79         values[input_id].quantization.scale,
80         (uint8_t) values[input_id].quantization.zero_point,
81         node->flags,
82         &opdata->operator_object);
83       break;
84     default:
85       XNN_UNREACHABLE;
86   }
87   if (status == xnn_status_success) {
88     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
89     opdata->inputs[0] = input_id;
90     opdata->outputs[0] = output_id;
91   }
92   return status;
93 }
94 
setup_convert_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)95 static enum xnn_status setup_convert_operator(
96   const struct xnn_operator_data* opdata,
97   const struct xnn_blob* blobs,
98   size_t num_blobs,
99   pthreadpool_t threadpool)
100 {
101   const uint32_t input_id = opdata->inputs[0];
102   assert(input_id != XNN_INVALID_VALUE_ID);
103   assert(input_id < num_blobs);
104 
105   const uint32_t output_id = opdata->outputs[0];
106   assert(output_id != XNN_INVALID_VALUE_ID);
107   assert(output_id < num_blobs);
108 
109   const struct xnn_blob* input_blob = blobs + input_id;
110   const void* input_data = input_blob->data;
111   assert(input_data != NULL);
112 
113   const struct xnn_blob* output_blob = blobs + output_id;
114   void* output_data = output_blob->data;
115   assert(output_data != NULL);
116 
117   switch (opdata->operator_object->type) {
118     case xnn_operator_type_convert_nc_f32_f16:
119       return xnn_setup_convert_nc_f32_f16(
120         opdata->operator_object,
121         opdata->batch_size,
122         input_data,
123         output_data,
124         threadpool);
125     case xnn_operator_type_convert_nc_f32_qs8:
126       return xnn_setup_convert_nc_f32_qs8(
127         opdata->operator_object,
128         opdata->batch_size,
129         input_data,
130         output_data,
131         threadpool);
132     case xnn_operator_type_convert_nc_f32_qu8:
133       return xnn_setup_convert_nc_f32_qu8(
134         opdata->operator_object,
135         opdata->batch_size,
136         input_data,
137         output_data,
138         threadpool);
139     case xnn_operator_type_convert_nc_f16_f32:
140       return xnn_setup_convert_nc_f16_f32(
141         opdata->operator_object,
142         opdata->batch_size,
143         input_data,
144         output_data,
145         threadpool);
146     case xnn_operator_type_convert_nc_qs8_f32:
147       return xnn_setup_convert_nc_qs8_f32(
148         opdata->operator_object,
149         opdata->batch_size,
150         input_data,
151         output_data,
152         threadpool);
153     case xnn_operator_type_convert_nc_qu8_f32:
154       return xnn_setup_convert_nc_qu8_f32(
155         opdata->operator_object,
156         opdata->batch_size,
157         input_data,
158         output_data,
159         threadpool);
160     default:
161       XNN_UNREACHABLE;
162   }
163 }
164 
validate_datatypes(enum xnn_datatype input_datatype,enum xnn_datatype output_datatype)165 static inline enum xnn_compute_type validate_datatypes(
166   enum xnn_datatype input_datatype,
167   enum xnn_datatype output_datatype)
168 {
169   switch (input_datatype) {
170     case xnn_datatype_fp32:
171       switch (output_datatype) {
172         case xnn_datatype_fp16:
173           return xnn_compute_type_fp32_to_fp16;
174         case xnn_datatype_qint8:
175           return xnn_compute_type_fp32_to_qs8;
176         case xnn_datatype_quint8:
177           return xnn_compute_type_fp32_to_qu8;
178         default:
179           break;
180       }
181       break;
182     case xnn_datatype_fp16:
183       if (output_datatype == xnn_datatype_fp32) {
184         return xnn_compute_type_fp16_to_fp32;
185       }
186       break;
187     case xnn_datatype_qint8:
188       if (output_datatype == xnn_datatype_fp32) {
189         return xnn_compute_type_qs8_to_fp32;
190       }
191       break;
192     case xnn_datatype_quint8:
193       if (output_datatype == xnn_datatype_fp32) {
194         return xnn_compute_type_qu8_to_fp32;
195       }
196       break;
197     default:
198       XNN_UNREACHABLE;
199   }
200   return xnn_compute_type_invalid;
201 }
202 
xnn_init_convert_node(struct xnn_node * node,enum xnn_compute_type compute_type,uint32_t input_id,uint32_t output_id,uint32_t flags)203 void xnn_init_convert_node(
204   struct xnn_node* node,
205   enum xnn_compute_type compute_type,
206   uint32_t input_id,
207   uint32_t output_id,
208   uint32_t flags)
209 {
210   node->type = xnn_node_type_convert;
211   node->compute_type = compute_type;
212   node->num_inputs = 1;
213   node->inputs[0] = input_id;
214   node->num_outputs = 1;
215   node->outputs[0] = output_id;
216   node->flags = flags;
217 
218   node->create = create_convert_operator;
219   node->setup = setup_convert_operator;
220 }
221 
xnn_define_convert(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)222 enum xnn_status xnn_define_convert(
223   xnn_subgraph_t subgraph,
224   uint32_t input_id,
225   uint32_t output_id,
226   uint32_t flags)
227 {
228   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
229     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
230       xnn_node_type_to_string(xnn_node_type_convert));
231     return xnn_status_uninitialized;
232   }
233 
234   if (input_id >= subgraph->num_values) {
235     xnn_log_error(
236       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
237       xnn_node_type_to_string(xnn_node_type_convert), input_id);
238     return xnn_status_invalid_parameter;
239   }
240 
241   const struct xnn_value* input_value = &subgraph->values[input_id];
242   if (input_value->type != xnn_value_type_dense_tensor) {
243     xnn_log_error(
244       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
245       xnn_node_type_to_string(xnn_node_type_convert), input_id, input_value->type);
246     return xnn_status_invalid_parameter;
247   }
248 
249   switch (input_value->datatype) {
250     case xnn_datatype_fp16:
251     case xnn_datatype_fp32:
252     case xnn_datatype_qint8:
253     case xnn_datatype_quint8:
254       break;
255     default:
256       xnn_log_error(
257         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
258         xnn_node_type_to_string(xnn_node_type_convert), input_id,
259         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
260       return xnn_status_invalid_parameter;
261   }
262 
263   if (output_id >= subgraph->num_values) {
264     xnn_log_error(
265       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
266       xnn_node_type_to_string(xnn_node_type_convert), output_id);
267     return xnn_status_invalid_parameter;
268   }
269 
270   const struct xnn_value* output_value = &subgraph->values[output_id];
271   if (output_value->type != xnn_value_type_dense_tensor) {
272     xnn_log_error(
273       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
274       xnn_node_type_to_string(xnn_node_type_convert), output_id, output_value->type);
275     return xnn_status_invalid_parameter;
276   }
277 
278   switch (output_value->datatype) {
279     case xnn_datatype_fp16:
280     case xnn_datatype_fp32:
281     case xnn_datatype_qint8:
282     case xnn_datatype_quint8:
283       break;
284     default:
285       xnn_log_error(
286         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
287         xnn_node_type_to_string(xnn_node_type_convert), output_id,
288         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
289       return xnn_status_invalid_parameter;
290   }
291 
292   enum xnn_compute_type compute_type = validate_datatypes(input_value->datatype, output_value->datatype);
293   if (compute_type == xnn_compute_type_invalid) {
294     xnn_log_error(
295       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
296       ": mismatching datatypes across input (%s) and output (%s)",
297       xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
298       xnn_datatype_to_string(input_value->datatype),
299       xnn_datatype_to_string(output_value->datatype));
300     return xnn_status_invalid_parameter;
301   }
302 
303   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
304   if (node == NULL) {
305     return xnn_status_out_of_memory;
306   }
307 
308   xnn_init_convert_node(node, compute_type, input_id, output_id, flags);
309   return xnn_status_success;
310 }
311