• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include <cstddef>
14 
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "mlir/Pass/Pass.h"  // from @llvm-project
18 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22 
23 namespace mlir {
24 namespace TFTPU {
25 namespace {
26 
27 struct TPUReorderReplicateAndPartitionedInputsPass
28     : public TF::TPUReorderReplicateAndPartitionedInputsPassBase<
29           TPUReorderReplicateAndPartitionedInputsPass> {
30   void runOnFunction() override;
31 };
32 
ReorderReplicateAndPartitionedInputs(TF::TPUReplicatedInputOp replicated_input)33 LogicalResult ReorderReplicateAndPartitionedInputs(
34     TF::TPUReplicatedInputOp replicated_input) {
35   if (!llvm::all_of(replicated_input.inputs(), [](Value input) {
36         return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
37             input.getDefiningOp());
38       }))
39     return replicated_input.emitOpError()
40            << "expects all inputs from 'tf.TPUPartitionedInput' ops";
41 
42   if (replicated_input.index() != -1)
43     return replicated_input->emitOpError()
44            << "unsupported index = " << replicated_input.index();
45 
46   auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputOp>(
47       replicated_input.getOperand(0).getDefiningOp());
48   llvm::Optional<::llvm::StringRef> xla_sharding =
49       first_partitioned_input._XlaSharding();
50   int64_t partition_dim = first_partitioned_input.partition_dim();
51   size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
52 
53   for (auto operand : replicated_input.inputs().drop_front()) {
54     auto partitioned_input =
55         llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
56     llvm::Optional<::llvm::StringRef> op_xla_sharding =
57         partitioned_input._XlaSharding();
58     int64_t op_partition_dim = partitioned_input.partition_dim();
59     // Abort if TPUPartitionedInput(s) do not have the same attributes.
60     if (partition_dim != op_partition_dim)
61       return partitioned_input->emitOpError()
62              << "expects partition_dim = " << partition_dim << " but found "
63              << op_partition_dim;
64     if (partitioned_input.getNumOperands() != num_cores_per_replica)
65       return partitioned_input->emitOpError()
66              << "expects " << num_cores_per_replica << " operands but found "
67              << partitioned_input.getNumOperands();
68     if (xla_sharding != op_xla_sharding)
69       return replicated_input.emitOpError()
70              << "expects all inputs from 'tf.TPUPartitionedInput' ops to have "
71                 "identical XLA sharding";
72   }
73 
74   // 2D Matrix to store per core per replica operands. The matrix dimensions are
75   // num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
76   // core. j-th column holds the operands for j-th replica.
77   llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
78       operands_per_replica_per_core;
79   operands_per_replica_per_core.resize(num_cores_per_replica);
80 
81   // Collect all operands in the 2D matrix.
82   for (auto operand : replicated_input.inputs()) {
83     auto pi = llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
84     for (auto& pi_operand : pi->getOpOperands()) {
85       unsigned core_id = pi_operand.getOperandNumber();
86       operands_per_replica_per_core[core_id].push_back(pi_operand.get());
87     }
88   }
89 
90   // Create new `tf.TPUReplicatedInput` ops feeding into one
91   // `tf.TPUPartitionedInput` op.
92   OpBuilder builder(replicated_input);
93   llvm::SmallVector<Value, 4> operands_per_core;
94   for (const auto& operands_per_replica : operands_per_replica_per_core) {
95     auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
96         replicated_input.getLoc(), replicated_input.getType(),
97         operands_per_replica, replicated_input->getAttrs());
98     operands_per_core.push_back(replicate_op);
99   }
100 
101   auto pi = builder.create<TF::TPUPartitionedInputOp>(
102       first_partitioned_input.getLoc(), replicated_input.getType(),
103       operands_per_core, first_partitioned_input->getAttrs());
104   replicated_input.replaceAllUsesWith(pi.output());
105   return success();
106 }
107 
runOnFunction()108 void TPUReorderReplicateAndPartitionedInputsPass::runOnFunction() {
109   auto result =
110       getFunction()->walk([](TF::TPUReplicatedInputOp replicated_input) {
111         if (llvm::none_of(replicated_input.inputs(), [](Value input) {
112               return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
113                   input.getDefiningOp());
114             }))
115           return WalkResult::advance();
116         if (failed(ReorderReplicateAndPartitionedInputs(replicated_input)))
117           return WalkResult::interrupt();
118 
119         assert(replicated_input->use_empty());
120         replicated_input->erase();
121         return WalkResult::advance();
122       });
123 
124   if (result.wasInterrupted()) {
125     signalPassFailure();
126     return;
127   }
128 
129   getFunction()->walk([](TF::TPUPartitionedInputOp partitioned_input) {
130     if (partitioned_input->use_empty()) partitioned_input->erase();
131   });
132 }
133 
134 }  // namespace
135 
136 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUReorderReplicateAndPartitionedInputsPass()137 CreateTPUReorderReplicateAndPartitionedInputsPass() {
138   return std::make_unique<TPUReorderReplicateAndPartitionedInputsPass>();
139 }
140 
141 }  // namespace TFTPU
142 }  // namespace mlir
143