1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Block.h" // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/Diagnostics.h" // from @llvm-project
35 #include "mlir/IR/Dialect.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44
45 namespace mlir {
46 namespace TFDevice {
47 namespace {
48 constexpr char kDeviceAttr[] = "device";
49 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
50 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
51 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
52
53 struct ReplicateToIslandPass
54 : public PassWrapper<ReplicateToIslandPass, FunctionPass> {
55 void runOnFunction() override;
56 };
57
58 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)59 bool RequiresReplicaIDAttribute(Operation* op) {
60 return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
61 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
62 }
63
64 // Collects TPU device ordinal for outside compilation communication ops. This
65 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
66 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)67 llvm::Optional<int64_t> GetDeviceOrdinal(
68 const llvm::Optional<DictionaryAttr>& devices, Location loc,
69 unsigned replica_id) {
70 int64_t device_ordinal = 0;
71 if (devices.hasValue()) {
72 if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
73 llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
74 .cast<StringAttr>()
75 .getValue();
76 if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
77 loc, tpu_device, &device_ordinal))) {
78 return llvm::Optional<int64_t>(device_ordinal);
79 }
80 }
81 }
82 return llvm::None;
83 }
84
85 // Updates replica variant ops in a region based on replica `replica_id`.
86 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
87 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
88 // ops, require replica id to be added as an op attribute to be used during
89 // execution. Handle such ops separately and add an integer attribute that
90 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)91 LogicalResult UpdateRegionReplicateVariantOps(
92 OpBuilder& builder, Location loc, Region& region, int replica_id,
93 const llvm::Optional<DictionaryAttr>& devices) {
94 llvm::Optional<int64_t> device_ordinal =
95 GetDeviceOrdinal(devices, loc, replica_id);
96
97 auto result = region.walk([&](Operation* op) -> WalkResult {
98 if (RequiresReplicaIDAttribute(op)) {
99 op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
100 return WalkResult::advance();
101 }
102
103 if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
104 if (!device_ordinal.hasValue())
105 return op->emitOpError()
106 << "requires device ordinal from device " << kTPUCore0
107 << " to be present in 'tf.device.replicate' op";
108
109 OpBuilder builder(op);
110 auto const_op = builder.create<TF::ConstOp>(
111 op->getLoc(), DenseIntElementsAttr::get(
112 RankedTensorType::get({}, builder.getI64Type()),
113 {device_ordinal.getValue()}));
114 op->replaceAllUsesWith(const_op);
115 op->erase();
116 return WalkResult::advance();
117 }
118
119 if (!devices.hasValue()) return WalkResult::advance();
120
121 // Map aliased devices to explicit devices based on replica.
122 if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
123 if (auto device_by_replica = devices.getValue().get(launch.device()))
124 launch->setAttr(
125 kDeviceAttr,
126 device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
127
128 return WalkResult::advance();
129 });
130
131 return failure(result.wasInterrupted());
132 }
133
134 // Creates islands per replica from `tf_device.replicate` region. If for a
135 // `tf_device.launch` op the device is an aliased device of the
136 // `tf_device.replicate`, the device will be remapped to an explicit device
137 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)138 LogicalResult ExpandReplicateIntoReplicas(
139 const Dialect* tf_dialect, OpBuilder& builder,
140 tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
141 int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
142 replicas.reserve(num_replicas);
143 auto devices = replicate_op.devices();
144
145 // Collect result types and operands.
146 Operation& terminator = replicate_op.GetBody().back();
147 llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
148 auto control_type = tf_executor::ControlType::get(island_op.getContext());
149 llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
150
151 // Replace replicate terminator with YieldOp.
152 builder.setInsertionPoint(&terminator);
153 builder.create<tf_executor::YieldOp>(terminator.getLoc(),
154 terminator.getOperands());
155 terminator.erase();
156
157 builder.setInsertionPoint(island_op);
158 BlockAndValueMapping mapping;
159 for (int i : llvm::seq<int>(0, num_replicas)) {
160 // Create new island for replica.
161 auto replica = builder.create<tf_executor::IslandOp>(
162 island_op.getLoc(), output_types, control_type, replica_inputs);
163
164 // Map block arg to replica arg.
165 mapping.clear();
166 for (auto& block_arg : replicate_op.GetBody().getArguments())
167 mapping.map(block_arg,
168 replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
169
170 // Copy over replicate region into replica island.
171 replicate_op.body().cloneInto(&replica.body(), mapping);
172
173 if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
174 replica.body(),
175 /*replica_id=*/i, devices)))
176 return failure();
177
178 replicas.push_back(replica);
179 }
180
181 return success();
182 }
183
184 // Creates islands per replica from `tf_device.replicate` region and remap
185 // replicate results with new island outputs. A single island is created to
186 // forward control dependencies if there is a control dependency output from the
187 // replicate island. Devices are remapped from aliased devices to explicit
188 // devices, for `tf_device.launch` ops.
189 //
190 // For example, the following:
191 //
192 // %0:2 = tf_executor.island(%control) {
193 // %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
194 // {n = 2 : i32,
195 // devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
196 // DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
197 // %a = "tf_device.launch"() ( {
198 // %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
199 // tf_device.return %2 : tensor<i1>
200 // }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
201 // %b = "tf_device.launch"() ( {
202 // %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
203 // tf_device.return %3 : tensor<i1>
204 // }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
205 // tf_device.return %a, %b : tensor<i1>, tensor<i1>
206 // }
207 // tf_executor.yield %1#0 : tensor<i1>
208 // }
209 //
210 // gets lowered to:
211 //
212 // %0:3 = tf_executor.island(%control) {
213 // %a0 = "tf_device.launch"() ( {
214 // %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
215 // tf_device.return %1 : tensor<i1>
216 // }) {device = "/DEVICE:0"} : () -> tensor<i1>
217 // %b0 = "tf_device.launch"() ( {
218 // %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
219 // tf_device.return %2 : tensor<i1>
220 // }) {device = "/DEVICE:2"} : () -> tensor<i1>
221 // tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
222 // }
223 // %3:3 = tf_executor.island(%control) {
224 // %a1 = "tf_device.launch"() ( {
225 // %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
226 // tf_device.return %4 : tensor<i1>
227 // }) {device = "/DEVICE:1"} : () -> tensor<i1>
228 // %b1 = "tf_device.launch"() ( {
229 // %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
230 // tf_device.return %5 : tensor<i1>
231 // }) {device = "/DEVICE:3"} : () -> tensor<i1>
232 // tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
233 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)234 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
235 tf_executor::GraphOp graph_op,
236 tf_executor::IslandOp island_op,
237 tf_device::ReplicateOp replicate_op) {
238 OpBuilder builder(island_op);
239 const int num_replicas = replicate_op.n();
240
241 // Create islands per replica.
242 llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
243 if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
244 replicate_op, num_replicas, replicas)))
245 return failure();
246
247 // Collect all replica results.
248 llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
249 nullptr);
250 for (auto replica_and_idx : llvm::enumerate(replicas))
251 for (auto replica_result_and_idx :
252 llvm::enumerate(replica_and_idx.value().outputs()))
253 replicas_outputs[num_replicas * replica_result_and_idx.index() +
254 replica_and_idx.index()] =
255 replica_result_and_idx.value();
256
257 // Remap replicate results to per replica result.
258 for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
259 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
260
261 // Add sink island to pin all replicas as a control dependency if there is a
262 // control dependency leading from the replicate originally.
263 if (!island_op.control().use_empty()) {
264 llvm::SmallVector<Value, 8> island_operands;
265 for (auto& replica : replicas) island_operands.push_back(replica.control());
266
267 builder.setInsertionPoint(island_op);
268 auto island_sink = builder.create<tf_executor::IslandOp>(
269 island_op.getLoc(), llvm::ArrayRef<Type>{},
270 tf_executor::ControlType::get(island_op.getContext()), island_operands);
271 island_sink.body().push_back(new Block);
272 builder.setInsertionPointToEnd(&island_sink.GetBody());
273 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
274 llvm::ArrayRef<Value>{});
275 island_op.control().replaceAllUsesWith(island_sink.control());
276 }
277
278 // Replicas with no uses should be pinned to a graph fetch so they still
279 // execute.
280 llvm::SmallVector<Value, 8> unused_replica_controls;
281 for (auto& replica : replicas)
282 if (replica.use_empty())
283 unused_replica_controls.push_back(replica.control());
284
285 if (!unused_replica_controls.empty()) {
286 tf_executor::FetchOp fetch = graph_op.GetFetch();
287 auto fetches = llvm::to_vector<8>(fetch.getOperands());
288 fetches.append(unused_replica_controls.begin(),
289 unused_replica_controls.end());
290 builder.setInsertionPoint(fetch);
291 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
292 fetch.erase();
293 }
294
295 island_op.erase();
296 return success();
297 }
298
runOnFunction()299 void ReplicateToIslandPass::runOnFunction() {
300 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
301 if (!tf_dialect) {
302 getOperation().emitError() << "'tf' dialect is not registered";
303 return signalPassFailure();
304 }
305
306 // Find islands with a single `tf_device.replicate` and create individual
307 // islands per replica of the replicate.
308 llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
309 getOperation().walk([&](tf_executor::GraphOp graph_op) {
310 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
311 if (!island_op.WrapsSingleOp()) continue;
312
313 if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
314 replicate_op_islands.push_back(island_op);
315 }
316 });
317
318 for (tf_executor::IslandOp island_op : replicate_op_islands) {
319 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
320 auto replicate_op =
321 cast<tf_device::ReplicateOp>(island_op.GetBody().front());
322 if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
323 replicate_op)))
324 return signalPassFailure();
325 }
326 }
327 } // anonymous namespace
328
CreateReplicateToIslandPass()329 std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
330 return std::make_unique<ReplicateToIslandPass>();
331 }
332
333 static PassRegistration<ReplicateToIslandPass> pass(
334 "tf-replicate-to-island", "Lowers device replicate to executor islands");
335
336 } // namespace TFDevice
337 } // namespace mlir
338