1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include "flatbuffers/flexbuffers.h"
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/toco/model.h"
21 #include "tensorflow/lite/toco/tooling_util.h"
22
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25
26 namespace toco {
27
28 namespace tflite {
29 namespace {
30
31 class OperatorTest : public ::testing::Test {
32 protected:
33 // Return the operator for the given name and type.
GetOperator(const string & name,OperatorType type)34 const BaseOperator& GetOperator(const string& name, OperatorType type) {
35 using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
36 using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
37
38 static auto* by_name = new OpsByName(BuildOperatorByNameMap());
39 static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
40
41 // Make sure the two maps were consitently built.
42 CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
43 BaseOperator* op1 = by_name->at(name).get();
44 CHECK(op1->type() == type) << "while verifying '" << name << "'.";
45
46 CHECK(by_type->count(type))
47 << "No operator for '" << OperatorTypeName(type) << "'.";
48 BaseOperator* op2 = by_type->at(type).get();
49 CHECK(op2->name() == name)
50 << "while verifying '" << OperatorTypeName(type) << "'.";
51
52 return *op1;
53 }
54
55 // Use the given BaseOperator to serialize the tf.mini operator into a set of
56 // TF Lite options. Proceed to deserialize the options back into a new
57 // tf.mini operator, which is then returned. If `options` is given, it will
58 // be populated with the serialized options.
59 template <typename T>
SerializeAndDeserialize(const BaseOperator & op,const T & toco_op,Options * options=nullptr)60 std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
61 const T& toco_op,
62 Options* options = nullptr) {
63 flatbuffers::FlatBufferBuilder builder;
64 Options input_options = op.Serialize(toco_op, &builder);
65
66 if (options) {
67 *options = input_options;
68 }
69
70 builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
71 input_options.builtin, input_options.custom,
72 ::tflite::CustomOptionsFormat_FLEXBUFFERS));
73 auto* output_options =
74 flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
75 auto new_toco_op = op.Deserialize(output_options->builtin_options(),
76 output_options->custom_options());
77
78 CHECK(new_toco_op->type == toco_op.type)
79 << "The type of the serialized and deserialized"
80 << HelpfulOperatorTypeName(*new_toco_op)
81 << " does not match the type of the original "
82 << HelpfulOperatorTypeName(toco_op);
83
84 return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
85 }
86
87 // Verify serialization and deserialization of simple operators (those
88 // that don't have any configuration parameters).
89 template <typename T>
CheckSimpleOperator(const string & name,OperatorType type)90 void CheckSimpleOperator(const string& name, OperatorType type) {
91 Options options;
92 auto output_toco_op =
93 SerializeAndDeserialize(GetOperator(name, type), T(), &options);
94
95 ASSERT_EQ(0, options.builtin.o);
96 ASSERT_EQ(0, options.custom.o);
97 ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
98
99 ASSERT_NE(nullptr, output_toco_op.get());
100 }
101
102 template <typename T>
CheckReducerOperator(const string & name,OperatorType type)103 void CheckReducerOperator(const string& name, OperatorType type) {
104 T op;
105
106 op.keep_dims = false;
107
108 auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
109 EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
110 }
111 };
112
TEST_F(OperatorTest,SimpleOperators)113 TEST_F(OperatorTest, SimpleOperators) {
114 CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
115 CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
116 CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
117 CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
118 CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
119 CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
120 CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
121 CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
122 CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
123 CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
124 CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
125 OperatorType::kLogSoftmax);
126 CheckSimpleOperator<TensorFlowMaximumOperator>(
127 "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum
128 CheckSimpleOperator<TensorFlowMinimumOperator>(
129 "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum
130 CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
131 CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
132 CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
133 CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
134 CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
135 CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
136 CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
137 OperatorType::kNotEqual);
138 CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
139 CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
140 CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
141 CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
142 CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
143 OperatorType::kLogicalOr);
144 CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND",
145 OperatorType::kLogicalAnd);
146 CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
147 OperatorType::kLogicalNot);
148 CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
149 CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
150 OperatorType::kSquare);
151 CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
152 OperatorType::kZerosLike);
153 CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
154 CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
155 CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
156 CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
157 OperatorType::kReverseV2);
158 CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
159 }
160
TEST_F(OperatorTest,BuiltinAdd)161 TEST_F(OperatorTest, BuiltinAdd) {
162 AddOperator op;
163 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
164 auto output_toco_op =
165 SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
166 EXPECT_EQ(op.fused_activation_function,
167 output_toco_op->fused_activation_function);
168 }
169
TEST_F(OperatorTest,BuiltinAddN)170 TEST_F(OperatorTest, BuiltinAddN) {
171 AddNOperator op;
172 auto output_toco_op =
173 SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
174 ASSERT_NE(output_toco_op.get(), nullptr);
175 }
176
TEST_F(OperatorTest,BuiltinReducerOps)177 TEST_F(OperatorTest, BuiltinReducerOps) {
178 CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
179 CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
180 CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
181 OperatorType::kReduceProd);
182 CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
183 OperatorType::kReduceMax);
184 CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
185 OperatorType::kReduceMin);
186 CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
187 }
188
TEST_F(OperatorTest,BuiltinCast)189 TEST_F(OperatorTest, BuiltinCast) {
190 CastOperator op;
191 op.src_data_type = ArrayDataType::kFloat;
192 op.dst_data_type = ArrayDataType::kUint8;
193 auto output_toco_op =
194 SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
195 EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
196 EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
197 }
198
TEST_F(OperatorTest,CustomConcatenation)199 TEST_F(OperatorTest, CustomConcatenation) {
200 ConcatenationOperator op;
201 op.axis = 123;
202 auto output_toco_op = SerializeAndDeserialize(
203 GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
204 EXPECT_EQ(op.axis, output_toco_op->axis);
205 }
206
TEST_F(OperatorTest,CustomDepthToSpace)207 TEST_F(OperatorTest, CustomDepthToSpace) {
208 DepthToSpaceOperator op;
209 op.block_size = 123;
210 auto output_toco_op = SerializeAndDeserialize(
211 GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
212 EXPECT_EQ(op.block_size, output_toco_op->block_size);
213 }
214
TEST_F(OperatorTest,CustomFakeQuant)215 TEST_F(OperatorTest, CustomFakeQuant) {
216 FakeQuantOperator op;
217 auto* minmax = new MinMax;
218 minmax->min = -10;
219 minmax->max = 200;
220 op.minmax.reset(minmax);
221 op.num_bits = 16;
222 auto output_toco_op = SerializeAndDeserialize(
223 GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
224 EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
225 EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
226 EXPECT_EQ(op.num_bits, output_toco_op->num_bits);
227 }
228
TEST_F(OperatorTest,CustomFullyConnected)229 TEST_F(OperatorTest, CustomFullyConnected) {
230 FullyConnectedOperator op;
231 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
232 auto output_toco_op = SerializeAndDeserialize(
233 GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
234 EXPECT_EQ(op.fused_activation_function,
235 output_toco_op->fused_activation_function);
236 }
237
TEST_F(OperatorTest,BuiltinGather)238 TEST_F(OperatorTest, BuiltinGather) {
239 GatherOperator op;
240 auto output_toco_op =
241 SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
242 ASSERT_NE(nullptr, output_toco_op.get());
243 }
244
TEST_F(OperatorTest,BuiltinGatherNd)245 TEST_F(OperatorTest, BuiltinGatherNd) {
246 GatherNdOperator op;
247 auto output_toco_op = SerializeAndDeserialize(
248 GetOperator("GATHER_ND", OperatorType::kGatherNd), op);
249 ASSERT_NE(output_toco_op.get(), nullptr);
250 }
251
TEST_F(OperatorTest,BuiltinWhere)252 TEST_F(OperatorTest, BuiltinWhere) {
253 WhereOperator op;
254 auto output_toco_op =
255 SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
256 ASSERT_NE(output_toco_op.get(), nullptr);
257 }
258
TEST_F(OperatorTest,BuiltinL2Pool)259 TEST_F(OperatorTest, BuiltinL2Pool) {
260 L2PoolOperator op;
261 op.stride_width = 123;
262 op.stride_height = 124;
263 op.padding.type = PaddingType::kValid;
264 op.kwidth = 480;
265 op.kheight = 1080;
266 auto output_toco_op = SerializeAndDeserialize(
267 GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
268 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
269 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
270 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
271 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
272 EXPECT_EQ(op.kheight, output_toco_op->kheight);
273 }
274
TEST_F(OperatorTest,BuiltinLocalResponseNormalization)275 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
276 LocalResponseNormalizationOperator op;
277 op.range = 123;
278 op.bias = 1.23;
279 op.alpha = 12.3;
280 op.beta = .123;
281 auto output_toco_op = SerializeAndDeserialize(
282 GetOperator("LOCAL_RESPONSE_NORMALIZATION",
283 OperatorType::kLocalResponseNormalization),
284 op);
285 EXPECT_EQ(op.range, output_toco_op->range);
286 EXPECT_EQ(op.bias, output_toco_op->bias);
287 EXPECT_EQ(op.alpha, output_toco_op->alpha);
288 EXPECT_EQ(op.beta, output_toco_op->beta);
289 }
290
TEST_F(OperatorTest,BuiltinMaxPool)291 TEST_F(OperatorTest, BuiltinMaxPool) {
292 MaxPoolOperator op;
293 op.stride_width = 123;
294 op.stride_height = 124;
295 op.padding.type = PaddingType::kValid;
296 op.kwidth = 480;
297 op.kheight = 1080;
298 auto output_toco_op = SerializeAndDeserialize(
299 GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
300 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
301 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
302 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
303 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
304 EXPECT_EQ(op.kheight, output_toco_op->kheight);
305 }
306
TEST_F(OperatorTest,BuiltinReshape)307 TEST_F(OperatorTest, BuiltinReshape) {
308 TensorFlowReshapeOperator op;
309 op.shape = {1, 2, 4, 5, 8};
310 auto output_toco_op = SerializeAndDeserialize(
311 GetOperator("RESHAPE", OperatorType::kReshape), op);
312 EXPECT_EQ(op.shape, output_toco_op->shape);
313 }
314
TEST_F(OperatorTest,CustomSoftmax)315 TEST_F(OperatorTest, CustomSoftmax) {
316 SoftmaxOperator op;
317 op.beta = 123.1;
318 auto output_toco_op = SerializeAndDeserialize(
319 GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
320 EXPECT_EQ(op.beta, output_toco_op->beta);
321 }
322
TEST_F(OperatorTest,BuiltinSpaceToDepth)323 TEST_F(OperatorTest, BuiltinSpaceToDepth) {
324 SpaceToDepthOperator op;
325 op.block_size = 123;
326 auto output_toco_op = SerializeAndDeserialize(
327 GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
328 EXPECT_EQ(op.block_size, output_toco_op->block_size);
329 }
330
TEST_F(OperatorTest,CustomSplit)331 TEST_F(OperatorTest, CustomSplit) {
332 TensorFlowSplitOperator op;
333 op.num_split = 123;
334 auto output_toco_op =
335 SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
336 EXPECT_EQ(op.num_split, output_toco_op->num_split);
337 }
338
TEST_F(OperatorTest,CustomSplitV)339 TEST_F(OperatorTest, CustomSplitV) {
340 TensorFlowSplitVOperator op;
341 op.num_split = 123;
342 auto output_toco_op = SerializeAndDeserialize(
343 GetOperator("SPLIT_V", OperatorType::kSplitV), op);
344 EXPECT_EQ(op.num_split, output_toco_op->num_split);
345 }
346
TEST_F(OperatorTest,BuiltinAveragePool)347 TEST_F(OperatorTest, BuiltinAveragePool) {
348 AveragePoolOperator op;
349 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
350 op.stride_width = 123;
351 op.stride_height = 124;
352 op.padding.type = PaddingType::kValid;
353 op.kwidth = 480;
354 op.kheight = 1080;
355 auto output_toco_op = SerializeAndDeserialize(
356 GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
357 EXPECT_EQ(op.fused_activation_function,
358 output_toco_op->fused_activation_function);
359 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
360 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
361 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
362 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
363 EXPECT_EQ(op.kheight, output_toco_op->kheight);
364 }
365
TEST_F(OperatorTest,BuiltinConvolution)366 TEST_F(OperatorTest, BuiltinConvolution) {
367 ConvOperator op;
368 op.stride_width = 123;
369 op.stride_height = 124;
370 op.padding.type = PaddingType::kValid;
371 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
372 auto output_toco_op =
373 SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
374 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
375 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
376 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
377 EXPECT_EQ(op.fused_activation_function,
378 output_toco_op->fused_activation_function);
379 }
380
TEST_F(OperatorTest,BuiltinDepthwiseConvolution)381 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
382 DepthwiseConvOperator op;
383 op.stride_width = 123;
384 op.stride_height = 124;
385 op.padding.type = PaddingType::kValid;
386 op.depth_multiplier = 6;
387 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
388 auto output_toco_op = SerializeAndDeserialize(
389 GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
390 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
391 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
392 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
393 EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
394 EXPECT_EQ(op.fused_activation_function,
395 output_toco_op->fused_activation_function);
396 }
397
TEST_F(OperatorTest,BuiltinL2Norm)398 TEST_F(OperatorTest, BuiltinL2Norm) {
399 L2NormalizationOperator op;
400 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
401 auto output_toco_op = SerializeAndDeserialize(
402 GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
403 EXPECT_EQ(op.fused_activation_function,
404 output_toco_op->fused_activation_function);
405 }
406
TEST_F(OperatorTest,BuiltinMul)407 TEST_F(OperatorTest, BuiltinMul) {
408 MulOperator op;
409 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
410 auto output_toco_op =
411 SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
412 EXPECT_EQ(op.fused_activation_function,
413 output_toco_op->fused_activation_function);
414 }
415
TEST_F(OperatorTest,ResizeBilinear)416 TEST_F(OperatorTest, ResizeBilinear) {
417 ResizeBilinearOperator op;
418 op.align_corners = true;
419 auto output_toco_op = SerializeAndDeserialize(
420 GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
421 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
422 }
423
TEST_F(OperatorTest,ResizeNearestNeighbor)424 TEST_F(OperatorTest, ResizeNearestNeighbor) {
425 ResizeNearestNeighborOperator op;
426 op.align_corners = true;
427 auto output_toco_op =
428 SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
429 OperatorType::kResizeNearestNeighbor),
430 op);
431 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
432 }
433
TEST_F(OperatorTest,Svdf)434 TEST_F(OperatorTest, Svdf) {
435 SvdfOperator op;
436 op.fused_activation_function = FusedActivationFunctionType::kRelu;
437 op.rank = 1;
438 auto output_toco_op =
439 SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
440 EXPECT_EQ(op.fused_activation_function,
441 output_toco_op->fused_activation_function);
442 EXPECT_EQ(op.rank, output_toco_op->rank);
443 }
444
TEST_F(OperatorTest,Squeeze)445 TEST_F(OperatorTest, Squeeze) {
446 SqueezeOperator op;
447 op.squeeze_dims = {-2, -3, 4, 1, 4};
448
449 auto output_toco_op = SerializeAndDeserialize(
450 GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
451 EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
452 }
453
TEST_F(OperatorTest,StridedSlice)454 TEST_F(OperatorTest, StridedSlice) {
455 StridedSliceOperator op;
456
457 op.begin_mask = 1;
458 op.end_mask = 2;
459 op.ellipsis_mask = 1;
460 op.new_axis_mask = 1;
461 op.shrink_axis_mask = 2;
462
463 auto output_toco_op = SerializeAndDeserialize(
464 GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
465 EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
466 EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
467 EXPECT_EQ(op.strides, output_toco_op->strides);
468 EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
469 EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
470 EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
471 EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
472 EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
473 EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
474 }
475
TEST_F(OperatorTest,BuiltinTopKV2)476 TEST_F(OperatorTest, BuiltinTopKV2) {
477 TopKV2Operator op;
478 auto output_toco_op = SerializeAndDeserialize(
479 GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
480 ASSERT_NE(nullptr, output_toco_op.get());
481 }
482
TEST_F(OperatorTest,BuiltinArgMax)483 TEST_F(OperatorTest, BuiltinArgMax) {
484 ArgMaxOperator op;
485 auto output_toco_op = SerializeAndDeserialize(
486 GetOperator("ARG_MAX", OperatorType::kArgMax), op);
487 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
488 }
489
TEST_F(OperatorTest,BuiltinArgMin)490 TEST_F(OperatorTest, BuiltinArgMin) {
491 ArgMinOperator op;
492 auto output_toco_op = SerializeAndDeserialize(
493 GetOperator("ARG_MIN", OperatorType::kArgMin), op);
494 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
495 }
496
TEST_F(OperatorTest,BuiltinDequantize)497 TEST_F(OperatorTest, BuiltinDequantize) {
498 DequantizeOperator op;
499 auto output_toco_op = SerializeAndDeserialize(
500 GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
501 }
502
TEST_F(OperatorTest,BuiltinTransposeConv)503 TEST_F(OperatorTest, BuiltinTransposeConv) {
504 TransposeConvOperator op;
505 op.stride_width = 123;
506 op.stride_height = 124;
507 op.padding.type = PaddingType::kValid;
508 auto output_toco_op = SerializeAndDeserialize(
509 GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op);
510 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
511 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
512 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
513 }
514
TEST_F(OperatorTest,BuiltinShape)515 TEST_F(OperatorTest, BuiltinShape) {
516 TensorFlowShapeOperator op;
517 op.output_data_type = ArrayDataType::kInt64;
518 auto output_toco_op =
519 SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
520 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
521 }
522
TEST_F(OperatorTest,BuiltinSparseToDense)523 TEST_F(OperatorTest, BuiltinSparseToDense) {
524 SparseToDenseOperator op;
525 op.validate_indices = false;
526 std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
527 SerializeAndDeserialize(
528 GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
529 EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
530 }
531
TEST_F(OperatorTest,BuiltinPack)532 TEST_F(OperatorTest, BuiltinPack) {
533 PackOperator op;
534 op.values_count = 3;
535 op.axis = 1;
536 std::unique_ptr<toco::PackOperator> output_toco_op =
537 SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op);
538 EXPECT_EQ(op.values_count, output_toco_op->values_count);
539 EXPECT_EQ(op.axis, output_toco_op->axis);
540 }
541
TEST_F(OperatorTest,BuiltinOneHot)542 TEST_F(OperatorTest, BuiltinOneHot) {
543 OneHotOperator op;
544 op.axis = 2;
545 auto output_toco_op = SerializeAndDeserialize(
546 GetOperator("ONE_HOT", OperatorType::kOneHot), op);
547 EXPECT_EQ(op.axis, output_toco_op->axis);
548 }
549
TEST_F(OperatorTest,BuiltinUnpack)550 TEST_F(OperatorTest, BuiltinUnpack) {
551 UnpackOperator op;
552 op.num = 5;
553 op.axis = 2;
554 auto output_toco_op =
555 SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
556 EXPECT_EQ(op.num, output_toco_op->num);
557 EXPECT_EQ(op.axis, output_toco_op->axis);
558 }
559
TEST_F(OperatorTest,BuiltinLeakyRelu)560 TEST_F(OperatorTest, BuiltinLeakyRelu) {
561 LeakyReluOperator op;
562 op.alpha = 3;
563 auto output_toco_op = SerializeAndDeserialize(
564 GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
565 EXPECT_EQ(op.alpha, output_toco_op->alpha);
566 }
567
TEST_F(OperatorTest,BuiltinSquaredDifference)568 TEST_F(OperatorTest, BuiltinSquaredDifference) {
569 SquaredDifferenceOperator op;
570 auto output_toco_op = SerializeAndDeserialize(
571 GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op);
572 ASSERT_NE(nullptr, output_toco_op.get());
573 }
574
TEST_F(OperatorTest,CustomCTCBeamSearchDecoder)575 TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
576 CTCBeamSearchDecoderOperator op;
577 op.beam_width = 3;
578 op.top_paths = 2;
579 op.merge_repeated = false;
580 std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
581 SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
582 OperatorType::kCTCBeamSearchDecoder),
583 op);
584 EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
585 EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
586 EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
587 }
588
TEST_F(OperatorTest,TensorFlowUnsupported)589 TEST_F(OperatorTest, TensorFlowUnsupported) {
590 TensorFlowUnsupportedOperator op;
591 op.tensorflow_op = "MyCustomUnsupportedOp";
592
593 ::tensorflow::NodeDef node_def;
594 auto attr = node_def.mutable_attr();
595 (*attr)["float_attr"].set_f(2.0);
596 (*attr)["str_attr"].set_s("Hello World");
597 (*attr)["int_attr"].set_i(17);
598 (*attr)["bool_attr"].set_b(true);
599 {
600 auto* list = (*attr)["list_string_attr"].mutable_list();
601 list->add_s("abcde");
602 list->add_s("1234");
603 list->add_s("");
604 list->add_s("zyxwv");
605 list->add_s("!-.");
606 }
607 {
608 auto* list = (*attr)["list_float_attr"].mutable_list();
609 list->add_f(std::numeric_limits<float>::min());
610 list->add_f(2.0);
611 list->add_f(-std::numeric_limits<float>::max());
612 }
613 {
614 auto* list = (*attr)["list_int_attr"].mutable_list();
615 list->add_i(1);
616 list->add_i(20);
617 list->add_i(1LL << 40);
618 list->add_i(-(1LL << 40));
619 }
620 node_def.SerializeToString(&op.tensorflow_node_def);
621
622 auto output_toco_op = SerializeAndDeserialize(
623 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
624
625 ::tensorflow::NodeDef output_node_def;
626 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
627 const auto& output_attr = output_node_def.attr();
628 EXPECT_EQ(2.0, output_attr.at("float_attr").f());
629 EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
630 EXPECT_EQ(17, output_attr.at("int_attr").i());
631 EXPECT_EQ(true, output_attr.at("bool_attr").b());
632 {
633 const auto& list = output_attr.at("list_string_attr").list();
634 ASSERT_EQ(5, list.s_size());
635 EXPECT_EQ("abcde", list.s(0));
636 EXPECT_EQ("1234", list.s(1));
637 EXPECT_EQ("", list.s(2));
638 EXPECT_EQ("zyxwv", list.s(3));
639 EXPECT_EQ("!-.", list.s(4));
640 }
641 {
642 const auto& list = output_attr.at("list_float_attr").list();
643 ASSERT_EQ(3, list.f_size());
644 EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0));
645 EXPECT_EQ(2.0, list.f(1));
646 EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2));
647 }
648 {
649 const auto& list = output_attr.at("list_int_attr").list();
650 ASSERT_EQ(4, list.i_size());
651 EXPECT_EQ(1, list.i(0));
652 EXPECT_EQ(20, list.i(1));
653 EXPECT_EQ(1LL << 40, list.i(2));
654 EXPECT_EQ(-(1LL << 40), list.i(3));
655 }
656 }
657
TEST_F(OperatorTest,TensorFlowUnsupportedWithoutAttr)658 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
659 TensorFlowUnsupportedOperator op;
660 op.tensorflow_op = "MyCustomUnsupportedOp";
661 auto output_toco_op = SerializeAndDeserialize(
662 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
663
664 ::tensorflow::NodeDef output_node_def;
665 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
666 EXPECT_TRUE(output_node_def.attr().empty());
667 }
668
TEST_F(OperatorTest,TestShouldExportAsFlexOp)669 TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
670 EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
671 EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
672 EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad"));
673 EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT"));
674 EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
675 // While the RandomShuffle op is available on desktop, it is not in the kernel
676 // set available on mobile and should be excluded.
677 EXPECT_FALSE(ShouldExportAsFlexOp(true, "RandomShuffle"));
678 }
679
TEST_F(OperatorTest,BuiltinMirrorPad)680 TEST_F(OperatorTest, BuiltinMirrorPad) {
681 MirrorPadOperator op;
682 op.mode = MirrorPadMode::kReflect;
683 auto output_toco_op = SerializeAndDeserialize(
684 GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op);
685 EXPECT_EQ(op.mode, output_toco_op->mode);
686 }
687
TEST_F(OperatorTest,BuiltinUnique)688 TEST_F(OperatorTest, BuiltinUnique) {
689 UniqueOperator op;
690 op.idx_out_type = ArrayDataType::kInt64;
691 auto output_toco_op =
692 SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
693 ASSERT_NE(nullptr, output_toco_op.get());
694 EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
695 }
696
TEST_F(OperatorTest,BuiltinReverseSequence)697 TEST_F(OperatorTest, BuiltinReverseSequence) {
698 ReverseSequenceOperator op;
699 op.seq_dim = 3;
700 op.batch_dim = 1;
701 std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op =
702 SerializeAndDeserialize(
703 GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op);
704 EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim);
705 EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
706 }
707
708 // Test version for a simple Op with 2 versions and the input type controls the
709 // version.
710 template <typename Op>
SimpleVersioningTest()711 void SimpleVersioningTest() {
712 Op op;
713 op.inputs = {"input1"};
714 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
715 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
716
717 Model uint8_model;
718 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
719 uint8_array.data_type = ArrayDataType::kUint8;
720 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
721 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
722
723 Model int8_model;
724 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
725 int8_array.data_type = ArrayDataType::kInt8;
726 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
727 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
728 }
729
730 // Test version for a simple Op with 2 versions and the output type controls the
731 // version.
732 template <typename Op>
SimpleOutputVersioningTest()733 void SimpleOutputVersioningTest() {
734 Op op;
735 op.outputs = {"output1"};
736 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
737 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
738
739 Model uint8_model;
740 Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
741 uint8_array.data_type = ArrayDataType::kUint8;
742 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
743 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
744
745 Model int8_model;
746 Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
747 int8_array.data_type = ArrayDataType::kInt8;
748 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
749 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
750 }
751
TEST_F(OperatorTest,VersioningEqualTest)752 TEST_F(OperatorTest, VersioningEqualTest) {
753 SimpleVersioningTest<TensorFlowEqualOperator>();
754 }
755
TEST_F(OperatorTest,VersioningNotEqualTest)756 TEST_F(OperatorTest, VersioningNotEqualTest) {
757 SimpleVersioningTest<TensorFlowNotEqualOperator>();
758 }
759
TEST_F(OperatorTest,VersioningLessTest)760 TEST_F(OperatorTest, VersioningLessTest) {
761 SimpleVersioningTest<TensorFlowLessOperator>();
762 }
763
TEST_F(OperatorTest,VersioningLessEqualTest)764 TEST_F(OperatorTest, VersioningLessEqualTest) {
765 SimpleVersioningTest<TensorFlowLessEqualOperator>();
766 }
767
TEST_F(OperatorTest,VersioningGreaterTest)768 TEST_F(OperatorTest, VersioningGreaterTest) {
769 SimpleVersioningTest<TensorFlowGreaterOperator>();
770 }
771
TEST_F(OperatorTest,VersioningGreaterEqualTest)772 TEST_F(OperatorTest, VersioningGreaterEqualTest) {
773 SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
774 }
775
TEST_F(OperatorTest,VersioningSpaceToBatchNDTest)776 TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
777 SimpleVersioningTest<SpaceToBatchNDOperator>();
778 }
779
TEST_F(OperatorTest,VersioningLogSoftmaxTest)780 TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
781 SimpleVersioningTest<LogSoftmaxOperator>();
782 }
783
TEST_F(OperatorTest,VersioningPackTest)784 TEST_F(OperatorTest, VersioningPackTest) {
785 SimpleVersioningTest<PackOperator>();
786 }
787
TEST_F(OperatorTest,VersioningBatchToSpaceNDTest)788 TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
789 SimpleVersioningTest<BatchToSpaceNDOperator>();
790 }
791
TEST_F(OperatorTest,VersioningTanhTest)792 TEST_F(OperatorTest, VersioningTanhTest) {
793 SimpleVersioningTest<TanhOperator>();
794 }
795
TEST_F(OperatorTest,VersioningStridedSliceTest)796 TEST_F(OperatorTest, VersioningStridedSliceTest) {
797 SimpleVersioningTest<StridedSliceOperator>();
798 }
799
TEST_F(OperatorTest,VersioningSpaceToDepthTest)800 TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
801 SimpleVersioningTest<SpaceToDepthOperator>();
802 }
803
TEST_F(OperatorTest,VersioningSliceTest)804 TEST_F(OperatorTest, VersioningSliceTest) {
805 SimpleVersioningTest<SliceOperator>();
806 }
807
TEST_F(OperatorTest,VersioningLogisticTest)808 TEST_F(OperatorTest, VersioningLogisticTest) {
809 SimpleVersioningTest<LogisticOperator>();
810 }
811
TEST_F(OperatorTest,VersioningL2NormTest)812 TEST_F(OperatorTest, VersioningL2NormTest) {
813 SimpleOutputVersioningTest<L2NormalizationOperator>();
814 }
815
TEST_F(OperatorTest,VersioningMaxTest)816 TEST_F(OperatorTest, VersioningMaxTest) {
817 SimpleVersioningTest<TensorFlowMaximumOperator>();
818 }
819
TEST_F(OperatorTest,VersioningMinTest)820 TEST_F(OperatorTest, VersioningMinTest) {
821 SimpleVersioningTest<TensorFlowMinimumOperator>();
822 }
823
TEST_F(OperatorTest,VersioningAddTest)824 TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
825
TEST_F(OperatorTest,VersioningSubTest)826 TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
827
TEST_F(OperatorTest,VersioningMulTest)828 TEST_F(OperatorTest, VersioningMulTest) { SimpleVersioningTest<MulOperator>(); }
829
TEST_F(OperatorTest,VersioningPadTest)830 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
831
TEST_F(OperatorTest,VersioningPadV2Test)832 TEST_F(OperatorTest, VersioningPadV2Test) {
833 SimpleVersioningTest<PadV2Operator>();
834 }
835
TEST_F(OperatorTest,VersioningConcatenationTest)836 TEST_F(OperatorTest, VersioningConcatenationTest) {
837 SimpleVersioningTest<ConcatenationOperator>();
838 }
839
TEST_F(OperatorTest,VersioningSelectTest)840 TEST_F(OperatorTest, VersioningSelectTest) {
841 SimpleVersioningTest<SelectOperator>();
842 }
843
TEST_F(OperatorTest,VersioningRelu6Test)844 TEST_F(OperatorTest, VersioningRelu6Test) {
845 SimpleVersioningTest<Relu6Operator>();
846 }
847
TEST_F(OperatorTest,VersioningFullyConnectedTest)848 TEST_F(OperatorTest, VersioningFullyConnectedTest) {
849 FullyConnectedOperator fully_connected_op;
850 fully_connected_op.inputs = {"input", "weight"};
851 fully_connected_op.outputs = {"output"};
852 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
853 const BaseOperator* op =
854 operator_by_type_map.at(fully_connected_op.type).get();
855
856 Model uint8_model;
857 Array& input_uint8_array =
858 uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
859 input_uint8_array.data_type = ArrayDataType::kUint8;
860 Array& weight_uint8_array =
861 uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
862 weight_uint8_array.data_type = ArrayDataType::kUint8;
863 Array& output_uint8_array =
864 uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
865 output_uint8_array.data_type = ArrayDataType::kUint8;
866 OperatorSignature uint8_signature = {.op = &fully_connected_op,
867 .model = &uint8_model};
868 EXPECT_EQ(op->GetVersion(uint8_signature), 1);
869
870 Model int8_model;
871 Array& input_int8_array =
872 int8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
873 input_int8_array.data_type = ArrayDataType::kInt8;
874 Array& weight_int8_array =
875 int8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
876 weight_int8_array.data_type = ArrayDataType::kInt8;
877 Array& output_int8_array =
878 int8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
879 output_int8_array.data_type = ArrayDataType::kInt8;
880 OperatorSignature int8_signature = {.op = &fully_connected_op,
881 .model = &int8_model};
882 EXPECT_EQ(op->GetVersion(int8_signature), 4);
883 }
884
885 } // namespace
886 } // namespace tflite
887
888 } // namespace toco
889