1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/TypeRange.h" // from @llvm-project
35 #include "mlir/IR/Visitors.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48
49 namespace mlir {
50 namespace TFTPU {
51
52 namespace {
53
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kHostFunctionAttr[] = "host_func";
56 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
57
58 struct TPUExtractOutsideCompilation
59 : public TF::TPUExtractOutsideCompilationPassBase<
60 TPUExtractOutsideCompilation> {
61 void runOnOperation() override;
62 };
63
64 // Build a function containing `ops` with `inputs` and `outputs` using
65 // `builder`. The `ops` are cloned and modified to use the function arguments
66 // as inputs.
BuildFunction(llvm::ArrayRef<Operation * > ops,llvm::ArrayRef<Value> inputs,llvm::ArrayRef<Value> outputs,OpBuilder * builder)67 FuncOp BuildFunction(llvm::ArrayRef<Operation*> ops,
68 llvm::ArrayRef<Value> inputs,
69 llvm::ArrayRef<Value> outputs, OpBuilder* builder) {
70 llvm::SmallVector<Type, 4> operand_types;
71 operand_types.reserve(inputs.size());
72 for (Value v : inputs) operand_types.emplace_back(v.getType());
73 llvm::SmallVector<Type, 4> output_types;
74 output_types.reserve(outputs.size());
75 for (Value v : outputs) output_types.emplace_back(v.getType());
76
77 auto func_type = builder->getFunctionType(operand_types, output_types);
78
79 FuncOp outlined_func =
80 FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type);
81
82 // Create function body.
83 Block* outlined_func_block = outlined_func.addEntryBlock();
84
85 // Clone the operations and remap the inputs to use the function arguments.
86 BlockAndValueMapping mapping;
87 mapping.map(inputs, outlined_func.getArguments());
88 builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin());
89 for (Operation* op : ops) {
90 builder->clone(*op, mapping);
91 }
92
93 // Set the returned values to use cloned ops results using mapping.
94 llvm::SmallVector<Value, 4> results_after_mapping;
95 for (Value result : outputs) {
96 results_after_mapping.push_back(mapping.lookupOrDefault(result));
97 }
98
99 builder->create<ReturnOp>(ops.front()->getLoc(), results_after_mapping);
100 return outlined_func;
101 }
102
103 // Encapsulates `func` in a module and serializes that module.
104 // `serialized_func_module` is set to the serialized module.
EncapsulateFuncAndSerialize(FuncOp func,std::string * serialized_func_module)105 void EncapsulateFuncAndSerialize(FuncOp func,
106 std::string* serialized_func_module) {
107 // Create a new module to hold func and all referenced functions.
108 OwningModuleRef module_for_func =
109 ModuleOp::create(mlir::UnknownLoc::get(func.getContext()));
110 SymbolTable symbol_table(module_for_func.get());
111
112 symbol_table.insert(func);
113 *serialized_func_module =
114 tensorflow::SerializeMlirModule(module_for_func.get());
115 }
116
117 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)118 bool HasOutsideCompilationNested(Operation* op) {
119 return op
120 ->walk([&](Operation* walked_op) {
121 if (op == walked_op) return WalkResult::advance();
122 if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
123 return WalkResult::interrupt();
124 }
125 return WalkResult::advance();
126 })
127 .wasInterrupted();
128 }
129
130 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)131 bool HasOutsideCompilationAncestor(Operation* op) {
132 while (op) {
133 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
134 return true;
135 }
136 op = op->getParentOp();
137 }
138 return false;
139 }
140
141 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)142 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
143 Operation* parent_op = op->getParentOp();
144 if (!parent_op) return false;
145 return HasOutsideCompilationAncestor(parent_op);
146 }
147
ApplyXlaHostTransferAttr(Operation * op,OpBuilder & builder)148 Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) {
149 op->setAttr("_xla_has_host_transfer", builder.getBoolAttr(true));
150 return op;
151 }
152
153 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
154 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)155 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
156 ValueRange inputs, Value compilation_key,
157 Value device_ordinal,
158 llvm::StringRef communication_key) {
159 if (device_ordinal)
160 return ApplyXlaHostTransferAttr(
161 builder.create<TF::_XlaSendFromHostV2Op>(
162 loc, inputs,
163 /*dynamic_key=*/compilation_key, device_ordinal,
164 builder.getStringAttr(communication_key)),
165 builder);
166
167 return ApplyXlaHostTransferAttr(
168 builder.create<TF::_XlaSendFromHostOp>(
169 loc, inputs,
170 /*dynamic_key=*/compilation_key,
171 builder.getStringAttr(communication_key),
172 /*device_ordinal=*/builder.getI64IntegerAttr(0)),
173 builder);
174 }
175
176 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
177 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)178 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
179 TypeRange output_types, Value compilation_key,
180 Value device_ordinal,
181 llvm::StringRef communication_key) {
182 if (device_ordinal)
183 return ApplyXlaHostTransferAttr(
184 builder.create<TF::_XlaRecvAtHostV2Op>(
185 loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
186 builder.getStringAttr(communication_key)),
187 builder);
188
189 return ApplyXlaHostTransferAttr(
190 builder.create<TF::_XlaRecvAtHostOp>(
191 loc, output_types, /*dynamic_key=*/compilation_key,
192 builder.getStringAttr(communication_key),
193 /*device_ordinal=*/builder.getI64IntegerAttr(0)),
194 builder);
195 }
196
197 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
198 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)199 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
200 OpBuilder& builder) {
201 auto host_side_if = builder.create<TF::IfRegionOp>(
202 if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
203 if_region.is_stateless(), if_region._then_func_nameAttr(),
204 if_region._else_func_nameAttr());
205
206 // Create empty then branch region.
207 auto& then_branch = host_side_if.then_branch();
208 then_branch.push_back(new Block);
209 builder.setInsertionPointToEnd(&then_branch.front());
210 builder.create<TF::YieldOp>(if_region.getLoc(),
211 /*operands=*/ArrayRef<Value>{});
212
213 // Create empty else branch region.
214 auto& else_branch = host_side_if.else_branch();
215 else_branch.push_back(new Block);
216 builder.setInsertionPointToEnd(&else_branch.front());
217 builder.create<TF::YieldOp>(if_region.getLoc(),
218 /*operands=*/ArrayRef<Value>{});
219 return host_side_if;
220 }
221 // Creates a WhileRegionOp cond and body regions with yield op and
222 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)223 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
224 uint64_t parallel_iterations, Location loc,
225 OpBuilder& builder) {
226 auto host_side_while = builder.create<TF::WhileRegionOp>(
227 loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
228 parallel_iterations, is_stateless, /*shape_invariant=*/false);
229
230 // Create empty else branch region.
231 auto& body = host_side_while.body();
232 body.push_back(new Block);
233 builder.setInsertionPointToEnd(&body.front());
234 builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
235 return host_side_while;
236 }
237
238 // TODO(b/157054714): Use a better abstraction instead of
239 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
240 // Creates a compilation key as placeholder. A placeholder compilation cache key
241 // is created because it is a required input to _XlaRecvAtHost and
242 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
243 // cluster that contains the outside compiled ops. This placeholder should be
244 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)245 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
246 Location loc, OpBuilder& builder) {
247 auto result_type =
248 RankedTensorType::get({3}, builder.getType<TF::StringType>());
249 return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
250 loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
251 }
252
253 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)254 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
255 OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
256 // An empty string placeholder is used for the device as that will be later
257 // populated with the device of the associated TPUReplicateMetadata op.
258 auto launch_op = builder.create<tf_device::LaunchOp>(
259 loc_op->getLoc(), builder.getStringAttr(host_device),
260 /*result_types=*/ArrayRef<Type>{});
261
262 launch_op.body().push_back(new Block);
263 builder.setInsertionPointToEnd(&launch_op.GetBody());
264 builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
265 llvm::ArrayRef<Value>{});
266
267 return launch_op;
268 }
269
270 // Returns true if `op` has non-static shaped outputs.
HasDynamicOutputs(Operation * op)271 bool HasDynamicOutputs(Operation* op) {
272 for (Value v : op->getResults()) {
273 if (TF::CanBeRefined(v.getType())) return true;
274 }
275 return false;
276 }
277
278 // Returns true if any op in `cluster_ops` has outputs consumed by ops not
279 // `cluster_ops` with a non-static shape.
HasDynamicOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)280 bool HasDynamicOutputs(const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
281 for (Operation* op : cluster_ops) {
282 for (const OpOperand& use : op->getUses()) {
283 if (cluster_ops.count(use.getOwner())) {
284 continue;
285 }
286 if (TF::CanBeRefined(use.get().getType())) return true;
287 }
288 }
289 return false;
290 }
291
HasDynamicExternalValues(Operation * op)292 bool HasDynamicExternalValues(Operation* op) {
293 return op
294 ->walk([](Operation* walked_op) {
295 for (Value v : walked_op->getOperands()) {
296 if (TF::CanBeRefined(v.getType())) {
297 return WalkResult::interrupt();
298 }
299 }
300 return WalkResult::advance();
301 })
302 .wasInterrupted();
303 }
304
305 // Returns operands of `cluster_ops` that need to be
306 // communicated from device->host. This is for the case when all operands have a
307 // static shape.
GetStaticExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)308 llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
309 tf_device::ClusterOp tpu_cluster,
310 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
311 llvm::SmallSetVector<Value, 4> external_values;
312 for (Operation* op : cluster_ops) {
313 op->walk([&](Operation* walked_op) {
314 if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
315 walked_op))
316 return WalkResult::advance();
317 for (Value v : walked_op->getOperands()) {
318 if (auto* defining_op = v.getDefiningOp()) {
319 if (!op->isAncestor(defining_op) &&
320 tpu_cluster->isAncestor(defining_op) &&
321 !HasOutsideCompilationAncestor(defining_op) &&
322 !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
323 external_values.insert(v);
324 }
325 continue;
326 }
327 auto block_arg = v.cast<BlockArgument>();
328 if (block_arg.getParentRegion() == op->getParentRegion())
329 external_values.insert(v);
330 }
331 return WalkResult::advance();
332 });
333 }
334 return external_values;
335 }
336
337 // Returns every operand of `cluster_ops` that does not come from an op in
338 // `cluster_ops`.
GetAllExternalOperands(const llvm::SmallSetVector<Operation *,4> & cluster_ops)339 llvm::SmallSetVector<Value, 4> GetAllExternalOperands(
340 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
341 llvm::SmallSetVector<Value, 4> external_values;
342 for (Operation* op : cluster_ops) {
343 op->walk([&](Operation* walked_op) {
344 for (Value v : walked_op->getOperands()) {
345 Operation* defining_op = v.getDefiningOp();
346 if (!defining_op || !cluster_ops.count(defining_op)) {
347 external_values.insert(v);
348 }
349 }
350 });
351 }
352 return external_values;
353 }
354
355 // Returns a SmallSetVector containing all of the operands that need to be
356 // communicated from device->host.
GetExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)357 llvm::SmallSetVector<Value, 4> GetExternalOperands(
358 tf_device::ClusterOp tpu_cluster,
359 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
360 // If there are any dynamic outputs, get all of the operands which are defined
361 // external to `cluster_ops`.
362 bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
363 if (has_dynamic_outputs) {
364 return GetAllExternalOperands(cluster_ops);
365 } else {
366 return GetStaticExternalOperands(tpu_cluster, cluster_ops);
367 }
368 }
369
370 // Gets all outputs that need to be communicated from host->device.
GetExternalOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)371 llvm::SmallSetVector<Value, 4> GetExternalOutputs(
372 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
373 llvm::SmallSetVector<Value, 4> external_outputs;
374 bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
375 for (Operation* op : cluster_ops) {
376 for (Operation* user : op->getUsers()) {
377 // We skip any operations that are in the same outside compilation
378 // cluster that will be moved to the host at the same time since both
379 // defining op and user op will be moved to host.
380 if (cluster_ops.count(user)) {
381 continue;
382 }
383 // This is pessimistic and in some cases will add extra communication.
384 if (!HasOutsideCompilationAncestor(user) || has_dynamic_outputs ||
385 HasDynamicOutputs(user)) {
386 for (Value v : user->getOperands()) {
387 if (v.getDefiningOp() == op) external_outputs.insert(v);
388 }
389 }
390 }
391 }
392 return external_outputs;
393 }
394
395 // Creates the HostCompute with `inputs` and `outputs`
396 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Location loc,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key,llvm::StringRef serialized_func_module)397 TF::_XlaHostComputeMlirOp CreateHostCompute(
398 OpBuilder& builder, Location loc,
399 const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
400 llvm::StringRef args_communication_key,
401 llvm::StringRef retvals_communication_key,
402 llvm::StringRef serialized_func_module) {
403 llvm::SmallVector<Type, 4> device_output_types;
404 for (const auto& output : outputs)
405 device_output_types.push_back(output.getType());
406 auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
407 loc, device_output_types, inputs.getArrayRef(),
408 builder.getStringAttr(args_communication_key),
409 builder.getStringAttr(retvals_communication_key),
410 /*tpu_core=*/builder.getI64IntegerAttr(0),
411 /*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
412 return host_compute;
413 }
414
MarkOutsideCompiled(Operation * op)415 void MarkOutsideCompiled(Operation* op) {
416 op->setAttr(kXlaOutsideCompilationAttr,
417 StringAttr::get(op->getContext(), "temp"));
418 }
419
420 // Returns whether an outside compilation cluster should be closed. True when:
421 // 1. There is a dynamically shaped output consumed by a non-outside compiled
422 // op.
423 // 2. There is no dynamically shaped output.
ShouldCloseCluster(llvm::ArrayRef<Value> outputs)424 bool ShouldCloseCluster(llvm::ArrayRef<Value> outputs) {
425 bool has_dynamic_output = false;
426 for (Value v : outputs) {
427 if (TF::CanBeRefined(v.getType())) {
428 has_dynamic_output = true;
429 for (Operation* user : v.getUsers()) {
430 if (!HasOutsideCompilationAncestor(user)) return true;
431 }
432 }
433 }
434 return !has_dynamic_output;
435 }
436
437 // Replaces `external_operands` with the results from `recv_at_host`.
438 // For non-static shapes, only replace operand usage if op is in the same
439 // region as insertion.
440 // For static-shapes, Replace operand usages if op is in the same region as
441 // insertion or if the op is outside compiled and will be moved to host later.
ReplaceExternalOperandUsage(const llvm::SmallSetVector<Value,4> & external_operands,Operation * recv_at_host,Operation * insertion_point,Block * original_op_block)442 void ReplaceExternalOperandUsage(
443 const llvm::SmallSetVector<Value, 4>& external_operands,
444 Operation* recv_at_host, Operation* insertion_point,
445 Block* original_op_block) {
446 auto replace_operand_usage = [&](OpOperand& operand) {
447 if (TF::CanBeRefined(operand.get().getType()) ||
448 HasDynamicOutputs(operand.getOwner())) {
449 return insertion_point->getParentRegion()->isAncestor(
450 operand.getOwner()->getParentRegion());
451 }
452 return insertion_point->getParentRegion()->isAncestor(
453 operand.getOwner()->getParentRegion()) ||
454 (HasOutsideCompilationAncestor(operand.getOwner()) &&
455 original_op_block == operand.getOwner()->getBlock());
456 };
457 for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
458 Value external_operand = std::get<0>(result);
459 external_operand.replaceUsesWithIf(std::get<1>(result),
460 replace_operand_usage);
461 }
462 }
463
HasDynamicOutputs(llvm::ArrayRef<Value> outputs)464 bool HasDynamicOutputs(llvm::ArrayRef<Value> outputs) {
465 for (Value v : outputs) {
466 if (TF::CanBeRefined(v.getType())) {
467 return true;
468 }
469 }
470 return false;
471 }
472
473 // Replaces usages of `external_outputs` which are values returned by outside
474 // compilation with the corresponding outputs from `host_compute`.
ReplaceExternalOutputUsage(const llvm::SmallSetVector<Value,4> & external_outputs,TF::_XlaHostComputeMlirOp host_compute)475 void ReplaceExternalOutputUsage(
476 const llvm::SmallSetVector<Value, 4>& external_outputs,
477 TF::_XlaHostComputeMlirOp host_compute) {
478 bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef());
479
480 auto replace_output_usage = [&](OpOperand& operand) {
481 // Don't replace output usages if in host computation (defining op and user
482 // in same region).
483 bool in_same_region =
484 operand.get().getDefiningOp()->getParentRegion()->isAncestor(
485 operand.getOwner()->getParentRegion());
486 if (has_dynamic_outputs || HasDynamicOutputs(operand.getOwner())) {
487 return !in_same_region;
488 } else {
489 // Don't replace output usages in host computation or for outside
490 // compiled ops.
491 return !in_same_region &&
492 !HasOutsideCompilationAncestor(operand.getOwner());
493 }
494 };
495 for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
496 Value external_output = std::get<0>(result);
497 external_output.replaceUsesWithIf(std::get<1>(result),
498 replace_output_usage);
499 }
500 }
501
502 // Move `clustered_ops` to run on host and adds communication ops to transfer
503 // `external_operands` and `external_outputs` to/from device/host. Inserts
504 // ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
505 // creating comm ops.
MoveOpsToHost(const llvm::SmallSetVector<Operation *,4> & clustered_ops,const llvm::SmallSetVector<Value,4> & external_operands,const llvm::SmallSetVector<Value,4> & external_outputs,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)506 void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
507 const llvm::SmallSetVector<Value, 4>& external_operands,
508 const llvm::SmallSetVector<Value, 4>& external_outputs,
509 Operation* insertion_point, Value compilation_key,
510 Value device_ordinal, int& communication_key_index) {
511 OpBuilder builder(insertion_point);
512 Operation& op = *clustered_ops.back();
513 std::string args_communication_key =
514 llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
515 .str();
516 std::string retvals_communication_key =
517 llvm::formatv("host_compute_channel_{0}_retvals",
518 (communication_key_index))
519 .str();
520
521 // Use a unique name when sending just the IfRegion predicate. This is
522 // for readable and to match the key in the TF2XLA bridge.
523 if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
524 external_operands.size() == 1) {
525 args_communication_key =
526 llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
527 .str();
528 }
529
530 std::string serialized_func_module;
531 if (HasDynamicOutputs(external_outputs.getArrayRef())) {
532 FuncOp shape_op = BuildFunction(clustered_ops.getArrayRef(),
533 external_operands.getArrayRef(),
534 external_outputs.getArrayRef(), &builder);
535 EncapsulateFuncAndSerialize(shape_op, &serialized_func_module);
536 }
537
538 builder.setInsertionPoint(&op);
539 auto host_compute =
540 CreateHostCompute(builder, op.getLoc(), external_operands,
541 external_outputs.getArrayRef(), args_communication_key,
542 retvals_communication_key, serialized_func_module);
543 // Insert ops on the host side computation to receive data from device.
544 builder.setInsertionPoint(insertion_point);
545 llvm::SmallVector<Type, 4> host_operand_types;
546 for (const auto& operand : external_operands)
547 host_operand_types.push_back(operand.getType());
548
549 Operation* recv_at_host = CreateRecvAtHostOp(
550 builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
551 args_communication_key);
552 Block* original_op_block = op.getBlock();
553 Operation* after_op = recv_at_host;
554 for (Operation* cluster_op : clustered_ops) {
555 cluster_op->moveAfter(after_op);
556 cluster_op->removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
557 after_op = cluster_op;
558 }
559
560 if (!external_outputs.empty()) {
561 CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
562 compilation_key, device_ordinal,
563 retvals_communication_key);
564 }
565
566 if (external_operands.empty()) {
567 recv_at_host->erase();
568 } else {
569 ReplaceExternalOperandUsage(external_operands,
570 /*recv_at_host=*/recv_at_host,
571 /*insertion_point=*/insertion_point,
572 /*original_op_block=*/original_op_block);
573 }
574
575 ReplaceExternalOutputUsage(external_outputs, host_compute);
576
577 if (external_operands.empty() && external_outputs.empty()) {
578 host_compute.erase();
579 } else {
580 ++communication_key_index;
581 }
582 }
583
584 // Move outside compiled ops in `src` to to `insertion_point` in host
585 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
586 // call to this method). Communication ops are added in both `src` and at
587 // `insertion_point` using `compilation_key`, `device_ordinal` and
588 // `communication_key_index` which is incremented when used. Communication ops
589 // are added only when needed and at the location need. There are checks to
590 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)591 LogicalResult MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
592 Operation* insertion_point, Value compilation_key,
593 Value device_ordinal,
594 int& communication_key_index) {
595 // Contains all of the outside compiled operations that should be moved to the
596 // host using a single `_XlaHostComputeMlir` op. This should only contain a
597 // single op except in the case where some of the input/output shapes are
598 // non-static.
599 llvm::SmallSetVector<Operation*, 4> clustered_ops;
600
601 for (Operation& op : llvm::make_early_inc_range(*src)) {
602 if (HasOutsideCompilationAncestorExclusive(&op) ||
603 !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
604 continue;
605
606 // We want to move the clustered_ops if the op to be added has all
607 // statically shaped operands since we can't ensure that the static shapes
608 // has been sent back to host in all cases. See
609 // @static_shapes_sandwiched_outside_compilation MLIR test for an example.
610 if (!HasDynamicExternalValues(&op) && !clustered_ops.empty()) {
611 llvm::SmallSetVector<Value, 4> external_operands =
612 GetExternalOperands(tpu_cluster, clustered_ops);
613 llvm::SmallSetVector<Value, 4> external_outputs =
614 GetExternalOutputs(clustered_ops);
615 MoveOpsToHost(clustered_ops, external_operands, external_outputs,
616 insertion_point, compilation_key, device_ordinal,
617 communication_key_index);
618 clustered_ops.clear();
619 }
620
621 clustered_ops.insert(&op);
622
623 // Get the outputs that need to be communicated from host -> device.
624 llvm::SmallSetVector<Value, 4> external_outputs =
625 GetExternalOutputs(clustered_ops);
626
627 if (ShouldCloseCluster(external_outputs.getArrayRef())) {
628 // Get the operands that need to be communicated from device -> host.
629 llvm::SmallSetVector<Value, 4> external_operands =
630 GetExternalOperands(tpu_cluster, clustered_ops);
631 MoveOpsToHost(clustered_ops, external_operands, external_outputs,
632 insertion_point, compilation_key, device_ordinal,
633 communication_key_index);
634 clustered_ops.clear();
635 }
636 }
637 return success();
638 }
639
640 // Decompose control flow in `tpu_cluster` into device computation and host
641 // (outside compiled) computation into two separate control flow ops with
642 // communication between the device/host for data dependencies. Both device and
643 // host control flow initially remain within `tpu_cluster` and a subsequency
644 // call to MoveOpsToHost moves the host side control flow to the host launch in
645 // tf_device.parallel_execute. Uses `compilation_key, `device_ordinal` and
646 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)647 LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
648 Value compilation_key, Value device_ordinal,
649 int& communication_key_index) {
650 auto result = tpu_cluster.GetBody().walk([&](Operation* op) {
651 if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
652 if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
653 OpBuilder builder(if_op);
654 auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
655 if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
656 host_if.then_branch().front().getTerminator(),
657 compilation_key, device_ordinal,
658 communication_key_index)))
659 return WalkResult::interrupt();
660 if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
661 host_if.else_branch().front().getTerminator(),
662 compilation_key, device_ordinal,
663 communication_key_index)))
664 return WalkResult::interrupt();
665 MarkOutsideCompiled(host_if.getOperation());
666 }
667 if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
668 if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
669 OpBuilder builder(while_op);
670 auto host_while = CloneEmptyWhile(while_op.is_stateless(),
671 while_op.parallel_iterations(),
672 while_op.getLoc(), builder);
673 const auto condition_send_recv_key =
674 llvm::formatv("while_condition_channel_{0}",
675 communication_key_index++)
676 .str();
677 auto& cond = host_while.cond();
678 cond.push_back(new Block);
679 auto condition = while_op.cond().front().getTerminator()->getOperand(0);
680 builder.setInsertionPoint(while_op.cond().front().getTerminator());
681 builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
682 condition_send_recv_key);
683 builder.setInsertionPointToEnd(&cond.front());
684 auto recv_condition_at_host = CreateRecvAtHostOp(
685 builder, while_op.getLoc(), TypeRange{condition.getType()},
686 compilation_key, device_ordinal, condition_send_recv_key);
687 builder.create<TF::YieldOp>(while_op.getLoc(),
688 recv_condition_at_host->getResults());
689
690 if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
691 recv_condition_at_host, compilation_key,
692 device_ordinal, communication_key_index)))
693 return WalkResult::interrupt();
694 if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(),
695 host_while.body().front().getTerminator(),
696 compilation_key, device_ordinal,
697 communication_key_index)))
698 return WalkResult::interrupt();
699 MarkOutsideCompiled(host_while.getOperation());
700 }
701 return WalkResult::advance();
702 });
703 if (result.wasInterrupted()) return failure();
704 return success();
705 }
706
707 // Removes outside compilation from all ops inside `host_launch_op`. Should
708 // only be run after all outside compiled ops have been moved to
709 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)710 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
711 host_launch_op.GetBody().walk([&](Operation* op) {
712 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
713 op->removeAttr(
714 Identifier::get(kXlaOutsideCompilationAttr, op->getContext()));
715 }
716 });
717 }
718
719 // Creates a `parallel_execute` op with a region for host computation and
720 // a region for `tpu_cluster` computation by extracting outside compiled ops to
721 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)722 LogicalResult CreateParallelExecuteForOutsideCompilation(
723 ModuleOp module, tf_device::ClusterOp tpu_cluster,
724 llvm::StringRef host_device) {
725 OpBuilder builder(tpu_cluster);
726 // Create parallel_execute regions, one for the host computation for outside
727 // compilation and the second for the original TPU cluster computation.
728 const int num_regions = 2;
729 auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
730 tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
731 Block& host_computation_block =
732 parallel_execute_op.GetRegionBlockWithIndex(0);
733 builder.setInsertionPointToEnd(&host_computation_block);
734
735 // Create a single launch op for all outside compiled ops.
736 tf_device::LaunchOp host_launch_op =
737 CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
738 builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
739 auto compilation_key_op =
740 CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
741 Value compilation_key = compilation_key_op.program();
742 auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
743 tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
744 Value device_ordinal = nullptr;
745 if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
746 device_ordinal = device_ordinal_op.device_ordinal();
747 }
748
749 int communication_key_index = 0;
750 // Decompose control flow into device and host control flow when outside
751 // compilation is included.
752 if (failed(DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
753 communication_key_index)))
754 return failure();
755
756 // Move all outside compiled ops including control flow to host launch.
757 if (failed(MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
758 host_launch_op.GetBody().getTerminator(),
759 compilation_key, device_ordinal,
760 communication_key_index)))
761 return failure();
762
763 if (communication_key_index == 0) compilation_key_op.erase();
764 if (communication_key_index == 0 || device_ordinal == nullptr)
765 device_ordinal_op.erase();
766
767 RemoveOutsideCompilation(host_launch_op);
768
769 builder.setInsertionPointToEnd(&host_computation_block);
770 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
771
772 // Move the launch body to last parallel_execute block.
773 Block& parallel_execute_tpu_block =
774 parallel_execute_op.GetRegionBlockWithIndex(1);
775 builder.setInsertionPointToEnd(¶llel_execute_tpu_block);
776 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
777 tpu_cluster.getResults());
778 tpu_cluster.getOperation()->moveBefore(
779 parallel_execute_tpu_block.getTerminator());
780
781 // Remap cluster results with parallel_execute results if user is outside of
782 // parallel_execute.
783 for (auto result :
784 llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
785 Value tpu_cluster_result = std::get<0>(result);
786 Value parallel_execute_result = std::get<1>(result);
787 for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
788 if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
789 use.set(parallel_execute_result);
790 }
791 return success();
792 }
793
runOnOperation()794 void TPUExtractOutsideCompilation::runOnOperation() {
795 // Get runtime devices information from the closest parent module.
796 auto module = getOperation();
797 mlir::TF::RuntimeDevices devices;
798 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
799 return signalPassFailure();
800
801 module.walk([&](tf_device::ClusterOp tpu_cluster) {
802 if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
803 std::string host_device;
804 if (tensorflow::HasModelParallelism(tpu_cluster)) {
805 tpu_cluster.emitOpError(
806 "outside compilation is not supported with model parallelism.");
807 return signalPassFailure();
808 }
809 if (failed(tensorflow::GetHostDeviceOutsideComputation(
810 devices, tpu_cluster, &host_device)))
811 return signalPassFailure();
812 if (failed(CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
813 host_device)))
814 return signalPassFailure();
815 }
816 });
817 // Remove `_xla_outside_compilation` attribute from all ops. These ops will
818 // be outside of the device cluster. The `_xla_outside_compilation` attribute
819 // on ops outside of tf_device.cluster don't have any meaning and can lead to
820 // errors later on. These ops were likely lifted out of the the
821 // tf_device.cluster in an earlier pass.
822 module.walk(
823 [](Operation* op) { op->removeAttr("_xla_outside_compilation"); });
824 }
825
826 } // namespace
827
828 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()829 CreateTPUExtractOutsideCompilationPass() {
830 return std::make_unique<TPUExtractOutsideCompilation>();
831 }
832
833 } // namespace TFTPU
834 } // namespace mlir
835