• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
create_prelu_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_prelu_operator(
17   const struct xnn_node* node,
18   const struct xnn_value* values,
19   size_t num_values,
20   struct xnn_operator_data* opdata)
21 {
22   assert(node->num_inputs == 2);
23   const uint32_t input_id = node->inputs[0];
24   assert(input_id != XNN_INVALID_VALUE_ID);
25   assert(input_id < num_values);
26   const uint32_t slope_id = node->inputs[1];
27   assert(slope_id != XNN_INVALID_VALUE_ID);
28   assert(slope_id < num_values);
29 
30   assert(node->num_outputs == 1);
31   const uint32_t output_id = node->outputs[0];
32   assert(output_id != XNN_INVALID_VALUE_ID);
33   assert(output_id < num_values);
34 
35   const size_t num_input_dims = values[input_id].shape.num_dims;
36   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
37 
38   enum xnn_status status;
39   switch (node->compute_type) {
40 #ifndef XNN_NO_F16_OPERATORS
41     case xnn_compute_type_fp16:
42       status = xnn_create_prelu_nc_f16(
43         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
44         values[slope_id].data /* negative slope */,
45         node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
46         &opdata->operator_object);
47       break;
48 #endif  // XNN_NO_F16_OPERATORS
49     case xnn_compute_type_fp32:
50       status = xnn_create_prelu_nc_f32(
51         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
52         values[slope_id].data /* negative slope */,
53         node->flags,
54         &opdata->operator_object);
55       break;
56     default:
57       XNN_UNREACHABLE;
58   }
59   if (status == xnn_status_success) {
60     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
61     opdata->inputs[0] = input_id;
62     opdata->outputs[0] = output_id;
63   }
64   return status;
65 }
66 
setup_prelu_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)67 static enum xnn_status setup_prelu_operator(
68   const struct xnn_operator_data* opdata,
69   const struct xnn_blob* blobs,
70   size_t num_blobs,
71   pthreadpool_t threadpool)
72 {
73   const uint32_t input_id = opdata->inputs[0];
74   assert(input_id != XNN_INVALID_VALUE_ID);
75   assert(input_id < num_blobs);
76 
77   const uint32_t output_id = opdata->outputs[0];
78   assert(output_id != XNN_INVALID_VALUE_ID);
79   assert(output_id < num_blobs);
80 
81   const struct xnn_blob* input_blob = blobs + input_id;
82   const void* input_data = input_blob->data;
83   assert(input_data != NULL);
84 
85   const struct xnn_blob* output_blob = blobs + output_id;
86   void* output_data = output_blob->data;
87   assert(output_data != NULL);
88 
89   switch (opdata->operator_object->type) {
90 #ifndef XNN_NO_F16_OPERATORS
91     case xnn_operator_type_prelu_nc_f16:
92       return xnn_setup_prelu_nc_f16(
93         opdata->operator_object,
94         opdata->batch_size,
95         input_data,
96         output_data,
97         threadpool);
98 #endif  // XNN_NO_F16_OPERATORS
99     case xnn_operator_type_prelu_nc_f32:
100       return xnn_setup_prelu_nc_f32(
101         opdata->operator_object,
102         opdata->batch_size,
103         input_data,
104         output_data,
105         threadpool);
106     default:
107       XNN_UNREACHABLE;
108   }
109 
110 }
111 
xnn_define_prelu(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t slope_id,uint32_t output_id,uint32_t flags)112 enum xnn_status xnn_define_prelu(
113   xnn_subgraph_t subgraph,
114   uint32_t input_id,
115   uint32_t slope_id,
116   uint32_t output_id,
117   uint32_t flags)
118 {
119   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
120     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
121       xnn_node_type_to_string(xnn_node_type_prelu));
122     return xnn_status_uninitialized;
123   }
124 
125   if (input_id >= subgraph->num_values) {
126     xnn_log_error(
127       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
128       xnn_node_type_to_string(xnn_node_type_prelu), input_id);
129     return xnn_status_invalid_parameter;
130   }
131 
132   const struct xnn_value* input_value = &subgraph->values[input_id];
133   if (input_value->type != xnn_value_type_dense_tensor) {
134     xnn_log_error(
135       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
136       xnn_node_type_to_string(xnn_node_type_prelu), input_id, input_value->type);
137     return xnn_status_invalid_parameter;
138   }
139 
140   switch (input_value->datatype) {
141     case xnn_datatype_fp32:
142       break;
143     default:
144       xnn_log_error(
145         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
146         xnn_node_type_to_string(xnn_node_type_prelu), input_id,
147         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
148       return xnn_status_invalid_parameter;
149   }
150 
151   if (slope_id >= subgraph->num_values) {
152     xnn_log_error(
153       "failed to define %s operator with slope ID #%" PRIu32 ": invalid Value ID",
154       xnn_node_type_to_string(xnn_node_type_prelu), slope_id);
155     return xnn_status_invalid_parameter;
156   }
157 
158   const struct xnn_value* slope_value = &subgraph->values[slope_id];
159   if (slope_value->type != xnn_value_type_dense_tensor) {
160     xnn_log_error(
161       "failed to define %s operator with slope ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
162       xnn_node_type_to_string(xnn_node_type_prelu), slope_id, slope_value->type);
163     return xnn_status_invalid_parameter;
164   }
165 
166   switch (slope_value->datatype) {
167     case xnn_datatype_fp32:
168       break;
169     default:
170       xnn_log_error(
171         "failed to define %s operator with slope ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
172         xnn_node_type_to_string(xnn_node_type_prelu), slope_id,
173         xnn_datatype_to_string(slope_value->datatype), slope_value->datatype);
174       return xnn_status_invalid_parameter;
175   }
176 
177   if (output_id >= subgraph->num_values) {
178     xnn_log_error(
179       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
180       xnn_node_type_to_string(xnn_node_type_prelu), output_id);
181     return xnn_status_invalid_parameter;
182   }
183 
184   const struct xnn_value* output_value = &subgraph->values[output_id];
185   if (output_value->type != xnn_value_type_dense_tensor) {
186     xnn_log_error(
187       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
188       xnn_node_type_to_string(xnn_node_type_prelu), output_id, output_value->type);
189     return xnn_status_invalid_parameter;
190   }
191 
192   switch (output_value->datatype) {
193     case xnn_datatype_fp32:
194       break;
195     default:
196       xnn_log_error(
197         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
198         xnn_node_type_to_string(xnn_node_type_prelu), output_id,
199         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
200       return xnn_status_invalid_parameter;
201   }
202 
203   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
204   if (node == NULL) {
205     return xnn_status_out_of_memory;
206   }
207 
208   node->type = xnn_node_type_prelu;
209   node->compute_type = xnn_compute_type_fp32;
210   node->num_inputs = 2;
211   node->inputs[0] = input_id;
212   node->inputs[1] = slope_id;
213   node->num_outputs = 1;
214   node->outputs[0] = output_id;
215   node->flags = flags;
216 
217   node->create = create_prelu_operator;
218   node->setup = setup_prelu_operator;
219 
220   return xnn_status_success;
221 }
222