• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28 
29 namespace tensorflow {
30 namespace graph_transforms {
31 
32 // Declare here, so we don't need a public header.
33 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
34                                   const TransformFuncContext& context,
35                                   GraphDef* output_graph_def);
36 struct MinMaxRecord {
37   string name;
38   float min;
39   float max;
40 };
41 Status ExtractMinMaxRecords(const string& log_file_name,
42                             std::vector<MinMaxRecord>* records);
43 
44 class FreezeRequantizationRangesTest : public ::testing::Test {
45  protected:
TestFreezeRequantizationRanges()46   void TestFreezeRequantizationRanges() {
47     auto root = tensorflow::Scope::NewRootScope();
48     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
49 
50     Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
51     test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
52     Output quantized_op = Const(root.WithOpName("quantized_op"),
53                                 Input::Initializer(quantized_tensor));
54 
55     Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
56     test::FillValues<float>(&quantized_min_tensor, {2.0f});
57     Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
58                                     Input::Initializer(quantized_min_tensor));
59 
60     Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
61     test::FillValues<float>(&quantized_max_tensor, {2.0f});
62     Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
63                                     Input::Initializer(quantized_min_tensor));
64 
65     Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
66     test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
67     Output offset_op =
68         Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
69 
70     Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
71     test::FillValues<float>(&offset_min_tensor, {0.0f});
72     Output offset_min_op = Const(root.WithOpName("offset_min_op"),
73                                  Input::Initializer(offset_min_tensor));
74 
75     Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
76     test::FillValues<float>(&offset_max_tensor, {255.0f});
77     Output offset_max_op = Const(root.WithOpName("offset_max_op"),
78                                  Input::Initializer(offset_max_tensor));
79 
80     QuantizedBiasAdd quantized_bias_add_op(
81         root.WithOpName("bias_add_op"), quantized_op, offset_op,
82         quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
83         DT_QINT32);
84 
85     RequantizationRange requantization_range_op(
86         root.WithOpName("requantization_range_op"),
87         quantized_bias_add_op.output, quantized_bias_add_op.min_out,
88         quantized_bias_add_op.max_out);
89 
90     Requantize requantize_op(
91         root.WithOpName("requantize_op"), quantized_bias_add_op.output,
92         quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
93         requantization_range_op.output_min, requantization_range_op.output_max,
94         DT_QUINT8);
95 
96     Output dequantize_op =
97         Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
98                    requantize_op.output_min, requantize_op.output_max);
99 
100     GraphDef graph_def;
101     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
102 
103     const string min_max_log_file_name =
104         io::JoinPath(testing::TmpDir(), "min_max_log_file.txt");
105     {
106       std::unique_ptr<WritableFile> file;
107       TF_ASSERT_OK(
108           Env::Default()->NewWritableFile(min_max_log_file_name, &file));
109       TF_ASSERT_OK(file->Append("Something irrelevant\n"));
110       TF_ASSERT_OK(
111           file->Append("[SomePrefix] "
112                        ";requantization_range_op__print__;__requant_min_max:"
113                        "[-2.4313571][10.584145]\n"));
114       TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
115     }
116 
117     TransformFuncContext context;
118     context.input_names = {};
119     context.output_names = {"dequantize_op"};
120     context.params = {{"min_max_log_file", {min_max_log_file_name}}};
121 
122     GraphDef frozen_graph_def;
123     TF_EXPECT_OK(
124         FreezeRequantizationRanges(graph_def, context, &frozen_graph_def));
125 
126     std::map<string, const NodeDef*> node_map;
127     MapNamesToNodes(frozen_graph_def, &node_map);
128     EXPECT_EQ(0, node_map.count("requantization_range_op"));
129     EXPECT_EQ(1, node_map.count("requantize_op"));
130     const string& min_input =
131         NodeNameFromInput(node_map.at("requantize_op")->input(3));
132     ASSERT_EQ(1, node_map.count(min_input));
133     EXPECT_EQ("Const", node_map.at(min_input)->op());
134     const string& max_input =
135         NodeNameFromInput(node_map.at("requantize_op")->input(4));
136     ASSERT_EQ(1, node_map.count(max_input));
137     EXPECT_EQ("Const", node_map.at(max_input)->op());
138 
139     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
140     TF_ASSERT_OK(original_session->Create(graph_def));
141     std::vector<Tensor> original_outputs;
142     TF_ASSERT_OK(
143         original_session->Run({}, {"dequantize_op"}, {}, &original_outputs));
144 
145     std::unique_ptr<Session> frozen_session(NewSession(SessionOptions()));
146     TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
147     std::vector<Tensor> frozen_outputs;
148     TF_ASSERT_OK(
149         frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs));
150 
151     ASSERT_EQ(original_outputs.size(), frozen_outputs.size());
152     ASSERT_EQ(1, frozen_outputs.size());
153     test::ExpectTensorNear<float>(original_outputs[0], frozen_outputs[0], 0.5);
154   }
155 
TestExtractMinMaxRecords()156   void TestExtractMinMaxRecords() {
157     const string min_max_log_file_name =
158         io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt");
159     {
160       std::unique_ptr<WritableFile> file;
161       TF_ASSERT_OK(
162           Env::Default()->NewWritableFile(min_max_log_file_name, &file));
163       TF_ASSERT_OK(file->Append("Something irrelevant\n"));
164       TF_ASSERT_OK(
165           file->Append("[SomePrefix] "
166                        ";requantization_range_op__print__;__requant_min_max:"
167                        "[-2.4313571][10.584145]\n"));
168       TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
169       TF_ASSERT_OK(file->Append(
170           "[SomeOtherPrefix] "
171           ";other_requantization_range_op__print__;__requant_min_max:"
172           "[-1.0][2.0]\n"));
173       TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
174       TF_ASSERT_OK(
175           file->Append("[SomePrefix] "
176                        ";requantization_range_op__print__;__requant_min_max:"
177                        "[-1.bad][2.0]\n"));
178     }
179     std::vector<MinMaxRecord> records;
180     TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records));
181     ASSERT_EQ(2, records.size());
182     EXPECT_EQ("requantization_range_op", records[0].name);
183     EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f);
184     EXPECT_NEAR(10.584145f, records[0].max, 1e-5f);
185     EXPECT_EQ("other_requantization_range_op", records[1].name);
186     EXPECT_NEAR(-1.0f, records[1].min, 1e-5f);
187     EXPECT_NEAR(2.0f, records[1].max, 1e-5f);
188   }
189 };
190 
TEST_F(FreezeRequantizationRangesTest,TestFreezeRequantizationRanges)191 TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) {
192   TestFreezeRequantizationRanges();
193 }
194 
TEST_F(FreezeRequantizationRangesTest,TestExtractMinMaxRecords)195 TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) {
196   TestExtractMinMaxRecords();
197 }
198 
199 }  // namespace graph_transforms
200 }  // namespace tensorflow
201