• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/TypeRange.h"  // from @llvm-project
35 #include "mlir/IR/Visitors.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 
49 namespace mlir {
50 namespace TFTPU {
51 
52 namespace {
53 
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kHostFunctionAttr[] = "host_func";
56 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
57 
58 struct TPUExtractOutsideCompilation
59     : public TF::TPUExtractOutsideCompilationPassBase<
60           TPUExtractOutsideCompilation> {
61   void runOnOperation() override;
62 };
63 
64 // Build a function containing `ops` with `inputs` and `outputs` using
65 // `builder`.  The `ops` are cloned and modified to use the function arguments
66 // as inputs.
BuildFunction(llvm::ArrayRef<Operation * > ops,llvm::ArrayRef<Value> inputs,llvm::ArrayRef<Value> outputs,OpBuilder * builder)67 FuncOp BuildFunction(llvm::ArrayRef<Operation*> ops,
68                      llvm::ArrayRef<Value> inputs,
69                      llvm::ArrayRef<Value> outputs, OpBuilder* builder) {
70   llvm::SmallVector<Type, 4> operand_types;
71   operand_types.reserve(inputs.size());
72   for (Value v : inputs) operand_types.emplace_back(v.getType());
73   llvm::SmallVector<Type, 4> output_types;
74   output_types.reserve(outputs.size());
75   for (Value v : outputs) output_types.emplace_back(v.getType());
76 
77   auto func_type = builder->getFunctionType(operand_types, output_types);
78 
79   FuncOp outlined_func =
80       FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type);
81 
82   // Create function body.
83   Block* outlined_func_block = outlined_func.addEntryBlock();
84 
85   // Clone the operations and remap the inputs to use the function arguments.
86   BlockAndValueMapping mapping;
87   mapping.map(inputs, outlined_func.getArguments());
88   builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin());
89   for (Operation* op : ops) {
90     builder->clone(*op, mapping);
91   }
92 
93   // Set the returned values to use cloned ops results using mapping.
94   llvm::SmallVector<Value, 4> results_after_mapping;
95   for (Value result : outputs) {
96     results_after_mapping.push_back(mapping.lookupOrDefault(result));
97   }
98 
99   builder->create<ReturnOp>(ops.front()->getLoc(), results_after_mapping);
100   return outlined_func;
101 }
102 
103 // Encapsulates `func` in a module and serializes that module.
104 // `serialized_func_module` is set to the serialized module.
EncapsulateFuncAndSerialize(FuncOp func,std::string * serialized_func_module)105 void EncapsulateFuncAndSerialize(FuncOp func,
106                                  std::string* serialized_func_module) {
107   // Create a new module to hold func and all referenced functions.
108   OwningModuleRef module_for_func =
109       ModuleOp::create(mlir::UnknownLoc::get(func.getContext()));
110   SymbolTable symbol_table(module_for_func.get());
111 
112   symbol_table.insert(func);
113   *serialized_func_module =
114       tensorflow::SerializeMlirModule(module_for_func.get());
115 }
116 
117 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)118 bool HasOutsideCompilationNested(Operation* op) {
119   return op
120       ->walk([&](Operation* walked_op) {
121         if (op == walked_op) return WalkResult::advance();
122         if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
123           return WalkResult::interrupt();
124         }
125         return WalkResult::advance();
126       })
127       .wasInterrupted();
128 }
129 
130 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)131 bool HasOutsideCompilationAncestor(Operation* op) {
132   while (op) {
133     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
134       return true;
135     }
136     op = op->getParentOp();
137   }
138   return false;
139 }
140 
141 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)142 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
143   Operation* parent_op = op->getParentOp();
144   if (!parent_op) return false;
145   return HasOutsideCompilationAncestor(parent_op);
146 }
147 
ApplyXlaHostTransferAttr(Operation * op,OpBuilder & builder)148 Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) {
149   op->setAttr("_xla_has_host_transfer", builder.getBoolAttr(true));
150   return op;
151 }
152 
153 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
154 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)155 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
156                                 ValueRange inputs, Value compilation_key,
157                                 Value device_ordinal,
158                                 llvm::StringRef communication_key) {
159   if (device_ordinal)
160     return ApplyXlaHostTransferAttr(
161         builder.create<TF::_XlaSendFromHostV2Op>(
162             loc, inputs,
163             /*dynamic_key=*/compilation_key, device_ordinal,
164             builder.getStringAttr(communication_key)),
165         builder);
166 
167   return ApplyXlaHostTransferAttr(
168       builder.create<TF::_XlaSendFromHostOp>(
169           loc, inputs,
170           /*dynamic_key=*/compilation_key,
171           builder.getStringAttr(communication_key),
172           /*device_ordinal=*/builder.getI64IntegerAttr(0)),
173       builder);
174 }
175 
176 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
177 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)178 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
179                               TypeRange output_types, Value compilation_key,
180                               Value device_ordinal,
181                               llvm::StringRef communication_key) {
182   if (device_ordinal)
183     return ApplyXlaHostTransferAttr(
184         builder.create<TF::_XlaRecvAtHostV2Op>(
185             loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
186             builder.getStringAttr(communication_key)),
187         builder);
188 
189   return ApplyXlaHostTransferAttr(
190       builder.create<TF::_XlaRecvAtHostOp>(
191           loc, output_types, /*dynamic_key=*/compilation_key,
192           builder.getStringAttr(communication_key),
193           /*device_ordinal=*/builder.getI64IntegerAttr(0)),
194       builder);
195 }
196 
197 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
198 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)199 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
200                                          OpBuilder& builder) {
201   auto host_side_if = builder.create<TF::IfRegionOp>(
202       if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
203       if_region.is_stateless(), if_region._then_func_nameAttr(),
204       if_region._else_func_nameAttr());
205 
206   // Create empty then branch region.
207   auto& then_branch = host_side_if.then_branch();
208   then_branch.push_back(new Block);
209   builder.setInsertionPointToEnd(&then_branch.front());
210   builder.create<TF::YieldOp>(if_region.getLoc(),
211                               /*operands=*/ArrayRef<Value>{});
212 
213   // Create empty else branch region.
214   auto& else_branch = host_side_if.else_branch();
215   else_branch.push_back(new Block);
216   builder.setInsertionPointToEnd(&else_branch.front());
217   builder.create<TF::YieldOp>(if_region.getLoc(),
218                               /*operands=*/ArrayRef<Value>{});
219   return host_side_if;
220 }
221 // Creates a WhileRegionOp cond and body regions with yield op and
222 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)223 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
224                                   uint64_t parallel_iterations, Location loc,
225                                   OpBuilder& builder) {
226   auto host_side_while = builder.create<TF::WhileRegionOp>(
227       loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
228       parallel_iterations, is_stateless, /*shape_invariant=*/false);
229 
230   // Create empty else branch region.
231   auto& body = host_side_while.body();
232   body.push_back(new Block);
233   builder.setInsertionPointToEnd(&body.front());
234   builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
235   return host_side_while;
236 }
237 
238 // TODO(b/157054714): Use a better abstraction instead of
239 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
240 // Creates a compilation key as placeholder. A placeholder compilation cache key
241 // is created because it is a required input to _XlaRecvAtHost and
242 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
243 // cluster that contains the outside compiled ops. This placeholder should be
244 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)245 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
246     Location loc, OpBuilder& builder) {
247   auto result_type =
248       RankedTensorType::get({3}, builder.getType<TF::StringType>());
249   return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
250       loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
251 }
252 
253 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)254 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
255     OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
256   // An empty string placeholder is used for the device as that will be later
257   // populated with the device of the associated TPUReplicateMetadata op.
258   auto launch_op = builder.create<tf_device::LaunchOp>(
259       loc_op->getLoc(), builder.getStringAttr(host_device),
260       /*result_types=*/ArrayRef<Type>{});
261 
262   launch_op.body().push_back(new Block);
263   builder.setInsertionPointToEnd(&launch_op.GetBody());
264   builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
265                                       llvm::ArrayRef<Value>{});
266 
267   return launch_op;
268 }
269 
270 // Returns true if `op` has non-static shaped outputs.
HasDynamicOutputs(Operation * op)271 bool HasDynamicOutputs(Operation* op) {
272   for (Value v : op->getResults()) {
273     if (TF::CanBeRefined(v.getType())) return true;
274   }
275   return false;
276 }
277 
278 // Returns true if any op in `cluster_ops` has outputs consumed by ops not
279 // `cluster_ops` with a non-static shape.
HasDynamicOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)280 bool HasDynamicOutputs(const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
281   for (Operation* op : cluster_ops) {
282     for (const OpOperand& use : op->getUses()) {
283       if (cluster_ops.count(use.getOwner())) {
284         continue;
285       }
286       if (TF::CanBeRefined(use.get().getType())) return true;
287     }
288   }
289   return false;
290 }
291 
HasDynamicExternalValues(Operation * op)292 bool HasDynamicExternalValues(Operation* op) {
293   return op
294       ->walk([](Operation* walked_op) {
295         for (Value v : walked_op->getOperands()) {
296           if (TF::CanBeRefined(v.getType())) {
297             return WalkResult::interrupt();
298           }
299         }
300         return WalkResult::advance();
301       })
302       .wasInterrupted();
303 }
304 
305 // Returns operands of `cluster_ops` that need to be
306 // communicated from device->host. This is for the case when all operands have a
307 // static shape.
GetStaticExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)308 llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
309     tf_device::ClusterOp tpu_cluster,
310     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
311   llvm::SmallSetVector<Value, 4> external_values;
312   for (Operation* op : cluster_ops) {
313     op->walk([&](Operation* walked_op) {
314       if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
315               walked_op))
316         return WalkResult::advance();
317       for (Value v : walked_op->getOperands()) {
318         if (auto* defining_op = v.getDefiningOp()) {
319           if (!op->isAncestor(defining_op) &&
320               tpu_cluster->isAncestor(defining_op) &&
321               !HasOutsideCompilationAncestor(defining_op) &&
322               !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
323             external_values.insert(v);
324           }
325           continue;
326         }
327         auto block_arg = v.cast<BlockArgument>();
328         if (block_arg.getParentRegion() == op->getParentRegion())
329           external_values.insert(v);
330       }
331       return WalkResult::advance();
332     });
333   }
334   return external_values;
335 }
336 
337 // Returns every operand of `cluster_ops` that does not come from an op in
338 // `cluster_ops`.
GetAllExternalOperands(const llvm::SmallSetVector<Operation *,4> & cluster_ops)339 llvm::SmallSetVector<Value, 4> GetAllExternalOperands(
340     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
341   llvm::SmallSetVector<Value, 4> external_values;
342   for (Operation* op : cluster_ops) {
343     op->walk([&](Operation* walked_op) {
344       for (Value v : walked_op->getOperands()) {
345         Operation* defining_op = v.getDefiningOp();
346         if (!defining_op || !cluster_ops.count(defining_op)) {
347           external_values.insert(v);
348         }
349       }
350     });
351   }
352   return external_values;
353 }
354 
355 // Returns a SmallSetVector containing all of the operands that need to be
356 // communicated from device->host.
GetExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)357 llvm::SmallSetVector<Value, 4> GetExternalOperands(
358     tf_device::ClusterOp tpu_cluster,
359     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
360   // If there are any dynamic outputs, get all of the operands which are defined
361   // external to `cluster_ops`.
362   bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
363   if (has_dynamic_outputs) {
364     return GetAllExternalOperands(cluster_ops);
365   } else {
366     return GetStaticExternalOperands(tpu_cluster, cluster_ops);
367   }
368 }
369 
370 // Gets all outputs that need to be communicated from host->device.
GetExternalOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)371 llvm::SmallSetVector<Value, 4> GetExternalOutputs(
372     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
373   llvm::SmallSetVector<Value, 4> external_outputs;
374   bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
375   for (Operation* op : cluster_ops) {
376     for (Operation* user : op->getUsers()) {
377       // We skip any operations that are in the same outside compilation
378       // cluster that will be moved to the host at the same time since both
379       // defining op and user op will be moved to host.
380       if (cluster_ops.count(user)) {
381         continue;
382       }
383       // This is pessimistic and in some cases will add extra communication.
384       if (!HasOutsideCompilationAncestor(user) || has_dynamic_outputs ||
385           HasDynamicOutputs(user)) {
386         for (Value v : user->getOperands()) {
387           if (v.getDefiningOp() == op) external_outputs.insert(v);
388         }
389       }
390     }
391   }
392   return external_outputs;
393 }
394 
395 // Creates the HostCompute with `inputs` and `outputs`
396 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Location loc,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key,llvm::StringRef serialized_func_module)397 TF::_XlaHostComputeMlirOp CreateHostCompute(
398     OpBuilder& builder, Location loc,
399     const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
400     llvm::StringRef args_communication_key,
401     llvm::StringRef retvals_communication_key,
402     llvm::StringRef serialized_func_module) {
403   llvm::SmallVector<Type, 4> device_output_types;
404   for (const auto& output : outputs)
405     device_output_types.push_back(output.getType());
406   auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
407       loc, device_output_types, inputs.getArrayRef(),
408       builder.getStringAttr(args_communication_key),
409       builder.getStringAttr(retvals_communication_key),
410       /*tpu_core=*/builder.getI64IntegerAttr(0),
411       /*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
412   return host_compute;
413 }
414 
MarkOutsideCompiled(Operation * op)415 void MarkOutsideCompiled(Operation* op) {
416   op->setAttr(kXlaOutsideCompilationAttr,
417               StringAttr::get(op->getContext(), "temp"));
418 }
419 
420 // Returns whether an outside compilation cluster should be closed.  True when:
421 // 1. There is a dynamically shaped output consumed by a non-outside compiled
422 // op.
423 // 2. There is no dynamically shaped output.
ShouldCloseCluster(llvm::ArrayRef<Value> outputs)424 bool ShouldCloseCluster(llvm::ArrayRef<Value> outputs) {
425   bool has_dynamic_output = false;
426   for (Value v : outputs) {
427     if (TF::CanBeRefined(v.getType())) {
428       has_dynamic_output = true;
429       for (Operation* user : v.getUsers()) {
430         if (!HasOutsideCompilationAncestor(user)) return true;
431       }
432     }
433   }
434   return !has_dynamic_output;
435 }
436 
437 // Replaces `external_operands` with the results from `recv_at_host`.
438 // For non-static shapes, only replace operand usage if op is in the same
439 // region as insertion.
440 // For static-shapes, Replace operand usages if op is in the same region as
441 // insertion or if the op is outside compiled and will be moved to host later.
ReplaceExternalOperandUsage(const llvm::SmallSetVector<Value,4> & external_operands,Operation * recv_at_host,Operation * insertion_point,Block * original_op_block)442 void ReplaceExternalOperandUsage(
443     const llvm::SmallSetVector<Value, 4>& external_operands,
444     Operation* recv_at_host, Operation* insertion_point,
445     Block* original_op_block) {
446   auto replace_operand_usage = [&](OpOperand& operand) {
447     if (TF::CanBeRefined(operand.get().getType()) ||
448         HasDynamicOutputs(operand.getOwner())) {
449       return insertion_point->getParentRegion()->isAncestor(
450           operand.getOwner()->getParentRegion());
451     }
452     return insertion_point->getParentRegion()->isAncestor(
453                operand.getOwner()->getParentRegion()) ||
454            (HasOutsideCompilationAncestor(operand.getOwner()) &&
455             original_op_block == operand.getOwner()->getBlock());
456   };
457   for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
458     Value external_operand = std::get<0>(result);
459     external_operand.replaceUsesWithIf(std::get<1>(result),
460                                        replace_operand_usage);
461   }
462 }
463 
HasDynamicOutputs(llvm::ArrayRef<Value> outputs)464 bool HasDynamicOutputs(llvm::ArrayRef<Value> outputs) {
465   for (Value v : outputs) {
466     if (TF::CanBeRefined(v.getType())) {
467       return true;
468     }
469   }
470   return false;
471 }
472 
473 // Replaces usages of `external_outputs` which are values returned by outside
474 // compilation with the corresponding outputs from `host_compute`.
ReplaceExternalOutputUsage(const llvm::SmallSetVector<Value,4> & external_outputs,TF::_XlaHostComputeMlirOp host_compute)475 void ReplaceExternalOutputUsage(
476     const llvm::SmallSetVector<Value, 4>& external_outputs,
477     TF::_XlaHostComputeMlirOp host_compute) {
478   bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef());
479 
480   auto replace_output_usage = [&](OpOperand& operand) {
481     // Don't replace output usages if in host computation (defining op and user
482     // in same region).
483     bool in_same_region =
484         operand.get().getDefiningOp()->getParentRegion()->isAncestor(
485             operand.getOwner()->getParentRegion());
486     if (has_dynamic_outputs || HasDynamicOutputs(operand.getOwner())) {
487       return !in_same_region;
488     } else {
489       // Don't replace output usages in host computation or for outside
490       // compiled ops.
491       return !in_same_region &&
492              !HasOutsideCompilationAncestor(operand.getOwner());
493     }
494   };
495   for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
496     Value external_output = std::get<0>(result);
497     external_output.replaceUsesWithIf(std::get<1>(result),
498                                       replace_output_usage);
499   }
500 }
501 
502 // Move `clustered_ops` to run on host and adds communication ops to transfer
503 // `external_operands` and `external_outputs` to/from device/host.  Inserts
504 // ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
505 // creating comm ops.
MoveOpsToHost(const llvm::SmallSetVector<Operation *,4> & clustered_ops,const llvm::SmallSetVector<Value,4> & external_operands,const llvm::SmallSetVector<Value,4> & external_outputs,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)506 void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
507                    const llvm::SmallSetVector<Value, 4>& external_operands,
508                    const llvm::SmallSetVector<Value, 4>& external_outputs,
509                    Operation* insertion_point, Value compilation_key,
510                    Value device_ordinal, int& communication_key_index) {
511   OpBuilder builder(insertion_point);
512   Operation& op = *clustered_ops.back();
513   std::string args_communication_key =
514       llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
515           .str();
516   std::string retvals_communication_key =
517       llvm::formatv("host_compute_channel_{0}_retvals",
518                     (communication_key_index))
519           .str();
520 
521   // Use a unique name when sending just the IfRegion predicate.  This is
522   // for readable and to match the key in the TF2XLA bridge.
523   if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
524       external_operands.size() == 1) {
525     args_communication_key =
526         llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
527             .str();
528   }
529 
530   std::string serialized_func_module;
531   if (HasDynamicOutputs(external_outputs.getArrayRef())) {
532     FuncOp shape_op = BuildFunction(clustered_ops.getArrayRef(),
533                                     external_operands.getArrayRef(),
534                                     external_outputs.getArrayRef(), &builder);
535     EncapsulateFuncAndSerialize(shape_op, &serialized_func_module);
536   }
537 
538   builder.setInsertionPoint(&op);
539   auto host_compute =
540       CreateHostCompute(builder, op.getLoc(), external_operands,
541                         external_outputs.getArrayRef(), args_communication_key,
542                         retvals_communication_key, serialized_func_module);
543   // Insert ops on the host side computation to receive data from device.
544   builder.setInsertionPoint(insertion_point);
545   llvm::SmallVector<Type, 4> host_operand_types;
546   for (const auto& operand : external_operands)
547     host_operand_types.push_back(operand.getType());
548 
549   Operation* recv_at_host = CreateRecvAtHostOp(
550       builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
551       args_communication_key);
552   Block* original_op_block = op.getBlock();
553   Operation* after_op = recv_at_host;
554   for (Operation* cluster_op : clustered_ops) {
555     cluster_op->moveAfter(after_op);
556     cluster_op->removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
557     after_op = cluster_op;
558   }
559 
560   if (!external_outputs.empty()) {
561     CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
562                          compilation_key, device_ordinal,
563                          retvals_communication_key);
564   }
565 
566   if (external_operands.empty()) {
567     recv_at_host->erase();
568   } else {
569     ReplaceExternalOperandUsage(external_operands,
570                                 /*recv_at_host=*/recv_at_host,
571                                 /*insertion_point=*/insertion_point,
572                                 /*original_op_block=*/original_op_block);
573   }
574 
575   ReplaceExternalOutputUsage(external_outputs, host_compute);
576 
577   if (external_operands.empty() && external_outputs.empty()) {
578     host_compute.erase();
579   } else {
580     ++communication_key_index;
581   }
582 }
583 
584 // Move outside compiled ops in `src` to to `insertion_point` in host
585 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
586 // call to this method).  Communication ops are added in both `src` and at
587 // `insertion_point` using `compilation_key`, `device_ordinal` and
588 // `communication_key_index` which is incremented when used. Communication ops
589 // are added only when needed and at the location need.  There are checks to
590 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)591 LogicalResult MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
592                             Operation* insertion_point, Value compilation_key,
593                             Value device_ordinal,
594                             int& communication_key_index) {
595   // Contains all of the outside compiled operations that should be moved to the
596   // host using a single `_XlaHostComputeMlir` op.  This should only contain a
597   // single op except in the case where some of the input/output shapes are
598   // non-static.
599   llvm::SmallSetVector<Operation*, 4> clustered_ops;
600 
601   for (Operation& op : llvm::make_early_inc_range(*src)) {
602     if (HasOutsideCompilationAncestorExclusive(&op) ||
603         !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
604       continue;
605 
606     // We want to move the clustered_ops if the op to be added has all
607     // statically shaped operands since we can't ensure that the static shapes
608     // has been sent back to host in all cases.  See
609     // @static_shapes_sandwiched_outside_compilation MLIR test for an example.
610     if (!HasDynamicExternalValues(&op) && !clustered_ops.empty()) {
611       llvm::SmallSetVector<Value, 4> external_operands =
612           GetExternalOperands(tpu_cluster, clustered_ops);
613       llvm::SmallSetVector<Value, 4> external_outputs =
614           GetExternalOutputs(clustered_ops);
615       MoveOpsToHost(clustered_ops, external_operands, external_outputs,
616                     insertion_point, compilation_key, device_ordinal,
617                     communication_key_index);
618       clustered_ops.clear();
619     }
620 
621     clustered_ops.insert(&op);
622 
623     // Get the outputs that need to be communicated from host -> device.
624     llvm::SmallSetVector<Value, 4> external_outputs =
625         GetExternalOutputs(clustered_ops);
626 
627     if (ShouldCloseCluster(external_outputs.getArrayRef())) {
628       // Get the operands that need to be communicated from device -> host.
629       llvm::SmallSetVector<Value, 4> external_operands =
630           GetExternalOperands(tpu_cluster, clustered_ops);
631       MoveOpsToHost(clustered_ops, external_operands, external_outputs,
632                     insertion_point, compilation_key, device_ordinal,
633                     communication_key_index);
634       clustered_ops.clear();
635     }
636   }
637   return success();
638 }
639 
640 // Decompose control flow in `tpu_cluster` into device computation and host
641 // (outside compiled) computation into two separate control flow ops with
642 // communication between the device/host for data dependencies.  Both device and
643 // host control flow initially remain within `tpu_cluster` and a subsequency
644 // call to MoveOpsToHost moves the host side control flow to the host launch in
645 // tf_device.parallel_execute.  Uses `compilation_key, `device_ordinal` and
646 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)647 LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
648                                    Value compilation_key, Value device_ordinal,
649                                    int& communication_key_index) {
650   auto result = tpu_cluster.GetBody().walk([&](Operation* op) {
651     if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
652       if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
653       OpBuilder builder(if_op);
654       auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
655       if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
656                                host_if.then_branch().front().getTerminator(),
657                                compilation_key, device_ordinal,
658                                communication_key_index)))
659         return WalkResult::interrupt();
660       if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
661                                host_if.else_branch().front().getTerminator(),
662                                compilation_key, device_ordinal,
663                                communication_key_index)))
664         return WalkResult::interrupt();
665       MarkOutsideCompiled(host_if.getOperation());
666     }
667     if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
668       if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
669       OpBuilder builder(while_op);
670       auto host_while = CloneEmptyWhile(while_op.is_stateless(),
671                                         while_op.parallel_iterations(),
672                                         while_op.getLoc(), builder);
673       const auto condition_send_recv_key =
674           llvm::formatv("while_condition_channel_{0}",
675                         communication_key_index++)
676               .str();
677       auto& cond = host_while.cond();
678       cond.push_back(new Block);
679       auto condition = while_op.cond().front().getTerminator()->getOperand(0);
680       builder.setInsertionPoint(while_op.cond().front().getTerminator());
681       builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
682                                           condition_send_recv_key);
683       builder.setInsertionPointToEnd(&cond.front());
684       auto recv_condition_at_host = CreateRecvAtHostOp(
685           builder, while_op.getLoc(), TypeRange{condition.getType()},
686           compilation_key, device_ordinal, condition_send_recv_key);
687       builder.create<TF::YieldOp>(while_op.getLoc(),
688                                   recv_condition_at_host->getResults());
689 
690       if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
691                                recv_condition_at_host, compilation_key,
692                                device_ordinal, communication_key_index)))
693         return WalkResult::interrupt();
694       if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(),
695                                host_while.body().front().getTerminator(),
696                                compilation_key, device_ordinal,
697                                communication_key_index)))
698         return WalkResult::interrupt();
699       MarkOutsideCompiled(host_while.getOperation());
700     }
701     return WalkResult::advance();
702   });
703   if (result.wasInterrupted()) return failure();
704   return success();
705 }
706 
707 // Removes outside compilation from all ops inside `host_launch_op`.  Should
708 // only be run after all outside compiled ops have been moved to
709 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)710 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
711   host_launch_op.GetBody().walk([&](Operation* op) {
712     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
713       op->removeAttr(
714           Identifier::get(kXlaOutsideCompilationAttr, op->getContext()));
715     }
716   });
717 }
718 
719 // Creates a `parallel_execute` op with a region for host computation and
720 // a region for `tpu_cluster` computation by extracting outside compiled ops to
721 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)722 LogicalResult CreateParallelExecuteForOutsideCompilation(
723     ModuleOp module, tf_device::ClusterOp tpu_cluster,
724     llvm::StringRef host_device) {
725   OpBuilder builder(tpu_cluster);
726   // Create parallel_execute regions, one for the host computation for outside
727   // compilation and the second for the original TPU cluster computation.
728   const int num_regions = 2;
729   auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
730       tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
731   Block& host_computation_block =
732       parallel_execute_op.GetRegionBlockWithIndex(0);
733   builder.setInsertionPointToEnd(&host_computation_block);
734 
735   // Create a single launch op for all outside compiled ops.
736   tf_device::LaunchOp host_launch_op =
737       CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
738   builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
739   auto compilation_key_op =
740       CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
741   Value compilation_key = compilation_key_op.program();
742   auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
743       tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
744   Value device_ordinal = nullptr;
745   if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
746     device_ordinal = device_ordinal_op.device_ordinal();
747   }
748 
749   int communication_key_index = 0;
750   // Decompose control flow into device and host control flow when outside
751   // compilation is included.
752   if (failed(DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
753                                   communication_key_index)))
754     return failure();
755 
756   // Move all outside compiled ops including control flow to host launch.
757   if (failed(MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
758                            host_launch_op.GetBody().getTerminator(),
759                            compilation_key, device_ordinal,
760                            communication_key_index)))
761     return failure();
762 
763   if (communication_key_index == 0) compilation_key_op.erase();
764   if (communication_key_index == 0 || device_ordinal == nullptr)
765     device_ordinal_op.erase();
766 
767   RemoveOutsideCompilation(host_launch_op);
768 
769   builder.setInsertionPointToEnd(&host_computation_block);
770   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
771 
772   // Move the launch body to last parallel_execute block.
773   Block& parallel_execute_tpu_block =
774       parallel_execute_op.GetRegionBlockWithIndex(1);
775   builder.setInsertionPointToEnd(&parallel_execute_tpu_block);
776   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
777                                       tpu_cluster.getResults());
778   tpu_cluster.getOperation()->moveBefore(
779       parallel_execute_tpu_block.getTerminator());
780 
781   // Remap cluster results with parallel_execute results if user is outside of
782   // parallel_execute.
783   for (auto result :
784        llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
785     Value tpu_cluster_result = std::get<0>(result);
786     Value parallel_execute_result = std::get<1>(result);
787     for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
788       if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
789         use.set(parallel_execute_result);
790   }
791   return success();
792 }
793 
runOnOperation()794 void TPUExtractOutsideCompilation::runOnOperation() {
795   // Get runtime devices information from the closest parent module.
796   auto module = getOperation();
797   mlir::TF::RuntimeDevices devices;
798   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
799     return signalPassFailure();
800 
801   module.walk([&](tf_device::ClusterOp tpu_cluster) {
802     if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
803       std::string host_device;
804       if (tensorflow::HasModelParallelism(tpu_cluster)) {
805         tpu_cluster.emitOpError(
806             "outside compilation is not supported with model parallelism.");
807         return signalPassFailure();
808       }
809       if (failed(tensorflow::GetHostDeviceOutsideComputation(
810               devices, tpu_cluster, &host_device)))
811         return signalPassFailure();
812       if (failed(CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
813                                                             host_device)))
814         return signalPassFailure();
815     }
816   });
817   // Remove `_xla_outside_compilation` attribute from all ops.  These ops will
818   // be outside of the device cluster. The `_xla_outside_compilation` attribute
819   // on ops outside of tf_device.cluster don't have any meaning and can lead to
820   // errors later on.  These ops were likely lifted out of the the
821   // tf_device.cluster in an earlier pass.
822   module.walk(
823       [](Operation* op) { op->removeAttr("_xla_outside_compilation"); });
824 }
825 
826 }  // namespace
827 
828 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()829 CreateTPUExtractOutsideCompilationPass() {
830   return std::make_unique<TPUExtractOutsideCompilation>();
831 }
832 
833 }  // namespace TFTPU
834 }  // namespace mlir
835