• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/Types.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
48 #include "tensorflow/compiler/xla/xla.pb.h"
49 #include "tensorflow/compiler/xla/xla_data.pb.h"
50 #include "tensorflow/core/framework/tensor_shape.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
55 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57 
58 namespace mlir {
59 namespace TFTPU {
60 
61 // NOLINTNEXTLINE
62 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
63     "tpu_compile_metadata_debug",
64     llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
65                    "'tf._TPUCompileMlir' op as a proto debug string"));
66 
67 constexpr char kNumReplicasAttr[] = "num_replicas";
68 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
69 constexpr char kPaddingMapAttr[] = "padding_map";
70 constexpr char kDeviceAttr[] = "device";
71 constexpr char kDevicesAttr[] = "devices";
72 constexpr char kVersionsAttr[] = "tf.versions";
73 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
74 
75 constexpr char kBadStringArrayElementMsg[] =
76     "bad '{0}' attribute at index {1}, not a string";
77 constexpr char kBadArrayElementMsg[] =
78     "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
79 constexpr char kBadArrayAttrLengthMsg[] =
80     "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
81 
82 namespace {
83 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
84   void runOnOperation() override;
85 };
86 
87 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)88 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
89   return llvm::formatv("requires attribute '{0}'", attribute).str();
90 }
91 
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)92 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
93                                           std::string* serialized_func_module) {
94   ModuleOp module = entry_func->getParentOfType<ModuleOp>();
95   SymbolTable entry_module_table(module);
96   llvm::SmallVector<FuncOp, 4> referenced({entry_func});
97 
98   // Create a new module to hold func and all referenced functions.
99   OwningModuleRef module_for_func =
100       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
101   auto parent_module = entry_func->getParentOfType<ModuleOp>();
102   auto versions_attr = parent_module->getAttr(kVersionsAttr);
103   if (!versions_attr)
104     return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
105 
106   module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
107   SymbolTable symbol_table(module_for_func.get());
108 
109   while (!referenced.empty()) {
110     auto func = referenced.pop_back_val();
111 
112     // Skip functions that have already been cloned into new module.
113     if (symbol_table.lookup<FuncOp>(func.getName())) continue;
114 
115     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
116     // all found FuncOps to new_module to make sure new_module is
117     // self-contained.
118     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
119     assert(uses && "expected to be able to collect symbol uses");
120     for (SymbolTable::SymbolUse use : *uses) {
121       FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
122           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
123 
124       // Skip Symbols that do not map to a function.
125       if (!referenced_func) continue;
126 
127       referenced.emplace_back(referenced_func);
128     }
129 
130     auto clone = func.clone();
131     if (clone.getName() == entry_func.getName()) {
132       // We can simply change name of TPU program's main function because there
133       // should be no other reference to it.
134       clone.setName("main");
135       clone.setPublic();
136     } else {
137       clone.setPrivate();
138     }
139     symbol_table.insert(clone);
140   }
141 
142   *serialized_func_module =
143       tensorflow::SerializeMlirModule(module_for_func.get());
144   return success();
145 }
146 
147 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
148 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)149 LogicalResult SetMetadataProtoStepMarkerLocation(
150     tf_device::ClusterFuncOp op,
151     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
152   auto step_marker_location =
153       op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
154   if (!step_marker_location)
155     return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
156 
157   // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
158   // empty.
159   xla::DebugOptions::StepMarkerLocation location =
160       xla::DebugOptions::STEP_MARK_AT_ENTRY;
161   if (!step_marker_location.getValue().empty() &&
162       !xla::DebugOptions::StepMarkerLocation_Parse(
163           std::string(step_marker_location.getValue()), &location))
164     return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
165                                         kStepMarkerLocationAttr,
166                                         step_marker_location.getValue()));
167 
168   metadata->set_step_marker_location(location);
169 
170   return success();
171 }
172 
173 // Populates a TPUCompileMetadataProto with PaddingMap from a
174 // `tf_device::ClusterFuncOp`.
SetMetadataProtoPaddingMap(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)175 LogicalResult SetMetadataProtoPaddingMap(
176     tf_device::ClusterFuncOp op,
177     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
178   auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
179   if (!padding_map)
180     return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
181 
182   for (const auto& padding_and_idx : llvm::enumerate(padding_map)) {
183     auto& padding_attr = padding_and_idx.value();
184     auto padding_attr_str = padding_attr.dyn_cast<StringAttr>();
185     if (!padding_attr_str)
186       return op.emitOpError(llvm::formatv(
187           kBadStringArrayElementMsg, kPaddingMapAttr, padding_and_idx.index()));
188 
189     tensorflow::tpu::PaddingMap* padding =
190         metadata->mutable_padding_maps()->Add();
191     if (!padding->ParseFromString(std::string(padding_attr_str.getValue())))
192       return op.emitOpError(llvm::formatv(
193           kBadArrayElementMsg, kPaddingMapAttr, padding_and_idx.index(),
194           padding_attr_str.getValue(), "tpu::PaddingMap"));
195   }
196 
197   return success();
198 }
199 
200 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)201 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
202                             int index, xla::OpSharding* sharding) {
203   auto sharding_str = attr.dyn_cast<StringAttr>();
204   if (!sharding_str)
205     return op->emitOpError(
206         llvm::formatv(kBadStringArrayElementMsg, name, index));
207 
208   if (!sharding->ParseFromString(sharding_str.getValue().str()))
209     return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
210                                          sharding_str.getValue(),
211                                          "xla::OpSharding"));
212 
213   return success();
214 }
215 
216 // Populates a TPUCompileMetadataProto with argument types and sharding from a
217 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)218 LogicalResult SetMetadataProtoArgs(
219     tf_device::ClusterFuncOp op,
220     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
221   auto input_shardings =
222       op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
223   if (!input_shardings)
224     return op.emitOpError(
225         CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
226 
227   if (input_shardings.size() != op.getNumOperands())
228     return op.emitOpError(
229         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
230                       op.getNumOperands(), input_shardings.size()));
231 
232   // Set args metadata in proto.
233   for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
234     Type operand_type = operand_type_and_idx.value();
235     int index = operand_type_and_idx.index();
236     tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
237     tensorflow::DataType dtype;
238     tensorflow::Status status =
239         tensorflow::ConvertToDataType(operand_type, &dtype);
240     if (!status.ok())
241       return op.emitOpError(
242           llvm::formatv("failed to determine operand type at index {0}: {1}",
243                         index, status.error_message()));
244 
245     arg->set_dtype(dtype);
246     // TODO(lyandy): Support other arg kinds.
247     if (dtype == tensorflow::DT_RESOURCE)
248       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
249     else
250       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
251 
252     // Populate argument shapes.
253     *arg->mutable_shape() = tensorflow::TensorShapeProto();
254     if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
255       tensorflow::TensorShapeProto shape_proto;
256       ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
257       *arg->mutable_shape() = std::move(shape_proto);
258     } else {
259       arg->mutable_shape()->set_unknown_rank(true);
260     }
261 
262     if (failed(SetOpSharding(op, input_shardings.getValue()[index],
263                              tensorflow::kInputShardingAttr, index,
264                              arg->mutable_sharding())))
265       return failure();
266   }
267 
268   return success();
269 }
270 
271 // Populates a TPUCompileMetadataProto with result sharding from a
272 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)273 LogicalResult SetMetadataProtoRetvals(
274     tf_device::ClusterFuncOp op,
275     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
276   auto output_shardings =
277       op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
278   if (!output_shardings)
279     return op.emitOpError(
280         CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
281 
282   if (output_shardings.size() != op.getNumResults())
283     return op.emitOpError(
284         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
285                       op.getNumResults(), output_shardings.size()));
286 
287   // Set retvals metadata in proto.
288   for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
289     if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
290                              tensorflow::kOutputShardingAttr,
291                              output_sharding_and_idx.index(),
292                              metadata->add_retvals()->mutable_sharding())))
293       return failure();
294 
295   return success();
296 }
297 
298 // Populates a TPUCompileMetadataProto from attributes of a
299 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
300 // op, a failure will be returned.
301 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)302 LogicalResult SetMetadataProtoFromClusterFuncOp(
303     tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
304     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
305     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
306   metadata->set_num_replicas(num_replicas);
307   metadata->set_num_cores_per_replica(num_cores_per_replica);
308 
309   if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
310     return failure();
311 
312   if (failed(SetMetadataProtoPaddingMap(op, metadata))) return failure();
313 
314   if (xla_device_assignment.hasValue())
315     *metadata->mutable_device_assignment() =
316         std::move(xla_device_assignment.getValue());
317   auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
318   if (!use_spmd_attr)
319     return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
320   metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
321 
322   if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
323 
324   return SetMetadataProtoRetvals(op, metadata);
325 }
326 
327 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)328 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
329                                    Operation* op, llvm::StringRef device) {
330   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
331 
332   auto launch = builder->create<tf_device::LaunchOp>(
333       loc, builder->getStringAttr(device), op->getResultTypes());
334   launch.body().push_back(new Block);
335 
336   builder->setInsertionPointToEnd(&launch.GetBody());
337   builder->create<tf_device::ReturnOp>(loc, op->getResults());
338 
339   // Move op inside cluster.
340   op->moveBefore(launch.GetBody().getTerminator());
341 
342   builder->restoreInsertionPoint(insert_point);
343 
344   return launch;
345 }
346 
347 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
348 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder)349 Operation* BuildCompileOp(
350     tf_device::ClusterFuncOp cluster_func, int num_replicas,
351     int num_cores_per_replica, llvm::StringRef compilation_device,
352     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
353     OpBuilder* builder) {
354   // Set metadata from attributes.
355   tensorflow::tpu::TPUCompileMetadataProto metadata;
356   if (failed(SetMetadataProtoFromClusterFuncOp(
357           cluster_func, num_replicas, num_cores_per_replica,
358           std::move(xla_device_assignment), &metadata)))
359     return nullptr;
360 
361   std::string txt_metadata;
362   if (tpu_compile_metadata_debug)
363     txt_metadata = metadata.DebugString();
364   else
365     metadata.SerializeToString(&txt_metadata);
366 
367   // Build a shape op for each input to cluster_func.
368   // TODO(b/139377366): When shape inference is ready, we can use compile time
369   // shape inference to get inputs that have static shapes and only use shape
370   // ops for the rest.
371   llvm::SmallVector<Value, 4> compile_op_operands;
372   compile_op_operands.reserve(cluster_func.getNumOperands());
373 
374   for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
375     // Skip adding shape op for operands that have static shapes.
376     tensorflow::PartialTensorShape shape(
377         metadata.args(operand_and_idx.index()).shape());
378     if (shape.IsFullyDefined()) continue;
379 
380     auto shape_op = builder->create<TF::ShapeOp>(
381         cluster_func.getLoc(),
382         RankedTensorType::get({-1}, builder->getIntegerType(64)),
383         operand_and_idx.value());
384     compile_op_operands.emplace_back(shape_op.getResult());
385   }
386 
387   FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
388   FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
389       func_attr.getValue());
390 
391   std::string txt_module;
392   if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
393 
394   auto compilation_status_type =
395       RankedTensorType::get({}, builder->getType<TF::StringType>());
396   auto program_type =
397       RankedTensorType::get({3}, builder->getType<TF::StringType>());
398 
399   auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
400       cluster_func.getLoc(),
401       /*compilation_status=*/compilation_status_type, /*program=*/
402       llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
403       compile_op_operands, txt_module, txt_metadata);
404 
405   return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
406                         compilation_device);
407 }
408 
409 // Assigns explicit devices to replicate op. An aliased device is created per
410 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)411 void AssignDevicesToReplicate(
412     tf_device::ReplicateOp replicate,
413     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
414         tpu_devices,
415     OpBuilder* builder) {
416   if (!replicate) return;
417 
418   const int num_replicas = tpu_devices.size();
419   const int num_cores_per_replica = tpu_devices.front().size();
420 
421   llvm::SmallVector<NamedAttribute, 8> device_attrs;
422   for (int core = 0; core < num_cores_per_replica; ++core) {
423     llvm::SmallVector<StringRef, 8> devices_by_core;
424     devices_by_core.reserve(num_replicas);
425     for (int replica = 0; replica < num_replicas; ++replica)
426       devices_by_core.push_back(tpu_devices[replica][core].device);
427 
428     device_attrs.push_back(
429         builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
430                               builder->getStrArrayAttr(devices_by_core)));
431   }
432 
433   // For data parallelism, also add replicated host devices, as these are
434   // necessary for outside compilation.
435   if (num_cores_per_replica == 1) {
436     llvm::SmallVector<StringRef, 8> hosts;
437     hosts.reserve(num_replicas);
438     for (int replica = 0; replica < num_replicas; ++replica)
439       hosts.push_back(tpu_devices[replica][0].host);
440 
441     device_attrs.push_back(builder->getNamedAttr(
442         tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
443   }
444 
445   replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
446 }
447 
448 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)449 LogicalResult BuildExecuteOp(
450     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
451     llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
452     OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
453   // TODO(b/139377366): Need to snapshot all resource variable inputs in
454   // follow-up CLs.
455   llvm::SmallVector<Type, 4> output_types;
456   auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
457       core_id, output_sharding_config, cluster_func, &output_types);
458   if (failed(result)) return failure();
459 
460   // TPUExecute has same output types as cluster_func.
461   *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
462                                                   output_types, inputs);
463   return success();
464 }
465 
466 // Creates a tf_device.parallel_execute op that wraps TPUExecute op to
467 // represent execution of TPU program in multiple logical cores.
BuildParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute_op)468 LogicalResult BuildParallelExecuteOp(
469     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
470         tpu_devices,
471     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
472     Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
473     OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
474   const int num_cores_per_replica = tpu_devices.front().size();
475   // parallel_execute op returns concatenated list of return values of
476   // all its regions.
477   //
478   // TODO(b/149102702): Correctly map inputs to parallel_execute op via
479   // identifying xla_sharding op in the cluster_func function.
480   const auto cluster_result_types = cluster_func.getResultTypes();
481   llvm::SmallVector<Type, 8> concatenated_output_types;
482   concatenated_output_types.reserve(cluster_result_types.size() *
483                                     num_cores_per_replica);
484 
485   for (int core = 0; core < num_cores_per_replica; ++core) {
486     llvm::SmallVector<Type, 4> output_types;
487     auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
488         core, output_sharding_config, cluster_func, &output_types);
489     if (failed(result)) return failure();
490 
491     for (Type t : output_types) concatenated_output_types.emplace_back(t);
492   }
493 
494   *parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
495       cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types);
496 
497   // Extract inputs for each region of the parallel_execute op. The i-th
498   // element in the list represents the input lists to TPU computation for
499   // i-th logical core.
500   llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
501   builder->setInsertionPoint(*parallel_execute_op);
502   auto result = tensorflow::ExtractInputsForLogicalDevices(
503       num_cores_per_replica, cluster_func, builder, &input_list);
504   if (failed(result)) return failure();
505 
506   const bool replicated = tpu_devices.size() != 1;
507   // For each logical core, create a region with TPUExecute op.
508   assert(input_list.size() == num_cores_per_replica);
509   for (int core = 0; core < num_cores_per_replica; ++core) {
510     auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
511     builder->setInsertionPointToEnd(&region);
512 
513     // Create Execute op.
514     //
515     // TODO(b/148913294): Identify inputs/return values specific to each
516     // logical core TPU execution by parsing xla_sharding op in
517     // cluster_func.
518     auto execute_inputs = input_list[core];
519     execute_inputs.emplace_back(compile_op->getResult(core + 1));
520 
521     TF::TPUExecuteOp execute;
522     result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
523                             cluster_func, builder, &execute);
524     if (failed(result)) return failure();
525 
526     // If computation is replicated, use aliased device. Otherwise there is only
527     // one execution device per core and the device is assigned to the execute
528     // op.
529     std::string device = replicated
530                              ? tensorflow::GetDeviceAliasForLogicalCore(core)
531                              : tpu_devices.front()[core].device;
532 
533     auto region_launch_op =
534         WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
535 
536     builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
537                                          region_launch_op.getResults());
538   }
539 
540   return success();
541 }
542 
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)543 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
544     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
545         tpu_devices,
546     Operation* execute_op, OpBuilder* builder) {
547   const bool replicated = tpu_devices.size() != 1;
548   // If computation is replicated, use aliased device. Otherwise there is only
549   // one execution device and the device is assigned to the execute op.
550   std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
551                                   : tpu_devices.front().front().device;
552 
553   return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
554 }
555 
556 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
557 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,llvm::StringRef compilation_device,OpBuilder * builder)558 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
559                                       llvm::StringRef compilation_device,
560                                       OpBuilder* builder) {
561   auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
562       compile_op->getLoc(), compile_op->getResult(0));
563   WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
564 }
565 
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,OpBuilder * builder)566 LogicalResult Rewrite(
567     tf_device::ClusterFuncOp cluster_func,
568     llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
569     OpBuilder* builder) {
570   // Collect `num_replicas` and `num_cores_per_replica` attributes.
571   int num_replicas = 1;
572   tf_device::ReplicateOp replicate =
573       cluster_func->getParentOfType<tf_device::ReplicateOp>();
574   if (replicate) num_replicas = replicate.n();
575 
576   auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
577       tensorflow::kNumCoresPerReplicaAttr);
578   if (!num_cores_per_replica_attr)
579     return cluster_func.emitOpError(
580         CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
581 
582   int num_cores_per_replica = num_cores_per_replica_attr.getInt();
583 
584   auto topology_attr =
585       cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
586   if (!topology_attr)
587     return cluster_func.emitOpError(
588         CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
589 
590   auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
591       tensorflow::kDeviceAssignmentAttr);
592   if (!device_assignment_attr)
593     return cluster_func.emitOpError(
594         llvm::formatv("requires attribute '{0}'",
595                       tensorflow::kDeviceAssignmentAttr)
596             .str());
597 
598   auto status_or_device_coodinates =
599       tensorflow::GetDeviceCoordinates(device_assignment_attr);
600   if (!status_or_device_coodinates.ok())
601     return cluster_func.emitError()
602            << "error in fetching tpu device coordinates: "
603            << status_or_device_coodinates.status().error_message();
604 
605   // Determine compilation and execution devices.
606   auto status_or_tpu_device_assignment =
607       tensorflow::GetTPUCompilationAndExecutionDevices(
608           devices, num_replicas, num_cores_per_replica,
609           topology_attr.getValue(),
610           status_or_device_coodinates.ConsumeValueOrDie());
611   if (!status_or_tpu_device_assignment.ok())
612     return cluster_func.emitError()
613            << "error in fetching TPU compilation/execution devices: "
614            << status_or_tpu_device_assignment.status().error_message();
615 
616   // Create compile op.
617   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
618   builder->setInsertionPoint(cluster_func);
619 
620   // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
621   // parallel_execute region if it exists.
622   if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
623     // Currently, outside compilation and model parallelism are not supported
624     // together.
625     assert(num_cores_per_replica == 1);
626     builder->setInsertionPoint(cluster_func->getParentOp());
627   }
628 
629   Operation* compile_op = BuildCompileOp(
630       cluster_func, num_replicas, num_cores_per_replica,
631       tpu_device_assignment.compilation_device,
632       std::move(tpu_device_assignment.xla_device_assignment), builder);
633   if (!compile_op) return failure();
634 
635   // This replaces _TPUCompileMlir placeholder ops that are required
636   // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
637   // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
638   // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
639   // structured lowering.
640   if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
641           cluster_func->getParentOp())) {
642     parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
643       key_op.replaceAllUsesWith(compile_op->getResult(1));
644       key_op.erase();
645     });
646   }
647 
648   // After rewrite, find if there is a TPUCompilationResultOp in the block with
649   // the same _tpu_replicate attribute and replace it with the result of the
650   // compile op. This op is used as a placeholder to hook during graph creation
651   // the other ops that are intended to consume the compile result.
652   Block* block = cluster_func.getOperation()->getBlock();
653   for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
654     compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
655 
656   BuildTPUCompileSucceededAssertOp(
657       compile_op, tpu_device_assignment.compilation_device, builder);
658 
659   AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
660                            builder);
661 
662   llvm::SmallVector<xla::OpSharding, 4> output_shardings;
663   auto result = tensorflow::ParseAndValidateOutputSharding(
664       num_cores_per_replica, cluster_func, &output_shardings);
665   if (failed(result)) return failure();
666 
667   builder->setInsertionPoint(cluster_func);
668   if (num_cores_per_replica > 1) {
669     // For model parallelism, tf_device.parallel_execute is used to express
670     // concurrent device execution across multiple logical devices.
671 
672     tf_device::ParallelExecuteOp execute_op;
673     result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
674                                     output_shardings, compile_op, cluster_func,
675                                     builder, &execute_op);
676     if (failed(result)) return failure();
677 
678     // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
679     // ops, the number of return values of parallel_execute op exceeds that of
680     // cluster_func op. As so, each return value of parallel_execute op must be
681     // mapped with corresponding return value usages of cluster_func.
682     return tensorflow::RemapOutputsFromLogicalDevices(
683         cluster_func.getLoc(), output_shardings, cluster_func, execute_op,
684         builder);
685   }
686 
687   llvm::SmallVector<Value, 4> execute_inputs(cluster_func.getOperands());
688   execute_inputs.emplace_back(compile_op->getResult(1));
689 
690   TF::TPUExecuteOp execute_op;
691   result = BuildExecuteOp(
692       /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder,
693       &execute_op);
694   if (failed(result)) return failure();
695 
696   tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
697       tpu_device_assignment.tpu_devices, execute_op, builder);
698   cluster_func.replaceAllUsesWith(launch_op);
699   return success();
700 }
701 
702 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
703 // TPUPartitionedOutputOp are present, they must be removed alongwith the
704 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)705 void EraseClusterFuncs(
706     llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
707   for (auto cluster : to_be_erased) {
708     for (auto result : cluster.results()) {
709       for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
710         if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
711           assert(user->use_empty());
712           user->erase();
713         }
714       }
715     }
716 
717     for (auto operand : cluster.operands()) {
718       Operation* def = operand.getDefiningOp();
719       if (operand.hasOneUse() &&
720           llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
721         operand.dropAllUses();
722         def->erase();
723       }
724     }
725 
726     assert(cluster->use_empty());
727     cluster->erase();
728   }
729 }
730 
runOnOperation()731 void TPURewritePass::runOnOperation() {
732   mlir::TF::RuntimeDevices devices;
733   if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
734     return signalPassFailure();
735 
736   llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
737   OpBuilder builder(&getContext());
738   auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
739     // Skip non-tpu device cluster_func.
740     auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
741     if (!replicate_attr) return WalkResult::advance();
742 
743     if (failed(Rewrite(op, devices.device_names(), &builder)))
744       return WalkResult::interrupt();
745 
746     to_be_erased.push_back(op);
747     return WalkResult::advance();
748   });
749 
750   if (result.wasInterrupted()) return signalPassFailure();
751 
752   EraseClusterFuncs(to_be_erased);
753 
754   // Eliminate TPUCompilationResultOp now that the rewrite is complete.
755   getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
756 
757   // TODO(b/139377366): Remove functions that are no longer needed.
758 }
759 
760 }  // namespace
761 
CreateTPURewritePass()762 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
763   return std::make_unique<TPURewritePass>();
764 }
765 
766 }  // namespace TFTPU
767 }  // namespace mlir
768