• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/export.h"
16 
17 #include <algorithm>
18 #include <initializer_list>
19 #include <memory>
20 #include <string>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow/lite/schema/schema_utils.h"
29 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
30 #include "tensorflow/lite/toco/tflite/operator.h"
31 #include "tensorflow/lite/toco/tflite/types.h"
32 
33 namespace toco {
34 namespace tflite {
35 namespace {
36 
37 using ::testing::ElementsAre;
38 using ::testing::HasSubstr;
39 
40 class ExportTest : public ::testing::Test {
41  protected:
ResetOperators()42   void ResetOperators() { input_model_.operators.clear(); }
AddTensorsByName(std::initializer_list<std::string> names)43   void AddTensorsByName(std::initializer_list<std::string> names) {
44     for (const std::string& name : names) {
45       input_model_.GetOrCreateArray(name);
46     }
47   }
AddOperatorsByName(std::initializer_list<std::string> names)48   void AddOperatorsByName(std::initializer_list<std::string> names) {
49     for (const std::string& name : names) {
50       if (name == "Conv") {
51         auto* op = new ConvOperator;
52         op->padding.type = PaddingType::kSame;
53         op->inputs = {"input", "filter"};
54         op->outputs = {"output"};
55         Array& input_array = input_model_.GetOrCreateArray(op->inputs[0]);
56         Array& filter_array = input_model_.GetOrCreateArray(op->inputs[1]);
57         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
58         input_array.data_type = ArrayDataType::kFloat;
59         filter_array.data_type = ArrayDataType::kFloat;
60         output_array.data_type = ArrayDataType::kFloat;
61         input_model_.operators.emplace_back(op);
62       } else if (name == "Add") {
63         auto* op = new AddOperator;
64         op->inputs = {"input1", "input2"};
65         op->outputs = {"output"};
66         Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
67         Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
68         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
69         input1_array.data_type = ArrayDataType::kFloat;
70         input2_array.data_type = ArrayDataType::kFloat;
71         output_array.data_type = ArrayDataType::kFloat;
72         input_model_.operators.emplace_back(op);
73       } else if (name == "Sub") {
74         auto* op = new SubOperator;
75         op->inputs = {"input1", "input2"};
76         op->outputs = {"output"};
77         Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
78         Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
79         Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
80         input1_array.data_type = ArrayDataType::kFloat;
81         input2_array.data_type = ArrayDataType::kFloat;
82         output_array.data_type = ArrayDataType::kFloat;
83         input1_array.copy_shape({1, 2, 2, 2});
84         input2_array.copy_shape({1, 2, 2, 2});
85         output_array.copy_shape({1, 2, 2, 2});
86         input_model_.operators.emplace_back(op);
87       } else if (name == "Assert") {
88         auto* op = new TensorFlowAssertOperator;
89 
90         // Even though assert is known to TOCO, it doesn't have a tflite
91         // serializer, so it has to be exported as a custom op. If we attach a
92         // NodeDef to it, however, it will be exported as a flex op instead.
93         ::tensorflow::NodeDef node_def;
94         node_def.set_name("Assert");
95         node_def.set_op("Assert");
96         node_def.SerializeToString(&op->tensorflow_node_def);
97 
98         input_model_.operators.emplace_back(op);
99       } else {
100         auto* op = new TensorFlowUnsupportedOperator;
101         op->tensorflow_op = name;
102         input_model_.operators.emplace_back(op);
103       }
104     }
105   }
106 
BuildQuantizableTestModel()107   void BuildQuantizableTestModel() {
108     input_model_.GetOrCreateArray("inputs");
109     Array& weight_array = input_model_.GetOrCreateArray("weights");
110 
111     // Make the buffer large enough for QuantizeWeights transformation to take
112     // effect.
113     int buf_size = 1296;
114     auto weight_buf = std::make_unique<float[]>(buf_size);
115     for (int i = 0; i < buf_size; i++) {
116       // Fill the array with some garbage values.
117       weight_buf[i] = static_cast<float>(i % 128);
118     }
119 
120     weight_array.data_type = ArrayDataType::kFloat;
121 
122     // Initialize shape for the input array.
123     Shape* weight_array_shape = weight_array.mutable_shape();
124     std::vector<int>* weight_array_shape_dim =
125         weight_array_shape->mutable_dims();
126     weight_array_shape_dim->resize(4, 6);
127     auto& weight_array_buffer =
128         weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
129     weight_array_buffer.data.resize(buf_size);
130     float* buf_ptr =
131         weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
132     std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
133 
134     {
135       auto* op = new ConvOperator;
136       op->padding.type = PaddingType::kSame;
137       op->inputs = {"inputs", "weights"};
138       op->outputs = {"output"};
139       Array& input_array = input_model_.GetArray(op->inputs[0]);
140       Array& filter_array = input_model_.GetArray(op->inputs[1]);
141       Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
142       input_array.data_type = ArrayDataType::kFloat;
143       filter_array.data_type = ArrayDataType::kFloat;
144       output_array.data_type = ArrayDataType::kFloat;
145       input_model_.operators.emplace_back(op);
146     }
147     {
148       auto* op = new AddOperator;
149       op->inputs = {"input1", "input2"};
150       op->outputs = {"output"};
151       Array& input1_array = input_model_.GetOrCreateArray(op->inputs[0]);
152       Array& input2_array = input_model_.GetOrCreateArray(op->inputs[1]);
153       Array& output_array = input_model_.GetOrCreateArray(op->outputs[0]);
154       input1_array.data_type = ArrayDataType::kFloat;
155       input2_array.data_type = ArrayDataType::kFloat;
156       output_array.data_type = ArrayDataType::kFloat;
157       input_model_.operators.emplace_back(op);
158     }
159   }
160 
ExportAndReturnStatus(const ExportParams & params)161   tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
162     std::string result;
163     return Export(input_model_, &result, params);
164   }
165 
ExportAndSummarizeOperators(const ExportParams & params)166   std::vector<std::string> ExportAndSummarizeOperators(
167       const ExportParams& params) {
168     std::vector<std::string> names;
169 
170     std::string result;
171     auto status = Export(input_model_, &result, params);
172     if (!status.ok()) {
173       LOG(INFO) << status.error_message();
174       return names;
175     }
176 
177     auto* model = ::tflite::GetModel(result.data());
178 
179     for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) {
180       auto builtin_code = GetBuiltinCode(opcode);
181       if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) {
182         names.push_back(std::string("builtin:") +
183                         ::tflite::EnumNameBuiltinOperator(builtin_code));
184       } else {
185         names.push_back(std::string("custom:") +
186                         opcode->custom_code()->c_str());
187       }
188     }
189 
190     return names;
191   }
192 
ExportAndGetOperatorIndices(const ExportParams & params)193   std::vector<uint32_t> ExportAndGetOperatorIndices(
194       const ExportParams& params) {
195     std::vector<uint32_t> indices;
196 
197     std::string result;
198     if (!Export(input_model_, &result, params).ok()) return indices;
199     auto* model = ::tflite::GetModel(result.data());
200 
201     auto operators = (*model->subgraphs())[0]->operators();
202     for (const auto* op : *operators) {
203       indices.push_back(op->opcode_index());
204     }
205     return indices;
206   }
207 
208   Model input_model_;
209 };
210 
TEST_F(ExportTest,LoadTensorsMap)211 TEST_F(ExportTest, LoadTensorsMap) {
212   AddTensorsByName({"tensor_one", "tensor_two"});
213 
214   details::TensorsMap tensors;
215   details::LoadTensorsMap(input_model_, &tensors);
216   EXPECT_EQ(0, tensors["tensor_one"]);
217   EXPECT_EQ(1, tensors["tensor_two"]);
218 }
219 
TEST_F(ExportTest,LoadOperatorsMap)220 TEST_F(ExportTest, LoadOperatorsMap) {
221   AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
222 
223   details::OperatorsMap operators;
224   const auto ops_by_type = BuildOperatorByTypeMap();
225   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
226   EXPECT_EQ(
227       0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
228   EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
229                                               "", 1)]);
230   EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
231                                               "MyCrazyOp", 1)]);
232   EXPECT_EQ(
233       3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
234 }
235 
TEST_F(ExportTest,UnsupportedFunctionality)236 TEST_F(ExportTest, UnsupportedFunctionality) {
237   AddOperatorsByName({"Conv"});
238 
239   ExportParams params;
240   params.allow_dynamic_tensors = false;
241   auto status = ExportAndReturnStatus(params);
242   EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED);
243   EXPECT_THAT(status.error_message(),
244               HasSubstr("Unsupported flag: allow_dynamic_tensors."));
245 }
246 
TEST_F(ExportTest,Export)247 TEST_F(ExportTest, Export) {
248   AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
249 
250   ExportParams params;
251   params.allow_custom_ops = true;
252   params.enable_select_tf_ops = false;
253   params.quantize_weights = QuantizedBufferType::NONE;
254 
255   EXPECT_THAT(ExportAndSummarizeOperators(params),
256               ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
257                           "builtin:SUB"));
258   EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
259 }
260 
TEST_F(ExportTest,ExportMinRuntime)261 TEST_F(ExportTest, ExportMinRuntime) {
262   AddOperatorsByName({"Conv", "Add", "Sub"});
263 
264   ExportParams params;
265   params.allow_custom_ops = true;
266   params.enable_select_tf_ops = false;
267   params.quantize_weights = QuantizedBufferType::NONE;
268 
269   std::string output;
270   auto status = Export(input_model_, &output, params);
271   auto* model = ::tflite::GetModel(output.data());
272   EXPECT_EQ(model->metadata()->size(), 1);
273   EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
274   auto buf = model->metadata()->Get(0)->buffer();
275   auto* buffer = (*model->buffers())[buf];
276   auto* array = buffer->data();
277   EXPECT_EQ(reinterpret_cast<const char*>(array->data()), std::string("1.6.0"));
278 }
279 
TEST_F(ExportTest,ExportEmptyMinRuntime)280 TEST_F(ExportTest, ExportEmptyMinRuntime) {
281   AddOperatorsByName({"Switch", "MyCustomOp", "Assert"});
282 
283   ExportParams params;
284   params.allow_custom_ops = true;
285 
286   std::string output;
287   auto status = Export(input_model_, &output, params);
288   auto* model = ::tflite::GetModel(output.data());
289   EXPECT_EQ(model->metadata()->size(), 1);
290   EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
291   auto buf = model->metadata()->Get(0)->buffer();
292   auto* buffer = (*model->buffers())[buf];
293   auto* array = buffer->data();
294   EXPECT_EQ(reinterpret_cast<const char*>(array->data()), std::string(""));
295 }
296 
TEST_F(ExportTest,UnsupportedControlFlowErrors)297 TEST_F(ExportTest, UnsupportedControlFlowErrors) {
298   AddOperatorsByName({"Conv", "Add", "Switch", "Merge"});
299 
300   ExportParams params;
301   params.allow_custom_ops = false;
302 
303   // The model contains control flow ops which are not convertible, so we should
304   // check the returned error message.
305 
306   std::string output;
307   const auto ops_by_type = BuildOperatorByTypeMap();
308   auto status = Export(input_model_, &output, params, ops_by_type);
309   EXPECT_EQ(status.error_message(),
310             "We are continually in the process of adding support to TensorFlow "
311             "Lite for more ops. It would be helpful if you could inform us of "
312             "how this conversion went by opening a github issue at "
313             "https://github.com/tensorflow/tensorflow/issues/"
314             "new?template=40-tflite-op-request.md\n and pasting the "
315             "following:\n\nTensorFlow Lite currently doesn't support control "
316             "flow ops: Merge, Switch. We are working on supporting control "
317             "flow ops, please see github issue at "
318             "https://github.com/tensorflow/tensorflow/issues/28485.");
319 }
320 
TEST_F(ExportTest,UnsupportedOpsAndNeedEnableFlex)321 TEST_F(ExportTest, UnsupportedOpsAndNeedEnableFlex) {
322   AddOperatorsByName({"Conv", "Add", "BatchNormWithGlobalNormalization"});
323 
324   ExportParams params;
325   params.allow_custom_ops = false;
326   params.enable_select_tf_ops = false;
327 
328   std::string output;
329   const auto ops_by_type = BuildOperatorByTypeMap();
330   auto status = Export(input_model_, &output, params, ops_by_type);
331   EXPECT_EQ(
332       status.error_message(),
333       "We are continually in the process of adding support to TensorFlow Lite "
334       "for more ops. It would be helpful if you could inform us of how this "
335       "conversion went by opening a github issue at "
336       "https://github.com/tensorflow/tensorflow/issues/"
337       "new?template=40-tflite-op-request.md\n and pasting the "
338       "following:\n\nSome of the operators in the model are not supported by "
339       "the standard TensorFlow Lite runtime. If those are native TensorFlow "
340       "operators, you might be able to use the extended runtime by passing "
341       "--enable_select_tf_ops, or by setting "
342       "target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling "
343       "tf.lite.TFLiteConverter(). Otherwise, if you have a custom "
344       "implementation for them you can disable this error with "
345       "--allow_custom_ops, or by setting allow_custom_ops=True when calling "
346       "tf.lite.TFLiteConverter(). Here is a list of builtin operators you are "
347       "using: ADD, CONV_2D. Here is a list of operators for which you will "
348       "need custom implementations: BatchNormWithGlobalNormalization.");
349 }
350 
TEST_F(ExportTest,UnsupportedOpsNeedCustomImplementation)351 TEST_F(ExportTest, UnsupportedOpsNeedCustomImplementation) {
352   AddOperatorsByName({"Conv", "Add", "MyCustomOp1", "MyCustomOp2"});
353 
354   ExportParams params;
355   params.allow_custom_ops = false;
356   params.enable_select_tf_ops = true;
357 
358   std::string output;
359   const auto ops_by_type = BuildOperatorByTypeMap();
360   auto status = Export(input_model_, &output, params, ops_by_type);
361   EXPECT_EQ(
362       status.error_message(),
363       "We are continually in the process of adding support to TensorFlow Lite "
364       "for more ops. It would be helpful if you could inform us of how this "
365       "conversion went by opening a github issue at "
366       "https://github.com/tensorflow/tensorflow/issues/"
367       "new?template=40-tflite-op-request.md\n and pasting the "
368       "following:\n\nSome of the operators in the model are not supported by "
369       "the standard TensorFlow Lite runtime and are not recognized by "
370       "TensorFlow. If you have a custom implementation for them you can "
371       "disable this error with --allow_custom_ops, or by setting "
372       "allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a "
373       "list of builtin operators you are using: ADD, CONV_2D. Here is a list "
374       "of operators for which you will need custom implementations: "
375       "MyCustomOp1, MyCustomOp2.");
376 }
377 
TEST_F(ExportTest,UnsupportedControlFlowAndCustomOpsErrors)378 TEST_F(ExportTest, UnsupportedControlFlowAndCustomOpsErrors) {
379   AddOperatorsByName(
380       {"Conv", "Add", "Switch", "Merge", "MyCustomOp1", "MyCustomOp2"});
381 
382   ExportParams params;
383   params.allow_custom_ops = false;
384 
385   // The model contains control flow ops which are not convertible, so we should
386   // check the returned error message.
387 
388   std::string output;
389   const auto ops_by_type = BuildOperatorByTypeMap();
390   auto status = Export(input_model_, &output, params, ops_by_type);
391   EXPECT_EQ(
392       status.error_message(),
393       "We are continually in the process of adding support to TensorFlow Lite "
394       "for more ops. It would be helpful if you could inform us of how this "
395       "conversion went by opening a github issue at "
396       "https://github.com/tensorflow/tensorflow/issues/"
397       "new?template=40-tflite-op-request.md\n and pasting the "
398       "following:\n\nTensorFlow Lite currently doesn't support control flow "
399       "ops: Merge, Switch. We are working on supporting control flow ops, "
400       "please see github issue at "
401       "https://github.com/tensorflow/tensorflow/issues/28485. Some of the "
402       "operators in the model are not supported by the standard TensorFlow "
403       "Lite runtime. If those are native TensorFlow operators, you might be "
404       "able to use the extended runtime by passing --enable_select_tf_ops, or "
405       "by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling "
406       "tf.lite.TFLiteConverter(). Otherwise, if you have a custom "
407       "implementation for them you can disable this error with "
408       "--allow_custom_ops, or by setting allow_custom_ops=True when calling "
409       "tf.lite.TFLiteConverter(). Here is a list of builtin operators you are "
410       "using: ADD, CONV_2D. Here is a list of operators for which you will "
411       "need custom implementations: MyCustomOp1, MyCustomOp2.");
412 }
413 
TEST_F(ExportTest,QuantizeWeights)414 TEST_F(ExportTest, QuantizeWeights) {
415   // Sanity check for quantize_weights parameter.
416   BuildQuantizableTestModel();
417   std::string unquantized_result;
418   Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
419 
420   BuildQuantizableTestModel();
421   std::string quantized_result;
422   Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
423 
424   // The quantized models should be smaller.
425   EXPECT_LT(quantized_result.size(), unquantized_result.size());
426 }
427 
428 class OpSetsTest : public ExportTest {
429  public:
430   enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
431 
SetAllowedOpSets(std::initializer_list<OpSet> sets)432   void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
433     import_all_ops_as_unsupported_ = true;
434     params_.allow_custom_ops = false;
435     params_.enable_select_tf_ops = false;
436     params_.quantize_weights = QuantizedBufferType::NONE;
437 
438     for (const OpSet& i : sets) {
439       switch (i) {
440         case kTfLiteBuiltins:
441           import_all_ops_as_unsupported_ = false;
442           break;
443         case kSelectTfOps:
444           params_.enable_select_tf_ops = true;
445           break;
446         case kCustomOps:
447           params_.allow_custom_ops = true;
448           break;
449       }
450     }
451   }
452 
ImportExport(std::initializer_list<std::string> op_names)453   std::vector<std::string> ImportExport(
454       std::initializer_list<std::string> op_names) {
455     ResetOperators();
456     if (!import_all_ops_as_unsupported_) {
457       AddOperatorsByName(op_names);
458     } else {
459       for (const std::string& name : op_names) {
460         auto* op = new TensorFlowUnsupportedOperator;
461         op->tensorflow_op = name;
462         input_model_.operators.emplace_back(op);
463       }
464     }
465     return ExportAndSummarizeOperators(params_);
466   }
467 
468  private:
469   bool import_all_ops_as_unsupported_;
470   ExportParams params_;
471 };
472 
TEST_F(OpSetsTest,BuiltinsOnly)473 TEST_F(OpSetsTest, BuiltinsOnly) {
474   // --target_op_set=TFLITE_BUILTINS
475   SetAllowedOpSets({kTfLiteBuiltins});
476   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
477               ElementsAre());
478   EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
479 
480   // --target_op_set=TFLITE_BUILTINS --allow_custom_ops
481   SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
482   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
483               ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:Assert",
484                           "custom:UnrollAndFold"));
485 }
486 
TEST_F(OpSetsTest,TfSelectOnly)487 TEST_F(OpSetsTest, TfSelectOnly) {
488   // --target_op_set=SELECT_TF_OPS
489   SetAllowedOpSets({kSelectTfOps});
490   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "RandomUniform",
491                             "UnrollAndFold", "Assert"}),
492               ElementsAre());
493   EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
494 
495   // --target_op_set=SELECT_TF_OPS --allow_custom_ops
496   SetAllowedOpSets({kSelectTfOps, kCustomOps});
497   EXPECT_THAT(
498       ImportExport(
499           {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
500       ElementsAre("custom:FlexAdd", "custom:FlexAdjustHue", "custom:FlexAssert",
501                   "custom:FlexRandomUniform", "custom:UnrollAndFold"));
502 }
503 
TEST_F(OpSetsTest,BuiltinsAndTfSelect)504 TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
505   // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
506   SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
507   EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold", "Assert"}),
508               ElementsAre());
509   EXPECT_THAT(ImportExport({"Add", "RandomUniform"}),
510               ElementsAre("builtin:ADD", "custom:FlexRandomUniform"));
511 
512   // --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
513   SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
514   EXPECT_THAT(
515       ImportExport(
516           {"Add", "AdjustHue", "RandomUniform", "UnrollAndFold", "Assert"}),
517       ElementsAre("builtin:ADD", "custom:FlexAdjustHue", "custom:FlexAssert",
518                   "custom:FlexRandomUniform", "custom:UnrollAndFold"));
519 }
520 
521 // This test is based on a hypothetical scenario that dilation is supported
522 // only in Conv version 2. So Toco populates version=1 when dilation parameters
523 // are all 1, and version=2 otherwise.
524 class FakeConvolutionOperator
525     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
526                              ::tflite::BuiltinOptions_Conv2DOptions> {
527  public:
FakeConvolutionOperator()528   FakeConvolutionOperator()
529       : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
530                         OperatorType::kConv) {}
531 
532   // Returning the op version according to the op parameters.
GetVersion(const OperatorSignature & op_signature) const533   int GetVersion(const OperatorSignature& op_signature) const override {
534     const TocoOperator& conv_op =
535         static_cast<const TocoOperator&>(*op_signature.op);
536     if (conv_op.dilation_width_factor != 1 ||
537         conv_op.dilation_height_factor != 1) {
538       // Version 2 if dilation is used.
539       return 2;
540     }
541     return 1;
542   }
543 
544   // Note: The read / write code doesn't need to be changed if we stick with
545   // the restrictions:
546   // * Only adding parameters at the bottom of the Flatbuffer tables.
547   // * When the default value of parameters are used, the op works consistently
548   //   with the previous version.
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const549   flatbuffers::Offset<TfLiteOptions> WriteOptions(
550       const TocoOperator& op,
551       flatbuffers::FlatBufferBuilder* builder) const override {
552     auto padding = Padding::Serialize(op.padding.type);
553     auto activation_function =
554         ActivationFunction::Serialize(op.fused_activation_function);
555     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
556                                          op.stride_height, activation_function,
557                                          op.dilation_width_factor,
558                                          op.dilation_height_factor);
559   }
560 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const561   void ReadOptions(const TfLiteOptions& options,
562                    TocoOperator* op) const override {
563     op->padding.type = Padding::Deserialize(options.padding());
564     op->stride_width = options.stride_w();
565     op->stride_height = options.stride_h();
566     op->dilation_width_factor = options.dilation_w_factor();
567     op->dilation_height_factor = options.dilation_h_factor();
568     op->fused_activation_function =
569         ActivationFunction::Deserialize(options.fused_activation_function());
570   }
571 };
572 
573 class VersionedOpExportTest : public ::testing::Test {
574  protected:
SetUp()575   void SetUp() override {
576     input_model_.GetOrCreateArray("input");
577     input_model_.GetOrCreateArray("filter");
578     input_model_.GetOrCreateArray("output");
579   }
AddConvOp(bool use_dilation)580   void AddConvOp(bool use_dilation) {
581     {
582       auto* op = new ConvOperator;
583       op->inputs.push_back("input");
584       op->inputs.push_back("filter");
585       op->outputs.push_back("output");
586 
587       op->padding.type = PaddingType::kSame;
588       op->stride_width = 1;
589       op->stride_height = 1;
590       if (use_dilation) {
591         op->dilation_width_factor = 2;
592         op->dilation_height_factor = 2;
593       } else {
594         op->dilation_width_factor = 1;
595         op->dilation_height_factor = 1;
596       }
597       input_model_.operators.emplace_back(op);
598     }
599   }
600 
601   std::map<OperatorType, std::unique_ptr<BaseOperator>>
BuildFakeOperatorByTypeMap()602   BuildFakeOperatorByTypeMap() {
603     std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
604     result[OperatorType::kConv] =
605         std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
606     return result;
607   }
608 
609   Model input_model_;
610 };
611 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV1)612 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
613   AddConvOp(false);
614 
615   details::OperatorsMap operators;
616   const auto ops_by_type = BuildFakeOperatorByTypeMap();
617   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
618 
619   EXPECT_EQ(1, operators.size());
620   EXPECT_EQ(0, operators.at(details::OperatorKey(
621                    ::tflite::BuiltinOperator_CONV_2D, "", 1)));
622 }
623 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithOpV2)624 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
625   AddConvOp(true);
626 
627   details::OperatorsMap operators;
628   const auto ops_by_type = BuildFakeOperatorByTypeMap();
629   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
630 
631   EXPECT_EQ(1, operators.size());
632   EXPECT_EQ(0, operators.at(details::OperatorKey(
633                    ::tflite::BuiltinOperator_CONV_2D, "", 2)));
634 }
635 
TEST_F(VersionedOpExportTest,LoadOperatorsMapWithBothVersions)636 TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
637   AddConvOp(false);
638   AddConvOp(true);
639 
640   details::OperatorsMap operators;
641   const auto ops_by_type = BuildFakeOperatorByTypeMap();
642   details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
643 
644   EXPECT_EQ(2, operators.size());
645   EXPECT_EQ(0, operators.at(details::OperatorKey(
646                    ::tflite::BuiltinOperator_CONV_2D, "", 1)));
647   EXPECT_EQ(1, operators.at(details::OperatorKey(
648                    ::tflite::BuiltinOperator_CONV_2D, "", 2)));
649 }
650 
TEST_F(VersionedOpExportTest,Export)651 TEST_F(VersionedOpExportTest, Export) {
652   AddConvOp(false);
653   AddConvOp(true);
654 
655   std::string result;
656   const auto ops_by_type = BuildFakeOperatorByTypeMap();
657   Export(input_model_, true, false, &result, ops_by_type);
658 
659   auto* model = ::tflite::GetModel(result.data());
660   auto operator_codes = model->operator_codes();
661 
662   // Verify that 2 operator codes are populated. Both are CONV_2D but with
663   // different versions.
664   EXPECT_EQ(2, operator_codes->size());
665   EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
666             GetBuiltinCode((*operator_codes)[0]));
667   EXPECT_EQ(1, (*operator_codes)[0]->version());
668   EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
669             GetBuiltinCode((*operator_codes)[1]));
670   EXPECT_EQ(2, (*operator_codes)[1]->version());
671 
672   // Verify that the 2 operators points to the correct indices of the operation
673   // codes.
674   auto operators = (*model->subgraphs())[0]->operators();
675   EXPECT_EQ(2, operators->size());
676   EXPECT_EQ(0, (*operators)[0]->opcode_index());
677   EXPECT_EQ(1, (*operators)[1]->opcode_index());
678 }
679 
TEST(OperatorKeyTest,TestBuiltinOp)680 TEST(OperatorKeyTest, TestBuiltinOp) {
681   Model model;
682   auto op = std::make_unique<ConvOperator>();
683 
684   // Test a normal float operation.
685   op->inputs = {"input", "filter"};
686   op->outputs = {"output"};
687   Array& input_array = model.GetOrCreateArray(op->inputs[0]);
688   Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
689   Array& output_array = model.GetOrCreateArray(op->outputs[0]);
690   input_array.data_type = ArrayDataType::kFloat;
691   filter_array.data_type = ArrayDataType::kFloat;
692   output_array.data_type = ArrayDataType::kFloat;
693 
694   const auto ops_by_type = BuildOperatorByTypeMap();
695   const toco::OperatorSignature op_signature = {op.get(), &model};
696   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
697 
698   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
699   EXPECT_EQ(key.custom_code(), "");
700   EXPECT_EQ(key.version(), 1);
701 }
702 
TEST(OperatorKeyTest,TestBuiltinOpWithVersionedInputTypes)703 TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
704   Model model;
705   auto op = std::make_unique<DequantizeOperator>();
706 
707   op->inputs = {"input"};
708   op->outputs = {"output"};
709   Array& input_array = model.GetOrCreateArray(op->inputs[0]);
710   Array& output_array = model.GetOrCreateArray(op->outputs[0]);
711   input_array.data_type = ArrayDataType::kInt8;
712   output_array.data_type = ArrayDataType::kFloat;
713 
714   const auto ops_by_type = BuildOperatorByTypeMap();
715 
716   // Test a signed int8 dequantize operation.
717   const toco::OperatorSignature op_signature = {op.get(), &model};
718   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
719 
720   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
721   EXPECT_EQ(key.custom_code(), "");
722   EXPECT_EQ(key.version(), 2);
723 }
724 
TEST(OperatorKeyTest,TestCustomOp)725 TEST(OperatorKeyTest, TestCustomOp) {
726   Model model;
727   auto op = std::make_unique<TensorFlowUnsupportedOperator>();
728   op->tensorflow_op = "MyCrazyCustomOp";
729 
730   const auto ops_by_type = BuildOperatorByTypeMap();
731   const toco::OperatorSignature op_signature = {op.get(), &model};
732   const auto key = details::OperatorKey(op_signature, ops_by_type, false);
733 
734   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
735   EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
736   EXPECT_EQ(key.version(), 1);
737 }
738 
TEST(OperatorKeyTest,TestFlexOp)739 TEST(OperatorKeyTest, TestFlexOp) {
740   Model model;
741   auto op = std::make_unique<TensorFlowUnsupportedOperator>();
742   op->tensorflow_op = "BatchMatMul";
743 
744   const auto ops_by_type = BuildOperatorByTypeMap();
745   {
746     const toco::OperatorSignature op_signature = {op.get(), &model};
747     const auto key = details::OperatorKey(op_signature, ops_by_type, false);
748     // It shouldn't be converted to Flex op if `allow_flex_op` is false.
749     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
750     EXPECT_EQ(key.custom_code(), "BatchMatMul");
751     EXPECT_EQ(key.version(), 1);
752     EXPECT_TRUE(key.is_custom_op());
753     EXPECT_FALSE(key.is_flex_op());
754   }
755 
756   {
757     // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
758     // is true.
759     const toco::OperatorSignature op_signature = {op.get(), &model};
760     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
761     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
762     EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
763     EXPECT_EQ(key.version(), 1);
764     EXPECT_FALSE(key.is_custom_op());
765     EXPECT_TRUE(key.is_flex_op());
766   }
767 }
768 
TEST(OperatorKeyTest,TestFlexWithControlFlowOp)769 TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
770   Model model;
771   auto op = std::make_unique<TensorFlowUnsupportedOperator>();
772   op->tensorflow_op = "Merge";
773 
774   const auto ops_by_type = BuildOperatorByTypeMap();
775   const toco::OperatorSignature op_signature = {op.get(), &model};
776   const auto key = details::OperatorKey(op_signature, ops_by_type, true);
777 
778   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
779   EXPECT_EQ(key.custom_code(), "FlexMerge");
780   EXPECT_EQ(key.version(), 1);
781   EXPECT_FALSE(key.is_custom_op());
782   EXPECT_TRUE(key.is_flex_op());
783   // The control flow ops should be marked as unsupported.
784   EXPECT_TRUE(key.is_unsupported_flex_op());
785 }
786 
TEST(OperatorKeyTest,TestFlexWithUnsupportedOp)787 TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
788   Model model;
789   auto op = std::make_unique<TensorFlowUnsupportedOperator>();
790   op->tensorflow_op = "UnsupportedOp";
791 
792   const auto ops_by_type = BuildOperatorByTypeMap();
793   const toco::OperatorSignature op_signature = {op.get(), &model};
794   const auto key = details::OperatorKey(op_signature, ops_by_type, true);
795 
796   EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
797   EXPECT_EQ(key.custom_code(), "UnsupportedOp");
798   EXPECT_EQ(key.version(), 1);
799   EXPECT_FALSE(key.is_flex_op());
800   EXPECT_FALSE(key.is_unsupported_flex_op());
801 }
802 
TEST(OperatorKeyTest,TestFlexWithPartiallySupportedOps)803 TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
804   // Test Toco-supported/TFLite-unsupported operators.
805   Model model;
806   // TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
807   // TFLite. Find a more robust way to test the fallback logic.
808   auto op = std::make_unique<TensorFlowAssertOperator>();
809 
810   const auto ops_by_type = BuildOperatorByTypeMap();
811 
812   {
813     // If NodeDef isn't retained in the Toco op, a regular custom op
814     // will be exported.
815     const toco::OperatorSignature op_signature = {op.get(), &model};
816     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
817     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
818     EXPECT_EQ(key.custom_code(), "Assert");
819     EXPECT_EQ(key.version(), 1);
820     EXPECT_TRUE(key.is_custom_op());
821     EXPECT_FALSE(key.is_flex_op());
822   }
823 
824   ::tensorflow::NodeDef node_def;
825   node_def.set_name("TensorFlowAssert");
826   node_def.set_op("TensorFlowAssert");
827   node_def.SerializeToString(&op->tensorflow_node_def);
828 
829   {
830     // If NodeDef is retained in the Toco op, a Flex op will be exported.
831     const toco::OperatorSignature op_signature = {op.get(), &model};
832     const auto key = details::OperatorKey(op_signature, ops_by_type, true);
833     EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
834     EXPECT_EQ(key.custom_code(), "FlexAssert");
835     EXPECT_EQ(key.version(), 1);
836     EXPECT_FALSE(key.is_custom_op());
837     EXPECT_TRUE(key.is_flex_op());
838   }
839 }
840 
841 // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
842 
843 }  // namespace
844 }  // namespace tflite
845 }  // namespace toco
846