• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/clusters/single_machine.h"
22 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/grappler/devices.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/protobuf/device_properties.pb.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37 
38 class LayoutOptimizerTest : public GrapplerTest {
39  protected:
SetUp()40   void SetUp() override {
41     gpu_available_ = GetNumAvailableGPUs() > 0;
42 
43     if (gpu_available_) {
44       virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
45     } else {
46       DeviceProperties device_properties;
47       device_properties.set_type("GPU");
48       device_properties.mutable_environment()->insert({"architecture", "6"});
49       virtual_cluster_.reset(
50           new VirtualCluster({{"/GPU:1", device_properties}}));
51     }
52     TF_CHECK_OK(virtual_cluster_->Provision());
53   }
54 
TearDown()55   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
56 
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)57   Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
58                       const string& padding) {
59     return SimpleConv2D(s, input_size, filter_size, padding, "");
60   }
61 
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)62   Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
63                       const string& padding, const string& device) {
64     int batch_size = 8;
65     int input_height = input_size;
66     int input_width = input_size;
67     int input_depth = 3;
68     int filter_count = 2;
69     int stride = 1;
70     TensorShape input_shape(
71         {batch_size, input_height, input_width, input_depth});
72     Tensor input_data(DT_FLOAT, input_shape);
73     test::FillIota<float>(&input_data, 1.0f);
74     Output input =
75         ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
76 
77     TensorShape filter_shape(
78         {filter_size, filter_size, input_depth, filter_count});
79     Tensor filter_data(DT_FLOAT, filter_shape);
80     test::FillIota<float>(&filter_data, 1.0f);
81     Output filter =
82         ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
83 
84     ops::Conv2D::Attrs attrs;
85     if (padding == "EXPLICIT") {
86       attrs = attrs.ExplicitPaddings({0, 0, 1, 2, 3, 4, 0, 0});
87     }
88 
89     Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
90                               filter, {1, stride, stride, 1}, padding, attrs);
91     return conv;
92   }
93 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)94   Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
95                                    int filter_size, const string& padding) {
96     return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
97                                      true);
98   }
99 
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool const_input_size,bool dilated)100   Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
101                                    int filter_size, const string& padding,
102                                    bool const_input_size, bool dilated) {
103     int batch_size = 128;
104     int input_height = input_size;
105     int input_width = input_size;
106     int input_depth = 3;
107     int filter_count = 2;
108     int stride = 1;
109     int dilation = dilated ? 2 : 1;
110     int64 padding_top = 1;
111     int64 padding_bottom = 2;
112     int64 padding_left = 3;
113     int64 padding_right = 4;
114     int64 output_height;
115     int64 output_width;
116     Padding padding_enum;
117     if (padding == "SAME") {
118       padding_enum = SAME;
119     } else if (padding == "VALID") {
120       padding_enum = VALID;
121     } else {
122       CHECK_EQ(padding, "EXPLICIT");
123       padding_enum = EXPLICIT;
124     }
125     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
126         input_height, filter_size, dilation, stride, padding_enum,
127         &output_height, &padding_top, &padding_bottom));
128     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
129         input_width, filter_size, dilation, stride, padding_enum, &output_width,
130         &padding_left, &padding_right));
131     TensorShape input_sizes_shape({4});
132     Tensor input_data(DT_INT32, input_sizes_shape);
133     test::FillValues<int>(&input_data,
134                           {batch_size, input_height, input_width, input_depth});
135     Output input_sizes =
136         ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
137 
138     TensorShape filter_shape(
139         {filter_size, filter_size, input_depth, filter_count});
140     Output filter =
141         ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
142 
143     TensorShape output_shape(
144         {batch_size, output_height, output_width, filter_count});
145     Tensor output_data(DT_FLOAT, output_shape);
146     test::FillIota<float>(&output_data, 1.0f);
147     Output output =
148         ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
149 
150     Output conv_backprop_input;
151     Output input_sizes_i =
152         ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
153     std::vector<int> dilations{1, dilation, dilation, 1};
154     std::vector<int> explicit_paddings;
155     if (padding == "EXPLICIT") {
156       explicit_paddings = {0,
157                            0,
158                            static_cast<int>(padding_top),
159                            static_cast<int>(padding_bottom),
160                            static_cast<int>(padding_left),
161                            static_cast<int>(padding_right),
162                            0,
163                            0};
164     }
165     auto attrs =
166         ops::Conv2DBackpropInput::Attrs().Dilations(dilations).ExplicitPaddings(
167             explicit_paddings);
168     if (const_input_size) {
169       conv_backprop_input = ops::Conv2DBackpropInput(
170           s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
171           {1, stride, stride, 1}, padding, attrs);
172     } else {
173       conv_backprop_input = ops::Conv2DBackpropInput(
174           s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
175           {1, stride, stride, 1}, padding, attrs);
176     }
177     return conv_backprop_input;
178   }
179 
GetAttrValue(const NodeDef & node)180   Tensor GetAttrValue(const NodeDef& node) {
181     Tensor tensor;
182     CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
183     return tensor;
184   }
185 
GetAttrShape(const NodeDef & node)186   TensorShape GetAttrShape(const NodeDef& node) {
187     return TensorShape(node.attr().at({"shape"}).shape());
188   }
189 
SimpleFusedBatchNormGrad(tensorflow::Scope * s,bool is_training)190   Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
191     int batch_size = 16;
192     int input_height = 8;
193     int input_width = 8;
194     int input_channels = 3;
195     TensorShape shape({batch_size, input_height, input_width, input_channels});
196     Tensor data(DT_FLOAT, shape);
197     test::FillIota<float>(&data, 1.0f);
198     Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
199     Output y_backprop =
200         ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
201 
202     TensorShape shape_vector({input_channels});
203     Tensor data_vector(DT_FLOAT, shape_vector);
204     test::FillIota<float>(&data_vector, 2.0f);
205     Output scale =
206         ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
207     Output reserve1 =
208         ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
209     Output reserve2 =
210         ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
211 
212     ops::FusedBatchNormGrad::Attrs attrs;
213     attrs.is_training_ = is_training;
214     auto output =
215         ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
216                                 x, scale, reserve1, reserve2, attrs);
217     return output.x_backprop;
218   }
219 
220   std::unique_ptr<Cluster> virtual_cluster_;
221   bool gpu_available_;
222 };
223 
TEST_F(LayoutOptimizerTest,Conv2DBackpropInput)224 TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
225   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
226   auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "EXPLICIT");
227   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
228   GrapplerItem item;
229   TF_CHECK_OK(s.ToGraphDef(&item.graph));
230   LayoutOptimizer optimizer;
231   GraphDef output;
232 
233   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
234   NodeMap node_map(&output);
235   string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
236   auto input_sizes_node = node_map.GetNode(input_name);
237   CHECK(input_sizes_node);
238   auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
239   CHECK(conv2d_backprop_node);
240   EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
241   auto input_sizes = GetAttrValue(*input_sizes_node);
242   Tensor input_sizes_expected(DT_INT32, {4});
243   test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
244   test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
245 
246   if (gpu_available_) {
247     TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
248     Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
249     std::vector<string> fetch = {"Fetch"};
250     auto tensors_expected =
251         EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
252     auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
253     EXPECT_EQ(1, tensors_expected.size());
254     EXPECT_EQ(1, tensors.size());
255     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
256   }
257 }
258 
TEST_F(LayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)259 TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
260   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
261   auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
262   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
263   GrapplerItem item;
264   TF_CHECK_OK(s.ToGraphDef(&item.graph));
265   LayoutOptimizer optimizer;
266   GraphDef output;
267 
268   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
269   NodeMap node_map(&output);
270   auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
271   CHECK(conv2d_backprop_node);
272   EXPECT_EQ(conv2d_backprop_node->input(0),
273             "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
274   auto input_sizes_node = node_map.GetNode(
275       "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
276   CHECK(input_sizes_node);
277   EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
278   EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
279 }
280 
TEST_F(LayoutOptimizerTest,FilterSizeIsOne)281 TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
282   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
283   auto conv = SimpleConv2D(&s, 2, 1, "SAME");
284   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
285   GrapplerItem item;
286   TF_CHECK_OK(s.ToGraphDef(&item.graph));
287   LayoutOptimizer optimizer;
288   GraphDef output;
289   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
290   NodeMap node_map(&output);
291   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
292 }
293 
TEST_F(LayoutOptimizerTest,FilterSizeNotOne)294 TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
295   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
296   auto conv = SimpleConv2D(&s, 2, 1, "SAME");
297   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
298   GrapplerItem item;
299   TF_CHECK_OK(s.ToGraphDef(&item.graph));
300   LayoutOptimizer optimizer;
301   GraphDef output;
302   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
303   NodeMap node_map(&output);
304   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
305 }
306 
TEST_F(LayoutOptimizerTest,EqualSizeWithValidPadding)307 TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
308   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
309   auto conv = SimpleConv2D(&s, 2, 2, "VALID");
310   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
311   GrapplerItem item;
312   TF_CHECK_OK(s.ToGraphDef(&item.graph));
313   LayoutOptimizer optimizer;
314   GraphDef output;
315   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
316   NodeMap node_map(&output);
317   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
318 }
319 
TEST_F(LayoutOptimizerTest,EqualSizeWithSamePadding)320 TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
321   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
322   auto conv = SimpleConv2D(&s, 2, 2, "SAME");
323   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
324   GrapplerItem item;
325   TF_CHECK_OK(s.ToGraphDef(&item.graph));
326   LayoutOptimizer optimizer;
327   GraphDef output;
328   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
329   NodeMap node_map(&output);
330   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
331 }
332 
TEST_F(LayoutOptimizerTest,NotEqualSizeWithValidPadding)333 TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
334   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
335   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
336   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
337   GrapplerItem item;
338   TF_CHECK_OK(s.ToGraphDef(&item.graph));
339   LayoutOptimizer optimizer;
340   GraphDef output;
341   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
342   NodeMap node_map(&output);
343   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
344 }
345 
TEST_F(LayoutOptimizerTest,ExplicitPadding)346 TEST_F(LayoutOptimizerTest, ExplicitPadding) {
347   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
348   auto conv = SimpleConv2D(&s, 4, 2, "EXPLICIT");
349   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
350   GrapplerItem item;
351   TF_CHECK_OK(s.ToGraphDef(&item.graph));
352   LayoutOptimizer optimizer;
353   GraphDef output;
354   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
355   NodeMap node_map(&output);
356   EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
357 }
358 
TEST_F(LayoutOptimizerTest,Pad)359 TEST_F(LayoutOptimizerTest, Pad) {
360   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
362   auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
363   auto p = ops::Pad(s.WithOpName("p"), conv, c);
364   auto o = ops::Identity(s.WithOpName("o"), p);
365   GrapplerItem item;
366   TF_CHECK_OK(s.ToGraphDef(&item.graph));
367   LayoutOptimizer optimizer;
368   GraphDef output;
369   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
370   NodeMap node_map(&output);
371 
372   auto pad = node_map.GetNode("p");
373   EXPECT_EQ(pad->input(0), "Conv2D");
374 
375   auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
376   EXPECT_TRUE(pad_const);
377   EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
378   Tensor tensor;
379   EXPECT_TRUE(
380       tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
381   Tensor tensor_expected(DT_INT32, {4, 2});
382   test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
383   test::ExpectTensorEqual<int>(tensor_expected, tensor);
384 }
385 
TEST_F(LayoutOptimizerTest,Connectivity)386 TEST_F(LayoutOptimizerTest, Connectivity) {
387   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
388   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
389   auto i1 = ops::Identity(s.WithOpName("i1"), conv);
390   auto i2 = ops::Identity(s.WithOpName("i2"), i1);
391   auto i3 = ops::Identity(s.WithOpName("i3"), i2);
392   GrapplerItem item;
393   TF_CHECK_OK(s.ToGraphDef(&item.graph));
394   // Make the graph not in topological order to test the handling of multi-hop
395   // connectivity (here we say two nodes are connected if all nodes in the
396   // middle are layout agnostic). If the graph is already in topological order,
397   // the problem is easier, where layout optimizer only needs to check
398   // single-hop connectivity.
399   NodeMap node_map_original(&item.graph);
400   auto node_i1 = node_map_original.GetNode("i1");
401   auto node_i2 = node_map_original.GetNode("i2");
402   node_i2->Swap(node_i1);
403   LayoutOptimizer optimizer;
404   GraphDef output;
405   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
406   NodeMap node_map_output(&output);
407   auto node_i2_output = node_map_output.GetNode("i2");
408   // Layout optimizer should process i2, as it detects i2 is connected with the
409   // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
410   // directly connected to the Conv2D node. The two added transposes between
411   // i1 and i2 should cancel each other, and as a result i2 is directly
412   // connected to i1.
413   EXPECT_EQ(node_i2_output->input(0), "i1");
414 }
415 
TEST_F(LayoutOptimizerTest,ConnectivityBinaryOpWithInputScalarAnd4D)416 TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
417   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
418   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
419   auto i1 = ops::Identity(s.WithOpName("i1"), conv);
420   auto i2 = ops::Identity(s.WithOpName("i2"), i1);
421   auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
422   auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
423   auto i3 = ops::Identity(s.WithOpName("i3"), sub);
424   auto i4 = ops::Identity(s.WithOpName("i4"), i3);
425   auto i5 = ops::Identity(s.WithOpName("i5"), i4);
426   auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
427   auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
428   auto i6 = ops::Identity(s.WithOpName("i6"), mul);
429   GrapplerItem item;
430   TF_CHECK_OK(s.ToGraphDef(&item.graph));
431   // Make the graph not in topological order to test the handling of multi-hop
432   // connectivity (here we say two nodes are connected if all nodes in the
433   // middle are layout agnostic). If the graph is already in topological order,
434   // the problem is easier, where layout optimizer only needs to check
435   // single-hop connectivity.
436   NodeMap node_map_original(&item.graph);
437   auto node_i1 = node_map_original.GetNode("i1");
438   auto node_mul = node_map_original.GetNode("mul");
439   node_mul->Swap(node_i1);
440   LayoutOptimizer optimizer;
441   GraphDef output;
442   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
443   NodeMap node_map_output(&output);
444   auto mul_node = node_map_output.GetNode("mul");
445   EXPECT_EQ(mul_node->input(0), "scalar_mul");
446   EXPECT_EQ(mul_node->input(1), "i5");
447 }
448 
TEST_F(LayoutOptimizerTest,PreserveFetch)449 TEST_F(LayoutOptimizerTest, PreserveFetch) {
450   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
451   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
452   auto i = ops::Identity(s.WithOpName("i"), conv);
453   GrapplerItem item;
454   item.fetch.push_back("Conv2D");
455   TF_CHECK_OK(s.ToGraphDef(&item.graph));
456   LayoutOptimizer optimizer;
457   GraphDef output;
458   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
459   NodeMap node_map(&output);
460   auto conv_node = node_map.GetNode("Conv2D");
461   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
462 }
463 
TEST_F(LayoutOptimizerTest,EmptyDevice)464 TEST_F(LayoutOptimizerTest, EmptyDevice) {
465   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
466   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
467   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
468   GrapplerItem item;
469   TF_CHECK_OK(s.ToGraphDef(&item.graph));
470   LayoutOptimizer optimizer;
471   GraphDef output;
472   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
473   NodeMap node_map(&output);
474   auto conv_node = node_map.GetNode("Conv2D");
475   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
476 }
477 
TEST_F(LayoutOptimizerTest,GPUDevice)478 TEST_F(LayoutOptimizerTest, GPUDevice) {
479   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
480   auto conv =
481       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
482   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
483   GrapplerItem item;
484   TF_CHECK_OK(s.ToGraphDef(&item.graph));
485   LayoutOptimizer optimizer;
486   GraphDef output;
487   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
488   NodeMap node_map(&output);
489   auto conv_node = node_map.GetNode("Conv2D");
490   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
491 }
492 
TEST_F(LayoutOptimizerTest,CPUDeviceLowercase)493 TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
494   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
495   auto conv =
496       SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
497   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
498   GrapplerItem item;
499   TF_CHECK_OK(s.ToGraphDef(&item.graph));
500   LayoutOptimizer optimizer;
501   GraphDef output;
502   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
503   NodeMap node_map(&output);
504   auto conv_node = node_map.GetNode("Conv2D");
505   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
506 }
507 
TEST_F(LayoutOptimizerTest,CPUDeviceUppercase)508 TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
509   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
510   auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
511   Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
512   GrapplerItem item;
513   TF_CHECK_OK(s.ToGraphDef(&item.graph));
514   LayoutOptimizer optimizer;
515   GraphDef output;
516   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
517   NodeMap node_map(&output);
518   auto conv_node = node_map.GetNode("Conv2D");
519   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
520 }
521 
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingTrue)522 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
523   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
524   auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
525   Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
526   GrapplerItem item;
527   TF_CHECK_OK(s.ToGraphDef(&item.graph));
528   LayoutOptimizer optimizer;
529   GraphDef output;
530   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
531   NodeMap node_map(&output);
532   auto conv_node = node_map.GetNode("FusedBatchNormGrad");
533   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
534 }
535 
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingFalse)536 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
537   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
538   auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
539   Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
540   GrapplerItem item;
541   TF_CHECK_OK(s.ToGraphDef(&item.graph));
542   LayoutOptimizer optimizer;
543   GraphDef output;
544   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
545   NodeMap node_map(&output);
546   auto conv_node = node_map.GetNode("FusedBatchNormGrad");
547   EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
548 }
549 
TEST_F(LayoutOptimizerTest,SplitDimC)550 TEST_F(LayoutOptimizerTest, SplitDimC) {
551   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
552   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
553   auto c = ops::Const(s.WithOpName("c"), 3, {});
554   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
555   auto i = ops::Identity(s.WithOpName("i"), split[0]);
556   GrapplerItem item;
557   TF_CHECK_OK(s.ToGraphDef(&item.graph));
558   LayoutOptimizer optimizer;
559   GraphDef output;
560   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
561   NodeMap node_map(&output);
562   auto split_node = node_map.GetNode("split");
563   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
564   EXPECT_EQ(split_node->input(1), "Conv2D");
565   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
566   EXPECT_EQ(split_const->op(), "Const");
567   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
568 }
569 
TEST_F(LayoutOptimizerTest,SplitDimH)570 TEST_F(LayoutOptimizerTest, SplitDimH) {
571   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
572   auto conv = SimpleConv2D(&s, 6, 2, "SAME");
573   auto c = ops::Const(s.WithOpName("c"), 1, {});
574   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
575   auto i = ops::Identity(s.WithOpName("i"), split[0]);
576   GrapplerItem item;
577   TF_CHECK_OK(s.ToGraphDef(&item.graph));
578   LayoutOptimizer optimizer;
579   GraphDef output;
580   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
581   NodeMap node_map(&output);
582   auto split_node = node_map.GetNode("split");
583   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
584   EXPECT_EQ(split_node->input(1), "Conv2D");
585   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
586   EXPECT_EQ(split_const->op(), "Const");
587   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
588 }
589 
TEST_F(LayoutOptimizerTest,SplitDimW)590 TEST_F(LayoutOptimizerTest, SplitDimW) {
591   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
592   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
593   auto c = ops::Const(s.WithOpName("c"), 2, {});
594   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
595   auto i = ops::Identity(s.WithOpName("i"), split[0]);
596   GrapplerItem item;
597   TF_CHECK_OK(s.ToGraphDef(&item.graph));
598   LayoutOptimizer optimizer;
599   GraphDef output;
600   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
601   NodeMap node_map(&output);
602   auto split_node = node_map.GetNode("split");
603   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
604   EXPECT_EQ(split_node->input(1), "Conv2D");
605   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
606   EXPECT_EQ(split_const->op(), "Const");
607   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
608 }
609 
TEST_F(LayoutOptimizerTest,SplitDimN)610 TEST_F(LayoutOptimizerTest, SplitDimN) {
611   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
612   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
613   auto c = ops::Const(s.WithOpName("c"), 0, {});
614   auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
615   auto i = ops::Identity(s.WithOpName("i"), split[0]);
616   GrapplerItem item;
617   TF_CHECK_OK(s.ToGraphDef(&item.graph));
618   LayoutOptimizer optimizer;
619   GraphDef output;
620   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
621   NodeMap node_map(&output);
622   auto split_node = node_map.GetNode("split");
623   EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
624   EXPECT_EQ(split_node->input(1), "Conv2D");
625   auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
626   EXPECT_EQ(split_const->op(), "Const");
627   EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
628 }
629 
TEST_F(LayoutOptimizerTest,SplitNonConstDim)630 TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
631   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
632   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
633   auto c = ops::Const(s.WithOpName("c"), 0, {});
634   auto i1 = ops::Identity(s.WithOpName("i1"), c);
635   auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
636   auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
637   GrapplerItem item;
638   TF_CHECK_OK(s.ToGraphDef(&item.graph));
639   LayoutOptimizer optimizer;
640   GraphDef output;
641   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
642   NodeMap node_map(&output);
643   auto split_node = node_map.GetNode("split");
644   EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
645   EXPECT_EQ(split_node->input(1), "Conv2D");
646   auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
647   EXPECT_EQ(map_node->op(), "DataFormatDimMap");
648   EXPECT_EQ(map_node->input(0), "i1");
649 }
650 
TEST_F(LayoutOptimizerTest,SplitSamePortToMultipleInputsOfSameNode)651 TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
652   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
653   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
654   auto axis = ops::Const(s.WithOpName("axis"), 3);
655   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
656   auto concat =
657       ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
658   auto o = ops::Identity(s.WithOpName("o"), concat);
659   GrapplerItem item;
660   TF_CHECK_OK(s.ToGraphDef(&item.graph));
661   LayoutOptimizer optimizer;
662   GraphDef output;
663   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
664   NodeMap node_map(&output);
665   auto concat_node = node_map.GetNode("concat");
666   EXPECT_EQ(concat_node->input(0), "split:1");
667   EXPECT_EQ(concat_node->input(1), "split:1");
668   EXPECT_EQ(concat_node->input(2), "split:1");
669   EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
670   auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
671   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
672 }
673 
TEST_F(LayoutOptimizerTest,ConcatDimH)674 TEST_F(LayoutOptimizerTest, ConcatDimH) {
675   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
676   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
677   auto axis = ops::Const(s.WithOpName("axis"), 1);
678   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
679   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
680   auto o = ops::Identity(s.WithOpName("o"), concat);
681   GrapplerItem item;
682   TF_CHECK_OK(s.ToGraphDef(&item.graph));
683   LayoutOptimizer optimizer;
684   GraphDef output;
685   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
686   NodeMap node_map(&output);
687   auto concat_node = node_map.GetNode("concat");
688   EXPECT_EQ(concat_node->input(0), "split");
689   EXPECT_EQ(concat_node->input(1), "split:1");
690   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
691   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
692   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
693 }
694 
TEST_F(LayoutOptimizerTest,ConcatNonConst)695 TEST_F(LayoutOptimizerTest, ConcatNonConst) {
696   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
697   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
698   auto axis = ops::Const(s.WithOpName("axis"), 1);
699   auto i = ops::Identity(s.WithOpName("i"), axis);
700   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
701   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
702   auto o = ops::Identity(s.WithOpName("o"), concat);
703   GrapplerItem item;
704   TF_CHECK_OK(s.ToGraphDef(&item.graph));
705   LayoutOptimizer optimizer;
706   GraphDef output;
707   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
708   NodeMap node_map(&output);
709   auto concat_node = node_map.GetNode("concat");
710   EXPECT_EQ(concat_node->input(0), "split");
711   EXPECT_EQ(concat_node->input(1), "split:1");
712   EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
713   auto concat_dim =
714       node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
715   EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
716   EXPECT_EQ(concat_dim->input(0), "i");
717 }
718 
TEST_F(LayoutOptimizerTest,ConcatDimW)719 TEST_F(LayoutOptimizerTest, ConcatDimW) {
720   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
721   auto conv = SimpleConv2D(&s, 4, 2, "SAME");
722   auto axis = ops::Const(s.WithOpName("axis"), 2);
723   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
724   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
725   auto o = ops::Identity(s.WithOpName("o"), concat);
726   GrapplerItem item;
727   TF_CHECK_OK(s.ToGraphDef(&item.graph));
728   LayoutOptimizer optimizer;
729   GraphDef output;
730   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
731   NodeMap node_map(&output);
732   auto concat_node = node_map.GetNode("concat");
733   EXPECT_EQ(concat_node->input(0), "split");
734   EXPECT_EQ(concat_node->input(1), "split:1");
735   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
736   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
737   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
738 }
739 
TEST_F(LayoutOptimizerTest,ConcatDimN)740 TEST_F(LayoutOptimizerTest, ConcatDimN) {
741   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
742   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
743   auto axis = ops::Const(s.WithOpName("axis"), 0);
744   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
745   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
746   auto o = ops::Identity(s.WithOpName("o"), concat);
747   GrapplerItem item;
748   TF_CHECK_OK(s.ToGraphDef(&item.graph));
749   LayoutOptimizer optimizer;
750   GraphDef output;
751   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
752   NodeMap node_map(&output);
753   auto concat_node = node_map.GetNode("concat");
754   EXPECT_EQ(concat_node->input(0), "split");
755   EXPECT_EQ(concat_node->input(1), "split:1");
756   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
757   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
758   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
759 }
760 
TEST_F(LayoutOptimizerTest,ConcatDimC)761 TEST_F(LayoutOptimizerTest, ConcatDimC) {
762   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
763   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
764   auto axis = ops::Const(s.WithOpName("axis"), 3);
765   auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
766   auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
767   auto o = ops::Identity(s.WithOpName("o"), concat);
768   GrapplerItem item;
769   TF_CHECK_OK(s.ToGraphDef(&item.graph));
770   LayoutOptimizer optimizer;
771   GraphDef output;
772   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
773   NodeMap node_map(&output);
774   auto concat_node = node_map.GetNode("concat");
775   EXPECT_EQ(concat_node->input(0), "split");
776   EXPECT_EQ(concat_node->input(1), "split:1");
777   EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
778   auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
779   EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
780 }
781 
TEST_F(LayoutOptimizerTest,Sum)782 TEST_F(LayoutOptimizerTest, Sum) {
783   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
784   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
785   auto reduction_indices =
786       ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
787   auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
788   auto o = ops::Identity(s.WithOpName("o"), sum);
789   GrapplerItem item;
790   TF_CHECK_OK(s.ToGraphDef(&item.graph));
791   LayoutOptimizer optimizer;
792   GraphDef output;
793   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
794   // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
795   // because of the worse performance in some cases.
796   /*
797   NodeMap node_map(&output);
798   auto sum_node = node_map.GetNode("sum");
799   EXPECT_EQ(sum_node->input(0), "Conv2D");
800   EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
801   auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
802   Tensor tensor;
803   EXPECT_TRUE(
804       tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
805   Tensor tensor_expected(DT_INT32, {3});
806   test::FillValues<int>(&tensor_expected, {0, 2, 3});
807   test::ExpectTensorEqual<int>(tensor_expected, tensor);
808   */
809 }
810 
TEST_F(LayoutOptimizerTest,MulScalarAnd4D)811 TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
812   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
813   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
814   auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
815   auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
816   auto o = ops::Identity(s.WithOpName("o"), mul);
817   GrapplerItem item;
818   TF_CHECK_OK(s.ToGraphDef(&item.graph));
819   LayoutOptimizer optimizer;
820   GraphDef output;
821   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
822   NodeMap node_map(&output);
823   auto mul_node = node_map.GetNode("mul");
824   EXPECT_EQ(mul_node->input(0), "scalar");
825   EXPECT_EQ(mul_node->input(1), "Conv2D");
826 }
827 
TEST_F(LayoutOptimizerTest,Mul4DAndScalar)828 TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
829   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
830   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
831   auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
832   auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
833   auto o = ops::Identity(s.WithOpName("o"), mul);
834   GrapplerItem item;
835   TF_CHECK_OK(s.ToGraphDef(&item.graph));
836   LayoutOptimizer optimizer;
837   GraphDef output;
838   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
839   NodeMap node_map(&output);
840   auto mul_node = node_map.GetNode("mul");
841   EXPECT_EQ(mul_node->input(0), "Conv2D");
842   EXPECT_EQ(mul_node->input(1), "scalar");
843 }
844 
TEST_F(LayoutOptimizerTest,Mul4DAndUnknownRank)845 TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
846   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
847   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
848   auto unknown_rank =
849       ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
850                        ops::Placeholder::Shape(PartialTensorShape()));
851   Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
852   Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
853   auto o = ops::AddN(s.WithOpName("o"), {mul, c});
854   GrapplerItem item;
855   TF_CHECK_OK(s.ToGraphDef(&item.graph));
856   LayoutOptimizer optimizer;
857   GraphDef output;
858   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
859   NodeMap node_map(&output);
860   auto mul_node = node_map.GetNode("mul");
861   // Node mul should not be processed by layout optimizer, because one of its
862   // inputs is of unknown rank.
863   EXPECT_EQ(mul_node->input(0),
864             "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
865   EXPECT_EQ(mul_node->input(1), "unknown");
866 }
867 
TEST_F(LayoutOptimizerTest,Mul4DAnd4D)868 TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
869   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
870   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
871   auto i = ops::Identity(s.WithOpName("i"), conv);
872   auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
873   auto o = ops::Identity(s.WithOpName("o"), mul);
874   GrapplerItem item;
875   TF_CHECK_OK(s.ToGraphDef(&item.graph));
876   LayoutOptimizer optimizer;
877   GraphDef output;
878   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
879   NodeMap node_map(&output);
880   auto mul_node = node_map.GetNode("mul");
881   EXPECT_EQ(mul_node->input(0), "Conv2D");
882   EXPECT_EQ(mul_node->input(1), "i");
883 }
884 
TEST_F(LayoutOptimizerTest,Mul4DAndVector)885 TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
886   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
887   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
888   auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
889   auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
890   auto o = ops::Identity(s.WithOpName("o"), mul);
891   GrapplerItem item;
892   TF_CHECK_OK(s.ToGraphDef(&item.graph));
893   LayoutOptimizer optimizer;
894   GraphDef output;
895   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
896   NodeMap node_map(&output);
897   auto mul_node = node_map.GetNode("mul");
898   EXPECT_EQ(mul_node->input(0), "Conv2D");
899   EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
900   auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
901   Tensor tensor;
902   EXPECT_TRUE(
903       tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
904   Tensor tensor_expected(DT_INT32, {4});
905   test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
906   test::ExpectTensorEqual<int>(tensor_expected, tensor);
907 }
908 
TEST_F(LayoutOptimizerTest,MulVectorAnd4D)909 TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
910   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
911   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
912   auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
913   auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
914   auto o = ops::Identity(s.WithOpName("o"), mul);
915   GrapplerItem item;
916   TF_CHECK_OK(s.ToGraphDef(&item.graph));
917   LayoutOptimizer optimizer;
918   GraphDef output;
919   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
920   NodeMap node_map(&output);
921   auto mul_node = node_map.GetNode("mul");
922   EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
923   EXPECT_EQ(mul_node->input(1), "Conv2D");
924   auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
925   Tensor tensor;
926   EXPECT_TRUE(
927       tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
928   Tensor tensor_expected(DT_INT32, {4});
929   test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
930   test::ExpectTensorEqual<int>(tensor_expected, tensor);
931 }
932 
TEST_F(LayoutOptimizerTest,SliceConst)933 TEST_F(LayoutOptimizerTest, SliceConst) {
934   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
935   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
936   auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
937   auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
938   auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
939   auto o = ops::Identity(s.WithOpName("o"), slice);
940   GrapplerItem item;
941   TF_CHECK_OK(s.ToGraphDef(&item.graph));
942   LayoutOptimizer optimizer;
943   GraphDef output;
944   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
945   NodeMap node_map(&output);
946   auto slice_node = node_map.GetNode("slice");
947   EXPECT_EQ(slice_node->input(0), "Conv2D");
948   EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
949   EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
950 
951   auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
952   Tensor begin_tensor;
953   EXPECT_TRUE(begin_tensor.FromProto(
954       begin_const->mutable_attr()->at({"value"}).tensor()));
955   Tensor begin_tensor_expected(DT_INT32, {4});
956   test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
957   test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
958 
959   auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
960   Tensor size_tensor;
961   EXPECT_TRUE(size_tensor.FromProto(
962       size_const->mutable_attr()->at({"value"}).tensor()));
963   Tensor size_tensor_expected(DT_INT32, {4});
964   test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
965   test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
966 }
967 
TEST_F(LayoutOptimizerTest,SliceNonConst)968 TEST_F(LayoutOptimizerTest, SliceNonConst) {
969   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
970   auto conv = SimpleConv2D(&s, 5, 2, "VALID");
971   auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
972   auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
973   auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
974   auto isize = ops::Identity(s.WithOpName("isize"), size);
975   auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
976   auto o = ops::Identity(s.WithOpName("o"), slice);
977   GrapplerItem item;
978   TF_CHECK_OK(s.ToGraphDef(&item.graph));
979   LayoutOptimizer optimizer;
980   GraphDef output;
981   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
982   NodeMap node_map(&output);
983   auto slice_node = node_map.GetNode("slice");
984   EXPECT_EQ(slice_node->input(0), "Conv2D");
985   EXPECT_EQ(slice_node->input(1),
986             "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
987   EXPECT_EQ(slice_node->input(2),
988             "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
989   auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
990   EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
991   EXPECT_EQ(perm1->input(0), "ibegin");
992   auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
993   EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
994   EXPECT_EQ(perm2->input(0), "isize");
995 }
996 
TEST_F(LayoutOptimizerTest,DoNotApplyOptimizerTwice)997 TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
998   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
999   auto scalar =
1000       ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
1001   auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
1002   auto o = ops::Identity(s.WithOpName("o"), mul);
1003   GrapplerItem item;
1004   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1005   LayoutOptimizer optimizer;
1006   GraphDef output;
1007   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1008   EXPECT_TRUE(errors::IsInvalidArgument(status));
1009 }
1010 
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAnd4D)1011 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
1012   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1013   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1014   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
1015   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1016   GrapplerItem item;
1017   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1018   LayoutOptimizer optimizer;
1019   GraphDef output;
1020   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1021   NodeMap node_map(&output);
1022   auto shapen_node = node_map.GetNode("shapen");
1023   EXPECT_EQ(shapen_node->input(0), "Conv2D");
1024   EXPECT_EQ(shapen_node->input(1), "Conv2D");
1025   auto add_node = node_map.GetNode("add");
1026   EXPECT_EQ(add_node->input(0),
1027             "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1028   EXPECT_EQ(add_node->input(1),
1029             "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1030   auto vec_permute1 =
1031       node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1032   EXPECT_EQ(vec_permute1->input(0), "shapen");
1033   EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
1034   auto vec_permute2 =
1035       node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1036   EXPECT_EQ(vec_permute2->input(0), "shapen:1");
1037   EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
1038 }
1039 
TEST_F(LayoutOptimizerTest,ShapeNWithInputsVectorAnd4D)1040 TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
1041   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1042   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1043   auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
1044   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
1045   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1046   GrapplerItem item;
1047   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1048   LayoutOptimizer optimizer;
1049   GraphDef output;
1050   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1051   NodeMap node_map(&output);
1052   auto shapen_node = node_map.GetNode("shapen");
1053   EXPECT_EQ(shapen_node->input(0), "vector");
1054   EXPECT_EQ(shapen_node->input(1), "Conv2D");
1055   auto add_node = node_map.GetNode("add");
1056   EXPECT_EQ(add_node->input(0), "shapen");
1057   EXPECT_EQ(add_node->input(1),
1058             "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1059   auto vec_permute =
1060       node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1061   EXPECT_EQ(vec_permute->input(0), "shapen:1");
1062   EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
1063 }
1064 
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAndNoNeedToTransform4D)1065 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
1066   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1067   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1068   auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1069   auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
1070   Output i2 = ops::Identity(s.WithOpName("i2"), i1);
1071   auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
1072   auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1073   GrapplerItem item;
1074   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1075   LayoutOptimizer optimizer;
1076   GraphDef output;
1077   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1078   NodeMap node_map(&output);
1079   auto shapen_node = node_map.GetNode("shapen");
1080   EXPECT_EQ(shapen_node->input(0), "Conv2D");
1081   EXPECT_EQ(shapen_node->input(1), "i2");
1082 }
1083 
TEST_F(LayoutOptimizerTest,Switch)1084 TEST_F(LayoutOptimizerTest, Switch) {
1085   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1086   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1087   ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
1088   auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
1089   auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
1090   auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
1091   GrapplerItem item;
1092   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1093   LayoutOptimizer optimizer;
1094   GraphDef output;
1095   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1096   NodeMap node_map(&output);
1097   auto switch_node = node_map.GetNode("switch");
1098   EXPECT_EQ(switch_node->input(0), "Conv2D");
1099   EXPECT_EQ(switch_node->input(1), "ctrl");
1100   auto i1_node = node_map.GetNode("i1");
1101   auto i2_node = node_map.GetNode("i2");
1102   auto trans1 = node_map.GetNode(i1_node->input(0));
1103   EXPECT_EQ(trans1->input(0), "switch:1");
1104   auto trans2 = node_map.GetNode(i2_node->input(0));
1105   EXPECT_EQ(trans2->input(0), "switch");
1106 }
1107 
TEST_F(LayoutOptimizerTest,MergeBothInputsConvertible)1108 TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
1109   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1110   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1111   Output i1 = ops::Identity(s.WithOpName("i1"), conv);
1112   auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
1113   auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1114   GrapplerItem item;
1115   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1116   LayoutOptimizer optimizer;
1117   GraphDef output;
1118   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1119   NodeMap node_map(&output);
1120   auto merge_node = node_map.GetNode("merge");
1121   EXPECT_EQ(merge_node->input(0), "Conv2D");
1122   EXPECT_EQ(merge_node->input(1), "i1");
1123   auto i2_node = node_map.GetNode("i2");
1124   EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1125   auto transpose =
1126       node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1127   EXPECT_EQ(transpose->input(0), "merge");
1128 }
1129 
TEST_F(LayoutOptimizerTest,MergeOneInputNotConvertible)1130 TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
1131   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1132   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1133   auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1134   auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
1135   auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1136   GrapplerItem item;
1137   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1138   LayoutOptimizer optimizer;
1139   GraphDef output;
1140   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1141   NodeMap node_map(&output);
1142   auto merge_node = node_map.GetNode("merge");
1143   EXPECT_EQ(merge_node->input(0), "tensor_4d");
1144   EXPECT_EQ(merge_node->input(1),
1145             "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1146 }
1147 
TEST_F(LayoutOptimizerTest,Complex)1148 TEST_F(LayoutOptimizerTest, Complex) {
1149   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1150   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1151   auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
1152   auto i = ops::Identity(s.WithOpName("i"), comp);
1153   GrapplerItem item;
1154   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1155   LayoutOptimizer optimizer;
1156   GraphDef output;
1157   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1158   NodeMap node_map(&output);
1159   auto merge_node = node_map.GetNode("complex");
1160   EXPECT_EQ(merge_node->input(0), "Conv2D");
1161   EXPECT_EQ(merge_node->input(1), "Conv2D");
1162   auto trans =
1163       node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1164   EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
1165 }
1166 
TEST_F(LayoutOptimizerTest,IdentityNWithInputsVectorAnd4D)1167 TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
1168   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1169   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1170   auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
1171   auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
1172   auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
1173   GrapplerItem item;
1174   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1175   LayoutOptimizer optimizer;
1176   GraphDef output;
1177   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1178   NodeMap node_map(&output);
1179   auto i = node_map.GetNode("identity_n");
1180   EXPECT_EQ(i->input(0), "vector");
1181   EXPECT_EQ(i->input(1), "Conv2D");
1182   auto trans =
1183       node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1184   EXPECT_EQ(trans->input(0), "identity_n:1");
1185   auto add_node = node_map.GetNode("add");
1186   EXPECT_EQ(add_node->input(0), "identity_n");
1187   EXPECT_EQ(add_node->input(1),
1188             "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1189 }
1190 
TEST_F(LayoutOptimizerTest,LoopNoLiveLock)1191 TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
1192   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1193   auto c = ops::Const(s.WithOpName("const"), 3.0f, {8, 3, 3, 2});
1194   auto merge = ops::Merge(s.WithOpName("merge"), {c, c});
1195   auto i0 = ops::Identity(s.WithOpName("i0"), merge.output);
1196   ops::Variable v_ctrl(s.WithOpName("v_ctrl"), {}, DT_BOOL);
1197   auto sw = ops::Switch(s.WithOpName("switch"), i0, v_ctrl);
1198   auto next = ops::NextIteration(s.WithOpName("next"), sw.output_true);
1199   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1200   auto mul = ops::Mul(s.WithOpName("mul"), conv, sw.output_false);
1201   GrapplerItem item;
1202   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1203   NodeMap node_map_original(&item.graph);
1204   auto merge_node = node_map_original.GetNode("merge");
1205   // Modify the graph to create a loop
1206   merge_node->set_input(1, "next");
1207   LayoutOptimizer optimizer;
1208   GraphDef output;
1209   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1210   NodeMap node_map(&output);
1211   auto conv_node = node_map.GetNode("Conv2D");
1212   EXPECT_EQ(conv_node->input(0),
1213             "Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer");
1214   auto mul_node = node_map.GetNode("mul");
1215   EXPECT_EQ(mul_node->input(0),
1216             "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1217 }
1218 
TEST_F(LayoutOptimizerTest,DevicePlacement)1219 TEST_F(LayoutOptimizerTest, DevicePlacement) {
1220   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1221   auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1222   auto shape = ops::Shape(s.WithOpName("s"), conv);
1223   auto i = ops::Identity(s.WithOpName("i"), shape);
1224   GrapplerItem item;
1225   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1226   VirtualPlacer virtual_placer(virtual_cluster_.get());
1227   for (auto& node : *item.graph.mutable_node()) {
1228     string device = virtual_placer.get_canonical_device_name(node);
1229     node.set_device(device);
1230   }
1231   LayoutOptimizer optimizer;
1232   GraphDef output;
1233   Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1234   NodeMap node_map(&output);
1235   auto vec_permute =
1236       node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1237   EXPECT_EQ(vec_permute->attr().at("_kernel").s(), "host");
1238 }
1239 }  // namespace
1240 }  // namespace grappler
1241 }  // namespace tensorflow
1242