• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/grappler/costs/graph_properties.h"
30 #include "tensorflow/core/grappler/costs/virtual_placer.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/frame.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 constexpr char kAttrSrcFormat[] = "src_format";
41 constexpr char kAttrDstFormat[] = "dst_format";
42 constexpr char kAttrOutputShape[] = "_output_shapes";
43 constexpr char kGPU[] = "GPU";
44 constexpr char kCPU[] = "CPU";
45 
46 // TransposeContext owns all data members. Must initialize GraphProperties,
47 // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
48 // pointers in FrameView, GraphDef and MutableGraphView must point to nodes in
49 // the same GraphDef instance.
50 struct TransposeContext {
51   // Initializes TransposeContext with given GrapplerItem. Because initializing
52   // FrameMap and GraphProperties may return error, we initialize
53   // TransposeContext outside constructor.
54   static Status InitializeTransposeContext(const GrapplerItem& item,
55                                            const Cluster* cluster,
56                                            TransposeContext* context);
57 
58   // Sets data formats to convert from and to for specified device type.
59   void AssignDeviceAndDataFormats(absl::string_view target_device,
60                                   absl::string_view src_format,
61                                   absl::string_view dst_format);
62 
63   FrameView frames;
64   GraphDef graph;
65   // Number of nodes in the original graph. As new nodes are appended to the end
66   // of the graph, all new nodes should have a node index greater than or equal
67   // to this.
68   int num_nodes;
69   absl::flat_hash_set<string> nodes_to_preserve;
70   std::unique_ptr<GraphProperties> graph_properties;
71   std::unique_ptr<utils::MutableGraphView> graph_view;
72   std::unique_ptr<const VirtualPlacer> virtual_placer;
73 
74   string target_device;
75   string src_format;
76   string dst_format;
77   absl::flat_hash_map<char, int> src_dim_indices;
78   absl::flat_hash_map<char, int> dst_dim_indices;
79   std::vector<int> src_to_dst;
80   std::vector<int> dst_to_src;
81 };
82 
83 class Transposer {
84  public:
Transposer()85   explicit Transposer() {}
86 
87   Transposer(const Transposer&) = delete;
88   Transposer& operator=(const Transposer&) = delete;
89 
~Transposer()90   virtual ~Transposer() {}
91 
92   // Returns true iff the node should be processed by this transposer.
93   // NodeProcessors may perform additional oprand specific checks before
94   // processing if necessary.
95   // Following common conditions are checked:
96   // * node's device matches target device
97   // * node's source format matches config's source format
98   // * node has output
99   bool ShouldProcess(const TransposeContext& context,
100                      const utils::MutableNodeView& node) const;
101 
102   // Transposes given node from src format to dst format. Also perform other
103   // necessary operations to guarantee the graph produce the same result.
104   // Eg. Add Transpose node sets before fanin ports and after fanout ports.
105   virtual Status TransposeNode(TransposeContext* context,
106                                utils::MutableNodeView* node) = 0;
107 
108   // Creates a Const node for permutation. If node with node_name already exits,
109   // return and reuse it.
110   Status CreateConstPermNode(TransposeContext* context,
111                              absl::string_view node_name,
112                              absl::string_view device,
113                              absl::Span<const int> permutation,
114                              absl::string_view control_node_name,
115                              utils::MutationNewNode* added_node);
116 
117   // Creates a TransposeNode with given properties. If node with node_name
118   // already exits, return and reuse it.
119   // A const perm node is also created and connected to the 2nd fanin.
120   // control_node_name is ignored if it is empty.
121   Status CreateTransposeNode(
122       TransposeContext* context, absl::string_view name_format,
123       const DataType& data_type, absl::string_view device,
124       TensorShapeProto fanin_shape, absl::Span<const int> permutation,
125       absl::string_view control_node_name, utils::MutationNewNode* added_node,
126       string* transpose_node_name);
127 
128   // Update all edges between dst_node->fanin[dst_ports] and dst_node by
129   // inserting an op node.
130   Status UpdateFaninEdgesWithOp(TransposeContext* context,
131                                 absl::Span<const int> dst_ports,
132                                 utils::MutableNodeView* dst_node,
133                                 absl::string_view op);
134 
135   // Update all edges between src_node:src_ports and nodes take
136   // src_node:src_ports as fanin. Also update attr _output_shape of src_node.
137   Status UpdateFanoutEdgesWithOp(TransposeContext* context,
138                                  absl::Span<const int> src_ports,
139                                  utils::MutableNodeView* src_node,
140                                  absl::string_view op);
141 
142   // Creates a DataFromat node with given properties.
143   // DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
144   Status CreateDataFormatNode(TransposeContext* context,
145                               absl::string_view node_name, absl::string_view op,
146                               absl::string_view device,
147                               const DataType& data_type, bool is_fanin_on_host,
148                               bool is_src_format_to_dst_format,
149                               utils::MutationNewNode* added_node);
150 
151  protected:
152   int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
153   bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
154                          int n) const;
155   bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
156                           absl::Span<const int> ports, int n) const;
157   int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
158   bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
159                         int n) const;
160 
161   // Checks if fanin at specified port(s) has dimensions `dims` iff fanin is a
162   // Const. If fanin is not a Const, no dimensions will be checked and this will
163   // return true.
164   bool IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, int port,
165                                absl::Span<const int> dims) const;
166   bool IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
167                                 absl::Span<const int> ports,
168                                 absl::Span<const int> dims) const;
169   bool CanProcessNode(const TransposeContext& context,
170                       const utils::MutableNodeView& node) const;
171   // Update all edges between dst_node->fanin[dst_ports] and dst_node.
172   // A node with op is created and inserted between all edges.
173   // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
174   Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
175                     absl::string_view op, const AttrValue* input_shape,
176                     bool is_in_frame, bool is_src_format_to_dst_format,
177                     const int src_port, const int dst_port,
178                     utils::MutableNodeView* src_node,
179                     utils::MutableNodeView* dst_node);
180   string GetFaninNameFormat(absl::string_view node_name, int port,
181                             absl::string_view src_format,
182                             absl::string_view dst_format);
183   string GetFanoutNameFormat(absl::string_view node_name, int port, int index,
184                              absl::string_view src_format,
185                              absl::string_view dst_format);
186   string LayoutOptimizerNode(absl::string_view node_name);
187   string GetReshapeNodeNameFormat(absl::string_view node_name, int index,
188                                   absl::string_view src_format,
189                                   absl::string_view dst_format);
190   string GetShapeConstNodeNameFormat(absl::string_view node_name, int index);
191 };
192 
193 class LayoutSensitiveOpTransposer : public Transposer {
194  public:
LayoutSensitiveOpTransposer()195   explicit LayoutSensitiveOpTransposer() : Transposer() {}
196 
197   // Updates attrs data_format, ksize, strides of the given node to dst_format.
198   // _output_shape is updated during UpdateOutputEdges.
199   Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
200 };
201 
202 // Layout sensitive op transposers.
203 
204 class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
205  public:
DefaultLayoutSensitiveOpTransposer()206   explicit DefaultLayoutSensitiveOpTransposer()
207       : LayoutSensitiveOpTransposer() {}
208 
209   Status TransposeNode(TransposeContext* context,
210                        utils::MutableNodeView* node) override;
211 };
212 
213 class BiasAddTransposer : public LayoutSensitiveOpTransposer {
214  public:
BiasAddTransposer()215   explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
216 
217   Status TransposeNode(TransposeContext* context,
218                        utils::MutableNodeView* node) override;
219 };
220 
221 class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
222  public:
AvgPoolGradTransposer()223   explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
224 
225   Status TransposeNode(TransposeContext* context,
226                        utils::MutableNodeView* node) override;
227 };
228 
229 class BiasAddGradTransposer : public LayoutSensitiveOpTransposer {
230  public:
BiasAddGradTransposer()231   explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {}
232 
233   Status TransposeNode(TransposeContext* context,
234                        utils::MutableNodeView* node) override;
235 };
236 
237 class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
238  public:
Conv2DBackpropFilterTransposer()239   explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
240 
241   Status TransposeNode(TransposeContext* context,
242                        utils::MutableNodeView* node) override;
243 };
244 
245 class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
246  public:
Conv2DBackpropInputTransposer()247   explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
248 
249   Status TransposeNode(TransposeContext* context,
250                        utils::MutableNodeView* node) override;
251 };
252 
253 class Conv3DTransposer : public LayoutSensitiveOpTransposer {
254  public:
Conv3DTransposer()255   explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {}
256 
257   Status TransposeNode(TransposeContext* context,
258                        utils::MutableNodeView* node) override;
259 };
260 
261 class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
262  public:
Conv3DBackpropFilterTransposer()263   explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
264 
265   Status TransposeNode(TransposeContext* context,
266                        utils::MutableNodeView* node) override;
267 };
268 
269 class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
270  public:
Conv3DBackpropInputTransposer()271   explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
272 
273   Status TransposeNode(TransposeContext* context,
274                        utils::MutableNodeView* node) override;
275 };
276 
277 class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer {
278  public:
FusedBatchNormExTransposer()279   explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {}
280 
281   Status TransposeNode(TransposeContext* context,
282                        utils::MutableNodeView* node) override;
283 };
284 
285 class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer {
286  public:
FusedBatchNormGradTransposer()287   explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {}
288 
289   Status TransposeNode(TransposeContext* context,
290                        utils::MutableNodeView* node) override;
291 
292  private:
293   bool IsTraining(const utils::MutableNodeView& node) const;
294 };
295 
296 class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer {
297  public:
MaxPoolV2Transposer()298   explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {}
299 
300   Status TransposeNode(TransposeContext* context,
301                        utils::MutableNodeView* node) override;
302 };
303 
304 class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer {
305  public:
MaxPoolGradTransposer()306   explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
307 
308   Status TransposeNode(TransposeContext* context,
309                        utils::MutableNodeView* node) override;
310 };
311 
312 class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer {
313  public:
MaxPoolGradV2Transposer()314   explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {}
315 
316   Status TransposeNode(TransposeContext* context,
317                        utils::MutableNodeView* node) override;
318 };
319 
320 // Layout agnostic op transposers.
321 
322 class LayoutAgnosticOpTransposer : public Transposer {
323  public:
LayoutAgnosticOpTransposer()324   explicit LayoutAgnosticOpTransposer() : Transposer() {}
325 
326  protected:
327   bool IsAfterDstToSrcTransform(const TransposeContext& context,
328                                 const utils::MutableNodeView& node) const;
329 
330   std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
331                                            const utils::MutableNodeView& node,
332                                            int rank) const;
333 };
334 
335 class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
336  public:
DefaultLayoutAgnosticOpTransposer()337   explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {}
338 
339   Status TransposeNode(TransposeContext* context,
340                        utils::MutableNodeView* node) override;
341 };
342 
343 class AddNTransposer : public LayoutAgnosticOpTransposer {
344  public:
AddNTransposer()345   explicit AddNTransposer() : LayoutAgnosticOpTransposer() {}
346 
347   Status TransposeNode(TransposeContext* context,
348                        utils::MutableNodeView* node) override;
349 };
350 
351 class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
352  public:
BinaryOpTransposer()353   explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {}
354 
355   Status TransposeNode(TransposeContext* context,
356                        utils::MutableNodeView* node) override;
357 
358  private:
359   bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
360   bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
361   std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
362                                        int rank);
363   Status AddNodeShapeConst(utils::Mutation* mutation,
364                            absl::string_view node_name,
365                            absl::string_view node_device, bool node_in_frame,
366                            int num_channels, absl::string_view depended_node,
367                            int rank);
368   Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
369                         absl::string_view node_device,
370                         absl::string_view input_name,
371                         absl::string_view shape_const_node_name,
372                         const DataType& data_type);
373   Status MaybeReshapeVectorFanin(TransposeContext* context,
374                                  utils::MutableNodeView* node, int rank);
375 };
376 
377 class ConcatOpTransposer : public LayoutAgnosticOpTransposer {
378  public:
ConcatOpTransposer()379   explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {}
380 
381   Status TransposeNode(TransposeContext* context,
382                        utils::MutableNodeView* node) override;
383 };
384 
385 class FillOpTransposer : public LayoutAgnosticOpTransposer {
386  public:
FillOpTransposer()387   explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {}
388 
389   Status TransposeNode(TransposeContext* context,
390                        utils::MutableNodeView* node) override;
391 };
392 
393 class IdentityNTransposer : public LayoutAgnosticOpTransposer {
394  public:
IdentityNTransposer()395   explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {}
396 
397   Status TransposeNode(TransposeContext* context,
398                        utils::MutableNodeView* node) override;
399 };
400 
401 class MergeTransposer : public LayoutAgnosticOpTransposer {
402  public:
MergeTransposer()403   explicit MergeTransposer() : LayoutAgnosticOpTransposer() {}
404 
405   Status TransposeNode(TransposeContext* context,
406                        utils::MutableNodeView* node) override;
407 
408  private:
409   bool IsEveryFaninAfterDstToSrcTransform(
410       const TransposeContext& context,
411       const utils::MutableNodeView& node) const;
412 };
413 
414 class PadTransposer : public LayoutAgnosticOpTransposer {
415  public:
PadTransposer()416   explicit PadTransposer() : LayoutAgnosticOpTransposer() {}
417 
418   Status TransposeNode(TransposeContext* context,
419                        utils::MutableNodeView* node) override;
420 };
421 
422 class ReduceTransposer : public LayoutAgnosticOpTransposer {
423  public:
ReduceTransposer()424   explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {}
425 
426   Status TransposeNode(TransposeContext* context,
427                        utils::MutableNodeView* node) override;
428 
429  private:
430   bool KeepDims(const utils::MutableNodeView& node);
431   bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
432   bool IsReduceAxisSupported(const TransposeContext& context,
433                              const utils::MutableNodeView& node, int rank);
434 };
435 
436 class ReverseV2Transposer : public LayoutAgnosticOpTransposer {
437  public:
ReverseV2Transposer()438   explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {}
439 
440   Status TransposeNode(TransposeContext* context,
441                        utils::MutableNodeView* node) override;
442 };
443 
444 class SelectTransposer : public LayoutAgnosticOpTransposer {
445  public:
SelectTransposer()446   explicit SelectTransposer() : LayoutAgnosticOpTransposer() {}
447 
448   Status TransposeNode(TransposeContext* context,
449                        utils::MutableNodeView* node) override;
450 
451  protected:
452   bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port);
453   std::vector<int> GetFaninPorts(const utils::MutableNodeView& fanin, int port);
454 };
455 
456 class ShapeTransposer : public LayoutAgnosticOpTransposer {
457  public:
ShapeTransposer()458   explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {}
459 
460   Status TransposeNode(TransposeContext* context,
461                        utils::MutableNodeView* node) override;
462 };
463 
464 class ShapeNTransposer : public LayoutAgnosticOpTransposer {
465  public:
ShapeNTransposer()466   explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {}
467 
468   Status TransposeNode(TransposeContext* context,
469                        utils::MutableNodeView* node) override;
470 };
471 
472 class SliceTransposer : public LayoutAgnosticOpTransposer {
473  public:
SliceTransposer()474   explicit SliceTransposer() : LayoutAgnosticOpTransposer() {}
475 
476   Status TransposeNode(TransposeContext* context,
477                        utils::MutableNodeView* node) override;
478 };
479 
480 class SplitTransposer : public LayoutAgnosticOpTransposer {
481  public:
SplitTransposer()482   explicit SplitTransposer() : LayoutAgnosticOpTransposer() {}
483 
484   Status TransposeNode(TransposeContext* context,
485                        utils::MutableNodeView* node) override;
486 };
487 
488 class SplitVTransposer : public LayoutAgnosticOpTransposer {
489  public:
SplitVTransposer()490   explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {}
491 
492   Status TransposeNode(TransposeContext* context,
493                        utils::MutableNodeView* node) override;
494 };
495 
496 class SqueezeTransposer : public LayoutAgnosticOpTransposer {
497  public:
SqueezeTransposer()498   explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {}
499 
500   Status TransposeNode(TransposeContext* context,
501                        utils::MutableNodeView* node) override;
502 
503  private:
504   bool IsInputConvertible(const TransposeContext& context,
505                           const utils::MutableNodeView& node) const;
506   bool IsAlongAxis(const AttrValue& attr, absl::Span<const int> axis,
507                    int rank) const;
508   bool IsDimsSupported(const TransposeContext& context,
509                        const utils::MutableNodeView& node) const;
510   Status UpdateSqueezeDims(TransposeContext* context,
511                            utils::MutableNodeView* node);
512 };
513 
514 class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
515  public:
StridedSliceTransposer()516   explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {}
517 
518   Status TransposeNode(TransposeContext* context,
519                        utils::MutableNodeView* node) override;
520 
521  private:
522   bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
523   bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
524   Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
525                      absl::string_view mask);
526 };
527 
528 class SwitchTransposer : public LayoutAgnosticOpTransposer {
529  public:
SwitchTransposer()530   explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {}
531 
532   Status TransposeNode(TransposeContext* context,
533                        utils::MutableNodeView* node) override;
534 };
535 
536 class TernaryOpTransposer : public LayoutAgnosticOpTransposer {
537  public:
TernaryOpTransposer()538   explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {}
539 
540   Status TransposeNode(TransposeContext* context,
541                        utils::MutableNodeView* node) override;
542 };
543 
544 class TileTransposer : public LayoutAgnosticOpTransposer {
545  public:
TileTransposer()546   explicit TileTransposer() : LayoutAgnosticOpTransposer() {}
547 
548   Status TransposeNode(TransposeContext* context,
549                        utils::MutableNodeView* node) override;
550 };
551 
552 class UnaryGradTransposer : public LayoutAgnosticOpTransposer {
553  public:
UnaryGradTransposer()554   explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {}
555 
556   Status TransposeNode(TransposeContext* context,
557                        utils::MutableNodeView* node) override;
558 };
559 
560 // Utils.
561 
562 // Permutes elements according to permutation and replaces the original values.
563 // Permutation and values must have same size.
564 template <typename T>
PermuteSingle(absl::string_view location,absl::Span<const int> permutation,T * values)565 Status PermuteSingle(absl::string_view location,
566                      absl::Span<const int> permutation, T* values) {
567   DCHECK(values != nullptr);
568   int permutation_size = permutation.size();
569   if (values->size() != permutation_size) {
570     return Status(tensorflow::error::Code::INVALID_ARGUMENT,
571                   absl::StrCat("Size of values ", values->size(),
572                                " does not match size of permutation ",
573                                permutation_size, " @ ", location));
574   }
575   typedef typename T::value_type V;
576   std::vector<V> elements(values->begin(), values->end());
577   int index = 0;
578   for (V& element : *values) {
579     element = elements[permutation[index++]];
580   }
581   return Status::OK();
582 }
583 
584 // Permutes two elements at a time according to permutation and replaces the
585 // original values. Values must be twice the size of permutation.
586 template <typename T>
PermuteDouble(absl::string_view location,absl::Span<const int> permutation,T * values)587 Status PermuteDouble(absl::string_view location,
588                      absl::Span<const int> permutation, T* values) {
589   DCHECK(values != nullptr);
590   int permutation_size = permutation.size();
591   if (values->size() != permutation_size * 2) {
592     return Status(tensorflow::error::Code::INVALID_ARGUMENT,
593                   absl::StrCat("Size of values ", values->size(),
594                                " does not match twice the size of permutation ",
595                                permutation_size, " @ ", location));
596   }
597   typedef typename T::value_type V;
598   std::vector<V> elements(values->begin(), values->end());
599   for (int i = 0; i < values->size(); i = i + 2) {
600     const int permutation_index = permutation[i / 2];
601     (*values)[i] = elements[permutation_index * 2];
602     (*values)[i + 1] = elements[permutation_index * 2 + 1];
603   }
604   return Status::OK();
605 }
606 
607 string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node);
608 
609 bool IsDefaultLayoutSensitiveOp(const NodeDef& node);
610 
611 bool IsLayoutSensitiveOp(const NodeDef& node);
612 
613 bool IsDefaultLayoutAgnosticOp(const NodeDef& node);
614 
615 bool IsLayoutAgnosticOp(const NodeDef& node);
616 
617 bool IsTernaryOp(const NodeDef& node);
618 
619 bool IsUnaryGrad(const NodeDef& node);
620 
621 bool IsMaxPoolV2(const NodeDef& node);
622 
623 bool IsMaxPoolGradV2(const NodeDef& node);
624 
625 bool IsMaxPoolGradGradV1(const NodeDef& node);
626 
627 bool IsMaxPoolGradGradV2(const NodeDef& node);
628 
629 bool IsBinaryOp(const NodeDef& node);
630 
631 bool IsReduceOp(const NodeDef& node);
632 
633 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node);
634 
635 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
636 
637 // Returns a value of constant input to the `node` at `index`, iff `predicate`
638 // evaluated to true. Returns true if `tensor` was populated with data.
639 bool GetValueAttrFromConstInputNode(
640     const utils::MutableNodeView& node,
641     const std::function<bool(const NodeDef&)>& predicate, int index,
642     Tensor* tensor);
643 
644 bool IsDataFormatOp(const utils::MutableNodeView& node);
645 
646 absl::flat_hash_map<char, int> GetDimensionIndices(
647     absl::string_view data_format);
648 
649 std::vector<int> GetPermutation(
650     const absl::flat_hash_map<char, int>& src_dim_indices,
651     absl::string_view dst_format);
652 
653 }  // namespace grappler
654 }  // namespace tensorflow
655 
656 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
657