1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Block.h" // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/Diagnostics.h" // from @llvm-project
35 #include "mlir/IR/Dialect.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44
45 namespace mlir {
46 namespace TFDevice {
47 namespace {
48 constexpr char kDeviceAttr[] = "device";
49 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
50 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
51 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
52
53 struct ReplicateToIslandPass
54 : public PassWrapper<ReplicateToIslandPass, FunctionPass> {
getArgumentmlir::TFDevice::__anon1565bbf70111::ReplicateToIslandPass55 StringRef getArgument() const final { return "tf-replicate-to-island"; }
56
getDescriptionmlir::TFDevice::__anon1565bbf70111::ReplicateToIslandPass57 StringRef getDescription() const final {
58 return "Lowers device replicate to executor islands";
59 }
60
61 void runOnFunction() override;
62 };
63
64 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)65 bool RequiresReplicaIDAttribute(Operation* op) {
66 return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
67 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
68 }
69
70 // Collects TPU device ordinal for outside compilation communication ops. This
71 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
72 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)73 llvm::Optional<int64_t> GetDeviceOrdinal(
74 const llvm::Optional<DictionaryAttr>& devices, Location loc,
75 unsigned replica_id) {
76 int64_t device_ordinal = 0;
77 if (devices.hasValue()) {
78 if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
79 llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
80 .cast<StringAttr>()
81 .getValue();
82 if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
83 loc, tpu_device, &device_ordinal))) {
84 return llvm::Optional<int64_t>(device_ordinal);
85 }
86 }
87 }
88 return llvm::None;
89 }
90
91 // Updates replica variant ops in a region based on replica `replica_id`.
92 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
93 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
94 // ops, require replica id to be added as an op attribute to be used during
95 // execution. Handle such ops separately and add an integer attribute that
96 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)97 LogicalResult UpdateRegionReplicateVariantOps(
98 OpBuilder& builder, Location loc, Region& region, int replica_id,
99 const llvm::Optional<DictionaryAttr>& devices) {
100 llvm::Optional<int64_t> device_ordinal =
101 GetDeviceOrdinal(devices, loc, replica_id);
102
103 auto result = region.walk([&](Operation* op) -> WalkResult {
104 if (RequiresReplicaIDAttribute(op)) {
105 op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
106 return WalkResult::advance();
107 }
108
109 if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
110 if (!device_ordinal.hasValue())
111 return op->emitOpError()
112 << "requires device ordinal from device " << kTPUCore0
113 << " to be present in 'tf.device.replicate' op";
114
115 OpBuilder builder(op);
116 auto const_op = builder.create<TF::ConstOp>(
117 op->getLoc(), DenseIntElementsAttr::get(
118 RankedTensorType::get({}, builder.getI64Type()),
119 {device_ordinal.getValue()}));
120 op->replaceAllUsesWith(const_op);
121 op->erase();
122 return WalkResult::advance();
123 }
124
125 if (!devices.hasValue()) return WalkResult::advance();
126
127 // Map aliased devices to explicit devices based on replica.
128 if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
129 if (auto device_by_replica = devices.getValue().get(launch.device()))
130 launch->setAttr(
131 kDeviceAttr,
132 device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
133
134 return WalkResult::advance();
135 });
136
137 return failure(result.wasInterrupted());
138 }
139
140 // Creates islands per replica from `tf_device.replicate` region. If for a
141 // `tf_device.launch` op the device is an aliased device of the
142 // `tf_device.replicate`, the device will be remapped to an explicit device
143 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)144 LogicalResult ExpandReplicateIntoReplicas(
145 const Dialect* tf_dialect, OpBuilder& builder,
146 tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
147 int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
148 replicas.reserve(num_replicas);
149 auto devices = replicate_op.devices();
150
151 // Collect result types and operands.
152 Operation& terminator = replicate_op.GetBody().back();
153 llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
154 auto control_type = tf_executor::ControlType::get(island_op.getContext());
155 llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
156
157 // Replace replicate terminator with YieldOp.
158 builder.setInsertionPoint(&terminator);
159 builder.create<tf_executor::YieldOp>(terminator.getLoc(),
160 terminator.getOperands());
161 terminator.erase();
162
163 builder.setInsertionPoint(island_op);
164 BlockAndValueMapping mapping;
165 for (int i : llvm::seq<int>(0, num_replicas)) {
166 // Create new island for replica.
167 auto replica = builder.create<tf_executor::IslandOp>(
168 island_op.getLoc(), output_types, control_type, replica_inputs);
169
170 // Map block arg to replica arg.
171 mapping.clear();
172 for (auto& block_arg : replicate_op.GetBody().getArguments())
173 mapping.map(block_arg,
174 replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
175
176 // Copy over replicate region into replica island.
177 replicate_op.body().cloneInto(&replica.body(), mapping);
178
179 if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
180 replica.body(),
181 /*replica_id=*/i, devices)))
182 return failure();
183
184 replicas.push_back(replica);
185 }
186
187 return success();
188 }
189
190 // Creates islands per replica from `tf_device.replicate` region and remap
191 // replicate results with new island outputs. A single island is created to
192 // forward control dependencies if there is a control dependency output from the
193 // replicate island. Devices are remapped from aliased devices to explicit
194 // devices, for `tf_device.launch` ops.
195 //
196 // For example, the following:
197 //
198 // %0:2 = tf_executor.island(%control) {
199 // %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
200 // {n = 2 : i32,
201 // devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
202 // DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
203 // %a = "tf_device.launch"() ( {
204 // %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
205 // tf_device.return %2 : tensor<i1>
206 // }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
207 // %b = "tf_device.launch"() ( {
208 // %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
209 // tf_device.return %3 : tensor<i1>
210 // }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
211 // tf_device.return %a, %b : tensor<i1>, tensor<i1>
212 // }
213 // tf_executor.yield %1#0 : tensor<i1>
214 // }
215 //
216 // gets lowered to:
217 //
218 // %0:3 = tf_executor.island(%control) {
219 // %a0 = "tf_device.launch"() ( {
220 // %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
221 // tf_device.return %1 : tensor<i1>
222 // }) {device = "/DEVICE:0"} : () -> tensor<i1>
223 // %b0 = "tf_device.launch"() ( {
224 // %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
225 // tf_device.return %2 : tensor<i1>
226 // }) {device = "/DEVICE:2"} : () -> tensor<i1>
227 // tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
228 // }
229 // %3:3 = tf_executor.island(%control) {
230 // %a1 = "tf_device.launch"() ( {
231 // %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
232 // tf_device.return %4 : tensor<i1>
233 // }) {device = "/DEVICE:1"} : () -> tensor<i1>
234 // %b1 = "tf_device.launch"() ( {
235 // %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
236 // tf_device.return %5 : tensor<i1>
237 // }) {device = "/DEVICE:3"} : () -> tensor<i1>
238 // tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
239 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)240 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
241 tf_executor::GraphOp graph_op,
242 tf_executor::IslandOp island_op,
243 tf_device::ReplicateOp replicate_op) {
244 OpBuilder builder(island_op);
245 const int num_replicas = replicate_op.n();
246
247 // Create islands per replica.
248 llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
249 if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
250 replicate_op, num_replicas, replicas)))
251 return failure();
252
253 // Collect all replica results.
254 llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
255 nullptr);
256 for (auto replica_and_idx : llvm::enumerate(replicas))
257 for (auto replica_result_and_idx :
258 llvm::enumerate(replica_and_idx.value().outputs()))
259 replicas_outputs[num_replicas * replica_result_and_idx.index() +
260 replica_and_idx.index()] =
261 replica_result_and_idx.value();
262
263 // Remap replicate results to per replica result.
264 for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
265 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
266
267 // Add sink island to pin all replicas as a control dependency if there is a
268 // control dependency leading from the replicate originally.
269 if (!island_op.control().use_empty()) {
270 llvm::SmallVector<Value, 8> island_operands;
271 for (auto& replica : replicas) island_operands.push_back(replica.control());
272
273 builder.setInsertionPoint(island_op);
274 auto island_sink = builder.create<tf_executor::IslandOp>(
275 island_op.getLoc(), llvm::ArrayRef<Type>{},
276 tf_executor::ControlType::get(island_op.getContext()), island_operands);
277 island_sink.body().push_back(new Block);
278 builder.setInsertionPointToEnd(&island_sink.GetBody());
279 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
280 llvm::ArrayRef<Value>{});
281 island_op.control().replaceAllUsesWith(island_sink.control());
282 }
283
284 // Replicas with no uses should be pinned to a graph fetch so they still
285 // execute.
286 llvm::SmallVector<Value, 8> unused_replica_controls;
287 for (auto& replica : replicas)
288 if (replica.use_empty())
289 unused_replica_controls.push_back(replica.control());
290
291 if (!unused_replica_controls.empty()) {
292 tf_executor::FetchOp fetch = graph_op.GetFetch();
293 auto fetches = llvm::to_vector<8>(fetch.getOperands());
294 fetches.append(unused_replica_controls.begin(),
295 unused_replica_controls.end());
296 builder.setInsertionPoint(fetch);
297 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
298 fetch.erase();
299 }
300
301 island_op.erase();
302 return success();
303 }
304
runOnFunction()305 void ReplicateToIslandPass::runOnFunction() {
306 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
307 if (!tf_dialect) {
308 getOperation().emitError() << "'tf' dialect is not registered";
309 return signalPassFailure();
310 }
311
312 // Find islands with a single `tf_device.replicate` and create individual
313 // islands per replica of the replicate.
314 llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
315 getOperation().walk([&](tf_executor::GraphOp graph_op) {
316 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
317 if (!island_op.WrapsSingleOp()) continue;
318
319 if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
320 replicate_op_islands.push_back(island_op);
321 }
322 });
323
324 for (tf_executor::IslandOp island_op : replicate_op_islands) {
325 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
326 auto replicate_op =
327 cast<tf_device::ReplicateOp>(island_op.GetBody().front());
328 if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
329 replicate_op)))
330 return signalPassFailure();
331 }
332 }
333 } // anonymous namespace
334
CreateReplicateToIslandPass()335 std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
336 return std::make_unique<ReplicateToIslandPass>();
337 }
338
339 static PassRegistration<ReplicateToIslandPass> pass;
340
341 } // namespace TFDevice
342 } // namespace mlir
343