1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassManager.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/Passes.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
30
31 #define DEBUG_TYPE "tf-layout-optimization"
32
33 namespace mlir {
34 namespace TF {
35
36 namespace {
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
38
39 // Helper method that returns an op from 'transpose_ops' that match criteria
40 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)41 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
42 const SmallVector<int64_t, 4>& permutation,
43 Operation* op, ConstOp permutation_op,
44 SmallVector<TransposeOp, 2>* transpose_ops) {
45 for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
46 auto tranpose_op = *it;
47 for (auto tranpose_operand : tranpose_op.getOperands()) {
48 auto ranked_tranpose_type =
49 tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
50 if (!ranked_tranpose_type) continue;
51 if (ranked_tranpose_type.getRank() == permutation.size() &&
52 operand->get().getType() ==
53 ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
54 TransposeOp transpose = tranpose_op;
55 transpose.getOperation()->moveBefore(op);
56 transpose.setOperand(0, operand->get());
57 transpose.setOperand(1, permutation_op);
58 transpose_ops->erase(it);
59 return transpose;
60 }
61 }
62 }
63 return nullptr;
64 }
65
66 // LayoutAssignmentPass assigns optimal data layout (data format) for all
67 // layout sensitive operations.
68 class LayoutAssignmentPass
69 : public PassWrapper<LayoutAssignmentPass, FunctionPass> {
70 public:
71 LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)72 explicit LayoutAssignmentPass(const std::string& force_data_format) {
73 force_data_format_ = force_data_format;
74 }
75
LayoutAssignmentPass(const LayoutAssignmentPass & pass)76 LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
77
78 void runOnFunction() final;
79
80 private:
81 // Force a specified data format for all layout sensitive operations.
82 Option<std::string> force_data_format_{
83 *this, "force-data-format",
84 llvm::cl::desc("Force data format for all layout sensitive ops")};
85 };
86
87 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
88 // the basic block where they are defined. This will allow canonicalzer to
89 // delete redundant transposes.
90 class MoveTransposesPass
91 : public PassWrapper<MoveTransposesPass, FunctionPass> {
92 public:
93 enum class Direction { kBegin, kEnd };
94
95 MoveTransposesPass() = default;
MoveTransposesPass(Direction direction,bool fold_transpose_in_ops)96 explicit MoveTransposesPass(Direction direction, bool fold_transpose_in_ops) {
97 direction_ = direction;
98 fold_transpose_in_ops_ = fold_transpose_in_ops;
99 }
MoveTransposesPass(const MoveTransposesPass & pass)100 MoveTransposesPass(const MoveTransposesPass& pass) {}
101
102 void runOnFunction() final;
103
104 private:
105 Option<bool> fold_transpose_in_ops_{
106 *this, "fold-transpose-in-ops",
107 llvm::cl::desc(
108 "Whether to fold transposes in ops which can support folding."),
109 llvm::cl::init(true)};
110
111 Option<Direction> direction_{
112 *this, "direction",
113 llvm::cl::desc("Move transposes to the beginning or the end of the block "
114 "where they are defined."),
115 llvm::cl::values(
116 clEnumValN(Direction::kBegin, "begin", "beginning of the block"),
117 clEnumValN(Direction::kEnd, "end", "end of the block"))};
118 };
119
120 using Permutation = SmallVector<int64_t, 4>;
121
runOnFunction()122 void LayoutAssignmentPass::runOnFunction() {
123 FuncOp func = getFunction();
124
125 // Get runtime devices information from the closest parent module.
126 RuntimeDevices devices;
127 if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
128 &devices)))
129 return signalPassFailure();
130
131 // If there is no runtime device information and data format is not explicitly
132 // forced, there is nothing to do.
133 if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
134
135 func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
136 // Get desired op data format.
137 StringRef target_data_format = force_data_format_;
138 if (target_data_format.empty()) {
139 target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
140 }
141
142 // Skip ops that already use target data format.
143 auto data_format = layout_sensitive_interface.data_format();
144 if (data_format == target_data_format) return;
145
146 // Transpose arguments into the target data format.
147 Permutation args_permutation =
148 GetDataFormatPermutation(data_format, target_data_format);
149
150 // Transpose results back to the original data format.
151 Permutation res_permutation =
152 GetDataFormatPermutation(target_data_format, data_format);
153
154 if (args_permutation.empty() || res_permutation.empty()) return;
155
156 mlir::Operation* op = layout_sensitive_interface.getOperation();
157 Location loc = op->getLoc();
158 OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
159
160 auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
161 auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
162 return DenseIntElementsAttr::get(perm_ty, permutation);
163 };
164
165 // Change operation data format.
166 if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
167 return;
168
169 // Permute arguments into the target data format.
170 builder.setInsertionPoint(op);
171 auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
172
173 for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
174 op->setOperand(
175 arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
176 }
177
178 // Permute results back to the original data format.
179 builder.setInsertionPointAfter(op);
180 auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
181
182 for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
183 OpResult result = op->getResult(res);
184
185 auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
186 result.replaceAllUsesWith(transposed_res);
187 transposed_res.setOperand(0, result);
188 }
189 });
190 }
191
192 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)193 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
194 // TODO(ezhulenev): Move transpose across layout sensitive operations.
195 if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
196
197 // Transpose operations that use operation results.
198 SmallVector<TransposeOp, 2> transpose_ops;
199
200 // Constant operation that defines permutation indices for result transposes.
201 ConstOp permutation_op;
202
203 // All operation results must be used by transpose operations with the same
204 // permutation indices.
205 for (OpResult result : op->getResults()) {
206 for (Operation* user : result.getUsers()) {
207 // Result user must be a transpose operation.
208 TransposeOp transpose = dyn_cast<TransposeOp>(user);
209 if (!transpose) return;
210
211 // With permutation defined by constant operation.
212 ConstOp perm =
213 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
214 if (!perm) return;
215
216 // With the same permutation indices.
217 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
218 if (!dense_elem_attr) return;
219
220 if (!permutation_op) permutation_op = perm;
221
222 // Check that permutation matches for all result transposes.
223 if (perm.value() != permutation_op.value()) return;
224
225 // Add a transpose operation for later reuse.
226 transpose_ops.push_back(transpose);
227 }
228 }
229
230 // Nothing to do here.
231 if (!permutation_op || transpose_ops.empty()) return;
232 SmallVector<int64_t, 4> permutation;
233 auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
234 for (const auto& value : perm_attr.getIntValues())
235 permutation.push_back(value.getSExtValue());
236
237 // We want to make sure the shape of the operand equals the transposed shape.
238 // mismatch can happen if 'op' supports broadcasting and the operands have
239 // different ranks.
240 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
241 auto transpose_op = *transpose_ops.begin();
242 auto result_type =
243 transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
244 auto is_valid_move =
245 llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
246 auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
247 return result_type && operand_type && result_type.hasRank() &&
248 operand_type.hasRank() &&
249 result_type.getRank() == operand_type.getRank();
250 });
251 if (!is_valid_move) return;
252 }
253
254 // At this point we checked that we can safely move Transpose node before
255 // `op`, and bypass all result transposes.
256 Location loc = op->getLoc();
257
258 // Move constant op defining result permutation to the beginning of the block.
259 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
260
261 // Bypass Transpose nodes for all results.
262 for (OpResult result : op->getResults()) {
263 result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
264 for (Operation* transpose : result.getUsers()) {
265 transpose->getResult(0).replaceAllUsesWith(result);
266 }
267 }
268
269 // Maybe add a Transpose node for all operands (or reuse existing transposes).
270 OpBuilder builder(op);
271 builder.setInsertionPoint(op);
272
273 for (OpOperand& operand : op->getOpOperands()) {
274 // Try to push transpose further up.
275 if (Operation* operand_op = operand.get().getDefiningOp())
276 work_list->push_back(operand_op);
277
278 // Try to reuse result transposes.
279 TransposeOp transpose = ReuseExistingTranspose(
280 &operand, permutation, op, permutation_op, &transpose_ops);
281 // If no transpose available for using, create new one.
282 if (!transpose)
283 transpose =
284 builder.create<TransposeOp>(loc, operand.get(), permutation_op);
285
286 operand.set(transpose);
287 }
288
289 // Remove unused transpose operations.
290 while (!transpose_ops.empty()) {
291 TransposeOp transpose = transpose_ops.pop_back_val();
292 transpose.erase();
293 }
294 }
295
296 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)297 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
298 bool fold_transpose_in_ops) {
299 // Indices of operands and results that depend on data layout.
300 SmallVector<unsigned, 4> layout_dependent_operands;
301 SmallVector<unsigned, 4> layout_dependent_results;
302
303 auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
304 bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
305
306 if (fold_operands && fold_transpose_in_ops) {
307 layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
308 layout_dependent_results = fold_operands.GetLayoutDependentResults();
309
310 } else if (layout_agnostic) {
311 // For layout agnostic operation (e.g. element wise operations) all operands
312 // and results must have the same data layout.
313 for (unsigned i = 0; i < op->getNumOperands(); ++i)
314 layout_dependent_operands.push_back(i);
315 for (unsigned i = 0; i < op->getNumResults(); ++i)
316 layout_dependent_results.push_back(i);
317 }
318
319 // Transpose operations that are operands of the `op`.
320 SmallVector<TransposeOp, 2> transpose_ops;
321
322 // Constant operation that defines permutation indices for operand transposes.
323 ConstOp permutation_op;
324
325 // Layout dependent operands must be transpose operations with the same
326 // permutation indices.
327 for (unsigned idx : layout_dependent_operands) {
328 OpOperand& operand = op->getOpOperand(idx);
329
330 // Operand must be defined by a transpose op.
331 TransposeOp transpose =
332 dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
333 if (!transpose) return;
334
335 // With permutation defined by constant operation.
336 ConstOp perm =
337 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
338 if (!perm) return;
339
340 // With the same permutation indices.
341 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
342 if (!dense_elem_attr) return;
343
344 if (!permutation_op) permutation_op = perm;
345
346 // Check that permutation matches for all result transposes.
347 if (perm.value() != permutation_op.value()) return;
348
349 // Add a transpose operation for later reuse only if it's used once.
350 if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
351 }
352
353 // Nothing to do here.
354 if (!permutation_op) return;
355
356 // All results after transpose must preserve the original result type.
357 SmallVector<Type, 4> original_type(op->getNumResults());
358 for (unsigned idx : layout_dependent_results)
359 original_type[idx] = op->getResult(idx).getType();
360
361 // Check if we can fold transpose into the operation.
362 if (fold_operands && fold_transpose_in_ops) {
363 SmallVector<int64_t, 8> permutation;
364
365 auto attr = permutation_op.value().cast<DenseElementsAttr>();
366 for (const auto& value : attr.getIntValues())
367 permutation.push_back(value.getSExtValue());
368
369 if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
370 }
371
372 // At this point we checked that we can safely move Transpose node after
373 // `op`, bypass all operands transposes, and transpose op results.
374 Location loc = op->getLoc();
375
376 // Move constant op defining result permutation to the beginning of the block.
377 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
378
379 // Bypass Transpose nodes for layout dependent operands.
380 for (unsigned idx : layout_dependent_operands) {
381 OpOperand& operand = op->getOpOperand(idx);
382 TransposeOp transpose =
383 dyn_cast<TransposeOp>(operand.get().getDefiningOp());
384 operand.set(transpose.getOperand(0));
385 }
386
387 // Maybe add Transpose nodes for layout dependent results
388 // (or reuse existing transposes).
389 OpBuilder builder(op);
390 builder.setInsertionPoint(op);
391
392 for (unsigned idx : layout_dependent_results) {
393 OpResult result = op->getResult(idx);
394
395 // Forward operand type only for layout agnostic operations, operations with
396 // custom folding will update the result type in `FoldOperandsPermutation`.
397 if (layout_agnostic) result.setType(op->getOperand(0).getType());
398
399 // Try to push transpose further down.
400 for (Operation* user : result.getUsers()) work_list->push_back(user);
401
402 // Try to reuse operand transposes.
403 TransposeOp transpose;
404 if (!transpose_ops.empty()) {
405 transpose = transpose_ops.pop_back_val();
406 transpose.getOperation()->moveBefore(op->getNextNode());
407 transpose.setOperand(0, result);
408 transpose.setOperand(1, permutation_op);
409 transpose.getResult().setType(original_type[idx]);
410 } else {
411 transpose = builder.create<TransposeOp>(loc, result, permutation_op);
412 }
413
414 // Forward all users to the transpose operation.
415 result.replaceAllUsesWith(transpose);
416 transpose.setOperand(0, result);
417 }
418
419 // Remove unused transpose operations.
420 while (!transpose_ops.empty()) {
421 TransposeOp transpose = transpose_ops.pop_back_val();
422 transpose.erase();
423 }
424 }
425
runOnFunction()426 void MoveTransposesPass::runOnFunction() {
427 FuncOp func = getFunction();
428
429 SmallVector<Operation*, 8> work_list;
430
431 func.walk([&](TransposeOp transpose) {
432 if (direction_ == Direction::kBegin) {
433 // Try to push transpose before the operand operation.
434 for (auto operand : transpose.getOperands()) {
435 if (auto op = operand.getDefiningOp()) work_list.push_back(op);
436 }
437 } else {
438 // Try to push transpose after the user operation.
439 for (Operation* user : transpose.y().getUsers()) {
440 work_list.push_back(user);
441 }
442 }
443 });
444
445 while (!work_list.empty()) {
446 Operation* op = work_list.pop_back_val();
447 if (direction_ == Direction::kBegin) {
448 MoveTransposeBefore(op, &work_list);
449 } else if (direction_ == Direction::kEnd) {
450 MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
451 }
452 }
453
454 func.walk([&](TransposeOp transpose) {
455 OpBuilder builder(transpose);
456 SmallVector<Value, 1> fold_result;
457 if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
458 assert(fold_result.size() == 1);
459 transpose.replaceAllUsesWith(fold_result[0]);
460 }
461 });
462 }
463
464 } // namespace
465
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)466 void CreateLayoutOptimizationPipeline(
467 OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
468 const LayoutOptimizationPipelineOptions& options) {
469 using Direction = MoveTransposesPass::Direction;
470
471 // Assign optimal layout for layout sensitive ops.
472 pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
473
474 // Move transposes to the beginning of the block and try to fold them.
475 pm.addPass(std::make_unique<MoveTransposesPass>(
476 Direction::kBegin, !options.skip_fold_transpose_in_ops));
477
478 // Move transposes to the end of the block and try to fold them.
479 pm.addPass(std::make_unique<MoveTransposesPass>(
480 Direction::kEnd, !options.skip_fold_transpose_in_ops));
481 }
482
483 static PassRegistration<LayoutAssignmentPass> layout_assignment(
484 "tf-layout-assignment", "Layout assignment pass");
485 static PassRegistration<MoveTransposesPass> move_transposes(
486 "tf-move-transposes", "Move transposes pass");
487
488 static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
489 pipeline("tf-layout-optimization",
490 "Assigns optimal data layout to all layout sensitive operations "
491 "and cancel redundant transpose operations.",
492 CreateLayoutOptimizationPipeline);
493
494 } // namespace TF
495 } // namespace mlir
496