1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28
29 namespace tensorflow {
30 namespace graph_transforms {
31
32 // Declare here, so we don't need a public header.
33 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
34 const TransformFuncContext& context,
35 GraphDef* output_graph_def);
36 struct MinMaxRecord {
37 string name;
38 float min;
39 float max;
40 };
41 Status ExtractMinMaxRecords(const string& log_file_name,
42 std::vector<MinMaxRecord>* records);
43
44 class FreezeRequantizationRangesTest : public ::testing::Test {
45 protected:
TestFreezeRequantizationRanges()46 void TestFreezeRequantizationRanges() {
47 auto root = tensorflow::Scope::NewRootScope();
48 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
49
50 Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
51 test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
52 Output quantized_op = Const(root.WithOpName("quantized_op"),
53 Input::Initializer(quantized_tensor));
54
55 Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
56 test::FillValues<float>(&quantized_min_tensor, {2.0f});
57 Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
58 Input::Initializer(quantized_min_tensor));
59
60 Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
61 test::FillValues<float>(&quantized_max_tensor, {2.0f});
62 Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
63 Input::Initializer(quantized_min_tensor));
64
65 Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
66 test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
67 Output offset_op =
68 Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
69
70 Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
71 test::FillValues<float>(&offset_min_tensor, {0.0f});
72 Output offset_min_op = Const(root.WithOpName("offset_min_op"),
73 Input::Initializer(offset_min_tensor));
74
75 Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
76 test::FillValues<float>(&offset_max_tensor, {255.0f});
77 Output offset_max_op = Const(root.WithOpName("offset_max_op"),
78 Input::Initializer(offset_max_tensor));
79
80 QuantizedBiasAdd quantized_bias_add_op(
81 root.WithOpName("bias_add_op"), quantized_op, offset_op,
82 quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
83 DT_QINT32);
84
85 RequantizationRange requantization_range_op(
86 root.WithOpName("requantization_range_op"),
87 quantized_bias_add_op.output, quantized_bias_add_op.min_out,
88 quantized_bias_add_op.max_out);
89
90 Requantize requantize_op(
91 root.WithOpName("requantize_op"), quantized_bias_add_op.output,
92 quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
93 requantization_range_op.output_min, requantization_range_op.output_max,
94 DT_QUINT8);
95
96 Output dequantize_op =
97 Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
98 requantize_op.output_min, requantize_op.output_max);
99
100 GraphDef graph_def;
101 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
102
103 const string min_max_log_file_name =
104 io::JoinPath(testing::TmpDir(), "min_max_log_file.txt");
105 {
106 std::unique_ptr<WritableFile> file;
107 TF_ASSERT_OK(
108 Env::Default()->NewWritableFile(min_max_log_file_name, &file));
109 TF_ASSERT_OK(file->Append("Something irrelevant\n"));
110 TF_ASSERT_OK(
111 file->Append("[SomePrefix] "
112 ";requantization_range_op__print__;__requant_min_max:"
113 "[-2.4313571][10.584145]\n"));
114 TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
115 }
116
117 TransformFuncContext context;
118 context.input_names = {};
119 context.output_names = {"dequantize_op"};
120 context.params = {{"min_max_log_file", {min_max_log_file_name}}};
121
122 GraphDef frozen_graph_def;
123 TF_EXPECT_OK(
124 FreezeRequantizationRanges(graph_def, context, &frozen_graph_def));
125
126 std::map<string, const NodeDef*> node_map;
127 MapNamesToNodes(frozen_graph_def, &node_map);
128 EXPECT_EQ(0, node_map.count("requantization_range_op"));
129 EXPECT_EQ(1, node_map.count("requantize_op"));
130 const string& min_input =
131 NodeNameFromInput(node_map.at("requantize_op")->input(3));
132 ASSERT_EQ(1, node_map.count(min_input));
133 EXPECT_EQ("Const", node_map.at(min_input)->op());
134 const string& max_input =
135 NodeNameFromInput(node_map.at("requantize_op")->input(4));
136 ASSERT_EQ(1, node_map.count(max_input));
137 EXPECT_EQ("Const", node_map.at(max_input)->op());
138
139 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
140 TF_ASSERT_OK(original_session->Create(graph_def));
141 std::vector<Tensor> original_outputs;
142 TF_ASSERT_OK(
143 original_session->Run({}, {"dequantize_op"}, {}, &original_outputs));
144
145 std::unique_ptr<Session> frozen_session(NewSession(SessionOptions()));
146 TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
147 std::vector<Tensor> frozen_outputs;
148 TF_ASSERT_OK(
149 frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs));
150
151 ASSERT_EQ(original_outputs.size(), frozen_outputs.size());
152 ASSERT_EQ(1, frozen_outputs.size());
153 test::ExpectTensorNear<float>(original_outputs[0], frozen_outputs[0], 0.5);
154 }
155
TestExtractMinMaxRecords()156 void TestExtractMinMaxRecords() {
157 const string min_max_log_file_name =
158 io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt");
159 {
160 std::unique_ptr<WritableFile> file;
161 TF_ASSERT_OK(
162 Env::Default()->NewWritableFile(min_max_log_file_name, &file));
163 TF_ASSERT_OK(file->Append("Something irrelevant\n"));
164 TF_ASSERT_OK(
165 file->Append("[SomePrefix] "
166 ";requantization_range_op__print__;__requant_min_max:"
167 "[-2.4313571][10.584145]\n"));
168 TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
169 TF_ASSERT_OK(file->Append(
170 "[SomeOtherPrefix] "
171 ";other_requantization_range_op__print__;__requant_min_max:"
172 "[-1.0][2.0]\n"));
173 TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
174 TF_ASSERT_OK(
175 file->Append("[SomePrefix] "
176 ";requantization_range_op__print__;__requant_min_max:"
177 "[-1.bad][2.0]\n"));
178 }
179 std::vector<MinMaxRecord> records;
180 TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records));
181 ASSERT_EQ(2, records.size());
182 EXPECT_EQ("requantization_range_op", records[0].name);
183 EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f);
184 EXPECT_NEAR(10.584145f, records[0].max, 1e-5f);
185 EXPECT_EQ("other_requantization_range_op", records[1].name);
186 EXPECT_NEAR(-1.0f, records[1].min, 1e-5f);
187 EXPECT_NEAR(2.0f, records[1].max, 1e-5f);
188 }
189 };
190
TEST_F(FreezeRequantizationRangesTest,TestFreezeRequantizationRanges)191 TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) {
192 TestFreezeRequantizationRanges();
193 }
194
TEST_F(FreezeRequantizationRangesTest,TestExtractMinMaxRecords)195 TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) {
196 TestExtractMinMaxRecords();
197 }
198
199 } // namespace graph_transforms
200 } // namespace tensorflow
201