1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/params.h>
14 #include <xnnpack/subgraph.h>
15
16
create_multiply_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_multiply_operator(
18 const struct xnn_node* node,
19 const struct xnn_value* values,
20 size_t num_values,
21 struct xnn_operator_data* opdata)
22 {
23 assert(node->num_inputs == 2);
24 const uint32_t input1_id = node->inputs[0];
25 assert(input1_id != XNN_INVALID_VALUE_ID);
26 assert(input1_id < num_values);
27 const uint32_t input2_id = node->inputs[1];
28 assert(input2_id != XNN_INVALID_VALUE_ID);
29 assert(input2_id < num_values);
30
31 assert(node->num_outputs == 1);
32 const uint32_t output_id = node->outputs[0];
33 assert(output_id != XNN_INVALID_VALUE_ID);
34 assert(output_id < num_values);
35
36 enum xnn_status status;
37 switch (node->compute_type) {
38 case xnn_compute_type_fp32:
39 status = xnn_create_multiply_nd_f32(
40 node->activation.output_min,
41 node->activation.output_max,
42 node->flags,
43 &opdata->operator_object);
44 break;
45 #ifndef XNN_NO_QS8_OPERATORS
46 case xnn_compute_type_qs8:
47 {
48 const float output_scale = values[output_id].quantization.scale;
49 const int32_t output_zero_point = values[output_id].quantization.zero_point;
50 const int8_t output_min =
51 (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
52 const int8_t output_max =
53 (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
54 status = xnn_create_multiply_nd_qs8(
55 (int8_t) values[input1_id].quantization.zero_point,
56 values[input1_id].quantization.scale,
57 (int8_t) values[input2_id].quantization.zero_point,
58 values[input2_id].quantization.scale,
59 (int8_t) output_zero_point,
60 output_scale, output_min, output_max, node->flags,
61 &opdata->operator_object);
62 break;
63 }
64 #endif // !defined(XNN_NO_QS8_OPERATORS)
65 #ifndef XNN_NO_QU8_OPERATORS
66 case xnn_compute_type_qu8:
67 {
68 const float output_scale = values[output_id].quantization.scale;
69 const int32_t output_zero_point = values[output_id].quantization.zero_point;
70 const uint8_t output_min =
71 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
72 const uint8_t output_max =
73 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
74 status = xnn_create_multiply_nd_qu8(
75 (uint8_t) values[input1_id].quantization.zero_point,
76 values[input1_id].quantization.scale,
77 (uint8_t) values[input2_id].quantization.zero_point,
78 values[input2_id].quantization.scale,
79 (uint8_t) output_zero_point,
80 output_scale, output_min, output_max, node->flags,
81 &opdata->operator_object);
82 break;
83 }
84 #endif // !defined(XNN_NO_QU8_OPERATORS)
85 default:
86 XNN_UNREACHABLE;
87 }
88 if (status == xnn_status_success) {
89 opdata->shape1.num_dims = values[input1_id].shape.num_dims;
90 opdata->shape2.num_dims = values[input2_id].shape.num_dims;
91 if (values[output_id].layout == xnn_layout_type_nchw) {
92 assert(values[input1_id].layout == xnn_layout_type_nchw);
93 assert(values[input2_id].layout == xnn_layout_type_nchw);
94 opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
95 opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
96 if (values[input1_id].shape.num_dims > 2) {
97 memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
98 }
99 opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
100 opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
101 if (values[input1_id].shape.num_dims > 2) {
102 memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
103 }
104 } else {
105 assert(values[output_id].layout == xnn_layout_type_nhwc);
106 assert(values[input1_id].layout == xnn_layout_type_nhwc);
107 assert(values[input2_id].layout == xnn_layout_type_nhwc);
108 memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
109 memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
110 }
111 opdata->inputs[0] = input1_id;
112 opdata->inputs[1] = input2_id;
113 opdata->outputs[0] = output_id;
114 }
115 return status;
116 }
117
setup_multiply_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)118 static enum xnn_status setup_multiply_operator(
119 const struct xnn_operator_data* opdata,
120 const struct xnn_blob* blobs,
121 size_t num_blobs,
122 pthreadpool_t threadpool)
123 {
124 const uint32_t input1_id = opdata->inputs[0];
125 assert(input1_id != XNN_INVALID_VALUE_ID);
126 assert(input1_id < num_blobs);
127
128 const uint32_t input2_id = opdata->inputs[1];
129 assert(input2_id != XNN_INVALID_VALUE_ID);
130 assert(input2_id < num_blobs);
131
132 const uint32_t output_id = opdata->outputs[0];
133 assert(output_id != XNN_INVALID_VALUE_ID);
134 assert(output_id < num_blobs);
135
136 const struct xnn_blob* input1_blob = blobs + input1_id;
137 const void* input1_data = input1_blob->data;
138 assert(input1_data != NULL);
139
140 const struct xnn_blob* input2_blob = blobs + input2_id;
141 const void* input2_data = input2_blob->data;
142 assert(input2_data != NULL);
143
144 const struct xnn_blob* output_blob = blobs + output_id;
145 void* output_data = output_blob->data;
146 assert(output_data != NULL);
147
148 switch (opdata->operator_object->type) {
149 case xnn_operator_type_multiply_nd_f32:
150 return xnn_setup_multiply_nd_f32(
151 opdata->operator_object,
152 opdata->shape1.num_dims,
153 opdata->shape1.dim,
154 opdata->shape2.num_dims,
155 opdata->shape2.dim,
156 input1_data, input2_data, output_data,
157 threadpool);
158 break;
159 #ifndef XNN_NO_QS8_OPERATORS
160 case xnn_operator_type_multiply_nd_qs8:
161 return xnn_setup_multiply_nd_qs8(
162 opdata->operator_object,
163 opdata->shape1.num_dims,
164 opdata->shape1.dim,
165 opdata->shape2.num_dims,
166 opdata->shape2.dim,
167 input1_data, input2_data, output_data,
168 threadpool);
169 break;
170 #endif // !defined(XNN_NO_QS8_OPERATORS)
171 #ifndef XNN_NO_QU8_OPERATORS
172 case xnn_operator_type_multiply_nd_qu8:
173 return xnn_setup_multiply_nd_qu8(
174 opdata->operator_object,
175 opdata->shape1.num_dims,
176 opdata->shape1.dim,
177 opdata->shape2.num_dims,
178 opdata->shape2.dim,
179 input1_data, input2_data, output_data,
180 threadpool);
181 break;
182 #endif // !defined(XNN_NO_QU8_OPERATORS)
183 default:
184 XNN_UNREACHABLE;
185 }
186 }
187
xnn_define_multiply2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)188 enum xnn_status xnn_define_multiply2(
189 xnn_subgraph_t subgraph,
190 float output_min,
191 float output_max,
192 uint32_t input1_id,
193 uint32_t input2_id,
194 uint32_t output_id,
195 uint32_t flags)
196 {
197 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
198 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
199 xnn_node_type_to_string(xnn_node_type_multiply2));
200 return xnn_status_uninitialized;
201 }
202
203 if (isnan(output_min)) {
204 xnn_log_error(
205 "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
206 xnn_node_type_to_string(xnn_node_type_multiply2));
207 return xnn_status_invalid_parameter;
208 }
209
210 if (isnan(output_max)) {
211 xnn_log_error(
212 "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
213 xnn_node_type_to_string(xnn_node_type_multiply2));
214 return xnn_status_invalid_parameter;
215 }
216
217 if (output_min >= output_max) {
218 xnn_log_error(
219 "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
220 xnn_node_type_to_string(xnn_node_type_multiply2), output_min, output_max);
221 return xnn_status_invalid_parameter;
222 }
223
224 if (input1_id >= subgraph->num_values) {
225 xnn_log_error(
226 "failed to define %s operator with the first input ID #%" PRIu32 ": invalid Value ID",
227 xnn_node_type_to_string(xnn_node_type_multiply2), input1_id);
228 return xnn_status_invalid_parameter;
229 }
230
231 const struct xnn_value* input1_value = &subgraph->values[input1_id];
232 if (input1_value->type != xnn_value_type_dense_tensor) {
233 xnn_log_error(
234 "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
235 xnn_node_type_to_string(xnn_node_type_multiply2), input1_id, input1_value->type);
236 return xnn_status_invalid_parameter;
237 }
238
239 switch (input1_value->datatype) {
240 case xnn_datatype_fp32:
241 #ifndef XNN_NO_QS8_OPERATORS
242 case xnn_datatype_qint8:
243 #endif // !defined(XNN_NO_QS8_OPERATORS)
244 #ifndef XNN_NO_QU8_OPERATORS
245 case xnn_datatype_quint8:
246 #endif // !defined(XNN_NO_QU8_OPERATORS)
247 break;
248 default:
249 xnn_log_error(
250 "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
251 xnn_node_type_to_string(xnn_node_type_multiply2), input1_id,
252 xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
253 return xnn_status_invalid_parameter;
254 }
255
256 if (input2_id >= subgraph->num_values) {
257 xnn_log_error(
258 "failed to define %s operator with the second input ID #%" PRIu32 ": invalid Value ID",
259 xnn_node_type_to_string(xnn_node_type_multiply2), input2_id);
260 return xnn_status_invalid_parameter;
261 }
262
263 const struct xnn_value* input2_value = &subgraph->values[input2_id];
264 if (input2_value->type != xnn_value_type_dense_tensor) {
265 xnn_log_error(
266 "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
267 xnn_node_type_to_string(xnn_node_type_multiply2), input2_id, input2_value->type);
268 return xnn_status_invalid_parameter;
269 }
270
271 switch (input2_value->datatype) {
272 case xnn_datatype_fp32:
273 #ifndef XNN_NO_QS8_OPERATORS
274 case xnn_datatype_qint8:
275 #endif // !defined(XNN_NO_QS8_OPERATORS)
276 #ifndef XNN_NO_QU8_OPERATORS
277 case xnn_datatype_quint8:
278 #endif // !defined(XNN_NO_QU8_OPERATORS)
279 break;
280 default:
281 xnn_log_error(
282 "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
283 xnn_node_type_to_string(xnn_node_type_multiply2), input2_id,
284 xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
285 return xnn_status_invalid_parameter;
286 }
287
288 if (output_id >= subgraph->num_values) {
289 xnn_log_error(
290 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
291 xnn_node_type_to_string(xnn_node_type_multiply2), output_id);
292 return xnn_status_invalid_parameter;
293 }
294
295 const struct xnn_value* output_value = &subgraph->values[output_id];
296 if (output_value->type != xnn_value_type_dense_tensor) {
297 xnn_log_error(
298 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
299 xnn_node_type_to_string(xnn_node_type_multiply2), output_id, output_value->type);
300 return xnn_status_invalid_parameter;
301 }
302
303 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
304 switch (output_value->datatype) {
305 case xnn_datatype_fp32:
306 compute_type = xnn_compute_type_fp32;
307 break;
308 #ifndef XNN_NO_QS8_OPERATORS
309 case xnn_datatype_qint8:
310 compute_type = xnn_compute_type_qs8;
311 break;
312 #endif // !defined(XNN_NO_QS8_OPERATORS)
313 #ifndef XNN_NO_QU8_OPERATORS
314 case xnn_datatype_quint8:
315 compute_type = xnn_compute_type_qu8;
316 break;
317 #endif // !defined(XNN_NO_QU8_OPERATORS)
318 default:
319 xnn_log_error(
320 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
321 xnn_node_type_to_string(xnn_node_type_multiply2), output_id,
322 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
323 return xnn_status_invalid_parameter;
324 }
325
326 if (input1_value->datatype != input2_value->datatype ||
327 input1_value->datatype != output_value->datatype)
328 {
329 xnn_log_error(
330 "failed to define %s operator with input IDs #%" PRIu32 " and #%" PRIu32 " and output ID #%" PRIu32
331 ": mismatching datatypes across the first input (%s), the second input (%s), and output (%s)",
332 xnn_node_type_to_string(xnn_node_type_multiply2), input1_id, input2_id, output_id,
333 xnn_datatype_to_string(input1_value->datatype),
334 xnn_datatype_to_string(input2_value->datatype),
335 xnn_datatype_to_string(output_value->datatype));
336 return xnn_status_invalid_parameter;
337 }
338
339 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
340 if (node == NULL) {
341 return xnn_status_out_of_memory;
342 }
343
344 node->type = xnn_node_type_multiply2;
345 node->compute_type = compute_type;
346 node->activation.output_min = output_min;
347 node->activation.output_max = output_max;
348 node->num_inputs = 2;
349 node->inputs[0] = input1_id;
350 node->inputs[1] = input2_id;
351 node->num_outputs = 1;
352 node->outputs[0] = output_id;
353 node->flags = flags;
354
355 node->create = create_multiply_operator;
356 node->setup = setup_multiply_operator;
357
358 return xnn_status_success;
359 }
360