1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16 #include "tensorflow/lite/string_util.h"
17
18 namespace toco {
19
20 namespace tflite {
21
22 namespace {
23
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)24 DataBuffer::FlatBufferOffset CopyStringToBuffer(
25 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
26 const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
27 ::tflite::DynamicBuffer dyn_buffer;
28 for (const std::string& str : src_data) {
29 dyn_buffer.AddString(str.c_str(), str.length());
30 }
31 char* tensor_buffer;
32 int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
33 std::vector<uint8_t> dst_data(bytes);
34 memcpy(dst_data.data(), tensor_buffer, bytes);
35 free(tensor_buffer);
36 return builder->CreateVector(dst_data.data(), bytes);
37 }
38
39 // vector<bool> may be implemented using a bit-set, so we can't just
40 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
41 // CreateVector handle it.
42 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)43 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
44 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
45 const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
46 return builder->CreateVector(src_data);
47 }
48
49 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)50 DataBuffer::FlatBufferOffset CopyBuffer(
51 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
52 using NativeT = ::toco::DataType<T>;
53 const auto& src_data = array.GetBuffer<T>().data;
54 const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
55 auto size = src_data.size() * sizeof(NativeT);
56 return builder->CreateVector(dst_data, size);
57 }
58
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)59 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
60 auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
61 std::vector<std::string>* dst_data =
62 &array->GetMutableBuffer<ArrayDataType::kString>().data;
63 int32_t num_strings = ::tflite::GetStringCount(src_data);
64 for (int i = 0; i < num_strings; i++) {
65 ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
66 std::string this_str(str_ref.str, str_ref.len);
67 dst_data->push_back(this_str);
68 }
69 }
70
71 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)72 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
73 using NativeT = ::toco::DataType<T>;
74 auto* src_buffer = buffer.data();
75 const NativeT* src_data =
76 reinterpret_cast<const NativeT*>(src_buffer->data());
77 int num_items = src_buffer->size() / sizeof(NativeT);
78
79 std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
80 for (int i = 0; i < num_items; ++i) {
81 dst_data->push_back(*src_data);
82 ++src_data;
83 }
84 }
85 } // namespace
86
Serialize(ArrayDataType array_data_type)87 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
88 switch (array_data_type) {
89 case ArrayDataType::kFloat:
90 return ::tflite::TensorType_FLOAT32;
91 case ArrayDataType::kInt16:
92 return ::tflite::TensorType_INT16;
93 case ArrayDataType::kInt32:
94 return ::tflite::TensorType_INT32;
95 case ArrayDataType::kUint32:
96 return ::tflite::TensorType_UINT32;
97 case ArrayDataType::kInt64:
98 return ::tflite::TensorType_INT64;
99 case ArrayDataType::kUint8:
100 return ::tflite::TensorType_UINT8;
101 case ArrayDataType::kString:
102 return ::tflite::TensorType_STRING;
103 case ArrayDataType::kBool:
104 return ::tflite::TensorType_BOOL;
105 case ArrayDataType::kComplex64:
106 return ::tflite::TensorType_COMPLEX64;
107 default:
108 // FLOAT32 is filled for unknown data types.
109 // TODO(ycling): Implement type inference in TF Lite interpreter.
110 return ::tflite::TensorType_FLOAT32;
111 }
112 }
113
Deserialize(int tensor_type)114 ArrayDataType DataType::Deserialize(int tensor_type) {
115 switch (::tflite::TensorType(tensor_type)) {
116 case ::tflite::TensorType_FLOAT32:
117 return ArrayDataType::kFloat;
118 case ::tflite::TensorType_INT16:
119 return ArrayDataType::kInt16;
120 case ::tflite::TensorType_INT32:
121 return ArrayDataType::kInt32;
122 case ::tflite::TensorType_UINT32:
123 return ArrayDataType::kUint32;
124 case ::tflite::TensorType_INT64:
125 return ArrayDataType::kInt64;
126 case ::tflite::TensorType_STRING:
127 return ArrayDataType::kString;
128 case ::tflite::TensorType_UINT8:
129 return ArrayDataType::kUint8;
130 case ::tflite::TensorType_BOOL:
131 return ArrayDataType::kBool;
132 case ::tflite::TensorType_COMPLEX64:
133 return ArrayDataType::kComplex64;
134 default:
135 LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
136 }
137 }
138
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)139 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
140 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
141 if (!array.buffer) return 0; // an empty buffer, usually an output.
142
143 switch (array.data_type) {
144 case ArrayDataType::kFloat:
145 return CopyBuffer<ArrayDataType::kFloat>(array, builder);
146 case ArrayDataType::kInt16:
147 return CopyBuffer<ArrayDataType::kInt16>(array, builder);
148 case ArrayDataType::kInt32:
149 return CopyBuffer<ArrayDataType::kInt32>(array, builder);
150 case ArrayDataType::kUint32:
151 return CopyBuffer<ArrayDataType::kUint32>(array, builder);
152 case ArrayDataType::kInt64:
153 return CopyBuffer<ArrayDataType::kInt64>(array, builder);
154 case ArrayDataType::kString:
155 return CopyStringToBuffer(array, builder);
156 case ArrayDataType::kUint8:
157 return CopyBuffer<ArrayDataType::kUint8>(array, builder);
158 case ArrayDataType::kBool:
159 return CopyBoolToBuffer(array, builder);
160 case ArrayDataType::kComplex64:
161 return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
162 default:
163 LOG(FATAL) << "Unhandled array data type.";
164 }
165 }
166
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)167 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
168 const ::tflite::Buffer& buffer, Array* array) {
169 if (tensor.buffer() == 0) return; // an empty buffer, usually an output.
170 if (buffer.data() == nullptr) return; // a non-defined buffer.
171
172 switch (tensor.type()) {
173 case ::tflite::TensorType_FLOAT32:
174 return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
175 case ::tflite::TensorType_INT16:
176 return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
177 case ::tflite::TensorType_INT32:
178 return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
179 case ::tflite::TensorType_UINT32:
180 return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
181 case ::tflite::TensorType_INT64:
182 return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
183 case ::tflite::TensorType_STRING:
184 return CopyStringFromBuffer(buffer, array);
185 case ::tflite::TensorType_UINT8:
186 return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
187 case ::tflite::TensorType_BOOL:
188 return CopyBuffer<ArrayDataType::kBool>(buffer, array);
189 case ::tflite::TensorType_COMPLEX64:
190 return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
191 default:
192 LOG(FATAL) << "Unhandled tensor type.";
193 }
194 }
195
Serialize(PaddingType padding_type)196 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
197 switch (padding_type) {
198 case PaddingType::kSame:
199 return ::tflite::Padding_SAME;
200 case PaddingType::kValid:
201 return ::tflite::Padding_VALID;
202 default:
203 LOG(FATAL) << "Unhandled padding type.";
204 }
205 }
206
Deserialize(int padding)207 PaddingType Padding::Deserialize(int padding) {
208 switch (::tflite::Padding(padding)) {
209 case ::tflite::Padding_SAME:
210 return PaddingType::kSame;
211 case ::tflite::Padding_VALID:
212 return PaddingType::kValid;
213 default:
214 LOG(FATAL) << "Unhandled padding.";
215 }
216 }
217
Serialize(FusedActivationFunctionType faf_type)218 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
219 FusedActivationFunctionType faf_type) {
220 switch (faf_type) {
221 case FusedActivationFunctionType::kNone:
222 return ::tflite::ActivationFunctionType_NONE;
223 case FusedActivationFunctionType::kRelu:
224 return ::tflite::ActivationFunctionType_RELU;
225 case FusedActivationFunctionType::kRelu6:
226 return ::tflite::ActivationFunctionType_RELU6;
227 case FusedActivationFunctionType::kRelu1:
228 return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
229 default:
230 LOG(FATAL) << "Unhandled fused activation function type.";
231 }
232 }
233
Deserialize(int activation_function)234 FusedActivationFunctionType ActivationFunction::Deserialize(
235 int activation_function) {
236 switch (::tflite::ActivationFunctionType(activation_function)) {
237 case ::tflite::ActivationFunctionType_NONE:
238 return FusedActivationFunctionType::kNone;
239 case ::tflite::ActivationFunctionType_RELU:
240 return FusedActivationFunctionType::kRelu;
241 case ::tflite::ActivationFunctionType_RELU6:
242 return FusedActivationFunctionType::kRelu6;
243 case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
244 return FusedActivationFunctionType::kRelu1;
245 default:
246 LOG(FATAL) << "Unhandled fused activation function type.";
247 }
248 }
249
250 } // namespace tflite
251
252 } // namespace toco
253