• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.pb.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/platform/random.h"
53 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
54 
55 namespace mlir {
56 namespace TFTPU {
57 
58 namespace {
59 
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62 constexpr char kDefaultShardingValue[] = "";
63 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
64 
GetRandomStateVariableName()65 std::string GetRandomStateVariableName() {
66   return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
67 }
68 
69 // A pass that takes advantage of a loop to add ops that allow the execution to
70 // avoid repeatedly formatting variables back and forth. The desired formatting
71 // is determined by TPU program compilation, so this pass does not include how
72 // to reformat the variables, but only inserts general TPUReshardVariablesOps in
73 // proper places, and TPUReshardVariablesOps interpret the compilation.
74 //
75 // The core idea of this optimization is to keep track of the formatting state
76 // of variables, and when the next desired state does not change, it can avoid
77 // reformatting. We associate a set of variables on a device with a formatting
78 // state, and TPUReshardVariablesOps compares the current state with a desired
79 // state (which can be the compilation result). If they mismatch,
80 // TPUReshardVariablesOp reformats the variables to the desired state; if they
81 // match, TPUReshardVariablesOp is a no-op.
82 //
83 // A major use of this pass is weight-update sharding in data parallelism, so we
84 // require there is a tf_device.replicate in the loop.
85 //
86 // For example, suppose we have a training loop (for simplicity we write the
87 // loop body inine):
88 //
89 //  %var0 = ...
90 //  %var1 = ...
91 //  tf.while (..., %var0, %var1) {
92 //    tf_device.replicate ([%var0, %var1] as %rvar) {
93 //      %compile:2 = "tf._TPUCompileMlir"()
94 //      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
95 //    }
96 //  }
97 //
98 // This pass will transform it into
99 //
100 //  %var0 = ...
101 //  %var1 = ...
102 //  %state_var0 = ...
103 //  %state_var1 = ...
104 //  tf.while (..., %var0, %var1, %state_var0, %state_var1) {
105 //    tf_device.replicate ([%var0, %var1] as %rvar,
106 //                         [%state_var0, %state_var1] as %rstate) {
107 //      %compile:2 = "tf._TPUCompileMlir"()
108 //      tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
109 //      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
110 //    }
111 //  }
112 //  %default_format = tf.constant()
113 //  tf_device.replicate ([%var0, %var1] as %rvar,
114 //                       [%state_var0, %state_var1] as %rstate) {
115 //    tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
116 //  }
117 struct TPUVariableRuntimeReformattingPass
118     : public PassWrapper<TPUVariableRuntimeReformattingPass,
119                          OperationPass<ModuleOp>> {
120   void runOnOperation() override;
121 };
122 
123 // Returns the earlier value of which `v` is an identity. If `skipped` is
124 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)125 Value SkipIdentity(Value v, bool allow_other_use,
126                    llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
127   while (auto result = v.dyn_cast<OpResult>()) {
128     if (!(allow_other_use || v.hasOneUse())) break;
129     auto op = result.getDefiningOp();
130     if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
131       break;
132     }
133     v = op->getOperand(result.getResultNumber());
134     if (skipped) skipped->insert(op);
135   }
136   return v;
137 }
138 
139 // Finds the formattable arguments of `execute` and annotates the metadata of
140 // `compile` to record these arguments. In addition, it returns a mapping from
141 // the formattable arguments of `execute` to the corresponding operand of
142 // `replicate`. The
143 // entries in the mapping are sorted in the order of operands of `execute`.
144 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)145 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
146     TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
147     TF::TPUExecuteAndUpdateVariablesOp execute,
148     tf_device::LaunchOp compile_launch) {
149   Region& body = while_op.body();
150   Region& cond = while_op.cond();
151 
152   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
153   auto mirrored_variable_indices_attr =
154       replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
155   if (!mirrored_variable_indices_attr) return mapping;
156 
157   // Finds the mapping from a replicate argument to an execute operand.
158   llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
159   for (auto index_and_arg : llvm::enumerate(execute.args())) {
160     auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
161     if (!arg.hasOneUse() ||
162         !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
163       continue;
164     }
165     auto block_arg = arg.dyn_cast<BlockArgument>();
166     if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
167     assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
168            "Found duplicate use of a resource in the execute op.");
169     replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
170         index_and_arg.index();
171   }
172   if (replicate_arg_to_execute_arg.empty()) return mapping;
173 
174   // Parse the original compile metadata.
175   Operation& compile = compile_launch.GetBody().front();
176   auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
177   assert(metadata_str && "Missing compilation metadata");
178   tensorflow::tpu::TPUCompileMetadataProto metadata;
179   metadata.ParseFromString(std::string(metadata_str.getValue()));
180   int64_t num_replicas = replicate.n();
181   // Find the formattable operands of `execute`, which must be mirrored
182   // variables (arguments of `replicate`), and must be pass-throughs from while
183   // operands.
184   for (const auto& mirrored_index : mirrored_variable_indices_attr) {
185     int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
186     // Check if the mirrored variable is an input to `execute`.
187     auto it = replicate_arg_to_execute_arg.find(replicate_arg);
188     if (it == replicate_arg_to_execute_arg.end()) continue;
189     // Get the data type of the resource.
190     auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
191                         .cast<TF::ResourceType>()
192                         .getSubtypes();
193     if (subtypes.size() != 1) continue;
194     auto data_type = getElementTypeOrSelf(subtypes[0]);
195     // The XLA backend does not yet support formatting 64-bit data types.
196     if (data_type.getIntOrFloatBitWidth() == 64) continue;
197 
198     const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
199 
200     int64_t num_inputs = 0;
201     if (replicate.IsReplicatedBlockArgument(block_arg)) {
202       num_inputs = num_replicas;
203     } else {
204       num_inputs = 1;
205     }
206 
207     // We have found a mirrored variable which is an input to the replicated
208     // `execute`. Now find if this mirrored variable is a pass-through of while
209     // arguments.
210     llvm::SmallVector<Value, 4> replicate_args;
211     for (int64_t i = 0; i < num_inputs; ++i) {
212       llvm::SmallPtrSet<Operation*, 4> skipped_identities;
213 
214       auto replicate_operand = SkipIdentity(
215           replicate.GetReplicaOperandForBlockArgument(block_arg, i),
216           /*allow_other_use=*/false, &skipped_identities);
217       // For region based control flow, the resource operand for the replicate
218       // should be a region capture. If this has any use other than the
219       // replicate op (within the body of the while) or the skipped identities,
220       // then do not apply the transformation to this variable.
221       bool is_region_capture =
222           replicate_operand.getParentRegion()->isProperAncestor(&body);
223       bool has_other_use_in_body =
224           llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
225             // Ignore uses that are not in the while body or condition.
226             if (!body.isAncestor(user->getParentRegion()) &&
227                 !cond.isAncestor(user->getParentRegion()))
228               return false;
229             // Within the body or cond, only uses in replicate and the skipped
230             // identities is allowed.
231             return user != replicate && skipped_identities.count(user) == 0;
232           });
233 
234       if (!is_region_capture || has_other_use_in_body) {
235         replicate_args.clear();
236         break;
237       }
238       replicate_args.push_back(replicate_operand);
239     }
240     if (replicate_args.empty()) continue;
241     // Now set the enable_xla_sharding field in the metadata to inform the
242     // compile op.
243     auto metadata_arg = metadata.mutable_args(it->second);
244     metadata_arg->set_enable_xla_sharding(
245         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
246     mapping.emplace_back(it->second, std::move(replicate_args));
247   }
248   // Sort the mapping according to execute operand order.
249   llvm::sort(mapping, llvm::less_first());
250   // Populate the `retval_index_for_sharding` field of the argument metadate.
251   for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
252     int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
253     auto arg_metadata = metadata.mutable_args(arg_index);
254     if (arg_metadata->enable_xla_sharding() ==
255         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
256       int64_t ret_index = execute.device_var_updates_indices()
257                               .getValue()[entry.index()]
258                               .cast<IntegerAttr>()
259                               .getInt();
260       arg_metadata->set_retval_index_for_sharding(ret_index);
261     }
262   }
263   // Update the metadata of the compile op.
264   compile.setAttr("metadata", StringAttr::get(compile.getContext(),
265                                               metadata.SerializeAsString()));
266   return mapping;
267 }
268 
269 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)270 tf_device::ReplicateOp AddInputsToReplicateOp(
271     tf_device::ReplicateOp replicate,
272     MutableArrayRef<TF::VarHandleOp> new_inputs,
273     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
274         devices) {
275   int64_t num_replicas = replicate.n();
276   assert(new_inputs.size() == num_replicas);
277 
278   // As model parallelism is not yet supported, we assume that all ops are
279   // placed in logical core 0.
280   // TODO(b/148913020): Remove this constraint once model parallelism is
281   // supported.
282   assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
283              ->getSecond()
284              .size() == num_replicas);
285 
286   llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
287   llvm::SmallVector<Value, 8> new_packed_inputs;
288   llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
289   replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
290   new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
291   for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
292     replicated_inputs.emplace_back();
293     for (int64_t i = 0; i < num_replicas; ++i) {
294       replicated_inputs.back().push_back(
295           replicate.GetReplicaOperandForBlockArgument(arg, i));
296     }
297     new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
298   }
299   for (const auto& arg : replicate.GetPackedBlockArguments()) {
300     new_packed_inputs.emplace_back(
301         replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
302   }
303   SmallVector<Value, 4> new_input_values;
304   new_input_values.reserve(new_inputs.size());
305   for (auto var : new_inputs) new_input_values.push_back(var.resource());
306   new_replicated_inputs.emplace_back(new_input_values,
307                                      new_input_values.front().getType());
308   OpBuilder builder(replicate);
309   auto new_replicate = builder.create<tf_device::ReplicateOp>(
310       replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
311       new_packed_inputs,
312       replicate.GetBody().getTerminator()->getOperandTypes());
313   for (auto arg : replicate.GetBody().getArguments()) {
314     if (replicate.IsReplicatedBlockArgument(arg)) {
315       arg.replaceAllUsesWith(
316           new_replicate.GetBody().getArgument(arg.getArgNumber()));
317     } else {
318       // There is a new added replicated state variable between replicated args
319       // and packed args.
320       arg.replaceAllUsesWith(
321           new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
322     }
323   }
324   for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
325     op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
326   }
327   replicate.replaceAllUsesWith(new_replicate);
328   replicate.erase();
329   return new_replicate;
330 }
331 
332 // Creates the per-device variables that represent the formatting state of each
333 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)334 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
335     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
336         devices,
337     Location loc, RankedTensorType key_type, OpBuilder* builder) {
338   llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
339 
340   // TODO(b/148913020): Remove this constraint once model parallelism is
341   // supported.
342   const auto& device_list =
343       devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
344 
345   // Create the state variable for each device.
346   for (llvm::StringRef device : device_list) {
347     state_vars.push_back(builder->create<TF::VarHandleOp>(
348         loc,
349         llvm::ArrayRef<Type>{RankedTensorType::get(
350             {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
351                                       builder->getContext()))},
352         llvm::ArrayRef<Value>{},
353         llvm::ArrayRef<NamedAttribute>{
354             builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
355             builder->getNamedAttr("container", builder->getStringAttr("")),
356             builder->getNamedAttr(
357                 "shared_name",
358                 builder->getStringAttr(GetRandomStateVariableName()))}));
359   }
360   return state_vars;
361 }
362 
363 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)364 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
365                     llvm::StringRef device) {
366   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
367 
368   auto launch = builder->create<tf_device::LaunchOp>(
369       loc, builder->getStringAttr(device), op->getResultTypes());
370   launch.body().push_back(new Block);
371 
372   builder->setInsertionPointToEnd(&launch.GetBody());
373   builder->create<tf_device::ReturnOp>(loc, op->getResults());
374 
375   // Move op inside launch.
376   op->moveBefore(launch.GetBody().getTerminator());
377 
378   builder->restoreInsertionPoint(insert_point);
379 }
380 
381 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)382 void HandleReplicateOp(TF::WhileRegionOp while_op,
383                        tf_device::ReplicateOp replicate) {
384   int64_t num_replicas = replicate.n();
385   if (num_replicas == 1) return;
386   tf_device::LaunchOp execute_launch;
387   for (auto execute_launch_op :
388        replicate.GetBody().getOps<tf_device::LaunchOp>()) {
389     if (!execute_launch_op.WrapsSingleOp() ||
390         !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
391             execute_launch_op.GetBody().front()))
392       continue;
393 
394     if (execute_launch == nullptr) {
395       execute_launch = execute_launch_op;
396     } else {
397       // We only support one execute op inside replicate.
398       execute_launch = nullptr;
399       break;
400     }
401   }
402   if (!execute_launch) return;
403   auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
404       execute_launch.GetBody().front());
405   auto compile =
406       SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
407   if (!compile) return;
408   auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
409   if (!compile_launch || !compile_launch.WrapsSingleOp() ||
410       !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
411     return;
412 
413   // Analyze the formattable inputs.
414   auto execute_arg_to_outer_args =
415       AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
416           while_op, replicate, execute, compile_launch);
417   if (execute_arg_to_outer_args.empty()) return;
418 
419   // Extract the replicated devices.
420   auto devices_attr = replicate.devices();
421   if (!devices_attr) return;
422 
423   auto device_map = devices_attr.getValue();
424   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
425   devices.reserve(device_map.size());
426 
427   for (auto it : device_map) {
428     auto device_alias = it.first.strref();
429     auto device_list = it.second.cast<ArrayAttr>();
430     llvm::SmallVector<StringRef, 4> device_list_for_alias;
431     device_list_for_alias.reserve(device_list.size());
432 
433     for (auto device : device_list)
434       device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
435 
436     devices.insert({device_alias, device_list_for_alias});
437   }
438 
439   OpBuilder builder(replicate);
440   builder.setInsertionPoint(while_op);
441   // Create per-device variables for formatting state, and add them to the while
442   // loop.
443   auto key_type =
444       RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
445   auto state_vars =
446       CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
447   replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
448   // Build the reformat according to the compilation. Build it inside
449   // `replicate`.
450   llvm::SmallVector<Value, 8> reformat_operands;
451   for (const auto& entry : execute_arg_to_outer_args) {
452     reformat_operands.push_back(execute.args()[entry.first]);
453   }
454   reformat_operands.push_back(compile_launch.getResult(1));
455   reformat_operands.push_back(replicate.GetBody().getArgument(
456       replicate.GetNumReplicatedBlockArguments() - 1));
457   builder.setInsertionPoint(execute_launch);
458   auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
459       execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
460   WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
461                  execute_launch.device());
462 
463   // Build the replicated unformat op after the loop. First prepare building the
464   // replicate op.
465   llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
466   llvm::SmallVector<Value, 8> unformat_packed_operands;
467   for (const auto& entry : execute_arg_to_outer_args) {
468     if (entry.second.size() > 1) {
469       unformat_replicate_operands.emplace_back(entry.second,
470                                                entry.second.front().getType());
471     } else {
472       unformat_packed_operands.emplace_back(entry.second.front());
473     }
474   }
475   llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
476   for (const auto& entry : llvm::enumerate(state_vars)) {
477     state_var_vals[entry.index()] = entry.value().resource();
478   }
479   // Add the replicated state var to the end of the replicate operands.
480   unformat_replicate_operands.emplace_back(state_var_vals,
481                                            state_var_vals.front().getType());
482   // Build a constant default key to specify that the unformatting should
483   // transform the variables to the original format.
484   builder.setInsertionPointAfter(while_op);
485   tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
486   default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
487   default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
488   default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
489   auto default_state_key = builder.create<TF::ConstOp>(
490       while_op.getLoc(),
491       tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
492   // With all replicated inputs, now build the replicate op.
493   auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
494       while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
495       unformat_packed_operands, TypeRange{});
496   // Then build the unformat op in the replicate op.
497   builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
498   llvm::SmallVector<Value, 8> unformat_operands;
499   // Add the replicated state var (the last replicated operand of the
500   // ReplicateOp) as the last operand of TPUReshardVariablesOp.
501   BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
502   auto replicated_block_args =
503       unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
504   auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
505   unformat_operands.append(replicated_block_args.begin(),
506                            replicated_block_args.end());
507   unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
508   unformat_operands.push_back(state);
509 
510   // Insert the default key as the second last operand.
511   unformat_operands.insert(
512       unformat_operands.begin() + unformat_operands.size() - 1,
513       default_state_key.getResult());
514   // Unformat op.
515   auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
516       while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
517   WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
518                  execute_launch.device());
519   builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
520 }
521 
runOnOperation()522 void TPUVariableRuntimeReformattingPass::runOnOperation() {
523   auto module = getOperation();
524   module.walk([&](TF::WhileRegionOp while_op) {
525     tf_device::ReplicateOp replicate;
526     while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
527       if (replicate == nullptr) {
528         replicate = replicate_op;
529         return WalkResult::advance();
530       }
531       // We do not handle loops with multiple replicate ops.
532       replicate = nullptr;
533       return WalkResult::interrupt();
534     });
535     // Model parallelism is not supported, and can be detected when a
536     // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
537     if (replicate &&
538         replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
539       HandleReplicateOp(while_op, replicate);
540   });
541 }
542 
543 }  // namespace
544 
CreateTPUVariableReformattingPass()545 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
546   return std::make_unique<TPUVariableRuntimeReformattingPass>();
547 }
548 
549 static PassRegistration<TPUVariableRuntimeReformattingPass> pass(
550     "tf-tpu-variable-runtime-reformatting",
551     "Adds device variable formatting op to allow compilation-guided variable "
552     "formatting.");
553 
554 }  // namespace TFTPU
555 }  // namespace mlir
556