1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/export.h"
16
17 #include <algorithm>
18 #include <initializer_list>
19 #include <memory>
20 #include <string>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_utils.h"
29 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
30 #include "tensorflow/lite/toco/tflite/operator.h"
31 #include "tensorflow/lite/toco/tflite/types.h"
32
33 namespace toco {
34 namespace tflite {
35 namespace {
36
37 using ::testing::ElementsAre;
38 using ::testing::HasSubstr;
39
40 class ExportTest : public ::testing::Test {
41 protected:
ResetOperators()42 void ResetOperators() { input_model_.operators.clear(); }
AddTensorsByName(std::initializer_list<std::string> names)43 void AddTensorsByName(std::initializer_list<std::string> names) {
44 for (const std::string& name : names) {
45 input_model_.GetOrCreateArray(name);
46 }
47 }
AddOperatorsByName(std::initializer_list<std::string> names)48 void AddOperatorsByName(std::initializer_list<std::string> names) {
49 for (const std::string& name : names) {
50 if (name == "Conv") {
51 auto* op = new ConvOperator;
52 op->padding.type = PaddingType::kSame;
53 op->inputs = {"input", "filter"};
54 op->outputs = {"output"};
55 Array& input_array = input_model_.GetOrCreateArray(op->inputs[0]);
56 Array& filter_array = input_model_.GetOrCreateArray(op->inputs[1]);
57 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
58 input_array.data_type = ArrayDataType::kFloat;
59 filter_array.data_type = ArrayDataType::kFloat;
60 output_array.data_type = ArrayDataType::kFloat;
61 input_model_.operators.emplace_back(op);
62 } else if (name == "Add") {
63 auto* op = new AddOperator;
64 op->inputs = {"input1", "input2"};
65 op->outputs = {"output"};
66 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
67 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
68 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
69 input1_array.data_type = ArrayDataType::kFloat;
70 input2_array.data_type = ArrayDataType::kFloat;
71 output_array.data_type = ArrayDataType::kFloat;
72 input_model_.operators.emplace_back(op);
73 } else if (name == "Sub") {
74 auto* op = new SubOperator;
75 op->inputs = {"input1", "input2"};
76 op->outputs = {"output"};
77 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
78 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
79 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
80 input1_array.data_type = ArrayDataType::kFloat;
81 input2_array.data_type = ArrayDataType::kFloat;
82 output_array.data_type = ArrayDataType::kFloat;
83 input1_array.copy_shape({1, 2, 2, 2});
84 input2_array.copy_shape({1, 2, 2, 2});
85 output_array.copy_shape({1, 2, 2, 2});
86 input_model_.operators.emplace_back(op);
87 } else if (name == "Assert") {
88 auto* op = new TensorFlowAssertOperator;
89
90 // Even though assert is known to TOCO, it doesn't have a tflite
91 // serializer, so it has to be exported as a custom op. If we attach a
92 // NodeDef to it, however, it will be exported as a flex op instead.
93 ::tensorflow::NodeDef node_def;
94 node_def.set_name("Assert");
95 node_def.set_op("Assert");
96 node_def.SerializeToString(&op->tensorflow_node_def);
97
98 input_model_.operators.emplace_back(op);
99 } else {
100 auto* op = new TensorFlowUnsupportedOperator;
101 op->tensorflow_op = name;
102 input_model_.operators.emplace_back(op);
103 }
104 }
105 }
106
BuildQuantizableTestModel()107 void BuildQuantizableTestModel() {
108 input_model_.GetOrCreateArray("inputs");
109 Array& weight_array = input_model_.GetOrCreateArray("weights");
110
111 // Make the buffer large enough for QuantizeWeights transformation to take
112 // effect.
113 int buf_size = 1296;
114 auto weight_buf = std::make_unique<float[]>(buf_size);
115 for (int i = 0; i < buf_size; i++) {
116 // Fill the array with some garbage values.
117 weight_buf[i] = static_cast<float>(i % 128);
118 }
119
120 weight_array.data_type = ArrayDataType::kFloat;
121
122 // Initialize shape for the input array.
123 Shape* weight_array_shape = weight_array.mutable_shape();
124 std::vector<int>* weight_array_shape_dim =
125 weight_array_shape->mutable_dims();
126 weight_array_shape_dim->resize(4, 6);
127 auto& weight_array_buffer =
128 weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
129 weight_array_buffer.data.resize(buf_size);
130 float* buf_ptr =
131 weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
132 std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
133
134 {
135 auto* op = new ConvOperator;
136 op->padding.type = PaddingType::kSame;
137 op->inputs = {"inputs", "weights"};
138 op->outputs = {"output"};
139 Array& input_array = input_model_.GetArray(op->inputs[0]);
140 Array& filter_array = input_model_.GetArray(op->inputs[1]);
141 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
142 input_array.data_type = ArrayDataType::kFloat;
143 filter_array.data_type = ArrayDataType::kFloat;
144 output_array.data_type = ArrayDataType::kFloat;
145 input_model_.operators.emplace_back(op);
146 }
147 {
148 auto* op = new AddOperator;
149 op->inputs = {"input1", "input2"};
150 op->outputs = {"output"};
151 Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
152 Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
153 Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
154 input1_array.data_type = ArrayDataType::kFloat;
155 input2_array.data_type = ArrayDataType::kFloat;
156 output_array.data_type = ArrayDataType::kFloat;
157 input_model_.operators.emplace_back(op);
158 }
159 }
160
ExportAndReturnStatus(const ExportParams & params)161 tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
162 std::string result;
163 return Export(input_model_, &result, params);
164 }
165
ExportAndSummarizeOperators(const ExportParams & params)166 std::vector<std::string> ExportAndSummarizeOperators(
167 const ExportParams& params) {
168 std::vector<std::string> names;
169
170 std::string result;
171 auto status = Export(input_model_, &result, params);
172 if (!status.ok()) {
173 LOG(INFO) << status.error_message();
174 return names;
175 }
176
177 auto* model = ::tflite::GetModel(result.data());
178
179 for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) {
180 auto builtin_code = GetBuiltinCode(opcode);
181 if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) {
182 names.push_back(std::string("builtin:") +
183 ::tflite::EnumNameBuiltinOperator(builtin_code));
184 } else {
185 names.push_back(std::string("custom:") +
186 opcode->custom_code()->c_str());
187 }
188 }
189
190 return names;
191 }
192
ExportAndGetOperatorIndices(const ExportParams & params)193 std::vector<uint32_t> ExportAndGetOperatorIndices(
194 const ExportParams& params) {
195 std::vector<uint32_t> indices;
196
197 std::string result;
198 if (!Export(input_model_, &result, params).ok()) return indices;
199 auto* model = ::tflite::GetModel(result.data());
200
201 auto operators = (*model->subgraphs())[0]->operators();
202 for (const auto* op : *operators) {
203 indices.push_back(op->opcode_index());
204 }
205 return indices;
206 }
207
208 Model input_model_;
209 };
210
TEST_F(ExportTest,LoadTensorsMap)211 TEST_F(ExportTest, LoadTensorsMap) {
212 AddTensorsByName({"tensor_one", "tensor_two"});
213
214 details::TensorsMap tensors;
215 details::LoadTensorsMap(input_model_, &tensors);
216 EXPECT_EQ(0, tensors["tensor_one"]);
217 EXPECT_EQ(1, tensors["tensor_two"]);
218 }
219
TEST_F(ExportTest,LoadOperatorsMap)220 TEST_F(ExportTest, LoadOperatorsMap) {
221 AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
222
223 details::OperatorsMap operators;
224 const auto ops_by_type = BuildOperatorByTypeMap();
225 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
226 EXPECT_EQ(
227 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
228 EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
229 "", 1)]);
230 EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
231 "MyCrazyOp", 1)]);
232 EXPECT_EQ(
233 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
234 }
235
TEST_F(ExportTest,UnsupportedFunctionality)236 TEST_F(ExportTest, UnsupportedFunctionality) {
237 AddOperatorsByName({"Conv"});
238
239 ExportParams params;
240 params.allow_dynamic_tensors = false;
241 auto status = ExportAndReturnStatus(params);
242 EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED);
243 EXPECT_THAT(status.error_message(),
244 HasSubstr("Unsupported flag: allow_dynamic_tensors."));
245 }
246
TEST_F(ExportTest,Export)247 TEST_F(ExportTest, Export) {
248 AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
249
250 ExportParams params;
251 params.allow_custom_ops = true;
252 params.enable_select_tf_ops = false;
253 params.quantize_weights = QuantizedBufferType::NONE;
254
255 EXPECT_THAT(ExportAndSummarizeOperators(params),
256 ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
257 "builtin:SUB"));
258 EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
259 }
260
TEST_F(ExportTest,ExportMinRuntime)261 TEST_F(ExportTest, ExportMinRuntime) {
262 AddOperatorsByName({"Conv", "Add", "Sub"});
263
264 ExportParams params;
265 params.allow_custom_ops = true;
266 params.enable_select_tf_ops = false;
267 params.quantize_weights = QuantizedBufferType::NONE;
268
269 std::string output;
270 auto status = Export(input_model_, &output, params);
271 auto* model = ::tflite::GetModel(output.data());
272 EXPECT_EQ(model->metadata()->size(), 1);
273 EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
274 auto buf = model->metadata()->Get(0)->buffer();
275 auto* buffer = (*model->buffers())[buf];
276 auto* array = buffer->data();
277 EXPECT_EQ(reinterpret_cast<const char*>(array->data()), std::string("1.6.0"));
278 }
279
TEST_F(ExportTest,ExportEmptyMinRuntime)280 TEST_F(ExportTest, ExportEmptyMinRuntime) {
281 AddOperatorsByName({"Switch", "MyCustomOp", "Assert"});
282
283 ExportParams params;
284 params.allow_custom_ops = true;
285
286 std::string output;
287 auto status = Export(input_model_, &output, params);
288 auto* model = ::tflite::GetModel(output.data());
289 EXPECT_EQ(model->metadata()->size(), 1);
290 EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
291 auto buf = model->metadata()->Get(0)->buffer();
292 auto* buffer = (*model->buffers())[buf];
293 auto* array = buffer->data();
294 EXPECT_EQ(reinterpret_cast<const char*>(array->data()), std::string(""));
295 }
296
TEST_F(ExportTest,UnsupportedControlFlowErrors)297 TEST_F(ExportTest, UnsupportedControlFlowErrors) {
298 AddOperatorsByName({"Conv", "Add", "Switch", "Merge"});
299
300 ExportParams params;
301 params.allow_custom_ops = false;
302
303 // The model contains control flow ops which are not convertible, so we should
304 // check the returned error message.
305
306 std::string output;
307 const auto ops_by_type = BuildOperatorByTypeMap();
308 auto status = Export(input_model_, &output, params, ops_by_type);
309 EXPECT_EQ(status.error_message(),
310 "We are continually in the process of adding support to TensorFlow "
311 "Lite for more ops. It would be helpful if you could inform us of "
312 "how this conversion went by opening a github issue at "
313 "https://github.com/tensorflow/tensorflow/issues/"
314 "new?template=40-tflite-op-request.md\n and pasting the "
315 "following:\n\nTensorFlow Lite currently doesn't support control "
316 "flow ops: Merge, Switch. We are working on supporting control "
317 "flow ops, please see github issue at "
318 "https://github.com/tensorflow/tensorflow/issues/28485.");
319 }
320
TEST_F(ExportTest,UnsupportedOpsAndNeedEnableFlex)321 TEST_F(ExportTest, UnsupportedOpsAndNeedEnableFlex) {
322 AddOperatorsByName({"Conv", "Add", "BatchNormWithGlobalNormalization"});
323
324 ExportParams params;
325 params.allow_custom_ops = false;
326 params.enable_select_tf_ops = false;
327
328 std::string output;
329 const auto ops_by_type = BuildOperatorByTypeMap();
330 auto status = Export(input_model_, &output, params, ops_by_type);
331 EXPECT_EQ(
332 status.error_message(),
333 "We are continually in the process of adding support to TensorFlow Lite "
334 "for more ops. It would be helpful if you could inform us of how this "
335 "conversion went by opening a github issue at "
336 "https://github.com/tensorflow/tensorflow/issues/"
337 "new?template=40-tflite-op-request.md\n and pasting the "
338 "following:\n\nSome of the operators in the model are not supported by "
339 "the standard TensorFlow Lite runtime. If those are native TensorFlow "
340 "operators, you might be able to use the extended runtime by passing "
341 "--enable_select_tf_ops, or by setting "
342 "target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling "
343 "tf.lite.TFLiteConverter(). Otherwise, if you have a custom "
344 "implementation for them you can disable this error with "
345 "--allow_custom_ops, or by setting allow_custom_ops=True when calling "
346 "tf.lite.TFLiteConverter(). Here is a list of builtin operators you are "
347 "using: ADD, CONV_2D. Here is a list of operators for which you will "
348 "need custom implementations: BatchNormWithGlobalNormalization.");
349 }
350
TEST_F(ExportTest,UnsupportedOpsNeedCustomImplementation)351 TEST_F(ExportTest, UnsupportedOpsNeedCustomImplementation) {
352 AddOperatorsByName({"Conv", "Add", "MyCustomOp1", "MyCustomOp2"});
353
354 ExportParams params;
355 params.allow_custom_ops = false;
356 params.enable_select_tf_ops = true;
357
358 std::string output;
359 const auto ops_by_type = BuildOperatorByTypeMap();
360 auto status = Export(input_model_, &output, params, ops_by_type);
361 EXPECT_EQ(
362 status.error_message(),
363 "We are continually in the process of adding support to TensorFlow Lite "
364 "for more ops. It would be helpful if you could inform us of how this "
365 "conversion went by opening a github issue at "
366 "https://github.com/tensorflow/tensorflow/issues/"
367 "new?template=40-tflite-op-request.md\n and pasting the "
368 "following:\n\nSome of the operators in the model are not supported by "
369 "the standard TensorFlow Lite runtime and are not recognized by "
370 "TensorFlow. If you have a custom implementation for them you can "
371 "disable this error with --allow_custom_ops, or by setting "
372 "allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a "
373 "list of builtin operators you are using: ADD, CONV_2D. Here is a list "
374 "of operators for which you will need custom implementations: "
375 "MyCustomOp1, MyCustomOp2.");
376 }
377
TEST_F(ExportTest,UnsupportedControlFlowAndCustomOpsErrors)378 TEST_F(ExportTest, UnsupportedControlFlowAndCustomOpsErrors) {
379 AddOperatorsByName(
380 {"Conv", "Add", "Switch", "Merge", "MyCustomOp1", "MyCustomOp2"});
381
382 ExportParams params;
383 params.allow_custom_ops = false;
384
385 // The model contains control flow ops which are not convertible, so we should
386 // check the returned error message.
387
388 std::string output;
389 const auto ops_by_type = BuildOperatorByTypeMap();
390 auto status = Export(input_model_, &output, params, ops_by_type);
391 EXPECT_EQ(
392 status.error_message(),
393 "We are continually in the process of adding support to TensorFlow Lite "
394 "for more ops. It would be helpful if you could inform us of how this "
395 "conversion went by opening a github issue at "
396 "https://github.com/tensorflow/tensorflow/issues/"
397 "new?template=40-tflite-op-request.md\n and pasting the "
398 "following:\n\nTensorFlow Lite currently doesn't support control flow "
399 "ops: Merge, Switch. We are working on supporting control flow ops, "
400 "please see github issue at "
401 "https://github.com/tensorflow/tensorflow/issues/28485. Some of the "
402 "operators in the model are not supported by the standard TensorFlow "
403 "Lite runtime. If those are native TensorFlow operators, you might be "
404 "able to use the extended runtime by passing --enable_select_tf_ops, or "
405 "by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling "
406 "tf.lite.TFLiteConverter(). Otherwise, if you have a custom "
407 "implementation for them you can disable this error with "
408 "--allow_custom_ops, or by setting allow_custom_ops=True when calling "
409 "tf.lite.TFLiteConverter(). Here is a list of builtin operators you are "
410 "using: ADD, CONV_2D. Here is a list of operators for which you will "
411 "need custom implementations: MyCustomOp1, MyCustomOp2.");
412 }
413
TEST_F(ExportTest,QuantizeWeights)414 TEST_F(ExportTest, QuantizeWeights) {
415 // Sanity check for quantize_weights parameter.
416 BuildQuantizableTestModel();
417 std::string unquantized_result;
418 Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
419
420 BuildQuantizableTestModel();
421 std::string quantized_result;
422 Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
423
424 // The quantized models should be smaller.
425 EXPECT_LT(quantized_result.size(), unquantized_result.size());
426 }
427
428 class OpSetsTest : public ExportTest {
429 public:
430 enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
431
SetAllowedOpSets(std::initializer_list<OpSet> sets)432 void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
433 import_all_ops_as_unsupported_ = true;
434 params_.allow_custom_ops = false;
435 params_.enable_select_tf_ops = false;
436 params_.quantize_weights = QuantizedBufferType::NONE;
437
438 for (const OpSet& i : sets) {
439 switch (i) {
440 case kTfLiteBuiltins:
441 import_all_ops_as_unsupported_ = false;
442 break;
443 case kSelectTfOps:
444 params_.enable_select_tf_ops = true;
445 break;
446 case kCustomOps:
447 params_.allow_custom_ops = true;
448 break;
449 }
450 }
451 }
452
ImportExport(std::initializer_list<std::string> op_names)453 std::vector<std::string> ImportExport(
454 std::initializer_list<std::string> op_names) {
455 ResetOperators();
456 if (!import_all_ops_as_unsupported_) {
457 AddOperatorsByName(op_names);
458 } else {
459 for (const std::string& name : op_names) {
460 auto* op = new TensorFlowUnsupportedOperator;
461 op->tensorflow_op = name;
462 input_model_.operators.emplace_back(op);
463 }
464 }
465 return ExportAndSummarizeOperators(params_);
466 }
467
468 private:
469 bool import_all_ops_as_unsupported_;
470 ExportParams params_;
471 };
472
TEST_F(OpSetsTest,BuiltinsOnly)473 TEST_F(OpSetsTest, BuiltinsOnly) {
474 // --target_op_set=TFLITE_BUILTINS
475 SetAllowedOpSets({kTfLiteBuiltins});
476 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
477 ElementsAre());
478 EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
479
480 // --target_op_set=TFLITE_BUILTINS --allow_custom_ops
481 SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
482 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
483 ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:Assert",
484 "custom:UnrollAndFold"));
485 }
486
TEST_F(OpSetsTest,TfSelectOnly)487 TEST_F(OpSetsTest, TfSelectOnly) {
488 // --target_op_set=SELECT_TF_OPS
489 SetAllowedOpSets({kSelectTfOps});
490 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "RandomUniform",
491 "UnrollAndFold", "Assert"}),
492 ElementsAre());
493 EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
494
495 // --target_op_set=SELECT_TF_OPS --allow_custom_ops
496 SetAllowedOpSets({kSelectTfOps, kCustomOps});
497 EXPECT_THAT(
498 ImportExport(
499 {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
500 ElementsAre("custom:FlexAdd", "custom:FlexAdjustHue", "custom:FlexAssert",
501 "custom:FlexRandomUniform", "custom:UnrollAndFold"));
502 }
503
TEST_F(OpSetsTest,BuiltinsAndTfSelect)504 TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
505 // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
506 SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
507 EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
508 ElementsAre());
509 EXPECT_THAT(ImportExport({"Add", "RandomUniform"}),
510 ElementsAre("builtin:ADD", "custom:FlexRandomUniform"));
511
512 // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
513 SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
514 EXPECT_THAT(
515 ImportExport(
516 {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
517 ElementsAre("builtin:ADD", "custom:FlexAdjustHue", "custom:FlexAssert",
518 "custom:FlexRandomUniform", "custom:UnrollAndFold"));
519 }
520
521 // This test is based on a hypothetical scenario that dilation is supported
522 // only in Conv version 2. So Toco populates version=1 when dilation parameters
523 // are all 1, and version=2 otherwise.
524 class FakeConvolutionOperator
525 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
526 ::tflite::BuiltinOptions_Conv2DOptions> {
527 public:
FakeConvolutionOperator()528 FakeConvolutionOperator()
529 : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
530 OperatorType::kConv) {}
531
532 // Returning the op version according to the op parameters.
GetVersion(const OperatorSignature & op_signature) const533 int GetVersion(const OperatorSignature& op_signature) const override {
534 const TocoOperator& conv_op =
535 static_cast<const TocoOperator&>(*op_signature.op);
536 if (conv_op.dilation_width_factor != 1 ||
537 conv_op.dilation_height_factor != 1) {
538 // Version 2 if dilation is used.
539 return 2;
540 }
541 return 1;
542 }
543
544 // Note: The read / write code doesn't need to be changed if we stick with
545 // the restrictions:
546 // * Only adding parameters at the bottom of the Flatbuffer tables.
547 // * When the default value of parameters are used, the op works consistently
548 // with the previous version.
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const549 flatbuffers::Offset<TfLiteOptions> WriteOptions(
550 const TocoOperator& op,
551 flatbuffers::FlatBufferBuilder* builder) const override {
552 auto padding = Padding::Serialize(op.padding.type);
553 auto activation_function =
554 ActivationFunction::Serialize(op.fused_activation_function);
555 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
556 op.stride_height, activation_function,
557 op.dilation_width_factor,
558 op.dilation_height_factor);
559 }
560
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const561 void ReadOptions(const TfLiteOptions& options,
562 TocoOperator* op) const override {
563 op->padding.type = Padding::Deserialize(options.padding());
564 op->stride_width = options.stride_w();
565 op->stride_height = options.stride_h();
566 op->dilation_width_factor = options.dilation_w_factor();
567 op->dilation_height_factor = options.dilation_h_factor();
568 op->fused_activation_function =
569 ActivationFunction::Deserialize(options.fused_activation_function());
570 }
571 };
572
573 class VersionedOpExportTest : public ::testing::Test {
574 protected:
SetUp()575 void SetUp() override {
576 input_model_.GetOrCreateArray("input");
577 input_model_.GetOrCreateArray("filter");
578 input_model_.GetOrCreateArray("output");
579 }
AddConvOp(bool use_dilation)580 void AddConvOp(bool use_dilation) {
581 {
582 auto* op = new ConvOperator;
583 op->inputs.push_back("input");
584 op->inputs.push_back("filter");
585 op->outputs.push_back("output");
586
587 op->padding.type = PaddingType::kSame;
588 op->stride_width = 1;
589 op->stride_height = 1;
590 if (use_dilation) {
591 op->dilation_width_factor = 2;
592 op->dilation_height_factor = 2;
593 } else {
594 op->dilation_width_factor = 1;
595 op->dilation_height_factor = 1;
596 }
597 input_model_.operators.emplace_back(op);
598 }
599 }
600
601 std::map<OperatorType, std::unique_ptr<BaseOperator>>
BuildFakeOperatorByTypeMap()602 BuildFakeOperatorByTypeMap() {
603 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
604 result[OperatorType::kConv] =
605 std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
606 return result;
607 }
608
609 Model input_model_;
610 };
611
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV1)612 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
613 AddConvOp(false);
614
615 details::OperatorsMap operators;
616 const auto ops_by_type = BuildFakeOperatorByTypeMap();
617 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
618
619 EXPECT_EQ(1, operators.size());
620 EXPECT_EQ(0, operators.at(details::OperatorKey(
621 ::tflite::BuiltinOperator_CONV_2D, "", 1)));
622 }
623
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV2)624 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
625 AddConvOp(true);
626
627 details::OperatorsMap operators;
628 const auto ops_by_type = BuildFakeOperatorByTypeMap();
629 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
630
631 EXPECT_EQ(1, operators.size());
632 EXPECT_EQ(0, operators.at(details::OperatorKey(
633 ::tflite::BuiltinOperator_CONV_2D, "", 2)));
634 }
635
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithBothVersions)636 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
637 AddConvOp(false);
638 AddConvOp(true);
639
640 details::OperatorsMap operators;
641 const auto ops_by_type = BuildFakeOperatorByTypeMap();
642 details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
643
644 EXPECT_EQ(2, operators.size());
645 EXPECT_EQ(0, operators.at(details::OperatorKey(
646 ::tflite::BuiltinOperator_CONV_2D, "", 1)));
647 EXPECT_EQ(1, operators.at(details::OperatorKey(
648 ::tflite::BuiltinOperator_CONV_2D, "", 2)));
649 }
650
TEST_F(VersionedOpExportTest,Export)651 TEST_F(VersionedOpExportTest, Export) {
652 AddConvOp(false);
653 AddConvOp(true);
654
655 std::string result;
656 const auto ops_by_type = BuildFakeOperatorByTypeMap();
657 Export(input_model_, true, false, &result, ops_by_type);
658
659 auto* model = ::tflite::GetModel(result.data());
660 auto operator_codes = model->operator_codes();
661
662 // Verify that 2 operator codes are populated. Both are CONV_2D but with
663 // different versions.
664 EXPECT_EQ(2, operator_codes->size());
665 EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
666 GetBuiltinCode((*operator_codes)[0]));
667 EXPECT_EQ(1, (*operator_codes)[0]->version());
668 EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
669 GetBuiltinCode((*operator_codes)[1]));
670 EXPECT_EQ(2, (*operator_codes)[1]->version());
671
672 // Verify that the 2 operators points to the correct indices of the operation
673 // codes.
674 auto operators = (*model->subgraphs())[0]->operators();
675 EXPECT_EQ(2, operators->size());
676 EXPECT_EQ(0, (*operators)[0]->opcode_index());
677 EXPECT_EQ(1, (*operators)[1]->opcode_index());
678 }
679
TEST(OperatorKeyTest,TestBuiltinOp)680 TEST(OperatorKeyTest, TestBuiltinOp) {
681 Model model;
682 auto op = std::make_unique<ConvOperator>();
683
684 // Test a normal float operation.
685 op->inputs = {"input", "filter"};
686 op->outputs = {"output"};
687 Array& input_array = model.GetOrCreateArray(op->inputs[0]);
688 Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
689 Array& output_array = model.GetOrCreateArray(op->outputs[0]);
690 input_array.data_type = ArrayDataType::kFloat;
691 filter_array.data_type = ArrayDataType::kFloat;
692 output_array.data_type = ArrayDataType::kFloat;
693
694 const auto ops_by_type = BuildOperatorByTypeMap();
695 const toco::OperatorSignature op_signature = {op.get(), &model};
696 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
697
698 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
699 EXPECT_EQ(key.custom_code(), "");
700 EXPECT_EQ(key.version(), 1);
701 }
702
TEST(OperatorKeyTest,TestBuiltinOpWithVersionedInputTypes)703 TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
704 Model model;
705 auto op = std::make_unique<DequantizeOperator>();
706
707 op->inputs = {"input"};
708 op->outputs = {"output"};
709 Array& input_array = model.GetOrCreateArray(op->inputs[0]);
710 Array& output_array = model.GetOrCreateArray(op->outputs[0]);
711 input_array.data_type = ArrayDataType::kInt8;
712 output_array.data_type = ArrayDataType::kFloat;
713
714 const auto ops_by_type = BuildOperatorByTypeMap();
715
716 // Test a signed int8 dequantize operation.
717 const toco::OperatorSignature op_signature = {op.get(), &model};
718 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
719
720 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
721 EXPECT_EQ(key.custom_code(), "");
722 EXPECT_EQ(key.version(), 2);
723 }
724
TEST(OperatorKeyTest,TestCustomOp)725 TEST(OperatorKeyTest, TestCustomOp) {
726 Model model;
727 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
728 op->tensorflow_op = "MyCrazyCustomOp";
729
730 const auto ops_by_type = BuildOperatorByTypeMap();
731 const toco::OperatorSignature op_signature = {op.get(), &model};
732 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
733
734 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
735 EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
736 EXPECT_EQ(key.version(), 1);
737 }
738
TEST(OperatorKeyTest,TestFlexOp)739 TEST(OperatorKeyTest, TestFlexOp) {
740 Model model;
741 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
742 op->tensorflow_op = "BatchMatMul";
743
744 const auto ops_by_type = BuildOperatorByTypeMap();
745 {
746 const toco::OperatorSignature op_signature = {op.get(), &model};
747 const auto key = details::OperatorKey(op_signature, ops_by_type, false);
748 // It shouldn't be converted to Flex op if `allow_flex_op` is false.
749 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
750 EXPECT_EQ(key.custom_code(), "BatchMatMul");
751 EXPECT_EQ(key.version(), 1);
752 EXPECT_TRUE(key.is_custom_op());
753 EXPECT_FALSE(key.is_flex_op());
754 }
755
756 {
757 // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
758 // is true.
759 const toco::OperatorSignature op_signature = {op.get(), &model};
760 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
761 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
762 EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
763 EXPECT_EQ(key.version(), 1);
764 EXPECT_FALSE(key.is_custom_op());
765 EXPECT_TRUE(key.is_flex_op());
766 }
767 }
768
TEST(OperatorKeyTest,TestFlexWithControlFlowOp)769 TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
770 Model model;
771 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
772 op->tensorflow_op = "Merge";
773
774 const auto ops_by_type = BuildOperatorByTypeMap();
775 const toco::OperatorSignature op_signature = {op.get(), &model};
776 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
777
778 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
779 EXPECT_EQ(key.custom_code(), "FlexMerge");
780 EXPECT_EQ(key.version(), 1);
781 EXPECT_FALSE(key.is_custom_op());
782 EXPECT_TRUE(key.is_flex_op());
783 // The control flow ops should be marked as unsupported.
784 EXPECT_TRUE(key.is_unsupported_flex_op());
785 }
786
TEST(OperatorKeyTest,TestFlexWithUnsupportedOp)787 TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
788 Model model;
789 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
790 op->tensorflow_op = "UnsupportedOp";
791
792 const auto ops_by_type = BuildOperatorByTypeMap();
793 const toco::OperatorSignature op_signature = {op.get(), &model};
794 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
795
796 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
797 EXPECT_EQ(key.custom_code(), "UnsupportedOp");
798 EXPECT_EQ(key.version(), 1);
799 EXPECT_FALSE(key.is_flex_op());
800 EXPECT_FALSE(key.is_unsupported_flex_op());
801 }
802
TEST(OperatorKeyTest,TestFlexWithPartiallySupportedOps)803 TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
804 // Test Toco-supported/TFLite-unsupported operators.
805 Model model;
806 // TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
807 // TFLite. Find a more robust way to test the fallback logic.
808 auto op = std::make_unique<TensorFlowAssertOperator>();
809
810 const auto ops_by_type = BuildOperatorByTypeMap();
811
812 {
813 // If NodeDef isn't retained in the Toco op, a regular custom op
814 // will be exported.
815 const toco::OperatorSignature op_signature = {op.get(), &model};
816 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
817 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
818 EXPECT_EQ(key.custom_code(), "Assert");
819 EXPECT_EQ(key.version(), 1);
820 EXPECT_TRUE(key.is_custom_op());
821 EXPECT_FALSE(key.is_flex_op());
822 }
823
824 ::tensorflow::NodeDef node_def;
825 node_def.set_name("TensorFlowAssert");
826 node_def.set_op("TensorFlowAssert");
827 node_def.SerializeToString(&op->tensorflow_node_def);
828
829 {
830 // If NodeDef is retained in the Toco op, a Flex op will be exported.
831 const toco::OperatorSignature op_signature = {op.get(), &model};
832 const auto key = details::OperatorKey(op_signature, ops_by_type, true);
833 EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
834 EXPECT_EQ(key.custom_code(), "FlexAssert");
835 EXPECT_EQ(key.version(), 1);
836 EXPECT_FALSE(key.is_custom_op());
837 EXPECT_TRUE(key.is_flex_op());
838 }
839 }
840
841 // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
842
843 } // namespace
844 } // namespace tflite
845 } // namespace toco
846