• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/params.h>
14 #include <xnnpack/subgraph.h>
15 
16 
create_copy_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)17 static enum xnn_status create_copy_operator(
18   const struct xnn_node* node,
19   const struct xnn_value* values,
20   size_t num_values,
21   struct xnn_operator_data* opdata)
22 {
23   assert(node->num_inputs == 1);
24   const uint32_t input_id = node->inputs[0];
25   assert(input_id != XNN_INVALID_VALUE_ID);
26   assert(input_id < num_values);
27 
28   assert(node->num_outputs == 1);
29   const uint32_t output_id = node->outputs[0];
30   assert(output_id != XNN_INVALID_VALUE_ID);
31   assert(output_id < num_values);
32 
33   enum xnn_status status;
34   switch (node->compute_type) {
35 #ifndef XNN_NO_F16_OPERATORS
36     case xnn_compute_type_fp16:
37       status = xnn_create_copy_nc_x16(
38         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
39         node->flags,
40         &opdata->operator_object);
41       break;
42 #endif  // !defined(XNN_NO_F16_OPERATORS)
43     case xnn_compute_type_fp32:
44       status = xnn_create_copy_nc_x32(
45         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
46         node->flags,
47         &opdata->operator_object);
48       break;
49 #ifndef XNN_NO_QS8_OPERATORS
50     case xnn_compute_type_qs8:
51 #endif  // !defined(XNN_NO_QS8_OPERATORS)
52 #ifndef XNN_NO_QU8_OPERATORS
53     case xnn_compute_type_qu8:
54 #endif  // !defined(XNN_NO_QU8_OPERATORS)
55 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
56       status = xnn_create_copy_nc_x8(
57         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
58         node->flags,
59         &opdata->operator_object);
60       break;
61 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
62     default:
63       XNN_UNREACHABLE;
64   }
65   if (status == xnn_status_success) {
66     opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape);
67     opdata->inputs[0] = input_id;
68     opdata->outputs[0] = output_id;
69   }
70   return status;
71 }
72 
setup_copy_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)73 static enum xnn_status setup_copy_operator(
74   const struct xnn_operator_data* opdata,
75   const struct xnn_blob* blobs,
76   size_t num_blobs,
77   pthreadpool_t threadpool)
78 {
79   const uint32_t input_id = opdata->inputs[0];
80   assert(input_id != XNN_INVALID_VALUE_ID);
81   assert(input_id < num_blobs);
82 
83   const uint32_t output_id = opdata->outputs[0];
84   assert(output_id != XNN_INVALID_VALUE_ID);
85   assert(output_id < num_blobs);
86 
87   const struct xnn_blob* input_blob = blobs + input_id;
88   const void* input_data = input_blob->data;
89   assert(input_data != NULL);
90 
91   const struct xnn_blob* output_blob = blobs + output_id;
92   void* output_data = output_blob->data;
93   assert(output_data != NULL);
94 
95   switch (opdata->operator_object->type) {
96 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
97     case xnn_operator_type_copy_nc_x8:
98       return xnn_setup_copy_nc_x8(
99         opdata->operator_object,
100         opdata->batch_size,
101         input_data,
102         output_data,
103         threadpool);
104       break;
105 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
106 #ifndef XNN_NO_F16_OPERATORS
107     case xnn_operator_type_copy_nc_x16:
108       return xnn_setup_copy_nc_x16(
109         opdata->operator_object,
110         opdata->batch_size,
111         input_data,
112         output_data,
113         threadpool);
114       break;
115 #endif  // !defined(XNN_NO_F16_OPERATORS)
116     case xnn_operator_type_copy_nc_x32:
117       return xnn_setup_copy_nc_x32(
118         opdata->operator_object,
119         opdata->batch_size,
120         input_data,
121         output_data,
122         threadpool);
123       break;
124     default:
125       XNN_UNREACHABLE;
126   }
127 }
128 
xnn_define_static_reshape(xnn_subgraph_t subgraph,size_t num_dims,const size_t * new_shape,uint32_t input_id,uint32_t output_id,uint32_t flags)129 enum xnn_status xnn_define_static_reshape(
130   xnn_subgraph_t subgraph,
131   size_t num_dims,
132   const size_t* new_shape,
133   uint32_t input_id,
134   uint32_t output_id,
135   uint32_t flags)
136 {
137   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
138     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
139       xnn_node_type_to_string(xnn_node_type_static_reshape));
140     return xnn_status_uninitialized;
141   }
142 
143   if (input_id >= subgraph->num_values) {
144     xnn_log_error(
145       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
146       xnn_node_type_to_string(xnn_node_type_static_reshape), input_id);
147     return xnn_status_invalid_parameter;
148   }
149 
150   const struct xnn_value* input_value = &subgraph->values[input_id];
151   if (input_value->type != xnn_value_type_dense_tensor) {
152     xnn_log_error(
153       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
154       xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, input_value->type);
155     return xnn_status_invalid_parameter;
156   }
157 
158   switch (input_value->datatype) {
159     case xnn_datatype_fp32:
160 #ifndef XNN_NO_QS8_OPERATORS
161     case xnn_datatype_qint8:
162 #endif  // !defined(XNN_NO_QS8_OPERATORS)
163 #ifndef XNN_NO_QU8_OPERATORS
164     case xnn_datatype_quint8:
165 #endif  // !defined(XNN_NO_QU8_OPERATORS)
166       break;
167     default:
168       xnn_log_error(
169         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
170         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id,
171         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
172       return xnn_status_invalid_parameter;
173   }
174 
175   if (output_id >= subgraph->num_values) {
176     xnn_log_error(
177       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
178       xnn_node_type_to_string(xnn_node_type_static_reshape), output_id);
179     return xnn_status_invalid_parameter;
180   }
181 
182   const struct xnn_value* output_value = &subgraph->values[output_id];
183   if (output_value->type != xnn_value_type_dense_tensor) {
184     xnn_log_error(
185       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
186       xnn_node_type_to_string(xnn_node_type_static_reshape), output_id, output_value->type);
187     return xnn_status_invalid_parameter;
188   }
189 
190   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
191   switch (output_value->datatype) {
192     case xnn_datatype_fp32:
193       compute_type = xnn_compute_type_fp32;
194       break;
195 #ifndef XNN_NO_QS8_OPERATORS
196     case xnn_datatype_qint8:
197       compute_type = xnn_compute_type_qs8;
198       break;
199 #endif  // !defined(XNN_NO_QS8_OPERATORS)
200 #ifndef XNN_NO_QU8_OPERATORS
201     case xnn_datatype_quint8:
202       compute_type = xnn_compute_type_qu8;
203       break;
204 #endif  // !defined(XNN_NO_QU8_OPERATORS)
205     default:
206       xnn_log_error(
207         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
208         xnn_node_type_to_string(xnn_node_type_static_reshape), output_id,
209         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
210       return xnn_status_invalid_parameter;
211   }
212 
213   if (input_value->datatype != output_value->datatype) {
214     xnn_log_error(
215       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
216       ": mismatching datatypes across input (%s) and output (%s)",
217       xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
218       xnn_datatype_to_string(input_value->datatype),
219       xnn_datatype_to_string(output_value->datatype));
220     return xnn_status_invalid_parameter;
221   }
222 
223 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
224   if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
225     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
226       xnn_log_error(
227         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
228         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
229         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
230         input_value->quantization.zero_point, output_value->quantization.zero_point);
231       return xnn_status_invalid_parameter;
232     }
233     if (input_value->quantization.scale != output_value->quantization.scale) {
234       xnn_log_error(
235         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
236         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
237         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
238         input_value->quantization.scale, output_value->quantization.scale);
239       return xnn_status_invalid_parameter;
240     }
241   }
242 #endif  // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
243 
244   if (num_dims > XNN_MAX_TENSOR_DIMS) {
245     xnn_log_error(
246       "failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported",
247       xnn_node_type_to_string(xnn_node_type_static_reshape), num_dims, (size_t) XNN_MAX_TENSOR_DIMS);
248     return xnn_status_unsupported_parameter;
249   }
250 
251   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
252   if (node == NULL) {
253     return xnn_status_out_of_memory;
254   }
255 
256   node->params.static_reshape.new_shape.num_dims = num_dims;
257   memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t));
258 
259   node->type = xnn_node_type_static_reshape;
260   node->compute_type = compute_type;
261   node->num_inputs = 1;
262   node->inputs[0] = input_id;
263   node->num_outputs = 1;
264   node->outputs[0] = output_id;
265   node->flags = flags;
266 
267   node->create = create_copy_operator;
268   node->setup = setup_copy_operator;
269 
270   return xnn_status_success;
271 }
272