• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16 
17 #include <string>
18 
19 #include "tensorflow/lite/string_util.h"
20 
21 namespace toco {
22 
23 namespace tflite {
24 
25 namespace {
26 
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)27 DataBuffer::FlatBufferOffset CopyStringToBuffer(
28     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
29   const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
30   ::tflite::DynamicBuffer dyn_buffer;
31   for (const std::string& str : src_data) {
32     dyn_buffer.AddString(str.c_str(), str.length());
33   }
34   char* tensor_buffer;
35   int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
36   std::vector<uint8_t> dst_data(bytes);
37   memcpy(dst_data.data(), tensor_buffer, bytes);
38   free(tensor_buffer);
39   return builder->CreateVector(dst_data.data(), bytes);
40 }
41 
42 // vector<bool> may be implemented using a bit-set, so we can't just
43 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
44 // CreateVector handle it.
45 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)46 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
47     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
48   const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
49   return builder->CreateVector(src_data);
50 }
51 
52 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)53 DataBuffer::FlatBufferOffset CopyBuffer(
54     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
55   using NativeT = ::toco::DataType<T>;
56   const auto& src_data = array.GetBuffer<T>().data;
57   const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
58   auto size = src_data.size() * sizeof(NativeT);
59   return builder->CreateVector(dst_data, size);
60 }
61 
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)62 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
63   auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
64   std::vector<std::string>* dst_data =
65       &array->GetMutableBuffer<ArrayDataType::kString>().data;
66   int32_t num_strings = ::tflite::GetStringCount(src_data);
67   for (int i = 0; i < num_strings; i++) {
68     ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
69     std::string this_str(str_ref.str, str_ref.len);
70     dst_data->push_back(this_str);
71   }
72 }
73 
74 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)75 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
76   using NativeT = ::toco::DataType<T>;
77   auto* src_buffer = buffer.data();
78   const NativeT* src_data =
79       reinterpret_cast<const NativeT*>(src_buffer->data());
80   int num_items = src_buffer->size() / sizeof(NativeT);
81 
82   std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
83   for (int i = 0; i < num_items; ++i) {
84     dst_data->push_back(*src_data);
85     ++src_data;
86   }
87 }
88 }  // namespace
89 
Serialize(ArrayDataType array_data_type)90 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
91   switch (array_data_type) {
92     case ArrayDataType::kFloat:
93       return ::tflite::TensorType_FLOAT32;
94     case ArrayDataType::kInt16:
95       return ::tflite::TensorType_INT16;
96     case ArrayDataType::kInt32:
97       return ::tflite::TensorType_INT32;
98     case ArrayDataType::kUint32:
99       return ::tflite::TensorType_UINT32;
100     case ArrayDataType::kInt64:
101       return ::tflite::TensorType_INT64;
102     case ArrayDataType::kUint8:
103       return ::tflite::TensorType_UINT8;
104     case ArrayDataType::kUint16:
105       return ::tflite::TensorType_UINT16;
106     case ArrayDataType::kString:
107       return ::tflite::TensorType_STRING;
108     case ArrayDataType::kBool:
109       return ::tflite::TensorType_BOOL;
110     case ArrayDataType::kComplex64:
111       return ::tflite::TensorType_COMPLEX64;
112     default:
113       // FLOAT32 is filled for unknown data types.
114       // TODO(ycling): Implement type inference in TF Lite interpreter.
115       return ::tflite::TensorType_FLOAT32;
116   }
117 }
118 
Deserialize(int tensor_type)119 ArrayDataType DataType::Deserialize(int tensor_type) {
120   switch (::tflite::TensorType(tensor_type)) {
121     case ::tflite::TensorType_FLOAT32:
122       return ArrayDataType::kFloat;
123     case ::tflite::TensorType_INT16:
124       return ArrayDataType::kInt16;
125     case ::tflite::TensorType_INT32:
126       return ArrayDataType::kInt32;
127     case ::tflite::TensorType_UINT32:
128       return ArrayDataType::kUint32;
129     case ::tflite::TensorType_INT64:
130       return ArrayDataType::kInt64;
131     case ::tflite::TensorType_STRING:
132       return ArrayDataType::kString;
133     case ::tflite::TensorType_UINT8:
134       return ArrayDataType::kUint8;
135     case ::tflite::TensorType_UINT16:
136       return ArrayDataType::kUint16;
137     case ::tflite::TensorType_BOOL:
138       return ArrayDataType::kBool;
139     case ::tflite::TensorType_COMPLEX64:
140       return ArrayDataType::kComplex64;
141     default:
142       LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
143   }
144 }
145 
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)146 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
147     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
148   if (!array.buffer) return 0;  // an empty buffer, usually an output.
149 
150   switch (array.data_type) {
151     case ArrayDataType::kFloat:
152       return CopyBuffer<ArrayDataType::kFloat>(array, builder);
153     case ArrayDataType::kInt16:
154       return CopyBuffer<ArrayDataType::kInt16>(array, builder);
155     case ArrayDataType::kInt32:
156       return CopyBuffer<ArrayDataType::kInt32>(array, builder);
157     case ArrayDataType::kUint32:
158       return CopyBuffer<ArrayDataType::kUint32>(array, builder);
159     case ArrayDataType::kInt64:
160       return CopyBuffer<ArrayDataType::kInt64>(array, builder);
161     case ArrayDataType::kString:
162       return CopyStringToBuffer(array, builder);
163     case ArrayDataType::kUint8:
164       return CopyBuffer<ArrayDataType::kUint8>(array, builder);
165     case ArrayDataType::kBool:
166       return CopyBoolToBuffer(array, builder);
167     case ArrayDataType::kComplex64:
168       return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
169     default:
170       LOG(FATAL) << "Unhandled array data type.";
171   }
172 }
173 
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)174 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
175                              const ::tflite::Buffer& buffer, Array* array) {
176   if (tensor.buffer() == 0) return;      // an empty buffer, usually an output.
177   if (buffer.data() == nullptr) return;  // a non-defined buffer.
178 
179   switch (tensor.type()) {
180     case ::tflite::TensorType_FLOAT32:
181       return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
182     case ::tflite::TensorType_INT16:
183       return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
184     case ::tflite::TensorType_INT32:
185       return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
186     case ::tflite::TensorType_UINT32:
187       return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
188     case ::tflite::TensorType_INT64:
189       return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
190     case ::tflite::TensorType_STRING:
191       return CopyStringFromBuffer(buffer, array);
192     case ::tflite::TensorType_UINT8:
193       return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
194     case ::tflite::TensorType_BOOL:
195       return CopyBuffer<ArrayDataType::kBool>(buffer, array);
196     case ::tflite::TensorType_COMPLEX64:
197       return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
198     default:
199       LOG(FATAL) << "Unhandled tensor type.";
200   }
201 }
202 
Serialize(PaddingType padding_type)203 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
204   switch (padding_type) {
205     case PaddingType::kSame:
206       return ::tflite::Padding_SAME;
207     case PaddingType::kValid:
208       return ::tflite::Padding_VALID;
209     default:
210       LOG(FATAL) << "Unhandled padding type.";
211   }
212 }
213 
Deserialize(int padding)214 PaddingType Padding::Deserialize(int padding) {
215   switch (::tflite::Padding(padding)) {
216     case ::tflite::Padding_SAME:
217       return PaddingType::kSame;
218     case ::tflite::Padding_VALID:
219       return PaddingType::kValid;
220     default:
221       LOG(FATAL) << "Unhandled padding.";
222   }
223 }
224 
Serialize(FusedActivationFunctionType faf_type)225 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
226     FusedActivationFunctionType faf_type) {
227   switch (faf_type) {
228     case FusedActivationFunctionType::kNone:
229       return ::tflite::ActivationFunctionType_NONE;
230     case FusedActivationFunctionType::kRelu:
231       return ::tflite::ActivationFunctionType_RELU;
232     case FusedActivationFunctionType::kRelu6:
233       return ::tflite::ActivationFunctionType_RELU6;
234     case FusedActivationFunctionType::kRelu1:
235       return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
236     default:
237       LOG(FATAL) << "Unhandled fused activation function type.";
238   }
239 }
240 
Deserialize(int activation_function)241 FusedActivationFunctionType ActivationFunction::Deserialize(
242     int activation_function) {
243   switch (::tflite::ActivationFunctionType(activation_function)) {
244     case ::tflite::ActivationFunctionType_NONE:
245       return FusedActivationFunctionType::kNone;
246     case ::tflite::ActivationFunctionType_RELU:
247       return FusedActivationFunctionType::kRelu;
248     case ::tflite::ActivationFunctionType_RELU6:
249       return FusedActivationFunctionType::kRelu6;
250     case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
251       return FusedActivationFunctionType::kRelu1;
252     default:
253       LOG(FATAL) << "Unhandled fused activation function type.";
254   }
255 }
256 
257 }  // namespace tflite
258 
259 }  // namespace toco
260