• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16 #include "tensorflow/lite/string_util.h"
17 
18 namespace toco {
19 
20 namespace tflite {
21 
22 namespace {
23 
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)24 DataBuffer::FlatBufferOffset CopyStringToBuffer(
25     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
26   const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
27   ::tflite::DynamicBuffer dyn_buffer;
28   for (const std::string& str : src_data) {
29     dyn_buffer.AddString(str.c_str(), str.length());
30   }
31   char* tensor_buffer;
32   int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
33   std::vector<uint8_t> dst_data(bytes);
34   memcpy(dst_data.data(), tensor_buffer, bytes);
35   free(tensor_buffer);
36   return builder->CreateVector(dst_data.data(), bytes);
37 }
38 
39 // vector<bool> may be implemented using a bit-set, so we can't just
40 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
41 // CreateVector handle it.
42 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)43 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
44     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
45   const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
46   return builder->CreateVector(src_data);
47 }
48 
49 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)50 DataBuffer::FlatBufferOffset CopyBuffer(
51     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
52   using NativeT = ::toco::DataType<T>;
53   const auto& src_data = array.GetBuffer<T>().data;
54   const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
55   auto size = src_data.size() * sizeof(NativeT);
56   return builder->CreateVector(dst_data, size);
57 }
58 
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)59 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
60   auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
61   std::vector<std::string>* dst_data =
62       &array->GetMutableBuffer<ArrayDataType::kString>().data;
63   int32_t num_strings = ::tflite::GetStringCount(src_data);
64   for (int i = 0; i < num_strings; i++) {
65     ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
66     std::string this_str(str_ref.str, str_ref.len);
67     dst_data->push_back(this_str);
68   }
69 }
70 
71 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)72 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
73   using NativeT = ::toco::DataType<T>;
74   auto* src_buffer = buffer.data();
75   const NativeT* src_data =
76       reinterpret_cast<const NativeT*>(src_buffer->data());
77   int num_items = src_buffer->size() / sizeof(NativeT);
78 
79   std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
80   for (int i = 0; i < num_items; ++i) {
81     dst_data->push_back(*src_data);
82     ++src_data;
83   }
84 }
85 }  // namespace
86 
Serialize(ArrayDataType array_data_type)87 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
88   switch (array_data_type) {
89     case ArrayDataType::kFloat:
90       return ::tflite::TensorType_FLOAT32;
91     case ArrayDataType::kInt16:
92       return ::tflite::TensorType_INT16;
93     case ArrayDataType::kInt32:
94       return ::tflite::TensorType_INT32;
95     case ArrayDataType::kUint32:
96       return ::tflite::TensorType_UINT32;
97     case ArrayDataType::kInt64:
98       return ::tflite::TensorType_INT64;
99     case ArrayDataType::kUint8:
100       return ::tflite::TensorType_UINT8;
101     case ArrayDataType::kString:
102       return ::tflite::TensorType_STRING;
103     case ArrayDataType::kBool:
104       return ::tflite::TensorType_BOOL;
105     case ArrayDataType::kComplex64:
106       return ::tflite::TensorType_COMPLEX64;
107     default:
108       // FLOAT32 is filled for unknown data types.
109       // TODO(ycling): Implement type inference in TF Lite interpreter.
110       return ::tflite::TensorType_FLOAT32;
111   }
112 }
113 
Deserialize(int tensor_type)114 ArrayDataType DataType::Deserialize(int tensor_type) {
115   switch (::tflite::TensorType(tensor_type)) {
116     case ::tflite::TensorType_FLOAT32:
117       return ArrayDataType::kFloat;
118     case ::tflite::TensorType_INT16:
119       return ArrayDataType::kInt16;
120     case ::tflite::TensorType_INT32:
121       return ArrayDataType::kInt32;
122     case ::tflite::TensorType_UINT32:
123       return ArrayDataType::kUint32;
124     case ::tflite::TensorType_INT64:
125       return ArrayDataType::kInt64;
126     case ::tflite::TensorType_STRING:
127       return ArrayDataType::kString;
128     case ::tflite::TensorType_UINT8:
129       return ArrayDataType::kUint8;
130     case ::tflite::TensorType_BOOL:
131       return ArrayDataType::kBool;
132     case ::tflite::TensorType_COMPLEX64:
133       return ArrayDataType::kComplex64;
134     default:
135       LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
136   }
137 }
138 
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)139 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
140     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
141   if (!array.buffer) return 0;  // an empty buffer, usually an output.
142 
143   switch (array.data_type) {
144     case ArrayDataType::kFloat:
145       return CopyBuffer<ArrayDataType::kFloat>(array, builder);
146     case ArrayDataType::kInt16:
147       return CopyBuffer<ArrayDataType::kInt16>(array, builder);
148     case ArrayDataType::kInt32:
149       return CopyBuffer<ArrayDataType::kInt32>(array, builder);
150     case ArrayDataType::kUint32:
151       return CopyBuffer<ArrayDataType::kUint32>(array, builder);
152     case ArrayDataType::kInt64:
153       return CopyBuffer<ArrayDataType::kInt64>(array, builder);
154     case ArrayDataType::kString:
155       return CopyStringToBuffer(array, builder);
156     case ArrayDataType::kUint8:
157       return CopyBuffer<ArrayDataType::kUint8>(array, builder);
158     case ArrayDataType::kBool:
159       return CopyBoolToBuffer(array, builder);
160     case ArrayDataType::kComplex64:
161       return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
162     default:
163       LOG(FATAL) << "Unhandled array data type.";
164   }
165 }
166 
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)167 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
168                              const ::tflite::Buffer& buffer, Array* array) {
169   if (tensor.buffer() == 0) return;      // an empty buffer, usually an output.
170   if (buffer.data() == nullptr) return;  // a non-defined buffer.
171 
172   switch (tensor.type()) {
173     case ::tflite::TensorType_FLOAT32:
174       return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
175     case ::tflite::TensorType_INT16:
176       return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
177     case ::tflite::TensorType_INT32:
178       return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
179     case ::tflite::TensorType_UINT32:
180       return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
181     case ::tflite::TensorType_INT64:
182       return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
183     case ::tflite::TensorType_STRING:
184       return CopyStringFromBuffer(buffer, array);
185     case ::tflite::TensorType_UINT8:
186       return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
187     case ::tflite::TensorType_BOOL:
188       return CopyBuffer<ArrayDataType::kBool>(buffer, array);
189     case ::tflite::TensorType_COMPLEX64:
190       return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
191     default:
192       LOG(FATAL) << "Unhandled tensor type.";
193   }
194 }
195 
Serialize(PaddingType padding_type)196 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
197   switch (padding_type) {
198     case PaddingType::kSame:
199       return ::tflite::Padding_SAME;
200     case PaddingType::kValid:
201       return ::tflite::Padding_VALID;
202     default:
203       LOG(FATAL) << "Unhandled padding type.";
204   }
205 }
206 
Deserialize(int padding)207 PaddingType Padding::Deserialize(int padding) {
208   switch (::tflite::Padding(padding)) {
209     case ::tflite::Padding_SAME:
210       return PaddingType::kSame;
211     case ::tflite::Padding_VALID:
212       return PaddingType::kValid;
213     default:
214       LOG(FATAL) << "Unhandled padding.";
215   }
216 }
217 
Serialize(FusedActivationFunctionType faf_type)218 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
219     FusedActivationFunctionType faf_type) {
220   switch (faf_type) {
221     case FusedActivationFunctionType::kNone:
222       return ::tflite::ActivationFunctionType_NONE;
223     case FusedActivationFunctionType::kRelu:
224       return ::tflite::ActivationFunctionType_RELU;
225     case FusedActivationFunctionType::kRelu6:
226       return ::tflite::ActivationFunctionType_RELU6;
227     case FusedActivationFunctionType::kRelu1:
228       return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
229     default:
230       LOG(FATAL) << "Unhandled fused activation function type.";
231   }
232 }
233 
Deserialize(int activation_function)234 FusedActivationFunctionType ActivationFunction::Deserialize(
235     int activation_function) {
236   switch (::tflite::ActivationFunctionType(activation_function)) {
237     case ::tflite::ActivationFunctionType_NONE:
238       return FusedActivationFunctionType::kNone;
239     case ::tflite::ActivationFunctionType_RELU:
240       return FusedActivationFunctionType::kRelu;
241     case ::tflite::ActivationFunctionType_RELU6:
242       return FusedActivationFunctionType::kRelu6;
243     case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
244       return FusedActivationFunctionType::kRelu1;
245     default:
246       LOG(FATAL) << "Unhandled fused activation function type.";
247   }
248 }
249 
250 }  // namespace tflite
251 
252 }  // namespace toco
253