• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/public/session.h"
28 #include "tensorflow/tools/graph_transforms/transform_utils.h"
29 
30 namespace tensorflow {
31 namespace graph_transforms {
32 
33 // Declare here, so we don't need a public header.
34 Status FoldOldBatchNorms(const GraphDef& input_graph_def,
35                          const TransformFuncContext& context,
36                          GraphDef* output_graph_def);
37 
38 class FoldOldBatchNormsTest : public ::testing::Test {
39  protected:
TestFoldOldBatchNorms()40   void TestFoldOldBatchNorms() {
41     auto root = tensorflow::Scope::NewRootScope();
42     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
43 
44     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
45     test::FillValues<float>(
46         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
47                       -5.0f, -3.0f, -6.0f});
48     Output input_op =
49         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
50 
51     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
52     test::FillValues<float>(&weights_data,
53                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
54     Output weights_op =
55         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
56 
57     Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
58                             {1, 1, 1, 1}, "VALID");
59 
60     Tensor mean_data(DT_FLOAT, TensorShape({2}));
61     test::FillValues<float>(&mean_data, {10.0f, 20.0f});
62     Output mean_op =
63         Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
64 
65     Tensor variance_data(DT_FLOAT, TensorShape({2}));
66     test::FillValues<float>(&variance_data, {0.25f, 0.5f});
67     Output variance_op = Const(root.WithOpName("variance_op"),
68                                Input::Initializer(variance_data));
69 
70     Tensor beta_data(DT_FLOAT, TensorShape({2}));
71     test::FillValues<float>(&beta_data, {0.1f, 0.6f});
72     Output beta_op =
73         Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
74 
75     Tensor gamma_data(DT_FLOAT, TensorShape({2}));
76     test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
77     Output gamma_op =
78         Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
79 
80     GraphDef original_graph_def;
81     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
82 
83     // This is needed because we're trying to convert over a deprecated op which
84     // should only be present in older GraphDef files. Without this we see a
85     // deprecation error.
86     // This is justified because we're trying to test a tool that is expected to
87     // run on legacy files, to help users convert over to less problematic
88     // versions.
89     NodeDef batch_norm_node;
90     batch_norm_node.set_op("BatchNormWithGlobalNormalization");
91     batch_norm_node.set_name("output");
92     AddNodeInput("conv_op", &batch_norm_node);
93     AddNodeInput("mean_op", &batch_norm_node);
94     AddNodeInput("variance_op", &batch_norm_node);
95     AddNodeInput("beta_op", &batch_norm_node);
96     AddNodeInput("gamma_op", &batch_norm_node);
97     SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
98     SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
99     SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
100     *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
101     original_graph_def.mutable_versions()->set_producer(8);
102 
103     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
104     TF_ASSERT_OK(original_session->Create(original_graph_def));
105     std::vector<Tensor> original_outputs;
106     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
107 
108     GraphDef fused_graph_def;
109     TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
110                                    &fused_graph_def));
111 
112     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
113     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
114     std::vector<Tensor> fused_outputs;
115     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
116 
117     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
118 
119     for (const NodeDef& node : fused_graph_def.node()) {
120       EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
121     }
122   }
123 
TestFoldOldBatchNormsAfterDepthwiseConv2dNative()124   void TestFoldOldBatchNormsAfterDepthwiseConv2dNative() {
125     auto root = tensorflow::Scope::NewRootScope();
126     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
127 
128     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
129     test::FillValues<float>(
130         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
131                       -5.0f, -3.0f, -6.0f});
132     Output input_op =
133         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
134 
135     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
136     test::FillValues<float>(&weights_data,
137                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
138     Output weights_op =
139         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
140 
141     Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
142                                            weights_op, {1, 1, 1, 1}, "VALID");
143 
144     Tensor mean_data(DT_FLOAT, TensorShape({4}));
145     test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
146     Output mean_op =
147         Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
148 
149     Tensor variance_data(DT_FLOAT, TensorShape({4}));
150     test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
151     Output variance_op = Const(root.WithOpName("variance_op"),
152                                Input::Initializer(variance_data));
153 
154     Tensor beta_data(DT_FLOAT, TensorShape({4}));
155     test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
156     Output beta_op =
157         Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
158 
159     Tensor gamma_data(DT_FLOAT, TensorShape({4}));
160     test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
161     Output gamma_op =
162         Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
163 
164     GraphDef original_graph_def;
165     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
166 
167     NodeDef batch_norm_node;
168     batch_norm_node.set_op("BatchNormWithGlobalNormalization");
169     batch_norm_node.set_name("output");
170     AddNodeInput("conv_op", &batch_norm_node);
171     AddNodeInput("mean_op", &batch_norm_node);
172     AddNodeInput("variance_op", &batch_norm_node);
173     AddNodeInput("beta_op", &batch_norm_node);
174     AddNodeInput("gamma_op", &batch_norm_node);
175     SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
176     SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
177     SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
178     *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
179     original_graph_def.mutable_versions()->set_producer(8);
180 
181     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
182     TF_ASSERT_OK(original_session->Create(original_graph_def));
183     std::vector<Tensor> original_outputs;
184     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
185 
186     GraphDef fused_graph_def;
187     TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
188                                    &fused_graph_def));
189 
190     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
191     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
192     std::vector<Tensor> fused_outputs;
193     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
194 
195     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
196 
197     for (const NodeDef& node : fused_graph_def.node()) {
198       EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
199     }
200   }
201 
TestFoldFusedBatchNorms()202   void TestFoldFusedBatchNorms() {
203     auto root = tensorflow::Scope::NewRootScope();
204     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
205 
206     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
207     test::FillValues<float>(
208         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
209                       -5.0f, -3.0f, -6.0f});
210     Output input_op =
211         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
212 
213     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
214     test::FillValues<float>(&weights_data,
215                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
216     Output weights_op =
217         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
218 
219     Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
220                             {1, 1, 1, 1}, "VALID");
221 
222     Tensor mean_data(DT_FLOAT, TensorShape({2}));
223     test::FillValues<float>(&mean_data, {10.0f, 20.0f});
224     Output mean_op =
225         Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
226 
227     Tensor variance_data(DT_FLOAT, TensorShape({2}));
228     test::FillValues<float>(&variance_data, {0.25f, 0.5f});
229     Output variance_op = Const(root.WithOpName("variance_op"),
230                                Input::Initializer(variance_data));
231 
232     Tensor beta_data(DT_FLOAT, TensorShape({2}));
233     test::FillValues<float>(&beta_data, {0.1f, 0.6f});
234     Output beta_op =
235         Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
236 
237     Tensor gamma_data(DT_FLOAT, TensorShape({2}));
238     test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
239     Output gamma_op =
240         Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
241 
242     GraphDef original_graph_def;
243     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
244 
245     NodeDef batch_norm_node;
246     batch_norm_node.set_op("FusedBatchNorm");
247     batch_norm_node.set_name("output");
248     AddNodeInput("conv_op", &batch_norm_node);
249     AddNodeInput("gamma_op", &batch_norm_node);
250     AddNodeInput("beta_op", &batch_norm_node);
251     AddNodeInput("mean_op", &batch_norm_node);
252     AddNodeInput("variance_op", &batch_norm_node);
253     SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
254     SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
255     SetNodeAttr("is_training", false, &batch_norm_node);
256     *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
257 
258     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
259     TF_ASSERT_OK(original_session->Create(original_graph_def));
260     std::vector<Tensor> original_outputs;
261     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
262 
263     GraphDef fused_graph_def;
264     TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
265                                    &fused_graph_def));
266 
267     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
268     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
269     std::vector<Tensor> fused_outputs;
270     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
271 
272     test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
273 
274     for (const NodeDef& node : fused_graph_def.node()) {
275       EXPECT_NE("FusedBatchNorm", node.op());
276     }
277   }
278 
TestFoldFusedBatchNormsAfterDepthwiseConv2dNative()279   void TestFoldFusedBatchNormsAfterDepthwiseConv2dNative() {
280     auto root = tensorflow::Scope::NewRootScope();
281     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
282 
283     Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
284     test::FillValues<float>(
285         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
286                       -5.0f, -3.0f, -6.0f});
287     Output input_op =
288         Const(root.WithOpName("input_op"), Input::Initializer(input_data));
289 
290     Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
291     test::FillValues<float>(&weights_data,
292                             {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
293     Output weights_op =
294         Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
295 
296     Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
297                                            weights_op, {1, 1, 1, 1}, "VALID");
298 
299     Tensor mean_data(DT_FLOAT, TensorShape({4}));
300     test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
301     Output mean_op =
302         Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
303 
304     Tensor variance_data(DT_FLOAT, TensorShape({4}));
305     test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
306     Output variance_op = Const(root.WithOpName("variance_op"),
307                                Input::Initializer(variance_data));
308 
309     Tensor beta_data(DT_FLOAT, TensorShape({4}));
310     test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
311     Output beta_op =
312         Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
313 
314     Tensor gamma_data(DT_FLOAT, TensorShape({4}));
315     test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
316     Output gamma_op =
317         Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
318 
319     GraphDef original_graph_def;
320     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
321 
322     NodeDef batch_norm_node;
323     batch_norm_node.set_op("FusedBatchNorm");
324     batch_norm_node.set_name("output");
325     AddNodeInput("conv_op", &batch_norm_node);
326     AddNodeInput("gamma_op", &batch_norm_node);
327     AddNodeInput("beta_op", &batch_norm_node);
328     AddNodeInput("mean_op", &batch_norm_node);
329     AddNodeInput("variance_op", &batch_norm_node);
330     SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
331     SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
332     SetNodeAttr("is_training", false, &batch_norm_node);
333     *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
334 
335     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
336     TF_ASSERT_OK(original_session->Create(original_graph_def));
337     std::vector<Tensor> original_outputs;
338     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
339 
340     GraphDef fused_graph_def;
341     TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
342                                    &fused_graph_def));
343 
344     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
345     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
346     std::vector<Tensor> fused_outputs;
347     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
348 
349     test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
350                       /*rtol=*/2e-5);
351 
352     for (const NodeDef& node : fused_graph_def.node()) {
353       EXPECT_NE("FusedBatchNorm", node.op());
354     }
355   }
356 
TestFoldFusedBatchNormsWithConcat(const bool split)357   void TestFoldFusedBatchNormsWithConcat(const bool split) {
358     auto root = tensorflow::Scope::NewRootScope();
359     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
360 
361     // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
362     auto input_shape =
363         split ? TensorShape({1, 1, 6, 2}) : TensorShape({1, 1, 12, 1});
364     Tensor input_data(DT_FLOAT, input_shape);
365     test::FillValues<float>(
366         &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
367                       -5.0f, -3.0f, -6.0f});
368 
369     Output input0_op =
370         Const(root.WithOpName("input_op0"), Input::Initializer(input_data));
371     // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
372     // The final output shape of concat is always {1, 2, 2, 2}.
373     auto weight_shape =
374         split ? TensorShape({1, 2, 2, 1}) : TensorShape({1, 2, 1, 2});
375     Tensor weights0_data(DT_FLOAT, weight_shape);
376     test::FillValues<float>(&weights0_data, {1.0f, 2.0f, 3.0f, 4.0f});
377     Output weights0_op = Const(root.WithOpName("weights1_op"),
378                                Input::Initializer(weights0_data));
379     Output conv0_op = Conv2D(root.WithOpName("conv1_op"), input0_op,
380                              weights0_op, {1, 1, 1, 1}, "VALID");
381 
382     Output input1_op =
383         Const(root.WithOpName("input1_op"), Input::Initializer(input_data));
384     Tensor weights1_data(DT_FLOAT, weight_shape);
385     test::FillValues<float>(&weights1_data, {1.0f, 2.0f, 3.0f, 4.0f});
386     Output weights1_op = Const(root.WithOpName("weights1_op"),
387                                Input::Initializer(weights1_data));
388     Output conv1_op = Conv2D(root.WithOpName("conv1_op"), input1_op,
389                              weights1_op, {1, 1, 1, 1}, "VALID");
390 
391     Tensor shape_tensor(DT_INT32, TensorShape({}));
392     // Concat at dim 3 if split; otherwise, concat at dim 2.
393     int32_t concat_axis = split ? 3 : 2;
394     test::FillValues<int32>(&shape_tensor, {concat_axis});
395     Output shape_op =
396         Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor));
397     Output concat_op =
398         Concat(root.WithOpName("concat_op"), {conv0_op, conv1_op}, shape_op);
399 
400     Tensor mean_data(DT_FLOAT, TensorShape({2}));
401     test::FillValues<float>(&mean_data, {10.0f, 20.0f});
402     Output mean_op =
403         Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
404 
405     Tensor variance_data(DT_FLOAT, TensorShape({2}));
406     test::FillValues<float>(&variance_data, {0.25f, 0.5f});
407     Output variance_op = Const(root.WithOpName("variance_op"),
408                                Input::Initializer(variance_data));
409 
410     Tensor beta_data(DT_FLOAT, TensorShape({2}));
411     test::FillValues<float>(&beta_data, {0.1f, 0.6f});
412     Output beta_op =
413         Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
414 
415     Tensor gamma_data(DT_FLOAT, TensorShape({2}));
416     test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
417     Output gamma_op =
418         Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
419 
420     GraphDef original_graph_def;
421     TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
422 
423     NodeDef batch_norm_node;
424     batch_norm_node.set_op("FusedBatchNorm");
425     batch_norm_node.set_name("output");
426     AddNodeInput("concat_op", &batch_norm_node);
427     AddNodeInput("gamma_op", &batch_norm_node);
428     AddNodeInput("beta_op", &batch_norm_node);
429     AddNodeInput("mean_op", &batch_norm_node);
430     AddNodeInput("variance_op", &batch_norm_node);
431     SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
432     SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
433     SetNodeAttr("is_training", false, &batch_norm_node);
434     *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
435 
436     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
437     TF_ASSERT_OK(original_session->Create(original_graph_def));
438     std::vector<Tensor> original_outputs;
439     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
440 
441     GraphDef fused_graph_def;
442     TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
443                                    &fused_graph_def));
444 
445     std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
446     TF_ASSERT_OK(fused_session->Create(fused_graph_def));
447     std::vector<Tensor> fused_outputs;
448     TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
449 
450     test::ExpectClose(original_outputs[0], fused_outputs[0]);
451 
452     for (const NodeDef& node : fused_graph_def.node()) {
453       EXPECT_NE("FusedBatchNorm", node.op());
454     }
455   }
456 };
457 
TestFoldFusedBatchNormsWithBatchToSpace()458 void TestFoldFusedBatchNormsWithBatchToSpace() {
459   auto root = tensorflow::Scope::NewRootScope();
460   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
461 
462   Tensor input_data(DT_FLOAT, TensorShape({2, 1, 3, 2}));
463   test::FillValues<float>(
464       &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
465                     -5.0f, -3.0f, -6.0f});
466   Output input_op =
467       Const(root.WithOpName("input_op"), Input::Initializer(input_data));
468 
469   Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
470   test::FillValues<float>(&weights_data,
471                           {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
472   Output weights_op =
473       Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
474 
475   Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
476                           {1, 1, 1, 1}, "VALID");
477 
478   Tensor block_shape_data(DT_INT32, TensorShape({2}));
479   test::FillValues<int32>(&block_shape_data, {1, 2});
480   Output block_shape_op = Const(root.WithOpName("block_shape_op"),
481                                 Input::Initializer(block_shape_data));
482 
483   Tensor crops_data(DT_INT32, TensorShape({2, 2}));
484   test::FillValues<int32>(&crops_data, {0, 0, 0, 1});
485   Output crops_op =
486       Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
487 
488   Output batch_to_space_op =
489       BatchToSpaceND(root.WithOpName("batch_to_space_op"), conv_op,
490                      block_shape_op, crops_data);
491 
492   Tensor mean_data(DT_FLOAT, TensorShape({2}));
493   test::FillValues<float>(&mean_data, {10.0f, 20.0f});
494   Output mean_op =
495       Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
496 
497   Tensor variance_data(DT_FLOAT, TensorShape({2}));
498   test::FillValues<float>(&variance_data, {0.25f, 0.5f});
499   Output variance_op =
500       Const(root.WithOpName("variance_op"), Input::Initializer(variance_data));
501 
502   Tensor beta_data(DT_FLOAT, TensorShape({2}));
503   test::FillValues<float>(&beta_data, {0.1f, 0.6f});
504   Output beta_op =
505       Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
506 
507   Tensor gamma_data(DT_FLOAT, TensorShape({2}));
508   test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
509   Output gamma_op =
510       Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
511 
512   GraphDef original_graph_def;
513   TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
514 
515   NodeDef batch_norm_node;
516   batch_norm_node.set_op("FusedBatchNorm");
517   batch_norm_node.set_name("output");
518   AddNodeInput("batch_to_space_op", &batch_norm_node);
519   AddNodeInput("gamma_op", &batch_norm_node);
520   AddNodeInput("beta_op", &batch_norm_node);
521   AddNodeInput("mean_op", &batch_norm_node);
522   AddNodeInput("variance_op", &batch_norm_node);
523   SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
524   SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
525   SetNodeAttr("is_training", false, &batch_norm_node);
526   *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
527 
528   std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
529   TF_ASSERT_OK(original_session->Create(original_graph_def));
530   std::vector<Tensor> original_outputs;
531   TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
532 
533   GraphDef fused_graph_def;
534   TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
535                                  &fused_graph_def));
536 
537   std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
538   TF_ASSERT_OK(fused_session->Create(fused_graph_def));
539   std::vector<Tensor> fused_outputs;
540   TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
541 
542   test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
543 
544   for (const NodeDef& node : fused_graph_def.node()) {
545     EXPECT_NE("FusedBatchNormWithBatchToSpace", node.op());
546   }
547 }
548 
TEST_F(FoldOldBatchNormsTest,TestFoldOldBatchNorms)549 TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
550   TestFoldOldBatchNorms();
551 }
552 
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNorms)553 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) {
554   TestFoldFusedBatchNorms();
555 }
556 
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsWithConcat)557 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
558   // Test axis is not 3, so all weights and offsets are fused to each of inputs
559   // of conv2d.
560   TestFoldFusedBatchNormsWithConcat(/*split=*/true);
561   // Test axis = 3, BatchNorm weights and offsets will be split before fused
562   // with conv2d weights.
563   TestFoldFusedBatchNormsWithConcat(/*split=*/false);
564 }
565 
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsWithBatchToSpace)566 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithBatchToSpace) {
567   TestFoldFusedBatchNormsWithBatchToSpace();
568 }
569 
TEST_F(FoldOldBatchNormsTest,TestFoldOldBatchNormsAfterDepthwiseConv2dNative)570 TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNormsAfterDepthwiseConv2dNative) {
571   TestFoldOldBatchNormsAfterDepthwiseConv2dNative();
572 }
573 
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsAfterDepthwiseConv2dNative)574 TEST_F(FoldOldBatchNormsTest,
575        TestFoldFusedBatchNormsAfterDepthwiseConv2dNative) {
576   TestFoldFusedBatchNormsAfterDepthwiseConv2dNative();
577 }
578 
579 }  // namespace graph_transforms
580 }  // namespace tensorflow
581