1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Diagnostics.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
51 #include "tensorflow/compiler/xla/xla.pb.h"
52 #include "tensorflow/compiler/xla/xla_data.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/platform/fingerprint.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61
62 namespace mlir {
63 namespace TFTPU {
64
65 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
66 constexpr char kDevicesAttr[] = "devices";
67 constexpr char kVersionsAttr[] = "tf.versions";
68 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
69
70 constexpr char kBadStringArrayElementMsg[] =
71 "bad '{0}' attribute at index {1}, not a string";
72 constexpr char kBadArrayElementMsg[] =
73 "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
74 constexpr char kBadArrayAttrLengthMsg[] =
75 "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
76
77 namespace {
78 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
79 void runOnOperation() override;
80 };
81
82 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)83 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
84 return llvm::formatv("requires attribute '{0}'", attribute).str();
85 }
86
EncapsulateFuncAndSerialize(func::FuncOp entry_func,std::string * serialized_func_module)87 LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func,
88 std::string* serialized_func_module) {
89 ModuleOp module = entry_func->getParentOfType<ModuleOp>();
90 SymbolTable entry_module_table(module);
91 llvm::SmallVector<func::FuncOp, 4> referenced({entry_func});
92
93 // Create a new module to hold func and all referenced functions.
94 OwningOpRef<mlir::ModuleOp> module_for_func =
95 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
96 auto parent_module = entry_func->getParentOfType<ModuleOp>();
97 auto versions_attr = parent_module->getAttr(kVersionsAttr);
98 if (!versions_attr)
99 return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
100
101 module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
102 SymbolTable symbol_table(module_for_func.get());
103
104 while (!referenced.empty()) {
105 auto func = referenced.pop_back_val();
106
107 // Skip functions that have already been cloned into new module.
108 if (symbol_table.lookup<func::FuncOp>(func.getName())) continue;
109
110 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
111 // all found FuncOps to new_module to make sure new_module is
112 // self-contained.
113 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
114 assert(uses && "expected to be able to collect symbol uses");
115 for (SymbolTable::SymbolUse use : *uses) {
116 func::FuncOp referenced_func = entry_module_table.lookup<func::FuncOp>(
117 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
118
119 // Skip Symbols that do not map to a function.
120 if (!referenced_func) continue;
121
122 referenced.emplace_back(referenced_func);
123 }
124
125 auto clone = func.clone();
126 if (clone.getName() == entry_func.getName()) {
127 // We can simply change name of TPU program's main function because there
128 // should be no other reference to it.
129 clone.setName(StringAttr::get(clone.getContext(), "main"));
130 clone.setPublic();
131 } else {
132 clone.setPrivate();
133 }
134 symbol_table.insert(clone);
135 }
136
137 *serialized_func_module =
138 tensorflow::SerializeMlirModule(module_for_func.get());
139 return success();
140 }
141
142 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
143 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)144 LogicalResult SetMetadataProtoStepMarkerLocation(
145 tf_device::ClusterFuncOp op,
146 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
147 auto step_marker_location =
148 op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
149 if (!step_marker_location)
150 return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
151
152 // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
153 // empty.
154 xla::DebugOptions::StepMarkerLocation location =
155 xla::DebugOptions::STEP_MARK_AT_ENTRY;
156 if (!step_marker_location.getValue().empty() &&
157 !xla::DebugOptions::StepMarkerLocation_Parse(
158 std::string(step_marker_location.getValue()), &location))
159 return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
160 kStepMarkerLocationAttr,
161 step_marker_location.getValue()));
162
163 metadata->set_step_marker_location(location);
164
165 return success();
166 }
167
168 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)169 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
170 int index, xla::OpSharding* sharding) {
171 auto sharding_str = attr.dyn_cast<StringAttr>();
172 if (!sharding_str)
173 return op->emitOpError(
174 llvm::formatv(kBadStringArrayElementMsg, name, index));
175
176 if (!sharding->ParseFromString(sharding_str.getValue().str()))
177 return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
178 sharding_str.getValue(),
179 "xla::OpSharding"));
180
181 return success();
182 }
183
184 // Populates a TPUCompileMetadataProto with argument types and sharding from a
185 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)186 LogicalResult SetMetadataProtoArgs(
187 tf_device::ClusterFuncOp op,
188 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
189 auto input_shardings =
190 op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
191 if (!input_shardings)
192 return op.emitOpError(
193 CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
194
195 if (input_shardings.size() != op.getNumOperands())
196 return op.emitOpError(
197 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
198 op.getNumOperands(), input_shardings.size()));
199
200 // Set args metadata in proto.
201 mlir::StringAttr replication_attr_name = mlir::StringAttr::get(
202 op.getContext(), "mhlo.is_same_data_across_replicas");
203 for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
204 Type operand_type = operand_type_and_idx.value();
205 int index = operand_type_and_idx.index();
206 tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
207 tensorflow::DataType dtype;
208 tensorflow::Status status =
209 tensorflow::ConvertToDataType(operand_type, &dtype);
210 if (!status.ok())
211 return op.emitOpError(
212 llvm::formatv("failed to determine operand type at index {0}: {1}",
213 index, status.error_message()));
214
215 arg->set_dtype(dtype);
216 // TODO(lyandy): Support other arg kinds.
217 if (dtype == tensorflow::DT_RESOURCE)
218 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
219 else
220 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
221
222 // Populate argument shapes.
223 *arg->mutable_shape() = tensorflow::TensorShapeProto();
224 if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
225 tensorflow::TensorShapeProto shape_proto;
226 ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
227 *arg->mutable_shape() = std::move(shape_proto);
228 } else {
229 arg->mutable_shape()->set_unknown_rank(true);
230 }
231
232 if (failed(SetOpSharding(op, input_shardings.getValue()[index],
233 tensorflow::kInputShardingAttr, index,
234 arg->mutable_sharding())))
235 return failure();
236
237 // Populate set_is_same_data_across_replicas
238 // Note: this information is duplicated and can be removed from the proto
239 // and here once MLIR bridge phase 2 doesn't fallback to the old bridge.
240 mlir::UnitAttr attr = op.getFunc().getArgAttrOfType<mlir::UnitAttr>(
241 index, replication_attr_name);
242 arg->set_is_same_data_across_replicas(attr != nullptr);
243 }
244
245 return success();
246 }
247
248 // Populates a TPUCompileMetadataProto with result sharding from a
249 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)250 LogicalResult SetMetadataProtoRetvals(
251 tf_device::ClusterFuncOp op,
252 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
253 auto output_shardings =
254 op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
255 if (!output_shardings)
256 return op.emitOpError(
257 CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
258
259 if (output_shardings.size() != op.getNumResults())
260 return op.emitOpError(
261 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
262 op.getNumResults(), output_shardings.size()));
263
264 // Set retvals metadata in proto.
265 for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
266 if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
267 tensorflow::kOutputShardingAttr,
268 output_sharding_and_idx.index(),
269 metadata->add_retvals()->mutable_sharding())))
270 return failure();
271
272 return success();
273 }
274
275 // Populates a TPUCompileMetadataProto from attributes of a
276 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
277 // op, a failure will be returned.
278 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)279 LogicalResult SetMetadataProtoFromClusterFuncOp(
280 tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
281 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
282 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
283 if (auto options_attr =
284 op->getAttrOfType<StringAttr>("tpu_compile_options_proto")) {
285 if (!metadata->mutable_compile_options()->ParseFromArray(
286 options_attr.data(), options_attr.size())) {
287 return failure();
288 }
289 }
290 metadata->set_num_replicas(num_replicas);
291 metadata->set_num_cores_per_replica(num_cores_per_replica);
292
293 if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
294 return failure();
295
296 if (xla_device_assignment.has_value())
297 *metadata->mutable_device_assignment() =
298 std::move(xla_device_assignment.getValue());
299 auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
300 if (!use_spmd_attr)
301 return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
302 metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
303
304 if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
305
306 return SetMetadataProtoRetvals(op, metadata);
307 }
308
309 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)310 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
311 Operation* op, llvm::StringRef device) {
312 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
313
314 auto launch = builder->create<tf_device::LaunchOp>(
315 loc, builder->getStringAttr(device), op->getResultTypes());
316 launch.body().push_back(new Block);
317
318 builder->setInsertionPointToEnd(&launch.GetBody());
319 builder->create<tf_device::ReturnOp>(loc, op->getResults());
320
321 // Move op inside cluster.
322 op->moveBefore(launch.GetBody().getTerminator());
323
324 builder->restoreInsertionPoint(insert_point);
325
326 return launch;
327 }
328
329 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
330 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder,bool tpu_compile_metadata_debug)331 Operation* BuildCompileOp(
332 tf_device::ClusterFuncOp cluster_func, int num_replicas,
333 int num_cores_per_replica, llvm::StringRef compilation_device,
334 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
335 OpBuilder* builder, bool tpu_compile_metadata_debug) {
336 // Set metadata from attributes.
337 tensorflow::tpu::TPUCompileMetadataProto metadata;
338 if (failed(SetMetadataProtoFromClusterFuncOp(
339 cluster_func, num_replicas, num_cores_per_replica,
340 std::move(xla_device_assignment), &metadata)))
341 return nullptr;
342
343 // Build a shape op for each input to cluster_func.
344 // TODO(b/139377366): When shape inference is ready, we can use compile time
345 // shape inference to get inputs that have static shapes and only use shape
346 // ops for the rest.
347 llvm::SmallVector<Value, 4> compile_op_operands;
348 compile_op_operands.reserve(cluster_func.getNumOperands());
349
350 for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
351 // Skip adding shape op for operands that have static shapes.
352 tensorflow::PartialTensorShape shape(
353 metadata.args(operand_and_idx.index()).shape());
354 if (shape.IsFullyDefined()) continue;
355
356 auto shape_op = builder->create<TF::ShapeOp>(
357 cluster_func.getLoc(),
358 RankedTensorType::get({-1}, builder->getIntegerType(64)),
359 operand_and_idx.value());
360 compile_op_operands.emplace_back(shape_op.getResult());
361 }
362
363 FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
364 func::FuncOp func =
365 cluster_func->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
366 func_attr.getValue());
367
368 std::string txt_module;
369 if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
370
371 auto compilation_status_type =
372 RankedTensorType::get({}, builder->getType<TF::StringType>());
373 auto program_type =
374 RankedTensorType::get({3}, builder->getType<TF::StringType>());
375
376 // Add MLIR module's fingerprint to compile metadata.
377 uint64_t mlir_fingerprint = tensorflow::Fingerprint64(txt_module);
378 metadata.set_mlir_fingerprint(mlir_fingerprint);
379
380 std::string txt_metadata;
381 if (tpu_compile_metadata_debug) {
382 ::tensorflow::protobuf::TextFormat::Printer printer;
383 printer.SetExpandAny(true);
384 printer.PrintToString(metadata, &txt_metadata);
385 } else {
386 metadata.SerializeToString(&txt_metadata);
387 }
388
389 auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
390 cluster_func.getLoc(),
391 /*compilation_status=*/compilation_status_type, /*program=*/
392 llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
393 compile_op_operands, txt_module, txt_metadata);
394
395 return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
396 compilation_device);
397 }
398
399 // Assigns explicit devices to replicate op. An aliased device is created per
400 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)401 void AssignDevicesToReplicate(
402 tf_device::ReplicateOp replicate,
403 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
404 tpu_devices,
405 OpBuilder* builder) {
406 if (!replicate) return;
407
408 const int num_replicas = tpu_devices.size();
409 const int num_cores_per_replica = tpu_devices.front().size();
410
411 llvm::SmallVector<NamedAttribute, 8> device_attrs;
412 for (int core = 0; core < num_cores_per_replica; ++core) {
413 llvm::SmallVector<StringRef, 8> devices_by_core;
414 devices_by_core.reserve(num_replicas);
415 for (int replica = 0; replica < num_replicas; ++replica)
416 devices_by_core.push_back(tpu_devices[replica][core].device);
417
418 device_attrs.push_back(
419 builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
420 builder->getStrArrayAttr(devices_by_core)));
421 }
422
423 // For data parallelism, also add replicated host devices, as these are
424 // necessary for outside compilation.
425 llvm::SmallVector<StringRef, 8> hosts;
426 hosts.reserve(num_replicas);
427 for (int replica = 0; replica < num_replicas; ++replica)
428 hosts.push_back(tpu_devices[replica][0].host);
429
430 device_attrs.push_back(builder->getNamedAttr(
431 tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
432
433 replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
434 }
435
436 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)437 LogicalResult BuildExecuteOp(
438 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
439 llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
440 OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
441 // TODO(b/139377366): Need to snapshot all resource variable inputs in
442 // follow-up CLs.
443 llvm::SmallVector<Type, 4> output_types;
444 llvm::SmallVector<int, 4> cluster_to_core_index;
445 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
446 core_id, output_sharding_config, cluster_func, &output_types,
447 &cluster_to_core_index);
448 if (failed(result)) return failure();
449
450 // TPUExecute has same output types as cluster_func.
451 *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
452 output_types, inputs);
453 auto producer_name_attr = cluster_func->getAttr("_producer_name");
454 if (producer_name_attr)
455 (*execute_op)->setAttr("_producer_name", producer_name_attr);
456 return success();
457 }
458
459 // Move child processes of the ParallelExecute that do not change. These are all
460 // children except for the child with the ClusterFunc.
461 //
462 // Returns the index of the child with the ClusterFunc.
MovePreservedParallelExecuteChildren(const int num_cores_per_replica,llvm::SmallVector<Type,8> & concatenated_output_types,OpBuilder * builder,tf_device::ClusterFuncOp cluster_func,tf_device::ParallelExecuteOp old_parallel_execute,tf_device::ParallelExecuteOp * new_parallel_execute)463 int MovePreservedParallelExecuteChildren(
464 const int num_cores_per_replica,
465 llvm::SmallVector<Type, 8>& concatenated_output_types, OpBuilder* builder,
466 tf_device::ClusterFuncOp cluster_func,
467 tf_device::ParallelExecuteOp old_parallel_execute,
468 tf_device::ParallelExecuteOp* new_parallel_execute) {
469 // `num_moved_children` is the number of children that will be preserved.
470 const int num_moved_children = old_parallel_execute.regions().size() - 1;
471 *new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
472 old_parallel_execute->getLoc(),
473 num_moved_children + num_cores_per_replica, concatenated_output_types);
474
475 // `cluster_idx` is the index of the child with the `ClusterFuncOp`, which
476 // will be replaced.
477 int cluster_idx = -1;
478 for (int child_idx = 0; child_idx < old_parallel_execute.regions().size();
479 ++child_idx) {
480 auto& block = old_parallel_execute.GetRegionBlockWithIndex(child_idx);
481 if (cluster_func->getBlock() == &block) {
482 assert(cluster_idx == -1);
483 cluster_idx = child_idx;
484 }
485 }
486 assert(cluster_idx != -1);
487
488 // Move children that are preserved.
489 for (int child_idx = 0; child_idx < num_moved_children; ++child_idx) {
490 int old_idx = child_idx >= cluster_idx ? child_idx + 1 : child_idx;
491 int new_idx = child_idx >= cluster_idx ? child_idx + num_cores_per_replica
492 : child_idx;
493 new_parallel_execute->getRegions()[new_idx]->takeBody(
494 *old_parallel_execute.getRegions()[old_idx]);
495 }
496
497 return cluster_idx;
498 }
499
500 // Given a `ParallelExecute`, replace it with a new `ParallelExecute`. The
501 // new `ParallelExecute` will replace the child that contains the
502 // `ClusterFunc` with `num_cores_per_replica` children. It keep other children
503 // the same. Return values from the child with the `ClusterFunc` will be
504 // duplicated `num_cores_per_replica` times.
AddToParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVectorImpl<llvm::SmallVector<int,4>> * cluster_to_core_index,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp old_parallel_execute,tf_device::ParallelExecuteOp * new_parallel_execute,int * cluster_idx)505 LogicalResult AddToParallelExecuteOp(
506 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
507 tpu_devices,
508 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
509 llvm::SmallVectorImpl<llvm::SmallVector<int, 4>>* cluster_to_core_index,
510 Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
511 OpBuilder* builder, tf_device::ParallelExecuteOp old_parallel_execute,
512 tf_device::ParallelExecuteOp* new_parallel_execute, int* cluster_idx) {
513 const int num_cores_per_replica = tpu_devices.front().size();
514 // parallel_execute op returns concatenated list of return values of
515 // all its regions.
516 //
517 // TODO(b/149102702): Correctly map inputs to parallel_execute op via
518 // identifying xla_sharding op in the cluster_func function.
519 const auto cluster_result_types = cluster_func.getResultTypes();
520 llvm::SmallVector<Type, 8> concatenated_output_types;
521 concatenated_output_types.reserve(cluster_result_types.size() *
522 num_cores_per_replica);
523
524 for (int core = 0; core < num_cores_per_replica; ++core) {
525 cluster_to_core_index->emplace_back(llvm::SmallVector<int, 4>());
526 llvm::SmallVector<Type, 4> output_types;
527 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
528 core, output_sharding_config, cluster_func, &output_types,
529 &(*cluster_to_core_index)[core]);
530 if (failed(result)) return failure();
531
532 for (Type t : output_types) concatenated_output_types.emplace_back(t);
533 }
534
535 *cluster_idx = MovePreservedParallelExecuteChildren(
536 num_cores_per_replica, concatenated_output_types, builder, cluster_func,
537 old_parallel_execute, new_parallel_execute);
538
539 // Extract inputs for each block of the parallel_execute op. The i-th
540 // element in the list represents the input lists to TPU computation for
541 // i-th logical core.
542 llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
543 builder->setInsertionPoint(*new_parallel_execute);
544 auto result = tensorflow::ExtractInputsForLogicalDevices(
545 num_cores_per_replica, cluster_func, builder, &input_list);
546 if (failed(result)) return failure();
547
548 const bool replicated = tpu_devices.size() != 1;
549 // For each logical core, create a region with TPUExecute op.
550 assert(input_list.size() == num_cores_per_replica);
551 for (int core = 0; core < num_cores_per_replica; ++core) {
552 auto& block =
553 new_parallel_execute->GetRegionBlockWithIndex((*cluster_idx) + core);
554 builder->setInsertionPointToEnd(&block);
555
556 // Create Execute op.
557 //
558 // TODO(b/148913294): Identify inputs/return values specific to each
559 // logical core TPU execution by parsing xla_sharding op in
560 // cluster_func.
561 auto execute_inputs = input_list[core];
562 execute_inputs.emplace_back(compile_op->getResult(core + 1));
563
564 TF::TPUExecuteOp execute;
565 result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
566 cluster_func, builder, &execute);
567 if (failed(result)) return failure();
568
569 // If computation is replicated, use aliased device. Otherwise there is only
570 // one execution device per core and the device is assigned to the execute
571 // op.
572 std::string device = replicated
573 ? tensorflow::GetDeviceAliasForLogicalCore(core)
574 : tpu_devices.front()[core].device;
575
576 auto block_launch_op =
577 WrapOpInLaunch(builder, block.getParent()->getLoc(), execute, device);
578
579 builder->create<tf_device::ReturnOp>(block.getParent()->getLoc(),
580 block_launch_op.getResults());
581 }
582
583 return success();
584 }
585
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)586 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
587 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
588 tpu_devices,
589 Operation* execute_op, OpBuilder* builder) {
590 const bool replicated = tpu_devices.size() != 1;
591 // If computation is replicated, use aliased device. Otherwise there is only
592 // one execution device and the device is assigned to the execute op.
593 std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
594 : tpu_devices.front().front().device;
595
596 return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
597 }
598
599 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
600 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,Operation * result_id,llvm::StringRef compilation_device,OpBuilder * builder)601 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
602 Operation* result_id,
603 llvm::StringRef compilation_device,
604 OpBuilder* builder) {
605 auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
606 compile_op->getLoc(), result_id->getResult(0));
607 WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
608 }
609
610 // Wrap the `ClusterFunc` in a `ParallelExecute` with only one child. This
611 // creates a canonical form, regardless of whether a `ParallelExecute` was
612 // created in a previous pass.
BuildSingletonParallelExecuteOp(tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute)613 void BuildSingletonParallelExecuteOp(
614 tf_device::ClusterFuncOp cluster_func, OpBuilder* builder,
615 tf_device::ParallelExecuteOp* parallel_execute) {
616 if (!*parallel_execute) {
617 const auto output_types = cluster_func.getResultTypes();
618 builder->setInsertionPoint(cluster_func);
619 *parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
620 cluster_func.getLoc(), 1, output_types);
621 cluster_func->remove();
622 auto& block = parallel_execute->GetRegionBlockWithIndex(0);
623 builder->setInsertionPointToEnd(&block);
624 builder->insert(cluster_func);
625 cluster_func.replaceAllUsesWith(*parallel_execute);
626 builder->create<tf_device::ReturnOp>(block.getParent()->getLoc(),
627 cluster_func.getResults());
628 }
629 }
630
631 // Unwrap the `ParallelExecute`'s contents if it only has one child.
RemoveSingletonParallelExecuteOp(tf_device::ParallelExecuteOp parallel_execute,OpBuilder * builder)632 LogicalResult RemoveSingletonParallelExecuteOp(
633 tf_device::ParallelExecuteOp parallel_execute, OpBuilder* builder) {
634 if (parallel_execute.regions().size() == 1) {
635 builder->setInsertionPoint(parallel_execute);
636 auto& block = parallel_execute.GetRegionBlockWithIndex(0);
637 llvm::SmallVector<Operation*, 2> ops_move;
638 for (Operation& op : block) {
639 ops_move.push_back(&op);
640 }
641 if (ops_move.size() != 2) {
642 parallel_execute.emitError() << "Expected 2 ops in parallel_execute.";
643 return failure();
644 }
645 auto launch = llvm::dyn_cast<tf_device::LaunchOp>(ops_move[0]);
646 if (!launch) {
647 parallel_execute.emitError()
648 << "Expected the op in parallel_execute to be a tf_device.launch";
649 return failure();
650 }
651 launch->remove();
652 builder->insert(launch);
653 parallel_execute.replaceAllUsesWith(launch);
654 parallel_execute.erase();
655 }
656 return success();
657 }
658
CheckTPUPartitionedInputAndOutputAreValid(tf_device::ClusterFuncOp cluster,tf_device::ParallelExecuteOp parallel_execute)659 LogicalResult CheckTPUPartitionedInputAndOutputAreValid(
660 tf_device::ClusterFuncOp cluster,
661 tf_device::ParallelExecuteOp parallel_execute) {
662 for (auto cluster_result : parallel_execute.execute_outputs()) {
663 for (Operation* user :
664 llvm::make_early_inc_range(cluster_result.getUsers())) {
665 // Check that user has no outputs that are TPUPartitionedOutput
666 for (auto result : user->getResults()) {
667 for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
668 if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
669 user->emitError() << "Input of TPUPartitionedOutput must "
670 << "be in tpu computation.";
671 return failure();
672 }
673 }
674 }
675 }
676 }
677 for (auto cluster_operand : cluster.operands()) {
678 Operation* def = cluster_operand.getDefiningOp();
679 // This pass assumes that a TPUPartitionedInput is preceeded by
680 // ReadVariable ops, and not vice versa. An earlier pass,
681 // TPUResourceReadsWritesPartitioning, should have ensured this
682 // precondition.
683 if (!def) continue;
684 for (auto operand : def->getOperands()) {
685 Operation* def_of_read = operand.getDefiningOp();
686 if (llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def_of_read)) {
687 def_of_read->emitError() << "Output of TPUPartitionedInput must "
688 << "be in tpu computation.";
689 return failure();
690 }
691 }
692 }
693 return success();
694 }
695
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,ArrayRef<TF::TPUCompilationResultOp> compilation_result,OpBuilder * builder,bool tpu_compile_metadata_debug)696 LogicalResult Rewrite(
697 tf_device::ClusterFuncOp cluster_func,
698 llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
699 ArrayRef<TF::TPUCompilationResultOp> compilation_result, OpBuilder* builder,
700 bool tpu_compile_metadata_debug) {
701 // Fetch the ParallelExecute parent of `cluster_func`, or create it if it does
702 // not exist.
703 tf_device::ParallelExecuteOp old_parallel_execute =
704 cluster_func->getParentOfType<tf_device::ParallelExecuteOp>();
705 if (old_parallel_execute &&
706 cluster_func->getParentOp() != old_parallel_execute) {
707 cluster_func->emitError() << "The ParallelExecute ancestor of a "
708 "ClusterFunc must be its direct parent.";
709 return failure();
710 }
711 BuildSingletonParallelExecuteOp(cluster_func, builder, &old_parallel_execute);
712
713 // check TPUPartitionedInput and TPUPartitionedOutput are in valid pattern
714 if (failed(CheckTPUPartitionedInputAndOutputAreValid(cluster_func,
715 old_parallel_execute)))
716 return failure();
717
718 // Collect `num_replicas` and `num_cores_per_replica` attributes.
719 int num_replicas = 1;
720 tf_device::ReplicateOp replicate =
721 cluster_func->getParentOfType<tf_device::ReplicateOp>();
722 if (replicate) num_replicas = replicate.n();
723
724 auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
725 tensorflow::kNumCoresPerReplicaAttr);
726 if (!num_cores_per_replica_attr)
727 return cluster_func.emitOpError(
728 CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
729
730 int num_cores_per_replica = num_cores_per_replica_attr.getInt();
731
732 auto topology_attr =
733 cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
734 if (!topology_attr)
735 return cluster_func.emitOpError(
736 CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
737
738 auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
739 tensorflow::kDeviceAssignmentAttr);
740 if (!device_assignment_attr)
741 return cluster_func.emitOpError(
742 llvm::formatv("requires attribute '{0}'",
743 tensorflow::kDeviceAssignmentAttr)
744 .str());
745
746 auto status_or_device_coodinates =
747 tensorflow::GetDeviceCoordinates(device_assignment_attr);
748 if (!status_or_device_coodinates.ok())
749 return cluster_func.emitError()
750 << "error in fetching tpu device coordinates: "
751 << status_or_device_coodinates.status().error_message();
752
753 // Determine compilation and execution devices.
754 auto status_or_tpu_device_assignment =
755 tensorflow::GetTPUCompilationAndExecutionDevices(
756 devices, num_replicas, num_cores_per_replica,
757 topology_attr.getValue(), status_or_device_coodinates.value());
758 if (!status_or_tpu_device_assignment.ok())
759 return cluster_func.emitError()
760 << "error in fetching TPU compilation/execution devices: "
761 << status_or_tpu_device_assignment.status().error_message();
762
763 // Create compile op.
764 auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
765
766 // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
767 // the parallel_execute.
768 builder->setInsertionPoint(old_parallel_execute);
769 Operation* compile_op =
770 BuildCompileOp(cluster_func, num_replicas, num_cores_per_replica,
771 tpu_device_assignment.compilation_device,
772 std::move(tpu_device_assignment.xla_device_assignment),
773 builder, tpu_compile_metadata_debug);
774 if (!compile_op) return failure();
775
776 // This replaces _TPUCompileMlir placeholder ops that are required
777 // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
778 // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
779 // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
780 // structured lowering.
781 old_parallel_execute.walk(
782 [&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
783 key_op.replaceAllUsesWith(compile_op->getResult(1));
784 key_op.erase();
785 });
786
787 // After rewrite, if there is a TPUCompilationResultOp from the same cluster,
788 // replace it with the result of the compile op. The TPUCompilationResultOp is
789 // used as a placeholder to hook during graph creation the other ops that are
790 // intended to consume the compile result.
791 Operation* result_id = compile_op;
792 // TODO(jpienaar): Remove this later.
793 auto compile_device_op = compile_op->getAttr("device");
794 for (auto res : compilation_result) {
795 // Build identity op with the same location/name as the original compilation
796 // result op.
797 result_id = builder->create<TF::IdentityOp>(
798 res.getLoc(), compile_op->getResult(0).getType(),
799 result_id->getResult(0));
800 // Assign to same device as result is currently set, unless unset and then
801 // assign to the device on which compilation will happen.
802 // TODO(jpienaar): Remove this later.
803 if (auto device = res->getAttrOfType<StringAttr>("device")) {
804 if (!device.getValue().empty())
805 result_id->setAttr("device", device);
806 else
807 result_id->setAttr("device", compile_device_op);
808 } else if (compile_device_op) {
809 result_id->setAttr("device", compile_device_op);
810 }
811 res.output().replaceAllUsesWith(compile_op->getResult(0));
812 }
813
814 BuildTPUCompileSucceededAssertOp(
815 compile_op, result_id, tpu_device_assignment.compilation_device, builder);
816
817 AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
818 builder);
819
820 llvm::SmallVector<xla::OpSharding, 4> output_shardings;
821 auto result = tensorflow::ParseAndValidateOutputSharding(
822 num_cores_per_replica, cluster_func, &output_shardings);
823 if (failed(result)) return failure();
824
825 // For model parallelism, tf_device.parallel_execute is used to express
826 // concurrent device execution across multiple logical devices.
827 tf_device::ParallelExecuteOp new_parallel_execute;
828 int cluster_idx;
829 llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index;
830 cluster_to_core_index.reserve(num_cores_per_replica);
831 result = AddToParallelExecuteOp(
832 tpu_device_assignment.tpu_devices, output_shardings,
833 &cluster_to_core_index, compile_op, cluster_func, builder,
834 old_parallel_execute, &new_parallel_execute, &cluster_idx);
835 if (failed(result)) return failure();
836
837 // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
838 // ops, the number of return values of parallel_execute op exceeds that of
839 // cluster_func op. As such, each return value of parallel_execute op must
840 // be mapped with corresponding return value usages of cluster_func.
841 result = tensorflow::RemapOutputsFromLogicalDevices(
842 cluster_func.getLoc(), output_shardings, cluster_to_core_index,
843 old_parallel_execute, cluster_idx, new_parallel_execute, builder);
844 if (failed(result)) return failure();
845
846 return RemoveSingletonParallelExecuteOp(new_parallel_execute, builder);
847 }
848
849 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
850 // TPUPartitionedOutputOp are present, they must be removed along with the
851 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)852 void EraseClusterFuncs(
853 llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
854 for (auto cluster : to_be_erased) {
855 auto old_parallel_execute =
856 cluster->getParentOfType<tf_device::ParallelExecuteOp>();
857 assert(old_parallel_execute);
858
859 for (auto result : old_parallel_execute.execute_outputs()) {
860 for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
861 if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
862 assert(user->use_empty());
863 user->erase();
864 }
865 }
866 }
867
868 for (auto operand : cluster.operands()) {
869 Operation* def = operand.getDefiningOp();
870 if (operand.hasOneUse() &&
871 llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
872 operand.dropAllUses();
873 def->erase();
874 }
875 }
876 assert(old_parallel_execute->use_empty());
877 old_parallel_execute->erase();
878 }
879 }
880
runOnOperation()881 void TPURewritePass::runOnOperation() {
882 mlir::TF::RuntimeDevices devices;
883 if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
884 return signalPassFailure();
885
886 // Collect compilation results.
887 llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
888 compilation_results;
889 auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
890 auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
891 if (!cluster_id) {
892 op->emitOpError("missing '_tpu_compilation_status'");
893 return WalkResult::interrupt();
894 }
895 compilation_results[cluster_id].push_back(op);
896 return WalkResult::advance();
897 });
898 if (result_init.wasInterrupted()) return signalPassFailure();
899 llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
900 OpBuilder builder(&getContext());
901 auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
902 if (failed(TF::HasValidCompilationAndReplicationAttributes(*op)))
903 return WalkResult::interrupt();
904 // Skip non-tpu device cluster_func.
905 auto cluster_id = op->getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
906 if (!cluster_id) return WalkResult::advance();
907
908 if (failed(Rewrite(op, devices.device_names(),
909 compilation_results[cluster_id], &builder,
910 tpu_compile_metadata_debug_)))
911 return WalkResult::interrupt();
912
913 to_be_erased.push_back(op);
914 return WalkResult::advance();
915 });
916 if (result.wasInterrupted()) return signalPassFailure();
917
918 EraseClusterFuncs(to_be_erased);
919
920 // Eliminate TPUCompilationResultOp now that the rewrite is complete.
921 for (auto& it : compilation_results) {
922 for (auto op : it.second) {
923 if (!op.use_empty()) {
924 mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
925 for (auto user : op->getUsers())
926 err.attachNote(user->getLoc()) << "remaining user";
927 return signalPassFailure();
928 }
929 op.erase();
930 }
931 }
932
933 // TODO(b/139377366): Remove functions that are no longer needed.
934 }
935
936 } // namespace
937
CreateTPURewritePass()938 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
939 return std::make_unique<TPURewritePass>();
940 }
941
942 } // namespace TFTPU
943 } // namespace mlir
944