• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/Types.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
49 #include "tensorflow/compiler/xla/xla.pb.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 
59 namespace mlir {
60 namespace TFTPU {
61 
62 // NOLINTNEXTLINE
63 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
64     "tpu_compile_metadata_debug",
65     llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
66                    "'tf._TPUCompileMlir' op as a proto debug string"));
67 
68 constexpr char kNumReplicasAttr[] = "num_replicas";
69 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
70 constexpr char kDeviceAttr[] = "device";
71 constexpr char kDevicesAttr[] = "devices";
72 constexpr char kVersionsAttr[] = "tf.versions";
73 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
74 
75 constexpr char kBadStringArrayElementMsg[] =
76     "bad '{0}' attribute at index {1}, not a string";
77 constexpr char kBadArrayElementMsg[] =
78     "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
79 constexpr char kBadArrayAttrLengthMsg[] =
80     "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
81 
82 namespace {
83 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
84   void runOnOperation() override;
85 };
86 
87 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)88 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
89   return llvm::formatv("requires attribute '{0}'", attribute).str();
90 }
91 
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)92 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
93                                           std::string* serialized_func_module) {
94   ModuleOp module = entry_func->getParentOfType<ModuleOp>();
95   SymbolTable entry_module_table(module);
96   llvm::SmallVector<FuncOp, 4> referenced({entry_func});
97 
98   // Create a new module to hold func and all referenced functions.
99   OwningModuleRef module_for_func =
100       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
101   auto parent_module = entry_func->getParentOfType<ModuleOp>();
102   auto versions_attr = parent_module->getAttr(kVersionsAttr);
103   if (!versions_attr)
104     return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
105 
106   module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
107   SymbolTable symbol_table(module_for_func.get());
108 
109   while (!referenced.empty()) {
110     auto func = referenced.pop_back_val();
111 
112     // Skip functions that have already been cloned into new module.
113     if (symbol_table.lookup<FuncOp>(func.getName())) continue;
114 
115     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
116     // all found FuncOps to new_module to make sure new_module is
117     // self-contained.
118     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
119     assert(uses && "expected to be able to collect symbol uses");
120     for (SymbolTable::SymbolUse use : *uses) {
121       FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
122           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
123 
124       // Skip Symbols that do not map to a function.
125       if (!referenced_func) continue;
126 
127       referenced.emplace_back(referenced_func);
128     }
129 
130     auto clone = func.clone();
131     if (clone.getName() == entry_func.getName()) {
132       // We can simply change name of TPU program's main function because there
133       // should be no other reference to it.
134       clone.setName("main");
135       clone.setPublic();
136     } else {
137       clone.setPrivate();
138     }
139     symbol_table.insert(clone);
140   }
141 
142   *serialized_func_module =
143       tensorflow::SerializeMlirModule(module_for_func.get());
144   return success();
145 }
146 
147 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
148 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)149 LogicalResult SetMetadataProtoStepMarkerLocation(
150     tf_device::ClusterFuncOp op,
151     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
152   auto step_marker_location =
153       op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
154   if (!step_marker_location)
155     return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
156 
157   // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
158   // empty.
159   xla::DebugOptions::StepMarkerLocation location =
160       xla::DebugOptions::STEP_MARK_AT_ENTRY;
161   if (!step_marker_location.getValue().empty() &&
162       !xla::DebugOptions::StepMarkerLocation_Parse(
163           std::string(step_marker_location.getValue()), &location))
164     return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
165                                         kStepMarkerLocationAttr,
166                                         step_marker_location.getValue()));
167 
168   metadata->set_step_marker_location(location);
169 
170   return success();
171 }
172 
173 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)174 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
175                             int index, xla::OpSharding* sharding) {
176   auto sharding_str = attr.dyn_cast<StringAttr>();
177   if (!sharding_str)
178     return op->emitOpError(
179         llvm::formatv(kBadStringArrayElementMsg, name, index));
180 
181   if (!sharding->ParseFromString(sharding_str.getValue().str()))
182     return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
183                                          sharding_str.getValue(),
184                                          "xla::OpSharding"));
185 
186   return success();
187 }
188 
189 // Populates a TPUCompileMetadataProto with argument types and sharding from a
190 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)191 LogicalResult SetMetadataProtoArgs(
192     tf_device::ClusterFuncOp op,
193     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
194   auto input_shardings =
195       op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
196   if (!input_shardings)
197     return op.emitOpError(
198         CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
199 
200   if (input_shardings.size() != op.getNumOperands())
201     return op.emitOpError(
202         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
203                       op.getNumOperands(), input_shardings.size()));
204 
205   // Set args metadata in proto.
206   mlir::Identifier replication_attr_name = mlir::Identifier::get(
207       "mhlo.is_same_data_across_replicas", op.getContext());
208   for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
209     Type operand_type = operand_type_and_idx.value();
210     int index = operand_type_and_idx.index();
211     tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
212     tensorflow::DataType dtype;
213     tensorflow::Status status =
214         tensorflow::ConvertToDataType(operand_type, &dtype);
215     if (!status.ok())
216       return op.emitOpError(
217           llvm::formatv("failed to determine operand type at index {0}: {1}",
218                         index, status.error_message()));
219 
220     arg->set_dtype(dtype);
221     // TODO(lyandy): Support other arg kinds.
222     if (dtype == tensorflow::DT_RESOURCE)
223       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
224     else
225       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
226 
227     // Populate argument shapes.
228     *arg->mutable_shape() = tensorflow::TensorShapeProto();
229     if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
230       tensorflow::TensorShapeProto shape_proto;
231       ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
232       *arg->mutable_shape() = std::move(shape_proto);
233     } else {
234       arg->mutable_shape()->set_unknown_rank(true);
235     }
236 
237     if (failed(SetOpSharding(op, input_shardings.getValue()[index],
238                              tensorflow::kInputShardingAttr, index,
239                              arg->mutable_sharding())))
240       return failure();
241 
242     // Populate set_is_same_data_across_replicas
243     // Note: this information is duplicated and can be removed from the proto
244     // and here once MLIR bridge phase 2 doesn't fallback to the old bridge.
245     mlir::UnitAttr attr = op.getFunc().getArgAttrOfType<mlir::UnitAttr>(
246         index, replication_attr_name);
247     arg->set_is_same_data_across_replicas(attr != nullptr);
248   }
249 
250   return success();
251 }
252 
253 // Populates a TPUCompileMetadataProto with result sharding from a
254 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)255 LogicalResult SetMetadataProtoRetvals(
256     tf_device::ClusterFuncOp op,
257     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
258   auto output_shardings =
259       op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
260   if (!output_shardings)
261     return op.emitOpError(
262         CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
263 
264   if (output_shardings.size() != op.getNumResults())
265     return op.emitOpError(
266         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
267                       op.getNumResults(), output_shardings.size()));
268 
269   // Set retvals metadata in proto.
270   for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
271     if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
272                              tensorflow::kOutputShardingAttr,
273                              output_sharding_and_idx.index(),
274                              metadata->add_retvals()->mutable_sharding())))
275       return failure();
276 
277   return success();
278 }
279 
280 // Populates a TPUCompileMetadataProto from attributes of a
281 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
282 // op, a failure will be returned.
283 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)284 LogicalResult SetMetadataProtoFromClusterFuncOp(
285     tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
286     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
287     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
288   metadata->set_num_replicas(num_replicas);
289   metadata->set_num_cores_per_replica(num_cores_per_replica);
290 
291   if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
292     return failure();
293 
294   if (xla_device_assignment.hasValue())
295     *metadata->mutable_device_assignment() =
296         std::move(xla_device_assignment.getValue());
297   auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
298   if (!use_spmd_attr)
299     return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
300   metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
301 
302   if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
303 
304   return SetMetadataProtoRetvals(op, metadata);
305 }
306 
307 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)308 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
309                                    Operation* op, llvm::StringRef device) {
310   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
311 
312   auto launch = builder->create<tf_device::LaunchOp>(
313       loc, builder->getStringAttr(device), op->getResultTypes());
314   launch.body().push_back(new Block);
315 
316   builder->setInsertionPointToEnd(&launch.GetBody());
317   builder->create<tf_device::ReturnOp>(loc, op->getResults());
318 
319   // Move op inside cluster.
320   op->moveBefore(launch.GetBody().getTerminator());
321 
322   builder->restoreInsertionPoint(insert_point);
323 
324   return launch;
325 }
326 
327 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
328 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder)329 Operation* BuildCompileOp(
330     tf_device::ClusterFuncOp cluster_func, int num_replicas,
331     int num_cores_per_replica, llvm::StringRef compilation_device,
332     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
333     OpBuilder* builder) {
334   // Set metadata from attributes.
335   tensorflow::tpu::TPUCompileMetadataProto metadata;
336   if (failed(SetMetadataProtoFromClusterFuncOp(
337           cluster_func, num_replicas, num_cores_per_replica,
338           std::move(xla_device_assignment), &metadata)))
339     return nullptr;
340 
341 
342   // Build a shape op for each input to cluster_func.
343   // TODO(b/139377366): When shape inference is ready, we can use compile time
344   // shape inference to get inputs that have static shapes and only use shape
345   // ops for the rest.
346   llvm::SmallVector<Value, 4> compile_op_operands;
347   compile_op_operands.reserve(cluster_func.getNumOperands());
348 
349   for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
350     // Skip adding shape op for operands that have static shapes.
351     tensorflow::PartialTensorShape shape(
352         metadata.args(operand_and_idx.index()).shape());
353     if (shape.IsFullyDefined()) continue;
354 
355     auto shape_op = builder->create<TF::ShapeOp>(
356         cluster_func.getLoc(),
357         RankedTensorType::get({-1}, builder->getIntegerType(64)),
358         operand_and_idx.value());
359     compile_op_operands.emplace_back(shape_op.getResult());
360   }
361 
362   FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
363   FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
364       func_attr.getValue());
365 
366   std::string txt_module;
367   if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
368 
369   auto compilation_status_type =
370       RankedTensorType::get({}, builder->getType<TF::StringType>());
371   auto program_type =
372       RankedTensorType::get({3}, builder->getType<TF::StringType>());
373 
374   // Add MLIR module's fingerprint to compile metadata.
375   uint64_t mlir_fingerprint = tensorflow::Fingerprint64(txt_module);
376   metadata.set_mlir_fingerprint(mlir_fingerprint);
377 
378   std::string txt_metadata;
379   if (tpu_compile_metadata_debug)
380     txt_metadata = metadata.DebugString();
381   else
382     metadata.SerializeToString(&txt_metadata);
383 
384   auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
385       cluster_func.getLoc(),
386       /*compilation_status=*/compilation_status_type, /*program=*/
387       llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
388       compile_op_operands, txt_module, txt_metadata);
389 
390   return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
391                         compilation_device);
392 }
393 
394 // Assigns explicit devices to replicate op. An aliased device is created per
395 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)396 void AssignDevicesToReplicate(
397     tf_device::ReplicateOp replicate,
398     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
399         tpu_devices,
400     OpBuilder* builder) {
401   if (!replicate) return;
402 
403   const int num_replicas = tpu_devices.size();
404   const int num_cores_per_replica = tpu_devices.front().size();
405 
406   llvm::SmallVector<NamedAttribute, 8> device_attrs;
407   for (int core = 0; core < num_cores_per_replica; ++core) {
408     llvm::SmallVector<StringRef, 8> devices_by_core;
409     devices_by_core.reserve(num_replicas);
410     for (int replica = 0; replica < num_replicas; ++replica)
411       devices_by_core.push_back(tpu_devices[replica][core].device);
412 
413     device_attrs.push_back(
414         builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
415                               builder->getStrArrayAttr(devices_by_core)));
416   }
417 
418   // For data parallelism, also add replicated host devices, as these are
419   // necessary for outside compilation.
420   if (num_cores_per_replica == 1) {
421     llvm::SmallVector<StringRef, 8> hosts;
422     hosts.reserve(num_replicas);
423     for (int replica = 0; replica < num_replicas; ++replica)
424       hosts.push_back(tpu_devices[replica][0].host);
425 
426     device_attrs.push_back(builder->getNamedAttr(
427         tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
428   }
429 
430   replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
431 }
432 
433 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)434 LogicalResult BuildExecuteOp(
435     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
436     llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
437     OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
438   // TODO(b/139377366): Need to snapshot all resource variable inputs in
439   // follow-up CLs.
440   llvm::SmallVector<Type, 4> output_types;
441   auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
442       core_id, output_sharding_config, cluster_func, &output_types);
443   if (failed(result)) return failure();
444 
445   // TPUExecute has same output types as cluster_func.
446   *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
447                                                   output_types, inputs);
448   return success();
449 }
450 
451 // Creates a tf_device.parallel_execute op that wraps TPUExecute op to
452 // represent execution of TPU program in multiple logical cores.
BuildParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute_op)453 LogicalResult BuildParallelExecuteOp(
454     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
455         tpu_devices,
456     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
457     Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
458     OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
459   const int num_cores_per_replica = tpu_devices.front().size();
460   // parallel_execute op returns concatenated list of return values of
461   // all its regions.
462   //
463   // TODO(b/149102702): Correctly map inputs to parallel_execute op via
464   // identifying xla_sharding op in the cluster_func function.
465   const auto cluster_result_types = cluster_func.getResultTypes();
466   llvm::SmallVector<Type, 8> concatenated_output_types;
467   concatenated_output_types.reserve(cluster_result_types.size() *
468                                     num_cores_per_replica);
469 
470   for (int core = 0; core < num_cores_per_replica; ++core) {
471     llvm::SmallVector<Type, 4> output_types;
472     auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
473         core, output_sharding_config, cluster_func, &output_types);
474     if (failed(result)) return failure();
475 
476     for (Type t : output_types) concatenated_output_types.emplace_back(t);
477   }
478 
479   *parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
480       cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types);
481 
482   // Extract inputs for each region of the parallel_execute op. The i-th
483   // element in the list represents the input lists to TPU computation for
484   // i-th logical core.
485   llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
486   builder->setInsertionPoint(*parallel_execute_op);
487   auto result = tensorflow::ExtractInputsForLogicalDevices(
488       num_cores_per_replica, cluster_func, builder, &input_list);
489   if (failed(result)) return failure();
490 
491   const bool replicated = tpu_devices.size() != 1;
492   // For each logical core, create a region with TPUExecute op.
493   assert(input_list.size() == num_cores_per_replica);
494   for (int core = 0; core < num_cores_per_replica; ++core) {
495     auto& region = parallel_execute_op->GetRegionBlockWithIndex(core);
496     builder->setInsertionPointToEnd(&region);
497 
498     // Create Execute op.
499     //
500     // TODO(b/148913294): Identify inputs/return values specific to each
501     // logical core TPU execution by parsing xla_sharding op in
502     // cluster_func.
503     auto execute_inputs = input_list[core];
504     execute_inputs.emplace_back(compile_op->getResult(core + 1));
505 
506     TF::TPUExecuteOp execute;
507     result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
508                             cluster_func, builder, &execute);
509     if (failed(result)) return failure();
510 
511     // If computation is replicated, use aliased device. Otherwise there is only
512     // one execution device per core and the device is assigned to the execute
513     // op.
514     std::string device = replicated
515                              ? tensorflow::GetDeviceAliasForLogicalCore(core)
516                              : tpu_devices.front()[core].device;
517 
518     auto region_launch_op =
519         WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
520 
521     builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
522                                          region_launch_op.getResults());
523   }
524 
525   return success();
526 }
527 
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)528 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
529     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
530         tpu_devices,
531     Operation* execute_op, OpBuilder* builder) {
532   const bool replicated = tpu_devices.size() != 1;
533   // If computation is replicated, use aliased device. Otherwise there is only
534   // one execution device and the device is assigned to the execute op.
535   std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
536                                   : tpu_devices.front().front().device;
537 
538   return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
539 }
540 
541 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
542 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,Operation * result_id,llvm::StringRef compilation_device,OpBuilder * builder)543 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
544                                       Operation* result_id,
545                                       llvm::StringRef compilation_device,
546                                       OpBuilder* builder) {
547   auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
548       compile_op->getLoc(), result_id->getResult(0));
549   WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
550 }
551 
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,ArrayRef<TF::TPUCompilationResultOp> compilation_result,OpBuilder * builder)552 LogicalResult Rewrite(
553     tf_device::ClusterFuncOp cluster_func,
554     llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
555     ArrayRef<TF::TPUCompilationResultOp> compilation_result,
556     OpBuilder* builder) {
557   // Collect `num_replicas` and `num_cores_per_replica` attributes.
558   int num_replicas = 1;
559   tf_device::ReplicateOp replicate =
560       cluster_func->getParentOfType<tf_device::ReplicateOp>();
561   if (replicate) num_replicas = replicate.n();
562 
563   auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
564       tensorflow::kNumCoresPerReplicaAttr);
565   if (!num_cores_per_replica_attr)
566     return cluster_func.emitOpError(
567         CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
568 
569   int num_cores_per_replica = num_cores_per_replica_attr.getInt();
570 
571   auto topology_attr =
572       cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
573   if (!topology_attr)
574     return cluster_func.emitOpError(
575         CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
576 
577   auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
578       tensorflow::kDeviceAssignmentAttr);
579   if (!device_assignment_attr)
580     return cluster_func.emitOpError(
581         llvm::formatv("requires attribute '{0}'",
582                       tensorflow::kDeviceAssignmentAttr)
583             .str());
584 
585   auto status_or_device_coodinates =
586       tensorflow::GetDeviceCoordinates(device_assignment_attr);
587   if (!status_or_device_coodinates.ok())
588     return cluster_func.emitError()
589            << "error in fetching tpu device coordinates: "
590            << status_or_device_coodinates.status().error_message();
591 
592   // Determine compilation and execution devices.
593   auto status_or_tpu_device_assignment =
594       tensorflow::GetTPUCompilationAndExecutionDevices(
595           devices, num_replicas, num_cores_per_replica,
596           topology_attr.getValue(),
597           status_or_device_coodinates.ConsumeValueOrDie());
598   if (!status_or_tpu_device_assignment.ok())
599     return cluster_func.emitError()
600            << "error in fetching TPU compilation/execution devices: "
601            << status_or_tpu_device_assignment.status().error_message();
602 
603   // Create compile op.
604   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
605   builder->setInsertionPoint(cluster_func);
606 
607   // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
608   // parallel_execute region if it exists.
609   if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
610     // Currently, outside compilation and model parallelism are not supported
611     // together.
612     assert(num_cores_per_replica == 1);
613     builder->setInsertionPoint(cluster_func->getParentOp());
614   }
615 
616   Operation* compile_op = BuildCompileOp(
617       cluster_func, num_replicas, num_cores_per_replica,
618       tpu_device_assignment.compilation_device,
619       std::move(tpu_device_assignment.xla_device_assignment), builder);
620   if (!compile_op) return failure();
621 
622   // This replaces _TPUCompileMlir placeholder ops that are required
623   // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
624   // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
625   // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
626   // structured lowering.
627   if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
628           cluster_func->getParentOp())) {
629     parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
630       key_op.replaceAllUsesWith(compile_op->getResult(1));
631       key_op.erase();
632     });
633   }
634 
635   // After rewrite, if there is a TPUCompilationResultOp from the same cluster,
636   // replace it with the result of the compile op. The TPUCompilationResultOp is
637   // used as a placeholder to hook during graph creation the other ops that are
638   // intended to consume the compile result.
639   Operation* result_id = compile_op;
640   // TODO(jpienaar): Remove this later.
641   auto compile_device_op = compile_op->getAttr("device");
642   for (auto res : compilation_result) {
643     // Build identity op with the same location/name as the original compilation
644     // result op.
645     result_id = builder->create<TF::IdentityOp>(
646         res.getLoc(), compile_op->getResult(0).getType(),
647         result_id->getResult(0));
648     // Assign to same device as result is currently set, unless unset and then
649     // assign to the device on which compilation will happen.
650     // TODO(jpienaar): Remove this later.
651     if (auto device = res->getAttrOfType<StringAttr>("device")) {
652       if (!device.getValue().empty())
653         result_id->setAttr("device", device);
654       else
655         result_id->setAttr("device", compile_device_op);
656     } else if (compile_device_op) {
657       result_id->setAttr("device", compile_device_op);
658     }
659     res.output().replaceAllUsesWith(compile_op->getResult(0));
660   }
661 
662   BuildTPUCompileSucceededAssertOp(
663       compile_op, result_id, tpu_device_assignment.compilation_device, builder);
664 
665   AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
666                            builder);
667 
668   llvm::SmallVector<xla::OpSharding, 4> output_shardings;
669   auto result = tensorflow::ParseAndValidateOutputSharding(
670       num_cores_per_replica, cluster_func, &output_shardings);
671   if (failed(result)) return failure();
672 
673   builder->setInsertionPoint(cluster_func);
674   if (num_cores_per_replica > 1) {
675     // For model parallelism, tf_device.parallel_execute is used to express
676     // concurrent device execution across multiple logical devices.
677 
678     tf_device::ParallelExecuteOp execute_op;
679     result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
680                                     output_shardings, compile_op, cluster_func,
681                                     builder, &execute_op);
682     if (failed(result)) return failure();
683 
684     // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
685     // ops, the number of return values of parallel_execute op exceeds that of
686     // cluster_func op. As so, each return value of parallel_execute op must be
687     // mapped with corresponding return value usages of cluster_func.
688     return tensorflow::RemapOutputsFromLogicalDevices(
689         cluster_func.getLoc(), output_shardings, cluster_func, execute_op,
690         builder);
691   }
692 
693   llvm::SmallVector<Value, 4> execute_inputs(cluster_func.getOperands());
694   execute_inputs.emplace_back(compile_op->getResult(1));
695 
696   TF::TPUExecuteOp execute_op;
697   result = BuildExecuteOp(
698       /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder,
699       &execute_op);
700   if (failed(result)) return failure();
701 
702   tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
703       tpu_device_assignment.tpu_devices, execute_op, builder);
704   cluster_func.replaceAllUsesWith(launch_op);
705   return success();
706 }
707 
708 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
709 // TPUPartitionedOutputOp are present, they must be removed alongwith the
710 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)711 void EraseClusterFuncs(
712     llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
713   for (auto cluster : to_be_erased) {
714     for (auto result : cluster.results()) {
715       for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
716         if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
717           assert(user->use_empty());
718           user->erase();
719         }
720       }
721     }
722 
723     for (auto operand : cluster.operands()) {
724       Operation* def = operand.getDefiningOp();
725       if (operand.hasOneUse() &&
726           llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
727         operand.dropAllUses();
728         def->erase();
729       }
730     }
731 
732     assert(cluster->use_empty());
733     cluster->erase();
734   }
735 }
736 
runOnOperation()737 void TPURewritePass::runOnOperation() {
738   mlir::TF::RuntimeDevices devices;
739   if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
740     return signalPassFailure();
741 
742   // Collect compilation results.
743   llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
744       compilation_results;
745   auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
746     auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
747     if (!cluster_id) {
748       op->emitOpError("missing '_tpu_compilation_status'");
749       return WalkResult::interrupt();
750     }
751     compilation_results[cluster_id].push_back(op);
752     return WalkResult::advance();
753   });
754   if (result_init.wasInterrupted()) return signalPassFailure();
755 
756   llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
757   OpBuilder builder(&getContext());
758   auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
759     // Skip non-tpu device cluster_func.
760     auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_replicate");
761     if (!cluster_id) return WalkResult::advance();
762 
763     if (failed(Rewrite(op, devices.device_names(),
764                        compilation_results[cluster_id], &builder)))
765       return WalkResult::interrupt();
766 
767     to_be_erased.push_back(op);
768     return WalkResult::advance();
769   });
770   if (result.wasInterrupted()) return signalPassFailure();
771 
772   EraseClusterFuncs(to_be_erased);
773 
774   // Eliminate TPUCompilationResultOp now that the rewrite is complete.
775   for (auto& it : compilation_results) {
776     for (auto op : it.second) {
777       if (!op.use_empty()) {
778         mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
779         for (auto user : op->getUsers())
780           err.attachNote(user->getLoc()) << "remaining user";
781         return signalPassFailure();
782       }
783       op.erase();
784     }
785   }
786 
787   // TODO(b/139377366): Remove functions that are no longer needed.
788 }
789 
790 }  // namespace
791 
CreateTPURewritePass()792 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
793   return std::make_unique<TPURewritePass>();
794 }
795 
796 }  // namespace TFTPU
797 }  // namespace mlir
798