1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14
15
create_leaky_relu_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_leaky_relu_operator(
17 const struct xnn_node* node,
18 const struct xnn_value* values,
19 size_t num_values,
20 struct xnn_operator_data* opdata)
21 {
22 assert(node->compute_type == xnn_compute_type_fp32);
23
24 assert(node->num_inputs == 1);
25 const uint32_t input_id = node->inputs[0];
26 assert(input_id != XNN_INVALID_VALUE_ID);
27 assert(input_id < num_values);
28
29 assert(node->num_outputs == 1);
30 const uint32_t output_id = node->outputs[0];
31 assert(output_id != XNN_INVALID_VALUE_ID);
32 assert(output_id < num_values);
33
34 const size_t num_input_dims = values[input_id].shape.num_dims;
35 const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
36
37 enum xnn_status status = xnn_create_leaky_relu_nc_f32(
38 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
39 node->params.leaky_relu.negative_slope,
40 node->flags,
41 &opdata->operator_object);
42 if (status == xnn_status_success) {
43 opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
44 opdata->inputs[0] = input_id;
45 opdata->outputs[0] = output_id;
46 }
47 return status;
48 }
49
setup_leaky_relu_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)50 static enum xnn_status setup_leaky_relu_operator(
51 const struct xnn_operator_data* opdata,
52 const struct xnn_blob* blobs,
53 size_t num_blobs,
54 pthreadpool_t threadpool)
55 {
56 const uint32_t input_id = opdata->inputs[0];
57 assert(input_id != XNN_INVALID_VALUE_ID);
58 assert(input_id < num_blobs);
59
60 const uint32_t output_id = opdata->outputs[0];
61 assert(output_id != XNN_INVALID_VALUE_ID);
62 assert(output_id < num_blobs);
63
64 const struct xnn_blob* input_blob = blobs + input_id;
65 const void* input_data = input_blob->data;
66 assert(input_data != NULL);
67
68 const struct xnn_blob* output_blob = blobs + output_id;
69 void* output_data = output_blob->data;
70 assert(output_data != NULL);
71
72 return xnn_setup_leaky_relu_nc_f32(
73 opdata->operator_object,
74 opdata->batch_size,
75 input_data,
76 output_data,
77 threadpool);
78 }
79
xnn_define_leaky_relu(xnn_subgraph_t subgraph,float negative_slope,uint32_t input_id,uint32_t output_id,uint32_t flags)80 enum xnn_status xnn_define_leaky_relu(
81 xnn_subgraph_t subgraph,
82 float negative_slope,
83 uint32_t input_id,
84 uint32_t output_id,
85 uint32_t flags)
86 {
87 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
88 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
89 xnn_node_type_to_string(xnn_node_type_leaky_relu));
90 return xnn_status_uninitialized;
91 }
92
93 if (!isfinite(negative_slope)) {
94 xnn_log_error(
95 "failed to create %s operator with %f negative slope: finite number expected",
96 xnn_node_type_to_string(xnn_node_type_leaky_relu),
97 negative_slope);
98 return xnn_status_invalid_parameter;
99 }
100
101 if (input_id >= subgraph->num_values) {
102 xnn_log_error(
103 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
104 xnn_node_type_to_string(xnn_node_type_leaky_relu), input_id);
105 return xnn_status_invalid_parameter;
106 }
107
108 const struct xnn_value* input_value = &subgraph->values[input_id];
109 if (input_value->type != xnn_value_type_dense_tensor) {
110 xnn_log_error(
111 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
112 xnn_node_type_to_string(xnn_node_type_leaky_relu), input_id, input_value->type);
113 return xnn_status_invalid_parameter;
114 }
115
116 switch (input_value->datatype) {
117 case xnn_datatype_fp32:
118 break;
119 default:
120 xnn_log_error(
121 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
122 xnn_node_type_to_string(xnn_node_type_leaky_relu), input_id,
123 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
124 return xnn_status_invalid_parameter;
125 }
126
127 if (output_id >= subgraph->num_values) {
128 xnn_log_error(
129 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
130 xnn_node_type_to_string(xnn_node_type_leaky_relu), output_id);
131 return xnn_status_invalid_parameter;
132 }
133
134 const struct xnn_value* output_value = &subgraph->values[output_id];
135 if (output_value->type != xnn_value_type_dense_tensor) {
136 xnn_log_error(
137 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
138 xnn_node_type_to_string(xnn_node_type_leaky_relu), output_id, output_value->type);
139 return xnn_status_invalid_parameter;
140 }
141
142 switch (output_value->datatype) {
143 case xnn_datatype_fp32:
144 break;
145 default:
146 xnn_log_error(
147 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
148 xnn_node_type_to_string(xnn_node_type_leaky_relu), output_id,
149 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
150 return xnn_status_invalid_parameter;
151 }
152
153 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
154 if (node == NULL) {
155 return xnn_status_out_of_memory;
156 }
157
158 node->type = xnn_node_type_leaky_relu;
159 node->compute_type = xnn_compute_type_fp32;
160 node->params.leaky_relu.negative_slope = negative_slope;
161 node->num_inputs = 1;
162 node->inputs[0] = input_id;
163 node->num_outputs = 1;
164 node->outputs[0] = output_id;
165 node->flags = flags;
166
167 node->create = create_leaky_relu_operator;
168 node->setup = setup_leaky_relu_operator;
169
170 return xnn_status_success;
171 }
172