1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassManager.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/Passes.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
30
31 #define DEBUG_TYPE "tf-layout-optimization"
32
33 namespace mlir {
34 namespace TF {
35
36 namespace {
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
38
39 // Helper method that returns an op from 'transpose_ops' that match criteria
40 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)41 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
42 const SmallVector<int64_t, 4>& permutation,
43 Operation* op, ConstOp permutation_op,
44 SmallVector<TransposeOp, 2>* transpose_ops) {
45 for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
46 auto tranpose_op = *it;
47 for (auto tranpose_operand : tranpose_op.getOperands()) {
48 auto ranked_tranpose_type =
49 tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
50 if (!ranked_tranpose_type) continue;
51 if (ranked_tranpose_type.getRank() == permutation.size() &&
52 operand->get().getType() ==
53 ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
54 TransposeOp transpose = tranpose_op;
55 transpose.getOperation()->moveBefore(op);
56 transpose.setOperand(0, operand->get());
57 transpose.setOperand(1, permutation_op);
58 transpose_ops->erase(it);
59 return transpose;
60 }
61 }
62 }
63 return nullptr;
64 }
65
66 // LayoutAssignmentPass assigns optimal data layout (data format) for all
67 // layout sensitive operations.
68 class LayoutAssignmentPass
69 : public PassWrapper<LayoutAssignmentPass, FunctionPass> {
70 public:
71 LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)72 explicit LayoutAssignmentPass(const std::string& force_data_format) {
73 force_data_format_ = force_data_format;
74 }
75
LayoutAssignmentPass(const LayoutAssignmentPass & pass)76 LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
77
getArgument() const78 StringRef getArgument() const final { return "tf-layout-assignment"; }
79
getDescription() const80 StringRef getDescription() const final { return "Layout assignment pass"; }
81
82 void runOnFunction() final;
83
84 private:
85 // Force a specified data format for all layout sensitive operations.
86 Option<std::string> force_data_format_{
87 *this, "force-data-format",
88 llvm::cl::desc("Force data format for all layout sensitive ops")};
89 };
90
91 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
92 // the basic block where they are defined. This will allow canonicalzer to
93 // delete redundant transposes.
94 class MoveTransposesPass
95 : public PassWrapper<MoveTransposesPass, FunctionPass> {
96 public:
97 enum class Direction { kBegin, kEnd };
98
99 MoveTransposesPass() = default;
MoveTransposesPass(Direction direction,bool fold_transpose_in_ops)100 explicit MoveTransposesPass(Direction direction, bool fold_transpose_in_ops) {
101 direction_ = direction;
102 fold_transpose_in_ops_ = fold_transpose_in_ops;
103 }
MoveTransposesPass(const MoveTransposesPass & pass)104 MoveTransposesPass(const MoveTransposesPass& pass) {}
105
getArgument() const106 StringRef getArgument() const final { return "tf-move-transposes"; }
getDescription() const107 StringRef getDescription() const final { return "Move transposes pass"; }
108
109 void runOnFunction() final;
110
111 private:
112 Option<bool> fold_transpose_in_ops_{
113 *this, "fold-transpose-in-ops",
114 llvm::cl::desc(
115 "Whether to fold transposes in ops which can support folding."),
116 llvm::cl::init(true)};
117
118 Option<Direction> direction_{
119 *this, "direction",
120 llvm::cl::desc("Move transposes to the beginning or the end of the block "
121 "where they are defined."),
122 llvm::cl::values(
123 clEnumValN(Direction::kBegin, "begin", "beginning of the block"),
124 clEnumValN(Direction::kEnd, "end", "end of the block"))};
125 };
126
127 using Permutation = SmallVector<int64_t, 4>;
128
runOnFunction()129 void LayoutAssignmentPass::runOnFunction() {
130 FuncOp func = getFunction();
131
132 // Get runtime devices information from the closest parent module.
133 RuntimeDevices devices;
134 if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
135 &devices)))
136 return signalPassFailure();
137
138 // If there is no runtime device information and data format is not explicitly
139 // forced, there is nothing to do.
140 if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
141
142 func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
143 // Get desired op data format.
144 StringRef target_data_format = force_data_format_;
145 if (target_data_format.empty()) {
146 target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
147 }
148
149 // Skip ops that already use target data format.
150 auto data_format = layout_sensitive_interface.data_format();
151 if (data_format == target_data_format) return;
152
153 // Transpose arguments into the target data format.
154 Permutation args_permutation =
155 GetDataFormatPermutation(data_format, target_data_format);
156
157 // Transpose results back to the original data format.
158 Permutation res_permutation =
159 GetDataFormatPermutation(target_data_format, data_format);
160
161 if (args_permutation.empty() || res_permutation.empty()) return;
162
163 mlir::Operation* op = layout_sensitive_interface.getOperation();
164 Location loc = op->getLoc();
165 OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
166
167 auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
168 auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
169 return DenseIntElementsAttr::get(perm_ty, permutation);
170 };
171
172 // Change operation data format.
173 if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
174 return;
175
176 // Permute arguments into the target data format.
177 builder.setInsertionPoint(op);
178 auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
179
180 for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
181 op->setOperand(
182 arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
183 }
184
185 // Permute results back to the original data format.
186 builder.setInsertionPointAfter(op);
187 auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
188
189 for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
190 OpResult result = op->getResult(res);
191
192 auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
193 result.replaceAllUsesWith(transposed_res);
194 transposed_res.setOperand(0, result);
195 }
196 });
197 }
198
199 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)200 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
201 // TODO(ezhulenev): Move transpose across layout sensitive operations.
202 if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
203
204 // Transpose operations that use operation results.
205 SmallVector<TransposeOp, 2> transpose_ops;
206
207 // Constant operation that defines permutation indices for result transposes.
208 ConstOp permutation_op;
209
210 // All operation results must be used by transpose operations with the same
211 // permutation indices.
212 for (OpResult result : op->getResults()) {
213 for (Operation* user : result.getUsers()) {
214 // Result user must be a transpose operation.
215 TransposeOp transpose = dyn_cast<TransposeOp>(user);
216 if (!transpose) return;
217
218 // With permutation defined by constant operation.
219 ConstOp perm =
220 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
221 if (!perm) return;
222
223 // With the same permutation indices.
224 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
225 if (!dense_elem_attr) return;
226
227 if (!permutation_op) permutation_op = perm;
228
229 // Check that permutation matches for all result transposes.
230 if (perm.value() != permutation_op.value()) return;
231
232 // Add a transpose operation for later reuse.
233 transpose_ops.push_back(transpose);
234 }
235 }
236
237 // Nothing to do here.
238 if (!permutation_op || transpose_ops.empty()) return;
239 SmallVector<int64_t, 4> permutation;
240 auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
241 for (const auto& value : perm_attr.getIntValues())
242 permutation.push_back(value.getSExtValue());
243
244 // We want to make sure the shape of the operand equals the transposed shape.
245 // mismatch can happen if 'op' supports broadcasting and the operands have
246 // different ranks.
247 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
248 auto transpose_op = *transpose_ops.begin();
249 auto result_type =
250 transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
251 auto is_valid_move =
252 llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
253 auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
254 return result_type && operand_type && result_type.hasRank() &&
255 operand_type.hasRank() &&
256 result_type.getRank() == operand_type.getRank();
257 });
258 if (!is_valid_move) return;
259 }
260
261 // At this point we checked that we can safely move Transpose node before
262 // `op`, and bypass all result transposes.
263 Location loc = op->getLoc();
264
265 // Move constant op defining result permutation to the beginning of the block.
266 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
267
268 // Bypass Transpose nodes for all results.
269 for (OpResult result : op->getResults()) {
270 result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
271 for (Operation* transpose : result.getUsers()) {
272 transpose->getResult(0).replaceAllUsesWith(result);
273 }
274 }
275
276 // Maybe add a Transpose node for all operands (or reuse existing transposes).
277 OpBuilder builder(op);
278 builder.setInsertionPoint(op);
279
280 for (OpOperand& operand : op->getOpOperands()) {
281 // Try to push transpose further up.
282 if (Operation* operand_op = operand.get().getDefiningOp())
283 work_list->push_back(operand_op);
284
285 // Try to reuse result transposes.
286 TransposeOp transpose = ReuseExistingTranspose(
287 &operand, permutation, op, permutation_op, &transpose_ops);
288 // If no transpose available for using, create new one.
289 if (!transpose)
290 transpose =
291 builder.create<TransposeOp>(loc, operand.get(), permutation_op);
292
293 operand.set(transpose);
294 }
295
296 // Remove unused transpose operations.
297 while (!transpose_ops.empty()) {
298 TransposeOp transpose = transpose_ops.pop_back_val();
299 transpose.erase();
300 }
301 }
302
303 // Revert the permutation applied in `type`.
ReversePermuteShapedType(mlir::ShapedType type,ArrayRef<int64_t> permutation)304 static mlir::ShapedType ReversePermuteShapedType(
305 mlir::ShapedType type, ArrayRef<int64_t> permutation) {
306 if (!type.hasRank()) return type;
307
308 auto shape = type.getShape();
309 SmallVector<int64_t, 4> new_shape(shape.size());
310
311 for (int i = 0; i < permutation.size(); ++i) {
312 int64_t index = permutation[i];
313 assert(index < shape.size());
314 new_shape[index] = shape[i];
315 }
316
317 return type.clone(new_shape);
318 }
319
320 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)321 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
322 bool fold_transpose_in_ops) {
323 // Indices of operands and results that depend on data layout.
324 SmallVector<unsigned, 4> layout_dependent_operands;
325 SmallVector<unsigned, 4> layout_dependent_results;
326
327 auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
328 bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
329
330 if (fold_operands && fold_transpose_in_ops) {
331 layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
332 layout_dependent_results = fold_operands.GetLayoutDependentResults();
333
334 } else if (layout_agnostic) {
335 // For layout agnostic operation (e.g. element wise operations) all operands
336 // and results must have the same data layout.
337 for (unsigned i = 0; i < op->getNumOperands(); ++i)
338 layout_dependent_operands.push_back(i);
339 for (unsigned i = 0; i < op->getNumResults(); ++i)
340 layout_dependent_results.push_back(i);
341 }
342
343 // Transpose operations that are operands of the `op`.
344 SmallVector<TransposeOp, 2> transpose_ops;
345
346 // Constant operation that defines permutation indices for operand transposes.
347 ConstOp permutation_op;
348
349 // Layout dependent operands must be transpose operations with the same
350 // permutation indices.
351 for (unsigned idx : layout_dependent_operands) {
352 OpOperand& operand = op->getOpOperand(idx);
353
354 // Operand must be defined by a transpose op.
355 TransposeOp transpose =
356 dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
357 if (!transpose) return;
358
359 // With permutation defined by constant operation.
360 ConstOp perm =
361 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
362 if (!perm) return;
363
364 // With the same permutation indices.
365 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
366 if (!dense_elem_attr) return;
367
368 if (!permutation_op) permutation_op = perm;
369
370 // Check that permutation matches for all result transposes.
371 if (perm.value() != permutation_op.value()) return;
372
373 // Add a transpose operation for later reuse only if it's used once.
374 if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
375 }
376
377 // Nothing to do here.
378 if (!permutation_op) return;
379
380 // All results after transpose must preserve the original result type.
381 SmallVector<Type, 4> original_type(op->getNumResults());
382 for (unsigned idx : layout_dependent_results)
383 original_type[idx] = op->getResult(idx).getType();
384
385 SmallVector<int64_t, 8> permutation;
386
387 auto attr = permutation_op.value().cast<DenseElementsAttr>();
388 for (const auto& value : attr.getIntValues())
389 permutation.push_back(value.getSExtValue());
390
391 // Check if we can fold transpose into the operation.
392 if (fold_operands && fold_transpose_in_ops) {
393 if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
394 }
395
396 // At this point we checked that we can safely move Transpose node after
397 // `op`, bypass all operands transposes, and transpose op results.
398 Location loc = op->getLoc();
399
400 // Move constant op defining result permutation to the beginning of the block.
401 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
402
403 // Bypass Transpose nodes for layout dependent operands.
404 for (unsigned idx : layout_dependent_operands) {
405 OpOperand& operand = op->getOpOperand(idx);
406 TransposeOp transpose =
407 dyn_cast<TransposeOp>(operand.get().getDefiningOp());
408 operand.set(transpose.getOperand(0));
409 }
410
411 // Maybe add Transpose nodes for layout dependent results
412 // (or reuse existing transposes).
413 OpBuilder builder(op);
414 builder.setInsertionPointAfter(op);
415
416 for (unsigned idx : layout_dependent_results) {
417 OpResult result = op->getResult(idx);
418
419 // If the op is layout agnostic, the new result type can be generated by
420 // reverting `permutation`. Otherwise, operations with custom folding will
421 // update the result type in `FoldOperandsPermutation`.
422 if (layout_agnostic)
423 result.setType(ReversePermuteShapedType(
424 result.getType().cast<ShapedType>(), permutation));
425
426 // Try to push transpose further down.
427 for (Operation* user : result.getUsers()) {
428 if (!llvm::isa<TransposeOp>(user)) work_list->push_back(user);
429 }
430
431 // Try to reuse operand transposes.
432 TransposeOp transpose;
433 if (!transpose_ops.empty()) {
434 transpose = transpose_ops.pop_back_val();
435 transpose.getOperation()->moveBefore(op->getNextNode());
436 transpose.setOperand(0, result);
437 transpose.setOperand(1, permutation_op);
438 transpose.getResult().setType(original_type[idx]);
439 } else {
440 transpose = builder.create<TransposeOp>(loc, result, permutation_op);
441 }
442
443 // Forward all users to the transpose operation.
444 result.replaceAllUsesWith(transpose);
445 transpose.setOperand(0, result);
446 }
447
448 // Remove unused transpose operations.
449 while (!transpose_ops.empty()) {
450 TransposeOp transpose = transpose_ops.pop_back_val();
451 transpose.erase();
452 }
453 }
454
runOnFunction()455 void MoveTransposesPass::runOnFunction() {
456 FuncOp func = getFunction();
457
458 SmallVector<Operation*, 8> work_list;
459
460 func.walk([&](TransposeOp transpose) {
461 if (direction_ == Direction::kBegin) {
462 // Try to push transpose before the operand operation.
463 for (auto operand : transpose.getOperands()) {
464 if (auto op = operand.getDefiningOp()) work_list.push_back(op);
465 }
466 } else {
467 // Try to push transpose after the user operation.
468 for (Operation* user : transpose.y().getUsers()) {
469 if (!llvm::isa<TransposeOp>(user)) work_list.push_back(user);
470 }
471 }
472 });
473
474 while (!work_list.empty()) {
475 Operation* op = work_list.pop_back_val();
476 if (direction_ == Direction::kBegin) {
477 MoveTransposeBefore(op, &work_list);
478 } else if (direction_ == Direction::kEnd) {
479 MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
480 }
481 }
482
483 func.walk([&](TransposeOp transpose) {
484 OpBuilder builder(transpose);
485 SmallVector<Value, 1> fold_result;
486 if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
487 assert(fold_result.size() == 1);
488 transpose.replaceAllUsesWith(fold_result[0]);
489 }
490 });
491 }
492
493 } // namespace
494
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)495 void CreateLayoutOptimizationPipeline(
496 OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
497 const LayoutOptimizationPipelineOptions& options) {
498 using Direction = MoveTransposesPass::Direction;
499
500 // Assign optimal layout for layout sensitive ops.
501 pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
502
503 // Move transposes to the beginning of the block and try to fold them.
504 pm.addPass(std::make_unique<MoveTransposesPass>(
505 Direction::kBegin, !options.skip_fold_transpose_in_ops));
506
507 // Move transposes to the end of the block and try to fold them.
508 pm.addPass(std::make_unique<MoveTransposesPass>(
509 Direction::kEnd, !options.skip_fold_transpose_in_ops));
510 }
511
512 static PassRegistration<LayoutAssignmentPass> layout_assignment;
513 static PassRegistration<MoveTransposesPass> move_transposes;
514
515 static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
516 pipeline("tf-layout-optimization",
517 "Assigns optimal data layout to all layout sensitive operations "
518 "and cancel redundant transpose operations.",
519 CreateLayoutOptimizationPipeline);
520
521 } // namespace TF
522 } // namespace mlir
523