• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
51 #include "tensorflow/compiler/xla/xla.pb.h"
52 #include "tensorflow/compiler/xla/xla_data.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.pb.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/platform/fingerprint.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61 
62 namespace mlir {
63 namespace TFTPU {
64 
65 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
66 constexpr char kDevicesAttr[] = "devices";
67 constexpr char kVersionsAttr[] = "tf.versions";
68 constexpr char kUseXlaSpmdAttr[] = "use_spmd_for_xla_partitioning";
69 
70 constexpr char kBadStringArrayElementMsg[] =
71     "bad '{0}' attribute at index {1}, not a string";
72 constexpr char kBadArrayElementMsg[] =
73     "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
74 constexpr char kBadArrayAttrLengthMsg[] =
75     "bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
76 
77 namespace {
78 struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
79   void runOnOperation() override;
80 };
81 
82 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)83 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
84   return llvm::formatv("requires attribute '{0}'", attribute).str();
85 }
86 
EncapsulateFuncAndSerialize(func::FuncOp entry_func,std::string * serialized_func_module)87 LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func,
88                                           std::string* serialized_func_module) {
89   ModuleOp module = entry_func->getParentOfType<ModuleOp>();
90   SymbolTable entry_module_table(module);
91   llvm::SmallVector<func::FuncOp, 4> referenced({entry_func});
92 
93   // Create a new module to hold func and all referenced functions.
94   OwningOpRef<mlir::ModuleOp> module_for_func =
95       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
96   auto parent_module = entry_func->getParentOfType<ModuleOp>();
97   auto versions_attr = parent_module->getAttr(kVersionsAttr);
98   if (!versions_attr)
99     return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
100 
101   module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
102   SymbolTable symbol_table(module_for_func.get());
103 
104   while (!referenced.empty()) {
105     auto func = referenced.pop_back_val();
106 
107     // Skip functions that have already been cloned into new module.
108     if (symbol_table.lookup<func::FuncOp>(func.getName())) continue;
109 
110     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
111     // all found FuncOps to new_module to make sure new_module is
112     // self-contained.
113     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
114     assert(uses && "expected to be able to collect symbol uses");
115     for (SymbolTable::SymbolUse use : *uses) {
116       func::FuncOp referenced_func = entry_module_table.lookup<func::FuncOp>(
117           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
118 
119       // Skip Symbols that do not map to a function.
120       if (!referenced_func) continue;
121 
122       referenced.emplace_back(referenced_func);
123     }
124 
125     auto clone = func.clone();
126     if (clone.getName() == entry_func.getName()) {
127       // We can simply change name of TPU program's main function because there
128       // should be no other reference to it.
129       clone.setName(StringAttr::get(clone.getContext(), "main"));
130       clone.setPublic();
131     } else {
132       clone.setPrivate();
133     }
134     symbol_table.insert(clone);
135   }
136 
137   *serialized_func_module =
138       tensorflow::SerializeMlirModule(module_for_func.get());
139   return success();
140 }
141 
142 // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
143 // `tf_device::ClusterFuncOp`.
SetMetadataProtoStepMarkerLocation(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)144 LogicalResult SetMetadataProtoStepMarkerLocation(
145     tf_device::ClusterFuncOp op,
146     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
147   auto step_marker_location =
148       op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
149   if (!step_marker_location)
150     return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
151 
152   // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
153   // empty.
154   xla::DebugOptions::StepMarkerLocation location =
155       xla::DebugOptions::STEP_MARK_AT_ENTRY;
156   if (!step_marker_location.getValue().empty() &&
157       !xla::DebugOptions::StepMarkerLocation_Parse(
158           std::string(step_marker_location.getValue()), &location))
159     return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
160                                         kStepMarkerLocationAttr,
161                                         step_marker_location.getValue()));
162 
163   metadata->set_step_marker_location(location);
164 
165   return success();
166 }
167 
168 // Parses a xla::OpSharding from a string attribute.
SetOpSharding(Operation * op,Attribute attr,llvm::StringRef name,int index,xla::OpSharding * sharding)169 LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name,
170                             int index, xla::OpSharding* sharding) {
171   auto sharding_str = attr.dyn_cast<StringAttr>();
172   if (!sharding_str)
173     return op->emitOpError(
174         llvm::formatv(kBadStringArrayElementMsg, name, index));
175 
176   if (!sharding->ParseFromString(sharding_str.getValue().str()))
177     return op->emitOpError(llvm::formatv(kBadArrayElementMsg, name, index,
178                                          sharding_str.getValue(),
179                                          "xla::OpSharding"));
180 
181   return success();
182 }
183 
184 // Populates a TPUCompileMetadataProto with argument types and sharding from a
185 // `tf_device::ClusterFuncOp`.
SetMetadataProtoArgs(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)186 LogicalResult SetMetadataProtoArgs(
187     tf_device::ClusterFuncOp op,
188     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
189   auto input_shardings =
190       op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
191   if (!input_shardings)
192     return op.emitOpError(
193         CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
194 
195   if (input_shardings.size() != op.getNumOperands())
196     return op.emitOpError(
197         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kInputShardingAttr,
198                       op.getNumOperands(), input_shardings.size()));
199 
200   // Set args metadata in proto.
201   mlir::StringAttr replication_attr_name = mlir::StringAttr::get(
202       op.getContext(), "mhlo.is_same_data_across_replicas");
203   for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
204     Type operand_type = operand_type_and_idx.value();
205     int index = operand_type_and_idx.index();
206     tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
207     tensorflow::DataType dtype;
208     tensorflow::Status status =
209         tensorflow::ConvertToDataType(operand_type, &dtype);
210     if (!status.ok())
211       return op.emitOpError(
212           llvm::formatv("failed to determine operand type at index {0}: {1}",
213                         index, status.error_message()));
214 
215     arg->set_dtype(dtype);
216     // TODO(lyandy): Support other arg kinds.
217     if (dtype == tensorflow::DT_RESOURCE)
218       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
219     else
220       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
221 
222     // Populate argument shapes.
223     *arg->mutable_shape() = tensorflow::TensorShapeProto();
224     if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
225       tensorflow::TensorShapeProto shape_proto;
226       ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
227       *arg->mutable_shape() = std::move(shape_proto);
228     } else {
229       arg->mutable_shape()->set_unknown_rank(true);
230     }
231 
232     if (failed(SetOpSharding(op, input_shardings.getValue()[index],
233                              tensorflow::kInputShardingAttr, index,
234                              arg->mutable_sharding())))
235       return failure();
236 
237     // Populate set_is_same_data_across_replicas
238     // Note: this information is duplicated and can be removed from the proto
239     // and here once MLIR bridge phase 2 doesn't fallback to the old bridge.
240     mlir::UnitAttr attr = op.getFunc().getArgAttrOfType<mlir::UnitAttr>(
241         index, replication_attr_name);
242     arg->set_is_same_data_across_replicas(attr != nullptr);
243   }
244 
245   return success();
246 }
247 
248 // Populates a TPUCompileMetadataProto with result sharding from a
249 // `tf_device::ClusterFuncOp`.
SetMetadataProtoRetvals(tf_device::ClusterFuncOp op,tensorflow::tpu::TPUCompileMetadataProto * metadata)250 LogicalResult SetMetadataProtoRetvals(
251     tf_device::ClusterFuncOp op,
252     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
253   auto output_shardings =
254       op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
255   if (!output_shardings)
256     return op.emitOpError(
257         CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
258 
259   if (output_shardings.size() != op.getNumResults())
260     return op.emitOpError(
261         llvm::formatv(kBadArrayAttrLengthMsg, tensorflow::kOutputShardingAttr,
262                       op.getNumResults(), output_shardings.size()));
263 
264   // Set retvals metadata in proto.
265   for (auto output_sharding_and_idx : llvm::enumerate(output_shardings))
266     if (failed(SetOpSharding(op, output_sharding_and_idx.value(),
267                              tensorflow::kOutputShardingAttr,
268                              output_sharding_and_idx.index(),
269                              metadata->add_retvals()->mutable_sharding())))
270       return failure();
271 
272   return success();
273 }
274 
275 // Populates a TPUCompileMetadataProto from attributes of a
276 // `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the
277 // op, a failure will be returned.
278 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromClusterFuncOp(tf_device::ClusterFuncOp op,int num_replicas,int num_cores_per_replica,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,tensorflow::tpu::TPUCompileMetadataProto * metadata)279 LogicalResult SetMetadataProtoFromClusterFuncOp(
280     tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica,
281     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
282     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
283   if (auto options_attr =
284           op->getAttrOfType<StringAttr>("tpu_compile_options_proto")) {
285     if (!metadata->mutable_compile_options()->ParseFromArray(
286             options_attr.data(), options_attr.size())) {
287       return failure();
288     }
289   }
290   metadata->set_num_replicas(num_replicas);
291   metadata->set_num_cores_per_replica(num_cores_per_replica);
292 
293   if (failed(SetMetadataProtoStepMarkerLocation(op, metadata)))
294     return failure();
295 
296   if (xla_device_assignment.has_value())
297     *metadata->mutable_device_assignment() =
298         std::move(xla_device_assignment.getValue());
299   auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
300   if (!use_spmd_attr)
301     return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
302   metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
303 
304   if (failed(SetMetadataProtoArgs(op, metadata))) return failure();
305 
306   return SetMetadataProtoRetvals(op, metadata);
307 }
308 
309 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)310 tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
311                                    Operation* op, llvm::StringRef device) {
312   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
313 
314   auto launch = builder->create<tf_device::LaunchOp>(
315       loc, builder->getStringAttr(device), op->getResultTypes());
316   launch.body().push_back(new Block);
317 
318   builder->setInsertionPointToEnd(&launch.GetBody());
319   builder->create<tf_device::ReturnOp>(loc, op->getResults());
320 
321   // Move op inside cluster.
322   op->moveBefore(launch.GetBody().getTerminator());
323 
324   builder->restoreInsertionPoint(insert_point);
325 
326   return launch;
327 }
328 
329 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
330 // functionally equivalent to the function referenced by cluster_func.
BuildCompileOp(tf_device::ClusterFuncOp cluster_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,llvm::Optional<xla::DeviceAssignmentProto> && xla_device_assignment,OpBuilder * builder,bool tpu_compile_metadata_debug)331 Operation* BuildCompileOp(
332     tf_device::ClusterFuncOp cluster_func, int num_replicas,
333     int num_cores_per_replica, llvm::StringRef compilation_device,
334     llvm::Optional<xla::DeviceAssignmentProto>&& xla_device_assignment,
335     OpBuilder* builder, bool tpu_compile_metadata_debug) {
336   // Set metadata from attributes.
337   tensorflow::tpu::TPUCompileMetadataProto metadata;
338   if (failed(SetMetadataProtoFromClusterFuncOp(
339           cluster_func, num_replicas, num_cores_per_replica,
340           std::move(xla_device_assignment), &metadata)))
341     return nullptr;
342 
343   // Build a shape op for each input to cluster_func.
344   // TODO(b/139377366): When shape inference is ready, we can use compile time
345   // shape inference to get inputs that have static shapes and only use shape
346   // ops for the rest.
347   llvm::SmallVector<Value, 4> compile_op_operands;
348   compile_op_operands.reserve(cluster_func.getNumOperands());
349 
350   for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
351     // Skip adding shape op for operands that have static shapes.
352     tensorflow::PartialTensorShape shape(
353         metadata.args(operand_and_idx.index()).shape());
354     if (shape.IsFullyDefined()) continue;
355 
356     auto shape_op = builder->create<TF::ShapeOp>(
357         cluster_func.getLoc(),
358         RankedTensorType::get({-1}, builder->getIntegerType(64)),
359         operand_and_idx.value());
360     compile_op_operands.emplace_back(shape_op.getResult());
361   }
362 
363   FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
364   func::FuncOp func =
365       cluster_func->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
366           func_attr.getValue());
367 
368   std::string txt_module;
369   if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
370 
371   auto compilation_status_type =
372       RankedTensorType::get({}, builder->getType<TF::StringType>());
373   auto program_type =
374       RankedTensorType::get({3}, builder->getType<TF::StringType>());
375 
376   // Add MLIR module's fingerprint to compile metadata.
377   uint64_t mlir_fingerprint = tensorflow::Fingerprint64(txt_module);
378   metadata.set_mlir_fingerprint(mlir_fingerprint);
379 
380   std::string txt_metadata;
381   if (tpu_compile_metadata_debug) {
382     ::tensorflow::protobuf::TextFormat::Printer printer;
383     printer.SetExpandAny(true);
384     printer.PrintToString(metadata, &txt_metadata);
385   } else {
386     metadata.SerializeToString(&txt_metadata);
387   }
388 
389   auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
390       cluster_func.getLoc(),
391       /*compilation_status=*/compilation_status_type, /*program=*/
392       llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
393       compile_op_operands, txt_module, txt_metadata);
394 
395   return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
396                         compilation_device);
397 }
398 
399 // Assigns explicit devices to replicate op. An aliased device is created per
400 // core, and all replica devices per core are grouped together.
AssignDevicesToReplicate(tf_device::ReplicateOp replicate,llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,OpBuilder * builder)401 void AssignDevicesToReplicate(
402     tf_device::ReplicateOp replicate,
403     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
404         tpu_devices,
405     OpBuilder* builder) {
406   if (!replicate) return;
407 
408   const int num_replicas = tpu_devices.size();
409   const int num_cores_per_replica = tpu_devices.front().size();
410 
411   llvm::SmallVector<NamedAttribute, 8> device_attrs;
412   for (int core = 0; core < num_cores_per_replica; ++core) {
413     llvm::SmallVector<StringRef, 8> devices_by_core;
414     devices_by_core.reserve(num_replicas);
415     for (int replica = 0; replica < num_replicas; ++replica)
416       devices_by_core.push_back(tpu_devices[replica][core].device);
417 
418     device_attrs.push_back(
419         builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
420                               builder->getStrArrayAttr(devices_by_core)));
421   }
422 
423   // For data parallelism, also add replicated host devices, as these are
424   // necessary for outside compilation.
425   llvm::SmallVector<StringRef, 8> hosts;
426   hosts.reserve(num_replicas);
427   for (int replica = 0; replica < num_replicas; ++replica)
428     hosts.push_back(tpu_devices[replica][0].host);
429 
430   device_attrs.push_back(builder->getNamedAttr(
431       tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
432 
433   replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
434 }
435 
436 // Creates a `tf.TPUExecute` op that executes TPU program.
BuildExecuteOp(const int core_id,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::ArrayRef<Value> inputs,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,TF::TPUExecuteOp * execute_op)437 LogicalResult BuildExecuteOp(
438     const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
439     llvm::ArrayRef<Value> inputs, tf_device::ClusterFuncOp cluster_func,
440     OpBuilder* builder, TF::TPUExecuteOp* execute_op) {
441   // TODO(b/139377366): Need to snapshot all resource variable inputs in
442   // follow-up CLs.
443   llvm::SmallVector<Type, 4> output_types;
444   llvm::SmallVector<int, 4> cluster_to_core_index;
445   auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
446       core_id, output_sharding_config, cluster_func, &output_types,
447       &cluster_to_core_index);
448   if (failed(result)) return failure();
449 
450   // TPUExecute has same output types as cluster_func.
451   *execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
452                                                   output_types, inputs);
453   auto producer_name_attr = cluster_func->getAttr("_producer_name");
454   if (producer_name_attr)
455     (*execute_op)->setAttr("_producer_name", producer_name_attr);
456   return success();
457 }
458 
459 // Move child processes of the ParallelExecute that do not change. These are all
460 // children except for the child with the ClusterFunc.
461 //
462 // Returns the index of the child with the ClusterFunc.
MovePreservedParallelExecuteChildren(const int num_cores_per_replica,llvm::SmallVector<Type,8> & concatenated_output_types,OpBuilder * builder,tf_device::ClusterFuncOp cluster_func,tf_device::ParallelExecuteOp old_parallel_execute,tf_device::ParallelExecuteOp * new_parallel_execute)463 int MovePreservedParallelExecuteChildren(
464     const int num_cores_per_replica,
465     llvm::SmallVector<Type, 8>& concatenated_output_types, OpBuilder* builder,
466     tf_device::ClusterFuncOp cluster_func,
467     tf_device::ParallelExecuteOp old_parallel_execute,
468     tf_device::ParallelExecuteOp* new_parallel_execute) {
469   // `num_moved_children` is the number of children that will be preserved.
470   const int num_moved_children = old_parallel_execute.regions().size() - 1;
471   *new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
472       old_parallel_execute->getLoc(),
473       num_moved_children + num_cores_per_replica, concatenated_output_types);
474 
475   // `cluster_idx` is the index of the child with the `ClusterFuncOp`, which
476   // will be replaced.
477   int cluster_idx = -1;
478   for (int child_idx = 0; child_idx < old_parallel_execute.regions().size();
479        ++child_idx) {
480     auto& block = old_parallel_execute.GetRegionBlockWithIndex(child_idx);
481     if (cluster_func->getBlock() == &block) {
482       assert(cluster_idx == -1);
483       cluster_idx = child_idx;
484     }
485   }
486   assert(cluster_idx != -1);
487 
488   // Move children that are preserved.
489   for (int child_idx = 0; child_idx < num_moved_children; ++child_idx) {
490     int old_idx = child_idx >= cluster_idx ? child_idx + 1 : child_idx;
491     int new_idx = child_idx >= cluster_idx ? child_idx + num_cores_per_replica
492                                            : child_idx;
493     new_parallel_execute->getRegions()[new_idx]->takeBody(
494         *old_parallel_execute.getRegions()[old_idx]);
495   }
496 
497   return cluster_idx;
498 }
499 
500 // Given a `ParallelExecute`, replace it with a new `ParallelExecute`. The
501 // new `ParallelExecute` will replace the child that contains the
502 // `ClusterFunc` with `num_cores_per_replica` children. It keep other children
503 // the same. Return values from the child with the `ClusterFunc` will be
504 // duplicated `num_cores_per_replica` times.
AddToParallelExecuteOp(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,llvm::ArrayRef<xla::OpSharding> output_sharding_config,llvm::SmallVectorImpl<llvm::SmallVector<int,4>> * cluster_to_core_index,Operation * compile_op,tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp old_parallel_execute,tf_device::ParallelExecuteOp * new_parallel_execute,int * cluster_idx)505 LogicalResult AddToParallelExecuteOp(
506     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
507         tpu_devices,
508     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
509     llvm::SmallVectorImpl<llvm::SmallVector<int, 4>>* cluster_to_core_index,
510     Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
511     OpBuilder* builder, tf_device::ParallelExecuteOp old_parallel_execute,
512     tf_device::ParallelExecuteOp* new_parallel_execute, int* cluster_idx) {
513   const int num_cores_per_replica = tpu_devices.front().size();
514   // parallel_execute op returns concatenated list of return values of
515   // all its regions.
516   //
517   // TODO(b/149102702): Correctly map inputs to parallel_execute op via
518   // identifying xla_sharding op in the cluster_func function.
519   const auto cluster_result_types = cluster_func.getResultTypes();
520   llvm::SmallVector<Type, 8> concatenated_output_types;
521   concatenated_output_types.reserve(cluster_result_types.size() *
522                                     num_cores_per_replica);
523 
524   for (int core = 0; core < num_cores_per_replica; ++core) {
525     cluster_to_core_index->emplace_back(llvm::SmallVector<int, 4>());
526     llvm::SmallVector<Type, 4> output_types;
527     auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation(
528         core, output_sharding_config, cluster_func, &output_types,
529         &(*cluster_to_core_index)[core]);
530     if (failed(result)) return failure();
531 
532     for (Type t : output_types) concatenated_output_types.emplace_back(t);
533   }
534 
535   *cluster_idx = MovePreservedParallelExecuteChildren(
536       num_cores_per_replica, concatenated_output_types, builder, cluster_func,
537       old_parallel_execute, new_parallel_execute);
538 
539   // Extract inputs for each block of the parallel_execute op. The i-th
540   // element in the list represents the input lists to TPU computation for
541   // i-th logical core.
542   llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
543   builder->setInsertionPoint(*new_parallel_execute);
544   auto result = tensorflow::ExtractInputsForLogicalDevices(
545       num_cores_per_replica, cluster_func, builder, &input_list);
546   if (failed(result)) return failure();
547 
548   const bool replicated = tpu_devices.size() != 1;
549   // For each logical core, create a region with TPUExecute op.
550   assert(input_list.size() == num_cores_per_replica);
551   for (int core = 0; core < num_cores_per_replica; ++core) {
552     auto& block =
553         new_parallel_execute->GetRegionBlockWithIndex((*cluster_idx) + core);
554     builder->setInsertionPointToEnd(&block);
555 
556     // Create Execute op.
557     //
558     // TODO(b/148913294): Identify inputs/return values specific to each
559     // logical core TPU execution by parsing xla_sharding op in
560     // cluster_func.
561     auto execute_inputs = input_list[core];
562     execute_inputs.emplace_back(compile_op->getResult(core + 1));
563 
564     TF::TPUExecuteOp execute;
565     result = BuildExecuteOp(core, output_sharding_config, execute_inputs,
566                             cluster_func, builder, &execute);
567     if (failed(result)) return failure();
568 
569     // If computation is replicated, use aliased device. Otherwise there is only
570     // one execution device per core and the device is assigned to the execute
571     // op.
572     std::string device = replicated
573                              ? tensorflow::GetDeviceAliasForLogicalCore(core)
574                              : tpu_devices.front()[core].device;
575 
576     auto block_launch_op =
577         WrapOpInLaunch(builder, block.getParent()->getLoc(), execute, device);
578 
579     builder->create<tf_device::ReturnOp>(block.getParent()->getLoc(),
580                                          block_launch_op.getResults());
581   }
582 
583   return success();
584 }
585 
AssignDevicesToReplicatedExecute(llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost,8>> tpu_devices,Operation * execute_op,OpBuilder * builder)586 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
587     llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
588         tpu_devices,
589     Operation* execute_op, OpBuilder* builder) {
590   const bool replicated = tpu_devices.size() != 1;
591   // If computation is replicated, use aliased device. Otherwise there is only
592   // one execution device and the device is assigned to the execute op.
593   std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
594                                   : tpu_devices.front().front().device;
595 
596   return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
597 }
598 
599 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
600 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,Operation * result_id,llvm::StringRef compilation_device,OpBuilder * builder)601 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
602                                       Operation* result_id,
603                                       llvm::StringRef compilation_device,
604                                       OpBuilder* builder) {
605   auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
606       compile_op->getLoc(), result_id->getResult(0));
607   WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
608 }
609 
610 // Wrap the `ClusterFunc` in a `ParallelExecute` with only one child. This
611 // creates a canonical form, regardless of whether a `ParallelExecute` was
612 // created in a previous pass.
BuildSingletonParallelExecuteOp(tf_device::ClusterFuncOp cluster_func,OpBuilder * builder,tf_device::ParallelExecuteOp * parallel_execute)613 void BuildSingletonParallelExecuteOp(
614     tf_device::ClusterFuncOp cluster_func, OpBuilder* builder,
615     tf_device::ParallelExecuteOp* parallel_execute) {
616   if (!*parallel_execute) {
617     const auto output_types = cluster_func.getResultTypes();
618     builder->setInsertionPoint(cluster_func);
619     *parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
620         cluster_func.getLoc(), 1, output_types);
621     cluster_func->remove();
622     auto& block = parallel_execute->GetRegionBlockWithIndex(0);
623     builder->setInsertionPointToEnd(&block);
624     builder->insert(cluster_func);
625     cluster_func.replaceAllUsesWith(*parallel_execute);
626     builder->create<tf_device::ReturnOp>(block.getParent()->getLoc(),
627                                          cluster_func.getResults());
628   }
629 }
630 
631 // Unwrap the `ParallelExecute`'s contents if it only has one child.
RemoveSingletonParallelExecuteOp(tf_device::ParallelExecuteOp parallel_execute,OpBuilder * builder)632 LogicalResult RemoveSingletonParallelExecuteOp(
633     tf_device::ParallelExecuteOp parallel_execute, OpBuilder* builder) {
634   if (parallel_execute.regions().size() == 1) {
635     builder->setInsertionPoint(parallel_execute);
636     auto& block = parallel_execute.GetRegionBlockWithIndex(0);
637     llvm::SmallVector<Operation*, 2> ops_move;
638     for (Operation& op : block) {
639       ops_move.push_back(&op);
640     }
641     if (ops_move.size() != 2) {
642       parallel_execute.emitError() << "Expected 2 ops in parallel_execute.";
643       return failure();
644     }
645     auto launch = llvm::dyn_cast<tf_device::LaunchOp>(ops_move[0]);
646     if (!launch) {
647       parallel_execute.emitError()
648           << "Expected the op in parallel_execute to be a tf_device.launch";
649       return failure();
650     }
651     launch->remove();
652     builder->insert(launch);
653     parallel_execute.replaceAllUsesWith(launch);
654     parallel_execute.erase();
655   }
656   return success();
657 }
658 
CheckTPUPartitionedInputAndOutputAreValid(tf_device::ClusterFuncOp cluster,tf_device::ParallelExecuteOp parallel_execute)659 LogicalResult CheckTPUPartitionedInputAndOutputAreValid(
660     tf_device::ClusterFuncOp cluster,
661     tf_device::ParallelExecuteOp parallel_execute) {
662   for (auto cluster_result : parallel_execute.execute_outputs()) {
663     for (Operation* user :
664          llvm::make_early_inc_range(cluster_result.getUsers())) {
665       // Check that user has no outputs that are TPUPartitionedOutput
666       for (auto result : user->getResults()) {
667         for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
668           if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
669             user->emitError() << "Input of TPUPartitionedOutput must "
670                               << "be in tpu computation.";
671             return failure();
672           }
673         }
674       }
675     }
676   }
677   for (auto cluster_operand : cluster.operands()) {
678     Operation* def = cluster_operand.getDefiningOp();
679     // This pass assumes that a TPUPartitionedInput is preceeded by
680     // ReadVariable ops, and not vice versa. An earlier pass,
681     // TPUResourceReadsWritesPartitioning, should have ensured this
682     // precondition.
683     if (!def) continue;
684     for (auto operand : def->getOperands()) {
685       Operation* def_of_read = operand.getDefiningOp();
686       if (llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def_of_read)) {
687         def_of_read->emitError() << "Output of TPUPartitionedInput must "
688                                  << "be in tpu computation.";
689         return failure();
690       }
691     }
692   }
693   return success();
694 }
695 
Rewrite(tf_device::ClusterFuncOp cluster_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,ArrayRef<TF::TPUCompilationResultOp> compilation_result,OpBuilder * builder,bool tpu_compile_metadata_debug)696 LogicalResult Rewrite(
697     tf_device::ClusterFuncOp cluster_func,
698     llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
699     ArrayRef<TF::TPUCompilationResultOp> compilation_result, OpBuilder* builder,
700     bool tpu_compile_metadata_debug) {
701   // Fetch the ParallelExecute parent of `cluster_func`, or create it if it does
702   // not exist.
703   tf_device::ParallelExecuteOp old_parallel_execute =
704       cluster_func->getParentOfType<tf_device::ParallelExecuteOp>();
705   if (old_parallel_execute &&
706       cluster_func->getParentOp() != old_parallel_execute) {
707     cluster_func->emitError() << "The ParallelExecute ancestor of a "
708                                  "ClusterFunc must be its direct parent.";
709     return failure();
710   }
711   BuildSingletonParallelExecuteOp(cluster_func, builder, &old_parallel_execute);
712 
713   // check TPUPartitionedInput and TPUPartitionedOutput are in valid pattern
714   if (failed(CheckTPUPartitionedInputAndOutputAreValid(cluster_func,
715                                                        old_parallel_execute)))
716     return failure();
717 
718   // Collect `num_replicas` and `num_cores_per_replica` attributes.
719   int num_replicas = 1;
720   tf_device::ReplicateOp replicate =
721       cluster_func->getParentOfType<tf_device::ReplicateOp>();
722   if (replicate) num_replicas = replicate.n();
723 
724   auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
725       tensorflow::kNumCoresPerReplicaAttr);
726   if (!num_cores_per_replica_attr)
727     return cluster_func.emitOpError(
728         CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
729 
730   int num_cores_per_replica = num_cores_per_replica_attr.getInt();
731 
732   auto topology_attr =
733       cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
734   if (!topology_attr)
735     return cluster_func.emitOpError(
736         CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
737 
738   auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
739       tensorflow::kDeviceAssignmentAttr);
740   if (!device_assignment_attr)
741     return cluster_func.emitOpError(
742         llvm::formatv("requires attribute '{0}'",
743                       tensorflow::kDeviceAssignmentAttr)
744             .str());
745 
746   auto status_or_device_coodinates =
747       tensorflow::GetDeviceCoordinates(device_assignment_attr);
748   if (!status_or_device_coodinates.ok())
749     return cluster_func.emitError()
750            << "error in fetching tpu device coordinates: "
751            << status_or_device_coodinates.status().error_message();
752 
753   // Determine compilation and execution devices.
754   auto status_or_tpu_device_assignment =
755       tensorflow::GetTPUCompilationAndExecutionDevices(
756           devices, num_replicas, num_cores_per_replica,
757           topology_attr.getValue(), status_or_device_coodinates.value());
758   if (!status_or_tpu_device_assignment.ok())
759     return cluster_func.emitError()
760            << "error in fetching TPU compilation/execution devices: "
761            << status_or_tpu_device_assignment.status().error_message();
762 
763   // Create compile op.
764   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
765 
766   // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
767   // the parallel_execute.
768   builder->setInsertionPoint(old_parallel_execute);
769   Operation* compile_op =
770       BuildCompileOp(cluster_func, num_replicas, num_cores_per_replica,
771                      tpu_device_assignment.compilation_device,
772                      std::move(tpu_device_assignment.xla_device_assignment),
773                      builder, tpu_compile_metadata_debug);
774   if (!compile_op) return failure();
775 
776   // This replaces _TPUCompileMlir placeholder ops that are required
777   // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
778   // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
779   // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
780   // structured lowering.
781   old_parallel_execute.walk(
782       [&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
783         key_op.replaceAllUsesWith(compile_op->getResult(1));
784         key_op.erase();
785       });
786 
787   // After rewrite, if there is a TPUCompilationResultOp from the same cluster,
788   // replace it with the result of the compile op. The TPUCompilationResultOp is
789   // used as a placeholder to hook during graph creation the other ops that are
790   // intended to consume the compile result.
791   Operation* result_id = compile_op;
792   // TODO(jpienaar): Remove this later.
793   auto compile_device_op = compile_op->getAttr("device");
794   for (auto res : compilation_result) {
795     // Build identity op with the same location/name as the original compilation
796     // result op.
797     result_id = builder->create<TF::IdentityOp>(
798         res.getLoc(), compile_op->getResult(0).getType(),
799         result_id->getResult(0));
800     // Assign to same device as result is currently set, unless unset and then
801     // assign to the device on which compilation will happen.
802     // TODO(jpienaar): Remove this later.
803     if (auto device = res->getAttrOfType<StringAttr>("device")) {
804       if (!device.getValue().empty())
805         result_id->setAttr("device", device);
806       else
807         result_id->setAttr("device", compile_device_op);
808     } else if (compile_device_op) {
809       result_id->setAttr("device", compile_device_op);
810     }
811     res.output().replaceAllUsesWith(compile_op->getResult(0));
812   }
813 
814   BuildTPUCompileSucceededAssertOp(
815       compile_op, result_id, tpu_device_assignment.compilation_device, builder);
816 
817   AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
818                            builder);
819 
820   llvm::SmallVector<xla::OpSharding, 4> output_shardings;
821   auto result = tensorflow::ParseAndValidateOutputSharding(
822       num_cores_per_replica, cluster_func, &output_shardings);
823   if (failed(result)) return failure();
824 
825   // For model parallelism, tf_device.parallel_execute is used to express
826   // concurrent device execution across multiple logical devices.
827   tf_device::ParallelExecuteOp new_parallel_execute;
828   int cluster_idx;
829   llvm::SmallVector<llvm::SmallVector<int, 4>, 4> cluster_to_core_index;
830   cluster_to_core_index.reserve(num_cores_per_replica);
831   result = AddToParallelExecuteOp(
832       tpu_device_assignment.tpu_devices, output_shardings,
833       &cluster_to_core_index, compile_op, cluster_func, builder,
834       old_parallel_execute, &new_parallel_execute, &cluster_idx);
835   if (failed(result)) return failure();
836 
837   // As tf_device.parallel_execute wraps # logical cores number of TPUExecute
838   // ops, the number of return values of parallel_execute op exceeds that of
839   // cluster_func op. As such, each return value of parallel_execute op must
840   // be mapped with corresponding return value usages of cluster_func.
841   result = tensorflow::RemapOutputsFromLogicalDevices(
842       cluster_func.getLoc(), output_shardings, cluster_to_core_index,
843       old_parallel_execute, cluster_idx, new_parallel_execute, builder);
844   if (failed(result)) return failure();
845 
846   return RemoveSingletonParallelExecuteOp(new_parallel_execute, builder);
847 }
848 
849 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
850 // TPUPartitionedOutputOp are present, they must be removed along with the
851 // ClusterFuncOp(s).
EraseClusterFuncs(llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased)852 void EraseClusterFuncs(
853     llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
854   for (auto cluster : to_be_erased) {
855     auto old_parallel_execute =
856         cluster->getParentOfType<tf_device::ParallelExecuteOp>();
857     assert(old_parallel_execute);
858 
859     for (auto result : old_parallel_execute.execute_outputs()) {
860       for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
861         if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
862           assert(user->use_empty());
863           user->erase();
864         }
865       }
866     }
867 
868     for (auto operand : cluster.operands()) {
869       Operation* def = operand.getDefiningOp();
870       if (operand.hasOneUse() &&
871           llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
872         operand.dropAllUses();
873         def->erase();
874       }
875     }
876     assert(old_parallel_execute->use_empty());
877     old_parallel_execute->erase();
878   }
879 }
880 
runOnOperation()881 void TPURewritePass::runOnOperation() {
882   mlir::TF::RuntimeDevices devices;
883   if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
884     return signalPassFailure();
885 
886   // Collect compilation results.
887   llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
888       compilation_results;
889   auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
890     auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
891     if (!cluster_id) {
892       op->emitOpError("missing '_tpu_compilation_status'");
893       return WalkResult::interrupt();
894     }
895     compilation_results[cluster_id].push_back(op);
896     return WalkResult::advance();
897   });
898   if (result_init.wasInterrupted()) return signalPassFailure();
899   llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
900   OpBuilder builder(&getContext());
901   auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
902     if (failed(TF::HasValidCompilationAndReplicationAttributes(*op)))
903       return WalkResult::interrupt();
904     // Skip non-tpu device cluster_func.
905     auto cluster_id = op->getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
906     if (!cluster_id) return WalkResult::advance();
907 
908     if (failed(Rewrite(op, devices.device_names(),
909                        compilation_results[cluster_id], &builder,
910                        tpu_compile_metadata_debug_)))
911       return WalkResult::interrupt();
912 
913     to_be_erased.push_back(op);
914     return WalkResult::advance();
915   });
916   if (result.wasInterrupted()) return signalPassFailure();
917 
918   EraseClusterFuncs(to_be_erased);
919 
920   // Eliminate TPUCompilationResultOp now that the rewrite is complete.
921   for (auto& it : compilation_results) {
922     for (auto op : it.second) {
923       if (!op.use_empty()) {
924         mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
925         for (auto user : op->getUsers())
926           err.attachNote(user->getLoc()) << "remaining user";
927         return signalPassFailure();
928       }
929       op.erase();
930     }
931   }
932 
933   // TODO(b/139377366): Remove functions that are no longer needed.
934 }
935 
936 }  // namespace
937 
CreateTPURewritePass()938 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
939   return std::make_unique<TPURewritePass>();
940 }
941 
942 }  // namespace TFTPU
943 }  // namespace mlir
944