1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/Types.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
48 #include "tensorflow/compiler/xla/xla.pb.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
55 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57
58 namespace mlir {
59 namespace TFTPU {
60
61 // NOLINTNEXTLINE
62 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
63 "tpu_compile_metadata_debug",
64 llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
65 "'tf._TPUCompileMlir' op as a proto debug string"));
66
67 constexpr char kNumReplicasAttr[] = "num_replicas";
68 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
69 constexpr char kPaddingMapAttr[] = "padding_map";
70 constexpr char kDeviceAttr[] = "device";
71 constexpr char kDevicesAttr[] = "devices";
72 constexpr char kVersionsAttr[] = "tf.versions";
73 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
74
75 constexpr char kBadStringArrayElementMsg[] =
76 "bad '{0}' attribute at index {1}, not a string";
77 constexpr char kBadArrayElementMsg[] =
78 "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
79 constexpr char kBadArrayAttrLengthMsg[] =
80 "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
81
82 namespace {
83 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
84 void runOnOperation() override;
85 };
86
87 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)88 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
89 return llvm::formatv("requires attribute '{0}'", attribute).str();
90 }
91
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)92 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
93 std::string* serialized_func_module) {
94 ModuleOp module = entry_func->getParentOfType<ModuleOp>();
95 SymbolTable entry_module_table(module);
96 llvm::SmallVector<FuncOp, 4> referenced({entry_func});
97
98 // Create a new module to hold func and all referenced functions.
99 OwningModuleRef module_for_func =
100 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
101 auto parent_module = entry_func->getParentOfType<ModuleOp>();
102 auto versions_attr = parent_module->getAttr(kVersionsAttr);
103 if (!versions_attr)
104 return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
105
106 module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
107 SymbolTable symbol_table(module_for_func.get());
108
109 while (!referenced.empty()) {
110 auto func = referenced.pop_back_val();
111
112 // Skip functions that have already been cloned into new module.
113 if (symbol_table.lookup<FuncOp>(func.getName())) continue;
114
115 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
116 // all found FuncOps to new_module to make sure new_module is
117 // self-contained.
118 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
119 assert(uses && "expected to be able to collect symbol uses");
120 for (SymbolTable::SymbolUse use : *uses) {
121 FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
122 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
123
124 // Skip Symbols that do not map to a function.
125 if (!referenced_func) continue;
126
127 referenced.emplace_back(referenced_func);
128 }
129
130 auto clone = func.clone();
131 if (clone.getName() == entry_func.getName()) {
132 // We can simply change name of TPU program's main function because there
133 // should be no other reference to it.
134 clone.setName("main");
135 clone.setPublic();
136 } else {
137 clone.setPrivate();
138 }
139 symbol_table.insert(clone);
140 }
141
142 *serialized_func_module =
143 tensorflow::SerializeMlirModule(module_for_func.get());
144 return success();
145 }
146
147 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
148 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)149 LogicalResult SetMetadataProtoStepMarkerLocation(
150 tf_device::ClusterFuncOp op,
151 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
152 auto step_marker_location =
153 op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
154 if (!step_marker_location)
155 return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
156
157 // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
158 // empty.
159 xla::DebugOptions::StepMarkerLocation location =
160 xla::DebugOptions::STEP_MARK_AT_ENTRY;
161 if (!step_marker_location.getValue().empty() &&
162 !xla::DebugOptions::StepMarkerLocation_Parse(
163 std::string(step_marker_location.getValue()), &location))
164 return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
165 kStepMarkerLocationAttr,
166 step_marker_location.getValue()));
167
168 metadata->set_step_marker_location(location);
169
170 return success();
171 }
172
173 // Populates a TPUCompileMetadataProto with PaddingMap from a
174 // `tf_device::ClusterFuncOp`.
SetMetadataProtoPaddingMap(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)175 LogicalResult SetMetadataProtoPaddingMap(
176 tf_device::ClusterFuncOp op,
177 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
178 auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
179 if (!padding_map)
180 return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
181
182 for (const auto& padding_and_idx : llvm::enumerate(padding_map)) {
183 auto& padding_attr = padding_and_idx.value();
184 auto padding_attr_str = padding_attr.dyn_cast<StringAttr>();
185 if (!padding_attr_str)
186 return op.emitOpError(llvm::formatv(
187 kBadStringArrayElementMsg, kPaddingMapAttr, padding_and_idx.index()));
188
189 tensorflow::tpu::PaddingMap* padding =
190 metadata->mutable_padding_maps()->Add();
191 if (!padding->ParseFromString(std::string(padding_attr_str.getValue())))
192 return op.emitOpError(llvm::formatv(
193 kBadArrayElementMsg, kPaddingMapAttr, padding_and_idx.index(),
194 padding_attr_str.getValue(), "tpu::PaddingMap"));
195 }
196
197 return success();
198 }
199
200 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)201 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
202 int index, xla::OpSharding* sharding) {
203 auto sharding_str = attr.dyn_cast<StringAttr>();
204 if (!sharding_str)
205 return op->emitOpError(
206 llvm::formatv(kBadStringArrayElementMsg, name, index));
207
208 if (!sharding->ParseFromString(sharding_str.getValue().str()))
209 return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
210 sharding_str.getValue(),
211 "xla::OpSharding"));
212
213 return success();
214 }
215
216 // Populates a TPUCompileMetadataProto with argument types and sharding from a
217 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)218 LogicalResult SetMetadataProtoArgs(
219 tf_device::ClusterFuncOp op,
220 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
221 auto input_shardings =
222 op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
223 if (!input_shardings)
224 return op.emitOpError(
225 CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
226
227 if (input_shardings.size() != op.getNumOperands())
228 return op.emitOpError(
229 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
230 op.getNumOperands(), input_shardings.size()));
231
232 // Set args metadata in proto.
233 for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
234 Type operand_type = operand_type_and_idx.value();
235 int index = operand_type_and_idx.index();
236 tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
237 tensorflow::DataType dtype;
238 tensorflow::Status status =
239 tensorflow::ConvertToDataType(operand_type, &dtype);
240 if (!status.ok())
241 return op.emitOpError(
242 llvm::formatv("failed to determine operand type at index {0}: {1}",
243 index, status.error_message()));
244
245 arg->set_dtype(dtype);
246 // TODO(lyandy): Support other arg kinds.
247 if (dtype == tensorflow::DT_RESOURCE)
248 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
249 else
250 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
251
252 // Populate argument shapes.
253 *arg->mutable_shape() = tensorflow::TensorShapeProto();
254 if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
255 tensorflow::TensorShapeProto shape_proto;
256 ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
257 *arg->mutable_shape() = std::move(shape_proto);
258 } else {
259 arg->mutable_shape()->set_unknown_rank(true);
260 }
261
262 if (failed(SetOpSharding(op, input_shardings.getValue()[index],
263 tensorflow::kInputShardingAttr, index,
264 arg->mutable_sharding())))
265 return failure();
266 }
267
268 return success();
269 }
270
271 // Populates a TPUCompileMetadataProto with result sharding from a
272 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)273 LogicalResult SetMetadataProtoRetvals(
274 tf_device::ClusterFuncOp op,
275 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
276 auto output_shardings =
277 op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
278 if (!output_shardings)
279 return op.emitOpError(
280 CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
281
282 if (output_shardings.size() != op.getNumResults())
283 return op.emitOpError(
284 llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
285 op.getNumResults(), output_shardings.size()));
286
287 // Set retvals metadata in proto.
288 for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
289 if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
290 tensorflow::kOutputShardingAttr,
291 output_sharding_and_idx.index(),
292 metadata->add_retvals()->mutable_sharding())))
293 return failure();
294
295 return success();
296 }
297
298 // Populates a TPUCompileMetadataProto from attributes of a
299 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
300 // op, a failure will be returned.
301 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)302 LogicalResult SetMetadataProtoFromClusterFuncOp(
303 tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
304 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
305 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
306 metadata->set_num_replicas(num_replicas);
307 metadata->set_num_cores_per_replica(num_cores_per_replica);
308
309 if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
310 return failure();
311
312 if (failed(SetMetadataProtoPaddingMap(op, metadata))) return failure();
313
314 if (xla_device_assignment.hasValue())
315 *metadata->mutable_device_assignment() =
316 std::move(xla_device_assignment.getValue());
317 auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
318 if (!use_spmd_attr)
319 return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
320 metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
321
322 if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
323
324 return SetMetadataProtoRetvals(op, metadata);
325 }
326
327 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)328 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
329 Operation* op, llvm::StringRef device) {
330 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
331
332 auto launch = builder->create<tf_device::LaunchOp>(
333 loc, builder->getStringAttr(device), op->getResultTypes());
334 launch.body().push_back(new Block);
335
336 builder->setInsertionPointToEnd(&launch.GetBody());
337 builder->create<tf_device::ReturnOp>(loc, op->getResults());
338
339 // Move op inside cluster.
340 op->moveBefore(launch.GetBody().getTerminator());
341
342 builder->restoreInsertionPoint(insert_point);
343
344 return launch;
345 }
346
347 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
348 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder)349 Operation* BuildCompileOp(
350 tf_device::ClusterFuncOp cluster_func, int num_replicas,
351 int num_cores_per_replica, llvm::StringRef compilation_device,
352 llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
353 OpBuilder* builder) {
354 // Set metadata from attributes.
355 tensorflow::tpu::TPUCompileMetadataProto metadata;
356 if (failed(SetMetadataProtoFromClusterFuncOp(
357 cluster_func, num_replicas, num_cores_per_replica,
358 std::move(xla_device_assignment), &metadata)))
359 return nullptr;
360
361 std::string txt_metadata;
362 if (tpu_compile_metadata_debug)
363 txt_metadata = metadata.DebugString();
364 else
365 metadata.SerializeToString(&txt_metadata);
366
367 // Build a shape op for each input to cluster_func.
368 // TODO(b/139377366): When shape inference is ready, we can use compile time
369 // shape inference to get inputs that have static shapes and only use shape
370 // ops for the rest.
371 llvm::SmallVector<Value, 4> compile_op_operands;
372 compile_op_operands.reserve(cluster_func.getNumOperands());
373
374 for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
375 // Skip adding shape op for operands that have static shapes.
376 tensorflow::PartialTensorShape shape(
377 metadata.args(operand_and_idx.index()).shape());
378 if (shape.IsFullyDefined()) continue;
379
380 auto shape_op = builder->create<TF::ShapeOp>(
381 cluster_func.getLoc(),
382 RankedTensorType::get({-1}, builder->getIntegerType(64)),
383 operand_and_idx.value());
384 compile_op_operands.emplace_back(shape_op.getResult());
385 }
386
387 FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
388 FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
389 func_attr.getValue());
390
391 std::string txt_module;
392 if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
393
394 auto compilation_status_type =
395 RankedTensorType::get({}, builder->getType<TF::StringType>());
396 auto program_type =
397 RankedTensorType::get({3}, builder->getType<TF::StringType>());
398
399 auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
400 cluster_func.getLoc(),
401 /*compilation_status=*/compilation_status_type, /*program=*/
402 llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
403 compile_op_operands, txt_module, txt_metadata);
404
405 return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
406 compilation_device);
407 }
408
409 // Assigns explicit devices to replicate op. An aliased device is created per
410 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)411 void AssignDevicesToReplicate(
412 tf_device::ReplicateOp replicate,
413 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
414 tpu_devices,
415 OpBuilder* builder) {
416 if (!replicate) return;
417
418 const int num_replicas = tpu_devices.size();
419 const int num_cores_per_replica = tpu_devices.front().size();
420
421 llvm::SmallVector<NamedAttribute, 8> device_attrs;
422 for (int core = 0; core < num_cores_per_replica; ++core) {
423 llvm::SmallVector<StringRef, 8> devices_by_core;
424 devices_by_core.reserve(num_replicas);
425 for (int replica = 0; replica < num_replicas; ++replica)
426 devices_by_core.push_back(tpu_devices[replica][core].device);
427
428 device_attrs.push_back(
429 builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
430 builder->getStrArrayAttr(devices_by_core)));
431 }
432
433 // For data parallelism, also add replicated host devices, as these are
434 // necessary for outside compilation.
435 if (num_cores_per_replica == 1) {
436 llvm::SmallVector<StringRef, 8> hosts;
437 hosts.reserve(num_replicas);
438 for (int replica = 0; replica < num_replicas; ++replica)
439 hosts.push_back(tpu_devices[replica][0].host);
440
441 device_attrs.push_back(builder->getNamedAttr(
442 tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
443 }
444
445 replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
446 }
447
448 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)449 LogicalResult BuildExecuteOp(
450 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
451 llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
452 OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
453 // TODO(b/139377366): Need to snapshot all resource variable inputs in
454 // follow-up CLs.
455 llvm::SmallVector<Type, 4> output_types;
456 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
457 core_id, output_sharding_config, cluster_func, &output_types);
458 if (failed(result)) return failure();
459
460 // TPUExecute has same output types as cluster_func.
461 *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
462 output_types, inputs);
463 return success();
464 }
465
466 // Creates a tf_device.parallel_execute op that wraps TPUExecute op to
467 // represent execution of TPU program in multiple logical cores.
BuildParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute_op)468 LogicalResult BuildParallelExecuteOp(
469 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
470 tpu_devices,
471 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
472 Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
473 OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
474 const int num_cores_per_replica = tpu_devices.front().size();
475 // parallel_execute op returns concatenated list of return values of
476 // all its regions.
477 //
478 // TODO(b/149102702): Correctly map inputs to parallel_execute op via
479 // identifying xla_sharding op in the cluster_func function.
480 const auto cluster_result_types = cluster_func.getResultTypes();
481 llvm::SmallVector<Type, 8> concatenated_output_types;
482 concatenated_output_types.reserve(cluster_result_types.size() *
483 num_cores_per_replica);
484
485 for (int core = 0; core < num_cores_per_replica; ++core) {
486 llvm::SmallVector<Type, 4> output_types;
487 auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
488 core, output_sharding_config, cluster_func, &output_types);
489 if (failed(result)) return failure();
490
491 for (Type t : output_types) concatenated_output_types.emplace_back(t);
492 }
493
494 *parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
495 cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types);
496
497 // Extract inputs for each region of the parallel_execute op. The i-th
498 // element in the list represents the input lists to TPU computation for
499 // i-th logical core.
500 llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
501 builder->setInsertionPoint(*parallel_execute_op);
502 auto result = tensorflow::ExtractInputsForLogicalDevices(
503 num_cores_per_replica, cluster_func, builder, &input_list);
504 if (failed(result)) return failure();
505
506 const bool replicated = tpu_devices.size() != 1;
507 // For each logical core, create a region with TPUExecute op.
508 assert(input_list.size() == num_cores_per_replica);
509 for (int core = 0; core < num_cores_per_replica; ++core) {
510 auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
511 builder->setInsertionPointToEnd(®ion);
512
513 // Create Execute op.
514 //
515 // TODO(b/148913294): Identify inputs/return values specific to each
516 // logical core TPU execution by parsing xla_sharding op in
517 // cluster_func.
518 auto execute_inputs = input_list[core];
519 execute_inputs.emplace_back(compile_op->getResult(core + 1));
520
521 TF::TPUExecuteOp execute;
522 result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
523 cluster_func, builder, &execute);
524 if (failed(result)) return failure();
525
526 // If computation is replicated, use aliased device. Otherwise there is only
527 // one execution device per core and the device is assigned to the execute
528 // op.
529 std::string device = replicated
530 ? tensorflow::GetDeviceAliasForLogicalCore(core)
531 : tpu_devices.front()[core].device;
532
533 auto region_launch_op =
534 WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
535
536 builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
537 region_launch_op.getResults());
538 }
539
540 return success();
541 }
542
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)543 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
544 llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
545 tpu_devices,
546 Operation* execute_op, OpBuilder* builder) {
547 const bool replicated = tpu_devices.size() != 1;
548 // If computation is replicated, use aliased device. Otherwise there is only
549 // one execution device and the device is assigned to the execute op.
550 std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
551 : tpu_devices.front().front().device;
552
553 return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
554 }
555
556 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
557 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,llvm::StringRef compilation_device,OpBuilder * builder)558 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
559 llvm::StringRef compilation_device,
560 OpBuilder* builder) {
561 auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
562 compile_op->getLoc(), compile_op->getResult(0));
563 WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
564 }
565
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,OpBuilder * builder)566 LogicalResult Rewrite(
567 tf_device::ClusterFuncOp cluster_func,
568 llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
569 OpBuilder* builder) {
570 // Collect `num_replicas` and `num_cores_per_replica` attributes.
571 int num_replicas = 1;
572 tf_device::ReplicateOp replicate =
573 cluster_func->getParentOfType<tf_device::ReplicateOp>();
574 if (replicate) num_replicas = replicate.n();
575
576 auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
577 tensorflow::kNumCoresPerReplicaAttr);
578 if (!num_cores_per_replica_attr)
579 return cluster_func.emitOpError(
580 CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
581
582 int num_cores_per_replica = num_cores_per_replica_attr.getInt();
583
584 auto topology_attr =
585 cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
586 if (!topology_attr)
587 return cluster_func.emitOpError(
588 CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
589
590 auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
591 tensorflow::kDeviceAssignmentAttr);
592 if (!device_assignment_attr)
593 return cluster_func.emitOpError(
594 llvm::formatv("requires attribute '{0}'",
595 tensorflow::kDeviceAssignmentAttr)
596 .str());
597
598 auto status_or_device_coodinates =
599 tensorflow::GetDeviceCoordinates(device_assignment_attr);
600 if (!status_or_device_coodinates.ok())
601 return cluster_func.emitError()
602 << "error in fetching tpu device coordinates: "
603 << status_or_device_coodinates.status().error_message();
604
605 // Determine compilation and execution devices.
606 auto status_or_tpu_device_assignment =
607 tensorflow::GetTPUCompilationAndExecutionDevices(
608 devices, num_replicas, num_cores_per_replica,
609 topology_attr.getValue(),
610 status_or_device_coodinates.ConsumeValueOrDie());
611 if (!status_or_tpu_device_assignment.ok())
612 return cluster_func.emitError()
613 << "error in fetching TPU compilation/execution devices: "
614 << status_or_tpu_device_assignment.status().error_message();
615
616 // Create compile op.
617 auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
618 builder->setInsertionPoint(cluster_func);
619
620 // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
621 // parallel_execute region if it exists.
622 if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
623 // Currently, outside compilation and model parallelism are not supported
624 // together.
625 assert(num_cores_per_replica == 1);
626 builder->setInsertionPoint(cluster_func->getParentOp());
627 }
628
629 Operation* compile_op = BuildCompileOp(
630 cluster_func, num_replicas, num_cores_per_replica,
631 tpu_device_assignment.compilation_device,
632 std::move(tpu_device_assignment.xla_device_assignment), builder);
633 if (!compile_op) return failure();
634
635 // This replaces _TPUCompileMlir placeholder ops that are required
636 // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
637 // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
638 // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
639 // structured lowering.
640 if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
641 cluster_func->getParentOp())) {
642 parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
643 key_op.replaceAllUsesWith(compile_op->getResult(1));
644 key_op.erase();
645 });
646 }
647
648 // After rewrite, find if there is a TPUCompilationResultOp in the block with
649 // the same _tpu_replicate attribute and replace it with the result of the
650 // compile op. This op is used as a placeholder to hook during graph creation
651 // the other ops that are intended to consume the compile result.
652 Block* block = cluster_func.getOperation()->getBlock();
653 for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
654 compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
655
656 BuildTPUCompileSucceededAssertOp(
657 compile_op, tpu_device_assignment.compilation_device, builder);
658
659 AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
660 builder);
661
662 llvm::SmallVector<xla::OpSharding, 4> output_shardings;
663 auto result = tensorflow::ParseAndValidateOutputSharding(
664 num_cores_per_replica, cluster_func, &output_shardings);
665 if (failed(result)) return failure();
666
667 builder->setInsertionPoint(cluster_func);
668 if (num_cores_per_replica > 1) {
669 // For model parallelism, tf_device.parallel_execute is used to express
670 // concurrent device execution across multiple logical devices.
671
672 tf_device::ParallelExecuteOp execute_op;
673 result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
674 output_shardings, compile_op, cluster_func,
675 builder, &execute_op);
676 if (failed(result)) return failure();
677
678 // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
679 // ops, the number of return values of parallel_execute op exceeds that of
680 // cluster_func op. As so, each return value of parallel_execute op must be
681 // mapped with corresponding return value usages of cluster_func.
682 return tensorflow::RemapOutputsFromLogicalDevices(
683 cluster_func.getLoc(), output_shardings, cluster_func, execute_op,
684 builder);
685 }
686
687 llvm::SmallVector<Value, 4> execute_inputs(cluster_func.getOperands());
688 execute_inputs.emplace_back(compile_op->getResult(1));
689
690 TF::TPUExecuteOp execute_op;
691 result = BuildExecuteOp(
692 /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder,
693 &execute_op);
694 if (failed(result)) return failure();
695
696 tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
697 tpu_device_assignment.tpu_devices, execute_op, builder);
698 cluster_func.replaceAllUsesWith(launch_op);
699 return success();
700 }
701
702 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
703 // TPUPartitionedOutputOp are present, they must be removed alongwith the
704 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)705 void EraseClusterFuncs(
706 llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
707 for (auto cluster : to_be_erased) {
708 for (auto result : cluster.results()) {
709 for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
710 if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
711 assert(user->use_empty());
712 user->erase();
713 }
714 }
715 }
716
717 for (auto operand : cluster.operands()) {
718 Operation* def = operand.getDefiningOp();
719 if (operand.hasOneUse() &&
720 llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
721 operand.dropAllUses();
722 def->erase();
723 }
724 }
725
726 assert(cluster->use_empty());
727 cluster->erase();
728 }
729 }
730
runOnOperation()731 void TPURewritePass::runOnOperation() {
732 mlir::TF::RuntimeDevices devices;
733 if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
734 return signalPassFailure();
735
736 llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
737 OpBuilder builder(&getContext());
738 auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
739 // Skip non-tpu device cluster_func.
740 auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
741 if (!replicate_attr) return WalkResult::advance();
742
743 if (failed(Rewrite(op, devices.device_names(), &builder)))
744 return WalkResult::interrupt();
745
746 to_be_erased.push_back(op);
747 return WalkResult::advance();
748 });
749
750 if (result.wasInterrupted()) return signalPassFailure();
751
752 EraseClusterFuncs(to_be_erased);
753
754 // Eliminate TPUCompilationResultOp now that the rewrite is complete.
755 getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
756
757 // TODO(b/139377366): Remove functions that are no longer needed.
758 }
759
760 } // namespace
761
CreateTPURewritePass()762 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
763 return std::make_unique<TPURewritePass>();
764 }
765
766 } // namespace TFTPU
767 } // namespace mlir
768