1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <tuple>
16 #include <vector>
17
18 #include <gtest/gtest.h>
19 #include "absl/memory/memory.h"
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23
24 namespace toco {
25
26 namespace {
27
RunIdentifyL2Normalization(const std::vector<float> & input,const std::vector<int> & input_shape,const std::vector<int> & output_shape,const bool div_square=false)28 void RunIdentifyL2Normalization(const std::vector<float>& input,
29 const std::vector<int>& input_shape,
30 const std::vector<int>& output_shape,
31 const bool div_square = false) {
32 Model model;
33 Array& input0 = model.GetOrCreateArray("input0");
34 Array& output = model.GetOrCreateArray("output");
35
36 *input0.mutable_shape()->mutable_dims() = input_shape;
37 input0.data_type = ArrayDataType::kFloat;
38 input0.GetMutableBuffer<ArrayDataType::kFloat>().data = input;
39
40 *output.mutable_shape()->mutable_dims() = output_shape;
41
42 auto sq_op = new TensorFlowSquareOperator;
43 sq_op->inputs = {"input0"};
44 sq_op->outputs = {"output"};
45
46 Array& sumoutput = model.GetOrCreateArray("Sumoutput");
47 *sumoutput.mutable_shape()->mutable_dims() = output_shape;
48
49 auto sum_op = new TensorFlowSumOperator;
50 sum_op->inputs = {sq_op->outputs[0]};
51 sum_op->outputs = {"Sumoutput"};
52
53 if (div_square) {
54 Array& sqrtoutput = model.GetOrCreateArray("squarertoutput");
55 *sqrtoutput.mutable_shape()->mutable_dims() = output_shape;
56
57 auto sqrt_op = new TensorFlowSqrtOperator;
58 sqrt_op->inputs = {sum_op->outputs[0]};
59 sqrt_op->outputs = {"squarertoutput"};
60
61 Array& divoutput = model.GetOrCreateArray("Divoutput");
62 *divoutput.mutable_shape()->mutable_dims() = output_shape;
63
64 auto div_op = new DivOperator;
65 div_op->inputs = {"input0", sqrt_op->outputs[0]};
66 div_op->outputs = {"Divoutput"};
67
68 /*Stack everything with the model*/
69 model.operators.push_back(std::unique_ptr<Operator>(div_op));
70 model.operators.push_back(std::unique_ptr<Operator>(sqrt_op));
71 model.operators.push_back(std::unique_ptr<Operator>(sum_op));
72 model.operators.push_back(std::unique_ptr<Operator>(sq_op));
73 } else {
74 Array& rsqoutput = model.GetOrCreateArray("Rsquareoutput");
75 *rsqoutput.mutable_shape()->mutable_dims() = output_shape;
76
77 auto rsqrt_op = new TensorFlowRsqrtOperator;
78 rsqrt_op->inputs = {sum_op->outputs[0]};
79 rsqrt_op->outputs = {"Rsquareoutput"};
80
81 Array& muloutput = model.GetOrCreateArray("Muloutput");
82 *muloutput.mutable_shape()->mutable_dims() = output_shape;
83
84 auto mul_op = new MulOperator;
85 mul_op->inputs = {"input0", rsqrt_op->outputs[0]};
86 mul_op->outputs = {"Muloutput"};
87
88 /*Stack everything with the model*/
89 model.operators.push_back(std::unique_ptr<Operator>(mul_op));
90 model.operators.push_back(std::unique_ptr<Operator>(rsqrt_op));
91 model.operators.push_back(std::unique_ptr<Operator>(sum_op));
92 model.operators.push_back(std::unique_ptr<Operator>(sq_op));
93 }
94
95 bool modified;
96 ASSERT_TRUE(IdentifyL2Normalization().Run(&model, 0, &modified).ok());
97 for (auto& op_it : model.operators) {
98 Operator* op = op_it.get();
99 // Since the optimization has kicked in we should not find any
100 // Mul, Rsqrt, Add, Sqr operators
101 if (div_square) {
102 EXPECT_FALSE(op->type == OperatorType::kDiv);
103 EXPECT_FALSE(op->type == OperatorType::kSqrt);
104 } else {
105 EXPECT_FALSE(op->type == OperatorType::kMul);
106 EXPECT_FALSE(op->type == OperatorType::kRsqrt);
107 }
108 EXPECT_FALSE(op->type == OperatorType::kAdd);
109 EXPECT_FALSE(op->type == OperatorType::kSquare);
110 }
111 }
112
113 // Test for reverse input in Min
TEST(IdentifyL2Normalization,MulRsqrtTest)114 TEST(IdentifyL2Normalization, MulRsqrtTest) {
115 RunIdentifyL2Normalization(
116 // Input data
117 {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8},
118
119 // Input shape
120 {3, 4},
121
122 {3, 4},
123
124 false);
125 }
126
TEST(IdentifyL2Normalization,DivSqrtNormTest)127 TEST(IdentifyL2Normalization, DivSqrtNormTest) {
128 RunIdentifyL2Normalization(
129 // Input data
130 {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8},
131
132 // Input shape
133 {3, 4},
134
135 {3, 4},
136
137 true);
138 }
139
140 } // namespace
141 } // namespace toco
142