1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/clusters/single_machine.h"
22 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/grappler/devices.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/protobuf/device_properties.pb.h"
33
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37
38 class LayoutOptimizerTest : public GrapplerTest {
39 protected:
SetUp()40 void SetUp() override {
41 gpu_available_ = GetNumAvailableGPUs() > 0;
42
43 if (gpu_available_) {
44 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
45 } else {
46 DeviceProperties device_properties;
47 device_properties.set_type("GPU");
48 device_properties.mutable_environment()->insert({"architecture", "6"});
49 virtual_cluster_.reset(
50 new VirtualCluster({{"/GPU:1", device_properties}}));
51 }
52 TF_CHECK_OK(virtual_cluster_->Provision());
53 }
54
TearDown()55 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
56
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)57 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
58 const string& padding) {
59 return SimpleConv2D(s, input_size, filter_size, padding, "");
60 }
61
SimpleConv2D(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,const string & device)62 Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
63 const string& padding, const string& device) {
64 int batch_size = 8;
65 int input_height = input_size;
66 int input_width = input_size;
67 int input_depth = 3;
68 int filter_count = 2;
69 int stride = 1;
70 TensorShape input_shape(
71 {batch_size, input_height, input_width, input_depth});
72 Tensor input_data(DT_FLOAT, input_shape);
73 test::FillIota<float>(&input_data, 1.0f);
74 Output input =
75 ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
76
77 TensorShape filter_shape(
78 {filter_size, filter_size, input_depth, filter_count});
79 Tensor filter_data(DT_FLOAT, filter_shape);
80 test::FillIota<float>(&filter_data, 1.0f);
81 Output filter =
82 ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
83
84 ops::Conv2D::Attrs attrs;
85 if (padding == "EXPLICIT") {
86 attrs = attrs.ExplicitPaddings({0, 0, 1, 2, 3, 4, 0, 0});
87 }
88
89 Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
90 filter, {1, stride, stride, 1}, padding, attrs);
91 return conv;
92 }
93
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding)94 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
95 int filter_size, const string& padding) {
96 return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
97 true);
98 }
99
SimpleConv2DBackpropInput(tensorflow::Scope * s,int input_size,int filter_size,const string & padding,bool const_input_size,bool dilated)100 Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
101 int filter_size, const string& padding,
102 bool const_input_size, bool dilated) {
103 int batch_size = 128;
104 int input_height = input_size;
105 int input_width = input_size;
106 int input_depth = 3;
107 int filter_count = 2;
108 int stride = 1;
109 int dilation = dilated ? 2 : 1;
110 int64 padding_top = 1;
111 int64 padding_bottom = 2;
112 int64 padding_left = 3;
113 int64 padding_right = 4;
114 int64 output_height;
115 int64 output_width;
116 Padding padding_enum;
117 if (padding == "SAME") {
118 padding_enum = SAME;
119 } else if (padding == "VALID") {
120 padding_enum = VALID;
121 } else {
122 CHECK_EQ(padding, "EXPLICIT");
123 padding_enum = EXPLICIT;
124 }
125 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
126 input_height, filter_size, dilation, stride, padding_enum,
127 &output_height, &padding_top, &padding_bottom));
128 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
129 input_width, filter_size, dilation, stride, padding_enum, &output_width,
130 &padding_left, &padding_right));
131 TensorShape input_sizes_shape({4});
132 Tensor input_data(DT_INT32, input_sizes_shape);
133 test::FillValues<int>(&input_data,
134 {batch_size, input_height, input_width, input_depth});
135 Output input_sizes =
136 ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
137
138 TensorShape filter_shape(
139 {filter_size, filter_size, input_depth, filter_count});
140 Output filter =
141 ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
142
143 TensorShape output_shape(
144 {batch_size, output_height, output_width, filter_count});
145 Tensor output_data(DT_FLOAT, output_shape);
146 test::FillIota<float>(&output_data, 1.0f);
147 Output output =
148 ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
149
150 Output conv_backprop_input;
151 Output input_sizes_i =
152 ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
153 std::vector<int> dilations{1, dilation, dilation, 1};
154 std::vector<int> explicit_paddings;
155 if (padding == "EXPLICIT") {
156 explicit_paddings = {0,
157 0,
158 static_cast<int>(padding_top),
159 static_cast<int>(padding_bottom),
160 static_cast<int>(padding_left),
161 static_cast<int>(padding_right),
162 0,
163 0};
164 }
165 auto attrs =
166 ops::Conv2DBackpropInput::Attrs().Dilations(dilations).ExplicitPaddings(
167 explicit_paddings);
168 if (const_input_size) {
169 conv_backprop_input = ops::Conv2DBackpropInput(
170 s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
171 {1, stride, stride, 1}, padding, attrs);
172 } else {
173 conv_backprop_input = ops::Conv2DBackpropInput(
174 s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
175 {1, stride, stride, 1}, padding, attrs);
176 }
177 return conv_backprop_input;
178 }
179
GetAttrValue(const NodeDef & node)180 Tensor GetAttrValue(const NodeDef& node) {
181 Tensor tensor;
182 CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
183 return tensor;
184 }
185
GetAttrShape(const NodeDef & node)186 TensorShape GetAttrShape(const NodeDef& node) {
187 return TensorShape(node.attr().at({"shape"}).shape());
188 }
189
SimpleFusedBatchNormGrad(tensorflow::Scope * s,bool is_training)190 Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
191 int batch_size = 16;
192 int input_height = 8;
193 int input_width = 8;
194 int input_channels = 3;
195 TensorShape shape({batch_size, input_height, input_width, input_channels});
196 Tensor data(DT_FLOAT, shape);
197 test::FillIota<float>(&data, 1.0f);
198 Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
199 Output y_backprop =
200 ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
201
202 TensorShape shape_vector({input_channels});
203 Tensor data_vector(DT_FLOAT, shape_vector);
204 test::FillIota<float>(&data_vector, 2.0f);
205 Output scale =
206 ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
207 Output reserve1 =
208 ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
209 Output reserve2 =
210 ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
211
212 ops::FusedBatchNormGrad::Attrs attrs;
213 attrs.is_training_ = is_training;
214 auto output =
215 ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
216 x, scale, reserve1, reserve2, attrs);
217 return output.x_backprop;
218 }
219
220 std::unique_ptr<Cluster> virtual_cluster_;
221 bool gpu_available_;
222 };
223
TEST_F(LayoutOptimizerTest,Conv2DBackpropInput)224 TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
225 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
226 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "EXPLICIT");
227 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
228 GrapplerItem item;
229 TF_CHECK_OK(s.ToGraphDef(&item.graph));
230 LayoutOptimizer optimizer;
231 GraphDef output;
232
233 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
234 NodeMap node_map(&output);
235 string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
236 auto input_sizes_node = node_map.GetNode(input_name);
237 CHECK(input_sizes_node);
238 auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
239 CHECK(conv2d_backprop_node);
240 EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
241 auto input_sizes = GetAttrValue(*input_sizes_node);
242 Tensor input_sizes_expected(DT_INT32, {4});
243 test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
244 test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
245
246 if (gpu_available_) {
247 TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
248 Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
249 std::vector<string> fetch = {"Fetch"};
250 auto tensors_expected =
251 EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
252 auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
253 EXPECT_EQ(1, tensors_expected.size());
254 EXPECT_EQ(1, tensors.size());
255 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
256 }
257 }
258
TEST_F(LayoutOptimizerTest,Conv2DBackpropInputNonConstInputSizes)259 TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
260 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
261 auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
262 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
263 GrapplerItem item;
264 TF_CHECK_OK(s.ToGraphDef(&item.graph));
265 LayoutOptimizer optimizer;
266 GraphDef output;
267
268 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
269 NodeMap node_map(&output);
270 auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
271 CHECK(conv2d_backprop_node);
272 EXPECT_EQ(conv2d_backprop_node->input(0),
273 "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
274 auto input_sizes_node = node_map.GetNode(
275 "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
276 CHECK(input_sizes_node);
277 EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
278 EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
279 }
280
TEST_F(LayoutOptimizerTest,FilterSizeIsOne)281 TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
282 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
283 auto conv = SimpleConv2D(&s, 2, 1, "SAME");
284 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
285 GrapplerItem item;
286 TF_CHECK_OK(s.ToGraphDef(&item.graph));
287 LayoutOptimizer optimizer;
288 GraphDef output;
289 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
290 NodeMap node_map(&output);
291 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
292 }
293
TEST_F(LayoutOptimizerTest,FilterSizeNotOne)294 TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
295 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
296 auto conv = SimpleConv2D(&s, 2, 1, "SAME");
297 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
298 GrapplerItem item;
299 TF_CHECK_OK(s.ToGraphDef(&item.graph));
300 LayoutOptimizer optimizer;
301 GraphDef output;
302 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
303 NodeMap node_map(&output);
304 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
305 }
306
TEST_F(LayoutOptimizerTest,EqualSizeWithValidPadding)307 TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
308 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
309 auto conv = SimpleConv2D(&s, 2, 2, "VALID");
310 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
311 GrapplerItem item;
312 TF_CHECK_OK(s.ToGraphDef(&item.graph));
313 LayoutOptimizer optimizer;
314 GraphDef output;
315 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
316 NodeMap node_map(&output);
317 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
318 }
319
TEST_F(LayoutOptimizerTest,EqualSizeWithSamePadding)320 TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
321 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
322 auto conv = SimpleConv2D(&s, 2, 2, "SAME");
323 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
324 GrapplerItem item;
325 TF_CHECK_OK(s.ToGraphDef(&item.graph));
326 LayoutOptimizer optimizer;
327 GraphDef output;
328 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
329 NodeMap node_map(&output);
330 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
331 }
332
TEST_F(LayoutOptimizerTest,NotEqualSizeWithValidPadding)333 TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
334 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
335 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
336 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
337 GrapplerItem item;
338 TF_CHECK_OK(s.ToGraphDef(&item.graph));
339 LayoutOptimizer optimizer;
340 GraphDef output;
341 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
342 NodeMap node_map(&output);
343 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
344 }
345
TEST_F(LayoutOptimizerTest,ExplicitPadding)346 TEST_F(LayoutOptimizerTest, ExplicitPadding) {
347 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
348 auto conv = SimpleConv2D(&s, 4, 2, "EXPLICIT");
349 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
350 GrapplerItem item;
351 TF_CHECK_OK(s.ToGraphDef(&item.graph));
352 LayoutOptimizer optimizer;
353 GraphDef output;
354 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
355 NodeMap node_map(&output);
356 EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
357 }
358
TEST_F(LayoutOptimizerTest,Pad)359 TEST_F(LayoutOptimizerTest, Pad) {
360 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
361 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
362 auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
363 auto p = ops::Pad(s.WithOpName("p"), conv, c);
364 auto o = ops::Identity(s.WithOpName("o"), p);
365 GrapplerItem item;
366 TF_CHECK_OK(s.ToGraphDef(&item.graph));
367 LayoutOptimizer optimizer;
368 GraphDef output;
369 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
370 NodeMap node_map(&output);
371
372 auto pad = node_map.GetNode("p");
373 EXPECT_EQ(pad->input(0), "Conv2D");
374
375 auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
376 EXPECT_TRUE(pad_const);
377 EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
378 Tensor tensor;
379 EXPECT_TRUE(
380 tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
381 Tensor tensor_expected(DT_INT32, {4, 2});
382 test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
383 test::ExpectTensorEqual<int>(tensor_expected, tensor);
384 }
385
TEST_F(LayoutOptimizerTest,Connectivity)386 TEST_F(LayoutOptimizerTest, Connectivity) {
387 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
388 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
389 auto i1 = ops::Identity(s.WithOpName("i1"), conv);
390 auto i2 = ops::Identity(s.WithOpName("i2"), i1);
391 auto i3 = ops::Identity(s.WithOpName("i3"), i2);
392 GrapplerItem item;
393 TF_CHECK_OK(s.ToGraphDef(&item.graph));
394 // Make the graph not in topological order to test the handling of multi-hop
395 // connectivity (here we say two nodes are connected if all nodes in the
396 // middle are layout agnostic). If the graph is already in topological order,
397 // the problem is easier, where layout optimizer only needs to check
398 // single-hop connectivity.
399 NodeMap node_map_original(&item.graph);
400 auto node_i1 = node_map_original.GetNode("i1");
401 auto node_i2 = node_map_original.GetNode("i2");
402 node_i2->Swap(node_i1);
403 LayoutOptimizer optimizer;
404 GraphDef output;
405 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
406 NodeMap node_map_output(&output);
407 auto node_i2_output = node_map_output.GetNode("i2");
408 // Layout optimizer should process i2, as it detects i2 is connected with the
409 // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
410 // directly connected to the Conv2D node. The two added transposes between
411 // i1 and i2 should cancel each other, and as a result i2 is directly
412 // connected to i1.
413 EXPECT_EQ(node_i2_output->input(0), "i1");
414 }
415
TEST_F(LayoutOptimizerTest,ConnectivityBinaryOpWithInputScalarAnd4D)416 TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
417 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
418 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
419 auto i1 = ops::Identity(s.WithOpName("i1"), conv);
420 auto i2 = ops::Identity(s.WithOpName("i2"), i1);
421 auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
422 auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
423 auto i3 = ops::Identity(s.WithOpName("i3"), sub);
424 auto i4 = ops::Identity(s.WithOpName("i4"), i3);
425 auto i5 = ops::Identity(s.WithOpName("i5"), i4);
426 auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
427 auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
428 auto i6 = ops::Identity(s.WithOpName("i6"), mul);
429 GrapplerItem item;
430 TF_CHECK_OK(s.ToGraphDef(&item.graph));
431 // Make the graph not in topological order to test the handling of multi-hop
432 // connectivity (here we say two nodes are connected if all nodes in the
433 // middle are layout agnostic). If the graph is already in topological order,
434 // the problem is easier, where layout optimizer only needs to check
435 // single-hop connectivity.
436 NodeMap node_map_original(&item.graph);
437 auto node_i1 = node_map_original.GetNode("i1");
438 auto node_mul = node_map_original.GetNode("mul");
439 node_mul->Swap(node_i1);
440 LayoutOptimizer optimizer;
441 GraphDef output;
442 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
443 NodeMap node_map_output(&output);
444 auto mul_node = node_map_output.GetNode("mul");
445 EXPECT_EQ(mul_node->input(0), "scalar_mul");
446 EXPECT_EQ(mul_node->input(1), "i5");
447 }
448
TEST_F(LayoutOptimizerTest,PreserveFetch)449 TEST_F(LayoutOptimizerTest, PreserveFetch) {
450 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
451 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
452 auto i = ops::Identity(s.WithOpName("i"), conv);
453 GrapplerItem item;
454 item.fetch.push_back("Conv2D");
455 TF_CHECK_OK(s.ToGraphDef(&item.graph));
456 LayoutOptimizer optimizer;
457 GraphDef output;
458 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
459 NodeMap node_map(&output);
460 auto conv_node = node_map.GetNode("Conv2D");
461 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
462 }
463
TEST_F(LayoutOptimizerTest,EmptyDevice)464 TEST_F(LayoutOptimizerTest, EmptyDevice) {
465 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
466 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
467 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
468 GrapplerItem item;
469 TF_CHECK_OK(s.ToGraphDef(&item.graph));
470 LayoutOptimizer optimizer;
471 GraphDef output;
472 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
473 NodeMap node_map(&output);
474 auto conv_node = node_map.GetNode("Conv2D");
475 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
476 }
477
TEST_F(LayoutOptimizerTest,GPUDevice)478 TEST_F(LayoutOptimizerTest, GPUDevice) {
479 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
480 auto conv =
481 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
482 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
483 GrapplerItem item;
484 TF_CHECK_OK(s.ToGraphDef(&item.graph));
485 LayoutOptimizer optimizer;
486 GraphDef output;
487 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
488 NodeMap node_map(&output);
489 auto conv_node = node_map.GetNode("Conv2D");
490 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
491 }
492
TEST_F(LayoutOptimizerTest,CPUDeviceLowercase)493 TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
494 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
495 auto conv =
496 SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
497 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
498 GrapplerItem item;
499 TF_CHECK_OK(s.ToGraphDef(&item.graph));
500 LayoutOptimizer optimizer;
501 GraphDef output;
502 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
503 NodeMap node_map(&output);
504 auto conv_node = node_map.GetNode("Conv2D");
505 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
506 }
507
TEST_F(LayoutOptimizerTest,CPUDeviceUppercase)508 TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
509 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
510 auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
511 Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
512 GrapplerItem item;
513 TF_CHECK_OK(s.ToGraphDef(&item.graph));
514 LayoutOptimizer optimizer;
515 GraphDef output;
516 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
517 NodeMap node_map(&output);
518 auto conv_node = node_map.GetNode("Conv2D");
519 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
520 }
521
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingTrue)522 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
523 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
524 auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
525 Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
526 GrapplerItem item;
527 TF_CHECK_OK(s.ToGraphDef(&item.graph));
528 LayoutOptimizer optimizer;
529 GraphDef output;
530 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
531 NodeMap node_map(&output);
532 auto conv_node = node_map.GetNode("FusedBatchNormGrad");
533 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
534 }
535
TEST_F(LayoutOptimizerTest,FusedBatchNormGradTrainingFalse)536 TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
537 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
538 auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
539 Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
540 GrapplerItem item;
541 TF_CHECK_OK(s.ToGraphDef(&item.graph));
542 LayoutOptimizer optimizer;
543 GraphDef output;
544 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
545 NodeMap node_map(&output);
546 auto conv_node = node_map.GetNode("FusedBatchNormGrad");
547 EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
548 }
549
TEST_F(LayoutOptimizerTest,SplitDimC)550 TEST_F(LayoutOptimizerTest, SplitDimC) {
551 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
552 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
553 auto c = ops::Const(s.WithOpName("c"), 3, {});
554 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
555 auto i = ops::Identity(s.WithOpName("i"), split[0]);
556 GrapplerItem item;
557 TF_CHECK_OK(s.ToGraphDef(&item.graph));
558 LayoutOptimizer optimizer;
559 GraphDef output;
560 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
561 NodeMap node_map(&output);
562 auto split_node = node_map.GetNode("split");
563 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
564 EXPECT_EQ(split_node->input(1), "Conv2D");
565 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
566 EXPECT_EQ(split_const->op(), "Const");
567 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
568 }
569
TEST_F(LayoutOptimizerTest,SplitDimH)570 TEST_F(LayoutOptimizerTest, SplitDimH) {
571 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
572 auto conv = SimpleConv2D(&s, 6, 2, "SAME");
573 auto c = ops::Const(s.WithOpName("c"), 1, {});
574 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
575 auto i = ops::Identity(s.WithOpName("i"), split[0]);
576 GrapplerItem item;
577 TF_CHECK_OK(s.ToGraphDef(&item.graph));
578 LayoutOptimizer optimizer;
579 GraphDef output;
580 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
581 NodeMap node_map(&output);
582 auto split_node = node_map.GetNode("split");
583 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
584 EXPECT_EQ(split_node->input(1), "Conv2D");
585 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
586 EXPECT_EQ(split_const->op(), "Const");
587 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
588 }
589
TEST_F(LayoutOptimizerTest,SplitDimW)590 TEST_F(LayoutOptimizerTest, SplitDimW) {
591 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
592 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
593 auto c = ops::Const(s.WithOpName("c"), 2, {});
594 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
595 auto i = ops::Identity(s.WithOpName("i"), split[0]);
596 GrapplerItem item;
597 TF_CHECK_OK(s.ToGraphDef(&item.graph));
598 LayoutOptimizer optimizer;
599 GraphDef output;
600 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
601 NodeMap node_map(&output);
602 auto split_node = node_map.GetNode("split");
603 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
604 EXPECT_EQ(split_node->input(1), "Conv2D");
605 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
606 EXPECT_EQ(split_const->op(), "Const");
607 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
608 }
609
TEST_F(LayoutOptimizerTest,SplitDimN)610 TEST_F(LayoutOptimizerTest, SplitDimN) {
611 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
612 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
613 auto c = ops::Const(s.WithOpName("c"), 0, {});
614 auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
615 auto i = ops::Identity(s.WithOpName("i"), split[0]);
616 GrapplerItem item;
617 TF_CHECK_OK(s.ToGraphDef(&item.graph));
618 LayoutOptimizer optimizer;
619 GraphDef output;
620 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
621 NodeMap node_map(&output);
622 auto split_node = node_map.GetNode("split");
623 EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
624 EXPECT_EQ(split_node->input(1), "Conv2D");
625 auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
626 EXPECT_EQ(split_const->op(), "Const");
627 EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
628 }
629
TEST_F(LayoutOptimizerTest,SplitNonConstDim)630 TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
631 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
632 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
633 auto c = ops::Const(s.WithOpName("c"), 0, {});
634 auto i1 = ops::Identity(s.WithOpName("i1"), c);
635 auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
636 auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
637 GrapplerItem item;
638 TF_CHECK_OK(s.ToGraphDef(&item.graph));
639 LayoutOptimizer optimizer;
640 GraphDef output;
641 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
642 NodeMap node_map(&output);
643 auto split_node = node_map.GetNode("split");
644 EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
645 EXPECT_EQ(split_node->input(1), "Conv2D");
646 auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
647 EXPECT_EQ(map_node->op(), "DataFormatDimMap");
648 EXPECT_EQ(map_node->input(0), "i1");
649 }
650
TEST_F(LayoutOptimizerTest,SplitSamePortToMultipleInputsOfSameNode)651 TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
652 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
653 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
654 auto axis = ops::Const(s.WithOpName("axis"), 3);
655 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
656 auto concat =
657 ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
658 auto o = ops::Identity(s.WithOpName("o"), concat);
659 GrapplerItem item;
660 TF_CHECK_OK(s.ToGraphDef(&item.graph));
661 LayoutOptimizer optimizer;
662 GraphDef output;
663 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
664 NodeMap node_map(&output);
665 auto concat_node = node_map.GetNode("concat");
666 EXPECT_EQ(concat_node->input(0), "split:1");
667 EXPECT_EQ(concat_node->input(1), "split:1");
668 EXPECT_EQ(concat_node->input(2), "split:1");
669 EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
670 auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
671 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
672 }
673
TEST_F(LayoutOptimizerTest,ConcatDimH)674 TEST_F(LayoutOptimizerTest, ConcatDimH) {
675 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
676 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
677 auto axis = ops::Const(s.WithOpName("axis"), 1);
678 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
679 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
680 auto o = ops::Identity(s.WithOpName("o"), concat);
681 GrapplerItem item;
682 TF_CHECK_OK(s.ToGraphDef(&item.graph));
683 LayoutOptimizer optimizer;
684 GraphDef output;
685 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
686 NodeMap node_map(&output);
687 auto concat_node = node_map.GetNode("concat");
688 EXPECT_EQ(concat_node->input(0), "split");
689 EXPECT_EQ(concat_node->input(1), "split:1");
690 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
691 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
692 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
693 }
694
TEST_F(LayoutOptimizerTest,ConcatNonConst)695 TEST_F(LayoutOptimizerTest, ConcatNonConst) {
696 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
697 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
698 auto axis = ops::Const(s.WithOpName("axis"), 1);
699 auto i = ops::Identity(s.WithOpName("i"), axis);
700 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
701 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
702 auto o = ops::Identity(s.WithOpName("o"), concat);
703 GrapplerItem item;
704 TF_CHECK_OK(s.ToGraphDef(&item.graph));
705 LayoutOptimizer optimizer;
706 GraphDef output;
707 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
708 NodeMap node_map(&output);
709 auto concat_node = node_map.GetNode("concat");
710 EXPECT_EQ(concat_node->input(0), "split");
711 EXPECT_EQ(concat_node->input(1), "split:1");
712 EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
713 auto concat_dim =
714 node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
715 EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
716 EXPECT_EQ(concat_dim->input(0), "i");
717 }
718
TEST_F(LayoutOptimizerTest,ConcatDimW)719 TEST_F(LayoutOptimizerTest, ConcatDimW) {
720 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
721 auto conv = SimpleConv2D(&s, 4, 2, "SAME");
722 auto axis = ops::Const(s.WithOpName("axis"), 2);
723 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
724 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
725 auto o = ops::Identity(s.WithOpName("o"), concat);
726 GrapplerItem item;
727 TF_CHECK_OK(s.ToGraphDef(&item.graph));
728 LayoutOptimizer optimizer;
729 GraphDef output;
730 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
731 NodeMap node_map(&output);
732 auto concat_node = node_map.GetNode("concat");
733 EXPECT_EQ(concat_node->input(0), "split");
734 EXPECT_EQ(concat_node->input(1), "split:1");
735 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
736 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
737 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
738 }
739
TEST_F(LayoutOptimizerTest,ConcatDimN)740 TEST_F(LayoutOptimizerTest, ConcatDimN) {
741 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
742 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
743 auto axis = ops::Const(s.WithOpName("axis"), 0);
744 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
745 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
746 auto o = ops::Identity(s.WithOpName("o"), concat);
747 GrapplerItem item;
748 TF_CHECK_OK(s.ToGraphDef(&item.graph));
749 LayoutOptimizer optimizer;
750 GraphDef output;
751 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
752 NodeMap node_map(&output);
753 auto concat_node = node_map.GetNode("concat");
754 EXPECT_EQ(concat_node->input(0), "split");
755 EXPECT_EQ(concat_node->input(1), "split:1");
756 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
757 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
758 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
759 }
760
TEST_F(LayoutOptimizerTest,ConcatDimC)761 TEST_F(LayoutOptimizerTest, ConcatDimC) {
762 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
763 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
764 auto axis = ops::Const(s.WithOpName("axis"), 3);
765 auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
766 auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
767 auto o = ops::Identity(s.WithOpName("o"), concat);
768 GrapplerItem item;
769 TF_CHECK_OK(s.ToGraphDef(&item.graph));
770 LayoutOptimizer optimizer;
771 GraphDef output;
772 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
773 NodeMap node_map(&output);
774 auto concat_node = node_map.GetNode("concat");
775 EXPECT_EQ(concat_node->input(0), "split");
776 EXPECT_EQ(concat_node->input(1), "split:1");
777 EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
778 auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
779 EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
780 }
781
TEST_F(LayoutOptimizerTest,Sum)782 TEST_F(LayoutOptimizerTest, Sum) {
783 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
784 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
785 auto reduction_indices =
786 ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
787 auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
788 auto o = ops::Identity(s.WithOpName("o"), sum);
789 GrapplerItem item;
790 TF_CHECK_OK(s.ToGraphDef(&item.graph));
791 LayoutOptimizer optimizer;
792 GraphDef output;
793 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
794 // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
795 // because of the worse performance in some cases.
796 /*
797 NodeMap node_map(&output);
798 auto sum_node = node_map.GetNode("sum");
799 EXPECT_EQ(sum_node->input(0), "Conv2D");
800 EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
801 auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
802 Tensor tensor;
803 EXPECT_TRUE(
804 tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
805 Tensor tensor_expected(DT_INT32, {3});
806 test::FillValues<int>(&tensor_expected, {0, 2, 3});
807 test::ExpectTensorEqual<int>(tensor_expected, tensor);
808 */
809 }
810
TEST_F(LayoutOptimizerTest,MulScalarAnd4D)811 TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
812 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
813 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
814 auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
815 auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
816 auto o = ops::Identity(s.WithOpName("o"), mul);
817 GrapplerItem item;
818 TF_CHECK_OK(s.ToGraphDef(&item.graph));
819 LayoutOptimizer optimizer;
820 GraphDef output;
821 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
822 NodeMap node_map(&output);
823 auto mul_node = node_map.GetNode("mul");
824 EXPECT_EQ(mul_node->input(0), "scalar");
825 EXPECT_EQ(mul_node->input(1), "Conv2D");
826 }
827
TEST_F(LayoutOptimizerTest,Mul4DAndScalar)828 TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
829 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
830 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
831 auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
832 auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
833 auto o = ops::Identity(s.WithOpName("o"), mul);
834 GrapplerItem item;
835 TF_CHECK_OK(s.ToGraphDef(&item.graph));
836 LayoutOptimizer optimizer;
837 GraphDef output;
838 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
839 NodeMap node_map(&output);
840 auto mul_node = node_map.GetNode("mul");
841 EXPECT_EQ(mul_node->input(0), "Conv2D");
842 EXPECT_EQ(mul_node->input(1), "scalar");
843 }
844
TEST_F(LayoutOptimizerTest,Mul4DAndUnknownRank)845 TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
846 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
847 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
848 auto unknown_rank =
849 ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
850 ops::Placeholder::Shape(PartialTensorShape()));
851 Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
852 Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
853 auto o = ops::AddN(s.WithOpName("o"), {mul, c});
854 GrapplerItem item;
855 TF_CHECK_OK(s.ToGraphDef(&item.graph));
856 LayoutOptimizer optimizer;
857 GraphDef output;
858 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
859 NodeMap node_map(&output);
860 auto mul_node = node_map.GetNode("mul");
861 // Node mul should not be processed by layout optimizer, because one of its
862 // inputs is of unknown rank.
863 EXPECT_EQ(mul_node->input(0),
864 "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
865 EXPECT_EQ(mul_node->input(1), "unknown");
866 }
867
TEST_F(LayoutOptimizerTest,Mul4DAnd4D)868 TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
869 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
870 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
871 auto i = ops::Identity(s.WithOpName("i"), conv);
872 auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
873 auto o = ops::Identity(s.WithOpName("o"), mul);
874 GrapplerItem item;
875 TF_CHECK_OK(s.ToGraphDef(&item.graph));
876 LayoutOptimizer optimizer;
877 GraphDef output;
878 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
879 NodeMap node_map(&output);
880 auto mul_node = node_map.GetNode("mul");
881 EXPECT_EQ(mul_node->input(0), "Conv2D");
882 EXPECT_EQ(mul_node->input(1), "i");
883 }
884
TEST_F(LayoutOptimizerTest,Mul4DAndVector)885 TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
886 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
887 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
888 auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
889 auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
890 auto o = ops::Identity(s.WithOpName("o"), mul);
891 GrapplerItem item;
892 TF_CHECK_OK(s.ToGraphDef(&item.graph));
893 LayoutOptimizer optimizer;
894 GraphDef output;
895 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
896 NodeMap node_map(&output);
897 auto mul_node = node_map.GetNode("mul");
898 EXPECT_EQ(mul_node->input(0), "Conv2D");
899 EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
900 auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
901 Tensor tensor;
902 EXPECT_TRUE(
903 tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
904 Tensor tensor_expected(DT_INT32, {4});
905 test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
906 test::ExpectTensorEqual<int>(tensor_expected, tensor);
907 }
908
TEST_F(LayoutOptimizerTest,MulVectorAnd4D)909 TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
910 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
911 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
912 auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
913 auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
914 auto o = ops::Identity(s.WithOpName("o"), mul);
915 GrapplerItem item;
916 TF_CHECK_OK(s.ToGraphDef(&item.graph));
917 LayoutOptimizer optimizer;
918 GraphDef output;
919 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
920 NodeMap node_map(&output);
921 auto mul_node = node_map.GetNode("mul");
922 EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
923 EXPECT_EQ(mul_node->input(1), "Conv2D");
924 auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
925 Tensor tensor;
926 EXPECT_TRUE(
927 tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
928 Tensor tensor_expected(DT_INT32, {4});
929 test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
930 test::ExpectTensorEqual<int>(tensor_expected, tensor);
931 }
932
TEST_F(LayoutOptimizerTest,SliceConst)933 TEST_F(LayoutOptimizerTest, SliceConst) {
934 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
935 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
936 auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
937 auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
938 auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
939 auto o = ops::Identity(s.WithOpName("o"), slice);
940 GrapplerItem item;
941 TF_CHECK_OK(s.ToGraphDef(&item.graph));
942 LayoutOptimizer optimizer;
943 GraphDef output;
944 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
945 NodeMap node_map(&output);
946 auto slice_node = node_map.GetNode("slice");
947 EXPECT_EQ(slice_node->input(0), "Conv2D");
948 EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
949 EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
950
951 auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
952 Tensor begin_tensor;
953 EXPECT_TRUE(begin_tensor.FromProto(
954 begin_const->mutable_attr()->at({"value"}).tensor()));
955 Tensor begin_tensor_expected(DT_INT32, {4});
956 test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
957 test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
958
959 auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
960 Tensor size_tensor;
961 EXPECT_TRUE(size_tensor.FromProto(
962 size_const->mutable_attr()->at({"value"}).tensor()));
963 Tensor size_tensor_expected(DT_INT32, {4});
964 test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
965 test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
966 }
967
TEST_F(LayoutOptimizerTest,SliceNonConst)968 TEST_F(LayoutOptimizerTest, SliceNonConst) {
969 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
970 auto conv = SimpleConv2D(&s, 5, 2, "VALID");
971 auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
972 auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
973 auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
974 auto isize = ops::Identity(s.WithOpName("isize"), size);
975 auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
976 auto o = ops::Identity(s.WithOpName("o"), slice);
977 GrapplerItem item;
978 TF_CHECK_OK(s.ToGraphDef(&item.graph));
979 LayoutOptimizer optimizer;
980 GraphDef output;
981 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
982 NodeMap node_map(&output);
983 auto slice_node = node_map.GetNode("slice");
984 EXPECT_EQ(slice_node->input(0), "Conv2D");
985 EXPECT_EQ(slice_node->input(1),
986 "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
987 EXPECT_EQ(slice_node->input(2),
988 "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
989 auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
990 EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
991 EXPECT_EQ(perm1->input(0), "ibegin");
992 auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
993 EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
994 EXPECT_EQ(perm2->input(0), "isize");
995 }
996
TEST_F(LayoutOptimizerTest,DoNotApplyOptimizerTwice)997 TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
998 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
999 auto scalar =
1000 ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
1001 auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
1002 auto o = ops::Identity(s.WithOpName("o"), mul);
1003 GrapplerItem item;
1004 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1005 LayoutOptimizer optimizer;
1006 GraphDef output;
1007 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1008 EXPECT_TRUE(errors::IsInvalidArgument(status));
1009 }
1010
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAnd4D)1011 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
1012 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1013 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1014 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
1015 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1016 GrapplerItem item;
1017 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1018 LayoutOptimizer optimizer;
1019 GraphDef output;
1020 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1021 NodeMap node_map(&output);
1022 auto shapen_node = node_map.GetNode("shapen");
1023 EXPECT_EQ(shapen_node->input(0), "Conv2D");
1024 EXPECT_EQ(shapen_node->input(1), "Conv2D");
1025 auto add_node = node_map.GetNode("add");
1026 EXPECT_EQ(add_node->input(0),
1027 "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1028 EXPECT_EQ(add_node->input(1),
1029 "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1030 auto vec_permute1 =
1031 node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1032 EXPECT_EQ(vec_permute1->input(0), "shapen");
1033 EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
1034 auto vec_permute2 =
1035 node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1036 EXPECT_EQ(vec_permute2->input(0), "shapen:1");
1037 EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
1038 }
1039
TEST_F(LayoutOptimizerTest,ShapeNWithInputsVectorAnd4D)1040 TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
1041 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1042 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1043 auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
1044 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
1045 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1046 GrapplerItem item;
1047 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1048 LayoutOptimizer optimizer;
1049 GraphDef output;
1050 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1051 NodeMap node_map(&output);
1052 auto shapen_node = node_map.GetNode("shapen");
1053 EXPECT_EQ(shapen_node->input(0), "vector");
1054 EXPECT_EQ(shapen_node->input(1), "Conv2D");
1055 auto add_node = node_map.GetNode("add");
1056 EXPECT_EQ(add_node->input(0), "shapen");
1057 EXPECT_EQ(add_node->input(1),
1058 "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1059 auto vec_permute =
1060 node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
1061 EXPECT_EQ(vec_permute->input(0), "shapen:1");
1062 EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
1063 }
1064
TEST_F(LayoutOptimizerTest,ShapeNWithInputs4DAndNoNeedToTransform4D)1065 TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
1066 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1067 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1068 auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1069 auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
1070 Output i2 = ops::Identity(s.WithOpName("i2"), i1);
1071 auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
1072 auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
1073 GrapplerItem item;
1074 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1075 LayoutOptimizer optimizer;
1076 GraphDef output;
1077 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1078 NodeMap node_map(&output);
1079 auto shapen_node = node_map.GetNode("shapen");
1080 EXPECT_EQ(shapen_node->input(0), "Conv2D");
1081 EXPECT_EQ(shapen_node->input(1), "i2");
1082 }
1083
TEST_F(LayoutOptimizerTest,Switch)1084 TEST_F(LayoutOptimizerTest, Switch) {
1085 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1086 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1087 ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
1088 auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
1089 auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
1090 auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
1091 GrapplerItem item;
1092 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1093 LayoutOptimizer optimizer;
1094 GraphDef output;
1095 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1096 NodeMap node_map(&output);
1097 auto switch_node = node_map.GetNode("switch");
1098 EXPECT_EQ(switch_node->input(0), "Conv2D");
1099 EXPECT_EQ(switch_node->input(1), "ctrl");
1100 auto i1_node = node_map.GetNode("i1");
1101 auto i2_node = node_map.GetNode("i2");
1102 auto trans1 = node_map.GetNode(i1_node->input(0));
1103 EXPECT_EQ(trans1->input(0), "switch:1");
1104 auto trans2 = node_map.GetNode(i2_node->input(0));
1105 EXPECT_EQ(trans2->input(0), "switch");
1106 }
1107
TEST_F(LayoutOptimizerTest,MergeBothInputsConvertible)1108 TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
1109 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1110 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1111 Output i1 = ops::Identity(s.WithOpName("i1"), conv);
1112 auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
1113 auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1114 GrapplerItem item;
1115 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1116 LayoutOptimizer optimizer;
1117 GraphDef output;
1118 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1119 NodeMap node_map(&output);
1120 auto merge_node = node_map.GetNode("merge");
1121 EXPECT_EQ(merge_node->input(0), "Conv2D");
1122 EXPECT_EQ(merge_node->input(1), "i1");
1123 auto i2_node = node_map.GetNode("i2");
1124 EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1125 auto transpose =
1126 node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1127 EXPECT_EQ(transpose->input(0), "merge");
1128 }
1129
TEST_F(LayoutOptimizerTest,MergeOneInputNotConvertible)1130 TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
1131 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1132 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1133 auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1134 auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
1135 auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1136 GrapplerItem item;
1137 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1138 LayoutOptimizer optimizer;
1139 GraphDef output;
1140 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1141 NodeMap node_map(&output);
1142 auto merge_node = node_map.GetNode("merge");
1143 EXPECT_EQ(merge_node->input(0), "tensor_4d");
1144 EXPECT_EQ(merge_node->input(1),
1145 "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1146 }
1147
TEST_F(LayoutOptimizerTest,Complex)1148 TEST_F(LayoutOptimizerTest, Complex) {
1149 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1150 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1151 auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
1152 auto i = ops::Identity(s.WithOpName("i"), comp);
1153 GrapplerItem item;
1154 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1155 LayoutOptimizer optimizer;
1156 GraphDef output;
1157 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1158 NodeMap node_map(&output);
1159 auto merge_node = node_map.GetNode("complex");
1160 EXPECT_EQ(merge_node->input(0), "Conv2D");
1161 EXPECT_EQ(merge_node->input(1), "Conv2D");
1162 auto trans =
1163 node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1164 EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
1165 }
1166
TEST_F(LayoutOptimizerTest,IdentityNWithInputsVectorAnd4D)1167 TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
1168 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1169 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1170 auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
1171 auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
1172 auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
1173 GrapplerItem item;
1174 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1175 LayoutOptimizer optimizer;
1176 GraphDef output;
1177 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1178 NodeMap node_map(&output);
1179 auto i = node_map.GetNode("identity_n");
1180 EXPECT_EQ(i->input(0), "vector");
1181 EXPECT_EQ(i->input(1), "Conv2D");
1182 auto trans =
1183 node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1184 EXPECT_EQ(trans->input(0), "identity_n:1");
1185 auto add_node = node_map.GetNode("add");
1186 EXPECT_EQ(add_node->input(0), "identity_n");
1187 EXPECT_EQ(add_node->input(1),
1188 "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1189 }
1190
TEST_F(LayoutOptimizerTest,LoopNoLiveLock)1191 TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
1192 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1193 auto c = ops::Const(s.WithOpName("const"), 3.0f, {8, 3, 3, 2});
1194 auto merge = ops::Merge(s.WithOpName("merge"), {c, c});
1195 auto i0 = ops::Identity(s.WithOpName("i0"), merge.output);
1196 ops::Variable v_ctrl(s.WithOpName("v_ctrl"), {}, DT_BOOL);
1197 auto sw = ops::Switch(s.WithOpName("switch"), i0, v_ctrl);
1198 auto next = ops::NextIteration(s.WithOpName("next"), sw.output_true);
1199 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1200 auto mul = ops::Mul(s.WithOpName("mul"), conv, sw.output_false);
1201 GrapplerItem item;
1202 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1203 NodeMap node_map_original(&item.graph);
1204 auto merge_node = node_map_original.GetNode("merge");
1205 // Modify the graph to create a loop
1206 merge_node->set_input(1, "next");
1207 LayoutOptimizer optimizer;
1208 GraphDef output;
1209 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1210 NodeMap node_map(&output);
1211 auto conv_node = node_map.GetNode("Conv2D");
1212 EXPECT_EQ(conv_node->input(0),
1213 "Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer");
1214 auto mul_node = node_map.GetNode("mul");
1215 EXPECT_EQ(mul_node->input(0),
1216 "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1217 }
1218
TEST_F(LayoutOptimizerTest,DevicePlacement)1219 TEST_F(LayoutOptimizerTest, DevicePlacement) {
1220 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1221 auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1222 auto shape = ops::Shape(s.WithOpName("s"), conv);
1223 auto i = ops::Identity(s.WithOpName("i"), shape);
1224 GrapplerItem item;
1225 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1226 VirtualPlacer virtual_placer(virtual_cluster_.get());
1227 for (auto& node : *item.graph.mutable_node()) {
1228 string device = virtual_placer.get_canonical_device_name(node);
1229 node.set_device(device);
1230 }
1231 LayoutOptimizer optimizer;
1232 GraphDef output;
1233 Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1234 NodeMap node_map(&output);
1235 auto vec_permute =
1236 node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
1237 EXPECT_EQ(vec_permute->attr().at("_kernel").s(), "host");
1238 }
1239 } // namespace
1240 } // namespace grappler
1241 } // namespace tensorflow
1242