1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17
18 #include <algorithm>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/memory_types.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/utils.h"
38 #include "tensorflow/core/grappler/utils/frame.h"
39 #include "tensorflow/core/grappler/utils/graph_view.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/protobuf/device_properties.pb.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43
44 namespace tensorflow {
45 namespace grappler {
46
47 namespace {
48
49 constexpr char kOptimizedSuffix[] = "LayoutOptimizer";
50 constexpr char kAttrKSize[] = "ksize";
51 constexpr char kAttrStrides[] = "strides";
52 constexpr char kAttrDilations[] = "dilations";
53 constexpr char kAttrExplicitPaddings[] = "explicit_paddings";
54 constexpr char kAttrDataFormat[] = "data_format";
55 constexpr char kAttrIsTraining[] = "is_training";
56 constexpr char kAttrValue[] = "value";
57 constexpr char kAttrN[] = "N";
58 constexpr char kAttrT[] = "T";
59 constexpr char kAttrNumSplit[] = "num_split";
60 constexpr char kAttrNumOuts[] = "num_outs";
61 constexpr char kAttrKeepDims[] = "keep_dims";
62 constexpr char kAttrSqueezeDims[] = "squeeze_dims";
63 constexpr char kOpTranspose[] = "Transpose";
64 constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute";
65 constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap";
66 constexpr char kOpConst[] = "Const";
67 constexpr char kReshape[] = "Reshape";
68 constexpr char kReshapeConst[] = "ReshapeConst";
69 constexpr int kRank = 4;
70 constexpr int kUnknownRank = -1;
71 constexpr int kInvalidRank = -2;
72
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format,bool * missing)73 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
74 absl::string_view src_data_format,
75 bool* missing) {
76 const auto* attr = node.GetAttr(kAttrDataFormat);
77 if (attr != nullptr) {
78 return attr->s() == src_data_format;
79 }
80 *missing = true;
81 return false;
82 }
83
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format)84 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
85 absl::string_view src_data_format) {
86 bool missing = false;
87 return AttrDataFormatMatch(node, src_data_format, &missing);
88 }
89
IsNonFloatingConv2D(const utils::MutableNodeView & node)90 bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
91 if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
92 const auto* attr = node.GetAttr(kAttrT);
93 if (attr != nullptr) {
94 return !kDataTypeIsFloating.Contains(attr->type());
95 }
96 }
97 return false;
98 }
99
100 // Utils for layout agnostic transposer.
101
IsComparisonOp(const NodeDef & node)102 bool IsComparisonOp(const NodeDef& node) {
103 bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
104 IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
105 IsLessEqual(node) || IsNotEqual(node);
106 return is_compare;
107 }
108
GetRegularFaninPorts(const utils::MutableNodeView & node)109 std::vector<int> GetRegularFaninPorts(const utils::MutableNodeView& node) {
110 const int num_regular_fanins = node.NumRegularFanins();
111 std::vector<int> values(num_regular_fanins);
112 std::iota(values.begin(), values.end(), 0);
113 return values;
114 }
115
GetConcatDataFaninPorts(const utils::MutableNodeView & node)116 std::vector<int> GetConcatDataFaninPorts(const utils::MutableNodeView& node) {
117 const auto* n_attr = node.GetAttr(kAttrN);
118 const int n = n_attr != nullptr ? n_attr->i() : 0;
119 const int start = (node.GetOp() == "Concat") ? 1 : 0;
120 const int end = start + n;
121 std::vector<int> values(end - start);
122 std::iota(values.begin(), values.end(), start);
123 return values;
124 }
125
126 struct ComparatorByNodeNameAndIndex {
operator ()tensorflow::grappler::__anonf97689a50111::ComparatorByNodeNameAndIndex127 bool operator()(const utils::MutableFaninView& node1,
128 const utils::MutableFaninView& node2) const {
129 auto* node1_view = node1.node_view();
130 auto* node2_view = node2.node_view();
131 auto name_compare = node1_view->GetName().compare(node2_view->GetName());
132 if (name_compare == 0) {
133 return node1.index() < node2.index();
134 }
135 return name_compare < 0;
136 }
137 };
138
IsHostMemory(const NodeDef & node,int output_port)139 bool IsHostMemory(const NodeDef& node, int output_port) {
140 DeviceNameUtils::ParsedName parsed_name;
141 if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
142 DeviceType device_type(parsed_name.type);
143 Status s = FindKernelDef(device_type, node, nullptr, nullptr);
144 if (s.ok()) {
145 tensorflow::MemoryTypeVector in_mtypes;
146 tensorflow::MemoryTypeVector out_mtypes;
147 s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
148 node, &in_mtypes, &out_mtypes);
149 if (s.ok()) {
150 if (out_mtypes[output_port] == HOST_MEMORY) {
151 return true;
152 }
153 }
154 } else {
155 return true;
156 }
157 }
158 return false;
159 }
160
GetDimensionIndicesFromLabel(const absl::flat_hash_map<char,int> & dim_indices,absl::Span<const char> labels)161 std::vector<int> GetDimensionIndicesFromLabel(
162 const absl::flat_hash_map<char, int>& dim_indices,
163 absl::Span<const char> labels) {
164 std::vector<int> indices;
165 indices.reserve(labels.size());
166 for (const auto& label : labels) {
167 indices.push_back(dim_indices.at(label));
168 }
169 return indices;
170 }
171
172 // RAII-styled object for keeping track of 4D to 5D data format
173 // upgrade/conversion. Currently only NHWC -> NDHWC and NCHW -> NCDHW are
174 // supported.
175 class ScopedDataFormatUpgrader {
176 public:
ScopedDataFormatUpgrader(TransposeContext * context,int rank)177 ScopedDataFormatUpgrader(TransposeContext* context, int rank)
178 : context_(context) {
179 if (rank == 5 && IsSupportedDataFormat(context_->src_format) &&
180 IsSupportedDataFormat(context_->dst_format)) {
181 old_src_format_ = context_->src_format;
182 old_dst_format_ = context_->dst_format;
183 std::string new_src_format = GetUpgradedDataFormat(context_->src_format);
184 std::string new_dst_format = GetUpgradedDataFormat(context_->dst_format);
185 context_->AssignDeviceAndDataFormats(context_->target_device,
186 new_src_format, new_dst_format);
187 upgraded_ = true;
188 }
189 }
190
191 ScopedDataFormatUpgrader(const ScopedDataFormatUpgrader&) = delete;
192 ScopedDataFormatUpgrader& operator=(const ScopedDataFormatUpgrader&) = delete;
193
~ScopedDataFormatUpgrader()194 ~ScopedDataFormatUpgrader() {
195 if (upgraded_) {
196 context_->AssignDeviceAndDataFormats(context_->target_device,
197 old_src_format_, old_dst_format_);
198 }
199 }
200
201 private:
IsSupportedDataFormat(absl::string_view data_format)202 bool IsSupportedDataFormat(absl::string_view data_format) {
203 return data_format == "NHWC" || data_format == "NCHW";
204 }
205
GetUpgradedDataFormat(absl::string_view data_format)206 std::string GetUpgradedDataFormat(absl::string_view data_format) {
207 if (data_format == "NHWC") {
208 return "NDHWC";
209 }
210
211 DCHECK_EQ(data_format, "NCHW");
212 return "NCDHW";
213 }
214
215 TransposeContext* context_ = nullptr;
216 bool upgraded_ = false;
217 std::string old_src_format_;
218 std::string old_dst_format_;
219 };
220
221 } // namespace
222
223 // TransposeContext.
224
InitializeTransposeContext(const GrapplerItem & item,const Cluster * cluster,TransposeContext * context)225 Status TransposeContext::InitializeTransposeContext(const GrapplerItem& item,
226 const Cluster* cluster,
227 TransposeContext* context) {
228 DCHECK(context != nullptr);
229 context->graph_properties = absl::make_unique<GraphProperties>(item);
230 TF_RETURN_IF_ERROR(context->graph_properties->InferStatically(false));
231 TF_RETURN_IF_ERROR(context->graph_properties->AnnotateOutputShapes(
232 &context->graph, /*allow_symbolic_shapes=*/true));
233 Status status;
234 context->graph_view =
235 absl::make_unique<utils::MutableGraphView>(&context->graph, &status);
236 TF_RETURN_IF_ERROR(status);
237 context->num_nodes = context->graph.node_size();
238 const auto& nodes_to_preserve = item.NodesToPreserve();
239 context->nodes_to_preserve = absl::flat_hash_set<string>(
240 nodes_to_preserve.begin(), nodes_to_preserve.end());
241 TF_RETURN_IF_ERROR(context->frames.InferFromGraph(context->graph));
242 if (cluster != nullptr) {
243 context->virtual_placer =
244 absl::make_unique<const VirtualPlacer>(cluster->GetDevices());
245 }
246 return Status::OK();
247 }
248
249 // Sets data formats to convert from and to for specified device type.
AssignDeviceAndDataFormats(absl::string_view target_device,absl::string_view src_format,absl::string_view dst_format)250 void TransposeContext::AssignDeviceAndDataFormats(
251 absl::string_view target_device, absl::string_view src_format,
252 absl::string_view dst_format) {
253 this->target_device = string(target_device);
254 this->src_format = string(src_format);
255 this->dst_format = string(dst_format);
256 this->src_dim_indices = GetDimensionIndices(src_format);
257 this->dst_dim_indices = GetDimensionIndices(dst_format);
258 this->src_to_dst = GetPermutation(this->src_dim_indices, dst_format);
259 this->dst_to_src = GetPermutation(this->dst_dim_indices, src_format);
260 }
261
262 // Transposer.
263
ShouldProcess(const TransposeContext & context,const utils::MutableNodeView & node) const264 bool Transposer::ShouldProcess(const TransposeContext& context,
265 const utils::MutableNodeView& node) const {
266 const auto* node_def = node.node();
267 const string& device_name =
268 GetDeviceName(context.virtual_placer.get(), *node_def);
269 string device;
270 string task;
271 const bool is_on_target_device =
272 DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
273 absl::StrContains(absl::AsciiStrToLower(device),
274 absl::AsciiStrToLower(context.target_device));
275
276 // Only checks data format for layout sensitive op.
277 const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
278 AttrDataFormatMatch(node, context.src_format);
279
280 // Only transposes floating point nodes.
281 const bool is_integer_conv2d = IsNonFloatingConv2D(node);
282
283 return is_on_target_device && data_format_match && !is_integer_conv2d &&
284 !context.nodes_to_preserve.contains(node_def->name()) &&
285 !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
286 }
287
CreateConstPermNode(TransposeContext * context,absl::string_view node_name,absl::string_view device,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node)288 Status Transposer::CreateConstPermNode(TransposeContext* context,
289 absl::string_view node_name,
290 absl::string_view device,
291 absl::Span<const int> permutation,
292 absl::string_view control_node_name,
293 utils::MutationNewNode* added_node) {
294 auto* graph_view = context->graph_view.get();
295 DCHECK(!graph_view->HasNode(node_name));
296
297 NodeDef node;
298 node.set_name(string(node_name));
299 node.set_op(kOpConst);
300 node.set_device(string(device));
301
302 if (!control_node_name.empty()) {
303 node.add_input(string(control_node_name));
304 }
305
306 AttrValue attr_data_type;
307 attr_data_type.set_type(DT_INT32);
308 node.mutable_attr()->insert({"dtype", attr_data_type});
309
310 AttrValue attr_tensor;
311 Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()}));
312 for (int i = 0, end = permutation.size(); i < end; i++) {
313 tensor.flat<int>()(i) = permutation[i];
314 }
315 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
316 node.mutable_attr()->insert({"value", attr_tensor});
317
318 Status status;
319 *added_node =
320 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
321 return status;
322 }
323
CreateTransposeNode(TransposeContext * context,absl::string_view name_format,const DataType & data_type,absl::string_view device,TensorShapeProto fanin_shape,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node,string * transpose_node_name)324 Status Transposer::CreateTransposeNode(
325 TransposeContext* context, absl::string_view name_format,
326 const DataType& data_type, absl::string_view device,
327 TensorShapeProto fanin_shape, absl::Span<const int> permutation,
328 absl::string_view control_node_name, utils::MutationNewNode* added_node,
329 string* transpose_node_name) {
330 const string node_name = absl::Substitute(name_format, kOpTranspose);
331 auto* graph_view = context->graph_view.get();
332 DCHECK(!graph_view->HasNode(node_name));
333 *transpose_node_name = node_name;
334
335 NodeDef node;
336 node.set_name(node_name);
337 node.set_op(kOpTranspose);
338 node.set_device(string(device));
339
340 AttrValue attr_data_type;
341 attr_data_type.set_type(data_type);
342 node.mutable_attr()->insert({"T", attr_data_type});
343
344 AttrValue attr_data_type_perm;
345 attr_data_type_perm.set_type(DT_INT32);
346 node.mutable_attr()->insert({"Tperm", attr_data_type_perm});
347
348 if (!fanin_shape.unknown_rank()) {
349 TF_RETURN_IF_ERROR(
350 PermuteSingle(absl::StrCat("fanin shape in", node.name()), permutation,
351 fanin_shape.mutable_dim()));
352 AttrValue attr_output_shape;
353 *attr_output_shape.mutable_list()->add_shape() = fanin_shape;
354 node.mutable_attr()->insert({kAttrOutputShape, attr_output_shape});
355 }
356
357 // Create Const Node
358 utils::MutationNewNode const_perm_added_node;
359 const string const_perm_node_name =
360 absl::Substitute(name_format, "PermConst");
361 TF_RETURN_IF_ERROR(CreateConstPermNode(context, const_perm_node_name, device,
362 permutation, control_node_name,
363 &const_perm_added_node));
364 // Add place holder for 1st input.
365 node.add_input("");
366 // Connect const_perm_node to 2nd input of transpose_node.
367 node.add_input(const_perm_node_name);
368
369 Status status;
370 *added_node =
371 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
372 return status;
373 }
374
UpdateFaninEdgesWithOp(TransposeContext * context,absl::Span<const int> dst_ports,utils::MutableNodeView * dst_node,absl::string_view op)375 Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context,
376 absl::Span<const int> dst_ports,
377 utils::MutableNodeView* dst_node,
378 absl::string_view op) {
379 const bool is_in_frame = context->frames.IsInFrame(*dst_node->node());
380 for (int dst_port : dst_ports) {
381 auto& fanin_port = dst_node->GetRegularFanin(dst_port);
382 auto* fanin_node_view = fanin_port.node_view();
383
384 TF_RETURN_IF_ERROR(
385 UpdateEdge(context,
386 GetFaninNameFormat(dst_node->GetName(), dst_port,
387 context->src_format, context->dst_format),
388 op, /*input_shape=*/nullptr, /*is_in_frame=*/is_in_frame,
389 /*is_src_format_to_dst_format=*/true, fanin_port.index(),
390 dst_port, fanin_node_view, dst_node));
391 }
392 return Status::OK();
393 }
394
UpdateFanoutEdgesWithOp(TransposeContext * context,absl::Span<const int> src_ports,utils::MutableNodeView * src_node,absl::string_view op)395 Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context,
396 absl::Span<const int> src_ports,
397 utils::MutableNodeView* src_node,
398 absl::string_view op) {
399 // Update attr _output_shapes for output ports.
400 const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape);
401 AttrValue shape_attr_copy;
402 if (op == kOpTranspose && output_shape_attr != nullptr) {
403 shape_attr_copy = *output_shape_attr;
404 for (int port : src_ports) {
405 auto* shape = shape_attr_copy.mutable_list()->mutable_shape(port);
406 if (shape->unknown_rank()) continue;
407 TF_RETURN_IF_ERROR(
408 PermuteSingle(absl::StrCat("output shape attribute at port ", port,
409 " in", src_node->GetName()),
410 context->src_to_dst, shape->mutable_dim()));
411 }
412 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
413 src_node, kAttrOutputShape, shape_attr_copy);
414 }
415
416 const bool is_in_frame = context->frames.IsInFrame(*src_node->node());
417 // We might modify the output set in the loop. Make a copy first.
418 // Use a set with custom comparator to order output nodes by node name,
419 // so that we can keep transposer name deterministic.
420 for (int src_port : src_ports) {
421 const auto& fanouts_src_port = src_node->GetRegularFanout(src_port);
422 std::vector<utils::MutableFaninView> sorted_fanouts(
423 fanouts_src_port.begin(), fanouts_src_port.end());
424 std::sort(sorted_fanouts.begin(), sorted_fanouts.end(),
425 ComparatorByNodeNameAndIndex());
426 int num_downstream_transposers = 0;
427 for (const auto& fanout : sorted_fanouts) {
428 TF_RETURN_IF_ERROR(UpdateEdge(
429 context,
430 GetFanoutNameFormat(src_node->GetName(), src_port,
431 num_downstream_transposers++, context->src_format,
432 context->dst_format),
433 op, &shape_attr_copy, /*is_in_frame=*/is_in_frame,
434 /*is_src_format_to_dst_format=*/false, src_port, fanout.index(),
435 src_node, fanout.node_view()));
436 }
437 }
438 return Status::OK();
439 }
440
CreateDataFormatNode(TransposeContext * context,absl::string_view node_name,absl::string_view op,absl::string_view device,const DataType & data_type,bool is_fanin_on_host,bool is_src_format_to_dst_format,utils::MutationNewNode * added_node)441 Status Transposer::CreateDataFormatNode(
442 TransposeContext* context, absl::string_view node_name,
443 absl::string_view op, absl::string_view device, const DataType& data_type,
444 bool is_fanin_on_host, bool is_src_format_to_dst_format,
445 utils::MutationNewNode* added_node) {
446 auto* graph_view = context->graph_view.get();
447 DCHECK(!graph_view->HasNode(node_name));
448
449 // Create the node
450 NodeDef node;
451 node.set_name(string(node_name));
452
453 // Set up parameters of node.
454 node.set_op(string(op));
455 node.set_device(string(device));
456 AttrValue attr_data_type;
457 attr_data_type.set_type(data_type);
458 node.mutable_attr()->insert({"T", attr_data_type});
459
460 // The inputs of a DataFormat op could be in host memory for ops such as
461 // Reshape. In such cases, run the kernel on the host too.
462 if (is_fanin_on_host) {
463 AttrValue attr_kernel;
464 attr_kernel.set_s("host");
465 node.mutable_attr()->insert({"_kernel", attr_kernel});
466 }
467
468 AttrValue src_format;
469 src_format.set_s(is_src_format_to_dst_format ? context->src_format
470 : context->dst_format);
471 node.mutable_attr()->insert({kAttrSrcFormat, src_format});
472 AttrValue dst_format;
473 dst_format.set_s(is_src_format_to_dst_format ? context->dst_format
474 : context->src_format);
475 node.mutable_attr()->insert({kAttrDstFormat, dst_format});
476
477 // Add place holder for 1st input field.
478 node.add_input("");
479
480 Status status;
481 *added_node =
482 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
483 return status;
484 }
485
UpdateEdge(TransposeContext * context,absl::string_view name_format,absl::string_view op,const AttrValue * input_shape,bool is_in_frame,bool is_src_format_to_dst_format,const int src_port,const int dst_port,utils::MutableNodeView * src_node,utils::MutableNodeView * dst_node)486 Status Transposer::UpdateEdge(
487 TransposeContext* context, absl::string_view name_format,
488 absl::string_view op, const AttrValue* input_shape, bool is_in_frame,
489 bool is_src_format_to_dst_format, const int src_port, const int dst_port,
490 utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) {
491 DCHECK(src_node != nullptr);
492 DCHECK(dst_node != nullptr);
493 auto* src_node_def = src_node->node();
494 auto* dst_node_def = dst_node->node();
495
496 // TODO(lyandy): Minimize device parsing/fetching.
497 const string device = GetDeviceName(
498 context->virtual_placer.get(),
499 is_src_format_to_dst_format ? *dst_node_def : *src_node_def);
500 DataType data_type =
501 is_src_format_to_dst_format
502 ? context->graph_properties
503 ->GetInputProperties(dst_node->GetName())[dst_port]
504 .dtype()
505 : context->graph_properties
506 ->GetOutputProperties(src_node->GetName())[src_port]
507 .dtype();
508
509 utils::MutationNewNode added_node;
510 string added_node_name;
511 if (op == kOpTranspose) {
512 TensorShapeProto input_shape_proto;
513 input_shape_proto.set_unknown_rank(true);
514 if (input_shape != nullptr) {
515 input_shape_proto = input_shape->list().shape(src_port);
516 } else {
517 const auto* src_node_shape_attr = src_node->GetAttr(kAttrOutputShape);
518 if (src_node_shape_attr != nullptr) {
519 input_shape_proto = src_node_shape_attr->list().shape(src_port);
520 }
521 }
522 const string control_node_name =
523 is_in_frame ? AsControlDependency(src_node_def->name()) : "";
524 const std::vector<int>& permutation =
525 is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src;
526 TF_RETURN_IF_ERROR(CreateTransposeNode(
527 context, name_format, data_type, device, input_shape_proto, permutation,
528 control_node_name, &added_node, &added_node_name));
529 } else if (op == kOpDataFormatVecPermute || op == kOpDataFormatDimMap) {
530 DeviceNameUtils::ParsedName parsed_name;
531 bool is_fanin_on_host =
532 DeviceNameUtils::ParseFullName(
533 GetDeviceName(context->virtual_placer.get(), *src_node_def),
534 &parsed_name) &&
535 parsed_name.type != "CPU" && IsHostMemory(*src_node_def, src_port);
536 const string node_name = absl::Substitute(name_format, op);
537 TF_RETURN_IF_ERROR(CreateDataFormatNode(
538 context, node_name, op, device, data_type, is_fanin_on_host,
539 is_src_format_to_dst_format, &added_node));
540 added_node_name = node_name;
541 } else {
542 return Status(error::INVALID_ARGUMENT,
543 absl::StrCat("Unsupported op \"", op,
544 "\". Supported ops are Transpose, "
545 "DataFormatVecPerm, DataFormatDimMap."));
546 }
547
548 // Connect src_node to 1st input of added_node.
549 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
550 mutation->AddOrUpdateRegularFanin(added_node, 0,
551 {src_node->GetName(), src_port});
552
553 // Connect output of added_node to dst_node:dst_port.
554 mutation->AddOrUpdateRegularFanin(dst_node, dst_port, {added_node_name, 0});
555
556 return Status::OK();
557 }
558
GetFanoutPortRank(const utils::MutableNodeView & node,int port) const559 int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
560 int port) const {
561 const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
562 if (output_shape_attr == nullptr ||
563 output_shape_attr->list().shape_size() <= port) {
564 return kInvalidRank;
565 }
566 const auto& shape = output_shape_attr->list().shape(port);
567 if (shape.unknown_rank()) {
568 return kUnknownRank;
569 }
570 return shape.dim_size();
571 }
572
IsFanoutPortRankN(const utils::MutableNodeView & node,int port,int n) const573 bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
574 int n) const {
575 return GetFanoutPortRank(node, port) == n;
576 }
577
IsFanoutPortsRankN(const utils::MutableNodeView & node,absl::Span<const int> ports,int n) const578 bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
579 absl::Span<const int> ports, int n) const {
580 for (const auto& port : ports) {
581 if (!IsFanoutPortRankN(node, port, n)) {
582 return false;
583 }
584 }
585 return true;
586 }
587
GetFaninPortRank(const utils::MutableNodeView & node,int port) const588 int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
589 int port) const {
590 if (port < node.NumRegularFanins() && port >= 0) {
591 const auto& regular_fanin = node.GetRegularFanin(port);
592 return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
593 }
594 return kInvalidRank;
595 }
596
IsFaninPortRankN(const utils::MutableNodeView & node,int port,int n) const597 bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
598 int n) const {
599 return GetFaninPortRank(node, port) == n;
600 }
601
IsFaninPortDimsNIfConst(const utils::MutableNodeView & node,int port,absl::Span<const int> dims) const602 bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
603 int port,
604 absl::Span<const int> dims) const {
605 if (port < node.NumRegularFanins() && port >= 0) {
606 const auto& regular_fanin = node.GetRegularFanin(port);
607 const auto* fanin_node_view = regular_fanin.node_view();
608 if (!IsConstant(*fanin_node_view->node())) {
609 return true;
610 }
611 // If fanin is a Const, check tensor to see if dimensions match.
612 const auto* value_attr = fanin_node_view->GetAttr(kAttrValue);
613 if (value_attr == nullptr) {
614 return false;
615 }
616 Tensor tensor;
617 if (!tensor.FromProto(value_attr->tensor())) {
618 return false;
619 }
620 const int dims_size = dims.size();
621 if (tensor.dims() != dims_size) {
622 return false;
623 }
624 for (int i = 0; i < dims_size; ++i) {
625 if (tensor.dim_size(i) != dims[i]) {
626 return false;
627 }
628 }
629 return true;
630 }
631 return false;
632 }
633
IsFaninPortsDimsNIfConst(const utils::MutableNodeView & node,absl::Span<const int> ports,absl::Span<const int> dims) const634 bool Transposer::IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
635 absl::Span<const int> ports,
636 absl::Span<const int> dims) const {
637 for (const auto& port : ports) {
638 if (!IsFaninPortDimsNIfConst(node, port, dims)) {
639 return false;
640 }
641 }
642 return true;
643 }
644
CanProcessNode(const TransposeContext & context,const utils::MutableNodeView & node) const645 bool Transposer::CanProcessNode(const TransposeContext& context,
646 const utils::MutableNodeView& node) const {
647 return !context.nodes_to_preserve.contains(node.GetName()) &&
648 !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
649 }
650
GetFaninNameFormat(absl::string_view node_name,int port,absl::string_view src_format,absl::string_view dst_format)651 string Transposer::GetFaninNameFormat(absl::string_view node_name, int port,
652 absl::string_view src_format,
653 absl::string_view dst_format) {
654 return absl::StrCat(node_name, "-", port, "-$0", src_format, "To", dst_format,
655 "-", kOptimizedSuffix);
656 }
657
GetFanoutNameFormat(absl::string_view node_name,int port,int index,absl::string_view src_format,absl::string_view dst_format)658 string Transposer::GetFanoutNameFormat(absl::string_view node_name, int port,
659 int index, absl::string_view src_format,
660 absl::string_view dst_format) {
661 return absl::StrCat(node_name, "-", port, "-", index, "-$0", dst_format, "To",
662 src_format, "-", kOptimizedSuffix);
663 }
664
LayoutOptimizerNode(absl::string_view node_name)665 string Transposer::LayoutOptimizerNode(absl::string_view node_name) {
666 return absl::StrCat(node_name, "-", kOptimizedSuffix);
667 }
668
GetReshapeNodeNameFormat(absl::string_view node_name,int index,absl::string_view src_format,absl::string_view dst_format)669 string Transposer::GetReshapeNodeNameFormat(absl::string_view node_name,
670 int index,
671 absl::string_view src_format,
672 absl::string_view dst_format) {
673 return absl::StrCat(node_name, "-", index, "-", kReshape, src_format, "To",
674 dst_format);
675 }
676
GetShapeConstNodeNameFormat(absl::string_view node_name,int index)677 string Transposer::GetShapeConstNodeNameFormat(absl::string_view node_name,
678 int index) {
679 return absl::StrCat(node_name, "-", index, "-", kReshapeConst);
680 }
681
682 // Layout sensitive transposer.
683
GetLayoutSensitiveNodeDataFormat(const utils::MutableNodeView & node)684 inline string GetLayoutSensitiveNodeDataFormat(
685 const utils::MutableNodeView& node) {
686 const auto* attr = node.GetAttr(kAttrDataFormat);
687 if (attr != nullptr) {
688 return attr->s();
689 }
690 return "";
691 }
692
UpdateNode(TransposeContext * context,utils::MutableNodeView * node)693 Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
694 utils::MutableNodeView* node) {
695 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
696 AttrValue data_format_attr;
697 data_format_attr.set_s(context->dst_format);
698 mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr);
699
700 auto permute_attr = [&context, &node,
701 &mutation](absl::string_view attr_name) {
702 const auto* attr = node->GetAttr(attr_name);
703 if (attr != nullptr) {
704 AttrValue attr_copy(*attr);
705 TF_RETURN_IF_ERROR(PermuteSingle(
706 absl::StrCat(attr_name, " attribute in", node->GetName()),
707 context->src_to_dst, attr_copy.mutable_list()->mutable_i()));
708 mutation->AddOrUpdateNodeAttr(node, attr_name, attr_copy);
709 }
710 return Status::OK();
711 };
712
713 // Update attrs.
714 TF_RETURN_IF_ERROR(permute_attr(kAttrStrides));
715 TF_RETURN_IF_ERROR(permute_attr(kAttrKSize));
716 TF_RETURN_IF_ERROR(permute_attr(kAttrDilations));
717
718 const auto* explicit_paddings_attr = node->GetAttr(kAttrExplicitPaddings);
719 if (explicit_paddings_attr != nullptr && explicit_paddings_attr->has_list() &&
720 explicit_paddings_attr->list().i_size() > 0) {
721 AttrValue explicit_paddings_attr_copy(*explicit_paddings_attr);
722 TF_RETURN_IF_ERROR(PermuteDouble(
723 absl::StrCat("explicit_paddings attribute in", node->GetName()),
724 context->src_to_dst,
725 explicit_paddings_attr_copy.mutable_list()->mutable_i()));
726 mutation->AddOrUpdateNodeAttr(node, kAttrExplicitPaddings,
727 explicit_paddings_attr_copy);
728 }
729
730 return Status::OK();
731 }
732
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)733 Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
734 TransposeContext* context, utils::MutableNodeView* node) {
735 DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
736 const int rank = GetFanoutPortRank(*node, 0);
737 if (rank != 4 && rank != 5) {
738 return Status::OK();
739 }
740 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
741 if (!ShouldProcess(*context, *node)) {
742 return Status::OK();
743 }
744 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
745 << "' with op '" << node->GetOp() << "' from data format '"
746 << context->src_format << "' to '" << context->dst_format << "'";
747 TF_RETURN_IF_ERROR(UpdateNode(context, node));
748 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
749 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
750 return context->graph_view->GetMutationBuilder()->Apply();
751 }
752
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)753 Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
754 utils::MutableNodeView* node) {
755 DCHECK(IsAvgPoolGrad(*node->node()));
756 if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) {
757 return Status::OK();
758 }
759 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
760 << "' with op '" << node->GetOp() << "' from data format '"
761 << context->src_format << "' to '" << context->dst_format << "'";
762 TF_RETURN_IF_ERROR(UpdateNode(context, node));
763 TF_RETURN_IF_ERROR(
764 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
765 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
766 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
767 return context->graph_view->GetMutationBuilder()->Apply();
768 }
769
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)770 Status BiasAddTransposer::TransposeNode(TransposeContext* context,
771 utils::MutableNodeView* node) {
772 // This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
773 // supports different data format.
774 DCHECK(IsBiasAddV2(*node->node()));
775 const int rank = GetFanoutPortRank(*node, 0);
776 if (rank != 4 && rank != 5) {
777 return Status::OK();
778 }
779 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
780 return Status::OK();
781 }
782 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
783 << "' with op '" << node->GetOp() << "' from data format '"
784 << context->src_format << "' to '" << context->dst_format << "'";
785 // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
786 // second or the last dim. Therefore, we use the original 4D data format in
787 // the context to update the node. For the input/output tensor, the
788 // corresponding 4D or 5D data format is needed.
789 TF_RETURN_IF_ERROR(UpdateNode(context, node));
790 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
791 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
792 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
793 return context->graph_view->GetMutationBuilder()->Apply();
794 }
795
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)796 Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
797 utils::MutableNodeView* node) {
798 DCHECK(IsBiasAddGrad(*node->node()));
799 const int rank = GetFaninPortRank(*node, 0);
800 if (rank != 4 && rank != 5) {
801 return Status::OK();
802 }
803 if (!ShouldProcess(*context, *node)) {
804 return Status::OK();
805 }
806 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
807 << "' with op '" << node->GetOp() << "' from data format '"
808 << context->src_format << "' to '" << context->dst_format << "'";
809 // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
810 // second or the last dim. Therefore, we use the original 4D data format in
811 // the context to update the node. For the input tensor, the corresponding 4D
812 // or 5D data format is needed.
813 TF_RETURN_IF_ERROR(UpdateNode(context, node));
814 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
815 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
816 // No need to update output shape, as it is always of shape 1-D with size the
817 // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
818 // used.
819 return context->graph_view->GetMutationBuilder()->Apply();
820 }
821
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)822 Status Conv2DBackpropFilterTransposer::TransposeNode(
823 TransposeContext* context, utils::MutableNodeView* node) {
824 DCHECK(IsConv2DBackpropFilter(*node->node()) ||
825 IsDepthwiseConv2dNativeBackpropFilter(*node->node()));
826 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
827 return Status::OK();
828 }
829 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
830 << "' with op '" << node->GetOp() << "' from data format '"
831 << context->src_format << "' to '" << context->dst_format << "'";
832 TF_RETURN_IF_ERROR(UpdateNode(context, node));
833 TF_RETURN_IF_ERROR(
834 UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
835 // No need to update output shape, as it is always of shape
836 // [filter_height, filter_width, in_channels, out_channels], regardless of
837 // whether NCHW or NHWC is used.
838 return context->graph_view->GetMutationBuilder()->Apply();
839 }
840
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)841 Status Conv2DBackpropInputTransposer::TransposeNode(
842 TransposeContext* context, utils::MutableNodeView* node) {
843 DCHECK(IsConv2DBackpropInput(*node->node()) ||
844 IsDepthwiseConv2dNativeBackpropInput(*node->node()));
845 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
846 return Status::OK();
847 }
848
849 const auto& fanin = node->GetRegularFanin(0);
850 auto* fanin_node = fanin.node_view();
851 const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
852 if (output_shape_attr == nullptr) {
853 VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
854 << " because it is missing attribute " << kAttrOutputShape;
855 return Status::OK();
856 }
857 TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
858 if (fanin_shape.dim_size() != 1) {
859 VLOG(3) << fanin_node->GetName() << " is not a vector.";
860 return Status::OK();
861 }
862
863 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
864 << "' with op '" << node->GetOp() << "' from data format '"
865 << context->src_format << "' to '" << context->dst_format << "'";
866 TF_RETURN_IF_ERROR(UpdateNode(context, node));
867 TF_RETURN_IF_ERROR(
868 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
869 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
870 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
871 return context->graph_view->GetMutationBuilder()->Apply();
872 }
873
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)874 Status Conv3DTransposer::TransposeNode(TransposeContext* context,
875 utils::MutableNodeView* node) {
876 DCHECK(IsConv3D(*node->node()));
877 const int rank = GetFanoutPortRank(*node, 0);
878 if (rank != 5) {
879 return Status::OK();
880 }
881 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
882 if (!ShouldProcess(*context, *node)) {
883 return Status::OK();
884 }
885 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
886 << "' with op '" << node->GetOp() << "' from data format '"
887 << context->src_format << "' to '" << context->dst_format << "'";
888 TF_RETURN_IF_ERROR(UpdateNode(context, node));
889 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
890 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
891 return context->graph_view->GetMutationBuilder()->Apply();
892 }
893
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)894 Status Conv3DBackpropFilterTransposer::TransposeNode(
895 TransposeContext* context, utils::MutableNodeView* node) {
896 DCHECK(IsConv3DBackpropFilterV2(*node->node()));
897 const int rank = GetFanoutPortRank(*node, 0);
898 if (rank != 5) {
899 return Status::OK();
900 }
901 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
902 if (!ShouldProcess(*context, *node)) {
903 return Status::OK();
904 }
905 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
906 << "' with op '" << node->GetOp() << "' from data format '"
907 << context->src_format << "' to '" << context->dst_format << "'";
908 TF_RETURN_IF_ERROR(UpdateNode(context, node));
909 TF_RETURN_IF_ERROR(
910 UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
911 // No need to update output shape, as it is always of shape
912 // [filter_height, filter_width, in_channels, out_channels], regardless of
913 // whether NCHW or NHWC is used.
914 return context->graph_view->GetMutationBuilder()->Apply();
915 }
916
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)917 Status Conv3DBackpropInputTransposer::TransposeNode(
918 TransposeContext* context, utils::MutableNodeView* node) {
919 DCHECK(IsConv3DBackpropInputV2(*node->node()));
920 const int rank = GetFanoutPortRank(*node, 0);
921 if (rank != 5) {
922 return Status::OK();
923 }
924 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
925 if (!ShouldProcess(*context, *node)) {
926 return Status::OK();
927 }
928 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
929 << "' with op '" << node->GetOp() << "' from data format '"
930 << context->src_format << "' to '" << context->dst_format << "'";
931 TF_RETURN_IF_ERROR(UpdateNode(context, node));
932 TF_RETURN_IF_ERROR(
933 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
934 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
935 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
936 return context->graph_view->GetMutationBuilder()->Apply();
937 }
938
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)939 Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context,
940 utils::MutableNodeView* node) {
941 DCHECK(IsFusedBatchNormEx(*node->node()));
942 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
943 return Status::OK();
944 }
945 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
946 << "' with op '" << node->GetOp() << "' from data format '"
947 << context->src_format << "' to '" << context->dst_format << "'";
948 TF_RETURN_IF_ERROR(UpdateNode(context, node));
949 if (node->NumRegularFanins() == 6) {
950 TF_RETURN_IF_ERROR(
951 UpdateFaninEdgesWithOp(context, {0, 5}, node, kOpTranspose));
952 } else {
953 TF_RETURN_IF_ERROR(
954 UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
955 }
956 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
957 return context->graph_view->GetMutationBuilder()->Apply();
958 }
959
IsTraining(const utils::MutableNodeView & node) const960 bool FusedBatchNormGradTransposer::IsTraining(
961 const utils::MutableNodeView& node) const {
962 const auto* is_training_attr = node.GetAttr(kAttrIsTraining);
963 if (is_training_attr != nullptr) {
964 return is_training_attr->b();
965 }
966 return false;
967 }
968
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)969 Status FusedBatchNormGradTransposer::TransposeNode(
970 TransposeContext* context, utils::MutableNodeView* node) {
971 DCHECK(IsFusedBatchNormGrad(*node->node()));
972 const int rank = GetFanoutPortRank(*node, 0);
973 if (rank != 4 && rank != 5) {
974 return Status::OK();
975 }
976 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
977 if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
978 return Status::OK();
979 }
980 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
981 << "' with op '" << node->GetOp() << "' from data format '"
982 << context->src_format << "' to '" << context->dst_format << "'";
983 TF_RETURN_IF_ERROR(UpdateNode(context, node));
984 TF_RETURN_IF_ERROR(
985 UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
986 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
987 return context->graph_view->GetMutationBuilder()->Apply();
988 }
989
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)990 Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context,
991 utils::MutableNodeView* node) {
992 DCHECK(IsMaxPoolV2(*node->node()));
993 // We check data_input's shape instead, because the shape inference of
994 // MaxPoolV2 is not able to infer the shape when ksize or strides is not
995 // constant.
996 const auto& data_fanin = node->GetRegularFanin(0);
997 auto* data_fanin_node = data_fanin.node_view();
998 if (!ShouldProcess(*context, *node) ||
999 !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 4)) {
1000 return Status::OK();
1001 }
1002 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1003 << "' with op '" << node->GetOp() << "' from data format '"
1004 << context->src_format << "' to '" << context->dst_format << "'";
1005 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1006 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1007 TF_RETURN_IF_ERROR(
1008 UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1009 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1010 return context->graph_view->GetMutationBuilder()->Apply();
1011 }
1012
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1013 Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context,
1014 utils::MutableNodeView* node) {
1015 DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node()));
1016 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1017 return Status::OK();
1018 }
1019 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1020 << "' with op '" << node->GetOp() << "' from data format '"
1021 << context->src_format << "' to '" << context->dst_format << "'";
1022 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1023 TF_RETURN_IF_ERROR(
1024 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1025 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1026 return context->graph_view->GetMutationBuilder()->Apply();
1027 }
1028
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1029 Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context,
1030 utils::MutableNodeView* node) {
1031 DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node()));
1032 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1033 return Status::OK();
1034 }
1035 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1036 << "' with op '" << node->GetOp() << "' from data format '"
1037 << context->src_format << "' to '" << context->dst_format << "'";
1038 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1039 TF_RETURN_IF_ERROR(
1040 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1041 TF_RETURN_IF_ERROR(
1042 UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute));
1043 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1044 return context->graph_view->GetMutationBuilder()->Apply();
1045 }
1046
1047 // Layout agnostic transposer.
1048
IsValidConstPermTransposeNode(const utils::MutableNodeView & node,absl::Span<const int> permutation)1049 inline bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
1050 absl::Span<const int> permutation) {
1051 Tensor tensor;
1052 if (!GetValueAttrFromConstInputNode(node, IsTranspose, 1, &tensor)) {
1053 return false;
1054 }
1055 const int permutation_size = permutation.size();
1056 if (tensor.NumElements() != permutation_size) {
1057 return false;
1058 }
1059
1060 const auto& tensor_data = tensor.unaligned_flat<int32>();
1061 for (int i = 0; i < permutation_size; i++) {
1062 if (permutation[i] != tensor_data(i)) {
1063 return false;
1064 }
1065 }
1066 return true;
1067 }
1068
IsValidDataFormatNode(const utils::MutableNodeView & node,absl::string_view src_format,absl::string_view dst_format)1069 inline bool IsValidDataFormatNode(const utils::MutableNodeView& node,
1070 absl::string_view src_format,
1071 absl::string_view dst_format) {
1072 if (!IsDataFormatOp(node)) {
1073 return false;
1074 }
1075 const auto* src_format_attr = node.GetAttr(kAttrSrcFormat);
1076 if (src_format_attr == nullptr || src_format_attr->s() != src_format) {
1077 return false;
1078 }
1079 const auto* dst_format_attr = node.GetAttr(kAttrDstFormat);
1080 if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) {
1081 return false;
1082 }
1083 return true;
1084 }
1085
IsLayoutOptimizerAddedDstToSrcTranspose(const TransposeContext & context,const utils::MutableNodeView & node)1086 inline bool IsLayoutOptimizerAddedDstToSrcTranspose(
1087 const TransposeContext& context, const utils::MutableNodeView& node) {
1088 return node.node_index() >= context.num_nodes &&
1089 IsValidConstPermTransposeNode(node, context.dst_to_src);
1090 }
1091
IsLayoutOptimizerAddedDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node)1092 inline bool IsLayoutOptimizerAddedDstToSrcTransform(
1093 const TransposeContext& context, const utils::MutableNodeView& node) {
1094 return node.node_index() >= context.num_nodes &&
1095 (IsValidConstPermTransposeNode(node, context.dst_to_src) ||
1096 IsValidDataFormatNode(node, context.dst_format, context.src_format));
1097 }
1098
IsAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1099 bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
1100 const TransposeContext& context, const utils::MutableNodeView& node) const {
1101 std::deque<utils::MutableNodeView*> queue;
1102 absl::flat_hash_set<utils::MutableNodeView*> visited_nodes;
1103 auto data_node_pos = GetDataFaninPorts(node);
1104 for (const int pos : data_node_pos) {
1105 const auto& fanin = node.GetRegularFanin(pos);
1106 auto* fanin_node = fanin.node_view();
1107 queue.push_back(fanin_node);
1108 visited_nodes.insert(fanin_node);
1109 }
1110 // The code will exit this while loop in one iteration in most cases, as the
1111 // graph is already topologically sorted.
1112 while (!queue.empty()) {
1113 utils::MutableNodeView* current_node = queue.front();
1114 queue.pop_front();
1115 if (IsLayoutOptimizerAddedDstToSrcTransform(context, *current_node)) {
1116 return true;
1117 }
1118 // We only continue searching if the path is connected through
1119 // format-agnostic nodes.
1120 if (IsLayoutAgnosticOp(*current_node->node())) {
1121 auto current_node_pos = GetDataFaninPorts(*current_node);
1122 for (const auto& pos : current_node_pos) {
1123 const auto& fanin = current_node->GetRegularFanin(pos);
1124 auto* fanin_node = fanin.node_view();
1125 if (visited_nodes.insert(fanin_node).second) {
1126 queue.push_back(fanin_node);
1127 }
1128 }
1129 }
1130 }
1131 return false;
1132 }
1133
GetVariadicNDFaninPorts(const TransposeContext & context,const utils::MutableNodeView & node,int rank) const1134 std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
1135 const TransposeContext& context, const utils::MutableNodeView& node,
1136 int rank) const {
1137 std::vector<int> ports;
1138 const int num_regular_fanins = node.NumRegularFanins();
1139 ports.reserve(num_regular_fanins);
1140 for (int i = 0; i < num_regular_fanins; ++i) {
1141 const auto& regular_fanin = node.GetRegularFanin(i);
1142 auto* regular_fanin_node = regular_fanin.node_view();
1143 int regular_fanin_port = regular_fanin.index();
1144 if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
1145 ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1146 IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1147 IsLayoutOptimizerAddedDstToSrcTranspose(context,
1148 *regular_fanin_node))) {
1149 ports.push_back(i);
1150 }
1151 }
1152 return ports;
1153 }
1154
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1155 Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
1156 TransposeContext* context, utils::MutableNodeView* node) {
1157 DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
1158 const int rank = GetFanoutPortRank(*node, 0);
1159 if (rank != 4 && rank != 5) {
1160 return Status::OK();
1161 }
1162 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1163 if (!ShouldProcess(*context, *node) ||
1164 !IsAfterDstToSrcTransform(*context, *node)) {
1165 return Status::OK();
1166 }
1167 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1168 << "' with op '" << node->GetOp() << "' from data format '"
1169 << context->src_format << "' to '" << context->dst_format << "'";
1170 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1171 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1172 return context->graph_view->GetMutationBuilder()->Apply();
1173 }
1174
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1175 Status AddNTransposer::TransposeNode(TransposeContext* context,
1176 utils::MutableNodeView* node) {
1177 DCHECK(IsAddN(*node->node()));
1178 const int rank = GetFanoutPortRank(*node, 0);
1179 if (rank != 4 && rank != 5) {
1180 return Status::OK();
1181 }
1182 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1183 if (!ShouldProcess(*context, *node) ||
1184 !IsAfterDstToSrcTransform(*context, *node)) {
1185 return Status::OK();
1186 }
1187 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1188 << "' with op '" << node->GetOp() << "' from data format '"
1189 << context->src_format << "' to '" << context->dst_format << "'";
1190 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1191 node, kOpTranspose));
1192 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1193 return context->graph_view->GetMutationBuilder()->Apply();
1194 }
1195
IsNDOperateWithMD(const utils::MutableNodeView & node,int n,int m)1196 bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
1197 int n, int m) {
1198 return IsFaninPortRankN(node, 0, n) && IsFaninPortRankN(node, 1, m);
1199 }
1200
IsFaninShapeSupported(const utils::MutableNodeView & node,int rank)1201 bool BinaryOpTransposer::IsFaninShapeSupported(
1202 const utils::MutableNodeView& node, int rank) {
1203 return (IsNDOperateWithMD(node, rank, 0) ||
1204 IsNDOperateWithMD(node, rank, 1) ||
1205 IsNDOperateWithMD(node, rank, rank) ||
1206 IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
1207 }
1208
GetNDDataFaninPorts(const utils::MutableNodeView & node,int rank)1209 std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
1210 const utils::MutableNodeView& node, int rank) {
1211 std::vector<int> values;
1212 if (IsFaninPortRankN(node, 0, rank)) {
1213 values.push_back(0);
1214 }
1215 if (IsFaninPortRankN(node, 1, rank)) {
1216 values.push_back(1);
1217 }
1218 return values;
1219 }
1220
AddNodeReshape(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,absl::string_view input_name,absl::string_view shape_const_node_name,const DataType & data_type)1221 Status BinaryOpTransposer::AddNodeReshape(
1222 utils::Mutation* mutation, absl::string_view node_name,
1223 absl::string_view node_device, absl::string_view input_name,
1224 absl::string_view shape_const_node_name, const DataType& data_type) {
1225 NodeDef new_node;
1226 new_node.set_name(string(node_name));
1227 new_node.add_input(string(input_name));
1228 new_node.add_input(string(shape_const_node_name));
1229 new_node.set_op(kReshape);
1230 new_node.set_device(string(node_device));
1231
1232 AttrValue attr_type_indices;
1233 attr_type_indices.set_type(DT_INT32);
1234 new_node.mutable_attr()->insert({"Tshape", attr_type_indices});
1235
1236 AttrValue attr_type_params;
1237 attr_type_params.set_type(data_type);
1238 new_node.mutable_attr()->insert({"T", attr_type_params});
1239
1240 Status status;
1241 mutation->AddNode(std::move(new_node), &status);
1242 return status;
1243 }
1244
AddNodeShapeConst(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,bool node_in_frame,int num_channels,absl::string_view depended_node,int rank)1245 Status BinaryOpTransposer::AddNodeShapeConst(
1246 utils::Mutation* mutation, absl::string_view node_name,
1247 absl::string_view node_device, bool node_in_frame, int num_channels,
1248 absl::string_view depended_node, int rank) {
1249 NodeDef new_node;
1250 new_node.set_name(string(node_name));
1251 new_node.set_op(kOpConst);
1252 new_node.set_device(string(node_device));
1253 AttrValue attr_data_type;
1254 attr_data_type.set_type(DT_INT32);
1255 new_node.mutable_attr()->insert({"dtype", attr_data_type});
1256
1257 AttrValue attr_tensor;
1258 Tensor tensor(DT_INT32, TensorShape({rank}));
1259 std::vector<int> shape(rank, 1);
1260 shape[1] = num_channels;
1261 for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1262 tensor.flat<int>()(i) = shape[i];
1263 }
1264 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1265 new_node.mutable_attr()->insert({"value", attr_tensor});
1266 if (node_in_frame) {
1267 // This is to ensure the transpose node and the const node are in the same
1268 // frame.
1269 // TODO(halehri): Add Test that exercises this condition.
1270 new_node.add_input(AsControlDependency(string(depended_node)));
1271 }
1272
1273 Status status;
1274 mutation->AddNode(std::move(new_node), &status);
1275 return status;
1276 }
1277
MaybeReshapeVectorFanin(TransposeContext * context,utils::MutableNodeView * node,int rank)1278 Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
1279 utils::MutableNodeView* node,
1280 int rank) {
1281 int vector_index = -1;
1282 if (IsNDOperateWithMD(*node, rank, 1)) {
1283 vector_index = 1;
1284 } else if (IsNDOperateWithMD(*node, 1, rank)) {
1285 vector_index = 0;
1286 }
1287 if (vector_index != -1) {
1288 const string& node_name = node->GetName();
1289 const string& node_device = node->GetDevice();
1290 string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat(
1291 node_name, vector_index, context->src_format, context->dst_format));
1292 string shape_const_node_name = LayoutOptimizerNode(
1293 GetShapeConstNodeNameFormat(node_name, vector_index));
1294 const auto& fanin = node->GetRegularFanin(vector_index);
1295 auto* fanin_node = fanin.node_view();
1296 const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
1297 if (output_shape_attr == nullptr) {
1298 return errors::InvalidArgument("Missing attribute ", kAttrOutputShape);
1299 }
1300 int vector_size =
1301 output_shape_attr->list().shape(fanin.index()).dim(0).size();
1302 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
1303 TF_RETURN_IF_ERROR(
1304 AddNodeShapeConst(mutation, shape_const_node_name, node_device,
1305 context->frames.IsInFrame(*node->node()), vector_size,
1306 fanin_node->GetName(), rank));
1307 const auto* t_attr = node->GetAttr(kAttrT);
1308 if (t_attr == nullptr) {
1309 return errors::InvalidArgument("Missing attribute ", kAttrT);
1310 }
1311 TF_RETURN_IF_ERROR(
1312 AddNodeReshape(mutation, reshape_node_name, node_device,
1313 TensorIdToString({fanin_node->GetName(), fanin.index()}),
1314 shape_const_node_name, t_attr->type()));
1315 mutation->AddOrUpdateRegularFanin(node, vector_index,
1316 {reshape_node_name, 0});
1317 }
1318 return Status::OK();
1319 }
1320
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1321 Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
1322 utils::MutableNodeView* node) {
1323 DCHECK(IsBinaryOp(*node->node()));
1324 const int rank = GetFanoutPortRank(*node, 0);
1325 if (rank != 4 && rank != 5) {
1326 return Status::OK();
1327 }
1328 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1329 if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
1330 !IsAfterDstToSrcTransform(*context, *node)) {
1331 return Status::OK();
1332 }
1333 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1334 << "' with op '" << node->GetOp() << "' from data format '"
1335 << context->src_format << "' to '" << context->dst_format << "'";
1336 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1337 context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
1338 TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
1339 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1340 return context->graph_view->GetMutationBuilder()->Apply();
1341 }
1342
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1343 Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
1344 utils::MutableNodeView* node) {
1345 DCHECK(IsConcat(*node->node()));
1346 const int rank = GetFanoutPortRank(*node, 0);
1347 if (rank != 4 && rank != 5) {
1348 return Status::OK();
1349 }
1350 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1351 if (!ShouldProcess(*context, *node) ||
1352 !IsAfterDstToSrcTransform(*context, *node)) {
1353 return Status::OK();
1354 }
1355 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1356 context, GetConcatDataFaninPorts(*node), node, kOpTranspose));
1357 int axis_node = 0;
1358 if (node->GetOp() == "ConcatV2") {
1359 const auto* n_attr = node->GetAttr(kAttrN);
1360 if (n_attr != nullptr) {
1361 axis_node = n_attr->i();
1362 }
1363 }
1364 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1365 << "' with op '" << node->GetOp() << "' from data format '"
1366 << context->src_format << "' to '" << context->dst_format << "'";
1367 TF_RETURN_IF_ERROR(
1368 UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
1369 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1370 return context->graph_view->GetMutationBuilder()->Apply();
1371 }
1372
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1373 Status FillOpTransposer::TransposeNode(TransposeContext* context,
1374 utils::MutableNodeView* node) {
1375 DCHECK(IsFill(*node->node()));
1376 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1377 !IsFaninPortDimsNIfConst(*node, 0, {4}) ||
1378 !IsAfterDstToSrcTransform(*context, *node)) {
1379 return Status::OK();
1380 }
1381 TF_RETURN_IF_ERROR(
1382 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1383 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1384 return context->graph_view->GetMutationBuilder()->Apply();
1385 }
1386
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1387 Status IdentityNTransposer::TransposeNode(TransposeContext* context,
1388 utils::MutableNodeView* node) {
1389 DCHECK(IsIdentityN(*node->node()));
1390 const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
1391
1392 // Temporarily upgrade the context to obtain the number of 5D fanin ports.
1393 std::vector<int> ports_5d;
1394 {
1395 ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1396 ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
1397 }
1398
1399 if (!ShouldProcess(*context, *node)) {
1400 return Status::OK();
1401 }
1402
1403 if (!ports_4d.empty()) {
1404 TF_RETURN_IF_ERROR(
1405 UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
1406 TF_RETURN_IF_ERROR(
1407 UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
1408 }
1409
1410 if (!ports_5d.empty()) {
1411 ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1412 TF_RETURN_IF_ERROR(
1413 UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
1414 TF_RETURN_IF_ERROR(
1415 UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
1416 }
1417 return context->graph_view->GetMutationBuilder()->Apply();
1418 }
1419
IsEveryFaninAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1420 bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform(
1421 const TransposeContext& context, const utils::MutableNodeView& node) const {
1422 for (const auto& regular_fanin : node.GetRegularFanins()) {
1423 auto* regular_fanin_node = regular_fanin.node_view();
1424 if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) &&
1425 ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1426 IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1427 IsLayoutOptimizerAddedDstToSrcTranspose(context,
1428 *regular_fanin_node))) {
1429 continue;
1430 }
1431 return false;
1432 }
1433 return true;
1434 }
1435
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1436 Status MergeTransposer::TransposeNode(TransposeContext* context,
1437 utils::MutableNodeView* node) {
1438 DCHECK(IsMerge(*node->node()));
1439 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1440 !IsEveryFaninAfterDstToSrcTransform(*context, *node)) {
1441 return Status::OK();
1442 }
1443 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1444 node, kOpTranspose));
1445 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1446 return context->graph_view->GetMutationBuilder()->Apply();
1447 }
1448
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1449 Status PadTransposer::TransposeNode(TransposeContext* context,
1450 utils::MutableNodeView* node) {
1451 DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) ||
1452 IsPad(*node->node()));
1453 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1454 !IsFaninPortDimsNIfConst(*node, 1, {4, 2}) ||
1455 !IsAfterDstToSrcTransform(*context, *node)) {
1456 return Status::OK();
1457 }
1458 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1459 TF_RETURN_IF_ERROR(
1460 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1461 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1462 return context->graph_view->GetMutationBuilder()->Apply();
1463 }
1464
KeepDims(const utils::MutableNodeView & node)1465 bool ReduceTransposer::KeepDims(const utils::MutableNodeView& node) {
1466 const auto* keep_dims_attr = node.GetAttr(kAttrKeepDims);
1467 if (keep_dims_attr != nullptr) {
1468 return keep_dims_attr->b();
1469 }
1470 return false;
1471 }
1472
IsAlongAxis(const Tensor & tensor,absl::Span<const int> axis,int rank)1473 bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
1474 absl::Span<const int> axis, int rank) {
1475 const int axis_size = axis.size();
1476 if (tensor.dims() != 1 || tensor.dim_size(0) != axis_size) {
1477 return false;
1478 }
1479 for (int i = 0; i < axis_size; ++i) {
1480 int local_axis = 0;
1481 if (tensor.dtype() == DT_INT32) {
1482 local_axis = tensor.flat<int32>()(i);
1483 } else {
1484 local_axis = tensor.flat<int64>()(i);
1485 }
1486 if (local_axis < 0) {
1487 local_axis += rank;
1488 }
1489 bool along_axis = false;
1490 for (int dim : axis) {
1491 if (local_axis == dim) {
1492 along_axis = true;
1493 break;
1494 }
1495 }
1496 if (!along_axis) {
1497 return false;
1498 }
1499 }
1500 return true;
1501 }
1502
IsReduceAxisSupported(const TransposeContext & context,const utils::MutableNodeView & node,int rank)1503 bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
1504 const utils::MutableNodeView& node,
1505 int rank) {
1506 if (KeepDims(node)) {
1507 return true;
1508 }
1509 const auto& regular_fanin_1 = node.GetRegularFanin(1);
1510 auto* axis_node = regular_fanin_1.node_view();
1511 if (!IsConstant(*axis_node->node())) {
1512 return false;
1513 }
1514 const auto* value_attr = axis_node->GetAttr(kAttrValue);
1515 if (value_attr == nullptr) {
1516 return false;
1517 }
1518 Tensor tensor;
1519 if (!tensor.FromProto(value_attr->tensor())) {
1520 LOG(ERROR) << "Failed to parse TensorProto.";
1521 return false;
1522 }
1523 auto indices = [&context](absl::Span<const char> labels) {
1524 return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1525 };
1526 if (rank == 5) {
1527 return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
1528 IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
1529 IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
1530 IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
1531 IsAlongAxis(tensor, indices({'C'}), 5);
1532 }
1533 DCHECK_EQ(rank, 4);
1534 return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
1535 IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
1536 IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
1537 IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
1538 IsAlongAxis(tensor, indices({'C'}), 4);
1539 }
1540
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1541 Status ReduceTransposer::TransposeNode(TransposeContext* context,
1542 utils::MutableNodeView* node) {
1543 DCHECK(IsReduceOp(*node->node()));
1544 const int rank = GetFaninPortRank(*node, 0);
1545 if (rank != 4 && rank != 5) {
1546 return Status::OK();
1547 }
1548 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1549 if (!ShouldProcess(*context, *node) ||
1550 !IsReduceAxisSupported(*context, *node, rank) ||
1551 !IsAfterDstToSrcTransform(*context, *node)) {
1552 return Status::OK();
1553 }
1554 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1555 << "' with op '" << node->GetOp() << "' from data format '"
1556 << context->src_format << "' to '" << context->dst_format << "'";
1557 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1558 TF_RETURN_IF_ERROR(
1559 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1560 if (KeepDims(*node)) {
1561 TF_RETURN_IF_ERROR(
1562 UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1563 }
1564 return context->graph_view->GetMutationBuilder()->Apply();
1565 }
1566
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1567 Status ReverseV2Transposer::TransposeNode(TransposeContext* context,
1568 utils::MutableNodeView* node) {
1569 DCHECK(IsReverseV2(*node->node()));
1570 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1571 !IsAfterDstToSrcTransform(*context, *node)) {
1572 return Status::OK();
1573 }
1574 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1575 TF_RETURN_IF_ERROR(
1576 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1577 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1578 return context->graph_view->GetMutationBuilder()->Apply();
1579 }
1580
IsFaninScalarVector4D(const utils::MutableNodeView & fanin,int port)1581 bool SelectTransposer::IsFaninScalarVector4D(
1582 const utils::MutableNodeView& fanin, int port) {
1583 return IsFanoutPortRankN(fanin, port, 0) ||
1584 IsFanoutPortRankN(fanin, port, 1) || IsFanoutPortRankN(fanin, port, 4);
1585 }
1586
GetFaninPorts(const utils::MutableNodeView & fanin,int port)1587 std::vector<int> SelectTransposer::GetFaninPorts(
1588 const utils::MutableNodeView& fanin, int port) {
1589 // Input 0 could be a scalar, a vector with size matching the first dimension
1590 // of input 1 and 2, or must have the same shape as input 1 and 2.
1591 if (IsFanoutPortRankN(fanin, port, 4)) {
1592 return {0, 1, 2};
1593 }
1594 return {1, 2};
1595 }
1596
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1597 Status SelectTransposer::TransposeNode(TransposeContext* context,
1598 utils::MutableNodeView* node) {
1599 DCHECK(IsSelect(*node->node()));
1600 const auto& regular_fanin_0 = node->GetRegularFanin(0);
1601 auto* regular_fanin_0_node = regular_fanin_0.node_view();
1602 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1603 !IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) ||
1604 !IsAfterDstToSrcTransform(*context, *node)) {
1605 return Status::OK();
1606 }
1607 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1608 context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()),
1609 node, kOpTranspose));
1610 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1611 return context->graph_view->GetMutationBuilder()->Apply();
1612 }
1613
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1614 Status ShapeTransposer::TransposeNode(TransposeContext* context,
1615 utils::MutableNodeView* node) {
1616 DCHECK(IsShape(*node->node()));
1617 const int rank = GetFaninPortRank(*node, 0);
1618 if (rank != 4 && rank != 5) {
1619 return Status::OK();
1620 }
1621 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1622 if (!ShouldProcess(*context, *node) ||
1623 !IsAfterDstToSrcTransform(*context, *node)) {
1624 return Status::OK();
1625 }
1626 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1627 << "' with op '" << node->GetOp() << "' from data format '"
1628 << context->src_format << "' to '" << context->dst_format << "'";
1629 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1630 TF_RETURN_IF_ERROR(
1631 UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1632 return context->graph_view->GetMutationBuilder()->Apply();
1633 }
1634
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1635 Status ShapeNTransposer::TransposeNode(TransposeContext* context,
1636 utils::MutableNodeView* node) {
1637 DCHECK(IsShapeN(*node->node()));
1638 // ShapeN requires all input tensors to have the same dimensions. Therefore,
1639 // we simply use the 0th fanin port.
1640 const int rank = GetFaninPortRank(*node, 0);
1641 if (rank != 4 && rank != 5) {
1642 return Status::OK();
1643 }
1644 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1645 const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
1646 if (!ShouldProcess(*context, *node) || ports.empty()) {
1647 return Status::OK();
1648 }
1649 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1650 << "' with op '" << node->GetOp() << "' from data format '"
1651 << context->src_format << "' to '" << context->dst_format << "'";
1652 TF_RETURN_IF_ERROR(
1653 UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
1654 TF_RETURN_IF_ERROR(
1655 UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute));
1656 return context->graph_view->GetMutationBuilder()->Apply();
1657 }
1658
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1659 Status SliceTransposer::TransposeNode(TransposeContext* context,
1660 utils::MutableNodeView* node) {
1661 DCHECK(IsSlice(*node->node()));
1662 const int rank = GetFanoutPortRank(*node, 0);
1663 if (rank != 4 && rank != 5) {
1664 return Status::OK();
1665 }
1666 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1667 if (!ShouldProcess(*context, *node) ||
1668 !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
1669 !IsAfterDstToSrcTransform(*context, *node)) {
1670 return Status::OK();
1671 }
1672 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1673 << "' with op '" << node->GetOp() << "' from data format '"
1674 << context->src_format << "' to '" << context->dst_format << "'";
1675 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1676 TF_RETURN_IF_ERROR(
1677 UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1678 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1679 return context->graph_view->GetMutationBuilder()->Apply();
1680 }
1681
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1682 Status SplitTransposer::TransposeNode(TransposeContext* context,
1683 utils::MutableNodeView* node) {
1684 DCHECK(IsSplit(*node->node()));
1685 const auto ports = GetDataFanoutPorts(*node);
1686 if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1687 !IsAfterDstToSrcTransform(*context, *node)) {
1688 return Status::OK();
1689 }
1690 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
1691 TF_RETURN_IF_ERROR(
1692 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap));
1693 TF_RETURN_IF_ERROR(
1694 UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1695 return context->graph_view->GetMutationBuilder()->Apply();
1696 }
1697
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1698 Status SplitVTransposer::TransposeNode(TransposeContext* context,
1699 utils::MutableNodeView* node) {
1700 DCHECK(IsSplitV(*node->node()));
1701 const auto ports = GetDataFanoutPorts(*node);
1702 if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1703 !IsAfterDstToSrcTransform(*context, *node)) {
1704 return Status::OK();
1705 }
1706 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1707 TF_RETURN_IF_ERROR(
1708 UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap));
1709 TF_RETURN_IF_ERROR(
1710 UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1711 return context->graph_view->GetMutationBuilder()->Apply();
1712 }
1713
IsInputConvertible(const TransposeContext & context,const utils::MutableNodeView & node) const1714 bool SqueezeTransposer::IsInputConvertible(
1715 const TransposeContext& context, const utils::MutableNodeView& node) const {
1716 const auto& regular_fanin_0 = node.GetRegularFanin(0);
1717 auto* regular_fanin_0_node = regular_fanin_0.node_view();
1718 const auto* output_shape_attr =
1719 regular_fanin_0_node->GetAttr(kAttrOutputShape);
1720 if (output_shape_attr != nullptr) {
1721 auto& shape = output_shape_attr->list().shape(regular_fanin_0.index());
1722 if (shape.dim_size() != kRank) {
1723 return false;
1724 }
1725 const int height_dim = context.src_dim_indices.at('H');
1726 const int width_dim = context.src_dim_indices.at('W');
1727 if (shape.dim(height_dim).size() == 1 && shape.dim(width_dim).size() == 1) {
1728 return true;
1729 }
1730 }
1731 return false;
1732 }
1733
IsAlongAxis(const AttrValue & attr,absl::Span<const int> axis,int rank) const1734 bool SqueezeTransposer::IsAlongAxis(const AttrValue& attr,
1735 absl::Span<const int> axis,
1736 int rank) const {
1737 const auto& list = attr.list();
1738 // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1739 int axis_size = axis.size();
1740 if (list.i_size() == 0) {
1741 return true;
1742 } else if (list.i_size() != axis_size) {
1743 return false;
1744 }
1745 for (int i = 0; i < axis_size; ++i) {
1746 int local_axis = list.i(i);
1747 if (local_axis < 0) {
1748 local_axis += rank;
1749 }
1750 bool along_axis = false;
1751 for (int dim : axis) {
1752 if (local_axis == dim) {
1753 along_axis = true;
1754 break;
1755 }
1756 }
1757 if (!along_axis) {
1758 return false;
1759 }
1760 }
1761 return true;
1762 }
1763
IsDimsSupported(const TransposeContext & context,const utils::MutableNodeView & node) const1764 bool SqueezeTransposer::IsDimsSupported(
1765 const TransposeContext& context, const utils::MutableNodeView& node) const {
1766 auto indices = [&context](absl::Span<const char> labels) {
1767 return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1768 };
1769 const auto* squeeze_dims_attr = node.GetAttr(kAttrSqueezeDims);
1770 if (squeeze_dims_attr == nullptr) {
1771 return false;
1772 }
1773 return (IsFanoutPortRankN(node, 0, 2) &&
1774 IsAlongAxis(*squeeze_dims_attr, indices({'H', 'W'}), kRank)) ||
1775 (IsFanoutPortRankN(node, 0, 1) &&
1776 IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank));
1777 }
1778
UpdateSqueezeDims(TransposeContext * context,utils::MutableNodeView * node)1779 Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context,
1780 utils::MutableNodeView* node) {
1781 const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims);
1782 if (squeeze_dims_attr == nullptr) {
1783 return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims);
1784 }
1785 const int num_input_dims = context->src_format.length();
1786 const int min_squeeze_dim = -num_input_dims;
1787 std::vector<int> squeeze_dims_mapped;
1788 const int squeeze_dims_size = squeeze_dims_attr->list().i_size();
1789 squeeze_dims_mapped.reserve(squeeze_dims_size);
1790 for (int i = 0; i < squeeze_dims_size; ++i) {
1791 int dim = squeeze_dims_attr->list().i(i);
1792 if (dim < min_squeeze_dim || dim >= num_input_dims) {
1793 return errors::InvalidArgument(
1794 "Attribute '", kAttrSqueezeDims, "' contains out of range index '",
1795 dim, "', index must be between [", min_squeeze_dim, ", ",
1796 num_input_dims, ")");
1797 }
1798 if (dim < 0) {
1799 dim += num_input_dims;
1800 }
1801 squeeze_dims_mapped.push_back(context->dst_to_src[dim]);
1802 }
1803 std::sort(squeeze_dims_mapped.begin(), squeeze_dims_mapped.end());
1804 AttrValue squeeze_dims;
1805 squeeze_dims.mutable_list()->mutable_i()->Reserve(squeeze_dims_size);
1806 for (const auto& dim : squeeze_dims_mapped) {
1807 squeeze_dims.mutable_list()->mutable_i()->Add(dim);
1808 }
1809 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
1810 node, kAttrSqueezeDims, squeeze_dims);
1811 return Status::OK();
1812 }
1813
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1814 Status SqueezeTransposer::TransposeNode(TransposeContext* context,
1815 utils::MutableNodeView* node) {
1816 DCHECK(IsSqueeze(*node->node()));
1817 if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) ||
1818 !IsInputConvertible(*context, *node) ||
1819 !IsAfterDstToSrcTransform(*context, *node)) {
1820 return Status::OK();
1821 }
1822 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1823 TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node));
1824 return context->graph_view->GetMutationBuilder()->Apply();
1825 }
1826
IsMaskZero(const utils::MutableNodeView & node,absl::string_view mask)1827 bool StridedSliceTransposer::IsMaskZero(const utils::MutableNodeView& node,
1828 absl::string_view mask) {
1829 const auto* mask_attr = node.GetAttr(mask);
1830 if (mask_attr != nullptr) {
1831 return mask_attr->i() == 0;
1832 }
1833 return true;
1834 }
1835
HasOnlyBeginEndMask(const utils::MutableNodeView & node)1836 bool StridedSliceTransposer::HasOnlyBeginEndMask(
1837 const utils::MutableNodeView& node) {
1838 return IsMaskZero(node, "ellipsis_mask") &&
1839 IsMaskZero(node, "new_axis_mask") &&
1840 IsMaskZero(node, "shrink_axis_mask");
1841 }
1842
PermuteMask(TransposeContext * context,utils::MutableNodeView * node,absl::string_view mask)1843 Status StridedSliceTransposer::PermuteMask(TransposeContext* context,
1844 utils::MutableNodeView* node,
1845 absl::string_view mask) {
1846 // Computers the permutation of the masks based on the src and dst format.
1847 // For example:
1848 // src_format = NHWC
1849 // dst_format = NCHW
1850 // src_to_dst permutation = [0, 3, 1, 2].
1851 // mask : 0010 [Note the bit positions correspond to indexes i.e this is in
1852 // reverse order of the src format (CWHN)] result : 0100 (WHCN)
1853 const auto* mask_attr = node->GetAttr(mask);
1854 const int mask_i = mask_attr != nullptr ? mask_attr->i() : 0;
1855 if (mask_i < 0 || mask_i > 15) {
1856 return errors::InvalidArgument("invalid mask value: ", mask_i);
1857 }
1858 int result = 0;
1859 for (int i = 0, end = context->src_to_dst.size(); i < end; i++) {
1860 const int final_pos = context->src_to_dst[i];
1861 const int position_mask = 1 << final_pos;
1862 const int bit_i = (mask_i & position_mask) >> final_pos;
1863 result |= bit_i << i;
1864 }
1865 AttrValue new_mask_attr;
1866 new_mask_attr.set_i(result);
1867 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(node, mask,
1868 new_mask_attr);
1869 return Status::OK();
1870 }
1871
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1872 Status StridedSliceTransposer::TransposeNode(TransposeContext* context,
1873 utils::MutableNodeView* node) {
1874 DCHECK(IsStridedSlice(*node->node()));
1875 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1876 !IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) ||
1877 !HasOnlyBeginEndMask(*node) ||
1878 !IsAfterDstToSrcTransform(*context, *node)) {
1879 return Status::OK();
1880 }
1881 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1882 TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask"));
1883 TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask"));
1884 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node,
1885 kOpDataFormatVecPermute));
1886 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1887 return context->graph_view->GetMutationBuilder()->Apply();
1888 }
1889
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1890 Status SwitchTransposer::TransposeNode(TransposeContext* context,
1891 utils::MutableNodeView* node) {
1892 DCHECK(IsSwitch(*node->node()));
1893 if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
1894 !IsAfterDstToSrcTransform(*context, *node)) {
1895 return Status::OK();
1896 }
1897 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1898 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node),
1899 node, kOpTranspose));
1900 return context->graph_view->GetMutationBuilder()->Apply();
1901 }
1902
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1903 Status TernaryOpTransposer::TransposeNode(TransposeContext* context,
1904 utils::MutableNodeView* node) {
1905 DCHECK(IsTernaryOp(*node->node()));
1906 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1907 !IsAfterDstToSrcTransform(*context, *node)) {
1908 return Status::OK();
1909 }
1910 TF_RETURN_IF_ERROR(
1911 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1912 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1913 return context->graph_view->GetMutationBuilder()->Apply();
1914 }
1915
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1916 Status TileTransposer::TransposeNode(TransposeContext* context,
1917 utils::MutableNodeView* node) {
1918 DCHECK(IsTile(*node->node()));
1919 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1920 !IsFaninPortDimsNIfConst(*node, 1, {4}) ||
1921 !IsAfterDstToSrcTransform(*context, *node)) {
1922 return Status::OK();
1923 }
1924 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1925 TF_RETURN_IF_ERROR(
1926 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1927 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1928 return context->graph_view->GetMutationBuilder()->Apply();
1929 }
1930
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1931 Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
1932 utils::MutableNodeView* node) {
1933 DCHECK(IsUnaryGrad(*node->node()));
1934 const int rank = GetFanoutPortRank(*node, 0);
1935 if (rank != 4 && rank != 5) {
1936 return Status::OK();
1937 }
1938 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1939 if (!ShouldProcess(*context, *node) ||
1940 !IsAfterDstToSrcTransform(*context, *node)) {
1941 return Status::OK();
1942 }
1943 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1944 << "' with op '" << node->GetOp() << "' from data format '"
1945 << context->src_format << "' to '" << context->dst_format << "'";
1946 TF_RETURN_IF_ERROR(
1947 UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
1948 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1949 return context->graph_view->GetMutationBuilder()->Apply();
1950 }
1951
1952 // Utils.
1953
GetDeviceName(const VirtualPlacer * virtual_placer,const NodeDef & node)1954 string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
1955 return (node.device().empty() && virtual_placer != nullptr)
1956 ? virtual_placer->get_canonical_device_name(node)
1957 : node.device();
1958 }
1959
IsDefaultLayoutSensitiveOp(const NodeDef & node)1960 bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
1961 static absl::flat_hash_set<string>* default_layout_sensitive_ops =
1962 new absl::flat_hash_set<std::string>(
1963 {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
1964 "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1965 "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
1966 return default_layout_sensitive_ops->find(node.op()) !=
1967 default_layout_sensitive_ops->end();
1968 }
1969
IsLayoutSensitiveOp(const NodeDef & node)1970 bool IsLayoutSensitiveOp(const NodeDef& node) {
1971 return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
1972 IsBiasAddV2(node) || IsBiasAddGrad(node) ||
1973 IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
1974 IsDepthwiseConv2dNativeBackpropFilter(node) ||
1975 IsDepthwiseConv2dNativeBackpropInput(node) ||
1976 IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
1977 IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) ||
1978 IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node) ||
1979 IsConv3D(node) || IsConv3DBackpropInputV2(node) ||
1980 IsConv3DBackpropFilterV2(node);
1981 }
1982
IsDefaultLayoutAgnosticOp(const NodeDef & node)1983 bool IsDefaultLayoutAgnosticOp(const NodeDef& node) {
1984 static absl::flat_hash_set<string>* agnostic_nodes =
1985 new absl::flat_hash_set<std::string>({"Abs",
1986 "Acos",
1987 "Acosh",
1988 "Angle",
1989 "Asin",
1990 "Asinh",
1991 "Atan",
1992 "Atanh",
1993 "Bitcast",
1994 "Cast",
1995 "Ceil",
1996 "CheckNumerics",
1997 "ComplexAbs",
1998 "Conj",
1999 "Cos",
2000 "Cosh",
2001 "Digamma",
2002 "Elu",
2003 "Enter",
2004 "Erf",
2005 "Erfc",
2006 "Exit",
2007 "Exp",
2008 "Expm1",
2009 "FakeQuantWithMinMaxVars",
2010 "FakeQuantWithMinMaxArgs",
2011 "Floor",
2012 "GuaranteeConst",
2013 "Identity",
2014 "Imag",
2015 "Inv",
2016 "IsFinite",
2017 "IsInf",
2018 "IsNan",
2019 "LeakyRelu",
2020 "Lgamma",
2021 "Log",
2022 "LogicalNot",
2023 "Log1p",
2024 "Neg",
2025 "NextIteration",
2026 "OnesLike",
2027 "PreventGradient",
2028 "QuantizeAndDequantizeV2",
2029 "QuantizeAndDequantizeV3",
2030 "Real",
2031 "Reciprocal",
2032 "Relu",
2033 "Relu6",
2034 "Rint",
2035 "Selu",
2036 "Sigmoid",
2037 "Sign",
2038 "Sin",
2039 "Sinh",
2040 "Snapshot",
2041 "Softplus",
2042 "Round",
2043 "Rsqrt",
2044 "Sqrt",
2045 "Square",
2046 "StopGradient",
2047 "Tan",
2048 "Tanh",
2049 "ZerosLike"});
2050 return agnostic_nodes->find(node.op()) != agnostic_nodes->end();
2051 }
2052
IsLayoutAgnosticOp(const NodeDef & node)2053 bool IsLayoutAgnosticOp(const NodeDef& node) {
2054 return IsDefaultLayoutAgnosticOp(node) || IsAddN(node) || IsBinaryOp(node) ||
2055 IsIdentityN(node) || IsMerge(node) || IsMirrorPad(node) ||
2056 IsMirrorPadGrad(node) || IsPad(node) || IsSelect(node) ||
2057 IsSwitch(node) || IsTernaryOp(node) || IsUnaryGrad(node) ||
2058 IsConcat(node) || IsReverseV2(node) || IsTile(node) || IsShape(node) ||
2059 IsShapeN(node) || IsFill(node) || IsSlice(node) || IsSplit(node) ||
2060 IsSqueeze(node) || IsSplitV(node) || IsStridedSlice(node) ||
2061 IsReduceOp(node);
2062 }
2063
IsTernaryOp(const NodeDef & node)2064 bool IsTernaryOp(const NodeDef& node) { return IsBetainc(node); }
2065
IsUnaryGrad(const NodeDef & node)2066 bool IsUnaryGrad(const NodeDef& node) {
2067 bool is_unary_grad =
2068 IsEluGrad(node) || IsInvGrad(node) || IsLeakyReluGrad(node) ||
2069 IsReciprocalGrad(node) || IsRelu6Grad(node) || IsReluGrad(node) ||
2070 IsRsqrtGrad(node) || IsSeluGrad(node) || IsSigmoidGrad(node) ||
2071 IsSoftplusGrad(node) || IsSoftsignGrad(node) || IsSqrtGrad(node) ||
2072 IsTanhGrad(node);
2073 return is_unary_grad;
2074 }
2075
IsMaxPoolV2(const NodeDef & node)2076 bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; }
2077
IsMaxPoolGradV2(const NodeDef & node)2078 bool IsMaxPoolGradV2(const NodeDef& node) {
2079 return node.op() == "MaxPoolGradV2";
2080 }
2081
IsMaxPoolGradGradV1(const NodeDef & node)2082 bool IsMaxPoolGradGradV1(const NodeDef& node) {
2083 return node.op() == "MaxPoolGradGrad";
2084 }
2085
IsMaxPoolGradGradV2(const NodeDef & node)2086 bool IsMaxPoolGradGradV2(const NodeDef& node) {
2087 return node.op() == "MaxPoolGradGradV2";
2088 }
2089
IsBinaryOp(const NodeDef & node)2090 bool IsBinaryOp(const NodeDef& node) {
2091 bool is_binary =
2092 IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
2093 IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
2094 IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
2095 IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
2096 IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
2097 IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
2098 return is_binary;
2099 }
2100
IsReduceOp(const NodeDef & node)2101 bool IsReduceOp(const NodeDef& node) {
2102 return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
2103 IsMin(node) || IsAll(node) || IsAny(node);
2104 }
2105
GetDataFaninPorts(const utils::MutableNodeView & node)2106 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node) {
2107 const auto* node_def = node.node();
2108 if (IsAvgPoolGrad(*node_def) || IsSplit(*node_def)) {
2109 return {1};
2110 }
2111 if (IsStridedSliceGrad(*node_def)) {
2112 return {4};
2113 }
2114 if (IsBinaryOp(*node_def) || IsUnaryGrad(*node_def)) {
2115 return {0, 1};
2116 }
2117 if (IsTernaryOp(*node_def) || IsSelect(*node_def) ||
2118 IsMaxPoolGrad(*node_def) || IsMaxPoolGradV2(*node_def) ||
2119 IsMaxPoolGradGradV1(*node_def) || IsMaxPoolGradGradV2(*node_def)) {
2120 return {0, 1, 2};
2121 }
2122 if (IsShapeN(*node_def) || IsIdentityN(*node_def) || IsAddN(*node_def) ||
2123 IsMerge(*node_def)) {
2124 return GetRegularFaninPorts(node);
2125 }
2126 if (IsConcat(*node_def)) {
2127 return GetConcatDataFaninPorts(node);
2128 }
2129 if (node.NumRegularFanins() > 0) {
2130 return {0};
2131 }
2132 return {};
2133 }
2134
GetDataFanoutPorts(const utils::MutableNodeView & node)2135 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node) {
2136 const auto* node_def = node.node();
2137 if (IsIdentityN(*node_def) || IsShape(*node_def) || IsShapeN(*node_def)) {
2138 return GetDataFaninPorts(node);
2139 }
2140 if (IsSplit(*node_def) || IsSplitV(*node_def)) {
2141 const auto* num_split_attr = node.GetAttr(kAttrNumSplit);
2142 if (num_split_attr == nullptr) {
2143 return {0};
2144 }
2145 std::vector<int> values(num_split_attr->i());
2146 std::iota(values.begin(), values.end(), 0);
2147 return values;
2148 }
2149 if (IsSwitch(*node_def)) {
2150 const auto* num_outs_attr = node.GetAttr(kAttrNumOuts);
2151 const int num_outs = num_outs_attr != nullptr ? num_outs_attr->i() : 2;
2152 std::vector<int> values(num_outs);
2153 std::iota(values.begin(), values.end(), 0);
2154 return values;
2155 }
2156 return {0};
2157 }
2158
GetValueAttrFromConstInputNode(const utils::MutableNodeView & node,const std::function<bool (const NodeDef &)> & predicate,int index,Tensor * tensor)2159 bool GetValueAttrFromConstInputNode(
2160 const utils::MutableNodeView& node,
2161 const std::function<bool(const NodeDef&)>& predicate, int index,
2162 Tensor* tensor) {
2163 if (!predicate(*node.node())) {
2164 return false;
2165 }
2166 const auto& regular_fanin = node.GetRegularFanin(index);
2167 auto* regular_fanin_node = regular_fanin.node_view();
2168 if (!IsConstant(*regular_fanin_node->node())) {
2169 return false;
2170 }
2171 const auto* value_attr = regular_fanin_node->GetAttr(kAttrValue);
2172 if (value_attr == nullptr || value_attr->tensor().dtype() != DT_INT32) {
2173 return false;
2174 }
2175 if (!tensor->FromProto(value_attr->tensor())) {
2176 return false;
2177 }
2178
2179 return true;
2180 }
2181
IsDataFormatOp(const utils::MutableNodeView & node)2182 bool IsDataFormatOp(const utils::MutableNodeView& node) {
2183 const string& op = node.GetOp();
2184 return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute;
2185 }
2186
GetDimensionIndices(absl::string_view data_format)2187 absl::flat_hash_map<char, int> GetDimensionIndices(
2188 absl::string_view data_format) {
2189 const int size = data_format.size();
2190 absl::flat_hash_map<char, int> index;
2191 index.reserve(size);
2192 for (int i = 0; i < size; i++) {
2193 index[data_format[i]] = i;
2194 }
2195 return index;
2196 }
2197
GetPermutation(const absl::flat_hash_map<char,int> & src_dim_indices,absl::string_view dst_format)2198 std::vector<int> GetPermutation(
2199 const absl::flat_hash_map<char, int>& src_dim_indices,
2200 absl::string_view dst_format) {
2201 // Generate permutation for transformation between src and dst format.
2202 // Example:
2203 // src = NWHC, dst = NCWH
2204 // index = { N:0 W:1 H:2 C:3 }
2205 // permutation = [0, 3, 1, 2]
2206 DCHECK(src_dim_indices.size() == dst_format.size());
2207 std::vector<int> permutation;
2208 const int size = dst_format.size();
2209 permutation.reserve(size);
2210 for (int i = 0; i < size; i++) {
2211 permutation.push_back(src_dim_indices.at(dst_format[i]));
2212 }
2213 return permutation;
2214 }
2215
2216 } // namespace grappler
2217 } // namespace tensorflow
2218