1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor_shape.pb.h"
51 #include "tensorflow/core/framework/types.pb.h"
52 #include "tensorflow/core/platform/random.h"
53 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
54
55 namespace mlir {
56 namespace TFTPU {
57
58 namespace {
59
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62 constexpr char kDefaultShardingValue[] = "";
63 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
64
GetRandomStateVariableName()65 std::string GetRandomStateVariableName() {
66 return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
67 }
68
69 // A pass that takes advantage of a loop to add ops that allow the execution to
70 // avoid repeatedly formatting variables back and forth. The desired formatting
71 // is determined by TPU program compilation, so this pass does not include how
72 // to reformat the variables, but only inserts general TPUReshardVariablesOps in
73 // proper places, and TPUReshardVariablesOps interpret the compilation.
74 //
75 // The core idea of this optimization is to keep track of the formatting state
76 // of variables, and when the next desired state does not change, it can avoid
77 // reformatting. We associate a set of variables on a device with a formatting
78 // state, and TPUReshardVariablesOps compares the current state with a desired
79 // state (which can be the compilation result). If they mismatch,
80 // TPUReshardVariablesOp reformats the variables to the desired state; if they
81 // match, TPUReshardVariablesOp is a no-op.
82 //
83 // A major use of this pass is weight-update sharding in data parallelism, so we
84 // require there is a tf_device.replicate in the loop.
85 //
86 // For example, suppose we have a training loop (for simplicity we write the
87 // loop body inine):
88 //
89 // %var0 = ...
90 // %var1 = ...
91 // tf.while (..., %var0, %var1) {
92 // tf_device.replicate ([%var0, %var1] as %rvar) {
93 // %compile:2 = "tf._TPUCompileMlir"()
94 // tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
95 // }
96 // }
97 //
98 // This pass will transform it into
99 //
100 // %var0 = ...
101 // %var1 = ...
102 // %state_var0 = ...
103 // %state_var1 = ...
104 // tf.while (..., %var0, %var1, %state_var0, %state_var1) {
105 // tf_device.replicate ([%var0, %var1] as %rvar,
106 // [%state_var0, %state_var1] as %rstate) {
107 // %compile:2 = "tf._TPUCompileMlir"()
108 // tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
109 // tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
110 // }
111 // }
112 // %default_format = tf.constant()
113 // tf_device.replicate ([%var0, %var1] as %rvar,
114 // [%state_var0, %state_var1] as %rstate) {
115 // tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
116 // }
117 struct TPUVariableRuntimeReformattingPass
118 : public PassWrapper<TPUVariableRuntimeReformattingPass,
119 OperationPass<ModuleOp>> {
120 void runOnOperation() override;
121
getArgumentmlir::TFTPU::__anon485b65d40111::TPUVariableRuntimeReformattingPass122 StringRef getArgument() const final {
123 return "tf-tpu-variable-runtime-reformatting";
124 }
125
getDescriptionmlir::TFTPU::__anon485b65d40111::TPUVariableRuntimeReformattingPass126 StringRef getDescription() const final {
127 return "Adds device variable formatting op to allow compilation-guided "
128 "variable formatting.";
129 }
130 };
131
132 // Returns the earlier value of which `v` is an identity. If `skipped` is
133 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)134 Value SkipIdentity(Value v, bool allow_other_use,
135 llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
136 while (auto result = v.dyn_cast<OpResult>()) {
137 if (!(allow_other_use || v.hasOneUse())) break;
138 auto op = result.getDefiningOp();
139 if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
140 break;
141 }
142 v = op->getOperand(result.getResultNumber());
143 if (skipped) skipped->insert(op);
144 }
145 return v;
146 }
147
148 // Finds the formattable arguments of `execute` and annotates the metadata of
149 // `compile` to record these arguments. In addition, it returns a mapping from
150 // the formattable arguments of `execute` to the corresponding operand of
151 // `replicate`. The
152 // entries in the mapping are sorted in the order of operands of `execute`.
153 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)154 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
155 TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
156 TF::TPUExecuteAndUpdateVariablesOp execute,
157 tf_device::LaunchOp compile_launch) {
158 Region& body = while_op.body();
159 Region& cond = while_op.cond();
160
161 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
162 auto mirrored_variable_indices_attr =
163 replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
164 if (!mirrored_variable_indices_attr) return mapping;
165
166 // Finds the mapping from a replicate argument to an execute operand.
167 llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
168 for (auto index_and_arg : llvm::enumerate(execute.args())) {
169 auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
170 if (!arg.hasOneUse() ||
171 !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
172 continue;
173 }
174 auto block_arg = arg.dyn_cast<BlockArgument>();
175 if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
176 assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
177 "Found duplicate use of a resource in the execute op.");
178 replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
179 index_and_arg.index();
180 }
181 if (replicate_arg_to_execute_arg.empty()) return mapping;
182
183 // Parse the original compile metadata.
184 Operation& compile = compile_launch.GetBody().front();
185 auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
186 assert(metadata_str && "Missing compilation metadata");
187 tensorflow::tpu::TPUCompileMetadataProto metadata;
188 metadata.ParseFromString(std::string(metadata_str.getValue()));
189 int64_t num_replicas = replicate.n();
190 // Find the formattable operands of `execute`, which must be mirrored
191 // variables (arguments of `replicate`), and must be pass-throughs from while
192 // operands.
193 for (const auto& mirrored_index : mirrored_variable_indices_attr) {
194 int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
195 // Check if the mirrored variable is an input to `execute`.
196 auto it = replicate_arg_to_execute_arg.find(replicate_arg);
197 if (it == replicate_arg_to_execute_arg.end()) continue;
198 // Get the data type of the resource.
199 auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
200 .cast<TF::ResourceType>()
201 .getSubtypes();
202 if (subtypes.size() != 1) continue;
203 auto data_type = getElementTypeOrSelf(subtypes[0]);
204 // The XLA backend does not yet support formatting 64-bit data types.
205 if (data_type.getIntOrFloatBitWidth() == 64) continue;
206
207 const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
208
209 int64_t num_inputs = 0;
210 if (replicate.IsReplicatedBlockArgument(block_arg)) {
211 num_inputs = num_replicas;
212 } else {
213 num_inputs = 1;
214 }
215
216 // We have found a mirrored variable which is an input to the replicated
217 // `execute`. Now find if this mirrored variable is a pass-through of while
218 // arguments.
219 llvm::SmallVector<Value, 4> replicate_args;
220 for (int64_t i = 0; i < num_inputs; ++i) {
221 llvm::SmallPtrSet<Operation*, 4> skipped_identities;
222
223 auto replicate_operand = SkipIdentity(
224 replicate.GetReplicaOperandForBlockArgument(block_arg, i),
225 /*allow_other_use=*/false, &skipped_identities);
226 // For region based control flow, the resource operand for the replicate
227 // should be a region capture. If this has any use other than the
228 // replicate op (within the body of the while) or the skipped identities,
229 // then do not apply the transformation to this variable.
230 bool is_region_capture =
231 replicate_operand.getParentRegion()->isProperAncestor(&body);
232 bool has_other_use_in_body =
233 llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
234 // Ignore uses that are not in the while body or condition.
235 if (!body.isAncestor(user->getParentRegion()) &&
236 !cond.isAncestor(user->getParentRegion()))
237 return false;
238 // Within the body or cond, only uses in replicate and the skipped
239 // identities is allowed.
240 return user != replicate && skipped_identities.count(user) == 0;
241 });
242
243 if (!is_region_capture || has_other_use_in_body) {
244 replicate_args.clear();
245 break;
246 }
247 replicate_args.push_back(replicate_operand);
248 }
249 if (replicate_args.empty()) continue;
250 // Now set the enable_xla_sharding field in the metadata to inform the
251 // compile op.
252 auto metadata_arg = metadata.mutable_args(it->second);
253 metadata_arg->set_enable_xla_sharding(
254 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
255 mapping.emplace_back(it->second, std::move(replicate_args));
256 }
257 // Sort the mapping according to execute operand order.
258 llvm::sort(mapping, llvm::less_first());
259 // Populate the `retval_index_for_sharding` field of the argument metadate.
260 for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
261 int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
262 auto arg_metadata = metadata.mutable_args(arg_index);
263 if (arg_metadata->enable_xla_sharding() ==
264 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
265 int64_t ret_index = execute.device_var_updates_indices()
266 .getValue()[entry.index()]
267 .cast<IntegerAttr>()
268 .getInt();
269 arg_metadata->set_retval_index_for_sharding(ret_index);
270 }
271 }
272 // Update the metadata of the compile op.
273 compile.setAttr("metadata", StringAttr::get(compile.getContext(),
274 metadata.SerializeAsString()));
275 return mapping;
276 }
277
278 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)279 tf_device::ReplicateOp AddInputsToReplicateOp(
280 tf_device::ReplicateOp replicate,
281 MutableArrayRef<TF::VarHandleOp> new_inputs,
282 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
283 devices) {
284 int64_t num_replicas = replicate.n();
285 assert(new_inputs.size() == num_replicas);
286
287 // As model parallelism is not yet supported, we assume that all ops are
288 // placed in logical core 0.
289 // TODO(b/148913020): Remove this constraint once model parallelism is
290 // supported.
291 assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
292 ->getSecond()
293 .size() == num_replicas);
294
295 llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
296 llvm::SmallVector<Value, 8> new_packed_inputs;
297 llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
298 replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
299 new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
300 for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
301 replicated_inputs.emplace_back();
302 for (int64_t i = 0; i < num_replicas; ++i) {
303 replicated_inputs.back().push_back(
304 replicate.GetReplicaOperandForBlockArgument(arg, i));
305 }
306 new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
307 }
308 for (const auto& arg : replicate.GetPackedBlockArguments()) {
309 new_packed_inputs.emplace_back(
310 replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
311 }
312 SmallVector<Value, 4> new_input_values;
313 new_input_values.reserve(new_inputs.size());
314 for (auto var : new_inputs) new_input_values.push_back(var.resource());
315 new_replicated_inputs.emplace_back(new_input_values,
316 new_input_values.front().getType());
317 OpBuilder builder(replicate);
318 auto new_replicate = builder.create<tf_device::ReplicateOp>(
319 replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
320 new_packed_inputs,
321 replicate.GetBody().getTerminator()->getOperandTypes());
322 for (auto arg : replicate.GetBody().getArguments()) {
323 if (replicate.IsReplicatedBlockArgument(arg)) {
324 arg.replaceAllUsesWith(
325 new_replicate.GetBody().getArgument(arg.getArgNumber()));
326 } else {
327 // There is a new added replicated state variable between replicated args
328 // and packed args.
329 arg.replaceAllUsesWith(
330 new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
331 }
332 }
333 for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
334 op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
335 }
336 replicate.replaceAllUsesWith(new_replicate);
337 replicate.erase();
338 return new_replicate;
339 }
340
341 // Creates the per-device variables that represent the formatting state of each
342 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)343 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
344 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
345 devices,
346 Location loc, RankedTensorType key_type, OpBuilder* builder) {
347 llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
348
349 // TODO(b/148913020): Remove this constraint once model parallelism is
350 // supported.
351 const auto& device_list =
352 devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
353
354 // Create the state variable for each device.
355 for (llvm::StringRef device : device_list) {
356 state_vars.push_back(builder->create<TF::VarHandleOp>(
357 loc,
358 llvm::ArrayRef<Type>{RankedTensorType::get(
359 {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
360 builder->getContext()))},
361 llvm::ArrayRef<Value>{},
362 llvm::ArrayRef<NamedAttribute>{
363 builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
364 builder->getNamedAttr("container", builder->getStringAttr("")),
365 builder->getNamedAttr(
366 "shared_name",
367 builder->getStringAttr(GetRandomStateVariableName()))}));
368 }
369 return state_vars;
370 }
371
372 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)373 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
374 llvm::StringRef device) {
375 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
376
377 auto launch = builder->create<tf_device::LaunchOp>(
378 loc, builder->getStringAttr(device), op->getResultTypes());
379 launch.body().push_back(new Block);
380
381 builder->setInsertionPointToEnd(&launch.GetBody());
382 builder->create<tf_device::ReturnOp>(loc, op->getResults());
383
384 // Move op inside launch.
385 op->moveBefore(launch.GetBody().getTerminator());
386
387 builder->restoreInsertionPoint(insert_point);
388 }
389
390 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)391 void HandleReplicateOp(TF::WhileRegionOp while_op,
392 tf_device::ReplicateOp replicate) {
393 int64_t num_replicas = replicate.n();
394 if (num_replicas == 1) return;
395 tf_device::LaunchOp execute_launch;
396 for (auto execute_launch_op :
397 replicate.GetBody().getOps<tf_device::LaunchOp>()) {
398 if (!execute_launch_op.WrapsSingleOp() ||
399 !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
400 execute_launch_op.GetBody().front()))
401 continue;
402
403 if (execute_launch == nullptr) {
404 execute_launch = execute_launch_op;
405 } else {
406 // We only support one execute op inside replicate.
407 execute_launch = nullptr;
408 break;
409 }
410 }
411 if (!execute_launch) return;
412 auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
413 execute_launch.GetBody().front());
414 auto compile =
415 SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
416 if (!compile) return;
417 auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
418 if (!compile_launch || !compile_launch.WrapsSingleOp() ||
419 !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
420 return;
421
422 // Analyze the formattable inputs.
423 auto execute_arg_to_outer_args =
424 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
425 while_op, replicate, execute, compile_launch);
426 if (execute_arg_to_outer_args.empty()) return;
427
428 // Extract the replicated devices.
429 auto devices_attr = replicate.devices();
430 if (!devices_attr) return;
431
432 auto device_map = devices_attr.getValue();
433 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
434 devices.reserve(device_map.size());
435
436 for (auto it : device_map) {
437 auto device_alias = it.first.strref();
438 auto device_list = it.second.cast<ArrayAttr>();
439 llvm::SmallVector<StringRef, 4> device_list_for_alias;
440 device_list_for_alias.reserve(device_list.size());
441
442 for (auto device : device_list)
443 device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
444
445 devices.insert({device_alias, device_list_for_alias});
446 }
447
448 OpBuilder builder(replicate);
449 builder.setInsertionPoint(while_op);
450 // Create per-device variables for formatting state, and add them to the while
451 // loop.
452 auto key_type =
453 RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
454 auto state_vars =
455 CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
456 replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
457 // Build the reformat according to the compilation. Build it inside
458 // `replicate`.
459 llvm::SmallVector<Value, 8> reformat_operands;
460 for (const auto& entry : execute_arg_to_outer_args) {
461 reformat_operands.push_back(execute.args()[entry.first]);
462 }
463 reformat_operands.push_back(compile_launch.getResult(1));
464 reformat_operands.push_back(replicate.GetBody().getArgument(
465 replicate.GetNumReplicatedBlockArguments() - 1));
466 builder.setInsertionPoint(execute_launch);
467 auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
468 execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
469 WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
470 execute_launch.device());
471
472 // Build the replicated unformat op after the loop. First prepare building the
473 // replicate op.
474 llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
475 llvm::SmallVector<Value, 8> unformat_packed_operands;
476 for (const auto& entry : execute_arg_to_outer_args) {
477 if (entry.second.size() > 1) {
478 unformat_replicate_operands.emplace_back(entry.second,
479 entry.second.front().getType());
480 } else {
481 unformat_packed_operands.emplace_back(entry.second.front());
482 }
483 }
484 llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
485 for (const auto& entry : llvm::enumerate(state_vars)) {
486 state_var_vals[entry.index()] = entry.value().resource();
487 }
488 // Add the replicated state var to the end of the replicate operands.
489 unformat_replicate_operands.emplace_back(state_var_vals,
490 state_var_vals.front().getType());
491 // Build a constant default key to specify that the unformatting should
492 // transform the variables to the original format.
493 builder.setInsertionPointAfter(while_op);
494 tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
495 default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
496 default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
497 default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
498 auto default_state_key = builder.create<TF::ConstOp>(
499 while_op.getLoc(),
500 tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
501 // With all replicated inputs, now build the replicate op.
502 auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
503 while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
504 unformat_packed_operands, TypeRange{});
505 // Then build the unformat op in the replicate op.
506 builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
507 llvm::SmallVector<Value, 8> unformat_operands;
508 // Add the replicated state var (the last replicated operand of the
509 // ReplicateOp) as the last operand of TPUReshardVariablesOp.
510 BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
511 auto replicated_block_args =
512 unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
513 auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
514 unformat_operands.append(replicated_block_args.begin(),
515 replicated_block_args.end());
516 unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
517 unformat_operands.push_back(state);
518
519 // Insert the default key as the second last operand.
520 unformat_operands.insert(
521 unformat_operands.begin() + unformat_operands.size() - 1,
522 default_state_key.getResult());
523 // Unformat op.
524 auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
525 while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
526 WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
527 execute_launch.device());
528 builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
529 }
530
runOnOperation()531 void TPUVariableRuntimeReformattingPass::runOnOperation() {
532 auto module = getOperation();
533 module.walk([&](TF::WhileRegionOp while_op) {
534 tf_device::ReplicateOp replicate;
535 while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
536 if (replicate == nullptr) {
537 replicate = replicate_op;
538 return WalkResult::advance();
539 }
540 // We do not handle loops with multiple replicate ops.
541 replicate = nullptr;
542 return WalkResult::interrupt();
543 });
544 // Model parallelism is not supported, and can be detected when a
545 // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
546 if (replicate &&
547 replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
548 HandleReplicateOp(while_op, replicate);
549 });
550 }
551
552 } // namespace
553
CreateTPUVariableReformattingPass()554 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass() {
555 return std::make_unique<TPUVariableRuntimeReformattingPass>();
556 }
557
558 static PassRegistration<TPUVariableRuntimeReformattingPass> pass;
559
560 } // namespace TFTPU
561 } // namespace mlir
562