1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/grappler/costs/graph_properties.h"
30 #include "tensorflow/core/grappler/costs/virtual_placer.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/frame.h"
33 #include "tensorflow/core/grappler/utils/graph_view.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 constexpr char kAttrSrcFormat[] = "src_format";
41 constexpr char kAttrDstFormat[] = "dst_format";
42 constexpr char kAttrOutputShape[] = "_output_shapes";
43 constexpr char kGPU[] = "GPU";
44 constexpr char kCPU[] = "CPU";
45
46 // TransposeContext owns all data members. Must initialize GraphProperties,
47 // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
48 // pointers in FrameView, GraphDef and MutableGraphView must point to nodes in
49 // the same GraphDef instance.
50 struct TransposeContext {
51 // Initializes TransposeContext with given GrapplerItem. Because initializing
52 // FrameMap and GraphProperties may return error, we initialize
53 // TransposeContext outside constructor.
54 static Status InitializeTransposeContext(const GrapplerItem& item,
55 const Cluster* cluster,
56 TransposeContext* context);
57
58 // Sets data formats to convert from and to for specified device type.
59 void AssignDeviceAndDataFormats(absl::string_view target_device,
60 absl::string_view src_format,
61 absl::string_view dst_format);
62
63 FrameView frames;
64 GraphDef graph;
65 // Number of nodes in the original graph. As new nodes are appended to the end
66 // of the graph, all new nodes should have a node index greater than or equal
67 // to this.
68 int num_nodes;
69 absl::flat_hash_set<string> nodes_to_preserve;
70 std::unique_ptr<GraphProperties> graph_properties;
71 std::unique_ptr<utils::MutableGraphView> graph_view;
72 std::unique_ptr<const VirtualPlacer> virtual_placer;
73
74 string target_device;
75 string src_format;
76 string dst_format;
77 absl::flat_hash_map<char, int> src_dim_indices;
78 absl::flat_hash_map<char, int> dst_dim_indices;
79 std::vector<int> src_to_dst;
80 std::vector<int> dst_to_src;
81 };
82
83 class Transposer {
84 public:
Transposer()85 explicit Transposer() {}
86
87 Transposer(const Transposer&) = delete;
88 Transposer& operator=(const Transposer&) = delete;
89
~Transposer()90 virtual ~Transposer() {}
91
92 // Returns true iff the node should be processed by this transposer.
93 // NodeProcessors may perform additional oprand specific checks before
94 // processing if necessary.
95 // Following common conditions are checked:
96 // * node's device matches target device
97 // * node's source format matches config's source format
98 // * node has output
99 bool ShouldProcess(const TransposeContext& context,
100 const utils::MutableNodeView& node) const;
101
102 // Transposes given node from src format to dst format. Also perform other
103 // necessary operations to guarantee the graph produce the same result.
104 // Eg. Add Transpose node sets before fanin ports and after fanout ports.
105 virtual Status TransposeNode(TransposeContext* context,
106 utils::MutableNodeView* node) = 0;
107
108 // Creates a Const node for permutation. If node with node_name already exits,
109 // return and reuse it.
110 Status CreateConstPermNode(TransposeContext* context,
111 absl::string_view node_name,
112 absl::string_view device,
113 absl::Span<const int> permutation,
114 absl::string_view control_node_name,
115 utils::MutationNewNode* added_node);
116
117 // Creates a TransposeNode with given properties. If node with node_name
118 // already exits, return and reuse it.
119 // A const perm node is also created and connected to the 2nd fanin.
120 // control_node_name is ignored if it is empty.
121 Status CreateTransposeNode(
122 TransposeContext* context, absl::string_view name_format,
123 const DataType& data_type, absl::string_view device,
124 TensorShapeProto fanin_shape, absl::Span<const int> permutation,
125 absl::string_view control_node_name, utils::MutationNewNode* added_node,
126 string* transpose_node_name);
127
128 // Update all edges between dst_node->fanin[dst_ports] and dst_node by
129 // inserting an op node.
130 Status UpdateFaninEdgesWithOp(TransposeContext* context,
131 absl::Span<const int> dst_ports,
132 utils::MutableNodeView* dst_node,
133 absl::string_view op);
134
135 // Update all edges between src_node:src_ports and nodes take
136 // src_node:src_ports as fanin. Also update attr _output_shape of src_node.
137 Status UpdateFanoutEdgesWithOp(TransposeContext* context,
138 absl::Span<const int> src_ports,
139 utils::MutableNodeView* src_node,
140 absl::string_view op);
141
142 // Creates a DataFromat node with given properties.
143 // DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
144 Status CreateDataFormatNode(TransposeContext* context,
145 absl::string_view node_name, absl::string_view op,
146 absl::string_view device,
147 const DataType& data_type, bool is_fanin_on_host,
148 bool is_src_format_to_dst_format,
149 utils::MutationNewNode* added_node);
150
151 protected:
152 int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
153 bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
154 int n) const;
155 bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
156 absl::Span<const int> ports, int n) const;
157 int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
158 bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
159 int n) const;
160
161 // Checks if fanin at specified port(s) has dimensions `dims` iff fanin is a
162 // Const. If fanin is not a Const, no dimensions will be checked and this will
163 // return true.
164 bool IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, int port,
165 absl::Span<const int> dims) const;
166 bool IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
167 absl::Span<const int> ports,
168 absl::Span<const int> dims) const;
169 bool CanProcessNode(const TransposeContext& context,
170 const utils::MutableNodeView& node) const;
171 // Update all edges between dst_node->fanin[dst_ports] and dst_node.
172 // A node with op is created and inserted between all edges.
173 // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
174 Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
175 absl::string_view op, const AttrValue* input_shape,
176 bool is_in_frame, bool is_src_format_to_dst_format,
177 const int src_port, const int dst_port,
178 utils::MutableNodeView* src_node,
179 utils::MutableNodeView* dst_node);
180 string GetFaninNameFormat(absl::string_view node_name, int port,
181 absl::string_view src_format,
182 absl::string_view dst_format);
183 string GetFanoutNameFormat(absl::string_view node_name, int port, int index,
184 absl::string_view src_format,
185 absl::string_view dst_format);
186 string LayoutOptimizerNode(absl::string_view node_name);
187 string GetReshapeNodeNameFormat(absl::string_view node_name, int index,
188 absl::string_view src_format,
189 absl::string_view dst_format);
190 string GetShapeConstNodeNameFormat(absl::string_view node_name, int index);
191 };
192
193 class LayoutSensitiveOpTransposer : public Transposer {
194 public:
LayoutSensitiveOpTransposer()195 explicit LayoutSensitiveOpTransposer() : Transposer() {}
196
197 // Updates attrs data_format, ksize, strides of the given node to dst_format.
198 // _output_shape is updated during UpdateOutputEdges.
199 Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
200 };
201
202 // Layout sensitive op transposers.
203
204 class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
205 public:
DefaultLayoutSensitiveOpTransposer()206 explicit DefaultLayoutSensitiveOpTransposer()
207 : LayoutSensitiveOpTransposer() {}
208
209 Status TransposeNode(TransposeContext* context,
210 utils::MutableNodeView* node) override;
211 };
212
213 class BiasAddTransposer : public LayoutSensitiveOpTransposer {
214 public:
BiasAddTransposer()215 explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
216
217 Status TransposeNode(TransposeContext* context,
218 utils::MutableNodeView* node) override;
219 };
220
221 class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
222 public:
AvgPoolGradTransposer()223 explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
224
225 Status TransposeNode(TransposeContext* context,
226 utils::MutableNodeView* node) override;
227 };
228
229 class BiasAddGradTransposer : public LayoutSensitiveOpTransposer {
230 public:
BiasAddGradTransposer()231 explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {}
232
233 Status TransposeNode(TransposeContext* context,
234 utils::MutableNodeView* node) override;
235 };
236
237 class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
238 public:
Conv2DBackpropFilterTransposer()239 explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
240
241 Status TransposeNode(TransposeContext* context,
242 utils::MutableNodeView* node) override;
243 };
244
245 class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
246 public:
Conv2DBackpropInputTransposer()247 explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
248
249 Status TransposeNode(TransposeContext* context,
250 utils::MutableNodeView* node) override;
251 };
252
253 class Conv3DTransposer : public LayoutSensitiveOpTransposer {
254 public:
Conv3DTransposer()255 explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {}
256
257 Status TransposeNode(TransposeContext* context,
258 utils::MutableNodeView* node) override;
259 };
260
261 class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
262 public:
Conv3DBackpropFilterTransposer()263 explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
264
265 Status TransposeNode(TransposeContext* context,
266 utils::MutableNodeView* node) override;
267 };
268
269 class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
270 public:
Conv3DBackpropInputTransposer()271 explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
272
273 Status TransposeNode(TransposeContext* context,
274 utils::MutableNodeView* node) override;
275 };
276
277 class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer {
278 public:
FusedBatchNormExTransposer()279 explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {}
280
281 Status TransposeNode(TransposeContext* context,
282 utils::MutableNodeView* node) override;
283 };
284
285 class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer {
286 public:
FusedBatchNormGradTransposer()287 explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {}
288
289 Status TransposeNode(TransposeContext* context,
290 utils::MutableNodeView* node) override;
291
292 private:
293 bool IsTraining(const utils::MutableNodeView& node) const;
294 };
295
296 class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer {
297 public:
MaxPoolV2Transposer()298 explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {}
299
300 Status TransposeNode(TransposeContext* context,
301 utils::MutableNodeView* node) override;
302 };
303
304 class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer {
305 public:
MaxPoolGradTransposer()306 explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
307
308 Status TransposeNode(TransposeContext* context,
309 utils::MutableNodeView* node) override;
310 };
311
312 class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer {
313 public:
MaxPoolGradV2Transposer()314 explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {}
315
316 Status TransposeNode(TransposeContext* context,
317 utils::MutableNodeView* node) override;
318 };
319
320 // Layout agnostic op transposers.
321
322 class LayoutAgnosticOpTransposer : public Transposer {
323 public:
LayoutAgnosticOpTransposer()324 explicit LayoutAgnosticOpTransposer() : Transposer() {}
325
326 protected:
327 bool IsAfterDstToSrcTransform(const TransposeContext& context,
328 const utils::MutableNodeView& node) const;
329
330 std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
331 const utils::MutableNodeView& node,
332 int rank) const;
333 };
334
335 class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
336 public:
DefaultLayoutAgnosticOpTransposer()337 explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {}
338
339 Status TransposeNode(TransposeContext* context,
340 utils::MutableNodeView* node) override;
341 };
342
343 class AddNTransposer : public LayoutAgnosticOpTransposer {
344 public:
AddNTransposer()345 explicit AddNTransposer() : LayoutAgnosticOpTransposer() {}
346
347 Status TransposeNode(TransposeContext* context,
348 utils::MutableNodeView* node) override;
349 };
350
351 class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
352 public:
BinaryOpTransposer()353 explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {}
354
355 Status TransposeNode(TransposeContext* context,
356 utils::MutableNodeView* node) override;
357
358 private:
359 bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
360 bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
361 std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
362 int rank);
363 Status AddNodeShapeConst(utils::Mutation* mutation,
364 absl::string_view node_name,
365 absl::string_view node_device, bool node_in_frame,
366 int num_channels, absl::string_view depended_node,
367 int rank);
368 Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
369 absl::string_view node_device,
370 absl::string_view input_name,
371 absl::string_view shape_const_node_name,
372 const DataType& data_type);
373 Status MaybeReshapeVectorFanin(TransposeContext* context,
374 utils::MutableNodeView* node, int rank);
375 };
376
377 class ConcatOpTransposer : public LayoutAgnosticOpTransposer {
378 public:
ConcatOpTransposer()379 explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {}
380
381 Status TransposeNode(TransposeContext* context,
382 utils::MutableNodeView* node) override;
383 };
384
385 class FillOpTransposer : public LayoutAgnosticOpTransposer {
386 public:
FillOpTransposer()387 explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {}
388
389 Status TransposeNode(TransposeContext* context,
390 utils::MutableNodeView* node) override;
391 };
392
393 class IdentityNTransposer : public LayoutAgnosticOpTransposer {
394 public:
IdentityNTransposer()395 explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {}
396
397 Status TransposeNode(TransposeContext* context,
398 utils::MutableNodeView* node) override;
399 };
400
401 class MergeTransposer : public LayoutAgnosticOpTransposer {
402 public:
MergeTransposer()403 explicit MergeTransposer() : LayoutAgnosticOpTransposer() {}
404
405 Status TransposeNode(TransposeContext* context,
406 utils::MutableNodeView* node) override;
407
408 private:
409 bool IsEveryFaninAfterDstToSrcTransform(
410 const TransposeContext& context,
411 const utils::MutableNodeView& node) const;
412 };
413
414 class PadTransposer : public LayoutAgnosticOpTransposer {
415 public:
PadTransposer()416 explicit PadTransposer() : LayoutAgnosticOpTransposer() {}
417
418 Status TransposeNode(TransposeContext* context,
419 utils::MutableNodeView* node) override;
420 };
421
422 class ReduceTransposer : public LayoutAgnosticOpTransposer {
423 public:
ReduceTransposer()424 explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {}
425
426 Status TransposeNode(TransposeContext* context,
427 utils::MutableNodeView* node) override;
428
429 private:
430 bool KeepDims(const utils::MutableNodeView& node);
431 bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
432 bool IsReduceAxisSupported(const TransposeContext& context,
433 const utils::MutableNodeView& node, int rank);
434 };
435
436 class ReverseV2Transposer : public LayoutAgnosticOpTransposer {
437 public:
ReverseV2Transposer()438 explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {}
439
440 Status TransposeNode(TransposeContext* context,
441 utils::MutableNodeView* node) override;
442 };
443
444 class SelectTransposer : public LayoutAgnosticOpTransposer {
445 public:
SelectTransposer()446 explicit SelectTransposer() : LayoutAgnosticOpTransposer() {}
447
448 Status TransposeNode(TransposeContext* context,
449 utils::MutableNodeView* node) override;
450
451 protected:
452 bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port);
453 std::vector<int> GetFaninPorts(const utils::MutableNodeView& fanin, int port);
454 };
455
456 class ShapeTransposer : public LayoutAgnosticOpTransposer {
457 public:
ShapeTransposer()458 explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {}
459
460 Status TransposeNode(TransposeContext* context,
461 utils::MutableNodeView* node) override;
462 };
463
464 class ShapeNTransposer : public LayoutAgnosticOpTransposer {
465 public:
ShapeNTransposer()466 explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {}
467
468 Status TransposeNode(TransposeContext* context,
469 utils::MutableNodeView* node) override;
470 };
471
472 class SliceTransposer : public LayoutAgnosticOpTransposer {
473 public:
SliceTransposer()474 explicit SliceTransposer() : LayoutAgnosticOpTransposer() {}
475
476 Status TransposeNode(TransposeContext* context,
477 utils::MutableNodeView* node) override;
478 };
479
480 class SplitTransposer : public LayoutAgnosticOpTransposer {
481 public:
SplitTransposer()482 explicit SplitTransposer() : LayoutAgnosticOpTransposer() {}
483
484 Status TransposeNode(TransposeContext* context,
485 utils::MutableNodeView* node) override;
486 };
487
488 class SplitVTransposer : public LayoutAgnosticOpTransposer {
489 public:
SplitVTransposer()490 explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {}
491
492 Status TransposeNode(TransposeContext* context,
493 utils::MutableNodeView* node) override;
494 };
495
496 class SqueezeTransposer : public LayoutAgnosticOpTransposer {
497 public:
SqueezeTransposer()498 explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {}
499
500 Status TransposeNode(TransposeContext* context,
501 utils::MutableNodeView* node) override;
502
503 private:
504 bool IsInputConvertible(const TransposeContext& context,
505 const utils::MutableNodeView& node) const;
506 bool IsAlongAxis(const AttrValue& attr, absl::Span<const int> axis,
507 int rank) const;
508 bool IsDimsSupported(const TransposeContext& context,
509 const utils::MutableNodeView& node) const;
510 Status UpdateSqueezeDims(TransposeContext* context,
511 utils::MutableNodeView* node);
512 };
513
514 class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
515 public:
StridedSliceTransposer()516 explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {}
517
518 Status TransposeNode(TransposeContext* context,
519 utils::MutableNodeView* node) override;
520
521 private:
522 bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
523 bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
524 Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
525 absl::string_view mask);
526 };
527
528 class SwitchTransposer : public LayoutAgnosticOpTransposer {
529 public:
SwitchTransposer()530 explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {}
531
532 Status TransposeNode(TransposeContext* context,
533 utils::MutableNodeView* node) override;
534 };
535
536 class TernaryOpTransposer : public LayoutAgnosticOpTransposer {
537 public:
TernaryOpTransposer()538 explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {}
539
540 Status TransposeNode(TransposeContext* context,
541 utils::MutableNodeView* node) override;
542 };
543
544 class TileTransposer : public LayoutAgnosticOpTransposer {
545 public:
TileTransposer()546 explicit TileTransposer() : LayoutAgnosticOpTransposer() {}
547
548 Status TransposeNode(TransposeContext* context,
549 utils::MutableNodeView* node) override;
550 };
551
552 class UnaryGradTransposer : public LayoutAgnosticOpTransposer {
553 public:
UnaryGradTransposer()554 explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {}
555
556 Status TransposeNode(TransposeContext* context,
557 utils::MutableNodeView* node) override;
558 };
559
560 // Utils.
561
562 // Permutes elements according to permutation and replaces the original values.
563 // Permutation and values must have same size.
564 template <typename T>
PermuteSingle(absl::string_view location,absl::Span<const int> permutation,T * values)565 Status PermuteSingle(absl::string_view location,
566 absl::Span<const int> permutation, T* values) {
567 DCHECK(values != nullptr);
568 int permutation_size = permutation.size();
569 if (values->size() != permutation_size) {
570 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
571 absl::StrCat("Size of values ", values->size(),
572 " does not match size of permutation ",
573 permutation_size, " @ ", location));
574 }
575 typedef typename T::value_type V;
576 std::vector<V> elements(values->begin(), values->end());
577 int index = 0;
578 for (V& element : *values) {
579 element = elements[permutation[index++]];
580 }
581 return Status::OK();
582 }
583
584 // Permutes two elements at a time according to permutation and replaces the
585 // original values. Values must be twice the size of permutation.
586 template <typename T>
PermuteDouble(absl::string_view location,absl::Span<const int> permutation,T * values)587 Status PermuteDouble(absl::string_view location,
588 absl::Span<const int> permutation, T* values) {
589 DCHECK(values != nullptr);
590 int permutation_size = permutation.size();
591 if (values->size() != permutation_size * 2) {
592 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
593 absl::StrCat("Size of values ", values->size(),
594 " does not match twice the size of permutation ",
595 permutation_size, " @ ", location));
596 }
597 typedef typename T::value_type V;
598 std::vector<V> elements(values->begin(), values->end());
599 for (int i = 0; i < values->size(); i = i + 2) {
600 const int permutation_index = permutation[i / 2];
601 (*values)[i] = elements[permutation_index * 2];
602 (*values)[i + 1] = elements[permutation_index * 2 + 1];
603 }
604 return Status::OK();
605 }
606
607 string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node);
608
609 bool IsDefaultLayoutSensitiveOp(const NodeDef& node);
610
611 bool IsLayoutSensitiveOp(const NodeDef& node);
612
613 bool IsDefaultLayoutAgnosticOp(const NodeDef& node);
614
615 bool IsLayoutAgnosticOp(const NodeDef& node);
616
617 bool IsTernaryOp(const NodeDef& node);
618
619 bool IsUnaryGrad(const NodeDef& node);
620
621 bool IsMaxPoolV2(const NodeDef& node);
622
623 bool IsMaxPoolGradV2(const NodeDef& node);
624
625 bool IsMaxPoolGradGradV1(const NodeDef& node);
626
627 bool IsMaxPoolGradGradV2(const NodeDef& node);
628
629 bool IsBinaryOp(const NodeDef& node);
630
631 bool IsReduceOp(const NodeDef& node);
632
633 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node);
634
635 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
636
637 // Returns a value of constant input to the `node` at `index`, iff `predicate`
638 // evaluated to true. Returns true if `tensor` was populated with data.
639 bool GetValueAttrFromConstInputNode(
640 const utils::MutableNodeView& node,
641 const std::function<bool(const NodeDef&)>& predicate, int index,
642 Tensor* tensor);
643
644 bool IsDataFormatOp(const utils::MutableNodeView& node);
645
646 absl::flat_hash_map<char, int> GetDimensionIndices(
647 absl::string_view data_format);
648
649 std::vector<int> GetPermutation(
650 const absl::flat_hash_map<char, int>& src_dim_indices,
651 absl::string_view dst_format);
652
653 } // namespace grappler
654 } // namespace tensorflow
655
656 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
657