• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <tuple>
16 #include <vector>
17 
18 #include <gtest/gtest.h>
19 #include "absl/memory/memory.h"
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 
24 namespace toco {
25 
26 namespace {
27 
RunIdentifyL2Normalization(const std::vector<float> & input,const std::vector<int> & input_shape,const std::vector<int> & output_shape,const bool div_square=false)28 void RunIdentifyL2Normalization(const std::vector<float>& input,
29                                 const std::vector<int>& input_shape,
30                                 const std::vector<int>& output_shape,
31                                 const bool div_square = false) {
32   Model model;
33   Array& input0 = model.GetOrCreateArray("input0");
34   Array& output = model.GetOrCreateArray("output");
35 
36   *input0.mutable_shape()->mutable_dims() = input_shape;
37   input0.data_type = ArrayDataType::kFloat;
38   input0.GetMutableBuffer<ArrayDataType::kFloat>().data = input;
39 
40   *output.mutable_shape()->mutable_dims() = output_shape;
41 
42   auto sq_op = new TensorFlowSquareOperator;
43   sq_op->inputs = {"input0"};
44   sq_op->outputs = {"output"};
45 
46   Array& sumoutput = model.GetOrCreateArray("Sumoutput");
47   *sumoutput.mutable_shape()->mutable_dims() = output_shape;
48 
49   auto sum_op = new TensorFlowSumOperator;
50   sum_op->inputs = {sq_op->outputs[0]};
51   sum_op->outputs = {"Sumoutput"};
52 
53   if (div_square) {
54     Array& sqrtoutput = model.GetOrCreateArray("squarertoutput");
55     *sqrtoutput.mutable_shape()->mutable_dims() = output_shape;
56 
57     auto sqrt_op = new TensorFlowSqrtOperator;
58     sqrt_op->inputs = {sum_op->outputs[0]};
59     sqrt_op->outputs = {"squarertoutput"};
60 
61     Array& divoutput = model.GetOrCreateArray("Divoutput");
62     *divoutput.mutable_shape()->mutable_dims() = output_shape;
63 
64     auto div_op = new DivOperator;
65     div_op->inputs = {"input0", sqrt_op->outputs[0]};
66     div_op->outputs = {"Divoutput"};
67 
68     /*Stack everything with the model*/
69     model.operators.push_back(std::unique_ptr<Operator>(div_op));
70     model.operators.push_back(std::unique_ptr<Operator>(sqrt_op));
71     model.operators.push_back(std::unique_ptr<Operator>(sum_op));
72     model.operators.push_back(std::unique_ptr<Operator>(sq_op));
73   } else {
74     Array& rsqoutput = model.GetOrCreateArray("Rsquareoutput");
75     *rsqoutput.mutable_shape()->mutable_dims() = output_shape;
76 
77     auto rsqrt_op = new TensorFlowRsqrtOperator;
78     rsqrt_op->inputs = {sum_op->outputs[0]};
79     rsqrt_op->outputs = {"Rsquareoutput"};
80 
81     Array& muloutput = model.GetOrCreateArray("Muloutput");
82     *muloutput.mutable_shape()->mutable_dims() = output_shape;
83 
84     auto mul_op = new MulOperator;
85     mul_op->inputs = {"input0", rsqrt_op->outputs[0]};
86     mul_op->outputs = {"Muloutput"};
87 
88     /*Stack everything with the model*/
89     model.operators.push_back(std::unique_ptr<Operator>(mul_op));
90     model.operators.push_back(std::unique_ptr<Operator>(rsqrt_op));
91     model.operators.push_back(std::unique_ptr<Operator>(sum_op));
92     model.operators.push_back(std::unique_ptr<Operator>(sq_op));
93   }
94 
95   bool modified;
96   ASSERT_TRUE(IdentifyL2Normalization().Run(&model, 0, &modified).ok());
97   for (auto& op_it : model.operators) {
98     Operator* op = op_it.get();
99     // Since the optimization has kicked in we should not find any
100     // Mul, Rsqrt, Add, Sqr  operators
101     if (div_square) {
102       EXPECT_FALSE(op->type == OperatorType::kDiv);
103       EXPECT_FALSE(op->type == OperatorType::kSqrt);
104     } else {
105       EXPECT_FALSE(op->type == OperatorType::kMul);
106       EXPECT_FALSE(op->type == OperatorType::kRsqrt);
107     }
108     EXPECT_FALSE(op->type == OperatorType::kAdd);
109     EXPECT_FALSE(op->type == OperatorType::kSquare);
110   }
111 }
112 
113 // Test for reverse input in Min
TEST(IdentifyL2Normalization,MulRsqrtTest)114 TEST(IdentifyL2Normalization, MulRsqrtTest) {
115   RunIdentifyL2Normalization(
116       // Input data
117       {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8},
118 
119       // Input shape
120       {3, 4},
121 
122       {3, 4},
123 
124       false);
125 }
126 
TEST(IdentifyL2Normalization,DivSqrtNormTest)127 TEST(IdentifyL2Normalization, DivSqrtNormTest) {
128   RunIdentifyL2Normalization(
129       // Input data
130       {3, 1, 4, 1, -5, 9, -2, 6, 5, 3, 5, 8},
131 
132       // Input shape
133       {3, 4},
134 
135       {3, 4},
136 
137       true);
138 }
139 
140 }  // namespace
141 }  // namespace toco
142