• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
32 #include "mlir/IR/TypeRange.h"  // from @llvm-project
33 #include "mlir/IR/Visitors.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44 
45 namespace mlir {
46 namespace TFTPU {
47 
48 namespace {
49 
50 constexpr char kDeviceAttr[] = "device";
51 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
52 
53 struct TPUExtractOutsideCompilation
54     : public TF::TPUExtractOutsideCompilationPassBase<
55           TPUExtractOutsideCompilation> {
56   void runOnOperation() override;
57 };
58 
59 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)60 bool HasOutsideCompilationNested(Operation* op) {
61   return op
62       ->walk([&](Operation* walked_op) {
63         if (op == walked_op) return WalkResult::advance();
64         if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
65           return WalkResult::interrupt();
66         }
67         return WalkResult::advance();
68       })
69       .wasInterrupted();
70 }
71 
72 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)73 bool HasOutsideCompilationAncestor(Operation* op) {
74   while (op) {
75     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
76       return true;
77     }
78     op = op->getParentOp();
79   }
80   return false;
81 }
82 
83 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)84 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
85   Operation* parent_op = op->getParentOp();
86   if (!parent_op) return false;
87   return HasOutsideCompilationAncestor(parent_op);
88 }
89 
90 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
91 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)92 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
93                                 ValueRange inputs, Value compilation_key,
94                                 Value device_ordinal,
95                                 llvm::StringRef communication_key) {
96   if (device_ordinal)
97     return builder.create<TF::_XlaSendFromHostV2Op>(
98         loc, inputs,
99         /*dynamic_key=*/compilation_key, device_ordinal,
100         builder.getStringAttr(communication_key));
101 
102   return builder.create<TF::_XlaSendFromHostOp>(
103       loc, inputs,
104       /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key),
105       /*device_ordinal=*/builder.getI64IntegerAttr(0));
106 }
107 
108 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
109 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)110 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
111                               TypeRange output_types, Value compilation_key,
112                               Value device_ordinal,
113                               llvm::StringRef communication_key) {
114   if (device_ordinal)
115     return builder.create<TF::_XlaRecvAtHostV2Op>(
116         loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
117         builder.getStringAttr(communication_key));
118 
119   return builder.create<TF::_XlaRecvAtHostOp>(
120       loc, output_types, /*dynamic_key=*/compilation_key,
121       builder.getStringAttr(communication_key),
122       /*device_ordinal=*/builder.getI64IntegerAttr(0));
123 }
124 
125 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
126 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)127 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
128                                          OpBuilder& builder) {
129   auto host_side_if = builder.create<TF::IfRegionOp>(
130       if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
131       if_region.is_stateless(), if_region._then_func_nameAttr(),
132       if_region._else_func_nameAttr());
133 
134   // Create empty then branch region.
135   auto& then_branch = host_side_if.then_branch();
136   then_branch.push_back(new Block);
137   builder.setInsertionPointToEnd(&then_branch.front());
138   builder.create<TF::YieldOp>(if_region.getLoc(),
139                               /*operands=*/ArrayRef<Value>{});
140 
141   // Create empty else branch region.
142   auto& else_branch = host_side_if.else_branch();
143   else_branch.push_back(new Block);
144   builder.setInsertionPointToEnd(&else_branch.front());
145   builder.create<TF::YieldOp>(if_region.getLoc(),
146                               /*operands=*/ArrayRef<Value>{});
147   return host_side_if;
148 }
149 // Creates a WhileRegionOp cond and body regions with yield op and
150 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)151 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
152                                   uint64_t parallel_iterations, Location loc,
153                                   OpBuilder& builder) {
154   auto host_side_while = builder.create<TF::WhileRegionOp>(
155       loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
156       parallel_iterations, is_stateless, /*shape_invariant=*/false);
157 
158   // Create empty else branch region.
159   auto& body = host_side_while.body();
160   body.push_back(new Block);
161   builder.setInsertionPointToEnd(&body.front());
162   builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
163   return host_side_while;
164 }
165 
166 // TODO(b/157054714): Use a better abstraction instead of
167 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
168 // Creates a compilation key as placeholder. A placeholder compilation cache key
169 // is created because it is a required input to _XlaRecvAtHost and
170 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
171 // cluster that contains the outside compiled ops. This placeholder should be
172 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)173 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
174     Location loc, OpBuilder& builder) {
175   auto result_type =
176       RankedTensorType::get({3}, builder.getType<TF::StringType>());
177   return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
178       loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
179 }
180 
181 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)182 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
183     OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
184   // An empty string placeholder is used for the device as that will be later
185   // populated with the device of the associated TPUReplicateMetadata op.
186   auto launch_op = builder.create<tf_device::LaunchOp>(
187       loc_op->getLoc(), builder.getStringAttr(host_device),
188       /*result_types=*/ArrayRef<Type>{});
189 
190   launch_op.body().push_back(new Block);
191   builder.setInsertionPointToEnd(&launch_op.GetBody());
192   builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
193                                       llvm::ArrayRef<Value>{});
194 
195   return launch_op;
196 }
197 
GetExternalOperands(tf_device::ClusterOp tpu_cluster,Operation * op)198 llvm::SmallSetVector<Value, 4> GetExternalOperands(
199     tf_device::ClusterOp tpu_cluster, Operation* op) {
200   llvm::SmallSetVector<Value, 4> external_values;
201   op->walk([&](Operation* walked_op) {
202     if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(walked_op))
203       return WalkResult::advance();
204     for (Value v : walked_op->getOperands()) {
205       if (auto* defining_op = v.getDefiningOp()) {
206         if (!op->isAncestor(defining_op) &&
207             tpu_cluster->isAncestor(defining_op) &&
208             !HasOutsideCompilationAncestor(defining_op) &&
209             !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
210           external_values.insert(v);
211         }
212         continue;
213       }
214       auto block_arg = v.cast<BlockArgument>();
215       if (block_arg.getParentRegion() == op->getParentRegion())
216         external_values.insert(v);
217     }
218     return WalkResult::advance();
219   });
220   return external_values;
221 }
222 
GetExternalOutputs(Operation * op)223 llvm::SmallSetVector<Value, 4> GetExternalOutputs(Operation* op) {
224   llvm::SmallSetVector<Value, 4> external_outputs;
225   for (Operation* user : op->getUsers()) {
226     if (!HasOutsideCompilationAncestor(user)) {
227       for (Value v : user->getOperands()) {
228         if (v.getDefiningOp() == op) external_outputs.insert(v);
229       }
230     }
231   }
232   return external_outputs;
233 }
234 
235 // Creates the HostCompute with `inputs` and `outputs`
236 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Operation * loc_op,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key)237 TF::_XlaHostComputeMlirOp CreateHostCompute(
238     OpBuilder& builder, Operation* loc_op,
239     const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
240     llvm::StringRef args_communication_key,
241     llvm::StringRef retvals_communication_key) {
242   llvm::SmallVector<Type, 4> device_output_types;
243   for (const auto& output : outputs)
244     device_output_types.push_back(output.getType());
245   auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
246       loc_op->getLoc(), device_output_types, inputs.getArrayRef(),
247       builder.getStringAttr(args_communication_key),
248       builder.getStringAttr(retvals_communication_key),
249       /*tpu_core=*/builder.getI64IntegerAttr(0));
250   return host_compute;
251 }
252 
MarkOutsideCompiled(Operation * op)253 void MarkOutsideCompiled(Operation* op) {
254   op->setAttr(kXlaOutsideCompilationAttr,
255               StringAttr::get(op->getContext(), "temp"));
256 }
257 
258 // Move outside compiled ops in `src` to to `insertion_point` in host
259 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
260 // call to this method).  Communication ops are added in both `src` and at
261 // `insertion_point` using `compilation_key`, `device_ordinal` and
262 // `communication_key_index` which is incremented when used. Communication ops
263 // are added only when needed and at the location need.  There are checks to
264 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)265 void MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
266                    Operation* insertion_point, Value compilation_key,
267                    Value device_ordinal, int& communication_key_index) {
268   OpBuilder builder(insertion_point);
269   for (Operation& op : llvm::make_early_inc_range(*src)) {
270     if (HasOutsideCompilationAncestorExclusive(&op) ||
271         !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
272       continue;
273     // Get the operands and outputs that need to be communicated between host
274     // and device.  External operands are from device -> host and external
275     // outputs are from host -> device.
276     auto external_operands = GetExternalOperands(tpu_cluster, &op);
277     auto external_outputs = GetExternalOutputs(&op);
278     builder.setInsertionPoint(&op);
279     std::string args_communication_key =
280         llvm::formatv("host_compute_channel_{0}_args",
281                       (communication_key_index))
282             .str();
283     if (llvm::isa<TF::IfRegionOp>(op) && external_operands.size() == 1) {
284       args_communication_key =
285           llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
286               .str();
287     }
288     std::string retvals_communication_key =
289         llvm::formatv("host_compute_channel_{0}_retvals",
290                       (communication_key_index))
291             .str();
292     auto host_compute = CreateHostCompute(
293         builder, &op, external_operands, external_outputs.getArrayRef(),
294         args_communication_key, retvals_communication_key);
295     // Insert ops on the host side computation to receive data from device.
296     builder.setInsertionPoint(insertion_point);
297     llvm::SmallVector<Type, 4> host_operand_types;
298     for (const auto& operand : external_operands)
299       host_operand_types.push_back(operand.getType());
300 
301     auto recv_at_host = CreateRecvAtHostOp(
302         builder, op.getLoc(), host_operand_types, compilation_key,
303         device_ordinal, args_communication_key);
304     auto original_op_block = op.getBlock();
305     op.moveAfter(recv_at_host);
306     op.removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
307     if (!external_outputs.empty()) {
308       CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
309                            compilation_key, device_ordinal,
310                            retvals_communication_key);
311     }
312     // Replace operand usages if op is in the same region as insertion or if
313     // the op is outside compiled and will be moved to host later.
314     auto replace_operand_usage = [&](OpOperand& operand) {
315       return insertion_point->getParentRegion()->isAncestor(
316                  operand.getOwner()->getParentRegion()) ||
317              (HasOutsideCompilationAncestor(operand.getOwner()) &&
318               original_op_block == operand.getOwner()->getBlock());
319     };
320     if (external_operands.empty()) {
321       recv_at_host->erase();
322     } else {
323       for (auto result :
324            llvm::zip(external_operands, recv_at_host->getResults())) {
325         Value external_operand = std::get<0>(result);
326         external_operand.replaceUsesWithIf(std::get<1>(result),
327                                            replace_operand_usage);
328       }
329     }
330     // Don't replace output usages in host computation or for outside
331     // compiled ops.
332     auto replace_output_usage = [&](OpOperand& operand) {
333       return !op.getParentRegion()->isAncestor(
334                  operand.getOwner()->getParentRegion()) &&
335              !HasOutsideCompilationAncestor(operand.getOwner());
336     };
337     for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
338       Value external_output = std::get<0>(result);
339       external_output.replaceUsesWithIf(std::get<1>(result),
340                                         replace_output_usage);
341     }
342     if (external_operands.empty() && external_outputs.empty()) {
343       host_compute.erase();
344     } else {
345       ++communication_key_index;
346     }
347   }
348 }
349 
350 // Decompose control flow in `tpu_cluster` into device computation and host
351 // (outside compiled) computation into two separate control flow ops with
352 // communication between the device/host for data dependencies.  Both device and
353 // host control flow initially remain within `tpu_cluster` and a subsequency
354 // call to MoveOpsToHost moves the host side control flow to the host launch in
355 // tf_device.parallel_execute.  Uses `compilation_key, `device_ordinal` and
356 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)357 void DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
358                           Value compilation_key, Value device_ordinal,
359                           int& communication_key_index) {
360   tpu_cluster.GetBody().walk([&](Operation* op) {
361     if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
362       if (!HasOutsideCompilationNested(op)) return;
363       OpBuilder builder(if_op);
364       auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
365       MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
366                     host_if.then_branch().front().getTerminator(),
367                     compilation_key, device_ordinal, communication_key_index);
368       MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
369                     host_if.else_branch().front().getTerminator(),
370                     compilation_key, device_ordinal, communication_key_index);
371       MarkOutsideCompiled(host_if.getOperation());
372     }
373     if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
374       if (!HasOutsideCompilationNested(op)) return;
375       OpBuilder builder(while_op);
376       auto host_while = CloneEmptyWhile(while_op.is_stateless(),
377                                         while_op.parallel_iterations(),
378                                         while_op.getLoc(), builder);
379       const auto condition_send_recv_key =
380           llvm::formatv("while_condition_channel_{0}",
381                         communication_key_index++)
382               .str();
383       auto& cond = host_while.cond();
384       cond.push_back(new Block);
385       auto condition = while_op.cond().front().getTerminator()->getOperand(0);
386       builder.setInsertionPoint(while_op.cond().front().getTerminator());
387       builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
388                                           condition_send_recv_key);
389       builder.setInsertionPointToEnd(&cond.front());
390       auto recv_condition_at_host = CreateRecvAtHostOp(
391           builder, while_op.getLoc(), TypeRange{condition.getType()},
392           compilation_key, device_ordinal, condition_send_recv_key);
393       builder.create<TF::YieldOp>(while_op.getLoc(),
394                                   recv_condition_at_host->getResults());
395 
396       MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
397                     recv_condition_at_host, compilation_key, device_ordinal,
398                     communication_key_index);
399       MoveOpsToHost(tpu_cluster, &while_op.body().front(),
400                     host_while.body().front().getTerminator(), compilation_key,
401                     device_ordinal, communication_key_index);
402       MarkOutsideCompiled(host_while.getOperation());
403     }
404   });
405 }
406 
407 // Removes outside compilation from all ops inside `host_launch_op`.  Should
408 // only be run after all outside compiled ops have been moved to
409 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)410 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
411   host_launch_op.GetBody().walk([&](Operation* op) {
412     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
413       op->removeAttr(
414           Identifier::get(kXlaOutsideCompilationAttr, op->getContext()));
415     }
416   });
417 }
418 
419 // Creates a `parallel_execute` op with a region for host computation and
420 // a region for `tpu_cluster` computation by extracting outside compiled ops to
421 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)422 void CreateParallelExecuteForOutsideCompilation(
423     ModuleOp module, tf_device::ClusterOp tpu_cluster,
424     llvm::StringRef host_device) {
425   OpBuilder builder(tpu_cluster);
426   // Create parallel_execute regions, one for the host computation for outside
427   // compilation and the second for the original TPU cluster computation.
428   const int num_regions = 2;
429   auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
430       tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
431   Block& host_computation_block =
432       parallel_execute_op.GetRegionBlockWithIndex(0);
433   builder.setInsertionPointToEnd(&host_computation_block);
434 
435   // Create a single launch op for all outside compiled ops.
436   tf_device::LaunchOp host_launch_op =
437       CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
438   builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
439   auto compilation_key_op =
440       CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
441   Value compilation_key = compilation_key_op.program();
442   auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
443       tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
444   Value device_ordinal = nullptr;
445   if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
446     device_ordinal = device_ordinal_op.device_ordinal();
447   }
448 
449   int communication_key_index = 0;
450   // Decompose control flow into device and host control flow when outside
451   // compilation is included.
452   DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
453                        communication_key_index);
454 
455   // Move all outside compiled ops including control flow to host launch.
456   MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
457                 host_launch_op.GetBody().getTerminator(), compilation_key,
458                 device_ordinal, communication_key_index);
459 
460   if (communication_key_index == 0) compilation_key_op.erase();
461   if (communication_key_index == 0 || device_ordinal == nullptr)
462     device_ordinal_op.erase();
463 
464   RemoveOutsideCompilation(host_launch_op);
465 
466   builder.setInsertionPointToEnd(&host_computation_block);
467   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
468 
469   // Move the launch body to last parallel_execute block.
470   Block& parallel_execute_tpu_block =
471       parallel_execute_op.GetRegionBlockWithIndex(1);
472   builder.setInsertionPointToEnd(&parallel_execute_tpu_block);
473   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
474                                       tpu_cluster.getResults());
475   tpu_cluster.getOperation()->moveBefore(
476       parallel_execute_tpu_block.getTerminator());
477 
478   // Remap cluster results with parallel_execute results if user is outside of
479   // parallel_execute.
480   for (auto result :
481        llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
482     Value tpu_cluster_result = std::get<0>(result);
483     Value parallel_execute_result = std::get<1>(result);
484     for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
485       if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
486         use.set(parallel_execute_result);
487   }
488 }
489 
runOnOperation()490 void TPUExtractOutsideCompilation::runOnOperation() {
491   // Get runtime devices information from the closest parent module.
492   auto module = getOperation();
493   mlir::TF::RuntimeDevices devices;
494   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
495     return signalPassFailure();
496 
497   module.walk([&](tf_device::ClusterOp tpu_cluster) {
498     if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
499       std::string host_device;
500       (void)tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
501                                                         &host_device);
502       CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
503                                                  host_device);
504     }
505   });
506 }
507 
508 }  // namespace
509 
510 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()511 CreateTPUExtractOutsideCompilationPass() {
512   return std::make_unique<TPUExtractOutsideCompilation>();
513 }
514 
515 }  // namespace TFTPU
516 }  // namespace mlir
517