• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/math_ops.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28 
29 namespace tensorflow {
30 namespace graph_transforms {
31 
32 // Declare here, so we don't need a public header.
33 Status FoldBatchNorms(const GraphDef& input_graph_def,
34                       const TransformFuncContext& context,
35                       GraphDef* output_graph_def);
36 
37 class FoldBatchNormsTest : public ::testing::Test {
38  protected:
TestFoldBatchNormsConv2D()39   void TestFoldBatchNormsConv2D() {
40     auto root = tensorflow::Scope::NewRootScope();
41     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
42 
43     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
44     test::FillValues<float>(
45         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
46                       -5.0f, -3.0f, -6.0f});
47     Output input_op =
48         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
49 
50     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
51     test::FillValues<float>(&weights_data,
52                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
53     Output weights_op =
54         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
55 
56     Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
57                             {1, 1, 1, 1}, "VALID");
58 
59     Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
60     test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
61     Output mul_values_op = Const(root.WithOpName("mul_values"),
62                                  Input::Initializer(mul_values_data));
63 
64     Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
65 
66     GraphDef original_graph_def;
67     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
68 
69     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
70     TF_ASSERT_OK(original_session->Create(original_graph_def));
71     std::vector<Tensor> original_outputs;
72     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
73 
74     GraphDef fused_graph_def;
75     TF_ASSERT_OK(
76         FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
77 
78     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
79     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
80     std::vector<Tensor> fused_outputs;
81     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
82 
83     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
84 
85     for (const NodeDef& node : fused_graph_def.node()) {
86       EXPECT_NE("Mul", node.op());
87     }
88   }
89 
TestFoldBatchNormsDepthwiseConv2dNative()90   void TestFoldBatchNormsDepthwiseConv2dNative() {
91     auto root = tensorflow::Scope::NewRootScope();
92     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
93 
94     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
95     test::FillValues<float>(
96         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
97                       -5.0f, -3.0f, -6.0f});
98     Output input_op =
99         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
100 
101     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
102     test::FillValues<float>(&weights_data,
103                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
104     Output weights_op =
105         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
106 
107     Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
108                                            weights_op, {1, 1, 1, 1}, "VALID");
109 
110     Tensor mul_values_data(DT_FLOAT, TensorShape({4}));
111     test::FillValues<float>(&mul_values_data, {2.0f, 3.0f, 4.0f, 5.0f});
112     Output mul_values_op = Const(root.WithOpName("mul_values"),
113                                  Input::Initializer(mul_values_data));
114 
115     Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
116 
117     GraphDef original_graph_def;
118     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
119 
120     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
121     TF_ASSERT_OK(original_session->Create(original_graph_def));
122     std::vector<Tensor> original_outputs;
123     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
124 
125     GraphDef fused_graph_def;
126     TF_ASSERT_OK(
127         FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
128 
129     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
130     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
131     std::vector<Tensor> fused_outputs;
132     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
133 
134     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
135 
136     for (const NodeDef& node : fused_graph_def.node()) {
137       EXPECT_NE("Mul", node.op());
138     }
139   }
140 
TestFoldBatchNormsConv2DShared()141   void TestFoldBatchNormsConv2DShared() {
142     auto root = tensorflow::Scope::NewRootScope();
143     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
144 
145     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
146     test::FillValues<float>(
147         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
148                       -5.0f, -3.0f, -6.0f});
149     Output input_op =
150         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
151 
152     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
153     test::FillValues<float>(&weights_data,
154                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
155     Output weights_op =
156         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
157 
158     Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
159                             {1, 1, 1, 1}, "VALID");
160 
161     Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
162     test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
163     Output mul_values_op = Const(root.WithOpName("mul_values"),
164                                  Input::Initializer(mul_values_data));
165 
166     Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
167 
168     Tensor mul_values_data_2(DT_FLOAT, TensorShape({2}));
169     test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f});
170     Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"),
171                                    Input::Initializer(mul_values_data));
172 
173     Output mul_op_2 =
174         Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2);
175 
176     GraphDef original_graph_def;
177     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
178 
179     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
180     TF_ASSERT_OK(original_session->Create(original_graph_def));
181     std::vector<Tensor> original_outputs;
182     TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {},
183                                        &original_outputs));
184 
185     GraphDef fused_graph_def;
186     TF_ASSERT_OK(FoldBatchNorms(
187         original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def));
188 
189     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
190     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
191     std::vector<Tensor> fused_outputs;
192     TF_ASSERT_OK(
193         fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs));
194 
195     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
196     test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5);
197   }
198 
TestFoldBatchNormsMatMul()199   void TestFoldBatchNormsMatMul() {
200     auto root = tensorflow::Scope::NewRootScope();
201     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
202 
203     Tensor input_data(DT_FLOAT, TensorShape({6, 2}));
204     test::FillValues<float>(
205         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
206                       -5.0f, -3.0f, -6.0f});
207     Output input_op =
208         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
209 
210     Tensor weights_data(DT_FLOAT, TensorShape({2, 2}));
211     test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f});
212     Output weights_op =
213         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
214 
215     Output matmul_op =
216         MatMul(root.WithOpName("matmul_op"), input_op, weights_op);
217 
218     Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
219     test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
220     Output mul_values_op = Const(root.WithOpName("mul_values"),
221                                  Input::Initializer(mul_values_data));
222 
223     Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op);
224 
225     GraphDef original_graph_def;
226     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
227 
228     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
229     TF_ASSERT_OK(original_session->Create(original_graph_def));
230     std::vector<Tensor> original_outputs;
231     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
232 
233     GraphDef fused_graph_def;
234     TF_ASSERT_OK(
235         FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
236 
237     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
238     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
239     std::vector<Tensor> fused_outputs;
240     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
241 
242     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
243 
244     for (const NodeDef& node : fused_graph_def.node()) {
245       EXPECT_NE("Mul", node.op());
246     }
247   }
248 };
249 
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsConv2D)250 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
251   TestFoldBatchNormsConv2D();
252 }
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsMatMul)253 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
254   TestFoldBatchNormsMatMul();
255 }
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsDepthwiseConv2dNative)256 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsDepthwiseConv2dNative) {
257   TestFoldBatchNormsDepthwiseConv2dNative();
258 }
259 
260 }  // namespace graph_transforms
261 }  // namespace tensorflow
262