1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
44 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
51 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
52
53 #define DEBUG_TYPE "tf-tpu-merge-variables-with-execute"
54
55 namespace mlir {
56 namespace TFTPU {
57
58 namespace {
59 constexpr char kAliasingAttr[] = "tf.aliasing_output";
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62
63 class TPUMergeVariablesWithExecutePass
64 : public TF::TPUMergeVariablesWithExecutePassBase<
65 TPUMergeVariablesWithExecutePass> {
getDependentDialects(DialectRegistry & registry) const66 void getDependentDialects(DialectRegistry& registry) const override {
67 // We need this here because at the moment we deserialize the TPUCompileMlir
68 // operation which contains annotation like `mhlo.sharding` attributes.
69 registry.insert<mhlo::MhloDialect>();
70 }
71 void runOnOperation() override;
72 };
73
74 // Information for a pair of input/output of the TPUExecute op and the
75 // surrounding read/assign ops.
76 struct VariableAccessInfo {
77 int execute_input_index = -1;
78 int execute_output_index = -1;
79 Operation* read = nullptr;
80 Operation* assign = nullptr;
81 };
82
83 // Information about all resource accesses to be merged into a TPUExecute op.
84 struct VariableAccessesForTPUExecute {
85 // Maps each detected resource to a VariableAccessInfo. Eventually, this will
86 // contain all values for which we want to merge the accessing ops with a
87 // TPUExecute op.
88 llvm::SmallDenseMap<Value, VariableAccessInfo, 8> per_resource_info;
89 // The corresponding new output index in TPUExecuteAndUpdateVariables for
90 // each old output index in TPUExecute.
91 llvm::SmallVector<int, 8> old_to_new_output_mapping;
92 // The resources read by ReadVariableOps that are inputs to TPUExecute,
93 // ordered by the input indices to TPUExecute.
94 llvm::SmallVector<Value, 8> resources_read;
95 // Operands for the new TPUExecuteAndUpdateVariables.
96 llvm::SmallVector<Value, 8> new_operand_values;
97 };
98
99 // Returns true iff the read or assign op associated with `resource` can be
100 // safely merged.
101 //
102 // `resource_ids` contains IDs of all previously accessed resources
103 // `previous_unknown_resource_access` is true if we had any previous unknown
104 // resource access.
IsResourceSafeForMerge(Value resource,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,const VariableAccessesForTPUExecute & infos,const llvm::SmallDenseSet<int64_t> & resource_ids,bool previous_unknown_resource_access)105 bool IsResourceSafeForMerge(
106 Value resource,
107 const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
108 const VariableAccessesForTPUExecute& infos,
109 const llvm::SmallDenseSet<int64_t>& resource_ids,
110 bool previous_unknown_resource_access) {
111 // If we had any unknown resource access before, then we conservatively assume
112 // that `resource` has been accessed before.
113 // If `resource` is an unknown resource, then we conservatively assume that
114 // the same resource has been accessed before.
115 if (previous_unknown_resource_access ||
116 resource_analysis_info.IsUnknownResource(resource))
117 return false;
118 const auto& ids = resource_analysis_info.GetResourceUniqueIds(resource);
119 for (int64_t id : ids) {
120 if (resource_ids.contains(id)) return false;
121 }
122 return true;
123 }
124
125 // Adds IDs of resources which `op` accesses to `resource_ids`.
126 // Returns true iff op accesses a resource unknown to `resource_analysis_info`
127 // in which case we have to conservatively assume that any resource might be
128 // accessed.
AddAccessedResourceIds(Operation * op,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,llvm::SmallDenseSet<int64_t> & resource_ids)129 bool AddAccessedResourceIds(
130 Operation* op,
131 const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
132 llvm::SmallDenseSet<int64_t>& resource_ids) {
133 for (Value operand : TF::filter_resources(op->getOperands())) {
134 if (resource_analysis_info.IsUnknownResource(operand)) {
135 VLOG(2) << " unknown access";
136 return true;
137 }
138 const auto& ids = resource_analysis_info.GetResourceUniqueIds(operand);
139 VLOG(2) << " accesses following resources: " << absl::StrJoin(ids, ", ");
140 resource_ids.insert(ids.begin(), ids.end());
141 }
142 return false;
143 }
144
145 // Finds the variable access info for a TPUExecute op.
146 // - `check_device` specifies whether it checks the device assignment of the
147 // variables to match the TPUExecute op. This is optional in some context,
148 // e.g., guaranteed by replication.
149 // - `check_same_region` specifies whether the reads/assigns need to be in the
150 // same region as `execute`. This is needed if `execute` is inside ReplicateOp.
BuildVariableAccessInfo(tf_device::LaunchOp execute_launch,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,bool check_device,bool check_same_region)151 VariableAccessesForTPUExecute BuildVariableAccessInfo(
152 tf_device::LaunchOp execute_launch,
153 const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
154 bool check_device, bool check_same_region) {
155 VariableAccessesForTPUExecute var_access_info;
156 Attribute device_attr = execute_launch.deviceAttr();
157 if (check_device && !device_attr) return var_access_info;
158 auto func = execute_launch->getParentOfType<mlir::func::FuncOp>();
159
160 // Track the first read op found, which is used later to check if there are
161 // assign ops between it and the TPUExecute op. We will exclude reads before
162 // interfering accesses in a conservative way (see below). We do not consider
163 // resource accesses in other islands since their ordering is enforced by
164 // inter-island dependencies.
165 Operation* first_read = nullptr;
166 auto execute = cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
167 auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
168 execute_launch->getParentOp());
169 Operation* execute_parent =
170 parallel_execute ? parallel_execute.getOperation() : execute_launch;
171 // Collect all operands of `execute` whose defining ops are variable reads
172 // that might get merged, and add relevant information to `var_access_info`.
173 for (auto operand : llvm::enumerate(execute->getOpOperands())) {
174 var_access_info.new_operand_values.push_back(operand.value().get());
175 auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
176 operand.value().get().getDefiningOp());
177 if (!read_op) continue;
178 if (check_same_region &&
179 read_op->getParentRegion() != execute_parent->getParentRegion())
180 continue;
181
182 auto resource = read_op.resource();
183 if (check_device) {
184 // TODO(lyandy): Wrap resource ops in tf_device.launch.
185 if (auto* resource_op = resource.getDefiningOp()) {
186 auto resource_attr = resource_op->getAttr(kDeviceAttr);
187 // Check device matching for the node defining the resource.
188 if (!resource_attr || resource_attr != device_attr) continue;
189 } else {
190 auto resource_arg = resource.dyn_cast<BlockArgument>();
191 assert(resource_arg);
192 if (resource_arg.getOwner() != &func.front()) continue;
193 // Check device matching for the argument defining the resource.
194 auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
195 resource_arg.getArgNumber(), kFuncDeviceAttr);
196 if (!resource_attr || resource_attr != device_attr) continue;
197 }
198 }
199
200 auto emplace_res = var_access_info.per_resource_info.try_emplace(
201 resource, VariableAccessInfo());
202 if (!emplace_res.second) {
203 LLVM_DEBUG(llvm::dbgs()
204 << "Skipping execute that has multiple reads of a variable: "
205 << execute << "\n");
206 var_access_info.per_resource_info.shrink_and_clear();
207 return var_access_info;
208 }
209
210 VLOG(2) << "Adding read op to merge candidates: " << debugString(read_op);
211 auto& info = emplace_res.first->getSecond();
212 info.execute_input_index = operand.index();
213 info.read = read_op;
214 var_access_info.new_operand_values[operand.index()] = resource;
215 var_access_info.resources_read.push_back(resource);
216 if (!first_read || info.read->isBeforeInBlock(first_read)) {
217 first_read = info.read;
218 }
219 }
220
221 if (!first_read) return var_access_info;
222
223 // Walk backwards from `execute_parent` to `first_read` and remove merge
224 // candidates based on resource modifications.
225 llvm::SmallDenseSet<int64_t> resource_ids;
226 bool previous_unknown_resource_access = false;
227 for (Operation& op : llvm::reverse(llvm::make_range(
228 first_read->getIterator(), execute_parent->getIterator()))) {
229 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
230 VLOG(2) << "Processing read op " << debugString(op);
231 auto info_it = var_access_info.per_resource_info.find(read_op.resource());
232 bool is_merge_candidate =
233 info_it != var_access_info.per_resource_info.end();
234
235 if (is_merge_candidate &&
236 !IsResourceSafeForMerge(read_op.resource(), resource_analysis_info,
237 var_access_info, resource_ids,
238 previous_unknown_resource_access)) {
239 VLOG(2) << " removing op from merge candidates";
240 int input_index = info_it->getSecond().execute_input_index;
241 var_access_info.new_operand_values[input_index] =
242 execute.getOperand(input_index);
243 var_access_info.per_resource_info.erase(info_it);
244 }
245 }
246 previous_unknown_resource_access |=
247 AddAccessedResourceIds(&op, resource_analysis_info, resource_ids);
248 }
249
250 if (var_access_info.per_resource_info.empty()) {
251 return var_access_info;
252 }
253
254 // Find outputs that are variable assigns.
255 Operation* last_assign = nullptr;
256 llvm::SmallPtrSet<Operation*, 8> all_assigns;
257 llvm::SmallVector<bool, 8> output_merged(execute_launch.getNumResults(),
258 false);
259
260 auto execute_outputs =
261 parallel_execute
262 ? parallel_execute.GetRegionOutputs(
263 execute_launch->getParentRegion()->getRegionNumber())
264 : execute_launch.getResults();
265 for (auto execute_output : llvm::enumerate(execute_outputs)) {
266 // TODO(lyandy): Handle updates to resource writes by remapping to parent
267 // launch result and checking if launch result is an AssignVariableOp.
268 auto result = execute_output.value();
269 if (!result.hasOneUse()) continue;
270
271 auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
272 if (!assign_op) continue;
273 auto resource = assign_op.resource();
274 auto it = var_access_info.per_resource_info.find(resource);
275 if (it == var_access_info.per_resource_info.end()) continue;
276 auto& info = it->getSecond();
277 if (info.assign) {
278 LLVM_DEBUG(llvm::dbgs()
279 << "Skipping execute that has multiple assigns of a variable: "
280 << execute << "\n");
281 var_access_info.per_resource_info.shrink_and_clear();
282 return var_access_info;
283 }
284 info.execute_output_index = execute_output.index();
285 info.assign = assign_op;
286 if (!last_assign || last_assign->isBeforeInBlock(assign_op)) {
287 last_assign = assign_op;
288 }
289 VLOG(2) << "Adding assign op to merge candidates: "
290 << debugString(assign_op);
291 all_assigns.insert(assign_op);
292 output_merged[execute_output.index()] = true;
293 }
294
295 if (last_assign != nullptr) {
296 // Walk forward from `execute_parent` to `last_assign` and remove merge
297 // candidates based on resource modifications.
298 resource_ids.clear();
299 previous_unknown_resource_access = false;
300 for (Operation& op :
301 llvm::make_range(std::next(execute_parent->getIterator()),
302 std::next(last_assign->getIterator()))) {
303 if (auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
304 VLOG(2) << "Processing assign op " << debugString(op);
305 bool is_merge_candidate = true;
306 if (all_assigns.count(assign_op) == 0) is_merge_candidate = false;
307 auto info_it =
308 var_access_info.per_resource_info.find(assign_op.resource());
309 if (info_it == var_access_info.per_resource_info.end())
310 is_merge_candidate = false;
311
312 if (is_merge_candidate &&
313 !IsResourceSafeForMerge(
314 assign_op.resource(), resource_analysis_info, var_access_info,
315 resource_ids, previous_unknown_resource_access)) {
316 VLOG(2) << " removing op from merge candidates";
317 output_merged[info_it->second.execute_output_index] = false;
318 info_it->second.execute_output_index = -1;
319 info_it->second.assign = nullptr;
320 }
321 }
322 previous_unknown_resource_access |=
323 AddAccessedResourceIds(&op, resource_analysis_info, resource_ids);
324 }
325 }
326
327 // Populate var_access_info.old_to_new_output_mapping.
328 int new_output_index = 0;
329 var_access_info.old_to_new_output_mapping.resize(
330 execute_launch.getNumResults());
331 for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) {
332 if (output_merged[i]) {
333 var_access_info.old_to_new_output_mapping[i] = -1;
334 } else {
335 var_access_info.old_to_new_output_mapping[i] = new_output_index;
336 ++new_output_index;
337 }
338 }
339 return var_access_info;
340 }
341
342 // Appends result types of tf_device.parallel_execute from `start` index region
343 // (inclusive) to `end` index region (exclusive) to `output_types` and returns
344 // the number of types added.
AppendTypes(llvm::SmallVectorImpl<Type> * output_types,tf_device::ParallelExecuteOp parallel_execute,int start,int end)345 int AppendTypes(llvm::SmallVectorImpl<Type>* output_types,
346 tf_device::ParallelExecuteOp parallel_execute, int start,
347 int end) {
348 const int size_before = output_types->size();
349 for (int index = start; index < end; ++index) {
350 Block& block = parallel_execute.GetRegionBlockWithIndex(index);
351 auto terminator_operand_types = block.getTerminator()->getOperandTypes();
352 output_types->append(terminator_operand_types.begin(),
353 terminator_operand_types.end());
354 }
355 return output_types->size() - size_before;
356 }
357
358 // Replaces TPUExecute with TPUExecuteAndUpdateVariables in a
359 // tf_device.parallel_execute op.
ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & var_access_info,OpBuilder * builder)360 void ReplaceParallelExecute(
361 tf_device::ParallelExecuteOp parallel_execute,
362 tf_device::LaunchOp execute_launch,
363 tf_device::LaunchOp merged_execute_launch,
364 const VariableAccessesForTPUExecute& var_access_info, OpBuilder* builder) {
365 Operation* parallel_execute_op = parallel_execute.getOperation();
366
367 // Collect result types of tf_device.parallel_execute and update region
368 // result types with the new merged execute result types.
369 llvm::SmallVector<Type, 8> output_types;
370 const int parallel_execute_num_results = parallel_execute_op->getNumResults();
371 output_types.reserve(parallel_execute_num_results);
372 Region* execute_region = merged_execute_launch->getParentRegion();
373 const int region_index = execute_region->getRegionNumber();
374 const int num_results_before_region =
375 AppendTypes(&output_types, parallel_execute, 0, region_index);
376 // Append updated results from merged execute.
377 output_types.append(merged_execute_launch.getResultTypes().begin(),
378 merged_execute_launch.getResultTypes().end());
379 const int num_regions = parallel_execute_op->getNumRegions();
380 const int num_results_after_region = AppendTypes(
381 &output_types, parallel_execute, region_index + 1, num_regions);
382
383 builder->setInsertionPoint(parallel_execute);
384 auto new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
385 parallel_execute.getLoc(), num_regions, output_types);
386
387 // Replace the uses of the original parallel_execute before region containing
388 // merged execute.
389 Operation* new_parallel_execute_op = new_parallel_execute.getOperation();
390 for (int i = 0; i < num_results_before_region; ++i)
391 parallel_execute_op->getResult(i).replaceAllUsesWith(
392 new_parallel_execute_op->getResult(i));
393
394 // Replace the uses of the original parallel_execute after region containing
395 // merged execute. The number of results changed in the region containing the
396 // merged execute, but they should match, so results are replaced starting
397 // from the ends of both parallel_execute.
398 const int new_parallel_execute_num_results =
399 new_parallel_execute_op->getNumResults();
400 for (int i = 0; i < num_results_after_region; ++i)
401 parallel_execute_op->getResult(parallel_execute_num_results - i - 1)
402 .replaceAllUsesWith(new_parallel_execute_op->getResult(
403 new_parallel_execute_num_results - i - 1));
404
405 // Replace the uses of the original parallel_execute for the region containing
406 // the merged execute.
407 auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
408 for (int i = 0, end = var_access_info.old_to_new_output_mapping.size();
409 i < end; ++i) {
410 if (var_access_info.old_to_new_output_mapping[i] < 0) continue;
411 old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
412 var_access_info.old_to_new_output_mapping[i] +
413 num_results_before_region));
414 }
415
416 // Replace original terminator with new terminator for returning merged
417 // execute results.
418 Operation* old_terminator = execute_region->front().getTerminator();
419 builder->setInsertionPointToEnd(&execute_region->front());
420 builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
421 merged_execute_launch.getResults());
422 old_terminator->erase();
423
424 // Remove the original TPUExecute op.
425 execute_launch.erase();
426
427 // Move all regions from old parallel_execute to new parallel_execute.
428 for (auto region : llvm::zip(new_parallel_execute_op->getRegions(),
429 parallel_execute_op->getRegions()))
430 std::get<0>(region).takeBody(std::get<1>(region));
431
432 // Remove the original parallel_execute.
433 parallel_execute_op->dropAllUses();
434 parallel_execute.erase();
435 }
436
437 // Replaces TPUExecute with TPUExecuteAndUpdateVariables.
ReplaceExecute(tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & var_access_info)438 void ReplaceExecute(tf_device::LaunchOp execute_launch,
439 tf_device::LaunchOp merged_execute_launch,
440 const VariableAccessesForTPUExecute& var_access_info) {
441 // Replace the uses.
442 for (int i = 0, end = var_access_info.old_to_new_output_mapping.size();
443 i < end; ++i) {
444 if (var_access_info.old_to_new_output_mapping[i] < 0) continue;
445 execute_launch.getResult(i).replaceAllUsesWith(
446 merged_execute_launch.getResult(
447 var_access_info.old_to_new_output_mapping[i]));
448 }
449
450 // Remove the original TPUExecute op.
451 execute_launch.getOperation()->dropAllUses();
452 execute_launch.erase();
453 }
454
455 // Merges the variable accesses into one TPUExecute op.
MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,bool check_device,bool check_same_region,OpBuilder * builder)456 LogicalResult MergeForOneTPUExecute(
457 tf_device::LaunchOp execute_launch,
458 const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
459 bool check_device, bool check_same_region, OpBuilder* builder) {
460 auto var_access_info = BuildVariableAccessInfo(
461 execute_launch, resource_analysis_info, check_device, check_same_region);
462 if (var_access_info.per_resource_info.empty()) return success();
463
464 // Start creating the new TPUExecuteAndUpdateVariables op.
465 builder->setInsertionPoint(execute_launch);
466 // Output types. Skip the original outputs for merged assigns.
467 llvm::SmallVector<Type, 8> new_output_types;
468 int old_output_index = 0;
469 for (const auto& type : execute_launch.getResultTypes()) {
470 if (var_access_info.old_to_new_output_mapping[old_output_index] >= 0) {
471 new_output_types.push_back(type);
472 }
473 ++old_output_index;
474 }
475 // The attributes for merged variable reads and updates.
476 llvm::SmallVector<int64_t, 8> device_var_reads_indices;
477 llvm::SmallVector<int64_t, 8> device_var_updates_indices;
478 for (auto resource : var_access_info.resources_read) {
479 auto info_it = var_access_info.per_resource_info.find(resource);
480 if (info_it == var_access_info.per_resource_info.end()) continue;
481 device_var_reads_indices.push_back(info_it->second.execute_input_index);
482 device_var_updates_indices.push_back(info_it->second.execute_output_index);
483 }
484
485 // Check that all resources are either read or written to.
486 for (auto it : llvm::enumerate(var_access_info.new_operand_values)) {
487 Type type = it.value().getType();
488 if (type.isa<TensorType>() &&
489 type.cast<TensorType>().getElementType().isa<TF::ResourceType>()) {
490 if (!llvm::is_contained(device_var_reads_indices, it.index()) &&
491 !llvm::is_contained(device_var_updates_indices, it.index())) {
492 return execute_launch.GetBody().front().emitError("operand #")
493 << it.index()
494 << " is a resource that was neither read nor written to; this "
495 "resource potentially failed to be hoisted";
496 }
497 }
498 }
499
500 // Create the merged execute and update variables op.
501 auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
502 execute_launch.getLoc(), new_output_types,
503 var_access_info.new_operand_values,
504 llvm::ArrayRef<NamedAttribute>{
505 builder->getNamedAttr(
506 "device_var_reads_indices",
507 builder->getI64ArrayAttr(device_var_reads_indices)),
508 builder->getNamedAttr(
509 "device_var_updates_indices",
510 builder->getI64ArrayAttr(device_var_updates_indices))});
511
512 // Wrap in launch for device assignment.
513 auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
514 merged_execute.getLoc(), execute_launch.deviceAttr(),
515 merged_execute.getResultTypes());
516 merged_execute_launch.body().push_back(new Block);
517
518 builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
519 builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
520 merged_execute.getResults());
521
522 merged_execute.getOperation()->moveBefore(
523 merged_execute_launch.GetBody().getTerminator());
524
525 if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
526 execute_launch->getParentOp()))
527 ReplaceParallelExecute(parallel_execute, execute_launch,
528 merged_execute_launch, var_access_info, builder);
529 else
530 ReplaceExecute(execute_launch, merged_execute_launch, var_access_info);
531
532 // Remove the assign ops.
533 for (const auto& entry : var_access_info.per_resource_info) {
534 const auto& info = entry.getSecond();
535 if (info.assign) info.assign->erase();
536 }
537
538 // Remove the read ops if they have no more uses.
539 for (const auto& entry : var_access_info.per_resource_info) {
540 const auto& info = entry.getSecond();
541 if (info.read->use_empty()) info.read->erase();
542 }
543 return success();
544 }
545
546 // Checks if an ops parent is a tf_device.parallel_execute and the region the
547 // op is in is perfectly wrapped.
ParentParallelExecuteWrapsSingleOp(Operation * op)548 bool ParentParallelExecuteWrapsSingleOp(Operation* op) {
549 auto parallel_execute =
550 llvm::dyn_cast<tf_device::ParallelExecuteOp>(op->getParentOp());
551 if (!parallel_execute) return true;
552
553 return parallel_execute.RegionWrapsSingleOp(
554 op->getParentRegion()->getRegionNumber());
555 }
556
runOnOperation()557 void TPUMergeVariablesWithExecutePass::runOnOperation() {
558 ModuleOp module = getOperation();
559 mlir::TF::ResourceAliasAnalysis resource_analysis(module);
560 module.walk([&](func::FuncOp func) {
561 const auto& resource_analysis_info =
562 resource_analysis.GetAnalysisForFunc(func);
563 // Find all the executes first, since we will mutate the nodes around each
564 // execute.
565 llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
566 func.walk([&](tf_device::LaunchOp op) {
567 if (op.WrapsSingleOp() &&
568 llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()) &&
569 ParentParallelExecuteWrapsSingleOp(op))
570 execute_launches.push_back(op);
571 });
572
573 for (auto execute_launch : execute_launches) {
574 OpBuilder builder(&getContext());
575 const bool parent_is_replicate =
576 llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
577 (llvm::isa<tf_device::ParallelExecuteOp>(
578 execute_launch->getParentOp()) &&
579 llvm::isa<tf_device::ReplicateOp>(
580 execute_launch->getParentOp()->getParentOp()));
581
582 // If this is inside a tf_device::ReplicateOp, the variables are
583 // guaranteed to be on the same device as the TPUExecute op. Skip device
584 // checking in that case, but we need to check that we are only merging
585 // reads/assigns that are also in this replicated region.
586 if (failed(MergeForOneTPUExecute(
587 execute_launch, resource_analysis_info,
588 /*check_device=*/!parent_is_replicate,
589 /*check_same_region=*/parent_is_replicate, &builder))) {
590 signalPassFailure();
591 return;
592 }
593 }
594 });
595 }
596
597 } // namespace
598
599 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUMergeVariablesWithExecutePass()600 CreateTPUMergeVariablesWithExecutePass() {
601 return std::make_unique<TPUMergeVariablesWithExecutePass>();
602 }
603
604 } // namespace TFTPU
605 } // namespace mlir
606