1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/MLIRContext.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/OperationSupport.h" // from @llvm-project
32 #include "mlir/IR/TypeRange.h" // from @llvm-project
33 #include "mlir/IR/Visitors.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44
45 namespace mlir {
46 namespace TFTPU {
47
48 namespace {
49
50 constexpr char kDeviceAttr[] = "device";
51 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
52
53 struct TPUExtractOutsideCompilation
54 : public TF::TPUExtractOutsideCompilationPassBase<
55 TPUExtractOutsideCompilation> {
56 void runOnOperation() override;
57 };
58
59 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)60 bool HasOutsideCompilationNested(Operation* op) {
61 return op
62 ->walk([&](Operation* walked_op) {
63 if (op == walked_op) return WalkResult::advance();
64 if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
65 return WalkResult::interrupt();
66 }
67 return WalkResult::advance();
68 })
69 .wasInterrupted();
70 }
71
72 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)73 bool HasOutsideCompilationAncestor(Operation* op) {
74 while (op) {
75 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
76 return true;
77 }
78 op = op->getParentOp();
79 }
80 return false;
81 }
82
83 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)84 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
85 Operation* parent_op = op->getParentOp();
86 if (!parent_op) return false;
87 return HasOutsideCompilationAncestor(parent_op);
88 }
89
90 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
91 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)92 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
93 ValueRange inputs, Value compilation_key,
94 Value device_ordinal,
95 llvm::StringRef communication_key) {
96 if (device_ordinal)
97 return builder.create<TF::_XlaSendFromHostV2Op>(
98 loc, inputs,
99 /*dynamic_key=*/compilation_key, device_ordinal,
100 builder.getStringAttr(communication_key));
101
102 return builder.create<TF::_XlaSendFromHostOp>(
103 loc, inputs,
104 /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key),
105 /*device_ordinal=*/builder.getI64IntegerAttr(0));
106 }
107
108 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
109 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)110 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
111 TypeRange output_types, Value compilation_key,
112 Value device_ordinal,
113 llvm::StringRef communication_key) {
114 if (device_ordinal)
115 return builder.create<TF::_XlaRecvAtHostV2Op>(
116 loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
117 builder.getStringAttr(communication_key));
118
119 return builder.create<TF::_XlaRecvAtHostOp>(
120 loc, output_types, /*dynamic_key=*/compilation_key,
121 builder.getStringAttr(communication_key),
122 /*device_ordinal=*/builder.getI64IntegerAttr(0));
123 }
124
125 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
126 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)127 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
128 OpBuilder& builder) {
129 auto host_side_if = builder.create<TF::IfRegionOp>(
130 if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
131 if_region.is_stateless(), if_region._then_func_nameAttr(),
132 if_region._else_func_nameAttr());
133
134 // Create empty then branch region.
135 auto& then_branch = host_side_if.then_branch();
136 then_branch.push_back(new Block);
137 builder.setInsertionPointToEnd(&then_branch.front());
138 builder.create<TF::YieldOp>(if_region.getLoc(),
139 /*operands=*/ArrayRef<Value>{});
140
141 // Create empty else branch region.
142 auto& else_branch = host_side_if.else_branch();
143 else_branch.push_back(new Block);
144 builder.setInsertionPointToEnd(&else_branch.front());
145 builder.create<TF::YieldOp>(if_region.getLoc(),
146 /*operands=*/ArrayRef<Value>{});
147 return host_side_if;
148 }
149 // Creates a WhileRegionOp cond and body regions with yield op and
150 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)151 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
152 uint64_t parallel_iterations, Location loc,
153 OpBuilder& builder) {
154 auto host_side_while = builder.create<TF::WhileRegionOp>(
155 loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
156 parallel_iterations, is_stateless, /*shape_invariant=*/false);
157
158 // Create empty else branch region.
159 auto& body = host_side_while.body();
160 body.push_back(new Block);
161 builder.setInsertionPointToEnd(&body.front());
162 builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
163 return host_side_while;
164 }
165
166 // TODO(b/157054714): Use a better abstraction instead of
167 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
168 // Creates a compilation key as placeholder. A placeholder compilation cache key
169 // is created because it is a required input to _XlaRecvAtHost and
170 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
171 // cluster that contains the outside compiled ops. This placeholder should be
172 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)173 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
174 Location loc, OpBuilder& builder) {
175 auto result_type =
176 RankedTensorType::get({3}, builder.getType<TF::StringType>());
177 return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
178 loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
179 }
180
181 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)182 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
183 OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
184 // An empty string placeholder is used for the device as that will be later
185 // populated with the device of the associated TPUReplicateMetadata op.
186 auto launch_op = builder.create<tf_device::LaunchOp>(
187 loc_op->getLoc(), builder.getStringAttr(host_device),
188 /*result_types=*/ArrayRef<Type>{});
189
190 launch_op.body().push_back(new Block);
191 builder.setInsertionPointToEnd(&launch_op.GetBody());
192 builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
193 llvm::ArrayRef<Value>{});
194
195 return launch_op;
196 }
197
GetExternalOperands(tf_device::ClusterOp tpu_cluster,Operation * op)198 llvm::SmallSetVector<Value, 4> GetExternalOperands(
199 tf_device::ClusterOp tpu_cluster, Operation* op) {
200 llvm::SmallSetVector<Value, 4> external_values;
201 op->walk([&](Operation* walked_op) {
202 if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(walked_op))
203 return WalkResult::advance();
204 for (Value v : walked_op->getOperands()) {
205 if (auto* defining_op = v.getDefiningOp()) {
206 if (!op->isAncestor(defining_op) &&
207 tpu_cluster->isAncestor(defining_op) &&
208 !HasOutsideCompilationAncestor(defining_op) &&
209 !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
210 external_values.insert(v);
211 }
212 continue;
213 }
214 auto block_arg = v.cast<BlockArgument>();
215 if (block_arg.getParentRegion() == op->getParentRegion())
216 external_values.insert(v);
217 }
218 return WalkResult::advance();
219 });
220 return external_values;
221 }
222
GetExternalOutputs(Operation * op)223 llvm::SmallSetVector<Value, 4> GetExternalOutputs(Operation* op) {
224 llvm::SmallSetVector<Value, 4> external_outputs;
225 for (Operation* user : op->getUsers()) {
226 if (!HasOutsideCompilationAncestor(user)) {
227 for (Value v : user->getOperands()) {
228 if (v.getDefiningOp() == op) external_outputs.insert(v);
229 }
230 }
231 }
232 return external_outputs;
233 }
234
235 // Creates the HostCompute with `inputs` and `outputs`
236 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Operation * loc_op,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key)237 TF::_XlaHostComputeMlirOp CreateHostCompute(
238 OpBuilder& builder, Operation* loc_op,
239 const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
240 llvm::StringRef args_communication_key,
241 llvm::StringRef retvals_communication_key) {
242 llvm::SmallVector<Type, 4> device_output_types;
243 for (const auto& output : outputs)
244 device_output_types.push_back(output.getType());
245 auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
246 loc_op->getLoc(), device_output_types, inputs.getArrayRef(),
247 builder.getStringAttr(args_communication_key),
248 builder.getStringAttr(retvals_communication_key),
249 /*tpu_core=*/builder.getI64IntegerAttr(0));
250 return host_compute;
251 }
252
MarkOutsideCompiled(Operation * op)253 void MarkOutsideCompiled(Operation* op) {
254 op->setAttr(kXlaOutsideCompilationAttr,
255 StringAttr::get(op->getContext(), "temp"));
256 }
257
258 // Move outside compiled ops in `src` to to `insertion_point` in host
259 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
260 // call to this method). Communication ops are added in both `src` and at
261 // `insertion_point` using `compilation_key`, `device_ordinal` and
262 // `communication_key_index` which is incremented when used. Communication ops
263 // are added only when needed and at the location need. There are checks to
264 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)265 void MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
266 Operation* insertion_point, Value compilation_key,
267 Value device_ordinal, int& communication_key_index) {
268 OpBuilder builder(insertion_point);
269 for (Operation& op : llvm::make_early_inc_range(*src)) {
270 if (HasOutsideCompilationAncestorExclusive(&op) ||
271 !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
272 continue;
273 // Get the operands and outputs that need to be communicated between host
274 // and device. External operands are from device -> host and external
275 // outputs are from host -> device.
276 auto external_operands = GetExternalOperands(tpu_cluster, &op);
277 auto external_outputs = GetExternalOutputs(&op);
278 builder.setInsertionPoint(&op);
279 std::string args_communication_key =
280 llvm::formatv("host_compute_channel_{0}_args",
281 (communication_key_index))
282 .str();
283 if (llvm::isa<TF::IfRegionOp>(op) && external_operands.size() == 1) {
284 args_communication_key =
285 llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
286 .str();
287 }
288 std::string retvals_communication_key =
289 llvm::formatv("host_compute_channel_{0}_retvals",
290 (communication_key_index))
291 .str();
292 auto host_compute = CreateHostCompute(
293 builder, &op, external_operands, external_outputs.getArrayRef(),
294 args_communication_key, retvals_communication_key);
295 // Insert ops on the host side computation to receive data from device.
296 builder.setInsertionPoint(insertion_point);
297 llvm::SmallVector<Type, 4> host_operand_types;
298 for (const auto& operand : external_operands)
299 host_operand_types.push_back(operand.getType());
300
301 auto recv_at_host = CreateRecvAtHostOp(
302 builder, op.getLoc(), host_operand_types, compilation_key,
303 device_ordinal, args_communication_key);
304 auto original_op_block = op.getBlock();
305 op.moveAfter(recv_at_host);
306 op.removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
307 if (!external_outputs.empty()) {
308 CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
309 compilation_key, device_ordinal,
310 retvals_communication_key);
311 }
312 // Replace operand usages if op is in the same region as insertion or if
313 // the op is outside compiled and will be moved to host later.
314 auto replace_operand_usage = [&](OpOperand& operand) {
315 return insertion_point->getParentRegion()->isAncestor(
316 operand.getOwner()->getParentRegion()) ||
317 (HasOutsideCompilationAncestor(operand.getOwner()) &&
318 original_op_block == operand.getOwner()->getBlock());
319 };
320 if (external_operands.empty()) {
321 recv_at_host->erase();
322 } else {
323 for (auto result :
324 llvm::zip(external_operands, recv_at_host->getResults())) {
325 Value external_operand = std::get<0>(result);
326 external_operand.replaceUsesWithIf(std::get<1>(result),
327 replace_operand_usage);
328 }
329 }
330 // Don't replace output usages in host computation or for outside
331 // compiled ops.
332 auto replace_output_usage = [&](OpOperand& operand) {
333 return !op.getParentRegion()->isAncestor(
334 operand.getOwner()->getParentRegion()) &&
335 !HasOutsideCompilationAncestor(operand.getOwner());
336 };
337 for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
338 Value external_output = std::get<0>(result);
339 external_output.replaceUsesWithIf(std::get<1>(result),
340 replace_output_usage);
341 }
342 if (external_operands.empty() && external_outputs.empty()) {
343 host_compute.erase();
344 } else {
345 ++communication_key_index;
346 }
347 }
348 }
349
350 // Decompose control flow in `tpu_cluster` into device computation and host
351 // (outside compiled) computation into two separate control flow ops with
352 // communication between the device/host for data dependencies. Both device and
353 // host control flow initially remain within `tpu_cluster` and a subsequency
354 // call to MoveOpsToHost moves the host side control flow to the host launch in
355 // tf_device.parallel_execute. Uses `compilation_key, `device_ordinal` and
356 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)357 void DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
358 Value compilation_key, Value device_ordinal,
359 int& communication_key_index) {
360 tpu_cluster.GetBody().walk([&](Operation* op) {
361 if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
362 if (!HasOutsideCompilationNested(op)) return;
363 OpBuilder builder(if_op);
364 auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
365 MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
366 host_if.then_branch().front().getTerminator(),
367 compilation_key, device_ordinal, communication_key_index);
368 MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
369 host_if.else_branch().front().getTerminator(),
370 compilation_key, device_ordinal, communication_key_index);
371 MarkOutsideCompiled(host_if.getOperation());
372 }
373 if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
374 if (!HasOutsideCompilationNested(op)) return;
375 OpBuilder builder(while_op);
376 auto host_while = CloneEmptyWhile(while_op.is_stateless(),
377 while_op.parallel_iterations(),
378 while_op.getLoc(), builder);
379 const auto condition_send_recv_key =
380 llvm::formatv("while_condition_channel_{0}",
381 communication_key_index++)
382 .str();
383 auto& cond = host_while.cond();
384 cond.push_back(new Block);
385 auto condition = while_op.cond().front().getTerminator()->getOperand(0);
386 builder.setInsertionPoint(while_op.cond().front().getTerminator());
387 builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
388 condition_send_recv_key);
389 builder.setInsertionPointToEnd(&cond.front());
390 auto recv_condition_at_host = CreateRecvAtHostOp(
391 builder, while_op.getLoc(), TypeRange{condition.getType()},
392 compilation_key, device_ordinal, condition_send_recv_key);
393 builder.create<TF::YieldOp>(while_op.getLoc(),
394 recv_condition_at_host->getResults());
395
396 MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
397 recv_condition_at_host, compilation_key, device_ordinal,
398 communication_key_index);
399 MoveOpsToHost(tpu_cluster, &while_op.body().front(),
400 host_while.body().front().getTerminator(), compilation_key,
401 device_ordinal, communication_key_index);
402 MarkOutsideCompiled(host_while.getOperation());
403 }
404 });
405 }
406
407 // Removes outside compilation from all ops inside `host_launch_op`. Should
408 // only be run after all outside compiled ops have been moved to
409 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)410 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
411 host_launch_op.GetBody().walk([&](Operation* op) {
412 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
413 op->removeAttr(
414 Identifier::get(kXlaOutsideCompilationAttr, op->getContext()));
415 }
416 });
417 }
418
419 // Creates a `parallel_execute` op with a region for host computation and
420 // a region for `tpu_cluster` computation by extracting outside compiled ops to
421 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)422 void CreateParallelExecuteForOutsideCompilation(
423 ModuleOp module, tf_device::ClusterOp tpu_cluster,
424 llvm::StringRef host_device) {
425 OpBuilder builder(tpu_cluster);
426 // Create parallel_execute regions, one for the host computation for outside
427 // compilation and the second for the original TPU cluster computation.
428 const int num_regions = 2;
429 auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
430 tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
431 Block& host_computation_block =
432 parallel_execute_op.GetRegionBlockWithIndex(0);
433 builder.setInsertionPointToEnd(&host_computation_block);
434
435 // Create a single launch op for all outside compiled ops.
436 tf_device::LaunchOp host_launch_op =
437 CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
438 builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
439 auto compilation_key_op =
440 CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
441 Value compilation_key = compilation_key_op.program();
442 auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
443 tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
444 Value device_ordinal = nullptr;
445 if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
446 device_ordinal = device_ordinal_op.device_ordinal();
447 }
448
449 int communication_key_index = 0;
450 // Decompose control flow into device and host control flow when outside
451 // compilation is included.
452 DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
453 communication_key_index);
454
455 // Move all outside compiled ops including control flow to host launch.
456 MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
457 host_launch_op.GetBody().getTerminator(), compilation_key,
458 device_ordinal, communication_key_index);
459
460 if (communication_key_index == 0) compilation_key_op.erase();
461 if (communication_key_index == 0 || device_ordinal == nullptr)
462 device_ordinal_op.erase();
463
464 RemoveOutsideCompilation(host_launch_op);
465
466 builder.setInsertionPointToEnd(&host_computation_block);
467 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
468
469 // Move the launch body to last parallel_execute block.
470 Block& parallel_execute_tpu_block =
471 parallel_execute_op.GetRegionBlockWithIndex(1);
472 builder.setInsertionPointToEnd(¶llel_execute_tpu_block);
473 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
474 tpu_cluster.getResults());
475 tpu_cluster.getOperation()->moveBefore(
476 parallel_execute_tpu_block.getTerminator());
477
478 // Remap cluster results with parallel_execute results if user is outside of
479 // parallel_execute.
480 for (auto result :
481 llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
482 Value tpu_cluster_result = std::get<0>(result);
483 Value parallel_execute_result = std::get<1>(result);
484 for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
485 if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
486 use.set(parallel_execute_result);
487 }
488 }
489
runOnOperation()490 void TPUExtractOutsideCompilation::runOnOperation() {
491 // Get runtime devices information from the closest parent module.
492 auto module = getOperation();
493 mlir::TF::RuntimeDevices devices;
494 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
495 return signalPassFailure();
496
497 module.walk([&](tf_device::ClusterOp tpu_cluster) {
498 if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
499 std::string host_device;
500 (void)tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
501 &host_device);
502 CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
503 host_device);
504 }
505 });
506 }
507
508 } // namespace
509
510 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()511 CreateTPUExtractOutsideCompilationPass() {
512 return std::make_unique<TPUExtractOutsideCompilation>();
513 }
514
515 } // namespace TFTPU
516 } // namespace mlir
517