• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
35 #include "mlir/IR/Dialect.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44 
45 namespace mlir {
46 namespace TFDevice {
47 namespace {
48 constexpr char kDeviceAttr[] = "device";
49 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
50 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
51 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
52 
53 struct ReplicateToIslandPass
54     : public PassWrapper<ReplicateToIslandPass, FunctionPass> {
getArgumentmlir::TFDevice::__anon1565bbf70111::ReplicateToIslandPass55   StringRef getArgument() const final { return "tf-replicate-to-island"; }
56 
getDescriptionmlir::TFDevice::__anon1565bbf70111::ReplicateToIslandPass57   StringRef getDescription() const final {
58     return "Lowers device replicate to executor islands";
59   }
60 
61   void runOnFunction() override;
62 };
63 
64 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)65 bool RequiresReplicaIDAttribute(Operation* op) {
66   return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
67                    TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
68 }
69 
70 // Collects TPU device ordinal for outside compilation communication ops. This
71 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
72 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)73 llvm::Optional<int64_t> GetDeviceOrdinal(
74     const llvm::Optional<DictionaryAttr>& devices, Location loc,
75     unsigned replica_id) {
76   int64_t device_ordinal = 0;
77   if (devices.hasValue()) {
78     if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
79       llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
80                                        .cast<StringAttr>()
81                                        .getValue();
82       if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
83               loc, tpu_device, &device_ordinal))) {
84         return llvm::Optional<int64_t>(device_ordinal);
85       }
86     }
87   }
88   return llvm::None;
89 }
90 
91 // Updates replica variant ops in a region based on replica `replica_id`.
92 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
93 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
94 // ops, require replica id to be added as an op attribute to be used during
95 // execution. Handle such ops separately and add an integer attribute that
96 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)97 LogicalResult UpdateRegionReplicateVariantOps(
98     OpBuilder& builder, Location loc, Region& region, int replica_id,
99     const llvm::Optional<DictionaryAttr>& devices) {
100   llvm::Optional<int64_t> device_ordinal =
101       GetDeviceOrdinal(devices, loc, replica_id);
102 
103   auto result = region.walk([&](Operation* op) -> WalkResult {
104     if (RequiresReplicaIDAttribute(op)) {
105       op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
106       return WalkResult::advance();
107     }
108 
109     if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
110       if (!device_ordinal.hasValue())
111         return op->emitOpError()
112                << "requires device ordinal from device " << kTPUCore0
113                << " to be present in 'tf.device.replicate' op";
114 
115       OpBuilder builder(op);
116       auto const_op = builder.create<TF::ConstOp>(
117           op->getLoc(), DenseIntElementsAttr::get(
118                             RankedTensorType::get({}, builder.getI64Type()),
119                             {device_ordinal.getValue()}));
120       op->replaceAllUsesWith(const_op);
121       op->erase();
122       return WalkResult::advance();
123     }
124 
125     if (!devices.hasValue()) return WalkResult::advance();
126 
127     // Map aliased devices to explicit devices based on replica.
128     if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
129       if (auto device_by_replica = devices.getValue().get(launch.device()))
130         launch->setAttr(
131             kDeviceAttr,
132             device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
133 
134     return WalkResult::advance();
135   });
136 
137   return failure(result.wasInterrupted());
138 }
139 
140 // Creates islands per replica from `tf_device.replicate` region. If for a
141 // `tf_device.launch` op the device is an aliased device of the
142 // `tf_device.replicate`, the device will be remapped to an explicit device
143 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)144 LogicalResult ExpandReplicateIntoReplicas(
145     const Dialect* tf_dialect, OpBuilder& builder,
146     tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
147     int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
148   replicas.reserve(num_replicas);
149   auto devices = replicate_op.devices();
150 
151   // Collect result types and operands.
152   Operation& terminator = replicate_op.GetBody().back();
153   llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
154   auto control_type = tf_executor::ControlType::get(island_op.getContext());
155   llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
156 
157   // Replace replicate terminator with YieldOp.
158   builder.setInsertionPoint(&terminator);
159   builder.create<tf_executor::YieldOp>(terminator.getLoc(),
160                                        terminator.getOperands());
161   terminator.erase();
162 
163   builder.setInsertionPoint(island_op);
164   BlockAndValueMapping mapping;
165   for (int i : llvm::seq<int>(0, num_replicas)) {
166     // Create new island for replica.
167     auto replica = builder.create<tf_executor::IslandOp>(
168         island_op.getLoc(), output_types, control_type, replica_inputs);
169 
170     // Map block arg to replica arg.
171     mapping.clear();
172     for (auto& block_arg : replicate_op.GetBody().getArguments())
173       mapping.map(block_arg,
174                   replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
175 
176     // Copy over replicate region into replica island.
177     replicate_op.body().cloneInto(&replica.body(), mapping);
178 
179     if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
180                                                replica.body(),
181                                                /*replica_id=*/i, devices)))
182       return failure();
183 
184     replicas.push_back(replica);
185   }
186 
187   return success();
188 }
189 
190 // Creates islands per replica from `tf_device.replicate` region and remap
191 // replicate results with new island outputs. A single island is created to
192 // forward control dependencies if there is a control dependency output from the
193 // replicate island. Devices are remapped from aliased devices to explicit
194 // devices, for `tf_device.launch` ops.
195 //
196 // For example, the following:
197 //
198 // %0:2 = tf_executor.island(%control) {
199 //   %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
200 //              {n = 2 : i32,
201 //               devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
202 //                          DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
203 //     %a = "tf_device.launch"() ( {
204 //       %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
205 //       tf_device.return %2 : tensor<i1>
206 //     }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
207 //     %b = "tf_device.launch"() ( {
208 //       %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
209 //       tf_device.return %3 : tensor<i1>
210 //     }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
211 //     tf_device.return %a, %b : tensor<i1>, tensor<i1>
212 //   }
213 //   tf_executor.yield %1#0 : tensor<i1>
214 // }
215 //
216 // gets lowered to:
217 //
218 // %0:3 = tf_executor.island(%control) {
219 //   %a0 = "tf_device.launch"() ( {
220 //     %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
221 //     tf_device.return %1 : tensor<i1>
222 //   }) {device = "/DEVICE:0"} : () -> tensor<i1>
223 //   %b0 = "tf_device.launch"() ( {
224 //     %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
225 //     tf_device.return %2 : tensor<i1>
226 //   }) {device = "/DEVICE:2"} : () -> tensor<i1>
227 //   tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
228 // }
229 // %3:3 = tf_executor.island(%control) {
230 //   %a1 = "tf_device.launch"() ( {
231 //     %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
232 //     tf_device.return %4 : tensor<i1>
233 //   }) {device = "/DEVICE:1"} : () -> tensor<i1>
234 //   %b1 = "tf_device.launch"() ( {
235 //     %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
236 //     tf_device.return %5 : tensor<i1>
237 //   }) {device = "/DEVICE:3"} : () -> tensor<i1>
238 //   tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
239 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)240 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
241                                          tf_executor::GraphOp graph_op,
242                                          tf_executor::IslandOp island_op,
243                                          tf_device::ReplicateOp replicate_op) {
244   OpBuilder builder(island_op);
245   const int num_replicas = replicate_op.n();
246 
247   // Create islands per replica.
248   llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
249   if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
250                                          replicate_op, num_replicas, replicas)))
251     return failure();
252 
253   // Collect all replica results.
254   llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
255                                                nullptr);
256   for (auto replica_and_idx : llvm::enumerate(replicas))
257     for (auto replica_result_and_idx :
258          llvm::enumerate(replica_and_idx.value().outputs()))
259       replicas_outputs[num_replicas * replica_result_and_idx.index() +
260                        replica_and_idx.index()] =
261           replica_result_and_idx.value();
262 
263   // Remap replicate results to per replica result.
264   for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
265     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
266 
267   // Add sink island to pin all replicas as a control dependency if there is a
268   // control dependency leading from the replicate originally.
269   if (!island_op.control().use_empty()) {
270     llvm::SmallVector<Value, 8> island_operands;
271     for (auto& replica : replicas) island_operands.push_back(replica.control());
272 
273     builder.setInsertionPoint(island_op);
274     auto island_sink = builder.create<tf_executor::IslandOp>(
275         island_op.getLoc(), llvm::ArrayRef<Type>{},
276         tf_executor::ControlType::get(island_op.getContext()), island_operands);
277     island_sink.body().push_back(new Block);
278     builder.setInsertionPointToEnd(&island_sink.GetBody());
279     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
280                                          llvm::ArrayRef<Value>{});
281     island_op.control().replaceAllUsesWith(island_sink.control());
282   }
283 
284   // Replicas with no uses should be pinned to a graph fetch so they still
285   // execute.
286   llvm::SmallVector<Value, 8> unused_replica_controls;
287   for (auto& replica : replicas)
288     if (replica.use_empty())
289       unused_replica_controls.push_back(replica.control());
290 
291   if (!unused_replica_controls.empty()) {
292     tf_executor::FetchOp fetch = graph_op.GetFetch();
293     auto fetches = llvm::to_vector<8>(fetch.getOperands());
294     fetches.append(unused_replica_controls.begin(),
295                    unused_replica_controls.end());
296     builder.setInsertionPoint(fetch);
297     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
298     fetch.erase();
299   }
300 
301   island_op.erase();
302   return success();
303 }
304 
runOnFunction()305 void ReplicateToIslandPass::runOnFunction() {
306   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
307   if (!tf_dialect) {
308     getOperation().emitError() << "'tf' dialect is not registered";
309     return signalPassFailure();
310   }
311 
312   // Find islands with a single `tf_device.replicate` and create individual
313   // islands per replica of the replicate.
314   llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
315   getOperation().walk([&](tf_executor::GraphOp graph_op) {
316     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
317       if (!island_op.WrapsSingleOp()) continue;
318 
319       if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
320         replicate_op_islands.push_back(island_op);
321     }
322   });
323 
324   for (tf_executor::IslandOp island_op : replicate_op_islands) {
325     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
326     auto replicate_op =
327         cast<tf_device::ReplicateOp>(island_op.GetBody().front());
328     if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
329                                           replicate_op)))
330       return signalPassFailure();
331   }
332 }
333 }  // anonymous namespace
334 
CreateReplicateToIslandPass()335 std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
336   return std::make_unique<ReplicateToIslandPass>();
337 }
338 
339 static PassRegistration<ReplicateToIslandPass> pass;
340 
341 }  // namespace TFDevice
342 }  // namespace mlir
343