1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/util.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <algorithm>
21 #include <complex>
22 #include <cstring>
23 #include <initializer_list>
24 #include <memory>
25 #include <string>
26 #include <vector>
27
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31
32 namespace tflite {
33 namespace {
34
UnresolvedOpInvoke(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus UnresolvedOpInvoke(TfLiteContext* context, TfLiteNode* node) {
36 context->ReportError(context,
37 "Encountered an unresolved custom op. Did you miss "
38 "a custom op or delegate?");
39 return kTfLiteError;
40 }
41
42 } // namespace
43
IsFlexOp(const char * custom_name)44 bool IsFlexOp(const char* custom_name) {
45 return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
46 strlen(kFlexCustomCodePrefix)) == 0;
47 }
48
BuildTfLiteIntArray(const std::vector<int> & data)49 std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
50 const std::vector<int>& data) {
51 std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
52 TfLiteIntArrayCreate(data.size()));
53 std::copy(data.begin(), data.end(), result->data);
54 return result;
55 }
56
ConvertVectorToTfLiteIntArray(const std::vector<int> & input)57 TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
58 return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
59 input.data());
60 }
61
ConvertArrayToTfLiteIntArray(const int rank,const int * dims)62 TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) {
63 TfLiteIntArray* output = TfLiteIntArrayCreate(rank);
64 for (size_t i = 0; i < rank; i++) {
65 output->data[i] = dims[i];
66 }
67 return output;
68 }
69
EqualArrayAndTfLiteIntArray(const TfLiteIntArray * a,const int b_size,const int * b)70 bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
71 const int* b) {
72 if (!a) return false;
73 if (a->size != b_size) return false;
74 for (int i = 0; i < a->size; ++i) {
75 if (a->data[i] != b[i]) return false;
76 }
77 return true;
78 }
79
CombineHashes(std::initializer_list<size_t> hashes)80 size_t CombineHashes(std::initializer_list<size_t> hashes) {
81 size_t result = 0;
82 // Hash combiner used by TensorFlow core.
83 for (size_t hash : hashes) {
84 result = result ^
85 (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4));
86 }
87 return result;
88 }
89
GetSizeOfType(TfLiteContext * context,const TfLiteType type,size_t * bytes)90 TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
91 size_t* bytes) {
92 // TODO(levp): remove the default case so that new types produce compilation
93 // error.
94 switch (type) {
95 case kTfLiteFloat32:
96 *bytes = sizeof(float);
97 break;
98 case kTfLiteInt32:
99 *bytes = sizeof(int32_t);
100 break;
101 case kTfLiteUInt32:
102 *bytes = sizeof(uint32_t);
103 break;
104 case kTfLiteUInt8:
105 *bytes = sizeof(uint8_t);
106 break;
107 case kTfLiteInt64:
108 *bytes = sizeof(int64_t);
109 break;
110 case kTfLiteUInt64:
111 *bytes = sizeof(uint64_t);
112 break;
113 case kTfLiteBool:
114 *bytes = sizeof(bool);
115 break;
116 case kTfLiteComplex64:
117 *bytes = sizeof(std::complex<float>);
118 break;
119 case kTfLiteComplex128:
120 *bytes = sizeof(std::complex<double>);
121 break;
122 case kTfLiteInt16:
123 *bytes = sizeof(int16_t);
124 break;
125 case kTfLiteInt8:
126 *bytes = sizeof(int8_t);
127 break;
128 case kTfLiteFloat16:
129 *bytes = sizeof(TfLiteFloat16);
130 break;
131 case kTfLiteFloat64:
132 *bytes = sizeof(double);
133 break;
134 default:
135 if (context) {
136 context->ReportError(
137 context,
138 "Type %d is unsupported. Only float16, float32, float64, int8, "
139 "int16, int32, int64, uint8, uint64, bool, complex64 and "
140 "complex128 supported currently.",
141 type);
142 }
143 return kTfLiteError;
144 }
145 return kTfLiteOk;
146 }
147
CreateUnresolvedCustomOp(const char * custom_op_name)148 TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name) {
149 return TfLiteRegistration{nullptr,
150 nullptr,
151 nullptr,
152 /*invoke*/ &UnresolvedOpInvoke,
153 nullptr,
154 BuiltinOperator_CUSTOM,
155 custom_op_name,
156 1};
157 }
158
IsUnresolvedCustomOp(const TfLiteRegistration & registration)159 bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) {
160 return registration.builtin_code == tflite::BuiltinOperator_CUSTOM &&
161 registration.invoke == &UnresolvedOpInvoke;
162 }
163
GetOpNameByRegistration(const TfLiteRegistration & registration)164 std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
165 auto op = registration.builtin_code;
166 std::string result =
167 EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
168 if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) &&
169 registration.custom_name) {
170 result += " " + std::string(registration.custom_name);
171 }
172 return result;
173 }
174
IsValidationSubgraph(const char * name)175 bool IsValidationSubgraph(const char* name) {
176 // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
177 return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
178 }
179 } // namespace tflite
180