• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/util.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <complex>
22 #include <cstring>
23 #include <initializer_list>
24 #include <memory>
25 #include <string>
26 #include <vector>
27 
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/schema/schema_generated.h"
31 
32 namespace tflite {
33 namespace {
34 
UnresolvedOpInvoke(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus UnresolvedOpInvoke(TfLiteContext* context, TfLiteNode* node) {
36   context->ReportError(context,
37                        "Encountered an unresolved custom op. Did you miss "
38                        "a custom op or delegate?");
39   return kTfLiteError;
40 }
41 
42 }  // namespace
43 
IsFlexOp(const char * custom_name)44 bool IsFlexOp(const char* custom_name) {
45   return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
46                                 strlen(kFlexCustomCodePrefix)) == 0;
47 }
48 
BuildTfLiteIntArray(const std::vector<int> & data)49 std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
50     const std::vector<int>& data) {
51   std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
52       TfLiteIntArrayCreate(data.size()));
53   std::copy(data.begin(), data.end(), result->data);
54   return result;
55 }
56 
ConvertVectorToTfLiteIntArray(const std::vector<int> & input)57 TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
58   return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
59                                       input.data());
60 }
61 
ConvertArrayToTfLiteIntArray(const int rank,const int * dims)62 TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) {
63   TfLiteIntArray* output = TfLiteIntArrayCreate(rank);
64   for (size_t i = 0; i < rank; i++) {
65     output->data[i] = dims[i];
66   }
67   return output;
68 }
69 
EqualArrayAndTfLiteIntArray(const TfLiteIntArray * a,const int b_size,const int * b)70 bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
71                                  const int* b) {
72   if (!a) return false;
73   if (a->size != b_size) return false;
74   for (int i = 0; i < a->size; ++i) {
75     if (a->data[i] != b[i]) return false;
76   }
77   return true;
78 }
79 
CombineHashes(std::initializer_list<size_t> hashes)80 size_t CombineHashes(std::initializer_list<size_t> hashes) {
81   size_t result = 0;
82   // Hash combiner used by TensorFlow core.
83   for (size_t hash : hashes) {
84     result = result ^
85              (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4));
86   }
87   return result;
88 }
89 
GetSizeOfType(TfLiteContext * context,const TfLiteType type,size_t * bytes)90 TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
91                            size_t* bytes) {
92   // TODO(levp): remove the default case so that new types produce compilation
93   // error.
94   switch (type) {
95     case kTfLiteFloat32:
96       *bytes = sizeof(float);
97       break;
98     case kTfLiteInt32:
99       *bytes = sizeof(int32_t);
100       break;
101     case kTfLiteUInt32:
102       *bytes = sizeof(uint32_t);
103       break;
104     case kTfLiteUInt8:
105       *bytes = sizeof(uint8_t);
106       break;
107     case kTfLiteInt64:
108       *bytes = sizeof(int64_t);
109       break;
110     case kTfLiteUInt64:
111       *bytes = sizeof(uint64_t);
112       break;
113     case kTfLiteBool:
114       *bytes = sizeof(bool);
115       break;
116     case kTfLiteComplex64:
117       *bytes = sizeof(std::complex<float>);
118       break;
119     case kTfLiteComplex128:
120       *bytes = sizeof(std::complex<double>);
121       break;
122     case kTfLiteInt16:
123       *bytes = sizeof(int16_t);
124       break;
125     case kTfLiteInt8:
126       *bytes = sizeof(int8_t);
127       break;
128     case kTfLiteFloat16:
129       *bytes = sizeof(TfLiteFloat16);
130       break;
131     case kTfLiteFloat64:
132       *bytes = sizeof(double);
133       break;
134     default:
135       if (context) {
136         context->ReportError(
137             context,
138             "Type %d is unsupported. Only float16, float32, float64, int8, "
139             "int16, int32, int64, uint8, uint64, bool, complex64 and "
140             "complex128 supported currently.",
141             type);
142       }
143       return kTfLiteError;
144   }
145   return kTfLiteOk;
146 }
147 
CreateUnresolvedCustomOp(const char * custom_op_name)148 TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name) {
149   return TfLiteRegistration{nullptr,
150                             nullptr,
151                             nullptr,
152                             /*invoke*/ &UnresolvedOpInvoke,
153                             nullptr,
154                             BuiltinOperator_CUSTOM,
155                             custom_op_name,
156                             1};
157 }
158 
IsUnresolvedCustomOp(const TfLiteRegistration & registration)159 bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) {
160   return registration.builtin_code == tflite::BuiltinOperator_CUSTOM &&
161          registration.invoke == &UnresolvedOpInvoke;
162 }
163 
GetOpNameByRegistration(const TfLiteRegistration & registration)164 std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
165   auto op = registration.builtin_code;
166   std::string result =
167       EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
168   if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) &&
169       registration.custom_name) {
170     result += " " + std::string(registration.custom_name);
171   }
172   return result;
173 }
174 
IsValidationSubgraph(const char * name)175 bool IsValidationSubgraph(const char* name) {
176   // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
177   return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
178 }
179 }  // namespace tflite
180