• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
44 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
51 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
52 
53 #define DEBUG_TYPE "tf-tpu-merge-variables-with-execute"
54 
55 namespace mlir {
56 namespace TFTPU {
57 
58 namespace {
59 constexpr char kAliasingAttr[] = "tf.aliasing_output";
60 constexpr char kDeviceAttr[] = "device";
61 constexpr char kFuncDeviceAttr[] = "tf.device";
62 
63 class TPUMergeVariablesWithExecutePass
64     : public TF::TPUMergeVariablesWithExecutePassBase<
65           TPUMergeVariablesWithExecutePass> {
getDependentDialects(DialectRegistry & registry) const66   void getDependentDialects(DialectRegistry& registry) const override {
67     // We need this here because at the moment we deserialize the TPUCompileMlir
68     // operation which contains annotation like `mhlo.sharding` attributes.
69     registry.insert<mhlo::MhloDialect>();
70   }
71   void runOnOperation() override;
72 };
73 
74 // Information for a pair of input/output of the TPUExecute op and the
75 // surrounding read/assign ops.
76 struct VariableAccessInfo {
77   int execute_input_index = -1;
78   int execute_output_index = -1;
79   Operation* read = nullptr;
80   Operation* assign = nullptr;
81 };
82 
83 // Information about all resource accesses to be merged into a TPUExecute op.
84 struct VariableAccessesForTPUExecute {
85   // Maps each detected resource to a VariableAccessInfo. Eventually, this will
86   // contain all values for which we want to merge the accessing ops with a
87   // TPUExecute op.
88   llvm::SmallDenseMap<Value, VariableAccessInfo, 8> per_resource_info;
89   // The corresponding new output index in TPUExecuteAndUpdateVariables for
90   // each old output index in TPUExecute.
91   llvm::SmallVector<int, 8> old_to_new_output_mapping;
92   // The resources read by ReadVariableOps that are inputs to TPUExecute,
93   // ordered by the input indices to TPUExecute.
94   llvm::SmallVector<Value, 8> resources_read;
95   // Operands for the new TPUExecuteAndUpdateVariables.
96   llvm::SmallVector<Value, 8> new_operand_values;
97 };
98 
99 // Returns true iff the read or assign op associated with `resource` can be
100 // safely merged.
101 //
102 // `resource_ids` contains IDs of all previously accessed resources
103 // `previous_unknown_resource_access` is true if we had any previous unknown
104 // resource access.
IsResourceSafeForMerge(Value resource,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,const VariableAccessesForTPUExecute & infos,const llvm::SmallDenseSet<int64_t> & resource_ids,bool previous_unknown_resource_access)105 bool IsResourceSafeForMerge(
106     Value resource,
107     const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
108     const VariableAccessesForTPUExecute& infos,
109     const llvm::SmallDenseSet<int64_t>& resource_ids,
110     bool previous_unknown_resource_access) {
111   // If we had any unknown resource access before, then we conservatively assume
112   // that `resource` has been accessed before.
113   // If `resource` is an unknown resource, then we conservatively assume that
114   // the same resource has been accessed before.
115   if (previous_unknown_resource_access ||
116       resource_analysis_info.IsUnknownResource(resource))
117     return false;
118   const auto& ids = resource_analysis_info.GetResourceUniqueIds(resource);
119   for (int64_t id : ids) {
120     if (resource_ids.contains(id)) return false;
121   }
122   return true;
123 }
124 
125 // Adds IDs of resources which `op` accesses to `resource_ids`.
126 // Returns true iff op accesses a resource unknown to `resource_analysis_info`
127 // in which case we have to conservatively assume that any resource might be
128 // accessed.
AddAccessedResourceIds(Operation * op,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,llvm::SmallDenseSet<int64_t> & resource_ids)129 bool AddAccessedResourceIds(
130     Operation* op,
131     const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
132     llvm::SmallDenseSet<int64_t>& resource_ids) {
133   for (Value operand : TF::filter_resources(op->getOperands())) {
134     if (resource_analysis_info.IsUnknownResource(operand)) {
135       VLOG(2) << "  unknown access";
136       return true;
137     }
138     const auto& ids = resource_analysis_info.GetResourceUniqueIds(operand);
139     VLOG(2) << "  accesses following resources: " << absl::StrJoin(ids, ", ");
140     resource_ids.insert(ids.begin(), ids.end());
141   }
142   return false;
143 }
144 
145 // Finds the variable access info for a TPUExecute op.
146 //  - `check_device` specifies  whether it checks the device assignment of the
147 //  variables to match the TPUExecute op. This is optional in some context,
148 //  e.g., guaranteed by replication.
149 //  - `check_same_region` specifies whether the reads/assigns need to be in the
150 //  same region as `execute`. This is needed if `execute` is inside ReplicateOp.
BuildVariableAccessInfo(tf_device::LaunchOp execute_launch,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,bool check_device,bool check_same_region)151 VariableAccessesForTPUExecute BuildVariableAccessInfo(
152     tf_device::LaunchOp execute_launch,
153     const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
154     bool check_device, bool check_same_region) {
155   VariableAccessesForTPUExecute var_access_info;
156   Attribute device_attr = execute_launch.deviceAttr();
157   if (check_device && !device_attr) return var_access_info;
158   auto func = execute_launch->getParentOfType<mlir::func::FuncOp>();
159 
160   // Track the first read op found, which is used later to check if there are
161   // assign ops between it and the TPUExecute op. We will exclude reads before
162   // interfering accesses in a conservative way (see below). We do not consider
163   // resource accesses in other islands since their ordering is enforced by
164   // inter-island dependencies.
165   Operation* first_read = nullptr;
166   auto execute = cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
167   auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
168       execute_launch->getParentOp());
169   Operation* execute_parent =
170       parallel_execute ? parallel_execute.getOperation() : execute_launch;
171   // Collect all operands of `execute` whose defining ops are variable reads
172   // that might get merged, and add relevant information to `var_access_info`.
173   for (auto operand : llvm::enumerate(execute->getOpOperands())) {
174     var_access_info.new_operand_values.push_back(operand.value().get());
175     auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
176         operand.value().get().getDefiningOp());
177     if (!read_op) continue;
178     if (check_same_region &&
179         read_op->getParentRegion() != execute_parent->getParentRegion())
180       continue;
181 
182     auto resource = read_op.resource();
183     if (check_device) {
184       // TODO(lyandy): Wrap resource ops in tf_device.launch.
185       if (auto* resource_op = resource.getDefiningOp()) {
186         auto resource_attr = resource_op->getAttr(kDeviceAttr);
187         // Check device matching for the node defining the resource.
188         if (!resource_attr || resource_attr != device_attr) continue;
189       } else {
190         auto resource_arg = resource.dyn_cast<BlockArgument>();
191         assert(resource_arg);
192         if (resource_arg.getOwner() != &func.front()) continue;
193         // Check device matching for the argument defining the resource.
194         auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
195             resource_arg.getArgNumber(), kFuncDeviceAttr);
196         if (!resource_attr || resource_attr != device_attr) continue;
197       }
198     }
199 
200     auto emplace_res = var_access_info.per_resource_info.try_emplace(
201         resource, VariableAccessInfo());
202     if (!emplace_res.second) {
203       LLVM_DEBUG(llvm::dbgs()
204                  << "Skipping execute that has multiple reads of a variable: "
205                  << execute << "\n");
206       var_access_info.per_resource_info.shrink_and_clear();
207       return var_access_info;
208     }
209 
210     VLOG(2) << "Adding read op to merge candidates: " << debugString(read_op);
211     auto& info = emplace_res.first->getSecond();
212     info.execute_input_index = operand.index();
213     info.read = read_op;
214     var_access_info.new_operand_values[operand.index()] = resource;
215     var_access_info.resources_read.push_back(resource);
216     if (!first_read || info.read->isBeforeInBlock(first_read)) {
217       first_read = info.read;
218     }
219   }
220 
221   if (!first_read) return var_access_info;
222 
223   // Walk backwards from `execute_parent` to `first_read` and remove merge
224   // candidates based on resource modifications.
225   llvm::SmallDenseSet<int64_t> resource_ids;
226   bool previous_unknown_resource_access = false;
227   for (Operation& op : llvm::reverse(llvm::make_range(
228            first_read->getIterator(), execute_parent->getIterator()))) {
229     if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
230       VLOG(2) << "Processing read op " << debugString(op);
231       auto info_it = var_access_info.per_resource_info.find(read_op.resource());
232       bool is_merge_candidate =
233           info_it != var_access_info.per_resource_info.end();
234 
235       if (is_merge_candidate &&
236           !IsResourceSafeForMerge(read_op.resource(), resource_analysis_info,
237                                   var_access_info, resource_ids,
238                                   previous_unknown_resource_access)) {
239         VLOG(2) << "  removing op from merge candidates";
240         int input_index = info_it->getSecond().execute_input_index;
241         var_access_info.new_operand_values[input_index] =
242             execute.getOperand(input_index);
243         var_access_info.per_resource_info.erase(info_it);
244       }
245     }
246     previous_unknown_resource_access |=
247         AddAccessedResourceIds(&op, resource_analysis_info, resource_ids);
248   }
249 
250   if (var_access_info.per_resource_info.empty()) {
251     return var_access_info;
252   }
253 
254   // Find outputs that are variable assigns.
255   Operation* last_assign = nullptr;
256   llvm::SmallPtrSet<Operation*, 8> all_assigns;
257   llvm::SmallVector<bool, 8> output_merged(execute_launch.getNumResults(),
258                                            false);
259 
260   auto execute_outputs =
261       parallel_execute
262           ? parallel_execute.GetRegionOutputs(
263                 execute_launch->getParentRegion()->getRegionNumber())
264           : execute_launch.getResults();
265   for (auto execute_output : llvm::enumerate(execute_outputs)) {
266     // TODO(lyandy): Handle updates to resource writes by remapping to parent
267     // launch result and checking if launch result is an AssignVariableOp.
268     auto result = execute_output.value();
269     if (!result.hasOneUse()) continue;
270 
271     auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
272     if (!assign_op) continue;
273     auto resource = assign_op.resource();
274     auto it = var_access_info.per_resource_info.find(resource);
275     if (it == var_access_info.per_resource_info.end()) continue;
276     auto& info = it->getSecond();
277     if (info.assign) {
278       LLVM_DEBUG(llvm::dbgs()
279                  << "Skipping execute that has multiple assigns of a variable: "
280                  << execute << "\n");
281       var_access_info.per_resource_info.shrink_and_clear();
282       return var_access_info;
283     }
284     info.execute_output_index = execute_output.index();
285     info.assign = assign_op;
286     if (!last_assign || last_assign->isBeforeInBlock(assign_op)) {
287       last_assign = assign_op;
288     }
289     VLOG(2) << "Adding assign op to merge candidates: "
290             << debugString(assign_op);
291     all_assigns.insert(assign_op);
292     output_merged[execute_output.index()] = true;
293   }
294 
295   if (last_assign != nullptr) {
296     // Walk forward from `execute_parent` to `last_assign` and remove merge
297     // candidates based on resource modifications.
298     resource_ids.clear();
299     previous_unknown_resource_access = false;
300     for (Operation& op :
301          llvm::make_range(std::next(execute_parent->getIterator()),
302                           std::next(last_assign->getIterator()))) {
303       if (auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
304         VLOG(2) << "Processing assign op " << debugString(op);
305         bool is_merge_candidate = true;
306         if (all_assigns.count(assign_op) == 0) is_merge_candidate = false;
307         auto info_it =
308             var_access_info.per_resource_info.find(assign_op.resource());
309         if (info_it == var_access_info.per_resource_info.end())
310           is_merge_candidate = false;
311 
312         if (is_merge_candidate &&
313             !IsResourceSafeForMerge(
314                 assign_op.resource(), resource_analysis_info, var_access_info,
315                 resource_ids, previous_unknown_resource_access)) {
316           VLOG(2) << "  removing op from merge candidates";
317           output_merged[info_it->second.execute_output_index] = false;
318           info_it->second.execute_output_index = -1;
319           info_it->second.assign = nullptr;
320         }
321       }
322       previous_unknown_resource_access |=
323           AddAccessedResourceIds(&op, resource_analysis_info, resource_ids);
324     }
325   }
326 
327   // Populate var_access_info.old_to_new_output_mapping.
328   int new_output_index = 0;
329   var_access_info.old_to_new_output_mapping.resize(
330       execute_launch.getNumResults());
331   for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) {
332     if (output_merged[i]) {
333       var_access_info.old_to_new_output_mapping[i] = -1;
334     } else {
335       var_access_info.old_to_new_output_mapping[i] = new_output_index;
336       ++new_output_index;
337     }
338   }
339   return var_access_info;
340 }
341 
342 // Appends result types of tf_device.parallel_execute from `start` index region
343 // (inclusive) to `end` index region (exclusive) to `output_types` and returns
344 // the number of types added.
AppendTypes(llvm::SmallVectorImpl<Type> * output_types,tf_device::ParallelExecuteOp parallel_execute,int start,int end)345 int AppendTypes(llvm::SmallVectorImpl<Type>* output_types,
346                 tf_device::ParallelExecuteOp parallel_execute, int start,
347                 int end) {
348   const int size_before = output_types->size();
349   for (int index = start; index < end; ++index) {
350     Block& block = parallel_execute.GetRegionBlockWithIndex(index);
351     auto terminator_operand_types = block.getTerminator()->getOperandTypes();
352     output_types->append(terminator_operand_types.begin(),
353                          terminator_operand_types.end());
354   }
355   return output_types->size() - size_before;
356 }
357 
358 // Replaces TPUExecute with TPUExecuteAndUpdateVariables in a
359 // tf_device.parallel_execute op.
ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & var_access_info,OpBuilder * builder)360 void ReplaceParallelExecute(
361     tf_device::ParallelExecuteOp parallel_execute,
362     tf_device::LaunchOp execute_launch,
363     tf_device::LaunchOp merged_execute_launch,
364     const VariableAccessesForTPUExecute& var_access_info, OpBuilder* builder) {
365   Operation* parallel_execute_op = parallel_execute.getOperation();
366 
367   // Collect result types of tf_device.parallel_execute and update region
368   // result types with the new merged execute result types.
369   llvm::SmallVector<Type, 8> output_types;
370   const int parallel_execute_num_results = parallel_execute_op->getNumResults();
371   output_types.reserve(parallel_execute_num_results);
372   Region* execute_region = merged_execute_launch->getParentRegion();
373   const int region_index = execute_region->getRegionNumber();
374   const int num_results_before_region =
375       AppendTypes(&output_types, parallel_execute, 0, region_index);
376   // Append updated results from merged execute.
377   output_types.append(merged_execute_launch.getResultTypes().begin(),
378                       merged_execute_launch.getResultTypes().end());
379   const int num_regions = parallel_execute_op->getNumRegions();
380   const int num_results_after_region = AppendTypes(
381       &output_types, parallel_execute, region_index + 1, num_regions);
382 
383   builder->setInsertionPoint(parallel_execute);
384   auto new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
385       parallel_execute.getLoc(), num_regions, output_types);
386 
387   // Replace the uses of the original parallel_execute before region containing
388   // merged execute.
389   Operation* new_parallel_execute_op = new_parallel_execute.getOperation();
390   for (int i = 0; i < num_results_before_region; ++i)
391     parallel_execute_op->getResult(i).replaceAllUsesWith(
392         new_parallel_execute_op->getResult(i));
393 
394   // Replace the uses of the original parallel_execute after region containing
395   // merged execute. The number of results changed in the region containing the
396   // merged execute, but they should match, so results are replaced starting
397   // from the ends of both parallel_execute.
398   const int new_parallel_execute_num_results =
399       new_parallel_execute_op->getNumResults();
400   for (int i = 0; i < num_results_after_region; ++i)
401     parallel_execute_op->getResult(parallel_execute_num_results - i - 1)
402         .replaceAllUsesWith(new_parallel_execute_op->getResult(
403             new_parallel_execute_num_results - i - 1));
404 
405   // Replace the uses of the original parallel_execute for the region containing
406   // the merged execute.
407   auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
408   for (int i = 0, end = var_access_info.old_to_new_output_mapping.size();
409        i < end; ++i) {
410     if (var_access_info.old_to_new_output_mapping[i] < 0) continue;
411     old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
412         var_access_info.old_to_new_output_mapping[i] +
413         num_results_before_region));
414   }
415 
416   // Replace original terminator with new terminator for returning merged
417   // execute results.
418   Operation* old_terminator = execute_region->front().getTerminator();
419   builder->setInsertionPointToEnd(&execute_region->front());
420   builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
421                                        merged_execute_launch.getResults());
422   old_terminator->erase();
423 
424   // Remove the original TPUExecute op.
425   execute_launch.erase();
426 
427   // Move all regions from old parallel_execute to new parallel_execute.
428   for (auto region : llvm::zip(new_parallel_execute_op->getRegions(),
429                                parallel_execute_op->getRegions()))
430     std::get<0>(region).takeBody(std::get<1>(region));
431 
432   // Remove the original parallel_execute.
433   parallel_execute_op->dropAllUses();
434   parallel_execute.erase();
435 }
436 
437 // Replaces TPUExecute with TPUExecuteAndUpdateVariables.
ReplaceExecute(tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & var_access_info)438 void ReplaceExecute(tf_device::LaunchOp execute_launch,
439                     tf_device::LaunchOp merged_execute_launch,
440                     const VariableAccessesForTPUExecute& var_access_info) {
441   // Replace the uses.
442   for (int i = 0, end = var_access_info.old_to_new_output_mapping.size();
443        i < end; ++i) {
444     if (var_access_info.old_to_new_output_mapping[i] < 0) continue;
445     execute_launch.getResult(i).replaceAllUsesWith(
446         merged_execute_launch.getResult(
447             var_access_info.old_to_new_output_mapping[i]));
448   }
449 
450   // Remove the original TPUExecute op.
451   execute_launch.getOperation()->dropAllUses();
452   execute_launch.erase();
453 }
454 
455 // Merges the variable accesses into one TPUExecute op.
MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,const mlir::TF::ResourceAliasAnalysis::Info & resource_analysis_info,bool check_device,bool check_same_region,OpBuilder * builder)456 LogicalResult MergeForOneTPUExecute(
457     tf_device::LaunchOp execute_launch,
458     const mlir::TF::ResourceAliasAnalysis::Info& resource_analysis_info,
459     bool check_device, bool check_same_region, OpBuilder* builder) {
460   auto var_access_info = BuildVariableAccessInfo(
461       execute_launch, resource_analysis_info, check_device, check_same_region);
462   if (var_access_info.per_resource_info.empty()) return success();
463 
464   // Start creating the new TPUExecuteAndUpdateVariables op.
465   builder->setInsertionPoint(execute_launch);
466   // Output types. Skip the original outputs for merged assigns.
467   llvm::SmallVector<Type, 8> new_output_types;
468   int old_output_index = 0;
469   for (const auto& type : execute_launch.getResultTypes()) {
470     if (var_access_info.old_to_new_output_mapping[old_output_index] >= 0) {
471       new_output_types.push_back(type);
472     }
473     ++old_output_index;
474   }
475   // The attributes for merged variable reads and updates.
476   llvm::SmallVector<int64_t, 8> device_var_reads_indices;
477   llvm::SmallVector<int64_t, 8> device_var_updates_indices;
478   for (auto resource : var_access_info.resources_read) {
479     auto info_it = var_access_info.per_resource_info.find(resource);
480     if (info_it == var_access_info.per_resource_info.end()) continue;
481     device_var_reads_indices.push_back(info_it->second.execute_input_index);
482     device_var_updates_indices.push_back(info_it->second.execute_output_index);
483   }
484 
485   // Check that all resources are either read or written to.
486   for (auto it : llvm::enumerate(var_access_info.new_operand_values)) {
487     Type type = it.value().getType();
488     if (type.isa<TensorType>() &&
489         type.cast<TensorType>().getElementType().isa<TF::ResourceType>()) {
490       if (!llvm::is_contained(device_var_reads_indices, it.index()) &&
491           !llvm::is_contained(device_var_updates_indices, it.index())) {
492         return execute_launch.GetBody().front().emitError("operand #")
493                << it.index()
494                << " is a resource that was neither read nor written to; this "
495                   "resource potentially failed to be hoisted";
496       }
497     }
498   }
499 
500   // Create the merged execute and update variables op.
501   auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
502       execute_launch.getLoc(), new_output_types,
503       var_access_info.new_operand_values,
504       llvm::ArrayRef<NamedAttribute>{
505           builder->getNamedAttr(
506               "device_var_reads_indices",
507               builder->getI64ArrayAttr(device_var_reads_indices)),
508           builder->getNamedAttr(
509               "device_var_updates_indices",
510               builder->getI64ArrayAttr(device_var_updates_indices))});
511 
512   // Wrap in launch for device assignment.
513   auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
514       merged_execute.getLoc(), execute_launch.deviceAttr(),
515       merged_execute.getResultTypes());
516   merged_execute_launch.body().push_back(new Block);
517 
518   builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
519   builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
520                                        merged_execute.getResults());
521 
522   merged_execute.getOperation()->moveBefore(
523       merged_execute_launch.GetBody().getTerminator());
524 
525   if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
526           execute_launch->getParentOp()))
527     ReplaceParallelExecute(parallel_execute, execute_launch,
528                            merged_execute_launch, var_access_info, builder);
529   else
530     ReplaceExecute(execute_launch, merged_execute_launch, var_access_info);
531 
532   // Remove the assign ops.
533   for (const auto& entry : var_access_info.per_resource_info) {
534     const auto& info = entry.getSecond();
535     if (info.assign) info.assign->erase();
536   }
537 
538   // Remove the read ops if they have no more uses.
539   for (const auto& entry : var_access_info.per_resource_info) {
540     const auto& info = entry.getSecond();
541     if (info.read->use_empty()) info.read->erase();
542   }
543   return success();
544 }
545 
546 // Checks if an ops parent is a tf_device.parallel_execute and the region the
547 // op is in is perfectly wrapped.
ParentParallelExecuteWrapsSingleOp(Operation * op)548 bool ParentParallelExecuteWrapsSingleOp(Operation* op) {
549   auto parallel_execute =
550       llvm::dyn_cast<tf_device::ParallelExecuteOp>(op->getParentOp());
551   if (!parallel_execute) return true;
552 
553   return parallel_execute.RegionWrapsSingleOp(
554       op->getParentRegion()->getRegionNumber());
555 }
556 
runOnOperation()557 void TPUMergeVariablesWithExecutePass::runOnOperation() {
558   ModuleOp module = getOperation();
559   mlir::TF::ResourceAliasAnalysis resource_analysis(module);
560   module.walk([&](func::FuncOp func) {
561     const auto& resource_analysis_info =
562         resource_analysis.GetAnalysisForFunc(func);
563     // Find all the executes first, since we will mutate the nodes around each
564     // execute.
565     llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
566     func.walk([&](tf_device::LaunchOp op) {
567       if (op.WrapsSingleOp() &&
568           llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()) &&
569           ParentParallelExecuteWrapsSingleOp(op))
570         execute_launches.push_back(op);
571     });
572 
573     for (auto execute_launch : execute_launches) {
574       OpBuilder builder(&getContext());
575       const bool parent_is_replicate =
576           llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
577           (llvm::isa<tf_device::ParallelExecuteOp>(
578                execute_launch->getParentOp()) &&
579            llvm::isa<tf_device::ReplicateOp>(
580                execute_launch->getParentOp()->getParentOp()));
581 
582       // If this is inside a tf_device::ReplicateOp, the variables are
583       // guaranteed to be on the same device as the TPUExecute op. Skip device
584       // checking in that case, but we need to check that we are only merging
585       // reads/assigns that are also in this replicated region.
586       if (failed(MergeForOneTPUExecute(
587               execute_launch, resource_analysis_info,
588               /*check_device=*/!parent_is_replicate,
589               /*check_same_region=*/parent_is_replicate, &builder))) {
590         signalPassFailure();
591         return;
592       }
593     }
594   });
595 }
596 
597 }  // namespace
598 
599 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUMergeVariablesWithExecutePass()600 CreateTPUMergeVariablesWithExecutePass() {
601   return std::make_unique<TPUMergeVariablesWithExecutePass>();
602 }
603 
604 }  // namespace TFTPU
605 }  // namespace mlir
606