• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.pb.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/platform/random.h"
53 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
54 
55 namespace mlir {
56 namespace TFTPU {
57 
58 namespace {
59 
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62 constexpr char kDefaultShardingValue[] = "";
63 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
64 
GetRandomStateVariableName()65 std::string GetRandomStateVariableName() {
66   return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
67 }
68 
69 // A pass that takes advantage of a loop to add ops that allow the execution to
70 // avoid repeatedly formatting variables back and forth. The desired formatting
71 // is determined by TPU program compilation, so this pass does not include how
72 // to reformat the variables, but only inserts general TPUReshardVariablesOps in
73 // proper places, and TPUReshardVariablesOps interpret the compilation.
74 //
75 // The core idea of this optimization is to keep track of the formatting state
76 // of variables, and when the next desired state does not change, it can avoid
77 // reformatting. We associate a set of variables on a device with a formatting
78 // state, and TPUReshardVariablesOps compares the current state with a desired
79 // state (which can be the compilation result). If they mismatch,
80 // TPUReshardVariablesOp reformats the variables to the desired state; if they
81 // match, TPUReshardVariablesOp is a no-op.
82 //
83 // A major use of this pass is weight-update sharding in data parallelism, so we
84 // require there is a tf_device.replicate in the loop.
85 //
86 // For example, suppose we have a training loop (for simplicity we write the
87 // loop body inine):
88 //
89 //  %var0 = ...
90 //  %var1 = ...
91 //  tf.while (..., %var0, %var1) {
92 //    tf_device.replicate ([%var0, %var1] as %rvar) {
93 //      %compile:2 = "tf._TPUCompileMlir"()
94 //      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
95 //    }
96 //  }
97 //
98 // This pass will transform it into
99 //
100 //  %var0 = ...
101 //  %var1 = ...
102 //  %state_var0 = ...
103 //  %state_var1 = ...
104 //  tf.while (..., %var0, %var1, %state_var0, %state_var1) {
105 //    tf_device.replicate ([%var0, %var1] as %rvar,
106 //                         [%state_var0, %state_var1] as %rstate) {
107 //      %compile:2 = "tf._TPUCompileMlir"()
108 //      tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
109 //      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
110 //    }
111 //  }
112 //  %default_format = tf.constant()
113 //  tf_device.replicate ([%var0, %var1] as %rvar,
114 //                       [%state_var0, %state_var1] as %rstate) {
115 //    tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
116 //  }
117 struct TPUVariableRuntimeReformattingPass
118     : public PassWrapper<TPUVariableRuntimeReformattingPass,
119                          OperationPass<ModuleOp>> {
120   void runOnOperation() override;
121 
getArgumentmlir::TFTPU::__anon485b65d40111::TPUVariableRuntimeReformattingPass122   StringRef getArgument() const final {
123     return "tf-tpu-variable-runtime-reformatting";
124   }
125 
getDescriptionmlir::TFTPU::__anon485b65d40111::TPUVariableRuntimeReformattingPass126   StringRef getDescription() const final {
127     return "Adds device variable formatting op to allow compilation-guided "
128            "variable formatting.";
129   }
130 };
131 
132 // Returns the earlier value of which `v` is an identity. If `skipped` is
133 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)134 Value SkipIdentity(Value v, bool allow_other_use,
135                    llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
136   while (auto result = v.dyn_cast<OpResult>()) {
137     if (!(allow_other_use || v.hasOneUse())) break;
138     auto op = result.getDefiningOp();
139     if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
140       break;
141     }
142     v = op->getOperand(result.getResultNumber());
143     if (skipped) skipped->insert(op);
144   }
145   return v;
146 }
147 
148 // Finds the formattable arguments of `execute` and annotates the metadata of
149 // `compile` to record these arguments. In addition, it returns a mapping from
150 // the formattable arguments of `execute` to the corresponding operand of
151 // `replicate`. The
152 // entries in the mapping are sorted in the order of operands of `execute`.
153 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)154 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
155     TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
156     TF::TPUExecuteAndUpdateVariablesOp execute,
157     tf_device::LaunchOp compile_launch) {
158   Region& body = while_op.body();
159   Region& cond = while_op.cond();
160 
161   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
162   auto mirrored_variable_indices_attr =
163       replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
164   if (!mirrored_variable_indices_attr) return mapping;
165 
166   // Finds the mapping from a replicate argument to an execute operand.
167   llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
168   for (auto index_and_arg : llvm::enumerate(execute.args())) {
169     auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
170     if (!arg.hasOneUse() ||
171         !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
172       continue;
173     }
174     auto block_arg = arg.dyn_cast<BlockArgument>();
175     if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
176     assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
177            "Found duplicate use of a resource in the execute op.");
178     replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
179         index_and_arg.index();
180   }
181   if (replicate_arg_to_execute_arg.empty()) return mapping;
182 
183   // Parse the original compile metadata.
184   Operation& compile = compile_launch.GetBody().front();
185   auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
186   assert(metadata_str && "Missing compilation metadata");
187   tensorflow::tpu::TPUCompileMetadataProto metadata;
188   metadata.ParseFromString(std::string(metadata_str.getValue()));
189   int64_t num_replicas = replicate.n();
190   // Find the formattable operands of `execute`, which must be mirrored
191   // variables (arguments of `replicate`), and must be pass-throughs from while
192   // operands.
193   for (const auto& mirrored_index : mirrored_variable_indices_attr) {
194     int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
195     // Check if the mirrored variable is an input to `execute`.
196     auto it = replicate_arg_to_execute_arg.find(replicate_arg);
197     if (it == replicate_arg_to_execute_arg.end()) continue;
198     // Get the data type of the resource.
199     auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
200                         .cast<TF::ResourceType>()
201                         .getSubtypes();
202     if (subtypes.size() != 1) continue;
203     auto data_type = getElementTypeOrSelf(subtypes[0]);
204     // The XLA backend does not yet support formatting 64-bit data types.
205     if (data_type.getIntOrFloatBitWidth() == 64) continue;
206 
207     const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
208 
209     int64_t num_inputs = 0;
210     if (replicate.IsReplicatedBlockArgument(block_arg)) {
211       num_inputs = num_replicas;
212     } else {
213       num_inputs = 1;
214     }
215 
216     // We have found a mirrored variable which is an input to the replicated
217     // `execute`. Now find if this mirrored variable is a pass-through of while
218     // arguments.
219     llvm::SmallVector<Value, 4> replicate_args;
220     for (int64_t i = 0; i < num_inputs; ++i) {
221       llvm::SmallPtrSet<Operation*, 4> skipped_identities;
222 
223       auto replicate_operand = SkipIdentity(
224           replicate.GetReplicaOperandForBlockArgument(block_arg, i),
225           /*allow_other_use=*/false, &skipped_identities);
226       // For region based control flow, the resource operand for the replicate
227       // should be a region capture. If this has any use other than the
228       // replicate op (within the body of the while) or the skipped identities,
229       // then do not apply the transformation to this variable.
230       bool is_region_capture =
231           replicate_operand.getParentRegion()->isProperAncestor(&body);
232       bool has_other_use_in_body =
233           llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
234             // Ignore uses that are not in the while body or condition.
235             if (!body.isAncestor(user->getParentRegion()) &&
236                 !cond.isAncestor(user->getParentRegion()))
237               return false;
238             // Within the body or cond, only uses in replicate and the skipped
239             // identities is allowed.
240             return user != replicate && skipped_identities.count(user) == 0;
241           });
242 
243       if (!is_region_capture || has_other_use_in_body) {
244         replicate_args.clear();
245         break;
246       }
247       replicate_args.push_back(replicate_operand);
248     }
249     if (replicate_args.empty()) continue;
250     // Now set the enable_xla_sharding field in the metadata to inform the
251     // compile op.
252     auto metadata_arg = metadata.mutable_args(it->second);
253     metadata_arg->set_enable_xla_sharding(
254         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
255     mapping.emplace_back(it->second, std::move(replicate_args));
256   }
257   // Sort the mapping according to execute operand order.
258   llvm::sort(mapping, llvm::less_first());
259   // Populate the `retval_index_for_sharding` field of the argument metadate.
260   for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
261     int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
262     auto arg_metadata = metadata.mutable_args(arg_index);
263     if (arg_metadata->enable_xla_sharding() ==
264         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
265       int64_t ret_index = execute.device_var_updates_indices()
266                               .getValue()[entry.index()]
267                               .cast<IntegerAttr>()
268                               .getInt();
269       arg_metadata->set_retval_index_for_sharding(ret_index);
270     }
271   }
272   // Update the metadata of the compile op.
273   compile.setAttr("metadata", StringAttr::get(compile.getContext(),
274                                               metadata.SerializeAsString()));
275   return mapping;
276 }
277 
278 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)279 tf_device::ReplicateOp AddInputsToReplicateOp(
280     tf_device::ReplicateOp replicate,
281     MutableArrayRef<TF::VarHandleOp> new_inputs,
282     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
283         devices) {
284   int64_t num_replicas = replicate.n();
285   assert(new_inputs.size() == num_replicas);
286 
287   // As model parallelism is not yet supported, we assume that all ops are
288   // placed in logical core 0.
289   // TODO(b/148913020): Remove this constraint once model parallelism is
290   // supported.
291   assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
292              ->getSecond()
293              .size() == num_replicas);
294 
295   llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
296   llvm::SmallVector<Value, 8> new_packed_inputs;
297   llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
298   replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
299   new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
300   for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
301     replicated_inputs.emplace_back();
302     for (int64_t i = 0; i < num_replicas; ++i) {
303       replicated_inputs.back().push_back(
304           replicate.GetReplicaOperandForBlockArgument(arg, i));
305     }
306     new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
307   }
308   for (const auto& arg : replicate.GetPackedBlockArguments()) {
309     new_packed_inputs.emplace_back(
310         replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
311   }
312   SmallVector<Value, 4> new_input_values;
313   new_input_values.reserve(new_inputs.size());
314   for (auto var : new_inputs) new_input_values.push_back(var.resource());
315   new_replicated_inputs.emplace_back(new_input_values,
316                                      new_input_values.front().getType());
317   OpBuilder builder(replicate);
318   auto new_replicate = builder.create<tf_device::ReplicateOp>(
319       replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
320       new_packed_inputs,
321       replicate.GetBody().getTerminator()->getOperandTypes());
322   for (auto arg : replicate.GetBody().getArguments()) {
323     if (replicate.IsReplicatedBlockArgument(arg)) {
324       arg.replaceAllUsesWith(
325           new_replicate.GetBody().getArgument(arg.getArgNumber()));
326     } else {
327       // There is a new added replicated state variable between replicated args
328       // and packed args.
329       arg.replaceAllUsesWith(
330           new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
331     }
332   }
333   for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
334     op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
335   }
336   replicate.replaceAllUsesWith(new_replicate);
337   replicate.erase();
338   return new_replicate;
339 }
340 
341 // Creates the per-device variables that represent the formatting state of each
342 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)343 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
344     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
345         devices,
346     Location loc, RankedTensorType key_type, OpBuilder* builder) {
347   llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
348 
349   // TODO(b/148913020): Remove this constraint once model parallelism is
350   // supported.
351   const auto& device_list =
352       devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
353 
354   // Create the state variable for each device.
355   for (llvm::StringRef device : device_list) {
356     state_vars.push_back(builder->create<TF::VarHandleOp>(
357         loc,
358         llvm::ArrayRef<Type>{RankedTensorType::get(
359             {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
360                                       builder->getContext()))},
361         llvm::ArrayRef<Value>{},
362         llvm::ArrayRef<NamedAttribute>{
363             builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
364             builder->getNamedAttr("container", builder->getStringAttr("")),
365             builder->getNamedAttr(
366                 "shared_name",
367                 builder->getStringAttr(GetRandomStateVariableName()))}));
368   }
369   return state_vars;
370 }
371 
372 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)373 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
374                     llvm::StringRef device) {
375   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
376 
377   auto launch = builder->create<tf_device::LaunchOp>(
378       loc, builder->getStringAttr(device), op->getResultTypes());
379   launch.body().push_back(new Block);
380 
381   builder->setInsertionPointToEnd(&launch.GetBody());
382   builder->create<tf_device::ReturnOp>(loc, op->getResults());
383 
384   // Move op inside launch.
385   op->moveBefore(launch.GetBody().getTerminator());
386 
387   builder->restoreInsertionPoint(insert_point);
388 }
389 
390 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)391 void HandleReplicateOp(TF::WhileRegionOp while_op,
392                        tf_device::ReplicateOp replicate) {
393   int64_t num_replicas = replicate.n();
394   if (num_replicas == 1) return;
395   tf_device::LaunchOp execute_launch;
396   for (auto execute_launch_op :
397        replicate.GetBody().getOps<tf_device::LaunchOp>()) {
398     if (!execute_launch_op.WrapsSingleOp() ||
399         !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
400             execute_launch_op.GetBody().front()))
401       continue;
402 
403     if (execute_launch == nullptr) {
404       execute_launch = execute_launch_op;
405     } else {
406       // We only support one execute op inside replicate.
407       execute_launch = nullptr;
408       break;
409     }
410   }
411   if (!execute_launch) return;
412   auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
413       execute_launch.GetBody().front());
414   auto compile =
415       SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
416   if (!compile) return;
417   auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
418   if (!compile_launch || !compile_launch.WrapsSingleOp() ||
419       !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
420     return;
421 
422   // Analyze the formattable inputs.
423   auto execute_arg_to_outer_args =
424       AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
425           while_op, replicate, execute, compile_launch);
426   if (execute_arg_to_outer_args.empty()) return;
427 
428   // Extract the replicated devices.
429   auto devices_attr = replicate.devices();
430   if (!devices_attr) return;
431 
432   auto device_map = devices_attr.getValue();
433   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
434   devices.reserve(device_map.size());
435 
436   for (auto it : device_map) {
437     auto device_alias = it.first.strref();
438     auto device_list = it.second.cast<ArrayAttr>();
439     llvm::SmallVector<StringRef, 4> device_list_for_alias;
440     device_list_for_alias.reserve(device_list.size());
441 
442     for (auto device : device_list)
443       device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
444 
445     devices.insert({device_alias, device_list_for_alias});
446   }
447 
448   OpBuilder builder(replicate);
449   builder.setInsertionPoint(while_op);
450   // Create per-device variables for formatting state, and add them to the while
451   // loop.
452   auto key_type =
453       RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
454   auto state_vars =
455       CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
456   replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
457   // Build the reformat according to the compilation. Build it inside
458   // `replicate`.
459   llvm::SmallVector<Value, 8> reformat_operands;
460   for (const auto& entry : execute_arg_to_outer_args) {
461     reformat_operands.push_back(execute.args()[entry.first]);
462   }
463   reformat_operands.push_back(compile_launch.getResult(1));
464   reformat_operands.push_back(replicate.GetBody().getArgument(
465       replicate.GetNumReplicatedBlockArguments() - 1));
466   builder.setInsertionPoint(execute_launch);
467   auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
468       execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
469   WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
470                  execute_launch.device());
471 
472   // Build the replicated unformat op after the loop. First prepare building the
473   // replicate op.
474   llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
475   llvm::SmallVector<Value, 8> unformat_packed_operands;
476   for (const auto& entry : execute_arg_to_outer_args) {
477     if (entry.second.size() > 1) {
478       unformat_replicate_operands.emplace_back(entry.second,
479                                                entry.second.front().getType());
480     } else {
481       unformat_packed_operands.emplace_back(entry.second.front());
482     }
483   }
484   llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
485   for (const auto& entry : llvm::enumerate(state_vars)) {
486     state_var_vals[entry.index()] = entry.value().resource();
487   }
488   // Add the replicated state var to the end of the replicate operands.
489   unformat_replicate_operands.emplace_back(state_var_vals,
490                                            state_var_vals.front().getType());
491   // Build a constant default key to specify that the unformatting should
492   // transform the variables to the original format.
493   builder.setInsertionPointAfter(while_op);
494   tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
495   default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
496   default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
497   default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
498   auto default_state_key = builder.create<TF::ConstOp>(
499       while_op.getLoc(),
500       tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
501   // With all replicated inputs, now build the replicate op.
502   auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
503       while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
504       unformat_packed_operands, TypeRange{});
505   // Then build the unformat op in the replicate op.
506   builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
507   llvm::SmallVector<Value, 8> unformat_operands;
508   // Add the replicated state var (the last replicated operand of the
509   // ReplicateOp) as the last operand of TPUReshardVariablesOp.
510   BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
511   auto replicated_block_args =
512       unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
513   auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
514   unformat_operands.append(replicated_block_args.begin(),
515                            replicated_block_args.end());
516   unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
517   unformat_operands.push_back(state);
518 
519   // Insert the default key as the second last operand.
520   unformat_operands.insert(
521       unformat_operands.begin() + unformat_operands.size() - 1,
522       default_state_key.getResult());
523   // Unformat op.
524   auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
525       while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
526   WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
527                  execute_launch.device());
528   builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
529 }
530 
runOnOperation()531 void TPUVariableRuntimeReformattingPass::runOnOperation() {
532   auto module = getOperation();
533   module.walk([&](TF::WhileRegionOp while_op) {
534     tf_device::ReplicateOp replicate;
535     while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
536       if (replicate == nullptr) {
537         replicate = replicate_op;
538         return WalkResult::advance();
539       }
540       // We do not handle loops with multiple replicate ops.
541       replicate = nullptr;
542       return WalkResult::interrupt();
543     });
544     // Model parallelism is not supported, and can be detected when a
545     // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
546     if (replicate &&
547         replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
548       HandleReplicateOp(while_op, replicate);
549   });
550 }
551 
552 }  // namespace
553 
CreateTPUVariableReformattingPass()554 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
555   return std::make_unique<TPUVariableRuntimeReformattingPass>();
556 }
557 
558 static PassRegistration<TPUVariableRuntimeReformattingPass> pass;
559 
560 }  // namespace TFTPU
561 }  // namespace mlir
562