• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/memory_types.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/utils.h"
38 #include "tensorflow/core/grappler/utils/frame.h"
39 #include "tensorflow/core/grappler/utils/graph_view.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/protobuf/device_properties.pb.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 
44 namespace tensorflow {
45 namespace grappler {
46 
47 namespace {
48 
49 constexpr char kOptimizedSuffix[] = "LayoutOptimizer";
50 constexpr char kAttrKSize[] = "ksize";
51 constexpr char kAttrStrides[] = "strides";
52 constexpr char kAttrDilations[] = "dilations";
53 constexpr char kAttrExplicitPaddings[] = "explicit_paddings";
54 constexpr char kAttrDataFormat[] = "data_format";
55 constexpr char kAttrIsTraining[] = "is_training";
56 constexpr char kAttrValue[] = "value";
57 constexpr char kAttrN[] = "N";
58 constexpr char kAttrT[] = "T";
59 constexpr char kAttrNumSplit[] = "num_split";
60 constexpr char kAttrNumOuts[] = "num_outs";
61 constexpr char kAttrKeepDims[] = "keep_dims";
62 constexpr char kAttrSqueezeDims[] = "squeeze_dims";
63 constexpr char kOpTranspose[] = "Transpose";
64 constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute";
65 constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap";
66 constexpr char kOpConst[] = "Const";
67 constexpr char kReshape[] = "Reshape";
68 constexpr char kReshapeConst[] = "ReshapeConst";
69 constexpr int kRank = 4;
70 constexpr int kUnknownRank = -1;
71 constexpr int kInvalidRank = -2;
72 
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format,bool * missing)73 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
74                                 absl::string_view src_data_format,
75                                 bool* missing) {
76   const auto* attr = node.GetAttr(kAttrDataFormat);
77   if (attr != nullptr) {
78     return attr->s() == src_data_format;
79   }
80   *missing = true;
81   return false;
82 }
83 
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format)84 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
85                                 absl::string_view src_data_format) {
86   bool missing = false;
87   return AttrDataFormatMatch(node, src_data_format, &missing);
88 }
89 
IsNonFloatingConv2D(const utils::MutableNodeView & node)90 bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
91   if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
92     const auto* attr = node.GetAttr(kAttrT);
93     if (attr != nullptr) {
94       return !kDataTypeIsFloating.Contains(attr->type());
95     }
96   }
97   return false;
98 }
99 
100 // Utils for layout agnostic transposer.
101 
IsComparisonOp(const NodeDef & node)102 bool IsComparisonOp(const NodeDef& node) {
103   bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
104                     IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
105                     IsLessEqual(node) || IsNotEqual(node);
106   return is_compare;
107 }
108 
GetRegularFaninPorts(const utils::MutableNodeView & node)109 std::vector<int> GetRegularFaninPorts(const utils::MutableNodeView& node) {
110   const int num_regular_fanins = node.NumRegularFanins();
111   std::vector<int> values(num_regular_fanins);
112   std::iota(values.begin(), values.end(), 0);
113   return values;
114 }
115 
GetConcatDataFaninPorts(const utils::MutableNodeView & node)116 std::vector<int> GetConcatDataFaninPorts(const utils::MutableNodeView& node) {
117   const auto* n_attr = node.GetAttr(kAttrN);
118   const int n = n_attr != nullptr ? n_attr->i() : 0;
119   const int start = (node.GetOp() == "Concat") ? 1 : 0;
120   const int end = start + n;
121   std::vector<int> values(end - start);
122   std::iota(values.begin(), values.end(), start);
123   return values;
124 }
125 
126 struct ComparatorByNodeNameAndIndex {
operator ()tensorflow::grappler::__anonf97689a50111::ComparatorByNodeNameAndIndex127   bool operator()(const utils::MutableFaninView& node1,
128                   const utils::MutableFaninView& node2) const {
129     auto* node1_view = node1.node_view();
130     auto* node2_view = node2.node_view();
131     auto name_compare = node1_view->GetName().compare(node2_view->GetName());
132     if (name_compare == 0) {
133       return node1.index() < node2.index();
134     }
135     return name_compare < 0;
136   }
137 };
138 
IsHostMemory(const NodeDef & node,int output_port)139 bool IsHostMemory(const NodeDef& node, int output_port) {
140   DeviceNameUtils::ParsedName parsed_name;
141   if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
142     DeviceType device_type(parsed_name.type);
143     Status s = FindKernelDef(device_type, node, nullptr, nullptr);
144     if (s.ok()) {
145       tensorflow::MemoryTypeVector in_mtypes;
146       tensorflow::MemoryTypeVector out_mtypes;
147       s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
148                                          node, &in_mtypes, &out_mtypes);
149       if (s.ok()) {
150         if (out_mtypes[output_port] == HOST_MEMORY) {
151           return true;
152         }
153       }
154     } else {
155       return true;
156     }
157   }
158   return false;
159 }
160 
GetDimensionIndicesFromLabel(const absl::flat_hash_map<char,int> & dim_indices,absl::Span<const char> labels)161 std::vector<int> GetDimensionIndicesFromLabel(
162     const absl::flat_hash_map<char, int>& dim_indices,
163     absl::Span<const char> labels) {
164   std::vector<int> indices;
165   indices.reserve(labels.size());
166   for (const auto& label : labels) {
167     indices.push_back(dim_indices.at(label));
168   }
169   return indices;
170 }
171 
172 // RAII-styled object for keeping track of 4D to 5D data format
173 // upgrade/conversion. Currently only NHWC -> NDHWC and NCHW -> NCDHW are
174 // supported.
175 class ScopedDataFormatUpgrader {
176  public:
ScopedDataFormatUpgrader(TransposeContext * context,int rank)177   ScopedDataFormatUpgrader(TransposeContext* context, int rank)
178       : context_(context) {
179     if (rank == 5 && IsSupportedDataFormat(context_->src_format) &&
180         IsSupportedDataFormat(context_->dst_format)) {
181       old_src_format_ = context_->src_format;
182       old_dst_format_ = context_->dst_format;
183       std::string new_src_format = GetUpgradedDataFormat(context_->src_format);
184       std::string new_dst_format = GetUpgradedDataFormat(context_->dst_format);
185       context_->AssignDeviceAndDataFormats(context_->target_device,
186                                            new_src_format, new_dst_format);
187       upgraded_ = true;
188     }
189   }
190 
191   ScopedDataFormatUpgrader(const ScopedDataFormatUpgrader&) = delete;
192   ScopedDataFormatUpgrader& operator=(const ScopedDataFormatUpgrader&) = delete;
193 
~ScopedDataFormatUpgrader()194   ~ScopedDataFormatUpgrader() {
195     if (upgraded_) {
196       context_->AssignDeviceAndDataFormats(context_->target_device,
197                                            old_src_format_, old_dst_format_);
198     }
199   }
200 
201  private:
IsSupportedDataFormat(absl::string_view data_format)202   bool IsSupportedDataFormat(absl::string_view data_format) {
203     return data_format == "NHWC" || data_format == "NCHW";
204   }
205 
GetUpgradedDataFormat(absl::string_view data_format)206   std::string GetUpgradedDataFormat(absl::string_view data_format) {
207     if (data_format == "NHWC") {
208       return "NDHWC";
209     }
210 
211     DCHECK_EQ(data_format, "NCHW");
212     return "NCDHW";
213   }
214 
215   TransposeContext* context_ = nullptr;
216   bool upgraded_ = false;
217   std::string old_src_format_;
218   std::string old_dst_format_;
219 };
220 
221 }  // namespace
222 
223 // TransposeContext.
224 
InitializeTransposeContext(const GrapplerItem & item,const Cluster * cluster,TransposeContext * context)225 Status TransposeContext::InitializeTransposeContext(const GrapplerItem& item,
226                                                     const Cluster* cluster,
227                                                     TransposeContext* context) {
228   DCHECK(context != nullptr);
229   context->graph_properties = absl::make_unique<GraphProperties>(item);
230   TF_RETURN_IF_ERROR(context->graph_properties->InferStatically(false));
231   TF_RETURN_IF_ERROR(context->graph_properties->AnnotateOutputShapes(
232       &context->graph, /*allow_symbolic_shapes=*/true));
233   Status status;
234   context->graph_view =
235       absl::make_unique<utils::MutableGraphView>(&context->graph, &status);
236   TF_RETURN_IF_ERROR(status);
237   context->num_nodes = context->graph.node_size();
238   const auto& nodes_to_preserve = item.NodesToPreserve();
239   context->nodes_to_preserve = absl::flat_hash_set<string>(
240       nodes_to_preserve.begin(), nodes_to_preserve.end());
241   TF_RETURN_IF_ERROR(context->frames.InferFromGraph(context->graph));
242   if (cluster != nullptr) {
243     context->virtual_placer =
244         absl::make_unique<const VirtualPlacer>(cluster->GetDevices());
245   }
246   return Status::OK();
247 }
248 
249 // Sets data formats to convert from and to for specified device type.
AssignDeviceAndDataFormats(absl::string_view target_device,absl::string_view src_format,absl::string_view dst_format)250 void TransposeContext::AssignDeviceAndDataFormats(
251     absl::string_view target_device, absl::string_view src_format,
252     absl::string_view dst_format) {
253   this->target_device = string(target_device);
254   this->src_format = string(src_format);
255   this->dst_format = string(dst_format);
256   this->src_dim_indices = GetDimensionIndices(src_format);
257   this->dst_dim_indices = GetDimensionIndices(dst_format);
258   this->src_to_dst = GetPermutation(this->src_dim_indices, dst_format);
259   this->dst_to_src = GetPermutation(this->dst_dim_indices, src_format);
260 }
261 
262 // Transposer.
263 
ShouldProcess(const TransposeContext & context,const utils::MutableNodeView & node) const264 bool Transposer::ShouldProcess(const TransposeContext& context,
265                                const utils::MutableNodeView& node) const {
266   const auto* node_def = node.node();
267   const string& device_name =
268       GetDeviceName(context.virtual_placer.get(), *node_def);
269   string device;
270   string task;
271   const bool is_on_target_device =
272       DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
273       absl::StrContains(absl::AsciiStrToLower(device),
274                         absl::AsciiStrToLower(context.target_device));
275 
276   // Only checks data format for layout sensitive op.
277   const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
278                                  AttrDataFormatMatch(node, context.src_format);
279 
280   // Only transposes floating point nodes.
281   const bool is_integer_conv2d = IsNonFloatingConv2D(node);
282 
283   return is_on_target_device && data_format_match && !is_integer_conv2d &&
284          !context.nodes_to_preserve.contains(node_def->name()) &&
285          !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
286 }
287 
CreateConstPermNode(TransposeContext * context,absl::string_view node_name,absl::string_view device,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node)288 Status Transposer::CreateConstPermNode(TransposeContext* context,
289                                        absl::string_view node_name,
290                                        absl::string_view device,
291                                        absl::Span<const int> permutation,
292                                        absl::string_view control_node_name,
293                                        utils::MutationNewNode* added_node) {
294   auto* graph_view = context->graph_view.get();
295   DCHECK(!graph_view->HasNode(node_name));
296 
297   NodeDef node;
298   node.set_name(string(node_name));
299   node.set_op(kOpConst);
300   node.set_device(string(device));
301 
302   if (!control_node_name.empty()) {
303     node.add_input(string(control_node_name));
304   }
305 
306   AttrValue attr_data_type;
307   attr_data_type.set_type(DT_INT32);
308   node.mutable_attr()->insert({"dtype", attr_data_type});
309 
310   AttrValue attr_tensor;
311   Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()}));
312   for (int i = 0, end = permutation.size(); i < end; i++) {
313     tensor.flat<int>()(i) = permutation[i];
314   }
315   tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
316   node.mutable_attr()->insert({"value", attr_tensor});
317 
318   Status status;
319   *added_node =
320       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
321   return status;
322 }
323 
CreateTransposeNode(TransposeContext * context,absl::string_view name_format,const DataType & data_type,absl::string_view device,TensorShapeProto fanin_shape,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node,string * transpose_node_name)324 Status Transposer::CreateTransposeNode(
325     TransposeContext* context, absl::string_view name_format,
326     const DataType& data_type, absl::string_view device,
327     TensorShapeProto fanin_shape, absl::Span<const int> permutation,
328     absl::string_view control_node_name, utils::MutationNewNode* added_node,
329     string* transpose_node_name) {
330   const string node_name = absl::Substitute(name_format, kOpTranspose);
331   auto* graph_view = context->graph_view.get();
332   DCHECK(!graph_view->HasNode(node_name));
333   *transpose_node_name = node_name;
334 
335   NodeDef node;
336   node.set_name(node_name);
337   node.set_op(kOpTranspose);
338   node.set_device(string(device));
339 
340   AttrValue attr_data_type;
341   attr_data_type.set_type(data_type);
342   node.mutable_attr()->insert({"T", attr_data_type});
343 
344   AttrValue attr_data_type_perm;
345   attr_data_type_perm.set_type(DT_INT32);
346   node.mutable_attr()->insert({"Tperm", attr_data_type_perm});
347 
348   if (!fanin_shape.unknown_rank()) {
349     TF_RETURN_IF_ERROR(
350         PermuteSingle(absl::StrCat("fanin shape in", node.name()), permutation,
351                       fanin_shape.mutable_dim()));
352     AttrValue attr_output_shape;
353     *attr_output_shape.mutable_list()->add_shape() = fanin_shape;
354     node.mutable_attr()->insert({kAttrOutputShape, attr_output_shape});
355   }
356 
357   // Create Const Node
358   utils::MutationNewNode const_perm_added_node;
359   const string const_perm_node_name =
360       absl::Substitute(name_format, "PermConst");
361   TF_RETURN_IF_ERROR(CreateConstPermNode(context, const_perm_node_name, device,
362                                          permutation, control_node_name,
363                                          &const_perm_added_node));
364   // Add place holder for 1st input.
365   node.add_input("");
366   // Connect const_perm_node to 2nd input of transpose_node.
367   node.add_input(const_perm_node_name);
368 
369   Status status;
370   *added_node =
371       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
372   return status;
373 }
374 
UpdateFaninEdgesWithOp(TransposeContext * context,absl::Span<const int> dst_ports,utils::MutableNodeView * dst_node,absl::string_view op)375 Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context,
376                                           absl::Span<const int> dst_ports,
377                                           utils::MutableNodeView* dst_node,
378                                           absl::string_view op) {
379   const bool is_in_frame = context->frames.IsInFrame(*dst_node->node());
380   for (int dst_port : dst_ports) {
381     auto& fanin_port = dst_node->GetRegularFanin(dst_port);
382     auto* fanin_node_view = fanin_port.node_view();
383 
384     TF_RETURN_IF_ERROR(
385         UpdateEdge(context,
386                    GetFaninNameFormat(dst_node->GetName(), dst_port,
387                                       context->src_format, context->dst_format),
388                    op, /*input_shape=*/nullptr, /*is_in_frame=*/is_in_frame,
389                    /*is_src_format_to_dst_format=*/true, fanin_port.index(),
390                    dst_port, fanin_node_view, dst_node));
391   }
392   return Status::OK();
393 }
394 
UpdateFanoutEdgesWithOp(TransposeContext * context,absl::Span<const int> src_ports,utils::MutableNodeView * src_node,absl::string_view op)395 Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context,
396                                            absl::Span<const int> src_ports,
397                                            utils::MutableNodeView* src_node,
398                                            absl::string_view op) {
399   // Update attr _output_shapes for output ports.
400   const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape);
401   AttrValue shape_attr_copy;
402   if (op == kOpTranspose && output_shape_attr != nullptr) {
403     shape_attr_copy = *output_shape_attr;
404     for (int port : src_ports) {
405       auto* shape = shape_attr_copy.mutable_list()->mutable_shape(port);
406       if (shape->unknown_rank()) continue;
407       TF_RETURN_IF_ERROR(
408           PermuteSingle(absl::StrCat("output shape attribute at port ", port,
409                                      " in", src_node->GetName()),
410                         context->src_to_dst, shape->mutable_dim()));
411     }
412     context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
413         src_node, kAttrOutputShape, shape_attr_copy);
414   }
415 
416   const bool is_in_frame = context->frames.IsInFrame(*src_node->node());
417   // We might modify the output set in the loop. Make a copy first.
418   // Use a set with custom comparator to order output nodes by node name,
419   // so that we can keep transposer name deterministic.
420   for (int src_port : src_ports) {
421     const auto& fanouts_src_port = src_node->GetRegularFanout(src_port);
422     std::vector<utils::MutableFaninView> sorted_fanouts(
423         fanouts_src_port.begin(), fanouts_src_port.end());
424     std::sort(sorted_fanouts.begin(), sorted_fanouts.end(),
425               ComparatorByNodeNameAndIndex());
426     int num_downstream_transposers = 0;
427     for (const auto& fanout : sorted_fanouts) {
428       TF_RETURN_IF_ERROR(UpdateEdge(
429           context,
430           GetFanoutNameFormat(src_node->GetName(), src_port,
431                               num_downstream_transposers++, context->src_format,
432                               context->dst_format),
433           op, &shape_attr_copy, /*is_in_frame=*/is_in_frame,
434           /*is_src_format_to_dst_format=*/false, src_port, fanout.index(),
435           src_node, fanout.node_view()));
436     }
437   }
438   return Status::OK();
439 }
440 
CreateDataFormatNode(TransposeContext * context,absl::string_view node_name,absl::string_view op,absl::string_view device,const DataType & data_type,bool is_fanin_on_host,bool is_src_format_to_dst_format,utils::MutationNewNode * added_node)441 Status Transposer::CreateDataFormatNode(
442     TransposeContext* context, absl::string_view node_name,
443     absl::string_view op, absl::string_view device, const DataType& data_type,
444     bool is_fanin_on_host, bool is_src_format_to_dst_format,
445     utils::MutationNewNode* added_node) {
446   auto* graph_view = context->graph_view.get();
447   DCHECK(!graph_view->HasNode(node_name));
448 
449   // Create the node
450   NodeDef node;
451   node.set_name(string(node_name));
452 
453   // Set up parameters of node.
454   node.set_op(string(op));
455   node.set_device(string(device));
456   AttrValue attr_data_type;
457   attr_data_type.set_type(data_type);
458   node.mutable_attr()->insert({"T", attr_data_type});
459 
460   // The inputs of a DataFormat op could be in host memory for ops such as
461   // Reshape. In such cases, run the kernel on the host too.
462   if (is_fanin_on_host) {
463     AttrValue attr_kernel;
464     attr_kernel.set_s("host");
465     node.mutable_attr()->insert({"_kernel", attr_kernel});
466   }
467 
468   AttrValue src_format;
469   src_format.set_s(is_src_format_to_dst_format ? context->src_format
470                                                : context->dst_format);
471   node.mutable_attr()->insert({kAttrSrcFormat, src_format});
472   AttrValue dst_format;
473   dst_format.set_s(is_src_format_to_dst_format ? context->dst_format
474                                                : context->src_format);
475   node.mutable_attr()->insert({kAttrDstFormat, dst_format});
476 
477   // Add place holder for 1st input field.
478   node.add_input("");
479 
480   Status status;
481   *added_node =
482       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
483   return status;
484 }
485 
UpdateEdge(TransposeContext * context,absl::string_view name_format,absl::string_view op,const AttrValue * input_shape,bool is_in_frame,bool is_src_format_to_dst_format,const int src_port,const int dst_port,utils::MutableNodeView * src_node,utils::MutableNodeView * dst_node)486 Status Transposer::UpdateEdge(
487     TransposeContext* context, absl::string_view name_format,
488     absl::string_view op, const AttrValue* input_shape, bool is_in_frame,
489     bool is_src_format_to_dst_format, const int src_port, const int dst_port,
490     utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) {
491   DCHECK(src_node != nullptr);
492   DCHECK(dst_node != nullptr);
493   auto* src_node_def = src_node->node();
494   auto* dst_node_def = dst_node->node();
495 
496   // TODO(lyandy): Minimize device parsing/fetching.
497   const string device = GetDeviceName(
498       context->virtual_placer.get(),
499       is_src_format_to_dst_format ? *dst_node_def : *src_node_def);
500   DataType data_type =
501       is_src_format_to_dst_format
502           ? context->graph_properties
503                 ->GetInputProperties(dst_node->GetName())[dst_port]
504                 .dtype()
505           : context->graph_properties
506                 ->GetOutputProperties(src_node->GetName())[src_port]
507                 .dtype();
508 
509   utils::MutationNewNode added_node;
510   string added_node_name;
511   if (op == kOpTranspose) {
512     TensorShapeProto input_shape_proto;
513     input_shape_proto.set_unknown_rank(true);
514     if (input_shape != nullptr) {
515       input_shape_proto = input_shape->list().shape(src_port);
516     } else {
517       const auto* src_node_shape_attr = src_node->GetAttr(kAttrOutputShape);
518       if (src_node_shape_attr != nullptr) {
519         input_shape_proto = src_node_shape_attr->list().shape(src_port);
520       }
521     }
522     const string control_node_name =
523         is_in_frame ? AsControlDependency(src_node_def->name()) : "";
524     const std::vector<int>& permutation =
525         is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src;
526     TF_RETURN_IF_ERROR(CreateTransposeNode(
527         context, name_format, data_type, device, input_shape_proto, permutation,
528         control_node_name, &added_node, &added_node_name));
529   } else if (op == kOpDataFormatVecPermute || op == kOpDataFormatDimMap) {
530     DeviceNameUtils::ParsedName parsed_name;
531     bool is_fanin_on_host =
532         DeviceNameUtils::ParseFullName(
533             GetDeviceName(context->virtual_placer.get(), *src_node_def),
534             &parsed_name) &&
535         parsed_name.type != "CPU" && IsHostMemory(*src_node_def, src_port);
536     const string node_name = absl::Substitute(name_format, op);
537     TF_RETURN_IF_ERROR(CreateDataFormatNode(
538         context, node_name, op, device, data_type, is_fanin_on_host,
539         is_src_format_to_dst_format, &added_node));
540     added_node_name = node_name;
541   } else {
542     return Status(error::INVALID_ARGUMENT,
543                   absl::StrCat("Unsupported op \"", op,
544                                "\". Supported ops are Transpose, "
545                                "DataFormatVecPerm, DataFormatDimMap."));
546   }
547 
548   // Connect src_node to 1st input of added_node.
549   utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
550   mutation->AddOrUpdateRegularFanin(added_node, 0,
551                                     {src_node->GetName(), src_port});
552 
553   // Connect output of added_node to dst_node:dst_port.
554   mutation->AddOrUpdateRegularFanin(dst_node, dst_port, {added_node_name, 0});
555 
556   return Status::OK();
557 }
558 
GetFanoutPortRank(const utils::MutableNodeView & node,int port) const559 int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
560                                   int port) const {
561   const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
562   if (output_shape_attr == nullptr ||
563       output_shape_attr->list().shape_size() <= port) {
564     return kInvalidRank;
565   }
566   const auto& shape = output_shape_attr->list().shape(port);
567   if (shape.unknown_rank()) {
568     return kUnknownRank;
569   }
570   return shape.dim_size();
571 }
572 
IsFanoutPortRankN(const utils::MutableNodeView & node,int port,int n) const573 bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
574                                    int n) const {
575   return GetFanoutPortRank(node, port) == n;
576 }
577 
IsFanoutPortsRankN(const utils::MutableNodeView & node,absl::Span<const int> ports,int n) const578 bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
579                                     absl::Span<const int> ports, int n) const {
580   for (const auto& port : ports) {
581     if (!IsFanoutPortRankN(node, port, n)) {
582       return false;
583     }
584   }
585   return true;
586 }
587 
GetFaninPortRank(const utils::MutableNodeView & node,int port) const588 int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
589                                  int port) const {
590   if (port < node.NumRegularFanins() && port >= 0) {
591     const auto& regular_fanin = node.GetRegularFanin(port);
592     return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
593   }
594   return kInvalidRank;
595 }
596 
IsFaninPortRankN(const utils::MutableNodeView & node,int port,int n) const597 bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
598                                   int n) const {
599   return GetFaninPortRank(node, port) == n;
600 }
601 
IsFaninPortDimsNIfConst(const utils::MutableNodeView & node,int port,absl::Span<const int> dims) const602 bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
603                                          int port,
604                                          absl::Span<const int> dims) const {
605   if (port < node.NumRegularFanins() && port >= 0) {
606     const auto& regular_fanin = node.GetRegularFanin(port);
607     const auto* fanin_node_view = regular_fanin.node_view();
608     if (!IsConstant(*fanin_node_view->node())) {
609       return true;
610     }
611     // If fanin is a Const, check tensor to see if dimensions match.
612     const auto* value_attr = fanin_node_view->GetAttr(kAttrValue);
613     if (value_attr == nullptr) {
614       return false;
615     }
616     Tensor tensor;
617     if (!tensor.FromProto(value_attr->tensor())) {
618       return false;
619     }
620     const int dims_size = dims.size();
621     if (tensor.dims() != dims_size) {
622       return false;
623     }
624     for (int i = 0; i < dims_size; ++i) {
625       if (tensor.dim_size(i) != dims[i]) {
626         return false;
627       }
628     }
629     return true;
630   }
631   return false;
632 }
633 
IsFaninPortsDimsNIfConst(const utils::MutableNodeView & node,absl::Span<const int> ports,absl::Span<const int> dims) const634 bool Transposer::IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
635                                           absl::Span<const int> ports,
636                                           absl::Span<const int> dims) const {
637   for (const auto& port : ports) {
638     if (!IsFaninPortDimsNIfConst(node, port, dims)) {
639       return false;
640     }
641   }
642   return true;
643 }
644 
CanProcessNode(const TransposeContext & context,const utils::MutableNodeView & node) const645 bool Transposer::CanProcessNode(const TransposeContext& context,
646                                 const utils::MutableNodeView& node) const {
647   return !context.nodes_to_preserve.contains(node.GetName()) &&
648          !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
649 }
650 
GetFaninNameFormat(absl::string_view node_name,int port,absl::string_view src_format,absl::string_view dst_format)651 string Transposer::GetFaninNameFormat(absl::string_view node_name, int port,
652                                       absl::string_view src_format,
653                                       absl::string_view dst_format) {
654   return absl::StrCat(node_name, "-", port, "-$0", src_format, "To", dst_format,
655                       "-", kOptimizedSuffix);
656 }
657 
GetFanoutNameFormat(absl::string_view node_name,int port,int index,absl::string_view src_format,absl::string_view dst_format)658 string Transposer::GetFanoutNameFormat(absl::string_view node_name, int port,
659                                        int index, absl::string_view src_format,
660                                        absl::string_view dst_format) {
661   return absl::StrCat(node_name, "-", port, "-", index, "-$0", dst_format, "To",
662                       src_format, "-", kOptimizedSuffix);
663 }
664 
LayoutOptimizerNode(absl::string_view node_name)665 string Transposer::LayoutOptimizerNode(absl::string_view node_name) {
666   return absl::StrCat(node_name, "-", kOptimizedSuffix);
667 }
668 
GetReshapeNodeNameFormat(absl::string_view node_name,int index,absl::string_view src_format,absl::string_view dst_format)669 string Transposer::GetReshapeNodeNameFormat(absl::string_view node_name,
670                                             int index,
671                                             absl::string_view src_format,
672                                             absl::string_view dst_format) {
673   return absl::StrCat(node_name, "-", index, "-", kReshape, src_format, "To",
674                       dst_format);
675 }
676 
GetShapeConstNodeNameFormat(absl::string_view node_name,int index)677 string Transposer::GetShapeConstNodeNameFormat(absl::string_view node_name,
678                                                int index) {
679   return absl::StrCat(node_name, "-", index, "-", kReshapeConst);
680 }
681 
682 // Layout sensitive transposer.
683 
GetLayoutSensitiveNodeDataFormat(const utils::MutableNodeView & node)684 inline string GetLayoutSensitiveNodeDataFormat(
685     const utils::MutableNodeView& node) {
686   const auto* attr = node.GetAttr(kAttrDataFormat);
687   if (attr != nullptr) {
688     return attr->s();
689   }
690   return "";
691 }
692 
UpdateNode(TransposeContext * context,utils::MutableNodeView * node)693 Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
694                                                utils::MutableNodeView* node) {
695   utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
696   AttrValue data_format_attr;
697   data_format_attr.set_s(context->dst_format);
698   mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr);
699 
700   auto permute_attr = [&context, &node,
701                        &mutation](absl::string_view attr_name) {
702     const auto* attr = node->GetAttr(attr_name);
703     if (attr != nullptr) {
704       AttrValue attr_copy(*attr);
705       TF_RETURN_IF_ERROR(PermuteSingle(
706           absl::StrCat(attr_name, " attribute in", node->GetName()),
707           context->src_to_dst, attr_copy.mutable_list()->mutable_i()));
708       mutation->AddOrUpdateNodeAttr(node, attr_name, attr_copy);
709     }
710     return Status::OK();
711   };
712 
713   // Update attrs.
714   TF_RETURN_IF_ERROR(permute_attr(kAttrStrides));
715   TF_RETURN_IF_ERROR(permute_attr(kAttrKSize));
716   TF_RETURN_IF_ERROR(permute_attr(kAttrDilations));
717 
718   const auto* explicit_paddings_attr = node->GetAttr(kAttrExplicitPaddings);
719   if (explicit_paddings_attr != nullptr && explicit_paddings_attr->has_list() &&
720       explicit_paddings_attr->list().i_size() > 0) {
721     AttrValue explicit_paddings_attr_copy(*explicit_paddings_attr);
722     TF_RETURN_IF_ERROR(PermuteDouble(
723         absl::StrCat("explicit_paddings attribute in", node->GetName()),
724         context->src_to_dst,
725         explicit_paddings_attr_copy.mutable_list()->mutable_i()));
726     mutation->AddOrUpdateNodeAttr(node, kAttrExplicitPaddings,
727                                   explicit_paddings_attr_copy);
728   }
729 
730   return Status::OK();
731 }
732 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)733 Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
734     TransposeContext* context, utils::MutableNodeView* node) {
735   DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
736   const int rank = GetFanoutPortRank(*node, 0);
737   if (rank != 4 && rank != 5) {
738     return Status::OK();
739   }
740   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
741   if (!ShouldProcess(*context, *node)) {
742     return Status::OK();
743   }
744   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
745           << "' with op '" << node->GetOp() << "' from data format '"
746           << context->src_format << "' to '" << context->dst_format << "'";
747   TF_RETURN_IF_ERROR(UpdateNode(context, node));
748   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
749   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
750   return context->graph_view->GetMutationBuilder()->Apply();
751 }
752 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)753 Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
754                                             utils::MutableNodeView* node) {
755   DCHECK(IsAvgPoolGrad(*node->node()));
756   if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) {
757     return Status::OK();
758   }
759   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
760           << "' with op '" << node->GetOp() << "' from data format '"
761           << context->src_format << "' to '" << context->dst_format << "'";
762   TF_RETURN_IF_ERROR(UpdateNode(context, node));
763   TF_RETURN_IF_ERROR(
764       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
765   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
766   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
767   return context->graph_view->GetMutationBuilder()->Apply();
768 }
769 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)770 Status BiasAddTransposer::TransposeNode(TransposeContext* context,
771                                         utils::MutableNodeView* node) {
772   // This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
773   // supports different data format.
774   DCHECK(IsBiasAddV2(*node->node()));
775   const int rank = GetFanoutPortRank(*node, 0);
776   if (rank != 4 && rank != 5) {
777     return Status::OK();
778   }
779   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
780     return Status::OK();
781   }
782   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
783           << "' with op '" << node->GetOp() << "' from data format '"
784           << context->src_format << "' to '" << context->dst_format << "'";
785   // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
786   // second or the last dim. Therefore, we use the original 4D data format in
787   // the context to update the node. For the input/output tensor, the
788   // corresponding 4D or 5D data format is needed.
789   TF_RETURN_IF_ERROR(UpdateNode(context, node));
790   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
791   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
792   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
793   return context->graph_view->GetMutationBuilder()->Apply();
794 }
795 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)796 Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
797                                             utils::MutableNodeView* node) {
798   DCHECK(IsBiasAddGrad(*node->node()));
799   const int rank = GetFaninPortRank(*node, 0);
800   if (rank != 4 && rank != 5) {
801     return Status::OK();
802   }
803   if (!ShouldProcess(*context, *node)) {
804     return Status::OK();
805   }
806   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
807           << "' with op '" << node->GetOp() << "' from data format '"
808           << context->src_format << "' to '" << context->dst_format << "'";
809   // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
810   // second or the last dim. Therefore, we use the original 4D data format in
811   // the context to update the node. For the input tensor, the corresponding 4D
812   // or 5D data format is needed.
813   TF_RETURN_IF_ERROR(UpdateNode(context, node));
814   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
815   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
816   // No need to update output shape, as it is always of shape 1-D with size the
817   // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
818   // used.
819   return context->graph_view->GetMutationBuilder()->Apply();
820 }
821 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)822 Status Conv2DBackpropFilterTransposer::TransposeNode(
823     TransposeContext* context, utils::MutableNodeView* node) {
824   DCHECK(IsConv2DBackpropFilter(*node->node()) ||
825          IsDepthwiseConv2dNativeBackpropFilter(*node->node()));
826   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
827     return Status::OK();
828   }
829   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
830           << "' with op '" << node->GetOp() << "' from data format '"
831           << context->src_format << "' to '" << context->dst_format << "'";
832   TF_RETURN_IF_ERROR(UpdateNode(context, node));
833   TF_RETURN_IF_ERROR(
834       UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
835   // No need to update output shape, as it is always of shape
836   // [filter_height, filter_width, in_channels, out_channels], regardless of
837   // whether NCHW or NHWC is used.
838   return context->graph_view->GetMutationBuilder()->Apply();
839 }
840 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)841 Status Conv2DBackpropInputTransposer::TransposeNode(
842     TransposeContext* context, utils::MutableNodeView* node) {
843   DCHECK(IsConv2DBackpropInput(*node->node()) ||
844          IsDepthwiseConv2dNativeBackpropInput(*node->node()));
845   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
846     return Status::OK();
847   }
848 
849   const auto& fanin = node->GetRegularFanin(0);
850   auto* fanin_node = fanin.node_view();
851   const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
852   if (output_shape_attr == nullptr) {
853     VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
854             << " because it is missing attribute " << kAttrOutputShape;
855     return Status::OK();
856   }
857   TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
858   if (fanin_shape.dim_size() != 1) {
859     VLOG(3) << fanin_node->GetName() << " is not a vector.";
860     return Status::OK();
861   }
862 
863   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
864           << "' with op '" << node->GetOp() << "' from data format '"
865           << context->src_format << "' to '" << context->dst_format << "'";
866   TF_RETURN_IF_ERROR(UpdateNode(context, node));
867   TF_RETURN_IF_ERROR(
868       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
869   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
870   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
871   return context->graph_view->GetMutationBuilder()->Apply();
872 }
873 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)874 Status Conv3DTransposer::TransposeNode(TransposeContext* context,
875                                        utils::MutableNodeView* node) {
876   DCHECK(IsConv3D(*node->node()));
877   const int rank = GetFanoutPortRank(*node, 0);
878   if (rank != 5) {
879     return Status::OK();
880   }
881   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
882   if (!ShouldProcess(*context, *node)) {
883     return Status::OK();
884   }
885   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
886           << "' with op '" << node->GetOp() << "' from data format '"
887           << context->src_format << "' to '" << context->dst_format << "'";
888   TF_RETURN_IF_ERROR(UpdateNode(context, node));
889   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
890   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
891   return context->graph_view->GetMutationBuilder()->Apply();
892 }
893 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)894 Status Conv3DBackpropFilterTransposer::TransposeNode(
895     TransposeContext* context, utils::MutableNodeView* node) {
896   DCHECK(IsConv3DBackpropFilterV2(*node->node()));
897   const int rank = GetFanoutPortRank(*node, 0);
898   if (rank != 5) {
899     return Status::OK();
900   }
901   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
902   if (!ShouldProcess(*context, *node)) {
903     return Status::OK();
904   }
905   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
906           << "' with op '" << node->GetOp() << "' from data format '"
907           << context->src_format << "' to '" << context->dst_format << "'";
908   TF_RETURN_IF_ERROR(UpdateNode(context, node));
909   TF_RETURN_IF_ERROR(
910       UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
911   // No need to update output shape, as it is always of shape
912   // [filter_height, filter_width, in_channels, out_channels], regardless of
913   // whether NCHW or NHWC is used.
914   return context->graph_view->GetMutationBuilder()->Apply();
915 }
916 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)917 Status Conv3DBackpropInputTransposer::TransposeNode(
918     TransposeContext* context, utils::MutableNodeView* node) {
919   DCHECK(IsConv3DBackpropInputV2(*node->node()));
920   const int rank = GetFanoutPortRank(*node, 0);
921   if (rank != 5) {
922     return Status::OK();
923   }
924   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
925   if (!ShouldProcess(*context, *node)) {
926     return Status::OK();
927   }
928   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
929           << "' with op '" << node->GetOp() << "' from data format '"
930           << context->src_format << "' to '" << context->dst_format << "'";
931   TF_RETURN_IF_ERROR(UpdateNode(context, node));
932   TF_RETURN_IF_ERROR(
933       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
934   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
935   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
936   return context->graph_view->GetMutationBuilder()->Apply();
937 }
938 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)939 Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context,
940                                                  utils::MutableNodeView* node) {
941   DCHECK(IsFusedBatchNormEx(*node->node()));
942   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
943     return Status::OK();
944   }
945   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
946           << "' with op '" << node->GetOp() << "' from data format '"
947           << context->src_format << "' to '" << context->dst_format << "'";
948   TF_RETURN_IF_ERROR(UpdateNode(context, node));
949   if (node->NumRegularFanins() == 6) {
950     TF_RETURN_IF_ERROR(
951         UpdateFaninEdgesWithOp(context, {0, 5}, node, kOpTranspose));
952   } else {
953     TF_RETURN_IF_ERROR(
954         UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
955   }
956   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
957   return context->graph_view->GetMutationBuilder()->Apply();
958 }
959 
IsTraining(const utils::MutableNodeView & node) const960 bool FusedBatchNormGradTransposer::IsTraining(
961     const utils::MutableNodeView& node) const {
962   const auto* is_training_attr = node.GetAttr(kAttrIsTraining);
963   if (is_training_attr != nullptr) {
964     return is_training_attr->b();
965   }
966   return false;
967 }
968 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)969 Status FusedBatchNormGradTransposer::TransposeNode(
970     TransposeContext* context, utils::MutableNodeView* node) {
971   DCHECK(IsFusedBatchNormGrad(*node->node()));
972   const int rank = GetFanoutPortRank(*node, 0);
973   if (rank != 4 && rank != 5) {
974     return Status::OK();
975   }
976   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
977   if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
978     return Status::OK();
979   }
980   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
981           << "' with op '" << node->GetOp() << "' from data format '"
982           << context->src_format << "' to '" << context->dst_format << "'";
983   TF_RETURN_IF_ERROR(UpdateNode(context, node));
984   TF_RETURN_IF_ERROR(
985       UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
986   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
987   return context->graph_view->GetMutationBuilder()->Apply();
988 }
989 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)990 Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context,
991                                           utils::MutableNodeView* node) {
992   DCHECK(IsMaxPoolV2(*node->node()));
993   // We check data_input's shape instead, because the shape inference of
994   // MaxPoolV2 is not able to infer the shape when ksize or strides is not
995   // constant.
996   const auto& data_fanin = node->GetRegularFanin(0);
997   auto* data_fanin_node = data_fanin.node_view();
998   if (!ShouldProcess(*context, *node) ||
999       !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 4)) {
1000     return Status::OK();
1001   }
1002   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1003           << "' with op '" << node->GetOp() << "' from data format '"
1004           << context->src_format << "' to '" << context->dst_format << "'";
1005   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1006   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1007   TF_RETURN_IF_ERROR(
1008       UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1009   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1010   return context->graph_view->GetMutationBuilder()->Apply();
1011 }
1012 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1013 Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context,
1014                                             utils::MutableNodeView* node) {
1015   DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node()));
1016   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1017     return Status::OK();
1018   }
1019   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1020           << "' with op '" << node->GetOp() << "' from data format '"
1021           << context->src_format << "' to '" << context->dst_format << "'";
1022   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1023   TF_RETURN_IF_ERROR(
1024       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1025   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1026   return context->graph_view->GetMutationBuilder()->Apply();
1027 }
1028 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1029 Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context,
1030                                               utils::MutableNodeView* node) {
1031   DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node()));
1032   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1033     return Status::OK();
1034   }
1035   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1036           << "' with op '" << node->GetOp() << "' from data format '"
1037           << context->src_format << "' to '" << context->dst_format << "'";
1038   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1039   TF_RETURN_IF_ERROR(
1040       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1041   TF_RETURN_IF_ERROR(
1042       UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute));
1043   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1044   return context->graph_view->GetMutationBuilder()->Apply();
1045 }
1046 
1047 // Layout agnostic transposer.
1048 
IsValidConstPermTransposeNode(const utils::MutableNodeView & node,absl::Span<const int> permutation)1049 inline bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
1050                                           absl::Span<const int> permutation) {
1051   Tensor tensor;
1052   if (!GetValueAttrFromConstInputNode(node, IsTranspose, 1, &tensor)) {
1053     return false;
1054   }
1055   const int permutation_size = permutation.size();
1056   if (tensor.NumElements() != permutation_size) {
1057     return false;
1058   }
1059 
1060   const auto& tensor_data = tensor.unaligned_flat<int32>();
1061   for (int i = 0; i < permutation_size; i++) {
1062     if (permutation[i] != tensor_data(i)) {
1063       return false;
1064     }
1065   }
1066   return true;
1067 }
1068 
IsValidDataFormatNode(const utils::MutableNodeView & node,absl::string_view src_format,absl::string_view dst_format)1069 inline bool IsValidDataFormatNode(const utils::MutableNodeView& node,
1070                                   absl::string_view src_format,
1071                                   absl::string_view dst_format) {
1072   if (!IsDataFormatOp(node)) {
1073     return false;
1074   }
1075   const auto* src_format_attr = node.GetAttr(kAttrSrcFormat);
1076   if (src_format_attr == nullptr || src_format_attr->s() != src_format) {
1077     return false;
1078   }
1079   const auto* dst_format_attr = node.GetAttr(kAttrDstFormat);
1080   if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) {
1081     return false;
1082   }
1083   return true;
1084 }
1085 
IsLayoutOptimizerAddedDstToSrcTranspose(const TransposeContext & context,const utils::MutableNodeView & node)1086 inline bool IsLayoutOptimizerAddedDstToSrcTranspose(
1087     const TransposeContext& context, const utils::MutableNodeView& node) {
1088   return node.node_index() >= context.num_nodes &&
1089          IsValidConstPermTransposeNode(node, context.dst_to_src);
1090 }
1091 
IsLayoutOptimizerAddedDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node)1092 inline bool IsLayoutOptimizerAddedDstToSrcTransform(
1093     const TransposeContext& context, const utils::MutableNodeView& node) {
1094   return node.node_index() >= context.num_nodes &&
1095          (IsValidConstPermTransposeNode(node, context.dst_to_src) ||
1096           IsValidDataFormatNode(node, context.dst_format, context.src_format));
1097 }
1098 
IsAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1099 bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
1100     const TransposeContext& context, const utils::MutableNodeView& node) const {
1101   std::deque<utils::MutableNodeView*> queue;
1102   absl::flat_hash_set<utils::MutableNodeView*> visited_nodes;
1103   auto data_node_pos = GetDataFaninPorts(node);
1104   for (const int pos : data_node_pos) {
1105     const auto& fanin = node.GetRegularFanin(pos);
1106     auto* fanin_node = fanin.node_view();
1107     queue.push_back(fanin_node);
1108     visited_nodes.insert(fanin_node);
1109   }
1110   // The code will exit this while loop in one iteration in most cases, as the
1111   // graph is already topologically sorted.
1112   while (!queue.empty()) {
1113     utils::MutableNodeView* current_node = queue.front();
1114     queue.pop_front();
1115     if (IsLayoutOptimizerAddedDstToSrcTransform(context, *current_node)) {
1116       return true;
1117     }
1118     // We only continue searching if the path is connected through
1119     // format-agnostic nodes.
1120     if (IsLayoutAgnosticOp(*current_node->node())) {
1121       auto current_node_pos = GetDataFaninPorts(*current_node);
1122       for (const auto& pos : current_node_pos) {
1123         const auto& fanin = current_node->GetRegularFanin(pos);
1124         auto* fanin_node = fanin.node_view();
1125         if (visited_nodes.insert(fanin_node).second) {
1126           queue.push_back(fanin_node);
1127         }
1128       }
1129     }
1130   }
1131   return false;
1132 }
1133 
GetVariadicNDFaninPorts(const TransposeContext & context,const utils::MutableNodeView & node,int rank) const1134 std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
1135     const TransposeContext& context, const utils::MutableNodeView& node,
1136     int rank) const {
1137   std::vector<int> ports;
1138   const int num_regular_fanins = node.NumRegularFanins();
1139   ports.reserve(num_regular_fanins);
1140   for (int i = 0; i < num_regular_fanins; ++i) {
1141     const auto& regular_fanin = node.GetRegularFanin(i);
1142     auto* regular_fanin_node = regular_fanin.node_view();
1143     int regular_fanin_port = regular_fanin.index();
1144     if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
1145         ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1146           IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1147          IsLayoutOptimizerAddedDstToSrcTranspose(context,
1148                                                  *regular_fanin_node))) {
1149       ports.push_back(i);
1150     }
1151   }
1152   return ports;
1153 }
1154 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1155 Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
1156     TransposeContext* context, utils::MutableNodeView* node) {
1157   DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
1158   const int rank = GetFanoutPortRank(*node, 0);
1159   if (rank != 4 && rank != 5) {
1160     return Status::OK();
1161   }
1162   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1163   if (!ShouldProcess(*context, *node) ||
1164       !IsAfterDstToSrcTransform(*context, *node)) {
1165     return Status::OK();
1166   }
1167   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1168           << "' with op '" << node->GetOp() << "' from data format '"
1169           << context->src_format << "' to '" << context->dst_format << "'";
1170   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1171   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1172   return context->graph_view->GetMutationBuilder()->Apply();
1173 }
1174 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1175 Status AddNTransposer::TransposeNode(TransposeContext* context,
1176                                      utils::MutableNodeView* node) {
1177   DCHECK(IsAddN(*node->node()));
1178   const int rank = GetFanoutPortRank(*node, 0);
1179   if (rank != 4 && rank != 5) {
1180     return Status::OK();
1181   }
1182   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1183   if (!ShouldProcess(*context, *node) ||
1184       !IsAfterDstToSrcTransform(*context, *node)) {
1185     return Status::OK();
1186   }
1187   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1188           << "' with op '" << node->GetOp() << "' from data format '"
1189           << context->src_format << "' to '" << context->dst_format << "'";
1190   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1191                                             node, kOpTranspose));
1192   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1193   return context->graph_view->GetMutationBuilder()->Apply();
1194 }
1195 
IsNDOperateWithMD(const utils::MutableNodeView & node,int n,int m)1196 bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
1197                                            int n, int m) {
1198   return IsFaninPortRankN(node, 0, n) && IsFaninPortRankN(node, 1, m);
1199 }
1200 
IsFaninShapeSupported(const utils::MutableNodeView & node,int rank)1201 bool BinaryOpTransposer::IsFaninShapeSupported(
1202     const utils::MutableNodeView& node, int rank) {
1203   return (IsNDOperateWithMD(node, rank, 0) ||
1204           IsNDOperateWithMD(node, rank, 1) ||
1205           IsNDOperateWithMD(node, rank, rank) ||
1206           IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
1207 }
1208 
GetNDDataFaninPorts(const utils::MutableNodeView & node,int rank)1209 std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
1210     const utils::MutableNodeView& node, int rank) {
1211   std::vector<int> values;
1212   if (IsFaninPortRankN(node, 0, rank)) {
1213     values.push_back(0);
1214   }
1215   if (IsFaninPortRankN(node, 1, rank)) {
1216     values.push_back(1);
1217   }
1218   return values;
1219 }
1220 
AddNodeReshape(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,absl::string_view input_name,absl::string_view shape_const_node_name,const DataType & data_type)1221 Status BinaryOpTransposer::AddNodeReshape(
1222     utils::Mutation* mutation, absl::string_view node_name,
1223     absl::string_view node_device, absl::string_view input_name,
1224     absl::string_view shape_const_node_name, const DataType& data_type) {
1225   NodeDef new_node;
1226   new_node.set_name(string(node_name));
1227   new_node.add_input(string(input_name));
1228   new_node.add_input(string(shape_const_node_name));
1229   new_node.set_op(kReshape);
1230   new_node.set_device(string(node_device));
1231 
1232   AttrValue attr_type_indices;
1233   attr_type_indices.set_type(DT_INT32);
1234   new_node.mutable_attr()->insert({"Tshape", attr_type_indices});
1235 
1236   AttrValue attr_type_params;
1237   attr_type_params.set_type(data_type);
1238   new_node.mutable_attr()->insert({"T", attr_type_params});
1239 
1240   Status status;
1241   mutation->AddNode(std::move(new_node), &status);
1242   return status;
1243 }
1244 
AddNodeShapeConst(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,bool node_in_frame,int num_channels,absl::string_view depended_node,int rank)1245 Status BinaryOpTransposer::AddNodeShapeConst(
1246     utils::Mutation* mutation, absl::string_view node_name,
1247     absl::string_view node_device, bool node_in_frame, int num_channels,
1248     absl::string_view depended_node, int rank) {
1249   NodeDef new_node;
1250   new_node.set_name(string(node_name));
1251   new_node.set_op(kOpConst);
1252   new_node.set_device(string(node_device));
1253   AttrValue attr_data_type;
1254   attr_data_type.set_type(DT_INT32);
1255   new_node.mutable_attr()->insert({"dtype", attr_data_type});
1256 
1257   AttrValue attr_tensor;
1258   Tensor tensor(DT_INT32, TensorShape({rank}));
1259   std::vector<int> shape(rank, 1);
1260   shape[1] = num_channels;
1261   for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1262     tensor.flat<int>()(i) = shape[i];
1263   }
1264   tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1265   new_node.mutable_attr()->insert({"value", attr_tensor});
1266   if (node_in_frame) {
1267     // This is to ensure the transpose node and the const node are in the same
1268     // frame.
1269     // TODO(halehri): Add Test that exercises this condition.
1270     new_node.add_input(AsControlDependency(string(depended_node)));
1271   }
1272 
1273   Status status;
1274   mutation->AddNode(std::move(new_node), &status);
1275   return status;
1276 }
1277 
MaybeReshapeVectorFanin(TransposeContext * context,utils::MutableNodeView * node,int rank)1278 Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
1279                                                    utils::MutableNodeView* node,
1280                                                    int rank) {
1281   int vector_index = -1;
1282   if (IsNDOperateWithMD(*node, rank, 1)) {
1283     vector_index = 1;
1284   } else if (IsNDOperateWithMD(*node, 1, rank)) {
1285     vector_index = 0;
1286   }
1287   if (vector_index != -1) {
1288     const string& node_name = node->GetName();
1289     const string& node_device = node->GetDevice();
1290     string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat(
1291         node_name, vector_index, context->src_format, context->dst_format));
1292     string shape_const_node_name = LayoutOptimizerNode(
1293         GetShapeConstNodeNameFormat(node_name, vector_index));
1294     const auto& fanin = node->GetRegularFanin(vector_index);
1295     auto* fanin_node = fanin.node_view();
1296     const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
1297     if (output_shape_attr == nullptr) {
1298       return errors::InvalidArgument("Missing attribute ", kAttrOutputShape);
1299     }
1300     int vector_size =
1301         output_shape_attr->list().shape(fanin.index()).dim(0).size();
1302     utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
1303     TF_RETURN_IF_ERROR(
1304         AddNodeShapeConst(mutation, shape_const_node_name, node_device,
1305                           context->frames.IsInFrame(*node->node()), vector_size,
1306                           fanin_node->GetName(), rank));
1307     const auto* t_attr = node->GetAttr(kAttrT);
1308     if (t_attr == nullptr) {
1309       return errors::InvalidArgument("Missing attribute ", kAttrT);
1310     }
1311     TF_RETURN_IF_ERROR(
1312         AddNodeReshape(mutation, reshape_node_name, node_device,
1313                        TensorIdToString({fanin_node->GetName(), fanin.index()}),
1314                        shape_const_node_name, t_attr->type()));
1315     mutation->AddOrUpdateRegularFanin(node, vector_index,
1316                                       {reshape_node_name, 0});
1317   }
1318   return Status::OK();
1319 }
1320 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1321 Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
1322                                          utils::MutableNodeView* node) {
1323   DCHECK(IsBinaryOp(*node->node()));
1324   const int rank = GetFanoutPortRank(*node, 0);
1325   if (rank != 4 && rank != 5) {
1326     return Status::OK();
1327   }
1328   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1329   if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
1330       !IsAfterDstToSrcTransform(*context, *node)) {
1331     return Status::OK();
1332   }
1333   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1334           << "' with op '" << node->GetOp() << "' from data format '"
1335           << context->src_format << "' to '" << context->dst_format << "'";
1336   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1337       context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
1338   TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
1339   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1340   return context->graph_view->GetMutationBuilder()->Apply();
1341 }
1342 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1343 Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
1344                                          utils::MutableNodeView* node) {
1345   DCHECK(IsConcat(*node->node()));
1346   const int rank = GetFanoutPortRank(*node, 0);
1347   if (rank != 4 && rank != 5) {
1348     return Status::OK();
1349   }
1350   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1351   if (!ShouldProcess(*context, *node) ||
1352       !IsAfterDstToSrcTransform(*context, *node)) {
1353     return Status::OK();
1354   }
1355   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1356       context, GetConcatDataFaninPorts(*node), node, kOpTranspose));
1357   int axis_node = 0;
1358   if (node->GetOp() == "ConcatV2") {
1359     const auto* n_attr = node->GetAttr(kAttrN);
1360     if (n_attr != nullptr) {
1361       axis_node = n_attr->i();
1362     }
1363   }
1364   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1365           << "' with op '" << node->GetOp() << "' from data format '"
1366           << context->src_format << "' to '" << context->dst_format << "'";
1367   TF_RETURN_IF_ERROR(
1368       UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
1369   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1370   return context->graph_view->GetMutationBuilder()->Apply();
1371 }
1372 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1373 Status FillOpTransposer::TransposeNode(TransposeContext* context,
1374                                        utils::MutableNodeView* node) {
1375   DCHECK(IsFill(*node->node()));
1376   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1377       !IsFaninPortDimsNIfConst(*node, 0, {4}) ||
1378       !IsAfterDstToSrcTransform(*context, *node)) {
1379     return Status::OK();
1380   }
1381   TF_RETURN_IF_ERROR(
1382       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1383   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1384   return context->graph_view->GetMutationBuilder()->Apply();
1385 }
1386 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1387 Status IdentityNTransposer::TransposeNode(TransposeContext* context,
1388                                           utils::MutableNodeView* node) {
1389   DCHECK(IsIdentityN(*node->node()));
1390   const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
1391 
1392   // Temporarily upgrade the context to obtain the number of 5D fanin ports.
1393   std::vector<int> ports_5d;
1394   {
1395     ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1396     ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
1397   }
1398 
1399   if (!ShouldProcess(*context, *node)) {
1400     return Status::OK();
1401   }
1402 
1403   if (!ports_4d.empty()) {
1404     TF_RETURN_IF_ERROR(
1405         UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
1406     TF_RETURN_IF_ERROR(
1407         UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
1408   }
1409 
1410   if (!ports_5d.empty()) {
1411     ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1412     TF_RETURN_IF_ERROR(
1413         UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
1414     TF_RETURN_IF_ERROR(
1415         UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
1416   }
1417   return context->graph_view->GetMutationBuilder()->Apply();
1418 }
1419 
IsEveryFaninAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1420 bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform(
1421     const TransposeContext& context, const utils::MutableNodeView& node) const {
1422   for (const auto& regular_fanin : node.GetRegularFanins()) {
1423     auto* regular_fanin_node = regular_fanin.node_view();
1424     if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) &&
1425         ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1426           IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1427          IsLayoutOptimizerAddedDstToSrcTranspose(context,
1428                                                  *regular_fanin_node))) {
1429       continue;
1430     }
1431     return false;
1432   }
1433   return true;
1434 }
1435 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1436 Status MergeTransposer::TransposeNode(TransposeContext* context,
1437                                       utils::MutableNodeView* node) {
1438   DCHECK(IsMerge(*node->node()));
1439   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1440       !IsEveryFaninAfterDstToSrcTransform(*context, *node)) {
1441     return Status::OK();
1442   }
1443   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1444                                             node, kOpTranspose));
1445   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1446   return context->graph_view->GetMutationBuilder()->Apply();
1447 }
1448 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1449 Status PadTransposer::TransposeNode(TransposeContext* context,
1450                                     utils::MutableNodeView* node) {
1451   DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) ||
1452          IsPad(*node->node()));
1453   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1454       !IsFaninPortDimsNIfConst(*node, 1, {4, 2}) ||
1455       !IsAfterDstToSrcTransform(*context, *node)) {
1456     return Status::OK();
1457   }
1458   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1459   TF_RETURN_IF_ERROR(
1460       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1461   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1462   return context->graph_view->GetMutationBuilder()->Apply();
1463 }
1464 
KeepDims(const utils::MutableNodeView & node)1465 bool ReduceTransposer::KeepDims(const utils::MutableNodeView& node) {
1466   const auto* keep_dims_attr = node.GetAttr(kAttrKeepDims);
1467   if (keep_dims_attr != nullptr) {
1468     return keep_dims_attr->b();
1469   }
1470   return false;
1471 }
1472 
IsAlongAxis(const Tensor & tensor,absl::Span<const int> axis,int rank)1473 bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
1474                                    absl::Span<const int> axis, int rank) {
1475   const int axis_size = axis.size();
1476   if (tensor.dims() != 1 || tensor.dim_size(0) != axis_size) {
1477     return false;
1478   }
1479   for (int i = 0; i < axis_size; ++i) {
1480     int local_axis = 0;
1481     if (tensor.dtype() == DT_INT32) {
1482       local_axis = tensor.flat<int32>()(i);
1483     } else {
1484       local_axis = tensor.flat<int64>()(i);
1485     }
1486     if (local_axis < 0) {
1487       local_axis += rank;
1488     }
1489     bool along_axis = false;
1490     for (int dim : axis) {
1491       if (local_axis == dim) {
1492         along_axis = true;
1493         break;
1494       }
1495     }
1496     if (!along_axis) {
1497       return false;
1498     }
1499   }
1500   return true;
1501 }
1502 
IsReduceAxisSupported(const TransposeContext & context,const utils::MutableNodeView & node,int rank)1503 bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
1504                                              const utils::MutableNodeView& node,
1505                                              int rank) {
1506   if (KeepDims(node)) {
1507     return true;
1508   }
1509   const auto& regular_fanin_1 = node.GetRegularFanin(1);
1510   auto* axis_node = regular_fanin_1.node_view();
1511   if (!IsConstant(*axis_node->node())) {
1512     return false;
1513   }
1514   const auto* value_attr = axis_node->GetAttr(kAttrValue);
1515   if (value_attr == nullptr) {
1516     return false;
1517   }
1518   Tensor tensor;
1519   if (!tensor.FromProto(value_attr->tensor())) {
1520     LOG(ERROR) << "Failed to parse TensorProto.";
1521     return false;
1522   }
1523   auto indices = [&context](absl::Span<const char> labels) {
1524     return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1525   };
1526   if (rank == 5) {
1527     return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
1528            IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
1529            IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
1530            IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
1531            IsAlongAxis(tensor, indices({'C'}), 5);
1532   }
1533   DCHECK_EQ(rank, 4);
1534   return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
1535          IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
1536          IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
1537          IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
1538          IsAlongAxis(tensor, indices({'C'}), 4);
1539 }
1540 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1541 Status ReduceTransposer::TransposeNode(TransposeContext* context,
1542                                        utils::MutableNodeView* node) {
1543   DCHECK(IsReduceOp(*node->node()));
1544   const int rank = GetFaninPortRank(*node, 0);
1545   if (rank != 4 && rank != 5) {
1546     return Status::OK();
1547   }
1548   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1549   if (!ShouldProcess(*context, *node) ||
1550       !IsReduceAxisSupported(*context, *node, rank) ||
1551       !IsAfterDstToSrcTransform(*context, *node)) {
1552     return Status::OK();
1553   }
1554   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1555           << "' with op '" << node->GetOp() << "' from data format '"
1556           << context->src_format << "' to '" << context->dst_format << "'";
1557   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1558   TF_RETURN_IF_ERROR(
1559       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1560   if (KeepDims(*node)) {
1561     TF_RETURN_IF_ERROR(
1562         UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1563   }
1564   return context->graph_view->GetMutationBuilder()->Apply();
1565 }
1566 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1567 Status ReverseV2Transposer::TransposeNode(TransposeContext* context,
1568                                           utils::MutableNodeView* node) {
1569   DCHECK(IsReverseV2(*node->node()));
1570   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1571       !IsAfterDstToSrcTransform(*context, *node)) {
1572     return Status::OK();
1573   }
1574   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1575   TF_RETURN_IF_ERROR(
1576       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1577   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1578   return context->graph_view->GetMutationBuilder()->Apply();
1579 }
1580 
IsFaninScalarVector4D(const utils::MutableNodeView & fanin,int port)1581 bool SelectTransposer::IsFaninScalarVector4D(
1582     const utils::MutableNodeView& fanin, int port) {
1583   return IsFanoutPortRankN(fanin, port, 0) ||
1584          IsFanoutPortRankN(fanin, port, 1) || IsFanoutPortRankN(fanin, port, 4);
1585 }
1586 
GetFaninPorts(const utils::MutableNodeView & fanin,int port)1587 std::vector<int> SelectTransposer::GetFaninPorts(
1588     const utils::MutableNodeView& fanin, int port) {
1589   // Input 0 could be a scalar, a vector with size matching the first dimension
1590   // of input 1 and 2, or must have the same shape as input 1 and 2.
1591   if (IsFanoutPortRankN(fanin, port, 4)) {
1592     return {0, 1, 2};
1593   }
1594   return {1, 2};
1595 }
1596 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1597 Status SelectTransposer::TransposeNode(TransposeContext* context,
1598                                        utils::MutableNodeView* node) {
1599   DCHECK(IsSelect(*node->node()));
1600   const auto& regular_fanin_0 = node->GetRegularFanin(0);
1601   auto* regular_fanin_0_node = regular_fanin_0.node_view();
1602   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1603       !IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) ||
1604       !IsAfterDstToSrcTransform(*context, *node)) {
1605     return Status::OK();
1606   }
1607   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1608       context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()),
1609       node, kOpTranspose));
1610   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1611   return context->graph_view->GetMutationBuilder()->Apply();
1612 }
1613 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1614 Status ShapeTransposer::TransposeNode(TransposeContext* context,
1615                                       utils::MutableNodeView* node) {
1616   DCHECK(IsShape(*node->node()));
1617   const int rank = GetFaninPortRank(*node, 0);
1618   if (rank != 4 && rank != 5) {
1619     return Status::OK();
1620   }
1621   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1622   if (!ShouldProcess(*context, *node) ||
1623       !IsAfterDstToSrcTransform(*context, *node)) {
1624     return Status::OK();
1625   }
1626   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1627           << "' with op '" << node->GetOp() << "' from data format '"
1628           << context->src_format << "' to '" << context->dst_format << "'";
1629   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1630   TF_RETURN_IF_ERROR(
1631       UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1632   return context->graph_view->GetMutationBuilder()->Apply();
1633 }
1634 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1635 Status ShapeNTransposer::TransposeNode(TransposeContext* context,
1636                                        utils::MutableNodeView* node) {
1637   DCHECK(IsShapeN(*node->node()));
1638   // ShapeN requires all input tensors to have the same dimensions. Therefore,
1639   // we simply use the 0th fanin port.
1640   const int rank = GetFaninPortRank(*node, 0);
1641   if (rank != 4 && rank != 5) {
1642     return Status::OK();
1643   }
1644   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1645   const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
1646   if (!ShouldProcess(*context, *node) || ports.empty()) {
1647     return Status::OK();
1648   }
1649   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1650           << "' with op '" << node->GetOp() << "' from data format '"
1651           << context->src_format << "' to '" << context->dst_format << "'";
1652   TF_RETURN_IF_ERROR(
1653       UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
1654   TF_RETURN_IF_ERROR(
1655       UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute));
1656   return context->graph_view->GetMutationBuilder()->Apply();
1657 }
1658 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1659 Status SliceTransposer::TransposeNode(TransposeContext* context,
1660                                       utils::MutableNodeView* node) {
1661   DCHECK(IsSlice(*node->node()));
1662   const int rank = GetFanoutPortRank(*node, 0);
1663   if (rank != 4 && rank != 5) {
1664     return Status::OK();
1665   }
1666   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1667   if (!ShouldProcess(*context, *node) ||
1668       !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
1669       !IsAfterDstToSrcTransform(*context, *node)) {
1670     return Status::OK();
1671   }
1672   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1673           << "' with op '" << node->GetOp() << "' from data format '"
1674           << context->src_format << "' to '" << context->dst_format << "'";
1675   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1676   TF_RETURN_IF_ERROR(
1677       UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1678   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1679   return context->graph_view->GetMutationBuilder()->Apply();
1680 }
1681 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1682 Status SplitTransposer::TransposeNode(TransposeContext* context,
1683                                       utils::MutableNodeView* node) {
1684   DCHECK(IsSplit(*node->node()));
1685   const auto ports = GetDataFanoutPorts(*node);
1686   if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1687       !IsAfterDstToSrcTransform(*context, *node)) {
1688     return Status::OK();
1689   }
1690   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
1691   TF_RETURN_IF_ERROR(
1692       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap));
1693   TF_RETURN_IF_ERROR(
1694       UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1695   return context->graph_view->GetMutationBuilder()->Apply();
1696 }
1697 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1698 Status SplitVTransposer::TransposeNode(TransposeContext* context,
1699                                        utils::MutableNodeView* node) {
1700   DCHECK(IsSplitV(*node->node()));
1701   const auto ports = GetDataFanoutPorts(*node);
1702   if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1703       !IsAfterDstToSrcTransform(*context, *node)) {
1704     return Status::OK();
1705   }
1706   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1707   TF_RETURN_IF_ERROR(
1708       UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap));
1709   TF_RETURN_IF_ERROR(
1710       UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1711   return context->graph_view->GetMutationBuilder()->Apply();
1712 }
1713 
IsInputConvertible(const TransposeContext & context,const utils::MutableNodeView & node) const1714 bool SqueezeTransposer::IsInputConvertible(
1715     const TransposeContext& context, const utils::MutableNodeView& node) const {
1716   const auto& regular_fanin_0 = node.GetRegularFanin(0);
1717   auto* regular_fanin_0_node = regular_fanin_0.node_view();
1718   const auto* output_shape_attr =
1719       regular_fanin_0_node->GetAttr(kAttrOutputShape);
1720   if (output_shape_attr != nullptr) {
1721     auto& shape = output_shape_attr->list().shape(regular_fanin_0.index());
1722     if (shape.dim_size() != kRank) {
1723       return false;
1724     }
1725     const int height_dim = context.src_dim_indices.at('H');
1726     const int width_dim = context.src_dim_indices.at('W');
1727     if (shape.dim(height_dim).size() == 1 && shape.dim(width_dim).size() == 1) {
1728       return true;
1729     }
1730   }
1731   return false;
1732 }
1733 
IsAlongAxis(const AttrValue & attr,absl::Span<const int> axis,int rank) const1734 bool SqueezeTransposer::IsAlongAxis(const AttrValue& attr,
1735                                     absl::Span<const int> axis,
1736                                     int rank) const {
1737   const auto& list = attr.list();
1738   // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1739   int axis_size = axis.size();
1740   if (list.i_size() == 0) {
1741     return true;
1742   } else if (list.i_size() != axis_size) {
1743     return false;
1744   }
1745   for (int i = 0; i < axis_size; ++i) {
1746     int local_axis = list.i(i);
1747     if (local_axis < 0) {
1748       local_axis += rank;
1749     }
1750     bool along_axis = false;
1751     for (int dim : axis) {
1752       if (local_axis == dim) {
1753         along_axis = true;
1754         break;
1755       }
1756     }
1757     if (!along_axis) {
1758       return false;
1759     }
1760   }
1761   return true;
1762 }
1763 
IsDimsSupported(const TransposeContext & context,const utils::MutableNodeView & node) const1764 bool SqueezeTransposer::IsDimsSupported(
1765     const TransposeContext& context, const utils::MutableNodeView& node) const {
1766   auto indices = [&context](absl::Span<const char> labels) {
1767     return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1768   };
1769   const auto* squeeze_dims_attr = node.GetAttr(kAttrSqueezeDims);
1770   if (squeeze_dims_attr == nullptr) {
1771     return false;
1772   }
1773   return (IsFanoutPortRankN(node, 0, 2) &&
1774           IsAlongAxis(*squeeze_dims_attr, indices({'H', 'W'}), kRank)) ||
1775          (IsFanoutPortRankN(node, 0, 1) &&
1776           IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank));
1777 }
1778 
UpdateSqueezeDims(TransposeContext * context,utils::MutableNodeView * node)1779 Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context,
1780                                             utils::MutableNodeView* node) {
1781   const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims);
1782   if (squeeze_dims_attr == nullptr) {
1783     return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims);
1784   }
1785   const int num_input_dims = context->src_format.length();
1786   const int min_squeeze_dim = -num_input_dims;
1787   std::vector<int> squeeze_dims_mapped;
1788   const int squeeze_dims_size = squeeze_dims_attr->list().i_size();
1789   squeeze_dims_mapped.reserve(squeeze_dims_size);
1790   for (int i = 0; i < squeeze_dims_size; ++i) {
1791     int dim = squeeze_dims_attr->list().i(i);
1792     if (dim < min_squeeze_dim || dim >= num_input_dims) {
1793       return errors::InvalidArgument(
1794           "Attribute '", kAttrSqueezeDims, "' contains out of range index '",
1795           dim, "', index must be between [", min_squeeze_dim, ", ",
1796           num_input_dims, ")");
1797     }
1798     if (dim < 0) {
1799       dim += num_input_dims;
1800     }
1801     squeeze_dims_mapped.push_back(context->dst_to_src[dim]);
1802   }
1803   std::sort(squeeze_dims_mapped.begin(), squeeze_dims_mapped.end());
1804   AttrValue squeeze_dims;
1805   squeeze_dims.mutable_list()->mutable_i()->Reserve(squeeze_dims_size);
1806   for (const auto& dim : squeeze_dims_mapped) {
1807     squeeze_dims.mutable_list()->mutable_i()->Add(dim);
1808   }
1809   context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
1810       node, kAttrSqueezeDims, squeeze_dims);
1811   return Status::OK();
1812 }
1813 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1814 Status SqueezeTransposer::TransposeNode(TransposeContext* context,
1815                                         utils::MutableNodeView* node) {
1816   DCHECK(IsSqueeze(*node->node()));
1817   if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) ||
1818       !IsInputConvertible(*context, *node) ||
1819       !IsAfterDstToSrcTransform(*context, *node)) {
1820     return Status::OK();
1821   }
1822   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1823   TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node));
1824   return context->graph_view->GetMutationBuilder()->Apply();
1825 }
1826 
IsMaskZero(const utils::MutableNodeView & node,absl::string_view mask)1827 bool StridedSliceTransposer::IsMaskZero(const utils::MutableNodeView& node,
1828                                         absl::string_view mask) {
1829   const auto* mask_attr = node.GetAttr(mask);
1830   if (mask_attr != nullptr) {
1831     return mask_attr->i() == 0;
1832   }
1833   return true;
1834 }
1835 
HasOnlyBeginEndMask(const utils::MutableNodeView & node)1836 bool StridedSliceTransposer::HasOnlyBeginEndMask(
1837     const utils::MutableNodeView& node) {
1838   return IsMaskZero(node, "ellipsis_mask") &&
1839          IsMaskZero(node, "new_axis_mask") &&
1840          IsMaskZero(node, "shrink_axis_mask");
1841 }
1842 
PermuteMask(TransposeContext * context,utils::MutableNodeView * node,absl::string_view mask)1843 Status StridedSliceTransposer::PermuteMask(TransposeContext* context,
1844                                            utils::MutableNodeView* node,
1845                                            absl::string_view mask) {
1846   // Computers the permutation of the masks based on the src and dst format.
1847   // For example:
1848   // src_format = NHWC
1849   // dst_format = NCHW
1850   // src_to_dst permutation = [0, 3, 1, 2].
1851   // mask : 0010 [Note the bit positions correspond to indexes i.e this is in
1852   // reverse order of the src format (CWHN)] result : 0100 (WHCN)
1853   const auto* mask_attr = node->GetAttr(mask);
1854   const int mask_i = mask_attr != nullptr ? mask_attr->i() : 0;
1855   if (mask_i < 0 || mask_i > 15) {
1856     return errors::InvalidArgument("invalid mask value: ", mask_i);
1857   }
1858   int result = 0;
1859   for (int i = 0, end = context->src_to_dst.size(); i < end; i++) {
1860     const int final_pos = context->src_to_dst[i];
1861     const int position_mask = 1 << final_pos;
1862     const int bit_i = (mask_i & position_mask) >> final_pos;
1863     result |= bit_i << i;
1864   }
1865   AttrValue new_mask_attr;
1866   new_mask_attr.set_i(result);
1867   context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(node, mask,
1868                                                                  new_mask_attr);
1869   return Status::OK();
1870 }
1871 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1872 Status StridedSliceTransposer::TransposeNode(TransposeContext* context,
1873                                              utils::MutableNodeView* node) {
1874   DCHECK(IsStridedSlice(*node->node()));
1875   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1876       !IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) ||
1877       !HasOnlyBeginEndMask(*node) ||
1878       !IsAfterDstToSrcTransform(*context, *node)) {
1879     return Status::OK();
1880   }
1881   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1882   TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask"));
1883   TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask"));
1884   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node,
1885                                             kOpDataFormatVecPermute));
1886   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1887   return context->graph_view->GetMutationBuilder()->Apply();
1888 }
1889 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1890 Status SwitchTransposer::TransposeNode(TransposeContext* context,
1891                                        utils::MutableNodeView* node) {
1892   DCHECK(IsSwitch(*node->node()));
1893   if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
1894       !IsAfterDstToSrcTransform(*context, *node)) {
1895     return Status::OK();
1896   }
1897   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1898   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node),
1899                                              node, kOpTranspose));
1900   return context->graph_view->GetMutationBuilder()->Apply();
1901 }
1902 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1903 Status TernaryOpTransposer::TransposeNode(TransposeContext* context,
1904                                           utils::MutableNodeView* node) {
1905   DCHECK(IsTernaryOp(*node->node()));
1906   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1907       !IsAfterDstToSrcTransform(*context, *node)) {
1908     return Status::OK();
1909   }
1910   TF_RETURN_IF_ERROR(
1911       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1912   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1913   return context->graph_view->GetMutationBuilder()->Apply();
1914 }
1915 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1916 Status TileTransposer::TransposeNode(TransposeContext* context,
1917                                      utils::MutableNodeView* node) {
1918   DCHECK(IsTile(*node->node()));
1919   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1920       !IsFaninPortDimsNIfConst(*node, 1, {4}) ||
1921       !IsAfterDstToSrcTransform(*context, *node)) {
1922     return Status::OK();
1923   }
1924   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1925   TF_RETURN_IF_ERROR(
1926       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1927   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1928   return context->graph_view->GetMutationBuilder()->Apply();
1929 }
1930 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1931 Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
1932                                           utils::MutableNodeView* node) {
1933   DCHECK(IsUnaryGrad(*node->node()));
1934   const int rank = GetFanoutPortRank(*node, 0);
1935   if (rank != 4 && rank != 5) {
1936     return Status::OK();
1937   }
1938   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1939   if (!ShouldProcess(*context, *node) ||
1940       !IsAfterDstToSrcTransform(*context, *node)) {
1941     return Status::OK();
1942   }
1943   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1944           << "' with op '" << node->GetOp() << "' from data format '"
1945           << context->src_format << "' to '" << context->dst_format << "'";
1946   TF_RETURN_IF_ERROR(
1947       UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
1948   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1949   return context->graph_view->GetMutationBuilder()->Apply();
1950 }
1951 
1952 // Utils.
1953 
GetDeviceName(const VirtualPlacer * virtual_placer,const NodeDef & node)1954 string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
1955   return (node.device().empty() && virtual_placer != nullptr)
1956              ? virtual_placer->get_canonical_device_name(node)
1957              : node.device();
1958 }
1959 
IsDefaultLayoutSensitiveOp(const NodeDef & node)1960 bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
1961   static absl::flat_hash_set<string>* default_layout_sensitive_ops =
1962       new absl::flat_hash_set<std::string>(
1963           {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
1964            "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1965            "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
1966   return default_layout_sensitive_ops->find(node.op()) !=
1967          default_layout_sensitive_ops->end();
1968 }
1969 
IsLayoutSensitiveOp(const NodeDef & node)1970 bool IsLayoutSensitiveOp(const NodeDef& node) {
1971   return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
1972          IsBiasAddV2(node) || IsBiasAddGrad(node) ||
1973          IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
1974          IsDepthwiseConv2dNativeBackpropFilter(node) ||
1975          IsDepthwiseConv2dNativeBackpropInput(node) ||
1976          IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
1977          IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) ||
1978          IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node) ||
1979          IsConv3D(node) || IsConv3DBackpropInputV2(node) ||
1980          IsConv3DBackpropFilterV2(node);
1981 }
1982 
IsDefaultLayoutAgnosticOp(const NodeDef & node)1983 bool IsDefaultLayoutAgnosticOp(const NodeDef& node) {
1984   static absl::flat_hash_set<string>* agnostic_nodes =
1985       new absl::flat_hash_set<std::string>({"Abs",
1986                                             "Acos",
1987                                             "Acosh",
1988                                             "Angle",
1989                                             "Asin",
1990                                             "Asinh",
1991                                             "Atan",
1992                                             "Atanh",
1993                                             "Bitcast",
1994                                             "Cast",
1995                                             "Ceil",
1996                                             "CheckNumerics",
1997                                             "ComplexAbs",
1998                                             "Conj",
1999                                             "Cos",
2000                                             "Cosh",
2001                                             "Digamma",
2002                                             "Elu",
2003                                             "Enter",
2004                                             "Erf",
2005                                             "Erfc",
2006                                             "Exit",
2007                                             "Exp",
2008                                             "Expm1",
2009                                             "FakeQuantWithMinMaxVars",
2010                                             "FakeQuantWithMinMaxArgs",
2011                                             "Floor",
2012                                             "GuaranteeConst",
2013                                             "Identity",
2014                                             "Imag",
2015                                             "Inv",
2016                                             "IsFinite",
2017                                             "IsInf",
2018                                             "IsNan",
2019                                             "LeakyRelu",
2020                                             "Lgamma",
2021                                             "Log",
2022                                             "LogicalNot",
2023                                             "Log1p",
2024                                             "Neg",
2025                                             "NextIteration",
2026                                             "OnesLike",
2027                                             "PreventGradient",
2028                                             "QuantizeAndDequantizeV2",
2029                                             "QuantizeAndDequantizeV3",
2030                                             "Real",
2031                                             "Reciprocal",
2032                                             "Relu",
2033                                             "Relu6",
2034                                             "Rint",
2035                                             "Selu",
2036                                             "Sigmoid",
2037                                             "Sign",
2038                                             "Sin",
2039                                             "Sinh",
2040                                             "Snapshot",
2041                                             "Softplus",
2042                                             "Round",
2043                                             "Rsqrt",
2044                                             "Sqrt",
2045                                             "Square",
2046                                             "StopGradient",
2047                                             "Tan",
2048                                             "Tanh",
2049                                             "ZerosLike"});
2050   return agnostic_nodes->find(node.op()) != agnostic_nodes->end();
2051 }
2052 
IsLayoutAgnosticOp(const NodeDef & node)2053 bool IsLayoutAgnosticOp(const NodeDef& node) {
2054   return IsDefaultLayoutAgnosticOp(node) || IsAddN(node) || IsBinaryOp(node) ||
2055          IsIdentityN(node) || IsMerge(node) || IsMirrorPad(node) ||
2056          IsMirrorPadGrad(node) || IsPad(node) || IsSelect(node) ||
2057          IsSwitch(node) || IsTernaryOp(node) || IsUnaryGrad(node) ||
2058          IsConcat(node) || IsReverseV2(node) || IsTile(node) || IsShape(node) ||
2059          IsShapeN(node) || IsFill(node) || IsSlice(node) || IsSplit(node) ||
2060          IsSqueeze(node) || IsSplitV(node) || IsStridedSlice(node) ||
2061          IsReduceOp(node);
2062 }
2063 
IsTernaryOp(const NodeDef & node)2064 bool IsTernaryOp(const NodeDef& node) { return IsBetainc(node); }
2065 
IsUnaryGrad(const NodeDef & node)2066 bool IsUnaryGrad(const NodeDef& node) {
2067   bool is_unary_grad =
2068       IsEluGrad(node) || IsInvGrad(node) || IsLeakyReluGrad(node) ||
2069       IsReciprocalGrad(node) || IsRelu6Grad(node) || IsReluGrad(node) ||
2070       IsRsqrtGrad(node) || IsSeluGrad(node) || IsSigmoidGrad(node) ||
2071       IsSoftplusGrad(node) || IsSoftsignGrad(node) || IsSqrtGrad(node) ||
2072       IsTanhGrad(node);
2073   return is_unary_grad;
2074 }
2075 
IsMaxPoolV2(const NodeDef & node)2076 bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; }
2077 
IsMaxPoolGradV2(const NodeDef & node)2078 bool IsMaxPoolGradV2(const NodeDef& node) {
2079   return node.op() == "MaxPoolGradV2";
2080 }
2081 
IsMaxPoolGradGradV1(const NodeDef & node)2082 bool IsMaxPoolGradGradV1(const NodeDef& node) {
2083   return node.op() == "MaxPoolGradGrad";
2084 }
2085 
IsMaxPoolGradGradV2(const NodeDef & node)2086 bool IsMaxPoolGradGradV2(const NodeDef& node) {
2087   return node.op() == "MaxPoolGradGradV2";
2088 }
2089 
IsBinaryOp(const NodeDef & node)2090 bool IsBinaryOp(const NodeDef& node) {
2091   bool is_binary =
2092       IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
2093       IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
2094       IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
2095       IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
2096       IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
2097       IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
2098   return is_binary;
2099 }
2100 
IsReduceOp(const NodeDef & node)2101 bool IsReduceOp(const NodeDef& node) {
2102   return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
2103          IsMin(node) || IsAll(node) || IsAny(node);
2104 }
2105 
GetDataFaninPorts(const utils::MutableNodeView & node)2106 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node) {
2107   const auto* node_def = node.node();
2108   if (IsAvgPoolGrad(*node_def) || IsSplit(*node_def)) {
2109     return {1};
2110   }
2111   if (IsStridedSliceGrad(*node_def)) {
2112     return {4};
2113   }
2114   if (IsBinaryOp(*node_def) || IsUnaryGrad(*node_def)) {
2115     return {0, 1};
2116   }
2117   if (IsTernaryOp(*node_def) || IsSelect(*node_def) ||
2118       IsMaxPoolGrad(*node_def) || IsMaxPoolGradV2(*node_def) ||
2119       IsMaxPoolGradGradV1(*node_def) || IsMaxPoolGradGradV2(*node_def)) {
2120     return {0, 1, 2};
2121   }
2122   if (IsShapeN(*node_def) || IsIdentityN(*node_def) || IsAddN(*node_def) ||
2123       IsMerge(*node_def)) {
2124     return GetRegularFaninPorts(node);
2125   }
2126   if (IsConcat(*node_def)) {
2127     return GetConcatDataFaninPorts(node);
2128   }
2129   if (node.NumRegularFanins() > 0) {
2130     return {0};
2131   }
2132   return {};
2133 }
2134 
GetDataFanoutPorts(const utils::MutableNodeView & node)2135 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node) {
2136   const auto* node_def = node.node();
2137   if (IsIdentityN(*node_def) || IsShape(*node_def) || IsShapeN(*node_def)) {
2138     return GetDataFaninPorts(node);
2139   }
2140   if (IsSplit(*node_def) || IsSplitV(*node_def)) {
2141     const auto* num_split_attr = node.GetAttr(kAttrNumSplit);
2142     if (num_split_attr == nullptr) {
2143       return {0};
2144     }
2145     std::vector<int> values(num_split_attr->i());
2146     std::iota(values.begin(), values.end(), 0);
2147     return values;
2148   }
2149   if (IsSwitch(*node_def)) {
2150     const auto* num_outs_attr = node.GetAttr(kAttrNumOuts);
2151     const int num_outs = num_outs_attr != nullptr ? num_outs_attr->i() : 2;
2152     std::vector<int> values(num_outs);
2153     std::iota(values.begin(), values.end(), 0);
2154     return values;
2155   }
2156   return {0};
2157 }
2158 
GetValueAttrFromConstInputNode(const utils::MutableNodeView & node,const std::function<bool (const NodeDef &)> & predicate,int index,Tensor * tensor)2159 bool GetValueAttrFromConstInputNode(
2160     const utils::MutableNodeView& node,
2161     const std::function<bool(const NodeDef&)>& predicate, int index,
2162     Tensor* tensor) {
2163   if (!predicate(*node.node())) {
2164     return false;
2165   }
2166   const auto& regular_fanin = node.GetRegularFanin(index);
2167   auto* regular_fanin_node = regular_fanin.node_view();
2168   if (!IsConstant(*regular_fanin_node->node())) {
2169     return false;
2170   }
2171   const auto* value_attr = regular_fanin_node->GetAttr(kAttrValue);
2172   if (value_attr == nullptr || value_attr->tensor().dtype() != DT_INT32) {
2173     return false;
2174   }
2175   if (!tensor->FromProto(value_attr->tensor())) {
2176     return false;
2177   }
2178 
2179   return true;
2180 }
2181 
IsDataFormatOp(const utils::MutableNodeView & node)2182 bool IsDataFormatOp(const utils::MutableNodeView& node) {
2183   const string& op = node.GetOp();
2184   return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute;
2185 }
2186 
GetDimensionIndices(absl::string_view data_format)2187 absl::flat_hash_map<char, int> GetDimensionIndices(
2188     absl::string_view data_format) {
2189   const int size = data_format.size();
2190   absl::flat_hash_map<char, int> index;
2191   index.reserve(size);
2192   for (int i = 0; i < size; i++) {
2193     index[data_format[i]] = i;
2194   }
2195   return index;
2196 }
2197 
GetPermutation(const absl::flat_hash_map<char,int> & src_dim_indices,absl::string_view dst_format)2198 std::vector<int> GetPermutation(
2199     const absl::flat_hash_map<char, int>& src_dim_indices,
2200     absl::string_view dst_format) {
2201   // Generate permutation for transformation between src and dst format.
2202   // Example:
2203   // src = NWHC, dst = NCWH
2204   // index = { N:0 W:1 H:2 C:3 }
2205   // permutation = [0, 3, 1, 2]
2206   DCHECK(src_dim_indices.size() == dst_format.size());
2207   std::vector<int> permutation;
2208   const int size = dst_format.size();
2209   permutation.reserve(size);
2210   for (int i = 0; i < size; i++) {
2211     permutation.push_back(src_dim_indices.at(dst_format[i]));
2212   }
2213   return permutation;
2214 }
2215 
2216 }  // namespace grappler
2217 }  // namespace tensorflow
2218