1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include <cstddef>
14
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "mlir/Pass/Pass.h" // from @llvm-project
18 #include "mlir/Support/LogicalResult.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22
23 namespace mlir {
24 namespace TFTPU {
25 namespace {
26
27 struct TPUReorderReplicateAndPartitionedInputsPass
28 : public TF::TPUReorderReplicateAndPartitionedInputsPassBase<
29 TPUReorderReplicateAndPartitionedInputsPass> {
30 void runOnFunction() override;
31 };
32
ReorderReplicateAndPartitionedInputs(TF::TPUReplicatedInputOp replicated_input)33 LogicalResult ReorderReplicateAndPartitionedInputs(
34 TF::TPUReplicatedInputOp replicated_input) {
35 if (!llvm::all_of(replicated_input.inputs(), [](Value input) {
36 return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
37 input.getDefiningOp());
38 }))
39 return replicated_input.emitOpError()
40 << "expects all inputs from 'tf.TPUPartitionedInput' ops";
41
42 if (replicated_input.index() != -1)
43 return replicated_input->emitOpError()
44 << "unsupported index = " << replicated_input.index();
45
46 auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputOp>(
47 replicated_input.getOperand(0).getDefiningOp());
48 llvm::Optional<::llvm::StringRef> xla_sharding =
49 first_partitioned_input._XlaSharding();
50 int64_t partition_dim = first_partitioned_input.partition_dim();
51 size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
52
53 for (auto operand : replicated_input.inputs().drop_front()) {
54 auto partitioned_input =
55 llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
56 llvm::Optional<::llvm::StringRef> op_xla_sharding =
57 partitioned_input._XlaSharding();
58 int64_t op_partition_dim = partitioned_input.partition_dim();
59 // Abort if TPUPartitionedInput(s) do not have the same attributes.
60 if (partition_dim != op_partition_dim)
61 return partitioned_input->emitOpError()
62 << "expects partition_dim = " << partition_dim << " but found "
63 << op_partition_dim;
64 if (partitioned_input.getNumOperands() != num_cores_per_replica)
65 return partitioned_input->emitOpError()
66 << "expects " << num_cores_per_replica << " operands but found "
67 << partitioned_input.getNumOperands();
68 if (xla_sharding != op_xla_sharding)
69 return replicated_input.emitOpError()
70 << "expects all inputs from 'tf.TPUPartitionedInput' ops to have "
71 "identical XLA sharding";
72 }
73
74 // 2D Matrix to store per core per replica operands. The matrix dimensions are
75 // num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
76 // core. j-th column holds the operands for j-th replica.
77 llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
78 operands_per_replica_per_core;
79 operands_per_replica_per_core.resize(num_cores_per_replica);
80
81 // Collect all operands in the 2D matrix.
82 for (auto operand : replicated_input.inputs()) {
83 auto pi = llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
84 for (auto& pi_operand : pi->getOpOperands()) {
85 unsigned core_id = pi_operand.getOperandNumber();
86 operands_per_replica_per_core[core_id].push_back(pi_operand.get());
87 }
88 }
89
90 // Create new `tf.TPUReplicatedInput` ops feeding into one
91 // `tf.TPUPartitionedInput` op.
92 OpBuilder builder(replicated_input);
93 llvm::SmallVector<Value, 4> operands_per_core;
94 for (const auto& operands_per_replica : operands_per_replica_per_core) {
95 auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
96 replicated_input.getLoc(), replicated_input.getType(),
97 operands_per_replica, replicated_input->getAttrs());
98 operands_per_core.push_back(replicate_op);
99 }
100
101 auto pi = builder.create<TF::TPUPartitionedInputOp>(
102 first_partitioned_input.getLoc(), replicated_input.getType(),
103 operands_per_core, first_partitioned_input->getAttrs());
104 replicated_input.replaceAllUsesWith(pi.output());
105 return success();
106 }
107
runOnFunction()108 void TPUReorderReplicateAndPartitionedInputsPass::runOnFunction() {
109 auto result =
110 getFunction()->walk([](TF::TPUReplicatedInputOp replicated_input) {
111 if (llvm::none_of(replicated_input.inputs(), [](Value input) {
112 return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
113 input.getDefiningOp());
114 }))
115 return WalkResult::advance();
116 if (failed(ReorderReplicateAndPartitionedInputs(replicated_input)))
117 return WalkResult::interrupt();
118
119 assert(replicated_input->use_empty());
120 replicated_input->erase();
121 return WalkResult::advance();
122 });
123
124 if (result.wasInterrupted()) {
125 signalPassFailure();
126 return;
127 }
128
129 getFunction()->walk([](TF::TPUPartitionedInputOp partitioned_input) {
130 if (partitioned_input->use_empty()) partitioned_input->erase();
131 });
132 }
133
134 } // namespace
135
136 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUReorderReplicateAndPartitionedInputsPass()137 CreateTPUReorderReplicateAndPartitionedInputsPass() {
138 return std::make_unique<TPUReorderReplicateAndPartitionedInputsPass>();
139 }
140
141 } // namespace TFTPU
142 } // namespace mlir
143