• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/protobuf/device_properties.pb.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
31 
32 class LayoutOptimizerTest : public ::testing::Test {
33  protected:
SetUp()34   void SetUp() override {
35     DeviceProperties device_properties;
36     device_properties.set_type("GPU");
37     device_properties.mutable_environment()->insert({"architecture", "6"});
38     virtual_cluster_.reset(new VirtualCluster({{"/GPU:0", device_properties}}));
39   }
40 
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)41   Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
42                       const string& padding) {
43     return SimpleConv2D(s, input_size, filter_size, padding, "");
44   }
45 
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)46   Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
47                       const string& padding, const string& device) {
48     int batch_size = 8;
49     int input_height = input_size;
50     int input_width = input_size;
51     int input_depth = 3;
52     int filter_count = 2;
53     int stride = 1;
54     TensorShape input_shape(
55         {batch_size, input_height, input_width, input_depth});
56     Tensor input_data(DT_FLOAT, input_shape);
57     test::FillIota<float>(&input_data, 1.0f);
58     Output input =
59         ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
60 
61     TensorShape filter_shape(
62         {filter_size, filter_size, input_depth, filter_count});
63     Tensor filter_data(DT_FLOAT, filter_shape);
64     test::FillIota<float>(&filter_data, 1.0f);
65     Output filter =
66         ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
67 
68     Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
69                               filter, {1, stride, stride, 1}, padding);
70     return conv;
71   }
72 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)73   Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
74                                    int filter_size, const string& padding) {
75     return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true);
76   }
77 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool const_input_size)78   Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
79                                    int filter_size, const string& padding,
80                                    bool const_input_size) {
81     int batch_size = 128;
82     int input_height = input_size;
83     int input_width = input_size;
84     int input_depth = 3;
85     int filter_count = 2;
86     int stride = 1;
87     TensorShape input_sizes_shape({4});
88     Tensor input_data(DT_INT32, input_sizes_shape);
89     test::FillValues<int>(&input_data,
90                           {batch_size, input_height, input_width, input_depth});
91     Output input_sizes =
92         ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
93 
94     TensorShape filter_shape(
95         {filter_size, filter_size, input_depth, filter_count});
96     Tensor filter_data(DT_FLOAT, filter_shape);
97     test::FillIota<float>(&filter_data, 1.0f);
98     Output filter =
99         ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
100 
101     int output_height = input_height;
102     int output_width = input_width;
103     TensorShape output_shape(
104         {batch_size, output_height, output_width, filter_count});
105     Tensor output_data(DT_FLOAT, output_shape);
106     test::FillIota<float>(&output_data, 1.0f);
107     Output output =
108         ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
109 
110     Output conv_backprop_input;
111     Output input_sizes_i =
112         ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
113     if (const_input_size) {
114       conv_backprop_input = ops::Conv2DBackpropInput(
115           s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
116           {1, stride, stride, 1}, padding);
117     } else {
118       conv_backprop_input = ops::Conv2DBackpropInput(
119           s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
120           {1, stride, stride, 1}, padding);
121     }
122     return conv_backprop_input;
123   }
124 
GetAttrValue(const NodeDef & node)125   Tensor GetAttrValue(const NodeDef& node) {
126     Tensor tensor;
127     CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
128     return tensor;
129   }
130 
SimpleFusedBatchNormGrad(tensorflow::Scope * s,bool is_training)131   Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
132     int batch_size = 16;
133     int input_height = 8;
134     int input_width = 8;
135     int input_channels = 3;
136     TensorShape shape({batch_size, input_height, input_width, input_channels});
137     Tensor data(DT_FLOAT, shape);
138     test::FillIota<float>(&data, 1.0f);
139     Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
140     Output y_backprop =
141         ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
142 
143     TensorShape shape_vector({input_channels});
144     Tensor data_vector(DT_FLOAT, shape_vector);
145     test::FillIota<float>(&data_vector, 2.0f);
146     Output scale =
147         ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
148     Output reserve1 =
149         ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
150     Output reserve2 =
151         ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
152 
153     ops::FusedBatchNormGrad::Attrs attrs;
154     attrs.is_training_ = is_training;
155     auto output =
156         ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
157                                 x, scale, reserve1, reserve2, attrs);
158     return output.x_backprop;
159   }
160 
161   std::unique_ptr<VirtualCluster> virtual_cluster_;
162 };
163 
TEST_F(LayoutOptimizerTest,Conv2DBackpropInput)164 TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
165   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
166   auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME");
167   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
168   GrapplerItem item;
169   TF_CHECK_OK(s.ToGraphDef(&item.graph));
170   LayoutOptimizer optimizer;
171   GraphDef output;
172 
173   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
174   NodeMap node_map(&output);
175   string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
176   auto input_sizes_node = node_map.GetNode(input_name);
177   CHECK(input_sizes_node);
178   auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
179   CHECK(conv2d_backprop_node);
180   EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
181   auto input_sizes = GetAttrValue(*input_sizes_node);
182   Tensor input_sizes_expected(DT_INT32, {4});
183   test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
184   test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
185 }
186 
TEST_F(LayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)187 TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
188   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
189   auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false);
190   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
191   GrapplerItem item;
192   TF_CHECK_OK(s.ToGraphDef(&item.graph));
193   LayoutOptimizer optimizer;
194   GraphDef output;
195 
196   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
197   NodeMap node_map(&output);
198   auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
199   CHECK(conv2d_backprop_node);
200   EXPECT_EQ(conv2d_backprop_node->input(0),
201             "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
202   auto input_sizes_node = node_map.GetNode(
203       "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
204   CHECK(input_sizes_node);
205   EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
206   EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
207 }
208 
TEST_F(LayoutOptimizerTest,FilterSizeIsOne)209 TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
210   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211   auto conv = SimpleConv2D(&s, 2, 1, "SAME");
212   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
213   GrapplerItem item;
214   TF_CHECK_OK(s.ToGraphDef(&item.graph));
215   LayoutOptimizer optimizer;
216   GraphDef output;
217   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
218   NodeMap node_map(&output);
219   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
220 }
221 
TEST_F(LayoutOptimizerTest,FilterSizeNotOne)222 TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
223   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
224   auto conv = SimpleConv2D(&s, 2, 1, "SAME");
225   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
226   GrapplerItem item;
227   TF_CHECK_OK(s.ToGraphDef(&item.graph));
228   LayoutOptimizer optimizer;
229   GraphDef output;
230   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
231   NodeMap node_map(&output);
232   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
233 }
234 
TEST_F(LayoutOptimizerTest,EqualSizeWithValidPadding)235 TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
236   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
237   auto conv = SimpleConv2D(&s, 2, 2, "VALID");
238   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
239   GrapplerItem item;
240   TF_CHECK_OK(s.ToGraphDef(&item.graph));
241   LayoutOptimizer optimizer;
242   GraphDef output;
243   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
244   NodeMap node_map(&output);
245   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
246 }
247 
TEST_F(LayoutOptimizerTest,EqualSizeWithSamePadding)248 TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
249   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
250   auto conv = SimpleConv2D(&s, 2, 2, "SAME");
251   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
252   GrapplerItem item;
253   TF_CHECK_OK(s.ToGraphDef(&item.graph));
254   LayoutOptimizer optimizer;
255   GraphDef output;
256   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
257   NodeMap node_map(&output);
258   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
259 }
260 
TEST_F(LayoutOptimizerTest,NotEqualSizeWithValidPadding)261 TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
262   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
263   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
264   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
265   GrapplerItem item;
266   TF_CHECK_OK(s.ToGraphDef(&item.graph));
267   LayoutOptimizer optimizer;
268   GraphDef output;
269   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
270   NodeMap node_map(&output);
271   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
272 }
273 
TEST_F(LayoutOptimizerTest,Pad)274 TEST_F(LayoutOptimizerTest, Pad) {
275   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
276   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
277   auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
278   auto p = ops::Pad(s.WithOpName("p"), conv, c);
279   auto o = ops::Identity(s.WithOpName("o"), p);
280   GrapplerItem item;
281   TF_CHECK_OK(s.ToGraphDef(&item.graph));
282   LayoutOptimizer optimizer;
283   GraphDef output;
284   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
285   NodeMap node_map(&output);
286 
287   auto pad = node_map.GetNode("p");
288   EXPECT_EQ(pad->input(0), "Conv2D");
289 
290   auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
291   EXPECT_TRUE(pad_const);
292   EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
293   Tensor tensor;
294   EXPECT_TRUE(
295       tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
296   Tensor tensor_expected(DT_INT32, {4, 2});
297   test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
298   test::ExpectTensorEqual<int>(tensor_expected, tensor);
299 }
300 
TEST_F(LayoutOptimizerTest,Connectivity)301 TEST_F(LayoutOptimizerTest, Connectivity) {
302   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
303   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
304   auto i1 = ops::Identity(s.WithOpName("i1"), conv);
305   auto i2 = ops::Identity(s.WithOpName("i2"), i1);
306   auto i3 = ops::Identity(s.WithOpName("i3"), i2);
307   GrapplerItem item;
308   TF_CHECK_OK(s.ToGraphDef(&item.graph));
309   // Make the graph not in topological order to test the handling of multi-hop
310   // connectivity (here we say two nodes are connected if all nodes in the
311   // middle are layout agnostic). If the graph is already in topological order,
312   // the problem is easier, where layout optimizer only needs to check
313   // single-hop connectivity.
314   NodeMap node_map_original(&item.graph);
315   auto node_i1 = node_map_original.GetNode("i1");
316   auto node_i2 = node_map_original.GetNode("i2");
317   node_i2->Swap(node_i1);
318   LayoutOptimizer optimizer;
319   GraphDef output;
320   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
321   NodeMap node_map_output(&output);
322   auto node_i2_output = node_map_output.GetNode("i2");
323   // Layout optimizer should process i2, as it detects i2 is connected with the
324   // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
325   // directly connected to the Conv2D node. The two added transposes between
326   // i1 and i2 should cancel each other, and as a result i2 is directly
327   // connected to i1.
328   EXPECT_EQ(node_i2_output->input(0), "i1");
329 }
330 
TEST_F(LayoutOptimizerTest,ConnectivityBinaryOpWithInputScalarAnd4D)331 TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
332   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
333   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
334   auto i1 = ops::Identity(s.WithOpName("i1"), conv);
335   auto i2 = ops::Identity(s.WithOpName("i2"), i1);
336   auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
337   auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
338   auto i3 = ops::Identity(s.WithOpName("i3"), sub);
339   auto i4 = ops::Identity(s.WithOpName("i4"), i3);
340   auto i5 = ops::Identity(s.WithOpName("i5"), i4);
341   auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
342   auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
343   auto i6 = ops::Identity(s.WithOpName("i6"), mul);
344   GrapplerItem item;
345   TF_CHECK_OK(s.ToGraphDef(&item.graph));
346   // Make the graph not in topological order to test the handling of multi-hop
347   // connectivity (here we say two nodes are connected if all nodes in the
348   // middle are layout agnostic). If the graph is already in topological order,
349   // the problem is easier, where layout optimizer only needs to check
350   // single-hop connectivity.
351   NodeMap node_map_original(&item.graph);
352   auto node_i1 = node_map_original.GetNode("i1");
353   auto node_mul = node_map_original.GetNode("mul");
354   node_mul->Swap(node_i1);
355   LayoutOptimizer optimizer;
356   GraphDef output;
357   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
358   NodeMap node_map_output(&output);
359   auto mul_node = node_map_output.GetNode("mul");
360   EXPECT_EQ(mul_node->input(0), "scalar_mul");
361   EXPECT_EQ(mul_node->input(1), "i5");
362 }
363 
TEST_F(LayoutOptimizerTest,PreserveFetch)364 TEST_F(LayoutOptimizerTest, PreserveFetch) {
365   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
366   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
367   auto i = ops::Identity(s.WithOpName("i"), conv);
368   GrapplerItem item;
369   item.fetch.push_back("Conv2D");
370   TF_CHECK_OK(s.ToGraphDef(&item.graph));
371   LayoutOptimizer optimizer;
372   GraphDef output;
373   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
374   NodeMap node_map(&output);
375   auto conv_node = node_map.GetNode("Conv2D");
376   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
377 }
378 
TEST_F(LayoutOptimizerTest,EmptyDevice)379 TEST_F(LayoutOptimizerTest, EmptyDevice) {
380   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
381   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
382   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
383   GrapplerItem item;
384   TF_CHECK_OK(s.ToGraphDef(&item.graph));
385   LayoutOptimizer optimizer;
386   GraphDef output;
387   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
388   NodeMap node_map(&output);
389   auto conv_node = node_map.GetNode("Conv2D");
390   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
391 }
392 
TEST_F(LayoutOptimizerTest,GPUDevice)393 TEST_F(LayoutOptimizerTest, GPUDevice) {
394   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
395   auto conv =
396       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
397   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
398   GrapplerItem item;
399   TF_CHECK_OK(s.ToGraphDef(&item.graph));
400   LayoutOptimizer optimizer;
401   GraphDef output;
402   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
403   NodeMap node_map(&output);
404   auto conv_node = node_map.GetNode("Conv2D");
405   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
406 }
407 
TEST_F(LayoutOptimizerTest,CPUDeviceLowercase)408 TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
409   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
410   auto conv =
411       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
412   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
413   GrapplerItem item;
414   TF_CHECK_OK(s.ToGraphDef(&item.graph));
415   LayoutOptimizer optimizer;
416   GraphDef output;
417   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
418   NodeMap node_map(&output);
419   auto conv_node = node_map.GetNode("Conv2D");
420   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
421 }
422 
TEST_F(LayoutOptimizerTest,CPUDeviceUppercase)423 TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
424   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
425   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
426   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
427   GrapplerItem item;
428   TF_CHECK_OK(s.ToGraphDef(&item.graph));
429   LayoutOptimizer optimizer;
430   GraphDef output;
431   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
432   NodeMap node_map(&output);
433   auto conv_node = node_map.GetNode("Conv2D");
434   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
435 }
436 
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingTrue)437 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
438   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
439   auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
440   Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
441   GrapplerItem item;
442   TF_CHECK_OK(s.ToGraphDef(&item.graph));
443   LayoutOptimizer optimizer;
444   GraphDef output;
445   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
446   NodeMap node_map(&output);
447   auto conv_node = node_map.GetNode("FusedBatchNormGrad");
448   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
449 }
450 
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingFalse)451 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
452   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
453   auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
454   Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
455   GrapplerItem item;
456   TF_CHECK_OK(s.ToGraphDef(&item.graph));
457   LayoutOptimizer optimizer;
458   GraphDef output;
459   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
460   NodeMap node_map(&output);
461   auto conv_node = node_map.GetNode("FusedBatchNormGrad");
462   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
463 }
464 
TEST_F(LayoutOptimizerTest,SplitDimC)465 TEST_F(LayoutOptimizerTest, SplitDimC) {
466   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
467   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
468   auto c = ops::Const(s.WithOpName("c"), 3, {});
469   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
470   auto i = ops::Identity(s.WithOpName("i"), split[0]);
471   GrapplerItem item;
472   TF_CHECK_OK(s.ToGraphDef(&item.graph));
473   LayoutOptimizer optimizer;
474   GraphDef output;
475   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
476   NodeMap node_map(&output);
477   auto split_node = node_map.GetNode("split");
478   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
479   EXPECT_EQ(split_node->input(1), "Conv2D");
480   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
481   EXPECT_EQ(split_const->op(), "Const");
482   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
483 }
484 
TEST_F(LayoutOptimizerTest,SplitDimH)485 TEST_F(LayoutOptimizerTest, SplitDimH) {
486   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
487   auto conv = SimpleConv2D(&s, 6, 2, "SAME");
488   auto c = ops::Const(s.WithOpName("c"), 1, {});
489   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
490   auto i = ops::Identity(s.WithOpName("i"), split[0]);
491   GrapplerItem item;
492   TF_CHECK_OK(s.ToGraphDef(&item.graph));
493   LayoutOptimizer optimizer;
494   GraphDef output;
495   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
496   NodeMap node_map(&output);
497   auto split_node = node_map.GetNode("split");
498   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
499   EXPECT_EQ(split_node->input(1), "Conv2D");
500   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
501   EXPECT_EQ(split_const->op(), "Const");
502   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
503 }
504 
TEST_F(LayoutOptimizerTest,SplitDimW)505 TEST_F(LayoutOptimizerTest, SplitDimW) {
506   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
507   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
508   auto c = ops::Const(s.WithOpName("c"), 2, {});
509   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
510   auto i = ops::Identity(s.WithOpName("i"), split[0]);
511   GrapplerItem item;
512   TF_CHECK_OK(s.ToGraphDef(&item.graph));
513   LayoutOptimizer optimizer;
514   GraphDef output;
515   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
516   NodeMap node_map(&output);
517   auto split_node = node_map.GetNode("split");
518   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
519   EXPECT_EQ(split_node->input(1), "Conv2D");
520   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
521   EXPECT_EQ(split_const->op(), "Const");
522   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
523 }
524 
TEST_F(LayoutOptimizerTest,SplitDimN)525 TEST_F(LayoutOptimizerTest, SplitDimN) {
526   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
527   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
528   auto c = ops::Const(s.WithOpName("c"), 0, {});
529   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
530   auto i = ops::Identity(s.WithOpName("i"), split[0]);
531   GrapplerItem item;
532   TF_CHECK_OK(s.ToGraphDef(&item.graph));
533   LayoutOptimizer optimizer;
534   GraphDef output;
535   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
536   NodeMap node_map(&output);
537   auto split_node = node_map.GetNode("split");
538   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
539   EXPECT_EQ(split_node->input(1), "Conv2D");
540   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
541   EXPECT_EQ(split_const->op(), "Const");
542   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
543 }
544 
TEST_F(LayoutOptimizerTest,SplitNonConstDim)545 TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
546   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
547   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
548   auto c = ops::Const(s.WithOpName("c"), 0, {});
549   auto i1 = ops::Identity(s.WithOpName("i1"), c);
550   auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
551   auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
552   GrapplerItem item;
553   TF_CHECK_OK(s.ToGraphDef(&item.graph));
554   LayoutOptimizer optimizer;
555   GraphDef output;
556   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
557   NodeMap node_map(&output);
558   auto split_node = node_map.GetNode("split");
559   EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
560   EXPECT_EQ(split_node->input(1), "Conv2D");
561   auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
562   EXPECT_EQ(map_node->op(), "DataFormatDimMap");
563   EXPECT_EQ(map_node->input(0), "i1");
564 }
565 
TEST_F(LayoutOptimizerTest,SplitSamePortToMultipleInputsOfSameNode)566 TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
567   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
568   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
569   auto axis = ops::Const(s.WithOpName("axis"), 3);
570   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
571   auto concat =
572       ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
573   auto o = ops::Identity(s.WithOpName("o"), concat);
574   GrapplerItem item;
575   TF_CHECK_OK(s.ToGraphDef(&item.graph));
576   LayoutOptimizer optimizer;
577   GraphDef output;
578   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
579   NodeMap node_map(&output);
580   auto concat_node = node_map.GetNode("concat");
581   EXPECT_EQ(concat_node->input(0), "split:1");
582   EXPECT_EQ(concat_node->input(1), "split:1");
583   EXPECT_EQ(concat_node->input(2), "split:1");
584   EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
585   auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
586   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
587 }
588 
TEST_F(LayoutOptimizerTest,ConcatDimH)589 TEST_F(LayoutOptimizerTest, ConcatDimH) {
590   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
591   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
592   auto axis = ops::Const(s.WithOpName("axis"), 1);
593   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
594   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
595   auto o = ops::Identity(s.WithOpName("o"), concat);
596   GrapplerItem item;
597   TF_CHECK_OK(s.ToGraphDef(&item.graph));
598   LayoutOptimizer optimizer;
599   GraphDef output;
600   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
601   NodeMap node_map(&output);
602   auto concat_node = node_map.GetNode("concat");
603   EXPECT_EQ(concat_node->input(0), "split");
604   EXPECT_EQ(concat_node->input(1), "split:1");
605   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
606   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
607   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
608 }
609 
TEST_F(LayoutOptimizerTest,ConcatNonConst)610 TEST_F(LayoutOptimizerTest, ConcatNonConst) {
611   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
612   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
613   auto axis = ops::Const(s.WithOpName("axis"), 1);
614   auto i = ops::Identity(s.WithOpName("i"), axis);
615   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
616   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
617   auto o = ops::Identity(s.WithOpName("o"), concat);
618   GrapplerItem item;
619   TF_CHECK_OK(s.ToGraphDef(&item.graph));
620   LayoutOptimizer optimizer;
621   GraphDef output;
622   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
623   NodeMap node_map(&output);
624   auto concat_node = node_map.GetNode("concat");
625   EXPECT_EQ(concat_node->input(0), "split");
626   EXPECT_EQ(concat_node->input(1), "split:1");
627   EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
628   auto concat_dim =
629       node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
630   EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
631   EXPECT_EQ(concat_dim->input(0), "i");
632 }
633 
TEST_F(LayoutOptimizerTest,ConcatDimW)634 TEST_F(LayoutOptimizerTest, ConcatDimW) {
635   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
636   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
637   auto axis = ops::Const(s.WithOpName("axis"), 2);
638   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
639   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
640   auto o = ops::Identity(s.WithOpName("o"), concat);
641   GrapplerItem item;
642   TF_CHECK_OK(s.ToGraphDef(&item.graph));
643   LayoutOptimizer optimizer;
644   GraphDef output;
645   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
646   NodeMap node_map(&output);
647   auto concat_node = node_map.GetNode("concat");
648   EXPECT_EQ(concat_node->input(0), "split");
649   EXPECT_EQ(concat_node->input(1), "split:1");
650   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
651   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
652   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
653 }
654 
TEST_F(LayoutOptimizerTest,ConcatDimN)655 TEST_F(LayoutOptimizerTest, ConcatDimN) {
656   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
657   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
658   auto axis = ops::Const(s.WithOpName("axis"), 0);
659   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
660   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
661   auto o = ops::Identity(s.WithOpName("o"), concat);
662   GrapplerItem item;
663   TF_CHECK_OK(s.ToGraphDef(&item.graph));
664   LayoutOptimizer optimizer;
665   GraphDef output;
666   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
667   NodeMap node_map(&output);
668   auto concat_node = node_map.GetNode("concat");
669   EXPECT_EQ(concat_node->input(0), "split");
670   EXPECT_EQ(concat_node->input(1), "split:1");
671   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
672   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
673   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
674 }
675 
TEST_F(LayoutOptimizerTest,ConcatDimC)676 TEST_F(LayoutOptimizerTest, ConcatDimC) {
677   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
678   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
679   auto axis = ops::Const(s.WithOpName("axis"), 3);
680   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
681   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
682   auto o = ops::Identity(s.WithOpName("o"), concat);
683   GrapplerItem item;
684   TF_CHECK_OK(s.ToGraphDef(&item.graph));
685   LayoutOptimizer optimizer;
686   GraphDef output;
687   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
688   NodeMap node_map(&output);
689   auto concat_node = node_map.GetNode("concat");
690   EXPECT_EQ(concat_node->input(0), "split");
691   EXPECT_EQ(concat_node->input(1), "split:1");
692   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
693   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
694   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
695 }
696 
TEST_F(LayoutOptimizerTest,Sum)697 TEST_F(LayoutOptimizerTest, Sum) {
698   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
699   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
700   auto reduction_indices =
701       ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
702   auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
703   auto o = ops::Identity(s.WithOpName("o"), sum);
704   GrapplerItem item;
705   TF_CHECK_OK(s.ToGraphDef(&item.graph));
706   LayoutOptimizer optimizer;
707   GraphDef output;
708   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
709   // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
710   // because of the worse performance in some cases.
711   /*
712   NodeMap node_map(&output);
713   auto sum_node = node_map.GetNode("sum");
714   EXPECT_EQ(sum_node->input(0), "Conv2D");
715   EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
716   auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
717   Tensor tensor;
718   EXPECT_TRUE(
719       tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
720   Tensor tensor_expected(DT_INT32, {3});
721   test::FillValues<int>(&tensor_expected, {0, 2, 3});
722   test::ExpectTensorEqual<int>(tensor_expected, tensor);
723   */
724 }
725 
TEST_F(LayoutOptimizerTest,MulScalarAnd4D)726 TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
727   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
728   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
729   auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
730   auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
731   auto o = ops::Identity(s.WithOpName("o"), mul);
732   GrapplerItem item;
733   TF_CHECK_OK(s.ToGraphDef(&item.graph));
734   LayoutOptimizer optimizer;
735   GraphDef output;
736   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
737   NodeMap node_map(&output);
738   auto mul_node = node_map.GetNode("mul");
739   EXPECT_EQ(mul_node->input(0), "scalar");
740   EXPECT_EQ(mul_node->input(1), "Conv2D");
741 }
742 
TEST_F(LayoutOptimizerTest,Mul4DAndScalar)743 TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
744   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
745   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
746   auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
747   auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
748   auto o = ops::Identity(s.WithOpName("o"), mul);
749   GrapplerItem item;
750   TF_CHECK_OK(s.ToGraphDef(&item.graph));
751   LayoutOptimizer optimizer;
752   GraphDef output;
753   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
754   NodeMap node_map(&output);
755   auto mul_node = node_map.GetNode("mul");
756   EXPECT_EQ(mul_node->input(0), "Conv2D");
757   EXPECT_EQ(mul_node->input(1), "scalar");
758 }
759 
TEST_F(LayoutOptimizerTest,Mul4DAndUnknownRank)760 TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
761   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
762   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
763   auto unknown_rank =
764       ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
765                        ops::Placeholder::Shape(PartialTensorShape()));
766   Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
767   Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
768   auto o = ops::AddN(s.WithOpName("o"), {mul, c});
769   GrapplerItem item;
770   TF_CHECK_OK(s.ToGraphDef(&item.graph));
771   LayoutOptimizer optimizer;
772   GraphDef output;
773   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
774   NodeMap node_map(&output);
775   auto mul_node = node_map.GetNode("mul");
776   // Node mul should not be processed by layout optimizer, because one of its
777   // inputs is of unknown rank.
778   EXPECT_EQ(mul_node->input(0),
779             "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
780   EXPECT_EQ(mul_node->input(1), "unknown");
781 }
782 
TEST_F(LayoutOptimizerTest,Mul4DAnd4D)783 TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
784   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
786   auto i = ops::Identity(s.WithOpName("i"), conv);
787   auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
788   auto o = ops::Identity(s.WithOpName("o"), mul);
789   GrapplerItem item;
790   TF_CHECK_OK(s.ToGraphDef(&item.graph));
791   LayoutOptimizer optimizer;
792   GraphDef output;
793   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
794   NodeMap node_map(&output);
795   auto mul_node = node_map.GetNode("mul");
796   EXPECT_EQ(mul_node->input(0), "Conv2D");
797   EXPECT_EQ(mul_node->input(1), "i");
798 }
799 
TEST_F(LayoutOptimizerTest,Mul4DAndVector)800 TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
801   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
802   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
803   auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
804   auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
805   auto o = ops::Identity(s.WithOpName("o"), mul);
806   GrapplerItem item;
807   TF_CHECK_OK(s.ToGraphDef(&item.graph));
808   LayoutOptimizer optimizer;
809   GraphDef output;
810   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
811   NodeMap node_map(&output);
812   auto mul_node = node_map.GetNode("mul");
813   EXPECT_EQ(mul_node->input(0), "Conv2D");
814   EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
815   auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
816   Tensor tensor;
817   EXPECT_TRUE(
818       tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
819   Tensor tensor_expected(DT_INT32, {4});
820   test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
821   test::ExpectTensorEqual<int>(tensor_expected, tensor);
822 }
823 
TEST_F(LayoutOptimizerTest,MulVectorAnd4D)824 TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
825   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
826   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
827   auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
828   auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
829   auto o = ops::Identity(s.WithOpName("o"), mul);
830   GrapplerItem item;
831   TF_CHECK_OK(s.ToGraphDef(&item.graph));
832   LayoutOptimizer optimizer;
833   GraphDef output;
834   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
835   NodeMap node_map(&output);
836   auto mul_node = node_map.GetNode("mul");
837   EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
838   EXPECT_EQ(mul_node->input(1), "Conv2D");
839   auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
840   Tensor tensor;
841   EXPECT_TRUE(
842       tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
843   Tensor tensor_expected(DT_INT32, {4});
844   test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
845   test::ExpectTensorEqual<int>(tensor_expected, tensor);
846 }
847 
TEST_F(LayoutOptimizerTest,SliceConst)848 TEST_F(LayoutOptimizerTest, SliceConst) {
849   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
850   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
851   auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
852   auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
853   auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
854   auto o = ops::Identity(s.WithOpName("o"), slice);
855   GrapplerItem item;
856   TF_CHECK_OK(s.ToGraphDef(&item.graph));
857   LayoutOptimizer optimizer;
858   GraphDef output;
859   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
860   NodeMap node_map(&output);
861   auto slice_node = node_map.GetNode("slice");
862   EXPECT_EQ(slice_node->input(0), "Conv2D");
863   EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
864   EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
865 
866   auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
867   Tensor begin_tensor;
868   EXPECT_TRUE(begin_tensor.FromProto(
869       begin_const->mutable_attr()->at({"value"}).tensor()));
870   Tensor begin_tensor_expected(DT_INT32, {4});
871   test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
872   test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
873 
874   auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
875   Tensor size_tensor;
876   EXPECT_TRUE(size_tensor.FromProto(
877       size_const->mutable_attr()->at({"value"}).tensor()));
878   Tensor size_tensor_expected(DT_INT32, {4});
879   test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
880   test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
881 }
882 
TEST_F(LayoutOptimizerTest,SliceNonConst)883 TEST_F(LayoutOptimizerTest, SliceNonConst) {
884   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
885   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
886   auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
887   auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
888   auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
889   auto isize = ops::Identity(s.WithOpName("isize"), size);
890   auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
891   auto o = ops::Identity(s.WithOpName("o"), slice);
892   GrapplerItem item;
893   TF_CHECK_OK(s.ToGraphDef(&item.graph));
894   LayoutOptimizer optimizer;
895   GraphDef output;
896   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
897   NodeMap node_map(&output);
898   auto slice_node = node_map.GetNode("slice");
899   EXPECT_EQ(slice_node->input(0), "Conv2D");
900   EXPECT_EQ(slice_node->input(1),
901             "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
902   EXPECT_EQ(slice_node->input(2),
903             "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
904   auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
905   EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
906   EXPECT_EQ(perm1->input(0), "ibegin");
907   auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
908   EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
909   EXPECT_EQ(perm2->input(0), "isize");
910 }
911 
TEST_F(LayoutOptimizerTest,DoNotApplyOptimizerTwice)912 TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
913   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
914   auto scalar =
915       ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
916   auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
917   auto o = ops::Identity(s.WithOpName("o"), mul);
918   GrapplerItem item;
919   TF_CHECK_OK(s.ToGraphDef(&item.graph));
920   LayoutOptimizer optimizer;
921   GraphDef output;
922   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
923   EXPECT_TRUE(errors::IsInvalidArgument(status));
924 }
925 
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAnd4D)926 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
927   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
928   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
929   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
930   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
931   GrapplerItem item;
932   TF_CHECK_OK(s.ToGraphDef(&item.graph));
933   LayoutOptimizer optimizer;
934   GraphDef output;
935   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
936   NodeMap node_map(&output);
937   auto shapen_node = node_map.GetNode("shapen");
938   EXPECT_EQ(shapen_node->input(0), "Conv2D");
939   EXPECT_EQ(shapen_node->input(1), "Conv2D");
940   auto add_node = node_map.GetNode("add");
941   EXPECT_EQ(add_node->input(0),
942             "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
943   EXPECT_EQ(add_node->input(1),
944             "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
945   auto vec_permute1 =
946       node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
947   EXPECT_EQ(vec_permute1->input(0), "shapen");
948   EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
949   auto vec_permute2 =
950       node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
951   EXPECT_EQ(vec_permute2->input(0), "shapen:1");
952   EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
953 }
954 
TEST_F(LayoutOptimizerTest,ShapeNWithInputsVectorAnd4D)955 TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
956   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
957   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
958   auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
959   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
960   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
961   GrapplerItem item;
962   TF_CHECK_OK(s.ToGraphDef(&item.graph));
963   LayoutOptimizer optimizer;
964   GraphDef output;
965   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
966   NodeMap node_map(&output);
967   auto shapen_node = node_map.GetNode("shapen");
968   EXPECT_EQ(shapen_node->input(0), "vector");
969   EXPECT_EQ(shapen_node->input(1), "Conv2D");
970   auto add_node = node_map.GetNode("add");
971   EXPECT_EQ(add_node->input(0), "shapen");
972   EXPECT_EQ(add_node->input(1),
973             "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
974   auto vec_permute =
975       node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
976   EXPECT_EQ(vec_permute->input(0), "shapen:1");
977   EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
978 }
979 
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAndNoNeedToTransform4D)980 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
981   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
982   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
983   auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
984   auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
985   Output i2 = ops::Identity(s.WithOpName("i2"), i1);
986   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
987   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
988   GrapplerItem item;
989   TF_CHECK_OK(s.ToGraphDef(&item.graph));
990   LayoutOptimizer optimizer;
991   GraphDef output;
992   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
993   NodeMap node_map(&output);
994   auto shapen_node = node_map.GetNode("shapen");
995   EXPECT_EQ(shapen_node->input(0), "Conv2D");
996   EXPECT_EQ(shapen_node->input(1), "i2");
997 }
998 
TEST_F(LayoutOptimizerTest,Switch)999 TEST_F(LayoutOptimizerTest, Switch) {
1000   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1001   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1002   ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
1003   auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
1004   auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
1005   auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
1006   GrapplerItem item;
1007   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1008   LayoutOptimizer optimizer;
1009   GraphDef output;
1010   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1011   NodeMap node_map(&output);
1012   auto switch_node = node_map.GetNode("switch");
1013   EXPECT_EQ(switch_node->input(0), "Conv2D");
1014   EXPECT_EQ(switch_node->input(1), "ctrl");
1015   auto i1_node = node_map.GetNode("i1");
1016   auto i2_node = node_map.GetNode("i2");
1017   auto trans1 = node_map.GetNode(i1_node->input(0));
1018   EXPECT_EQ(trans1->input(0), "switch:1");
1019   auto trans2 = node_map.GetNode(i2_node->input(0));
1020   EXPECT_EQ(trans2->input(0), "switch");
1021 }
1022 
TEST_F(LayoutOptimizerTest,MergeBothInputsConvertible)1023 TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
1024   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1026   Output i1 = ops::Identity(s.WithOpName("i1"), conv);
1027   auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
1028   auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1029   GrapplerItem item;
1030   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1031   LayoutOptimizer optimizer;
1032   GraphDef output;
1033   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1034   NodeMap node_map(&output);
1035   auto merge_node = node_map.GetNode("merge");
1036   EXPECT_EQ(merge_node->input(0), "Conv2D");
1037   EXPECT_EQ(merge_node->input(1), "i1");
1038   auto i2_node = node_map.GetNode("i2");
1039   EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1040   auto transpose =
1041       node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1042   EXPECT_EQ(transpose->input(0), "merge");
1043 }
1044 
TEST_F(LayoutOptimizerTest,MergeOneInputNotConvertible)1045 TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
1046   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1047   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1048   auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1049   auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
1050   auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1051   GrapplerItem item;
1052   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1053   LayoutOptimizer optimizer;
1054   GraphDef output;
1055   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1056   NodeMap node_map(&output);
1057   auto merge_node = node_map.GetNode("merge");
1058   EXPECT_EQ(merge_node->input(0), "tensor_4d");
1059   EXPECT_EQ(merge_node->input(1),
1060             "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1061 }
1062 
TEST_F(LayoutOptimizerTest,Complex)1063 TEST_F(LayoutOptimizerTest, Complex) {
1064   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1066   auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
1067   auto i = ops::Identity(s.WithOpName("i"), comp);
1068   GrapplerItem item;
1069   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1070   LayoutOptimizer optimizer;
1071   GraphDef output;
1072   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1073   NodeMap node_map(&output);
1074   auto merge_node = node_map.GetNode("complex");
1075   EXPECT_EQ(merge_node->input(0), "Conv2D");
1076   EXPECT_EQ(merge_node->input(1), "Conv2D");
1077   auto trans =
1078       node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1079   EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
1080 }
1081 
TEST_F(LayoutOptimizerTest,IdentityNWithInputsVectorAnd4D)1082 TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
1083   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1084   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1085   auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
1086   auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
1087   auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
1088   GrapplerItem item;
1089   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1090   LayoutOptimizer optimizer;
1091   GraphDef output;
1092   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1093   NodeMap node_map(&output);
1094   auto i = node_map.GetNode("identity_n");
1095   EXPECT_EQ(i->input(0), "vector");
1096   EXPECT_EQ(i->input(1), "Conv2D");
1097   auto trans =
1098       node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1099   EXPECT_EQ(trans->input(0), "identity_n:1");
1100   auto add_node = node_map.GetNode("add");
1101   EXPECT_EQ(add_node->input(0), "identity_n");
1102   EXPECT_EQ(add_node->input(1),
1103             "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1104 }
1105 }  // namespace
1106 }  // namespace grappler
1107 }  // namespace tensorflow
1108