1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Diagnostics.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/Types.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
49 #include "tensorflow/compiler/xla/xla.pb.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58
59 namespace mlir {
60 namespace TFTPU {
61
62 // NOLINTNEXTLINE
63 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
64 "tpu_compile_metadata_debug",
65 llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
66 "'tf._TPUCompileMlir' op as a proto debug string"));
67
68 constexpr char kNumReplicasAttr[] = "num_replicas";
69 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
70 constexpr char kDeviceAttr[] = "device";
71 constexpr char kDevicesAttr[] = "devices";
72 constexpr char kVersionsAttr[] = "tf.versions";
73 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
74
75 constexpr char kBadStringArrayElementMsg[] =
76 "bad '{0}' attribute at index {1}, not a string";
77 constexpr char kBadArrayElementMsg[] =
78 "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
79 constexpr char kBadArrayAttrLengthMsg[] =
80 "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
81
82 namespace {
83 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
84 void runOnOperation() override;
85 };
86
87 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)88 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
89 return llvm::formatv("requires attribute '{0}'", attribute).str();
90 }
91
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)92 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
93 std::string* serialized_func_module) {
94 ModuleOp module = entry_func->getParentOfType<ModuleOp>();
95 SymbolTable entry_module_table(module);
96 llvm::SmallVector<FuncOp, 4> referenced({entry_func});
97
98 // Create a new module to hold func and all referenced functions.
99 OwningModuleRef module_for_func =
100 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
101 auto parent_module = entry_func->getParentOfType<ModuleOp>();
102 auto versions_attr = parent_module->getAttr(kVersionsAttr);
103 if (!versions_attr)
104 return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
105
106 module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
107 SymbolTable symbol_table(module_for_func.get());
108
109 while (!referenced.empty()) {
110 auto func = referenced.pop_back_val();
111
112 // Skip functions that have already been cloned into new module.
113 if (symbol_table.lookup<FuncOp>(func.getName())) continue;
114
115 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
116 // all found FuncOps to new_module to make sure new_module is
117 // self-contained.
118 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
119 assert(uses && "expected to be able to collect symbol uses");
120 for (SymbolTable::SymbolUse use : *uses) {
121 FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
122 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
123
124 // Skip Symbols that do not map to a function.
125 if (!referenced_func) continue;
126
127 referenced.emplace_back(referenced_func);
128 }
129
130 auto clone = func.clone();
131 if (clone.getName() == entry_func.getName()) {
132 // We can simply change name of TPU program's main function because there
133 // should be no other reference to it.
134 clone.setName("main");
135 clone.setPublic();
136 } else {
137 clone.setPrivate();
138 }
139 symbol_table.insert(clone);
140 }
141
142 *serialized_func_module =
143 tensorflow::SerializeMlirModule(module_for_func.get());
144 return success();
145 }
146
147 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
148 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)149 LogicalResult SetMetadataProtoStepMarkerLocation(
150 tf_device::ClusterFuncOp op,
151 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
152 auto step_marker_location =
153 op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
154 if (!step_marker_location)
155 return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
156
157 // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
158 // empty.
159 xla::DebugOptions::StepMarkerLocation location =
160 xla::DebugOptions::STEP_MARK_AT_ENTRY;
161 if (!step_marker_location.getValue().empty() &&
162 !xla::DebugOptions::StepMarkerLocation_Parse(
163 std::string(step_marker_location.getValue()), &location))
164 return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
165 kStepMarkerLocationAttr,
166 step_marker_location.getValue()));
167
168 metadata->set_step_marker_location(location);
169
170 return success();
171 }
172
173 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)174 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
175 int index, xla::OpSharding* sharding) {
176 auto sharding_str = attr.dyn_cast<StringAttr>();
177 if (!sharding_str)
178 return op->emitOpError(
179 llvm::formatv(kBadStringArrayElementMsg, name, index));
180
181 if (!sharding->ParseFromString(sharding_str.getValue().str()))
182 return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
183 sharding_str.getValue(),
184 "xla::OpSharding"));
185
186 return success();
187 }
188
189 // Populates a TPUCompileMetadataProto with argument types and sharding from a
190 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)191 LogicalResult SetMetadataProtoArgs(
192 tf_device::ClusterFuncOp op,
193 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
194 auto input_shardings =
195 op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
196 if (!input_shardings)
197 return op.emitOpError(
198 CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
199
200 if (input_shardings.size() != op.getNumOperands())
201 return op.emitOpError(
202 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
203 op.getNumOperands(), input_shardings.size()));
204
205 // Set args metadata in proto.
206 mlir::Identifier replication_attr_name = mlir::Identifier::get(
207 "mhlo.is_same_data_across_replicas", op.getContext());
208 for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
209 Type operand_type = operand_type_and_idx.value();
210 int index = operand_type_and_idx.index();
211 tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
212 tensorflow::DataType dtype;
213 tensorflow::Status status =
214 tensorflow::ConvertToDataType(operand_type, &dtype);
215 if (!status.ok())
216 return op.emitOpError(
217 llvm::formatv("failed to determine operand type at index {0}: {1}",
218 index, status.error_message()));
219
220 arg->set_dtype(dtype);
221 // TODO(lyandy): Support other arg kinds.
222 if (dtype == tensorflow::DT_RESOURCE)
223 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
224 else
225 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
226
227 // Populate argument shapes.
228 *arg->mutable_shape() = tensorflow::TensorShapeProto();
229 if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
230 tensorflow::TensorShapeProto shape_proto;
231 ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
232 *arg->mutable_shape() = std::move(shape_proto);
233 } else {
234 arg->mutable_shape()->set_unknown_rank(true);
235 }
236
237 if (failed(SetOpSharding(op, input_shardings.getValue()[index],
238 tensorflow::kInputShardingAttr, index,
239 arg->mutable_sharding())))
240 return failure();
241
242 // Populate set_is_same_data_across_replicas
243 // Note: this information is duplicated and can be removed from the proto
244 // and here once MLIR bridge phase 2 doesn't fallback to the old bridge.
245 mlir::UnitAttr attr = op.getFunc().getArgAttrOfType<mlir::UnitAttr>(
246 index, replication_attr_name);
247 arg->set_is_same_data_across_replicas(attr != nullptr);
248 }
249
250 return success();
251 }
252
253 // Populates a TPUCompileMetadataProto with result sharding from a
254 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)255 LogicalResult SetMetadataProtoRetvals(
256 tf_device::ClusterFuncOp op,
257 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
258 auto output_shardings =
259 op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
260 if (!output_shardings)
261 return op.emitOpError(
262 CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
263
264 if (output_shardings.size() != op.getNumResults())
265 return op.emitOpError(
266 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
267 op.getNumResults(), output_shardings.size()));
268
269 // Set retvals metadata in proto.
270 for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
271 if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
272 tensorflow::kOutputShardingAttr,
273 output_sharding_and_idx.index(),
274 metadata->add_retvals()->mutable_sharding())))
275 return failure();
276
277 return success();
278 }
279
280 // Populates a TPUCompileMetadataProto from attributes of a
281 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
282 // op, a failure will be returned.
283 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)284 LogicalResult SetMetadataProtoFromClusterFuncOp(
285 tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
286 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
287 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
288 metadata->set_num_replicas(num_replicas);
289 metadata->set_num_cores_per_replica(num_cores_per_replica);
290
291 if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
292 return failure();
293
294 if (xla_device_assignment.hasValue())
295 *metadata->mutable_device_assignment() =
296 std::move(xla_device_assignment.getValue());
297 auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
298 if (!use_spmd_attr)
299 return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
300 metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
301
302 if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
303
304 return SetMetadataProtoRetvals(op, metadata);
305 }
306
307 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)308 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
309 Operation* op, llvm::StringRef device) {
310 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
311
312 auto launch = builder->create<tf_device::LaunchOp>(
313 loc, builder->getStringAttr(device), op->getResultTypes());
314 launch.body().push_back(new Block);
315
316 builder->setInsertionPointToEnd(&launch.GetBody());
317 builder->create<tf_device::ReturnOp>(loc, op->getResults());
318
319 // Move op inside cluster.
320 op->moveBefore(launch.GetBody().getTerminator());
321
322 builder->restoreInsertionPoint(insert_point);
323
324 return launch;
325 }
326
327 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
328 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder)329 Operation* BuildCompileOp(
330 tf_device::ClusterFuncOp cluster_func, int num_replicas,
331 int num_cores_per_replica, llvm::StringRef compilation_device,
332 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
333 OpBuilder* builder) {
334 // Set metadata from attributes.
335 tensorflow::tpu::TPUCompileMetadataProto metadata;
336 if (failed(SetMetadataProtoFromClusterFuncOp(
337 cluster_func, num_replicas, num_cores_per_replica,
338 std::move(xla_device_assignment), &metadata)))
339 return nullptr;
340
341
342 // Build a shape op for each input to cluster_func.
343 // TODO(b/139377366): When shape inference is ready, we can use compile time
344 // shape inference to get inputs that have static shapes and only use shape
345 // ops for the rest.
346 llvm::SmallVector<Value, 4> compile_op_operands;
347 compile_op_operands.reserve(cluster_func.getNumOperands());
348
349 for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
350 // Skip adding shape op for operands that have static shapes.
351 tensorflow::PartialTensorShape shape(
352 metadata.args(operand_and_idx.index()).shape());
353 if (shape.IsFullyDefined()) continue;
354
355 auto shape_op = builder->create<TF::ShapeOp>(
356 cluster_func.getLoc(),
357 RankedTensorType::get({-1}, builder->getIntegerType(64)),
358 operand_and_idx.value());
359 compile_op_operands.emplace_back(shape_op.getResult());
360 }
361
362 FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
363 FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
364 func_attr.getValue());
365
366 std::string txt_module;
367 if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
368
369 auto compilation_status_type =
370 RankedTensorType::get({}, builder->getType<TF::StringType>());
371 auto program_type =
372 RankedTensorType::get({3}, builder->getType<TF::StringType>());
373
374 // Add MLIR module's fingerprint to compile metadata.
375 uint64_t mlir_fingerprint = tensorflow::Fingerprint64(txt_module);
376 metadata.set_mlir_fingerprint(mlir_fingerprint);
377
378 std::string txt_metadata;
379 if (tpu_compile_metadata_debug)
380 txt_metadata = metadata.DebugString();
381 else
382 metadata.SerializeToString(&txt_metadata);
383
384 auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
385 cluster_func.getLoc(),
386 /*compilation_status=*/compilation_status_type, /*program=*/
387 llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
388 compile_op_operands, txt_module, txt_metadata);
389
390 return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
391 compilation_device);
392 }
393
394 // Assigns explicit devices to replicate op. An aliased device is created per
395 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)396 void AssignDevicesToReplicate(
397 tf_device::ReplicateOp replicate,
398 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
399 tpu_devices,
400 OpBuilder* builder) {
401 if (!replicate) return;
402
403 const int num_replicas = tpu_devices.size();
404 const int num_cores_per_replica = tpu_devices.front().size();
405
406 llvm::SmallVector<NamedAttribute, 8> device_attrs;
407 for (int core = 0; core < num_cores_per_replica; ++core) {
408 llvm::SmallVector<StringRef, 8> devices_by_core;
409 devices_by_core.reserve(num_replicas);
410 for (int replica = 0; replica < num_replicas; ++replica)
411 devices_by_core.push_back(tpu_devices[replica][core].device);
412
413 device_attrs.push_back(
414 builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
415 builder->getStrArrayAttr(devices_by_core)));
416 }
417
418 // For data parallelism, also add replicated host devices, as these are
419 // necessary for outside compilation.
420 if (num_cores_per_replica == 1) {
421 llvm::SmallVector<StringRef, 8> hosts;
422 hosts.reserve(num_replicas);
423 for (int replica = 0; replica < num_replicas; ++replica)
424 hosts.push_back(tpu_devices[replica][0].host);
425
426 device_attrs.push_back(builder->getNamedAttr(
427 tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
428 }
429
430 replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
431 }
432
433 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)434 LogicalResult BuildExecuteOp(
435 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
436 llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
437 OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
438 // TODO(b/139377366): Need to snapshot all resource variable inputs in
439 // follow-up CLs.
440 llvm::SmallVector<Type, 4> output_types;
441 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
442 core_id, output_sharding_config, cluster_func, &output_types);
443 if (failed(result)) return failure();
444
445 // TPUExecute has same output types as cluster_func.
446 *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
447 output_types, inputs);
448 return success();
449 }
450
451 // Creates a tf_device.parallel_execute op that wraps TPUExecute op to
452 // represent execution of TPU program in multiple logical cores.
BuildParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute_op)453 LogicalResult BuildParallelExecuteOp(
454 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
455 tpu_devices,
456 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
457 Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
458 OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
459 const int num_cores_per_replica = tpu_devices.front().size();
460 // parallel_execute op returns concatenated list of return values of
461 // all its regions.
462 //
463 // TODO(b/149102702): Correctly map inputs to parallel_execute op via
464 // identifying xla_sharding op in the cluster_func function.
465 const auto cluster_result_types = cluster_func.getResultTypes();
466 llvm::SmallVector<Type, 8> concatenated_output_types;
467 concatenated_output_types.reserve(cluster_result_types.size() *
468 num_cores_per_replica);
469
470 for (int core = 0; core < num_cores_per_replica; ++core) {
471 llvm::SmallVector<Type, 4> output_types;
472 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
473 core, output_sharding_config, cluster_func, &output_types);
474 if (failed(result)) return failure();
475
476 for (Type t : output_types) concatenated_output_types.emplace_back(t);
477 }
478
479 *parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
480 cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types);
481
482 // Extract inputs for each region of the parallel_execute op. The i-th
483 // element in the list represents the input lists to TPU computation for
484 // i-th logical core.
485 llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
486 builder->setInsertionPoint(*parallel_execute_op);
487 auto result = tensorflow::ExtractInputsForLogicalDevices(
488 num_cores_per_replica, cluster_func, builder, &input_list);
489 if (failed(result)) return failure();
490
491 const bool replicated = tpu_devices.size() != 1;
492 // For each logical core, create a region with TPUExecute op.
493 assert(input_list.size() == num_cores_per_replica);
494 for (int core = 0; core < num_cores_per_replica; ++core) {
495 auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
496 builder->setInsertionPointToEnd(®ion);
497
498 // Create Execute op.
499 //
500 // TODO(b/148913294): Identify inputs/return values specific to each
501 // logical core TPU execution by parsing xla_sharding op in
502 // cluster_func.
503 auto execute_inputs = input_list[core];
504 execute_inputs.emplace_back(compile_op->getResult(core + 1));
505
506 TF::TPUExecuteOp execute;
507 result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
508 cluster_func, builder, &execute);
509 if (failed(result)) return failure();
510
511 // If computation is replicated, use aliased device. Otherwise there is only
512 // one execution device per core and the device is assigned to the execute
513 // op.
514 std::string device = replicated
515 ? tensorflow::GetDeviceAliasForLogicalCore(core)
516 : tpu_devices.front()[core].device;
517
518 auto region_launch_op =
519 WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
520
521 builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
522 region_launch_op.getResults());
523 }
524
525 return success();
526 }
527
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)528 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
529 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
530 tpu_devices,
531 Operation* execute_op, OpBuilder* builder) {
532 const bool replicated = tpu_devices.size() != 1;
533 // If computation is replicated, use aliased device. Otherwise there is only
534 // one execution device and the device is assigned to the execute op.
535 std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
536 : tpu_devices.front().front().device;
537
538 return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
539 }
540
541 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
542 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,Operation * result_id,llvm::StringRef compilation_device,OpBuilder * builder)543 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
544 Operation* result_id,
545 llvm::StringRef compilation_device,
546 OpBuilder* builder) {
547 auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
548 compile_op->getLoc(), result_id->getResult(0));
549 WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
550 }
551
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,ArrayRef<TF::TPUCompilationResultOp> compilation_result,OpBuilder * builder)552 LogicalResult Rewrite(
553 tf_device::ClusterFuncOp cluster_func,
554 llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
555 ArrayRef<TF::TPUCompilationResultOp> compilation_result,
556 OpBuilder* builder) {
557 // Collect `num_replicas` and `num_cores_per_replica` attributes.
558 int num_replicas = 1;
559 tf_device::ReplicateOp replicate =
560 cluster_func->getParentOfType<tf_device::ReplicateOp>();
561 if (replicate) num_replicas = replicate.n();
562
563 auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
564 tensorflow::kNumCoresPerReplicaAttr);
565 if (!num_cores_per_replica_attr)
566 return cluster_func.emitOpError(
567 CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
568
569 int num_cores_per_replica = num_cores_per_replica_attr.getInt();
570
571 auto topology_attr =
572 cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
573 if (!topology_attr)
574 return cluster_func.emitOpError(
575 CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
576
577 auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
578 tensorflow::kDeviceAssignmentAttr);
579 if (!device_assignment_attr)
580 return cluster_func.emitOpError(
581 llvm::formatv("requires attribute '{0}'",
582 tensorflow::kDeviceAssignmentAttr)
583 .str());
584
585 auto status_or_device_coodinates =
586 tensorflow::GetDeviceCoordinates(device_assignment_attr);
587 if (!status_or_device_coodinates.ok())
588 return cluster_func.emitError()
589 << "error in fetching tpu device coordinates: "
590 << status_or_device_coodinates.status().error_message();
591
592 // Determine compilation and execution devices.
593 auto status_or_tpu_device_assignment =
594 tensorflow::GetTPUCompilationAndExecutionDevices(
595 devices, num_replicas, num_cores_per_replica,
596 topology_attr.getValue(),
597 status_or_device_coodinates.ConsumeValueOrDie());
598 if (!status_or_tpu_device_assignment.ok())
599 return cluster_func.emitError()
600 << "error in fetching TPU compilation/execution devices: "
601 << status_or_tpu_device_assignment.status().error_message();
602
603 // Create compile op.
604 auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
605 builder->setInsertionPoint(cluster_func);
606
607 // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
608 // parallel_execute region if it exists.
609 if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
610 // Currently, outside compilation and model parallelism are not supported
611 // together.
612 assert(num_cores_per_replica == 1);
613 builder->setInsertionPoint(cluster_func->getParentOp());
614 }
615
616 Operation* compile_op = BuildCompileOp(
617 cluster_func, num_replicas, num_cores_per_replica,
618 tpu_device_assignment.compilation_device,
619 std::move(tpu_device_assignment.xla_device_assignment), builder);
620 if (!compile_op) return failure();
621
622 // This replaces _TPUCompileMlir placeholder ops that are required
623 // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
624 // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
625 // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
626 // structured lowering.
627 if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
628 cluster_func->getParentOp())) {
629 parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
630 key_op.replaceAllUsesWith(compile_op->getResult(1));
631 key_op.erase();
632 });
633 }
634
635 // After rewrite, if there is a TPUCompilationResultOp from the same cluster,
636 // replace it with the result of the compile op. The TPUCompilationResultOp is
637 // used as a placeholder to hook during graph creation the other ops that are
638 // intended to consume the compile result.
639 Operation* result_id = compile_op;
640 // TODO(jpienaar): Remove this later.
641 auto compile_device_op = compile_op->getAttr("device");
642 for (auto res : compilation_result) {
643 // Build identity op with the same location/name as the original compilation
644 // result op.
645 result_id = builder->create<TF::IdentityOp>(
646 res.getLoc(), compile_op->getResult(0).getType(),
647 result_id->getResult(0));
648 // Assign to same device as result is currently set, unless unset and then
649 // assign to the device on which compilation will happen.
650 // TODO(jpienaar): Remove this later.
651 if (auto device = res->getAttrOfType<StringAttr>("device")) {
652 if (!device.getValue().empty())
653 result_id->setAttr("device", device);
654 else
655 result_id->setAttr("device", compile_device_op);
656 } else if (compile_device_op) {
657 result_id->setAttr("device", compile_device_op);
658 }
659 res.output().replaceAllUsesWith(compile_op->getResult(0));
660 }
661
662 BuildTPUCompileSucceededAssertOp(
663 compile_op, result_id, tpu_device_assignment.compilation_device, builder);
664
665 AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
666 builder);
667
668 llvm::SmallVector<xla::OpSharding, 4> output_shardings;
669 auto result = tensorflow::ParseAndValidateOutputSharding(
670 num_cores_per_replica, cluster_func, &output_shardings);
671 if (failed(result)) return failure();
672
673 builder->setInsertionPoint(cluster_func);
674 if (num_cores_per_replica > 1) {
675 // For model parallelism, tf_device.parallel_execute is used to express
676 // concurrent device execution across multiple logical devices.
677
678 tf_device::ParallelExecuteOp execute_op;
679 result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
680 output_shardings, compile_op, cluster_func,
681 builder, &execute_op);
682 if (failed(result)) return failure();
683
684 // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
685 // ops, the number of return values of parallel_execute op exceeds that of
686 // cluster_func op. As so, each return value of parallel_execute op must be
687 // mapped with corresponding return value usages of cluster_func.
688 return tensorflow::RemapOutputsFromLogicalDevices(
689 cluster_func.getLoc(), output_shardings, cluster_func, execute_op,
690 builder);
691 }
692
693 llvm::SmallVector<Value, 4> execute_inputs(cluster_func.getOperands());
694 execute_inputs.emplace_back(compile_op->getResult(1));
695
696 TF::TPUExecuteOp execute_op;
697 result = BuildExecuteOp(
698 /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder,
699 &execute_op);
700 if (failed(result)) return failure();
701
702 tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
703 tpu_device_assignment.tpu_devices, execute_op, builder);
704 cluster_func.replaceAllUsesWith(launch_op);
705 return success();
706 }
707
708 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
709 // TPUPartitionedOutputOp are present, they must be removed alongwith the
710 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)711 void EraseClusterFuncs(
712 llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
713 for (auto cluster : to_be_erased) {
714 for (auto result : cluster.results()) {
715 for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
716 if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
717 assert(user->use_empty());
718 user->erase();
719 }
720 }
721 }
722
723 for (auto operand : cluster.operands()) {
724 Operation* def = operand.getDefiningOp();
725 if (operand.hasOneUse() &&
726 llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
727 operand.dropAllUses();
728 def->erase();
729 }
730 }
731
732 assert(cluster->use_empty());
733 cluster->erase();
734 }
735 }
736
runOnOperation()737 void TPURewritePass::runOnOperation() {
738 mlir::TF::RuntimeDevices devices;
739 if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
740 return signalPassFailure();
741
742 // Collect compilation results.
743 llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
744 compilation_results;
745 auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
746 auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
747 if (!cluster_id) {
748 op->emitOpError("missing '_tpu_compilation_status'");
749 return WalkResult::interrupt();
750 }
751 compilation_results[cluster_id].push_back(op);
752 return WalkResult::advance();
753 });
754 if (result_init.wasInterrupted()) return signalPassFailure();
755
756 llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
757 OpBuilder builder(&getContext());
758 auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
759 // Skip non-tpu device cluster_func.
760 auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_replicate");
761 if (!cluster_id) return WalkResult::advance();
762
763 if (failed(Rewrite(op, devices.device_names(),
764 compilation_results[cluster_id], &builder)))
765 return WalkResult::interrupt();
766
767 to_be_erased.push_back(op);
768 return WalkResult::advance();
769 });
770 if (result.wasInterrupted()) return signalPassFailure();
771
772 EraseClusterFuncs(to_be_erased);
773
774 // Eliminate TPUCompilationResultOp now that the rewrite is complete.
775 for (auto& it : compilation_results) {
776 for (auto op : it.second) {
777 if (!op.use_empty()) {
778 mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
779 for (auto user : op->getUsers())
780 err.attachNote(user->getLoc()) << "remaining user";
781 return signalPassFailure();
782 }
783 op.erase();
784 }
785 }
786
787 // TODO(b/139377366): Remove functions that are no longer needed.
788 }
789
790 } // namespace
791
CreateTPURewritePass()792 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
793 return std::make_unique<TPURewritePass>();
794 }
795
796 } // namespace TFTPU
797 } // namespace mlir
798