1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16
17 #include <string>
18
19 #include "tensorflow/lite/string_util.h"
20
21 namespace toco {
22
23 namespace tflite {
24
25 namespace {
26
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)27 DataBuffer::FlatBufferOffset CopyStringToBuffer(
28 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
29 const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
30 ::tflite::DynamicBuffer dyn_buffer;
31 for (const std::string& str : src_data) {
32 dyn_buffer.AddString(str.c_str(), str.length());
33 }
34 char* tensor_buffer;
35 int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
36 std::vector<uint8_t> dst_data(bytes);
37 memcpy(dst_data.data(), tensor_buffer, bytes);
38 free(tensor_buffer);
39 return builder->CreateVector(dst_data.data(), bytes);
40 }
41
42 // vector<bool> may be implemented using a bit-set, so we can't just
43 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
44 // CreateVector handle it.
45 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)46 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
47 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
48 const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
49 return builder->CreateVector(src_data);
50 }
51
52 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)53 DataBuffer::FlatBufferOffset CopyBuffer(
54 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
55 using NativeT = ::toco::DataType<T>;
56 const auto& src_data = array.GetBuffer<T>().data;
57 const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
58 auto size = src_data.size() * sizeof(NativeT);
59 return builder->CreateVector(dst_data, size);
60 }
61
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)62 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
63 auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
64 std::vector<std::string>* dst_data =
65 &array->GetMutableBuffer<ArrayDataType::kString>().data;
66 int32_t num_strings = ::tflite::GetStringCount(src_data);
67 for (int i = 0; i < num_strings; i++) {
68 ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
69 std::string this_str(str_ref.str, str_ref.len);
70 dst_data->push_back(this_str);
71 }
72 }
73
74 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)75 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
76 using NativeT = ::toco::DataType<T>;
77 auto* src_buffer = buffer.data();
78 const NativeT* src_data =
79 reinterpret_cast<const NativeT*>(src_buffer->data());
80 int num_items = src_buffer->size() / sizeof(NativeT);
81
82 std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
83 for (int i = 0; i < num_items; ++i) {
84 dst_data->push_back(*src_data);
85 ++src_data;
86 }
87 }
88 } // namespace
89
Serialize(ArrayDataType array_data_type)90 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
91 switch (array_data_type) {
92 case ArrayDataType::kFloat:
93 return ::tflite::TensorType_FLOAT32;
94 case ArrayDataType::kInt16:
95 return ::tflite::TensorType_INT16;
96 case ArrayDataType::kInt32:
97 return ::tflite::TensorType_INT32;
98 case ArrayDataType::kUint32:
99 return ::tflite::TensorType_UINT32;
100 case ArrayDataType::kInt64:
101 return ::tflite::TensorType_INT64;
102 case ArrayDataType::kUint8:
103 return ::tflite::TensorType_UINT8;
104 case ArrayDataType::kUint16:
105 return ::tflite::TensorType_UINT16;
106 case ArrayDataType::kString:
107 return ::tflite::TensorType_STRING;
108 case ArrayDataType::kBool:
109 return ::tflite::TensorType_BOOL;
110 case ArrayDataType::kComplex64:
111 return ::tflite::TensorType_COMPLEX64;
112 default:
113 // FLOAT32 is filled for unknown data types.
114 // TODO(ycling): Implement type inference in TF Lite interpreter.
115 return ::tflite::TensorType_FLOAT32;
116 }
117 }
118
Deserialize(int tensor_type)119 ArrayDataType DataType::Deserialize(int tensor_type) {
120 switch (::tflite::TensorType(tensor_type)) {
121 case ::tflite::TensorType_FLOAT32:
122 return ArrayDataType::kFloat;
123 case ::tflite::TensorType_INT16:
124 return ArrayDataType::kInt16;
125 case ::tflite::TensorType_INT32:
126 return ArrayDataType::kInt32;
127 case ::tflite::TensorType_UINT32:
128 return ArrayDataType::kUint32;
129 case ::tflite::TensorType_INT64:
130 return ArrayDataType::kInt64;
131 case ::tflite::TensorType_STRING:
132 return ArrayDataType::kString;
133 case ::tflite::TensorType_UINT8:
134 return ArrayDataType::kUint8;
135 case ::tflite::TensorType_UINT16:
136 return ArrayDataType::kUint16;
137 case ::tflite::TensorType_BOOL:
138 return ArrayDataType::kBool;
139 case ::tflite::TensorType_COMPLEX64:
140 return ArrayDataType::kComplex64;
141 default:
142 LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
143 }
144 }
145
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)146 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
147 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
148 if (!array.buffer) return 0; // an empty buffer, usually an output.
149
150 switch (array.data_type) {
151 case ArrayDataType::kFloat:
152 return CopyBuffer<ArrayDataType::kFloat>(array, builder);
153 case ArrayDataType::kInt16:
154 return CopyBuffer<ArrayDataType::kInt16>(array, builder);
155 case ArrayDataType::kInt32:
156 return CopyBuffer<ArrayDataType::kInt32>(array, builder);
157 case ArrayDataType::kUint32:
158 return CopyBuffer<ArrayDataType::kUint32>(array, builder);
159 case ArrayDataType::kInt64:
160 return CopyBuffer<ArrayDataType::kInt64>(array, builder);
161 case ArrayDataType::kString:
162 return CopyStringToBuffer(array, builder);
163 case ArrayDataType::kUint8:
164 return CopyBuffer<ArrayDataType::kUint8>(array, builder);
165 case ArrayDataType::kBool:
166 return CopyBoolToBuffer(array, builder);
167 case ArrayDataType::kComplex64:
168 return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
169 default:
170 LOG(FATAL) << "Unhandled array data type.";
171 }
172 }
173
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)174 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
175 const ::tflite::Buffer& buffer, Array* array) {
176 if (tensor.buffer() == 0) return; // an empty buffer, usually an output.
177 if (buffer.data() == nullptr) return; // a non-defined buffer.
178
179 switch (tensor.type()) {
180 case ::tflite::TensorType_FLOAT32:
181 return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
182 case ::tflite::TensorType_INT16:
183 return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
184 case ::tflite::TensorType_INT32:
185 return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
186 case ::tflite::TensorType_UINT32:
187 return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
188 case ::tflite::TensorType_INT64:
189 return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
190 case ::tflite::TensorType_STRING:
191 return CopyStringFromBuffer(buffer, array);
192 case ::tflite::TensorType_UINT8:
193 return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
194 case ::tflite::TensorType_BOOL:
195 return CopyBuffer<ArrayDataType::kBool>(buffer, array);
196 case ::tflite::TensorType_COMPLEX64:
197 return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
198 default:
199 LOG(FATAL) << "Unhandled tensor type.";
200 }
201 }
202
Serialize(PaddingType padding_type)203 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
204 switch (padding_type) {
205 case PaddingType::kSame:
206 return ::tflite::Padding_SAME;
207 case PaddingType::kValid:
208 return ::tflite::Padding_VALID;
209 default:
210 LOG(FATAL) << "Unhandled padding type.";
211 }
212 }
213
Deserialize(int padding)214 PaddingType Padding::Deserialize(int padding) {
215 switch (::tflite::Padding(padding)) {
216 case ::tflite::Padding_SAME:
217 return PaddingType::kSame;
218 case ::tflite::Padding_VALID:
219 return PaddingType::kValid;
220 default:
221 LOG(FATAL) << "Unhandled padding.";
222 }
223 }
224
Serialize(FusedActivationFunctionType faf_type)225 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
226 FusedActivationFunctionType faf_type) {
227 switch (faf_type) {
228 case FusedActivationFunctionType::kNone:
229 return ::tflite::ActivationFunctionType_NONE;
230 case FusedActivationFunctionType::kRelu:
231 return ::tflite::ActivationFunctionType_RELU;
232 case FusedActivationFunctionType::kRelu6:
233 return ::tflite::ActivationFunctionType_RELU6;
234 case FusedActivationFunctionType::kRelu1:
235 return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
236 default:
237 LOG(FATAL) << "Unhandled fused activation function type.";
238 }
239 }
240
Deserialize(int activation_function)241 FusedActivationFunctionType ActivationFunction::Deserialize(
242 int activation_function) {
243 switch (::tflite::ActivationFunctionType(activation_function)) {
244 case ::tflite::ActivationFunctionType_NONE:
245 return FusedActivationFunctionType::kNone;
246 case ::tflite::ActivationFunctionType_RELU:
247 return FusedActivationFunctionType::kRelu;
248 case ::tflite::ActivationFunctionType_RELU6:
249 return FusedActivationFunctionType::kRelu6;
250 case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
251 return FusedActivationFunctionType::kRelu1;
252 default:
253 LOG(FATAL) << "Unhandled fused activation function type.";
254 }
255 }
256
257 } // namespace tflite
258
259 } // namespace toco
260