• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
30 
31 #define DEBUG_TYPE "tf-layout-optimization"
32 
33 namespace mlir {
34 namespace TF {
35 
36 namespace {
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
38 
39 // Helper method that returns an op from 'transpose_ops' that match criteria
40 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)41 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
42                                    const SmallVector<int64_t, 4>& permutation,
43                                    Operation* op, ConstOp permutation_op,
44                                    SmallVector<TransposeOp, 2>* transpose_ops) {
45   for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
46     auto tranpose_op = *it;
47     for (auto tranpose_operand : tranpose_op.getOperands()) {
48       auto ranked_tranpose_type =
49           tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
50       if (!ranked_tranpose_type) continue;
51       if (ranked_tranpose_type.getRank() == permutation.size() &&
52           operand->get().getType() ==
53               ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
54         TransposeOp transpose = tranpose_op;
55         transpose.getOperation()->moveBefore(op);
56         transpose.setOperand(0, operand->get());
57         transpose.setOperand(1, permutation_op);
58         transpose_ops->erase(it);
59         return transpose;
60       }
61     }
62   }
63   return nullptr;
64 }
65 
66 // LayoutAssignmentPass assigns optimal data layout (data format) for all
67 // layout sensitive operations.
68 class LayoutAssignmentPass
69     : public PassWrapper<LayoutAssignmentPass, FunctionPass> {
70  public:
71   LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)72   explicit LayoutAssignmentPass(const std::string& force_data_format) {
73     force_data_format_ = force_data_format;
74   }
75 
LayoutAssignmentPass(const LayoutAssignmentPass & pass)76   LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
77 
getArgument() const78   StringRef getArgument() const final { return "tf-layout-assignment"; }
79 
getDescription() const80   StringRef getDescription() const final { return "Layout assignment pass"; }
81 
82   void runOnFunction() final;
83 
84  private:
85   // Force a specified data format for all layout sensitive operations.
86   Option<std::string> force_data_format_{
87       *this, "force-data-format",
88       llvm::cl::desc("Force data format for all layout sensitive ops")};
89 };
90 
91 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
92 // the basic block where they are defined. This will allow canonicalzer to
93 // delete redundant transposes.
94 class MoveTransposesPass
95     : public PassWrapper<MoveTransposesPass, FunctionPass> {
96  public:
97   enum class Direction { kBegin, kEnd };
98 
99   MoveTransposesPass() = default;
MoveTransposesPass(Direction direction,bool fold_transpose_in_ops)100   explicit MoveTransposesPass(Direction direction, bool fold_transpose_in_ops) {
101     direction_ = direction;
102     fold_transpose_in_ops_ = fold_transpose_in_ops;
103   }
MoveTransposesPass(const MoveTransposesPass & pass)104   MoveTransposesPass(const MoveTransposesPass& pass) {}
105 
getArgument() const106   StringRef getArgument() const final { return "tf-move-transposes"; }
getDescription() const107   StringRef getDescription() const final { return "Move transposes pass"; }
108 
109   void runOnFunction() final;
110 
111  private:
112   Option<bool> fold_transpose_in_ops_{
113       *this, "fold-transpose-in-ops",
114       llvm::cl::desc(
115           "Whether to fold transposes in ops which can support folding."),
116       llvm::cl::init(true)};
117 
118   Option<Direction> direction_{
119       *this, "direction",
120       llvm::cl::desc("Move transposes to the beginning or the end of the block "
121                      "where they are defined."),
122       llvm::cl::values(
123           clEnumValN(Direction::kBegin, "begin", "beginning of the block"),
124           clEnumValN(Direction::kEnd, "end", "end of the block"))};
125 };
126 
127 using Permutation = SmallVector<int64_t, 4>;
128 
runOnFunction()129 void LayoutAssignmentPass::runOnFunction() {
130   FuncOp func = getFunction();
131 
132   // Get runtime devices information from the closest parent module.
133   RuntimeDevices devices;
134   if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
135                                             &devices)))
136     return signalPassFailure();
137 
138   // If there is no runtime device information and data format is not explicitly
139   // forced, there is nothing to do.
140   if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
141 
142   func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
143     // Get desired op data format.
144     StringRef target_data_format = force_data_format_;
145     if (target_data_format.empty()) {
146       target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
147     }
148 
149     // Skip ops that already use target data format.
150     auto data_format = layout_sensitive_interface.data_format();
151     if (data_format == target_data_format) return;
152 
153     // Transpose arguments into the target data format.
154     Permutation args_permutation =
155         GetDataFormatPermutation(data_format, target_data_format);
156 
157     // Transpose results back to the original data format.
158     Permutation res_permutation =
159         GetDataFormatPermutation(target_data_format, data_format);
160 
161     if (args_permutation.empty() || res_permutation.empty()) return;
162 
163     mlir::Operation* op = layout_sensitive_interface.getOperation();
164     Location loc = op->getLoc();
165     OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
166 
167     auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
168       auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
169       return DenseIntElementsAttr::get(perm_ty, permutation);
170     };
171 
172     // Change operation data format.
173     if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
174       return;
175 
176     // Permute arguments into the target data format.
177     builder.setInsertionPoint(op);
178     auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
179 
180     for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
181       op->setOperand(
182           arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
183     }
184 
185     // Permute results back to the original data format.
186     builder.setInsertionPointAfter(op);
187     auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
188 
189     for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
190       OpResult result = op->getResult(res);
191 
192       auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
193       result.replaceAllUsesWith(transposed_res);
194       transposed_res.setOperand(0, result);
195     }
196   });
197 }
198 
199 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)200 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
201   // TODO(ezhulenev): Move transpose across layout sensitive operations.
202   if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
203 
204   // Transpose operations that use operation results.
205   SmallVector<TransposeOp, 2> transpose_ops;
206 
207   // Constant operation that defines permutation indices for result transposes.
208   ConstOp permutation_op;
209 
210   // All operation results must be used by transpose operations with the same
211   // permutation indices.
212   for (OpResult result : op->getResults()) {
213     for (Operation* user : result.getUsers()) {
214       // Result user must be a transpose operation.
215       TransposeOp transpose = dyn_cast<TransposeOp>(user);
216       if (!transpose) return;
217 
218       // With permutation defined by constant operation.
219       ConstOp perm =
220           dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
221       if (!perm) return;
222 
223       // With the same permutation indices.
224       auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
225       if (!dense_elem_attr) return;
226 
227       if (!permutation_op) permutation_op = perm;
228 
229       // Check that permutation matches for all result transposes.
230       if (perm.value() != permutation_op.value()) return;
231 
232       // Add a transpose operation for later reuse.
233       transpose_ops.push_back(transpose);
234     }
235   }
236 
237   // Nothing to do here.
238   if (!permutation_op || transpose_ops.empty()) return;
239   SmallVector<int64_t, 4> permutation;
240   auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
241   for (const auto& value : perm_attr.getIntValues())
242     permutation.push_back(value.getSExtValue());
243 
244   // We want to make sure the shape of the operand equals the transposed shape.
245   // mismatch can happen if 'op' supports broadcasting and the operands have
246   // different ranks.
247   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
248     auto transpose_op = *transpose_ops.begin();
249     auto result_type =
250         transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
251     auto is_valid_move =
252         llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
253           auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
254           return result_type && operand_type && result_type.hasRank() &&
255                  operand_type.hasRank() &&
256                  result_type.getRank() == operand_type.getRank();
257         });
258     if (!is_valid_move) return;
259   }
260 
261   // At this point we checked that we can safely move Transpose node before
262   // `op`, and bypass all result transposes.
263   Location loc = op->getLoc();
264 
265   // Move constant op defining result permutation to the beginning of the block.
266   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
267 
268   // Bypass Transpose nodes for all results.
269   for (OpResult result : op->getResults()) {
270     result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
271     for (Operation* transpose : result.getUsers()) {
272       transpose->getResult(0).replaceAllUsesWith(result);
273     }
274   }
275 
276   // Maybe add a Transpose node for all operands (or reuse existing transposes).
277   OpBuilder builder(op);
278   builder.setInsertionPoint(op);
279 
280   for (OpOperand& operand : op->getOpOperands()) {
281     // Try to push transpose further up.
282     if (Operation* operand_op = operand.get().getDefiningOp())
283       work_list->push_back(operand_op);
284 
285     // Try to reuse result transposes.
286     TransposeOp transpose = ReuseExistingTranspose(
287         &operand, permutation, op, permutation_op, &transpose_ops);
288     // If no transpose available for using, create new one.
289     if (!transpose)
290       transpose =
291           builder.create<TransposeOp>(loc, operand.get(), permutation_op);
292 
293     operand.set(transpose);
294   }
295 
296   // Remove unused transpose operations.
297   while (!transpose_ops.empty()) {
298     TransposeOp transpose = transpose_ops.pop_back_val();
299     transpose.erase();
300   }
301 }
302 
303 // Revert the permutation applied in `type`.
ReversePermuteShapedType(mlir::ShapedType type,ArrayRef<int64_t> permutation)304 static mlir::ShapedType ReversePermuteShapedType(
305     mlir::ShapedType type, ArrayRef<int64_t> permutation) {
306   if (!type.hasRank()) return type;
307 
308   auto shape = type.getShape();
309   SmallVector<int64_t, 4> new_shape(shape.size());
310 
311   for (int i = 0; i < permutation.size(); ++i) {
312     int64_t index = permutation[i];
313     assert(index < shape.size());
314     new_shape[index] = shape[i];
315   }
316 
317   return type.clone(new_shape);
318 }
319 
320 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)321 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
322                         bool fold_transpose_in_ops) {
323   // Indices of operands and results that depend on data layout.
324   SmallVector<unsigned, 4> layout_dependent_operands;
325   SmallVector<unsigned, 4> layout_dependent_results;
326 
327   auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
328   bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
329 
330   if (fold_operands && fold_transpose_in_ops) {
331     layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
332     layout_dependent_results = fold_operands.GetLayoutDependentResults();
333 
334   } else if (layout_agnostic) {
335     // For layout agnostic operation (e.g. element wise operations) all operands
336     // and results must have the same data layout.
337     for (unsigned i = 0; i < op->getNumOperands(); ++i)
338       layout_dependent_operands.push_back(i);
339     for (unsigned i = 0; i < op->getNumResults(); ++i)
340       layout_dependent_results.push_back(i);
341   }
342 
343   // Transpose operations that are operands of the `op`.
344   SmallVector<TransposeOp, 2> transpose_ops;
345 
346   // Constant operation that defines permutation indices for operand transposes.
347   ConstOp permutation_op;
348 
349   // Layout dependent operands must be transpose operations with the same
350   // permutation indices.
351   for (unsigned idx : layout_dependent_operands) {
352     OpOperand& operand = op->getOpOperand(idx);
353 
354     // Operand must be defined by a transpose op.
355     TransposeOp transpose =
356         dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
357     if (!transpose) return;
358 
359     // With permutation defined by constant operation.
360     ConstOp perm =
361         dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
362     if (!perm) return;
363 
364     // With the same permutation indices.
365     auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
366     if (!dense_elem_attr) return;
367 
368     if (!permutation_op) permutation_op = perm;
369 
370     // Check that permutation matches for all result transposes.
371     if (perm.value() != permutation_op.value()) return;
372 
373     // Add a transpose operation for later reuse only if it's used once.
374     if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
375   }
376 
377   // Nothing to do here.
378   if (!permutation_op) return;
379 
380   // All results after transpose must preserve the original result type.
381   SmallVector<Type, 4> original_type(op->getNumResults());
382   for (unsigned idx : layout_dependent_results)
383     original_type[idx] = op->getResult(idx).getType();
384 
385   SmallVector<int64_t, 8> permutation;
386 
387   auto attr = permutation_op.value().cast<DenseElementsAttr>();
388   for (const auto& value : attr.getIntValues())
389     permutation.push_back(value.getSExtValue());
390 
391   // Check if we can fold transpose into the operation.
392   if (fold_operands && fold_transpose_in_ops) {
393     if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
394   }
395 
396   // At this point we checked that we can safely move Transpose node after
397   // `op`, bypass all operands transposes, and transpose op results.
398   Location loc = op->getLoc();
399 
400   // Move constant op defining result permutation to the beginning of the block.
401   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
402 
403   // Bypass Transpose nodes for layout dependent operands.
404   for (unsigned idx : layout_dependent_operands) {
405     OpOperand& operand = op->getOpOperand(idx);
406     TransposeOp transpose =
407         dyn_cast<TransposeOp>(operand.get().getDefiningOp());
408     operand.set(transpose.getOperand(0));
409   }
410 
411   // Maybe add Transpose nodes for layout dependent results
412   // (or reuse existing transposes).
413   OpBuilder builder(op);
414   builder.setInsertionPointAfter(op);
415 
416   for (unsigned idx : layout_dependent_results) {
417     OpResult result = op->getResult(idx);
418 
419     // If the op is layout agnostic, the new result type can be generated by
420     // reverting `permutation`. Otherwise, operations with custom folding will
421     // update the result type in `FoldOperandsPermutation`.
422     if (layout_agnostic)
423       result.setType(ReversePermuteShapedType(
424           result.getType().cast<ShapedType>(), permutation));
425 
426     // Try to push transpose further down.
427     for (Operation* user : result.getUsers()) {
428       if (!llvm::isa<TransposeOp>(user)) work_list->push_back(user);
429     }
430 
431     // Try to reuse operand transposes.
432     TransposeOp transpose;
433     if (!transpose_ops.empty()) {
434       transpose = transpose_ops.pop_back_val();
435       transpose.getOperation()->moveBefore(op->getNextNode());
436       transpose.setOperand(0, result);
437       transpose.setOperand(1, permutation_op);
438       transpose.getResult().setType(original_type[idx]);
439     } else {
440       transpose = builder.create<TransposeOp>(loc, result, permutation_op);
441     }
442 
443     // Forward all users to the transpose operation.
444     result.replaceAllUsesWith(transpose);
445     transpose.setOperand(0, result);
446   }
447 
448   // Remove unused transpose operations.
449   while (!transpose_ops.empty()) {
450     TransposeOp transpose = transpose_ops.pop_back_val();
451     transpose.erase();
452   }
453 }
454 
runOnFunction()455 void MoveTransposesPass::runOnFunction() {
456   FuncOp func = getFunction();
457 
458   SmallVector<Operation*, 8> work_list;
459 
460   func.walk([&](TransposeOp transpose) {
461     if (direction_ == Direction::kBegin) {
462       // Try to push transpose before the operand operation.
463       for (auto operand : transpose.getOperands()) {
464         if (auto op = operand.getDefiningOp()) work_list.push_back(op);
465       }
466     } else {
467       // Try to push transpose after the user operation.
468       for (Operation* user : transpose.y().getUsers()) {
469         if (!llvm::isa<TransposeOp>(user)) work_list.push_back(user);
470       }
471     }
472   });
473 
474   while (!work_list.empty()) {
475     Operation* op = work_list.pop_back_val();
476     if (direction_ == Direction::kBegin) {
477       MoveTransposeBefore(op, &work_list);
478     } else if (direction_ == Direction::kEnd) {
479       MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
480     }
481   }
482 
483   func.walk([&](TransposeOp transpose) {
484     OpBuilder builder(transpose);
485     SmallVector<Value, 1> fold_result;
486     if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
487       assert(fold_result.size() == 1);
488       transpose.replaceAllUsesWith(fold_result[0]);
489     }
490   });
491 }
492 
493 }  // namespace
494 
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)495 void CreateLayoutOptimizationPipeline(
496     OpPassManager& pm,  // NOLINT - MLIR contract is pass by mutable reference.
497     const LayoutOptimizationPipelineOptions& options) {
498   using Direction = MoveTransposesPass::Direction;
499 
500   // Assign optimal layout for layout sensitive ops.
501   pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
502 
503   // Move transposes to the beginning of the block and try to fold them.
504   pm.addPass(std::make_unique<MoveTransposesPass>(
505       Direction::kBegin, !options.skip_fold_transpose_in_ops));
506 
507   // Move transposes to the end of the block and try to fold them.
508   pm.addPass(std::make_unique<MoveTransposesPass>(
509       Direction::kEnd, !options.skip_fold_transpose_in_ops));
510 }
511 
512 static PassRegistration<LayoutAssignmentPass> layout_assignment;
513 static PassRegistration<MoveTransposesPass> move_transposes;
514 
515 static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
516     pipeline("tf-layout-optimization",
517              "Assigns optimal data layout to all layout sensitive operations "
518              "and cancel redundant transpose operations.",
519              CreateLayoutOptimizationPipeline);
520 
521 }  // namespace TF
522 }  // namespace mlir
523