• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
35 #include "mlir/IR/Dialect.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44 
45 namespace mlir {
46 namespace TFDevice {
47 namespace {
48 constexpr char kDeviceAttr[] = "device";
49 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
50 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
51 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
52 
53 struct ReplicateToIslandPass
54     : public PassWrapper<ReplicateToIslandPass, FunctionPass> {
55   void runOnFunction() override;
56 };
57 
58 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)59 bool RequiresReplicaIDAttribute(Operation* op) {
60   return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
61                    TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
62 }
63 
64 // Collects TPU device ordinal for outside compilation communication ops. This
65 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
66 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)67 llvm::Optional<int64_t> GetDeviceOrdinal(
68     const llvm::Optional<DictionaryAttr>& devices, Location loc,
69     unsigned replica_id) {
70   int64_t device_ordinal = 0;
71   if (devices.hasValue()) {
72     if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
73       llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
74                                        .cast<StringAttr>()
75                                        .getValue();
76       if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
77               loc, tpu_device, &device_ordinal))) {
78         return llvm::Optional<int64_t>(device_ordinal);
79       }
80     }
81   }
82   return llvm::None;
83 }
84 
85 // Updates replica variant ops in a region based on replica `replica_id`.
86 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
87 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
88 // ops, require replica id to be added as an op attribute to be used during
89 // execution. Handle such ops separately and add an integer attribute that
90 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)91 LogicalResult UpdateRegionReplicateVariantOps(
92     OpBuilder& builder, Location loc, Region& region, int replica_id,
93     const llvm::Optional<DictionaryAttr>& devices) {
94   llvm::Optional<int64_t> device_ordinal =
95       GetDeviceOrdinal(devices, loc, replica_id);
96 
97   auto result = region.walk([&](Operation* op) -> WalkResult {
98     if (RequiresReplicaIDAttribute(op)) {
99       op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
100       return WalkResult::advance();
101     }
102 
103     if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
104       if (!device_ordinal.hasValue())
105         return op->emitOpError()
106                << "requires device ordinal from device " << kTPUCore0
107                << " to be present in 'tf.device.replicate' op";
108 
109       OpBuilder builder(op);
110       auto const_op = builder.create<TF::ConstOp>(
111           op->getLoc(), DenseIntElementsAttr::get(
112                             RankedTensorType::get({}, builder.getI64Type()),
113                             {device_ordinal.getValue()}));
114       op->replaceAllUsesWith(const_op);
115       op->erase();
116       return WalkResult::advance();
117     }
118 
119     if (!devices.hasValue()) return WalkResult::advance();
120 
121     // Map aliased devices to explicit devices based on replica.
122     if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
123       if (auto device_by_replica = devices.getValue().get(launch.device()))
124         launch->setAttr(
125             kDeviceAttr,
126             device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
127 
128     return WalkResult::advance();
129   });
130 
131   return failure(result.wasInterrupted());
132 }
133 
134 // Creates islands per replica from `tf_device.replicate` region. If for a
135 // `tf_device.launch` op the device is an aliased device of the
136 // `tf_device.replicate`, the device will be remapped to an explicit device
137 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)138 LogicalResult ExpandReplicateIntoReplicas(
139     const Dialect* tf_dialect, OpBuilder& builder,
140     tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
141     int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
142   replicas.reserve(num_replicas);
143   auto devices = replicate_op.devices();
144 
145   // Collect result types and operands.
146   Operation& terminator = replicate_op.GetBody().back();
147   llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
148   auto control_type = tf_executor::ControlType::get(island_op.getContext());
149   llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
150 
151   // Replace replicate terminator with YieldOp.
152   builder.setInsertionPoint(&terminator);
153   builder.create<tf_executor::YieldOp>(terminator.getLoc(),
154                                        terminator.getOperands());
155   terminator.erase();
156 
157   builder.setInsertionPoint(island_op);
158   BlockAndValueMapping mapping;
159   for (int i : llvm::seq<int>(0, num_replicas)) {
160     // Create new island for replica.
161     auto replica = builder.create<tf_executor::IslandOp>(
162         island_op.getLoc(), output_types, control_type, replica_inputs);
163 
164     // Map block arg to replica arg.
165     mapping.clear();
166     for (auto& block_arg : replicate_op.GetBody().getArguments())
167       mapping.map(block_arg,
168                   replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
169 
170     // Copy over replicate region into replica island.
171     replicate_op.body().cloneInto(&replica.body(), mapping);
172 
173     if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
174                                                replica.body(),
175                                                /*replica_id=*/i, devices)))
176       return failure();
177 
178     replicas.push_back(replica);
179   }
180 
181   return success();
182 }
183 
184 // Creates islands per replica from `tf_device.replicate` region and remap
185 // replicate results with new island outputs. A single island is created to
186 // forward control dependencies if there is a control dependency output from the
187 // replicate island. Devices are remapped from aliased devices to explicit
188 // devices, for `tf_device.launch` ops.
189 //
190 // For example, the following:
191 //
192 // %0:2 = tf_executor.island(%control) {
193 //   %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
194 //              {n = 2 : i32,
195 //               devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
196 //                          DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
197 //     %a = "tf_device.launch"() ( {
198 //       %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
199 //       tf_device.return %2 : tensor<i1>
200 //     }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
201 //     %b = "tf_device.launch"() ( {
202 //       %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
203 //       tf_device.return %3 : tensor<i1>
204 //     }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
205 //     tf_device.return %a, %b : tensor<i1>, tensor<i1>
206 //   }
207 //   tf_executor.yield %1#0 : tensor<i1>
208 // }
209 //
210 // gets lowered to:
211 //
212 // %0:3 = tf_executor.island(%control) {
213 //   %a0 = "tf_device.launch"() ( {
214 //     %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
215 //     tf_device.return %1 : tensor<i1>
216 //   }) {device = "/DEVICE:0"} : () -> tensor<i1>
217 //   %b0 = "tf_device.launch"() ( {
218 //     %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
219 //     tf_device.return %2 : tensor<i1>
220 //   }) {device = "/DEVICE:2"} : () -> tensor<i1>
221 //   tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
222 // }
223 // %3:3 = tf_executor.island(%control) {
224 //   %a1 = "tf_device.launch"() ( {
225 //     %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
226 //     tf_device.return %4 : tensor<i1>
227 //   }) {device = "/DEVICE:1"} : () -> tensor<i1>
228 //   %b1 = "tf_device.launch"() ( {
229 //     %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
230 //     tf_device.return %5 : tensor<i1>
231 //   }) {device = "/DEVICE:3"} : () -> tensor<i1>
232 //   tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
233 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)234 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
235                                          tf_executor::GraphOp graph_op,
236                                          tf_executor::IslandOp island_op,
237                                          tf_device::ReplicateOp replicate_op) {
238   OpBuilder builder(island_op);
239   const int num_replicas = replicate_op.n();
240 
241   // Create islands per replica.
242   llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
243   if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
244                                          replicate_op, num_replicas, replicas)))
245     return failure();
246 
247   // Collect all replica results.
248   llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
249                                                nullptr);
250   for (auto replica_and_idx : llvm::enumerate(replicas))
251     for (auto replica_result_and_idx :
252          llvm::enumerate(replica_and_idx.value().outputs()))
253       replicas_outputs[num_replicas * replica_result_and_idx.index() +
254                        replica_and_idx.index()] =
255           replica_result_and_idx.value();
256 
257   // Remap replicate results to per replica result.
258   for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
259     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
260 
261   // Add sink island to pin all replicas as a control dependency if there is a
262   // control dependency leading from the replicate originally.
263   if (!island_op.control().use_empty()) {
264     llvm::SmallVector<Value, 8> island_operands;
265     for (auto& replica : replicas) island_operands.push_back(replica.control());
266 
267     builder.setInsertionPoint(island_op);
268     auto island_sink = builder.create<tf_executor::IslandOp>(
269         island_op.getLoc(), llvm::ArrayRef<Type>{},
270         tf_executor::ControlType::get(island_op.getContext()), island_operands);
271     island_sink.body().push_back(new Block);
272     builder.setInsertionPointToEnd(&island_sink.GetBody());
273     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
274                                          llvm::ArrayRef<Value>{});
275     island_op.control().replaceAllUsesWith(island_sink.control());
276   }
277 
278   // Replicas with no uses should be pinned to a graph fetch so they still
279   // execute.
280   llvm::SmallVector<Value, 8> unused_replica_controls;
281   for (auto& replica : replicas)
282     if (replica.use_empty())
283       unused_replica_controls.push_back(replica.control());
284 
285   if (!unused_replica_controls.empty()) {
286     tf_executor::FetchOp fetch = graph_op.GetFetch();
287     auto fetches = llvm::to_vector<8>(fetch.getOperands());
288     fetches.append(unused_replica_controls.begin(),
289                    unused_replica_controls.end());
290     builder.setInsertionPoint(fetch);
291     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
292     fetch.erase();
293   }
294 
295   island_op.erase();
296   return success();
297 }
298 
runOnFunction()299 void ReplicateToIslandPass::runOnFunction() {
300   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
301   if (!tf_dialect) {
302     getOperation().emitError() << "'tf' dialect is not registered";
303     return signalPassFailure();
304   }
305 
306   // Find islands with a single `tf_device.replicate` and create individual
307   // islands per replica of the replicate.
308   llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
309   getOperation().walk([&](tf_executor::GraphOp graph_op) {
310     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
311       if (!island_op.WrapsSingleOp()) continue;
312 
313       if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
314         replicate_op_islands.push_back(island_op);
315     }
316   });
317 
318   for (tf_executor::IslandOp island_op : replicate_op_islands) {
319     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
320     auto replicate_op =
321         cast<tf_device::ReplicateOp>(island_op.GetBody().front());
322     if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
323                                           replicate_op)))
324       return signalPassFailure();
325   }
326 }
327 }  // anonymous namespace
328 
CreateReplicateToIslandPass()329 std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass() {
330   return std::make_unique<ReplicateToIslandPass>();
331 }
332 
333 static PassRegistration<ReplicateToIslandPass> pass(
334     "tf-replicate-to-island", "Lowers device replicate to executor islands");
335 
336 }  // namespace TFDevice
337 }  // namespace mlir
338