1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8
9 #include <xnnpack.h>
10 #include <xnnpack/log.h>
11 #include <xnnpack/params.h>
12 #include <xnnpack/subgraph.h>
13 #include <xnnpack/subgraph-validation.h>
14
xnn_subgraph_check_xnnpack_initialized(enum xnn_node_type node_type)15 enum xnn_status xnn_subgraph_check_xnnpack_initialized(enum xnn_node_type node_type)
16 {
17 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
18 xnn_log_error("failed to define %s operator: XNNPACK is not initialized", xnn_node_type_to_string(node_type));
19 return xnn_status_uninitialized;
20 }
21 return xnn_status_success;
22 }
23
xnn_subgraph_check_input_node_id(enum xnn_node_type node_type,uint32_t input_id,size_t num_values)24 enum xnn_status xnn_subgraph_check_input_node_id(enum xnn_node_type node_type, uint32_t input_id, size_t num_values)
25 {
26 if (input_id >= num_values) {
27 xnn_log_error(
28 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
29 xnn_node_type_to_string(node_type), input_id);
30 return xnn_status_invalid_parameter;
31 }
32 return xnn_status_success;
33 }
34
xnn_subgraph_check_nth_input_node_id(enum xnn_node_type node_type,uint32_t input_id,size_t num_values,size_t nth)35 enum xnn_status xnn_subgraph_check_nth_input_node_id(
36 enum xnn_node_type node_type,
37 uint32_t input_id,
38 size_t num_values,
39 size_t nth)
40 {
41 if (input_id >= num_values) {
42 xnn_log_error(
43 "failed to define %s operator with the input %zu ID #%" PRIu32 ": invalid Value ID",
44 xnn_node_type_to_string(node_type), nth, input_id);
45 return xnn_status_invalid_parameter;
46 }
47 return xnn_status_success;
48 }
49
xnn_subgraph_check_input_type_dense(enum xnn_node_type node_type,uint32_t input_id,const struct xnn_value * input_value)50 enum xnn_status xnn_subgraph_check_input_type_dense(
51 enum xnn_node_type node_type,
52 uint32_t input_id,
53 const struct xnn_value* input_value)
54 {
55 if (input_value->type != xnn_value_type_dense_tensor) {
56 xnn_log_error(
57 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
58 xnn_node_type_to_string(node_type), input_id, input_value->type);
59 return xnn_status_invalid_parameter;
60 }
61 return xnn_status_success;
62 }
63
xnn_subgraph_check_nth_input_type_dense(enum xnn_node_type node_type,uint32_t input_id,const struct xnn_value * input_value,size_t nth)64 enum xnn_status xnn_subgraph_check_nth_input_type_dense(
65 enum xnn_node_type node_type,
66 uint32_t input_id,
67 const struct xnn_value* input_value,
68 size_t nth)
69 {
70 if (input_value->type != xnn_value_type_dense_tensor) {
71 xnn_log_error(
72 "failed to define %s operator with %zu input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
73 xnn_node_type_to_string(node_type), nth, input_id, input_value->type);
74 return xnn_status_invalid_parameter;
75 }
76 return xnn_status_success;
77 }
78
xnn_subgraph_check_output_node_id(enum xnn_node_type node_type,uint32_t output_id,size_t num_values)79 enum xnn_status xnn_subgraph_check_output_node_id(enum xnn_node_type node_type, uint32_t output_id, size_t num_values)
80 {
81 if (output_id >= num_values) {
82 xnn_log_error(
83 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
84 xnn_node_type_to_string(node_type), output_id);
85 return xnn_status_invalid_parameter;
86 }
87 return xnn_status_success;
88 }
89
xnn_subgraph_check_output_type_dense(enum xnn_node_type node_type,uint32_t output_id,const struct xnn_value * output_value)90 enum xnn_status xnn_subgraph_check_output_type_dense(
91 enum xnn_node_type node_type,
92 uint32_t output_id,
93 const struct xnn_value* output_value)
94 {
95 if (output_value->type != xnn_value_type_dense_tensor) {
96 xnn_log_error(
97 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
98 xnn_node_type_to_string(node_type), output_id, output_value->type);
99 return xnn_status_invalid_parameter;
100 }
101 return xnn_status_success;
102 }
103
xnn_subgraph_check_datatype_matches(enum xnn_node_type node_type,uint32_t input_id,const struct xnn_value * input_value,uint32_t output_id,const struct xnn_value * output_value)104 enum xnn_status xnn_subgraph_check_datatype_matches(
105 enum xnn_node_type node_type,
106 uint32_t input_id,
107 const struct xnn_value* input_value,
108 uint32_t output_id,
109 const struct xnn_value* output_value)
110 {
111 assert(input_value->datatype != xnn_datatype_invalid);
112 assert(output_value->datatype != xnn_datatype_invalid);
113 if (input_value->datatype != output_value->datatype) {
114 xnn_log_error(
115 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
116 ": mismatching datatypes across the input (%s) and output (%s)",
117 xnn_node_type_to_string(node_type), input_id, output_id,
118 xnn_datatype_to_string(input_value->datatype),
119 xnn_datatype_to_string(output_value->datatype));
120 return xnn_status_invalid_parameter;
121 }
122 return xnn_status_success;
123 }
124
xnn_subgraph_check_datatype_matches_two_inputs(enum xnn_node_type node_type,uint32_t input1_id,const struct xnn_value * input1_value,uint32_t input2_id,const struct xnn_value * input2_value,uint32_t output_id,const struct xnn_value * output_value)125 enum xnn_status xnn_subgraph_check_datatype_matches_two_inputs(
126 enum xnn_node_type node_type,
127 uint32_t input1_id,
128 const struct xnn_value* input1_value,
129 uint32_t input2_id,
130 const struct xnn_value* input2_value,
131 uint32_t output_id,
132 const struct xnn_value* output_value)
133 {
134 assert(input1_value->datatype != xnn_datatype_invalid);
135 assert(input2_value->datatype != xnn_datatype_invalid);
136 assert(output_value->datatype != xnn_datatype_invalid);
137 if (input1_value->datatype != input2_value->datatype ||
138 input1_value->datatype != output_value->datatype)
139 {
140 xnn_log_error(
141 "failed to define %s operator with input IDs #%" PRIu32 " and #%" PRIu32 " and output ID #%" PRIu32
142 ": mismatching datatypes across the first input (%s), the second input (%s), and output (%s)",
143 xnn_node_type_to_string(node_type), input1_id, input2_id, output_id,
144 xnn_datatype_to_string(input1_value->datatype),
145 xnn_datatype_to_string(input2_value->datatype),
146 xnn_datatype_to_string(output_value->datatype));
147 return xnn_status_invalid_parameter;
148 }
149 return xnn_status_success;
150 }
151
152
xnn_subgraph_check_output_min_max(enum xnn_node_type node_type,float output_min,float output_max)153 enum xnn_status xnn_subgraph_check_output_min_max(enum xnn_node_type node_type, float output_min, float output_max)
154 {
155 if (isnan(output_min)) {
156 xnn_log_error(
157 "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
158 xnn_node_type_to_string(node_type));
159 return xnn_status_invalid_parameter;
160 }
161
162 if (isnan(output_max)) {
163 xnn_log_error(
164 "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
165 xnn_node_type_to_string(node_type));
166 return xnn_status_invalid_parameter;
167 }
168
169 if (output_min >= output_max) {
170 xnn_log_error(
171 "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
172 xnn_node_type_to_string(node_type), output_min, output_max);
173 return xnn_status_invalid_parameter;
174 }
175 return xnn_status_success;
176 }
177