• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/params.h>
14 #include <xnnpack/subgraph.h>
15 
16 
create_multiply_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_multiply_operator(
18   const struct xnn_node* node,
19   const struct xnn_value* values,
20   size_t num_values,
21   struct xnn_operator_data* opdata)
22 {
23   assert(node->num_inputs == 2);
24   const uint32_t input1_id = node->inputs[0];
25   assert(input1_id != XNN_INVALID_VALUE_ID);
26   assert(input1_id < num_values);
27   const uint32_t input2_id = node->inputs[1];
28   assert(input2_id != XNN_INVALID_VALUE_ID);
29   assert(input2_id < num_values);
30 
31   assert(node->num_outputs == 1);
32   const uint32_t output_id = node->outputs[0];
33   assert(output_id != XNN_INVALID_VALUE_ID);
34   assert(output_id < num_values);
35 
36   enum xnn_status status;
37   switch (node->compute_type) {
38     case xnn_compute_type_fp32:
39       status = xnn_create_multiply_nd_f32(
40         node->activation.output_min,
41         node->activation.output_max,
42         node->flags,
43         &opdata->operator_object);
44       break;
45 #ifndef XNN_NO_QS8_OPERATORS
46     case xnn_compute_type_qs8:
47     {
48       const float output_scale = values[output_id].quantization.scale;
49       const int32_t output_zero_point = values[output_id].quantization.zero_point;
50       const int8_t output_min =
51         (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
52       const int8_t output_max =
53         (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
54       status = xnn_create_multiply_nd_qs8(
55         (int8_t) values[input1_id].quantization.zero_point,
56         values[input1_id].quantization.scale,
57         (int8_t) values[input2_id].quantization.zero_point,
58         values[input2_id].quantization.scale,
59         (int8_t) output_zero_point,
60         output_scale, output_min, output_max, node->flags,
61         &opdata->operator_object);
62       break;
63     }
64 #endif  // !defined(XNN_NO_QS8_OPERATORS)
65 #ifndef XNN_NO_QU8_OPERATORS
66     case xnn_compute_type_qu8:
67     {
68       const float output_scale = values[output_id].quantization.scale;
69       const int32_t output_zero_point = values[output_id].quantization.zero_point;
70       const uint8_t output_min =
71         (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
72       const uint8_t output_max =
73         (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
74       status = xnn_create_multiply_nd_qu8(
75         (uint8_t) values[input1_id].quantization.zero_point,
76         values[input1_id].quantization.scale,
77         (uint8_t) values[input2_id].quantization.zero_point,
78         values[input2_id].quantization.scale,
79         (uint8_t) output_zero_point,
80         output_scale, output_min, output_max, node->flags,
81         &opdata->operator_object);
82       break;
83     }
84 #endif  // !defined(XNN_NO_QU8_OPERATORS)
85     default:
86       XNN_UNREACHABLE;
87   }
88   if (status == xnn_status_success) {
89     opdata->shape1.num_dims = values[input1_id].shape.num_dims;
90     opdata->shape2.num_dims = values[input2_id].shape.num_dims;
91     if (values[output_id].layout == xnn_layout_type_nchw) {
92       assert(values[input1_id].layout == xnn_layout_type_nchw);
93       assert(values[input2_id].layout == xnn_layout_type_nchw);
94       opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
95       opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
96       if (values[input1_id].shape.num_dims > 2) {
97         memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
98       }
99       opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
100       opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
101       if (values[input1_id].shape.num_dims > 2) {
102         memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
103       }
104     } else {
105       assert(values[output_id].layout == xnn_layout_type_nhwc);
106       assert(values[input1_id].layout == xnn_layout_type_nhwc);
107       assert(values[input2_id].layout == xnn_layout_type_nhwc);
108       memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
109       memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
110     }
111     opdata->inputs[0] = input1_id;
112     opdata->inputs[1] = input2_id;
113     opdata->outputs[0] = output_id;
114   }
115   return status;
116 }
117 
setup_multiply_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)118 static enum xnn_status setup_multiply_operator(
119   const struct xnn_operator_data* opdata,
120   const struct xnn_blob* blobs,
121   size_t num_blobs,
122   pthreadpool_t threadpool)
123 {
124   const uint32_t input1_id = opdata->inputs[0];
125   assert(input1_id != XNN_INVALID_VALUE_ID);
126   assert(input1_id < num_blobs);
127 
128   const uint32_t input2_id = opdata->inputs[1];
129   assert(input2_id != XNN_INVALID_VALUE_ID);
130   assert(input2_id < num_blobs);
131 
132   const uint32_t output_id = opdata->outputs[0];
133   assert(output_id != XNN_INVALID_VALUE_ID);
134   assert(output_id < num_blobs);
135 
136   const struct xnn_blob* input1_blob = blobs + input1_id;
137   const void* input1_data = input1_blob->data;
138   assert(input1_data != NULL);
139 
140   const struct xnn_blob* input2_blob = blobs + input2_id;
141   const void* input2_data = input2_blob->data;
142   assert(input2_data != NULL);
143 
144   const struct xnn_blob* output_blob = blobs + output_id;
145   void* output_data = output_blob->data;
146   assert(output_data != NULL);
147 
148   switch (opdata->operator_object->type) {
149     case xnn_operator_type_multiply_nd_f32:
150       return xnn_setup_multiply_nd_f32(
151         opdata->operator_object,
152         opdata->shape1.num_dims,
153         opdata->shape1.dim,
154         opdata->shape2.num_dims,
155         opdata->shape2.dim,
156         input1_data, input2_data, output_data,
157         threadpool);
158       break;
159 #ifndef XNN_NO_QS8_OPERATORS
160     case xnn_operator_type_multiply_nd_qs8:
161       return xnn_setup_multiply_nd_qs8(
162         opdata->operator_object,
163         opdata->shape1.num_dims,
164         opdata->shape1.dim,
165         opdata->shape2.num_dims,
166         opdata->shape2.dim,
167         input1_data, input2_data, output_data,
168         threadpool);
169       break;
170 #endif  // !defined(XNN_NO_QS8_OPERATORS)
171 #ifndef XNN_NO_QU8_OPERATORS
172     case xnn_operator_type_multiply_nd_qu8:
173       return xnn_setup_multiply_nd_qu8(
174         opdata->operator_object,
175         opdata->shape1.num_dims,
176         opdata->shape1.dim,
177         opdata->shape2.num_dims,
178         opdata->shape2.dim,
179         input1_data, input2_data, output_data,
180         threadpool);
181       break;
182 #endif  // !defined(XNN_NO_QU8_OPERATORS)
183     default:
184       XNN_UNREACHABLE;
185   }
186 }
187 
xnn_define_multiply2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)188 enum xnn_status xnn_define_multiply2(
189   xnn_subgraph_t subgraph,
190   float output_min,
191   float output_max,
192   uint32_t input1_id,
193   uint32_t input2_id,
194   uint32_t output_id,
195   uint32_t flags)
196 {
197   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
198     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
199       xnn_node_type_to_string(xnn_node_type_multiply2));
200     return xnn_status_uninitialized;
201   }
202 
203   if (isnan(output_min)) {
204     xnn_log_error(
205       "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
206       xnn_node_type_to_string(xnn_node_type_multiply2));
207     return xnn_status_invalid_parameter;
208   }
209 
210   if (isnan(output_max)) {
211     xnn_log_error(
212       "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
213       xnn_node_type_to_string(xnn_node_type_multiply2));
214     return xnn_status_invalid_parameter;
215   }
216 
217   if (output_min >= output_max) {
218     xnn_log_error(
219       "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
220       xnn_node_type_to_string(xnn_node_type_multiply2), output_min, output_max);
221     return xnn_status_invalid_parameter;
222   }
223 
224   if (input1_id >= subgraph->num_values) {
225     xnn_log_error(
226       "failed to define %s operator with the first input ID #%" PRIu32 ": invalid Value ID",
227       xnn_node_type_to_string(xnn_node_type_multiply2), input1_id);
228     return xnn_status_invalid_parameter;
229   }
230 
231   const struct xnn_value* input1_value = &subgraph->values[input1_id];
232   if (input1_value->type != xnn_value_type_dense_tensor) {
233     xnn_log_error(
234       "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
235       xnn_node_type_to_string(xnn_node_type_multiply2), input1_id, input1_value->type);
236     return xnn_status_invalid_parameter;
237   }
238 
239   switch (input1_value->datatype) {
240     case xnn_datatype_fp32:
241 #ifndef XNN_NO_QS8_OPERATORS
242     case xnn_datatype_qint8:
243 #endif  // !defined(XNN_NO_QS8_OPERATORS)
244 #ifndef XNN_NO_QU8_OPERATORS
245     case xnn_datatype_quint8:
246 #endif  // !defined(XNN_NO_QU8_OPERATORS)
247       break;
248     default:
249       xnn_log_error(
250         "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
251         xnn_node_type_to_string(xnn_node_type_multiply2), input1_id,
252         xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
253       return xnn_status_invalid_parameter;
254   }
255 
256   if (input2_id >= subgraph->num_values) {
257     xnn_log_error(
258       "failed to define %s operator with the second input ID #%" PRIu32 ": invalid Value ID",
259       xnn_node_type_to_string(xnn_node_type_multiply2), input2_id);
260     return xnn_status_invalid_parameter;
261   }
262 
263   const struct xnn_value* input2_value = &subgraph->values[input2_id];
264   if (input2_value->type != xnn_value_type_dense_tensor) {
265     xnn_log_error(
266       "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
267       xnn_node_type_to_string(xnn_node_type_multiply2), input2_id, input2_value->type);
268     return xnn_status_invalid_parameter;
269   }
270 
271   switch (input2_value->datatype) {
272     case xnn_datatype_fp32:
273 #ifndef XNN_NO_QS8_OPERATORS
274     case xnn_datatype_qint8:
275 #endif  // !defined(XNN_NO_QS8_OPERATORS)
276 #ifndef XNN_NO_QU8_OPERATORS
277     case xnn_datatype_quint8:
278 #endif  // !defined(XNN_NO_QU8_OPERATORS)
279       break;
280     default:
281       xnn_log_error(
282         "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
283         xnn_node_type_to_string(xnn_node_type_multiply2), input2_id,
284         xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
285       return xnn_status_invalid_parameter;
286   }
287 
288   if (output_id >= subgraph->num_values) {
289     xnn_log_error(
290       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
291       xnn_node_type_to_string(xnn_node_type_multiply2), output_id);
292     return xnn_status_invalid_parameter;
293   }
294 
295   const struct xnn_value* output_value = &subgraph->values[output_id];
296   if (output_value->type != xnn_value_type_dense_tensor) {
297     xnn_log_error(
298       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
299       xnn_node_type_to_string(xnn_node_type_multiply2), output_id, output_value->type);
300     return xnn_status_invalid_parameter;
301   }
302 
303   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
304   switch (output_value->datatype) {
305     case xnn_datatype_fp32:
306       compute_type = xnn_compute_type_fp32;
307       break;
308 #ifndef XNN_NO_QS8_OPERATORS
309     case xnn_datatype_qint8:
310       compute_type = xnn_compute_type_qs8;
311       break;
312 #endif  // !defined(XNN_NO_QS8_OPERATORS)
313 #ifndef XNN_NO_QU8_OPERATORS
314     case xnn_datatype_quint8:
315       compute_type = xnn_compute_type_qu8;
316       break;
317 #endif  // !defined(XNN_NO_QU8_OPERATORS)
318     default:
319       xnn_log_error(
320         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
321         xnn_node_type_to_string(xnn_node_type_multiply2), output_id,
322         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
323       return xnn_status_invalid_parameter;
324   }
325 
326   if (input1_value->datatype != input2_value->datatype ||
327       input1_value->datatype != output_value->datatype)
328   {
329     xnn_log_error(
330       "failed to define %s operator with input IDs #%" PRIu32 " and #%" PRIu32 " and output ID #%" PRIu32
331       ": mismatching datatypes across the first input (%s), the second input (%s), and output (%s)",
332       xnn_node_type_to_string(xnn_node_type_multiply2), input1_id, input2_id, output_id,
333       xnn_datatype_to_string(input1_value->datatype),
334       xnn_datatype_to_string(input2_value->datatype),
335       xnn_datatype_to_string(output_value->datatype));
336     return xnn_status_invalid_parameter;
337   }
338 
339   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
340   if (node == NULL) {
341     return xnn_status_out_of_memory;
342   }
343 
344   node->type = xnn_node_type_multiply2;
345   node->compute_type = compute_type;
346   node->activation.output_min = output_min;
347   node->activation.output_max = output_max;
348   node->num_inputs = 2;
349   node->inputs[0] = input1_id;
350   node->inputs[1] = input2_id;
351   node->num_outputs = 1;
352   node->outputs[0] = output_id;
353   node->flags = flags;
354 
355   node->create = create_multiply_operator;
356   node->setup = setup_multiply_operator;
357 
358   return xnn_status_success;
359 }
360