1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.pb.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/platform/random.h"
53 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
54
55 namespace mlir {
56 namespace TFTPU {
57
58 namespace {
59
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62 constexpr char kDefaultShardingValue[] = "";
63 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
64
GetRandomStateVariableName()65 std::string GetRandomStateVariableName() {
66 return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
67 }
68
69 // A pass that takes advantage of a loop to add ops that allow the execution to
70 // avoid repeatedly formatting variables back and forth. The desired formatting
71 // is determined by TPU program compilation, so this pass does not include how
72 // to reformat the variables, but only inserts general TPUReshardVariablesOps in
73 // proper places, and TPUReshardVariablesOps interpret the compilation.
74 //
75 // The core idea of this optimization is to keep track of the formatting state
76 // of variables, and when the next desired state does not change, it can avoid
77 // reformatting. We associate a set of variables on a device with a formatting
78 // state, and TPUReshardVariablesOps compares the current state with a desired
79 // state (which can be the compilation result). If they mismatch,
80 // TPUReshardVariablesOp reformats the variables to the desired state; if they
81 // match, TPUReshardVariablesOp is a no-op.
82 //
83 // A major use of this pass is weight-update sharding in data parallelism, so we
84 // require there is a tf_device.replicate in the loop.
85 //
86 // For example, suppose we have a training loop (for simplicity we write the
87 // loop body inine):
88 //
89 // %var0 = ...
90 // %var1 = ...
91 // tf.while (..., %var0, %var1) {
92 // tf_device.replicate ([%var0, %var1] as %rvar) {
93 // %compile:2 = "tf._TPUCompileMlir"()
94 // tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
95 // }
96 // }
97 //
98 // This pass will transform it into
99 //
100 // %var0 = ...
101 // %var1 = ...
102 // %state_var0 = ...
103 // %state_var1 = ...
104 // tf.while (..., %var0, %var1, %state_var0, %state_var1) {
105 // tf_device.replicate ([%var0, %var1] as %rvar,
106 // [%state_var0, %state_var1] as %rstate) {
107 // %compile:2 = "tf._TPUCompileMlir"()
108 // tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
109 // tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
110 // }
111 // }
112 // %default_format = tf.constant()
113 // tf_device.replicate ([%var0, %var1] as %rvar,
114 // [%state_var0, %state_var1] as %rstate) {
115 // tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
116 // }
117 struct TPUVariableRuntimeReformattingPass
118 : public PassWrapper<TPUVariableRuntimeReformattingPass,
119 OperationPass<ModuleOp>> {
120 void runOnOperation() override;
121 };
122
123 // Returns the earlier value of which `v` is an identity. If `skipped` is
124 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)125 Value SkipIdentity(Value v, bool allow_other_use,
126 llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
127 while (auto result = v.dyn_cast<OpResult>()) {
128 if (!(allow_other_use || v.hasOneUse())) break;
129 auto op = result.getDefiningOp();
130 if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
131 break;
132 }
133 v = op->getOperand(result.getResultNumber());
134 if (skipped) skipped->insert(op);
135 }
136 return v;
137 }
138
139 // Finds the formattable arguments of `execute` and annotates the metadata of
140 // `compile` to record these arguments. In addition, it returns a mapping from
141 // the formattable arguments of `execute` to the corresponding operand of
142 // `replicate`. The
143 // entries in the mapping are sorted in the order of operands of `execute`.
144 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)145 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
146 TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
147 TF::TPUExecuteAndUpdateVariablesOp execute,
148 tf_device::LaunchOp compile_launch) {
149 Region& body = while_op.body();
150 Region& cond = while_op.cond();
151
152 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
153 auto mirrored_variable_indices_attr =
154 replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
155 if (!mirrored_variable_indices_attr) return mapping;
156
157 // Finds the mapping from a replicate argument to an execute operand.
158 llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
159 for (auto index_and_arg : llvm::enumerate(execute.args())) {
160 auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
161 if (!arg.hasOneUse() ||
162 !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
163 continue;
164 }
165 auto block_arg = arg.dyn_cast<BlockArgument>();
166 if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
167 assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
168 "Found duplicate use of a resource in the execute op.");
169 replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
170 index_and_arg.index();
171 }
172 if (replicate_arg_to_execute_arg.empty()) return mapping;
173
174 // Parse the original compile metadata.
175 Operation& compile = compile_launch.GetBody().front();
176 auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
177 assert(metadata_str && "Missing compilation metadata");
178 tensorflow::tpu::TPUCompileMetadataProto metadata;
179 metadata.ParseFromString(std::string(metadata_str.getValue()));
180 int64_t num_replicas = replicate.n();
181 // Find the formattable operands of `execute`, which must be mirrored
182 // variables (arguments of `replicate`), and must be pass-throughs from while
183 // operands.
184 for (const auto& mirrored_index : mirrored_variable_indices_attr) {
185 int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
186 // Check if the mirrored variable is an input to `execute`.
187 auto it = replicate_arg_to_execute_arg.find(replicate_arg);
188 if (it == replicate_arg_to_execute_arg.end()) continue;
189 // Get the data type of the resource.
190 auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
191 .cast<TF::ResourceType>()
192 .getSubtypes();
193 if (subtypes.size() != 1) continue;
194 auto data_type = getElementTypeOrSelf(subtypes[0]);
195 // The XLA backend does not yet support formatting 64-bit data types.
196 if (data_type.getIntOrFloatBitWidth() == 64) continue;
197
198 const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
199
200 int64_t num_inputs = 0;
201 if (replicate.IsReplicatedBlockArgument(block_arg)) {
202 num_inputs = num_replicas;
203 } else {
204 num_inputs = 1;
205 }
206
207 // We have found a mirrored variable which is an input to the replicated
208 // `execute`. Now find if this mirrored variable is a pass-through of while
209 // arguments.
210 llvm::SmallVector<Value, 4> replicate_args;
211 for (int64_t i = 0; i < num_inputs; ++i) {
212 llvm::SmallPtrSet<Operation*, 4> skipped_identities;
213
214 auto replicate_operand = SkipIdentity(
215 replicate.GetReplicaOperandForBlockArgument(block_arg, i),
216 /*allow_other_use=*/false, &skipped_identities);
217 // For region based control flow, the resource operand for the replicate
218 // should be a region capture. If this has any use other than the
219 // replicate op (within the body of the while) or the skipped identities,
220 // then do not apply the transformation to this variable.
221 bool is_region_capture =
222 replicate_operand.getParentRegion()->isProperAncestor(&body);
223 bool has_other_use_in_body =
224 llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
225 // Ignore uses that are not in the while body or condition.
226 if (!body.isAncestor(user->getParentRegion()) &&
227 !cond.isAncestor(user->getParentRegion()))
228 return false;
229 // Within the body or cond, only uses in replicate and the skipped
230 // identities is allowed.
231 return user != replicate && skipped_identities.count(user) == 0;
232 });
233
234 if (!is_region_capture || has_other_use_in_body) {
235 replicate_args.clear();
236 break;
237 }
238 replicate_args.push_back(replicate_operand);
239 }
240 if (replicate_args.empty()) continue;
241 // Now set the enable_xla_sharding field in the metadata to inform the
242 // compile op.
243 auto metadata_arg = metadata.mutable_args(it->second);
244 metadata_arg->set_enable_xla_sharding(
245 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
246 mapping.emplace_back(it->second, std::move(replicate_args));
247 }
248 // Sort the mapping according to execute operand order.
249 llvm::sort(mapping, llvm::less_first());
250 // Populate the `retval_index_for_sharding` field of the argument metadate.
251 for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
252 int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
253 auto arg_metadata = metadata.mutable_args(arg_index);
254 if (arg_metadata->enable_xla_sharding() ==
255 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
256 int64_t ret_index = execute.device_var_updates_indices()
257 .getValue()[entry.index()]
258 .cast<IntegerAttr>()
259 .getInt();
260 arg_metadata->set_retval_index_for_sharding(ret_index);
261 }
262 }
263 // Update the metadata of the compile op.
264 compile.setAttr("metadata", StringAttr::get(compile.getContext(),
265 metadata.SerializeAsString()));
266 return mapping;
267 }
268
269 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)270 tf_device::ReplicateOp AddInputsToReplicateOp(
271 tf_device::ReplicateOp replicate,
272 MutableArrayRef<TF::VarHandleOp> new_inputs,
273 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
274 devices) {
275 int64_t num_replicas = replicate.n();
276 assert(new_inputs.size() == num_replicas);
277
278 // As model parallelism is not yet supported, we assume that all ops are
279 // placed in logical core 0.
280 // TODO(b/148913020): Remove this constraint once model parallelism is
281 // supported.
282 assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
283 ->getSecond()
284 .size() == num_replicas);
285
286 llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
287 llvm::SmallVector<Value, 8> new_packed_inputs;
288 llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
289 replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
290 new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
291 for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
292 replicated_inputs.emplace_back();
293 for (int64_t i = 0; i < num_replicas; ++i) {
294 replicated_inputs.back().push_back(
295 replicate.GetReplicaOperandForBlockArgument(arg, i));
296 }
297 new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
298 }
299 for (const auto& arg : replicate.GetPackedBlockArguments()) {
300 new_packed_inputs.emplace_back(
301 replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
302 }
303 SmallVector<Value, 4> new_input_values;
304 new_input_values.reserve(new_inputs.size());
305 for (auto var : new_inputs) new_input_values.push_back(var.resource());
306 new_replicated_inputs.emplace_back(new_input_values,
307 new_input_values.front().getType());
308 OpBuilder builder(replicate);
309 auto new_replicate = builder.create<tf_device::ReplicateOp>(
310 replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
311 new_packed_inputs,
312 replicate.GetBody().getTerminator()->getOperandTypes());
313 for (auto arg : replicate.GetBody().getArguments()) {
314 if (replicate.IsReplicatedBlockArgument(arg)) {
315 arg.replaceAllUsesWith(
316 new_replicate.GetBody().getArgument(arg.getArgNumber()));
317 } else {
318 // There is a new added replicated state variable between replicated args
319 // and packed args.
320 arg.replaceAllUsesWith(
321 new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
322 }
323 }
324 for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
325 op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
326 }
327 replicate.replaceAllUsesWith(new_replicate);
328 replicate.erase();
329 return new_replicate;
330 }
331
332 // Creates the per-device variables that represent the formatting state of each
333 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)334 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
335 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
336 devices,
337 Location loc, RankedTensorType key_type, OpBuilder* builder) {
338 llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
339
340 // TODO(b/148913020): Remove this constraint once model parallelism is
341 // supported.
342 const auto& device_list =
343 devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
344
345 // Create the state variable for each device.
346 for (llvm::StringRef device : device_list) {
347 state_vars.push_back(builder->create<TF::VarHandleOp>(
348 loc,
349 llvm::ArrayRef<Type>{RankedTensorType::get(
350 {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
351 builder->getContext()))},
352 llvm::ArrayRef<Value>{},
353 llvm::ArrayRef<NamedAttribute>{
354 builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
355 builder->getNamedAttr("container", builder->getStringAttr("")),
356 builder->getNamedAttr(
357 "shared_name",
358 builder->getStringAttr(GetRandomStateVariableName()))}));
359 }
360 return state_vars;
361 }
362
363 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)364 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
365 llvm::StringRef device) {
366 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
367
368 auto launch = builder->create<tf_device::LaunchOp>(
369 loc, builder->getStringAttr(device), op->getResultTypes());
370 launch.body().push_back(new Block);
371
372 builder->setInsertionPointToEnd(&launch.GetBody());
373 builder->create<tf_device::ReturnOp>(loc, op->getResults());
374
375 // Move op inside launch.
376 op->moveBefore(launch.GetBody().getTerminator());
377
378 builder->restoreInsertionPoint(insert_point);
379 }
380
381 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)382 void HandleReplicateOp(TF::WhileRegionOp while_op,
383 tf_device::ReplicateOp replicate) {
384 int64_t num_replicas = replicate.n();
385 if (num_replicas == 1) return;
386 tf_device::LaunchOp execute_launch;
387 for (auto execute_launch_op :
388 replicate.GetBody().getOps<tf_device::LaunchOp>()) {
389 if (!execute_launch_op.WrapsSingleOp() ||
390 !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
391 execute_launch_op.GetBody().front()))
392 continue;
393
394 if (execute_launch == nullptr) {
395 execute_launch = execute_launch_op;
396 } else {
397 // We only support one execute op inside replicate.
398 execute_launch = nullptr;
399 break;
400 }
401 }
402 if (!execute_launch) return;
403 auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
404 execute_launch.GetBody().front());
405 auto compile =
406 SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
407 if (!compile) return;
408 auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
409 if (!compile_launch || !compile_launch.WrapsSingleOp() ||
410 !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
411 return;
412
413 // Analyze the formattable inputs.
414 auto execute_arg_to_outer_args =
415 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
416 while_op, replicate, execute, compile_launch);
417 if (execute_arg_to_outer_args.empty()) return;
418
419 // Extract the replicated devices.
420 auto devices_attr = replicate.devices();
421 if (!devices_attr) return;
422
423 auto device_map = devices_attr.getValue();
424 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
425 devices.reserve(device_map.size());
426
427 for (auto it : device_map) {
428 auto device_alias = it.first.strref();
429 auto device_list = it.second.cast<ArrayAttr>();
430 llvm::SmallVector<StringRef, 4> device_list_for_alias;
431 device_list_for_alias.reserve(device_list.size());
432
433 for (auto device : device_list)
434 device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
435
436 devices.insert({device_alias, device_list_for_alias});
437 }
438
439 OpBuilder builder(replicate);
440 builder.setInsertionPoint(while_op);
441 // Create per-device variables for formatting state, and add them to the while
442 // loop.
443 auto key_type =
444 RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
445 auto state_vars =
446 CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
447 replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
448 // Build the reformat according to the compilation. Build it inside
449 // `replicate`.
450 llvm::SmallVector<Value, 8> reformat_operands;
451 for (const auto& entry : execute_arg_to_outer_args) {
452 reformat_operands.push_back(execute.args()[entry.first]);
453 }
454 reformat_operands.push_back(compile_launch.getResult(1));
455 reformat_operands.push_back(replicate.GetBody().getArgument(
456 replicate.GetNumReplicatedBlockArguments() - 1));
457 builder.setInsertionPoint(execute_launch);
458 auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
459 execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
460 WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
461 execute_launch.device());
462
463 // Build the replicated unformat op after the loop. First prepare building the
464 // replicate op.
465 llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
466 llvm::SmallVector<Value, 8> unformat_packed_operands;
467 for (const auto& entry : execute_arg_to_outer_args) {
468 if (entry.second.size() > 1) {
469 unformat_replicate_operands.emplace_back(entry.second,
470 entry.second.front().getType());
471 } else {
472 unformat_packed_operands.emplace_back(entry.second.front());
473 }
474 }
475 llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
476 for (const auto& entry : llvm::enumerate(state_vars)) {
477 state_var_vals[entry.index()] = entry.value().resource();
478 }
479 // Add the replicated state var to the end of the replicate operands.
480 unformat_replicate_operands.emplace_back(state_var_vals,
481 state_var_vals.front().getType());
482 // Build a constant default key to specify that the unformatting should
483 // transform the variables to the original format.
484 builder.setInsertionPointAfter(while_op);
485 tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
486 default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
487 default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
488 default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
489 auto default_state_key = builder.create<TF::ConstOp>(
490 while_op.getLoc(),
491 tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
492 // With all replicated inputs, now build the replicate op.
493 auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
494 while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
495 unformat_packed_operands, TypeRange{});
496 // Then build the unformat op in the replicate op.
497 builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
498 llvm::SmallVector<Value, 8> unformat_operands;
499 // Add the replicated state var (the last replicated operand of the
500 // ReplicateOp) as the last operand of TPUReshardVariablesOp.
501 BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
502 auto replicated_block_args =
503 unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
504 auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
505 unformat_operands.append(replicated_block_args.begin(),
506 replicated_block_args.end());
507 unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
508 unformat_operands.push_back(state);
509
510 // Insert the default key as the second last operand.
511 unformat_operands.insert(
512 unformat_operands.begin() + unformat_operands.size() - 1,
513 default_state_key.getResult());
514 // Unformat op.
515 auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
516 while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
517 WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
518 execute_launch.device());
519 builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
520 }
521
runOnOperation()522 void TPUVariableRuntimeReformattingPass::runOnOperation() {
523 auto module = getOperation();
524 module.walk([&](TF::WhileRegionOp while_op) {
525 tf_device::ReplicateOp replicate;
526 while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
527 if (replicate == nullptr) {
528 replicate = replicate_op;
529 return WalkResult::advance();
530 }
531 // We do not handle loops with multiple replicate ops.
532 replicate = nullptr;
533 return WalkResult::interrupt();
534 });
535 // Model parallelism is not supported, and can be detected when a
536 // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
537 if (replicate &&
538 replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
539 HandleReplicateOp(while_op, replicate);
540 });
541 }
542
543 } // namespace
544
CreateTPUVariableReformattingPass()545 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
546 return std::make_unique<TPUVariableRuntimeReformattingPass>();
547 }
548
549 static PassRegistration<TPUVariableRuntimeReformattingPass> pass(
550 "tf-tpu-variable-runtime-reformatting",
551 "Adds device variable formatting op to allow compilation-guided variable "
552 "formatting.");
553
554 } // namespace TFTPU
555 } // namespace mlir
556