• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
30 
31 #define DEBUG_TYPE "tf-layout-optimization"
32 
33 namespace mlir {
34 namespace TF {
35 
36 namespace {
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
38 
39 // Helper method that returns an op from 'transpose_ops' that match criteria
40 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)41 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
42                                    const SmallVector<int64_t, 4>& permutation,
43                                    Operation* op, ConstOp permutation_op,
44                                    SmallVector<TransposeOp, 2>* transpose_ops) {
45   for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
46     auto tranpose_op = *it;
47     for (auto tranpose_operand : tranpose_op.getOperands()) {
48       auto ranked_tranpose_type =
49           tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
50       if (!ranked_tranpose_type) continue;
51       if (ranked_tranpose_type.getRank() == permutation.size() &&
52           operand->get().getType() ==
53               ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
54         TransposeOp transpose = tranpose_op;
55         transpose.getOperation()->moveBefore(op);
56         transpose.setOperand(0, operand->get());
57         transpose.setOperand(1, permutation_op);
58         transpose_ops->erase(it);
59         return transpose;
60       }
61     }
62   }
63   return nullptr;
64 }
65 
66 // LayoutAssignmentPass assigns optimal data layout (data format) for all
67 // layout sensitive operations.
68 class LayoutAssignmentPass
69     : public PassWrapper<LayoutAssignmentPass, FunctionPass> {
70  public:
71   LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)72   explicit LayoutAssignmentPass(const std::string& force_data_format) {
73     force_data_format_ = force_data_format;
74   }
75 
LayoutAssignmentPass(const LayoutAssignmentPass & pass)76   LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
77 
78   void runOnFunction() final;
79 
80  private:
81   // Force a specified data format for all layout sensitive operations.
82   Option<std::string> force_data_format_{
83       *this, "force-data-format",
84       llvm::cl::desc("Force data format for all layout sensitive ops")};
85 };
86 
87 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
88 // the basic block where they are defined. This will allow canonicalzer to
89 // delete redundant transposes.
90 class MoveTransposesPass
91     : public PassWrapper<MoveTransposesPass, FunctionPass> {
92  public:
93   enum class Direction { kBegin, kEnd };
94 
95   MoveTransposesPass() = default;
MoveTransposesPass(Direction direction,bool fold_transpose_in_ops)96   explicit MoveTransposesPass(Direction direction, bool fold_transpose_in_ops) {
97     direction_ = direction;
98     fold_transpose_in_ops_ = fold_transpose_in_ops;
99   }
MoveTransposesPass(const MoveTransposesPass & pass)100   MoveTransposesPass(const MoveTransposesPass& pass) {}
101 
102   void runOnFunction() final;
103 
104  private:
105   Option<bool> fold_transpose_in_ops_{
106       *this, "fold-transpose-in-ops",
107       llvm::cl::desc(
108           "Whether to fold transposes in ops which can support folding."),
109       llvm::cl::init(true)};
110 
111   Option<Direction> direction_{
112       *this, "direction",
113       llvm::cl::desc("Move transposes to the beginning or the end of the block "
114                      "where they are defined."),
115       llvm::cl::values(
116           clEnumValN(Direction::kBegin, "begin", "beginning of the block"),
117           clEnumValN(Direction::kEnd, "end", "end of the block"))};
118 };
119 
120 using Permutation = SmallVector<int64_t, 4>;
121 
runOnFunction()122 void LayoutAssignmentPass::runOnFunction() {
123   FuncOp func = getFunction();
124 
125   // Get runtime devices information from the closest parent module.
126   RuntimeDevices devices;
127   if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
128                                             &devices)))
129     return signalPassFailure();
130 
131   // If there is no runtime device information and data format is not explicitly
132   // forced, there is nothing to do.
133   if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
134 
135   func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
136     // Get desired op data format.
137     StringRef target_data_format = force_data_format_;
138     if (target_data_format.empty()) {
139       target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
140     }
141 
142     // Skip ops that already use target data format.
143     auto data_format = layout_sensitive_interface.data_format();
144     if (data_format == target_data_format) return;
145 
146     // Transpose arguments into the target data format.
147     Permutation args_permutation =
148         GetDataFormatPermutation(data_format, target_data_format);
149 
150     // Transpose results back to the original data format.
151     Permutation res_permutation =
152         GetDataFormatPermutation(target_data_format, data_format);
153 
154     if (args_permutation.empty() || res_permutation.empty()) return;
155 
156     mlir::Operation* op = layout_sensitive_interface.getOperation();
157     Location loc = op->getLoc();
158     OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
159 
160     auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
161       auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
162       return DenseIntElementsAttr::get(perm_ty, permutation);
163     };
164 
165     // Change operation data format.
166     if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
167       return;
168 
169     // Permute arguments into the target data format.
170     builder.setInsertionPoint(op);
171     auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
172 
173     for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
174       op->setOperand(
175           arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
176     }
177 
178     // Permute results back to the original data format.
179     builder.setInsertionPointAfter(op);
180     auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
181 
182     for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
183       OpResult result = op->getResult(res);
184 
185       auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
186       result.replaceAllUsesWith(transposed_res);
187       transposed_res.setOperand(0, result);
188     }
189   });
190 }
191 
192 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)193 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
194   // TODO(ezhulenev): Move transpose across layout sensitive operations.
195   if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
196 
197   // Transpose operations that use operation results.
198   SmallVector<TransposeOp, 2> transpose_ops;
199 
200   // Constant operation that defines permutation indices for result transposes.
201   ConstOp permutation_op;
202 
203   // All operation results must be used by transpose operations with the same
204   // permutation indices.
205   for (OpResult result : op->getResults()) {
206     for (Operation* user : result.getUsers()) {
207       // Result user must be a transpose operation.
208       TransposeOp transpose = dyn_cast<TransposeOp>(user);
209       if (!transpose) return;
210 
211       // With permutation defined by constant operation.
212       ConstOp perm =
213           dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
214       if (!perm) return;
215 
216       // With the same permutation indices.
217       auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
218       if (!dense_elem_attr) return;
219 
220       if (!permutation_op) permutation_op = perm;
221 
222       // Check that permutation matches for all result transposes.
223       if (perm.value() != permutation_op.value()) return;
224 
225       // Add a transpose operation for later reuse.
226       transpose_ops.push_back(transpose);
227     }
228   }
229 
230   // Nothing to do here.
231   if (!permutation_op || transpose_ops.empty()) return;
232   SmallVector<int64_t, 4> permutation;
233   auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
234   for (const auto& value : perm_attr.getIntValues())
235     permutation.push_back(value.getSExtValue());
236 
237   // We want to make sure the shape of the operand equals the transposed shape.
238   // mismatch can happen if 'op' supports broadcasting and the operands have
239   // different ranks.
240   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
241     auto transpose_op = *transpose_ops.begin();
242     auto result_type =
243         transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
244     auto is_valid_move =
245         llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
246           auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
247           return result_type && operand_type && result_type.hasRank() &&
248                  operand_type.hasRank() &&
249                  result_type.getRank() == operand_type.getRank();
250         });
251     if (!is_valid_move) return;
252   }
253 
254   // At this point we checked that we can safely move Transpose node before
255   // `op`, and bypass all result transposes.
256   Location loc = op->getLoc();
257 
258   // Move constant op defining result permutation to the beginning of the block.
259   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
260 
261   // Bypass Transpose nodes for all results.
262   for (OpResult result : op->getResults()) {
263     result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
264     for (Operation* transpose : result.getUsers()) {
265       transpose->getResult(0).replaceAllUsesWith(result);
266     }
267   }
268 
269   // Maybe add a Transpose node for all operands (or reuse existing transposes).
270   OpBuilder builder(op);
271   builder.setInsertionPoint(op);
272 
273   for (OpOperand& operand : op->getOpOperands()) {
274     // Try to push transpose further up.
275     if (Operation* operand_op = operand.get().getDefiningOp())
276       work_list->push_back(operand_op);
277 
278     // Try to reuse result transposes.
279     TransposeOp transpose = ReuseExistingTranspose(
280         &operand, permutation, op, permutation_op, &transpose_ops);
281     // If no transpose available for using, create new one.
282     if (!transpose)
283       transpose =
284           builder.create<TransposeOp>(loc, operand.get(), permutation_op);
285 
286     operand.set(transpose);
287   }
288 
289   // Remove unused transpose operations.
290   while (!transpose_ops.empty()) {
291     TransposeOp transpose = transpose_ops.pop_back_val();
292     transpose.erase();
293   }
294 }
295 
296 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)297 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
298                         bool fold_transpose_in_ops) {
299   // Indices of operands and results that depend on data layout.
300   SmallVector<unsigned, 4> layout_dependent_operands;
301   SmallVector<unsigned, 4> layout_dependent_results;
302 
303   auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
304   bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
305 
306   if (fold_operands && fold_transpose_in_ops) {
307     layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
308     layout_dependent_results = fold_operands.GetLayoutDependentResults();
309 
310   } else if (layout_agnostic) {
311     // For layout agnostic operation (e.g. element wise operations) all operands
312     // and results must have the same data layout.
313     for (unsigned i = 0; i < op->getNumOperands(); ++i)
314       layout_dependent_operands.push_back(i);
315     for (unsigned i = 0; i < op->getNumResults(); ++i)
316       layout_dependent_results.push_back(i);
317   }
318 
319   // Transpose operations that are operands of the `op`.
320   SmallVector<TransposeOp, 2> transpose_ops;
321 
322   // Constant operation that defines permutation indices for operand transposes.
323   ConstOp permutation_op;
324 
325   // Layout dependent operands must be transpose operations with the same
326   // permutation indices.
327   for (unsigned idx : layout_dependent_operands) {
328     OpOperand& operand = op->getOpOperand(idx);
329 
330     // Operand must be defined by a transpose op.
331     TransposeOp transpose =
332         dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
333     if (!transpose) return;
334 
335     // With permutation defined by constant operation.
336     ConstOp perm =
337         dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
338     if (!perm) return;
339 
340     // With the same permutation indices.
341     auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
342     if (!dense_elem_attr) return;
343 
344     if (!permutation_op) permutation_op = perm;
345 
346     // Check that permutation matches for all result transposes.
347     if (perm.value() != permutation_op.value()) return;
348 
349     // Add a transpose operation for later reuse only if it's used once.
350     if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
351   }
352 
353   // Nothing to do here.
354   if (!permutation_op) return;
355 
356   // All results after transpose must preserve the original result type.
357   SmallVector<Type, 4> original_type(op->getNumResults());
358   for (unsigned idx : layout_dependent_results)
359     original_type[idx] = op->getResult(idx).getType();
360 
361   // Check if we can fold transpose into the operation.
362   if (fold_operands && fold_transpose_in_ops) {
363     SmallVector<int64_t, 8> permutation;
364 
365     auto attr = permutation_op.value().cast<DenseElementsAttr>();
366     for (const auto& value : attr.getIntValues())
367       permutation.push_back(value.getSExtValue());
368 
369     if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
370   }
371 
372   // At this point we checked that we can safely move Transpose node after
373   // `op`, bypass all operands transposes, and transpose op results.
374   Location loc = op->getLoc();
375 
376   // Move constant op defining result permutation to the beginning of the block.
377   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
378 
379   // Bypass Transpose nodes for layout dependent operands.
380   for (unsigned idx : layout_dependent_operands) {
381     OpOperand& operand = op->getOpOperand(idx);
382     TransposeOp transpose =
383         dyn_cast<TransposeOp>(operand.get().getDefiningOp());
384     operand.set(transpose.getOperand(0));
385   }
386 
387   // Maybe add Transpose nodes for layout dependent results
388   // (or reuse existing transposes).
389   OpBuilder builder(op);
390   builder.setInsertionPoint(op);
391 
392   for (unsigned idx : layout_dependent_results) {
393     OpResult result = op->getResult(idx);
394 
395     // Forward operand type only for layout agnostic operations, operations with
396     // custom folding will update the result type in `FoldOperandsPermutation`.
397     if (layout_agnostic) result.setType(op->getOperand(0).getType());
398 
399     // Try to push transpose further down.
400     for (Operation* user : result.getUsers()) work_list->push_back(user);
401 
402     // Try to reuse operand transposes.
403     TransposeOp transpose;
404     if (!transpose_ops.empty()) {
405       transpose = transpose_ops.pop_back_val();
406       transpose.getOperation()->moveBefore(op->getNextNode());
407       transpose.setOperand(0, result);
408       transpose.setOperand(1, permutation_op);
409       transpose.getResult().setType(original_type[idx]);
410     } else {
411       transpose = builder.create<TransposeOp>(loc, result, permutation_op);
412     }
413 
414     // Forward all users to the transpose operation.
415     result.replaceAllUsesWith(transpose);
416     transpose.setOperand(0, result);
417   }
418 
419   // Remove unused transpose operations.
420   while (!transpose_ops.empty()) {
421     TransposeOp transpose = transpose_ops.pop_back_val();
422     transpose.erase();
423   }
424 }
425 
runOnFunction()426 void MoveTransposesPass::runOnFunction() {
427   FuncOp func = getFunction();
428 
429   SmallVector<Operation*, 8> work_list;
430 
431   func.walk([&](TransposeOp transpose) {
432     if (direction_ == Direction::kBegin) {
433       // Try to push transpose before the operand operation.
434       for (auto operand : transpose.getOperands()) {
435         if (auto op = operand.getDefiningOp()) work_list.push_back(op);
436       }
437     } else {
438       // Try to push transpose after the user operation.
439       for (Operation* user : transpose.y().getUsers()) {
440         work_list.push_back(user);
441       }
442     }
443   });
444 
445   while (!work_list.empty()) {
446     Operation* op = work_list.pop_back_val();
447     if (direction_ == Direction::kBegin) {
448       MoveTransposeBefore(op, &work_list);
449     } else if (direction_ == Direction::kEnd) {
450       MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
451     }
452   }
453 
454   func.walk([&](TransposeOp transpose) {
455     OpBuilder builder(transpose);
456     SmallVector<Value, 1> fold_result;
457     if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
458       assert(fold_result.size() == 1);
459       transpose.replaceAllUsesWith(fold_result[0]);
460     }
461   });
462 }
463 
464 }  // namespace
465 
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)466 void CreateLayoutOptimizationPipeline(
467     OpPassManager& pm,  // NOLINT - MLIR contract is pass by mutable reference.
468     const LayoutOptimizationPipelineOptions& options) {
469   using Direction = MoveTransposesPass::Direction;
470 
471   // Assign optimal layout for layout sensitive ops.
472   pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
473 
474   // Move transposes to the beginning of the block and try to fold them.
475   pm.addPass(std::make_unique<MoveTransposesPass>(
476       Direction::kBegin, !options.skip_fold_transpose_in_ops));
477 
478   // Move transposes to the end of the block and try to fold them.
479   pm.addPass(std::make_unique<MoveTransposesPass>(
480       Direction::kEnd, !options.skip_fold_transpose_in_ops));
481 }
482 
483 static PassRegistration<LayoutAssignmentPass> layout_assignment(
484     "tf-layout-assignment", "Layout assignment pass");
485 static PassRegistration<MoveTransposesPass> move_transposes(
486     "tf-move-transposes", "Move transposes pass");
487 
488 static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
489     pipeline("tf-layout-optimization",
490              "Assigns optimal data layout to all layout sensitive operations "
491              "and cancel redundant transpose operations.",
492              CreateLayoutOptimizationPipeline);
493 
494 }  // namespace TF
495 }  // namespace mlir
496