• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
17 
18 #include <cstdint>
19 #include <initializer_list>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/SCCIterator.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Analysis/CallGraph.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/Value.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 
45 namespace mlir {
46 namespace TF {
47 namespace detail {
48 
49 //===----------------------------------------------------------------------===//
50 // BacktrackAnalysisInfo
51 //===----------------------------------------------------------------------===//
52 // Class to hold backtrack analysis for a results of a region. Backtrack
53 // analysis will trace back the definition of return values of regions through
54 // pass-through operations, so that the return value of the region will have the
55 // same value as the backtracked value.
56 class BacktrackAnalysisInfo {
57  public:
58   // Initializes the backtrack analysis for the given region.
59   explicit BacktrackAnalysisInfo(Region& region,
60                                  detail::BacktrackAnalysis& backtrack_analysis);
61 
62   BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
63 
64   // Returns the value to which the given result number of the region can be
65   // backtracked to.
GetValue(int result_index) const66   Value GetValue(int result_index) const {
67     return backtracked_values_[result_index];
68   }
69 
70   // Returns the argument index of the region to which the given result number
71   // can backtracked to. Such results will be called "function passthrough". If
72   // the result cannot be backtracked to a region argument, returns llvm::None.
GetArg(int result_index) const73   llvm::Optional<int> GetArg(int result_index) const {
74     if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
75       if (arg.getParentBlock() == &region_->front()) return arg.getArgNumber();
76     return llvm::None;
77   }
78 
79  private:
80   friend class detail::BacktrackAnalysis;
81 
82   // Region for which this object holds the analysis info.
83   Region* region_;
84 
85   // Backtracked values indexed by the result number.
86   llvm::SmallVector<Value, 4> backtracked_values_;
87 };
88 
89 //===----------------------------------------------------------------------===//
90 // BacktrackAnalysis
91 //===----------------------------------------------------------------------===//
92 // Holds backtrack analysis for all functions and regions within a module.
93 class BacktrackAnalysis {
94  public:
95   using InfoT = BacktrackAnalysisInfo;
96 
97   // Constructs the analysis by analyzing the given module.
98   explicit BacktrackAnalysis(ModuleOp module);
99 
100   // Returns backtracking analysis for the given region.
GetAnalysisForRegion(Region & region) const101   const InfoT& GetAnalysisForRegion(Region& region) const {
102     auto it = info_map_.find(&region);
103     assert(it != info_map_.end());
104     return it->second;
105   }
106 
107   // Returns backtracking analysis for the given function.
GetAnalysisForFunc(FuncOp func) const108   const InfoT& GetAnalysisForFunc(FuncOp func) const {
109     return GetAnalysisForRegion(func.getBody());
110   }
111 
112   // Backtracks the given value.
113   Value BacktrackValue(Value value);
114 
115  private:
116   // Returns the analysis for the given region (analyzing the region if it has
117   // not yet been analyzed).
GetOrCreateAnalysis(Region & region)118   const InfoT& GetOrCreateAnalysis(Region& region) {
119     auto it = info_map_.find(&region);
120     if (it == info_map_.end()) {
121       // Note: Keep object construction and insertion separate. If we use
122       // emplace() to construct and insert in a single shot, when analyzing
123       // this region, calls to BacktrackValue() may end up inserting additional
124       // entries in the map, causing the underlying storage to be moved. This
125       // would also include this pertially constructed object that we have just
126       // inserted into the map and are constructing it. To avoid this issue,
127       // construct the analysis object separately and then insert it into the
128       // map.
129       InfoT info(region, *this);
130       info_map_.insert({&region, std::move(info)});
131     }
132 
133     return GetAnalysisForRegion(region);
134   }
135 
136   // Returns the backtrack analysis for the given region if it exists.
137   // If the region has not yet been analyzed, returns llvm::None.
GetAnalysisIfExists(Region & region) const138   Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
139     auto it = info_map_.find(&region);
140     if (it == info_map_.end()) return llvm::None;
141     return &it->second;
142   }
143 
GetAnalysisIfExists(FuncOp func) const144   Optional<const InfoT*> GetAnalysisIfExists(FuncOp func) const {
145     return GetAnalysisIfExists(func.getBody());
146   }
147 
148  private:
149   llvm::SmallDenseMap<Region*, InfoT> info_map_;
150 };
151 
152 // Analyzes all regions attached to all operations in the module.
BacktrackAnalysis(ModuleOp module)153 BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
154   const CallGraph call_graph(module);
155 
156   // Visit functions bottom up when doing the analysis. Note that SCC iterator
157   // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
158   // after SCC2, i.e., the graph is traversed bottom up just the way we want.
159   auto scc_begin = llvm::scc_begin(&call_graph);
160   auto scc_end = llvm::scc_end(&call_graph);
161   for (auto& scc : make_range(scc_begin, scc_end)) {
162     // Each SCC node is a collection of callgraph nodes that form a cycle. We
163     // will visit these nodes in an arbitrary order. If a node being visited
164     // calls a function that has not yet been analyzed, we will not be able to
165     // backtrack through that function call (our analysis will be correct but
166     // pessimistic).
167     for (CallGraphNode* node : scc) {
168       if (node->isExternal()) continue;
169       Region* region = node->getCallableRegion();
170       GetOrCreateAnalysis(*region);
171     }
172   }
173 
174   // This above call graph analysis will cover all regions attached to functions
175   // but we also need to analyze regions attached to other ops.
176   module.walk([this](Operation* op) {
177     for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
178   });
179 }
180 
181 // Backtracks the definition of `value` looking through passthrough ops.
182 // Returns a non-null value and can return `value` if backtracking is not
183 // possible.
BacktrackValue(Value value)184 Value BacktrackAnalysis::BacktrackValue(Value value) {
185   while (Operation* op = value.getDefiningOp()) {
186     int res_index = value.cast<OpResult>().getResultNumber();
187     if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
188       value = graph.GetFetch().getOperand(res_index);
189     } else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
190       // Control output is generated by the IslandOp, not the yield in
191       // in the Island body.
192       if (value == island.control()) break;
193       value = island.GetYield().getOperand(res_index);
194     } else if (isa<IdentityNOp, IdentityOp>(op)) {
195       value = op->getOperand(res_index);
196     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
197       FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
198       if (!func) break;
199       // Check if the function being called has been analyzed. if not,
200       // we cannot backtrack the value further.
201       Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
202       if (!callee_info) break;
203       Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
204       if (!passthrough_arg) break;
205       value = call.getArgOperands()[passthrough_arg.getValue()];
206     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
207       value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
208     } else {
209       break;
210     }
211   }
212   return value;
213 }
214 
215 // Analyze the region.
BacktrackAnalysisInfo(Region & region,detail::BacktrackAnalysis & backtrack_analysis)216 BacktrackAnalysisInfo::BacktrackAnalysisInfo(
217     Region& region, detail::BacktrackAnalysis& backtrack_analysis)
218     : region_(&region) {
219   if (region.empty()) return;
220 
221   assert(llvm::hasSingleElement(region.getBlocks()));
222   auto results = region.front().getTerminator()->getOperands();
223   if (results.empty()) return;
224 
225   backtracked_values_.reserve(results.size());
226   for (auto result : results)
227     backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // ResourceAliasAnalysisInfo
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 
236 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
237 
238 }  // namespace
239 
240 constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
241 
242 // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(FuncOp func_op,const BacktrackAnalysis & backtrack_analysis)243 ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
244     FuncOp func_op, const BacktrackAnalysis& backtrack_analysis) {
245   // This function populates resource_value_to_ids_ and id_to_resource_values_.
246 
247   int64_t next_unique_id = 0;
248 
249   // Helper to assign new unique id for all resources in the given list of
250   // values.
251   auto assign_unique_id_to_all = [&](ValueRange values) {
252     for (Value value : filter_resources(values)) {
253       AddValueUniqueIDMapping(value, next_unique_id++);
254     }
255   };
256 
257   // Helper to assign new unknown id for all resources in the given list of
258   // values.
259   auto assign_unknown_id_to_all = [&](ValueRange values) {
260     for (Value value : filter_resources(values)) {
261       AddValueUniqueIDMapping(value, kUnknownResourceId);
262     }
263   };
264 
265   // If the "tf.resource_arg_unique_id" argument attributes are present for
266   // resource-type arguments, respect them when choosing IDs; otherwise, they
267   // must not alias.
268   const bool has_arg_unique_id_attrs =
269       llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
270         return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
271       });
272   // Maps the kResourceArgUniqueIdAttr attribute value to the internal integer
273   // ID used by this pass.
274   if (has_arg_unique_id_attrs) {
275     llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
276     for (auto arg : filter_resources(func_op.getArguments())) {
277       auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
278           arg.getArgNumber(), kResourceArgUniqueIdAttr);
279       assert(id_attr &&
280              "tf.resource_arg_unique_id attribute should exist on either "
281              "none or all arguments.");
282       auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(),
283                                                             next_unique_id++);
284       AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
285     }
286   } else {
287     assign_unique_id_to_all(func_op.getArguments());
288   }
289 
290   // Since this analysis is neither inter-procedural nor inter-regional,
291   // each region attached to Op's within a function is analyzed independently.
292   // Seed this analysis for each such region by mapping all resource arguments
293   // for such regions to a new unique-id. This is required because walk() walks
294   // the attached regions first before visiting the op, so there is no
295   // opportunity during the walk to seed region arguments. Also note that walk
296   // eventually also visits the Op on which the walk() is called, so make sure
297   // we do not overwrite the function argument mapping here.
298   func_op.walk([&](Operation* op) {
299     if (op == func_op) return;
300     for (Region& region : op->getRegions()) {
301       assign_unique_id_to_all(region.getArguments());
302     }
303   });
304 
305   llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
306   func_op.walk([&](Operation* op) {
307     if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
308       ResourceHandleValueAndId resource =
309           resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map,
310                                                      next_unique_id);
311       AddValueUniqueIDMapping(resource.value, resource.id);
312     } else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
313       for (auto result : filter_resources(op->getResults()))
314         PropagateInputToOutput(op->getOperand(result.getResultNumber()),
315                                result);
316     } else if (auto while_op = dyn_cast<WhileOp>(op)) {
317       AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
318                                      while_op.body_function()));
319     } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
320       AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
321                                          while_region.body()));
322     } else if (auto case_op = dyn_cast<CaseOp>(op)) {
323       llvm::SmallVector<FuncOp, 4> functions;
324       case_op.get_branch_functions(functions);
325       AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
326     } else if (auto if_op = dyn_cast<IfOp>(op)) {
327       AnalyzeFunctionalCaseOrIfOp(
328           if_op, {if_op.then_function(), if_op.else_function()},
329           backtrack_analysis);
330     } else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
331       AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
332     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
333       FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
334       if (!func) {
335         assign_unknown_id_to_all(op->getResults());
336         return WalkResult::advance();
337       }
338       const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func);
339       for (auto result : filter_resources(op->getResults())) {
340         auto passthrough_arg = func_info.GetArg(result.getResultNumber());
341         if (passthrough_arg) {
342           PropagateInputToOutput(
343               call.getArgOperands()[passthrough_arg.getValue()], result);
344         } else {
345           AddValueUniqueIDMapping(result, kUnknownResourceId);
346         }
347       }
348     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
349       Region& region = op->getRegion(0);
350       const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
351       for (auto result : filter_resources(op->getResults())) {
352         Value body_result = body_info.GetValue(result.getResultNumber());
353         PropagateInputToOutput(body_result, result);
354       }
355     } else {
356       assign_unknown_id_to_all(op->getResults());
357     }
358     return WalkResult::advance();
359   });
360 }
361 
362 // Propagates the resource ID's from an input operand to a result. Returns true
363 // if the mapping changed.
PropagateInputToOutput(const Value & operand,const OpResult & result)364 bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand,
365                                                        const OpResult& result) {
366   auto operand_it = resource_value_to_ids_.find(operand);
367   assert(operand_it != resource_value_to_ids_.end() &&
368          "A resource-type output does not have the corresponding "
369          "resource-type input.");
370   bool change = false;
371   for (int64_t id : operand_it->second)
372     change = AddValueUniqueIDMapping(result, id) || change;
373   return change;
374 }
375 
376 // Analyzes while loops to compute resourceIDs for the loop results.
377 //
378 // (1) The base case for the analysis is that if the loop body does not execute
379 //     at all, the resource IDs for each result is the same as the resource IDs
380 //     of the corresponding input.
381 // (2) If the loop does execute one or more times, then we need to account for
382 //     data flow through the body of the while loop. If result #r is the same
383 //     as arg #a of the loop body (pass through argument), then we can reason
384 //     further, else if the result is not a passthrough, we mark it as unknown.
385 // (3) For passthrough results, if result #r is the same as arg #a of the loop
386 //     body, after one iteration, result #r = arg #a, so we need to also
387 //     propagate arg #a to result #r. After another iteration, arg #a of the
388 //     loop body will be result #a of the previous iteration. So then we need
389 //     propagate from result #a to result #r. Generalizing, the resource ID
390 //     propagation (for results which are passthrough) looks like:
391 //
392 //     for r in (0, num_results) : result[r] = arg[r];
393 //     repeat till no change {
394 //       a = passthrough arg for result #r;
395 //       result[r] += result[a];
396 //     }
397 //
AnalyzeWhileLoop(Operation * while_op,const BacktrackAnalysisInfo & body_info)398 void ResourceAliasAnalysisInfo::AnalyzeWhileLoop(
399     Operation* while_op, const BacktrackAnalysisInfo& body_info) {
400   // Seed the resource ID's for the results using either the resource ID of the
401   // passthrough arg, or unknown. We need to perform further analysis if we
402   // find a passthrough arg which is not the same as corresponding the result #.
403   llvm::SmallVector<Optional<int>, 4> passthrough_args(
404       while_op->getNumResults());
405   bool need_analysis = false;
406   for (auto result : filter_resources(while_op->getResults())) {
407     int result_index = result.getResultNumber();
408     passthrough_args[result_index] = body_info.GetArg(result_index);
409     if (passthrough_args[result_index]) {
410       int passthru_index = passthrough_args[result_index].getValue();
411       PropagateInputToOutput(while_op->getOperand(passthru_index), result);
412       need_analysis |=
413           !IsUnknownResource(result) && passthru_index != result_index;
414     } else {
415       AddValueUniqueIDMapping(result, kUnknownResourceId);
416     }
417   }
418 
419   if (!need_analysis) return;
420 
421   // We found a result that is not unknown and whose passthrough operand index
422   // is not the same as the result index, which means there is "crosstalk"
423   // between 2 or more operands. In that case, we do an iterative propagation
424   // of resource ID's till the results converge.
425   bool change = true;
426   while (change) {
427     change = false;
428     for (auto result : filter_resources(while_op->getResults())) {
429       if (IsUnknownResource(result)) continue;
430       // If this result has a valid passthrough arg, propagate resource ID's
431       // from the result of the passthrough arg
432       int result_index = result.getResultNumber();
433       int passthru_index = passthrough_args[result_index].getValue();
434       change =
435           PropagateInputToOutput(while_op->getResult(passthru_index), result) ||
436           change;
437     }
438   }
439 }
440 
441 template <class CaseOrIfOp>
AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op,llvm::ArrayRef<FuncOp> functions,const BacktrackAnalysis & backtrack_analysis)442 void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp(
443     CaseOrIfOp case_or_if_op, llvm::ArrayRef<FuncOp> functions,
444     const BacktrackAnalysis& backtrack_analysis) {
445   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
446   infos.reserve(functions.size());
447   for (FuncOp func : functions)
448     infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func));
449 
450   // If a result is a passthrough of all branches' inputs, merge the resource
451   // IDs of corresponding operands for all the inputs.
452   for (auto result : filter_resources(case_or_if_op.getResults())) {
453     llvm::SmallVector<llvm::Optional<int>, 2> passthrough_args;
454     passthrough_args.reserve(functions.size());
455     for (const auto* info : infos)
456       passthrough_args.emplace_back(info->GetArg(result.getResultNumber()));
457 
458     const bool all_passthrough_args_known = llvm::all_of(
459         passthrough_args, [](const llvm::Optional<int>& passthrough_arg) {
460           return passthrough_arg.hasValue();
461         });
462     if (all_passthrough_args_known) {
463       for (const auto& passthrough_arg : passthrough_args) {
464         Value operand = case_or_if_op.input()[passthrough_arg.getValue()];
465         PropagateInputToOutput(operand, result);
466       }
467     } else {
468       AddValueUniqueIDMapping(result, kUnknownResourceId);
469     }
470   }
471 }
472 
AnalyzeRegionCaseOrIfOp(Operation * case_or_if_op,const BacktrackAnalysis & backtrack_analysis)473 void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp(
474     Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) {
475   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
476   infos.reserve(case_or_if_op->getNumRegions());
477   for (Region& region : case_or_if_op->getRegions())
478     infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region));
479 
480   // For region Case/If, the walk would have visited all branch regions before
481   // visiting the Case/If op. Backtracking of each region results will either
482   // give a value computed within these regions, or a region capture. If it is a
483   // region capture computed before this Case/If, it will have been visited
484   // earlier and a mapping would exist for that value. If it is computed within
485   // the region, then again a mapping would exist.
486   for (auto result : filter_resources(case_or_if_op->getResults())) {
487     for (const auto* info : infos) {
488       Value region_result = info->GetValue(result.getResultNumber());
489       PropagateInputToOutput(region_result, result);
490     }
491   }
492 }
493 
IsUnknownResource(Value resource) const494 bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const {
495   auto it = resource_value_to_ids_.find(resource);
496   assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
497   // The set is sorted so we only need to check the first element since
498   // kUnknownResourceId < 0.
499   static_assert(kUnknownResourceId < 0,
500                 "kUnknownResourceId should be negative");
501   return *it->getSecond().begin() == kUnknownResourceId;
502 }
503 
504 const llvm::SmallSet<int64_t, 8>&
GetResourceUniqueIds(Value resource) const505 ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
506   assert(!IsUnknownResource(resource));
507   auto it = resource_value_to_ids_.find(resource);
508   assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
509   return it->getSecond();
510 }
511 
512 const llvm::SmallSetVector<Value, 8>&
GetUniqueIdResources(const int64_t id) const513 ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
514   auto it = id_to_resource_values_.find(id);
515   assert(it != id_to_resource_values_.end() && "Unseen id was queried");
516   return it->getSecond();
517 }
518 
GetResourceAliases(Value resource) const519 llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
520     Value resource) const {
521   assert(!IsUnknownResource(resource) && "Unknown resource was queried");
522   llvm::SmallSetVector<Value, 8> aliases;
523   for (int64_t id : GetResourceUniqueIds(resource)) {
524     const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
525         GetUniqueIdResources(id);
526     aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
527   }
528   // If there are resources that were marked as unknown, they alias with all
529   // other resources.
530   auto it = id_to_resource_values_.find(kUnknownResourceId);
531   if (it != id_to_resource_values_.end())
532     aliases.insert(it->getSecond().begin(), it->getSecond().end());
533   return aliases;
534 }
535 
536 }  // namespace detail
537 
538 //===----------------------------------------------------------------------===//
539 // ResourceAliasAnalysis
540 //===----------------------------------------------------------------------===//
541 
ResourceAliasAnalysis(ModuleOp module)542 ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) {
543   // Analyze all regions for backtracking info.
544   detail::BacktrackAnalysis backtrack_analysis(module);
545 
546   // Analyze each function.
547   for (auto func : module.getOps<FuncOp>())
548     this->info_map_.try_emplace(func, func, backtrack_analysis);
549 }
550 
551 }  // namespace TF
552 }  // namespace mlir
553