• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
17 
18 #include <cstdint>
19 #include <initializer_list>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Analysis/CallGraph.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/IR/Visitors.h"  // from @llvm-project
38 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 
46 namespace mlir {
47 namespace TF {
48 namespace detail {
49 
50 //===----------------------------------------------------------------------===//
51 // BacktrackAnalysisInfo
52 //===----------------------------------------------------------------------===//
53 // Class to hold backtrack analysis for a results of a region. Backtrack
54 // analysis will trace back the definition of return values of regions through
55 // pass-through operations, so that the return value of the region will have the
56 // same value as the backtracked value.
57 class BacktrackAnalysisInfo {
58  public:
59   // Initializes the backtrack analysis for the given region.
60   explicit BacktrackAnalysisInfo(Region& region,
61                                  detail::BacktrackAnalysis& backtrack_analysis);
62 
63   BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
64 
65   // Returns the value to which the given result number of the region can be
66   // backtracked to.
GetValue(int result_index) const67   Value GetValue(int result_index) const {
68     return backtracked_values_[result_index];
69   }
70 
71   // Returns the argument index of the region to which the given result number
72   // can backtracked to. Such results will be called "function passthrough". If
73   // the result cannot be backtracked to a region argument, returns llvm::None.
GetArg(int result_index) const74   llvm::Optional<int> GetArg(int result_index) const {
75     if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
76       if (arg.getParentBlock() == &region_->front()) return arg.getArgNumber();
77     return llvm::None;
78   }
79 
80  private:
81   friend class detail::BacktrackAnalysis;
82 
83   // Region for which this object holds the analysis info.
84   Region* region_;
85 
86   // Backtracked values indexed by the result number.
87   llvm::SmallVector<Value, 4> backtracked_values_;
88 };
89 
90 //===----------------------------------------------------------------------===//
91 // BacktrackAnalysis
92 //===----------------------------------------------------------------------===//
93 // Holds backtrack analysis for all functions and regions within a module.
94 class BacktrackAnalysis {
95  public:
96   using InfoT = BacktrackAnalysisInfo;
97 
98   // Constructs the analysis by analyzing the given module.
99   explicit BacktrackAnalysis(ModuleOp module);
100 
101   // Returns backtracking analysis for the given region.
GetAnalysisForRegion(Region & region) const102   const InfoT& GetAnalysisForRegion(Region& region) const {
103     auto it = info_map_.find(&region);
104     assert(it != info_map_.end());
105     return it->second;
106   }
107 
108   // Returns backtracking analysis for the given function.
GetAnalysisForFunc(FuncOp func) const109   const InfoT& GetAnalysisForFunc(FuncOp func) const {
110     return GetAnalysisForRegion(func.getBody());
111   }
112 
113   // Backtracks the given value.
114   Value BacktrackValue(Value value);
115 
116  private:
117   // Returns the analysis for the given region (analyzing the region if it has
118   // not yet been analyzed).
GetOrCreateAnalysis(Region & region)119   const InfoT& GetOrCreateAnalysis(Region& region) {
120     auto it = info_map_.find(&region);
121     if (it == info_map_.end()) {
122       // Note: Keep object construction and insertion separate. If we use
123       // emplace() to construct and insert in a single shot, when analyzing
124       // this region, calls to BacktrackValue() may end up inserting additional
125       // entries in the map, causing the underlying storage to be moved. This
126       // would also include this pertially constructed object that we have just
127       // inserted into the map and are constructing it. To avoid this issue,
128       // construct the analysis object separately and then insert it into the
129       // map.
130       InfoT info(region, *this);
131       info_map_.insert({&region, std::move(info)});
132     }
133 
134     return GetAnalysisForRegion(region);
135   }
136 
137   // Returns the backtrack analysis for the given region if it exists.
138   // If the region has not yet been analyzed, returns llvm::None.
GetAnalysisIfExists(Region & region) const139   Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
140     auto it = info_map_.find(&region);
141     if (it == info_map_.end()) return llvm::None;
142     return &it->second;
143   }
144 
GetAnalysisIfExists(FuncOp func) const145   Optional<const InfoT*> GetAnalysisIfExists(FuncOp func) const {
146     return GetAnalysisIfExists(func.getBody());
147   }
148 
149  private:
150   llvm::SmallDenseMap<Region*, InfoT> info_map_;
151 };
152 
153 // Analyzes all regions attached to all operations in the module.
BacktrackAnalysis(ModuleOp module)154 BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
155   const CallGraph call_graph(module);
156 
157   // Visit functions bottom up when doing the analysis. Note that SCC iterator
158   // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
159   // after SCC2, i.e., the graph is traversed bottom up just the way we want.
160   auto scc_begin = llvm::scc_begin(&call_graph);
161   auto scc_end = llvm::scc_end(&call_graph);
162   for (auto& scc : make_range(scc_begin, scc_end)) {
163     // Each SCC node is a collection of callgraph nodes that form a cycle. We
164     // will visit these nodes in an arbitrary order. If a node being visited
165     // calls a function that has not yet been analyzed, we will not be able to
166     // backtrack through that function call (our analysis will be correct but
167     // pessimistic).
168     for (CallGraphNode* node : scc) {
169       if (node->isExternal()) continue;
170       Region* region = node->getCallableRegion();
171       GetOrCreateAnalysis(*region);
172     }
173   }
174 
175   // This above call graph analysis will cover all regions attached to functions
176   // but we also need to analyze regions attached to other ops.
177   module->walk([this](Operation* op) {
178     if (op->hasTrait<OpTrait::NoTerminator>()) return;
179     for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
180   });
181 }
182 
183 // Backtracks the definition of `value` looking through passthrough ops.
184 // Returns a non-null value and can return `value` if backtracking is not
185 // possible.
BacktrackValue(Value value)186 Value BacktrackAnalysis::BacktrackValue(Value value) {
187   while (Operation* op = value.getDefiningOp()) {
188     int res_index = value.cast<OpResult>().getResultNumber();
189     if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
190       value = graph.GetFetch().getOperand(res_index);
191     } else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
192       // Control output is generated by the IslandOp, not the yield in
193       // in the Island body.
194       if (value == island.control()) break;
195       value = island.GetYield().getOperand(res_index);
196     } else if (isa<IdentityNOp, IdentityOp>(op)) {
197       value = op->getOperand(res_index);
198     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
199       FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
200       if (!func) break;
201       // Check if the function being called has been analyzed. if not,
202       // we cannot backtrack the value further.
203       Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
204       if (!callee_info) break;
205       Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
206       if (!passthrough_arg) break;
207       value = call.getArgOperands()[passthrough_arg.getValue()];
208     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
209       value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
210     } else {
211       break;
212     }
213   }
214   return value;
215 }
216 
217 // Analyze the region.
BacktrackAnalysisInfo(Region & region,detail::BacktrackAnalysis & backtrack_analysis)218 BacktrackAnalysisInfo::BacktrackAnalysisInfo(
219     Region& region, detail::BacktrackAnalysis& backtrack_analysis)
220     : region_(&region) {
221   if (region.empty()) return;
222 
223   assert(llvm::hasSingleElement(region.getBlocks()));
224 
225   auto results = region.front().getTerminator()->getOperands();
226   if (results.empty()) return;
227 
228   backtracked_values_.reserve(results.size());
229   for (auto result : results)
230     backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // ResourceAliasAnalysisInfo
235 //===----------------------------------------------------------------------===//
236 
237 namespace {
238 
239 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
240 
IsResourceAllocatingOp(Operation * op)241 bool IsResourceAllocatingOp(Operation* op) {
242   auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
243   if (!mem_interface) return false;
244 
245   for (Value value : filter_resources(op->getResults())) {
246     llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
247     mem_interface.getEffectsOnValue(value, effects);
248     for (auto& effect_instance : effects) {
249       if (isa<MemoryEffects::Allocate>(effect_instance.getEffect())) {
250         return true;
251       }
252     }
253   }
254   return false;
255 }
256 
257 }  // namespace
258 
259 constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
260 
IncrementResourceTypeId(int64_t & resource_type_id)261 void IncrementResourceTypeId(int64_t& resource_type_id) {
262   if (resource_type_id == ResourceAliasAnalysisInfo::kMaxResourceTypeId) {
263     // We don't expect this to happen, currently there are 10 resource types in
264     // TF dialect. Still, it should be visible if this ever happens.
265     LOG(WARNING) << "reached limit for supported number of resource types ("
266                  << ResourceAliasAnalysisInfo::kMaxResourceTypeId
267                  << "); this could lead to overly conservative execution order";
268     // Note: By not incrementing `resource_type_id` we still maintain
269     // correctness, we might only handle different resource types as the same
270     // type (for ID `kMaxResourceTypeId`) which is overly conservative.
271   } else {
272     ++resource_type_id;
273   }
274 }
275 
276 // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(FuncOp func_op,const BacktrackAnalysis & backtrack_analysis)277 ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
278     FuncOp func_op, const BacktrackAnalysis& backtrack_analysis) {
279   // This function populates resource_value_to_ids_ and id_to_resource_values_.
280 
281   // See `ResourceAliasAnalysisInfo` class for ID semantics.
282   int64_t next_unique_type_id = 0;
283   int64_t next_unique_instance_id = kMaxResourceTypeId + 1;
284 
285   // Helper to assign new unique id for all resources in the given list of
286   // values.
287   auto assign_unique_id_to_all = [&](ValueRange values) {
288     for (Value value : filter_resources(values)) {
289       AddValueUniqueIDMapping(value, next_unique_instance_id++);
290     }
291   };
292 
293   // Helper to assign new unknown id for all resources in the given list of
294   // values.
295   auto assign_unknown_id_to_all = [&](ValueRange values) {
296     for (Value value : filter_resources(values)) {
297       AddValueUniqueIDMapping(value, kUnknownResourceId);
298     }
299   };
300 
301   // If `tf.resource_arg_unique_id` argument attributes are present for
302   // resource-type arguments, use those to decide which arguments correspond to
303   // the same resource (and thus need the same ID). Otherwise, they must not
304   // alias.
305   const bool has_arg_unique_id_attrs =
306       llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
307         return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
308       });
309   if (has_arg_unique_id_attrs) {
310     // Resource arguments have ID's attached (via `kResourceArgUniqueIdAttr`)
311     // that represent different resources. Map those ID's to the internal
312     // instance ID's used by this pass.
313     llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
314     for (auto arg : filter_resources(func_op.getArguments())) {
315       auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
316           arg.getArgNumber(), kResourceArgUniqueIdAttr);
317       assert(id_attr &&
318              "tf.resource_arg_unique_id attribute should exist on either "
319              "none or all arguments.");
320       auto emplace_res = attr_id_to_internal_id.try_emplace(
321           id_attr.getInt(), next_unique_instance_id);
322       AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
323       // Only increment ID if it has been used.
324       if (emplace_res.second) ++next_unique_instance_id;
325     }
326   } else {
327     // No `kResourceArgUniqueIdAttr` attribute is present, so all resource
328     // arguments must correspond to different resources and we can assign unique
329     // ID's.
330     assign_unique_id_to_all(func_op.getArguments());
331   }
332 
333   // Since this analysis is neither inter-procedural nor inter-regional,
334   // each region attached to Op's within a function is analyzed independently.
335   // Seed this analysis for each such region by mapping all resource arguments
336   // for such regions to a new unique-id. This is required because walk() walks
337   // the attached regions first before visiting the op, so there is no
338   // opportunity during the walk to seed region arguments. Also note that walk
339   // eventually also visits the Op on which the walk() is called, so make sure
340   // we do not overwrite the function argument mapping here.
341   func_op.walk([&](Operation* op) {
342     if (op == func_op) return;
343     for (Region& region : op->getRegions()) {
344       assign_unique_id_to_all(region.getArguments());
345     }
346   });
347 
348   llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
349   func_op.walk([&](Operation* op) {
350     if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
351       llvm::SmallVector<ResourceHandleValueAndId, 4> resources =
352           resource_alloc.GetResourceHandleValueAndIdList(
353               resource_handle_id_map, next_unique_instance_id);
354       for (auto& resource_handle : resources) {
355         AddValueUniqueIDMapping(resource_handle.value, resource_handle.id);
356       }
357     } else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
358       for (auto result : filter_resources(op->getResults()))
359         PropagateInputToOutput(op->getOperand(result.getResultNumber()),
360                                result);
361     } else if (auto while_op = dyn_cast<WhileOp>(op)) {
362       AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
363                                      while_op.body_function()));
364     } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
365       AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
366                                          while_region.body()));
367     } else if (auto case_op = dyn_cast<CaseOp>(op)) {
368       llvm::SmallVector<FuncOp, 4> functions;
369       case_op.get_branch_functions(functions);
370       AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
371     } else if (auto if_op = dyn_cast<IfOp>(op)) {
372       AnalyzeFunctionalCaseOrIfOp(
373           if_op, {if_op.then_function(), if_op.else_function()},
374           backtrack_analysis);
375     } else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
376       AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
377     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
378       FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
379       if (!func) {
380         assign_unknown_id_to_all(op->getResults());
381         return WalkResult::advance();
382       }
383       const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func);
384       for (auto result : filter_resources(op->getResults())) {
385         auto passthrough_arg = func_info.GetArg(result.getResultNumber());
386         if (passthrough_arg) {
387           PropagateInputToOutput(
388               call.getArgOperands()[passthrough_arg.getValue()], result);
389         } else {
390           AddValueUniqueIDMapping(result, kUnknownResourceId);
391         }
392       }
393     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
394       Region& region = op->getRegion(0);
395       const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
396       for (auto result : filter_resources(op->getResults())) {
397         Value body_result = body_info.GetValue(result.getResultNumber());
398         PropagateInputToOutput(body_result, result);
399       }
400     } else {
401       auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
402       for (Value value : filter_resources(op->getResults())) {
403         // Set unknown ID first, reset later if applicable.
404         int64_t resource_id = kUnknownResourceId;
405 
406         if (mem_interface) {
407           auto alloc_effect =
408               mem_interface.getEffectOnValue<MemoryEffects::Allocate>(value);
409           if (alloc_effect) {
410             TypeID mlir_type_id =
411                 alloc_effect.getValue().getResource()->getResourceID();
412             // Update or lookup internal type ID.
413             auto emplace_result = type_id_to_internal_type_id_.try_emplace(
414                 mlir_type_id, next_unique_type_id);
415             // Change unknown ID to type-based ID.
416             resource_id = emplace_result.first->getSecond();
417             // Only increment ID if we have encountered a new resource type.
418             if (emplace_result.second)
419               IncrementResourceTypeId(next_unique_type_id);
420           }
421         }
422         AddValueUniqueIDMapping(value, resource_id);
423       }
424     }
425     return WalkResult::advance();
426   });
427 }
428 
429 // Propagates the resource ID's from an input operand to a result. Returns true
430 // if the mapping changed.
PropagateInputToOutput(const Value & operand,const OpResult & result)431 bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand,
432                                                        const OpResult& result) {
433   auto operand_it = resource_value_to_ids_.find(operand);
434   assert(operand_it != resource_value_to_ids_.end() &&
435          "A resource-type output does not have the corresponding "
436          "resource-type input.");
437   bool change = false;
438   for (int64_t id : operand_it->second)
439     change = AddValueUniqueIDMapping(result, id) || change;
440   return change;
441 }
442 
443 // Analyzes while loops to compute resourceIDs for the loop results.
444 //
445 // (1) The base case for the analysis is that if the loop body does not execute
446 //     at all, the resource IDs for each result is the same as the resource IDs
447 //     of the corresponding input.
448 // (2) If the loop does execute one or more times, then we need to account for
449 //     data flow through the body of the while loop. If result #r is the same
450 //     as arg #a of the loop body (pass through argument), then we can reason
451 //     further, else if the result is not a passthrough, we mark it as unknown.
452 // (3) For passthrough results, if result #r is the same as arg #a of the loop
453 //     body, after one iteration, result #r = arg #a, so we need to also
454 //     propagate arg #a to result #r. After another iteration, arg #a of the
455 //     loop body will be result #a of the previous iteration. So then we need
456 //     propagate from result #a to result #r. Generalizing, the resource ID
457 //     propagation (for results which are passthrough) looks like:
458 //
459 //     for r in (0, num_results) : result[r] = arg[r];
460 //     repeat till no change {
461 //       a = passthrough arg for result #r;
462 //       result[r] += result[a];
463 //     }
464 //
AnalyzeWhileLoop(Operation * while_op,const BacktrackAnalysisInfo & body_info)465 void ResourceAliasAnalysisInfo::AnalyzeWhileLoop(
466     Operation* while_op, const BacktrackAnalysisInfo& body_info) {
467   // Seed the resource ID's for the results using either the resource ID of the
468   // passthrough arg, or unknown. We need to perform further analysis if we
469   // find a passthrough arg which is not the same as corresponding the result #.
470   llvm::SmallVector<Optional<int>, 4> passthrough_args(
471       while_op->getNumResults());
472   bool need_analysis = false;
473   for (auto result : filter_resources(while_op->getResults())) {
474     int result_index = result.getResultNumber();
475     passthrough_args[result_index] = body_info.GetArg(result_index);
476     if (passthrough_args[result_index]) {
477       int passthru_index = passthrough_args[result_index].getValue();
478       PropagateInputToOutput(while_op->getOperand(passthru_index), result);
479       need_analysis |=
480           !IsUnknownResource(result) && passthru_index != result_index;
481     } else {
482       AddValueUniqueIDMapping(result, kUnknownResourceId);
483     }
484   }
485 
486   if (!need_analysis) return;
487 
488   // We found a result that is not unknown and whose passthrough operand index
489   // is not the same as the result index, which means there is "crosstalk"
490   // between 2 or more operands. In that case, we do an iterative propagation
491   // of resource ID's till the results converge.
492   bool change = true;
493   while (change) {
494     change = false;
495     for (auto result : filter_resources(while_op->getResults())) {
496       if (IsUnknownResource(result)) continue;
497       // If this result has a valid passthrough arg, propagate resource ID's
498       // from the result of the passthrough arg
499       int result_index = result.getResultNumber();
500       int passthru_index = passthrough_args[result_index].getValue();
501       change =
502           PropagateInputToOutput(while_op->getResult(passthru_index), result) ||
503           change;
504     }
505   }
506 }
507 
508 template <class CaseOrIfOp>
AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op,llvm::ArrayRef<FuncOp> functions,const BacktrackAnalysis & backtrack_analysis)509 void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp(
510     CaseOrIfOp case_or_if_op, llvm::ArrayRef<FuncOp> functions,
511     const BacktrackAnalysis& backtrack_analysis) {
512   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
513   infos.reserve(functions.size());
514   for (FuncOp func : functions)
515     infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func));
516 
517   // If a result is a passthrough of all branches' inputs, merge the resource
518   // IDs of corresponding operands for all the inputs.
519   for (auto result : filter_resources(case_or_if_op.getResults())) {
520     llvm::SmallVector<llvm::Optional<int>, 2> passthrough_args;
521     passthrough_args.reserve(functions.size());
522     for (const auto* info : infos)
523       passthrough_args.emplace_back(info->GetArg(result.getResultNumber()));
524 
525     const bool all_passthrough_args_known = llvm::all_of(
526         passthrough_args, [](const llvm::Optional<int>& passthrough_arg) {
527           return passthrough_arg.hasValue();
528         });
529     if (all_passthrough_args_known) {
530       for (const auto& passthrough_arg : passthrough_args) {
531         Value operand = case_or_if_op.input()[passthrough_arg.getValue()];
532         PropagateInputToOutput(operand, result);
533       }
534     } else {
535       AddValueUniqueIDMapping(result, kUnknownResourceId);
536     }
537   }
538 }
539 
AnalyzeRegionCaseOrIfOp(Operation * case_or_if_op,const BacktrackAnalysis & backtrack_analysis)540 void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp(
541     Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) {
542   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
543   infos.reserve(case_or_if_op->getNumRegions());
544   for (Region& region : case_or_if_op->getRegions())
545     infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region));
546 
547   // For region Case/If, the walk would have visited all branch regions before
548   // visiting the Case/If op. Backtracking of each region results will either
549   // give a value computed within these regions, or a region capture. If it is a
550   // region capture computed before this Case/If, it will have been visited
551   // earlier and a mapping would exist for that value. If it is computed within
552   // the region, then again a mapping would exist.
553   for (auto result : filter_resources(case_or_if_op->getResults())) {
554     for (const auto* info : infos) {
555       Value region_result = info->GetValue(result.getResultNumber());
556       PropagateInputToOutput(region_result, result);
557     }
558   }
559 }
560 
IsUnknownResource(Value resource) const561 bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const {
562   auto it = resource_value_to_ids_.find(resource);
563   assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
564   // The set is sorted so we only need to check the first element since
565   // kUnknownResourceId < 0.
566   static_assert(kUnknownResourceId < 0,
567                 "kUnknownResourceId should be negative");
568   return *it->getSecond().begin() == kUnknownResourceId;
569 }
570 
571 const llvm::SmallSet<int64_t, 8>&
GetResourceUniqueIds(Value resource) const572 ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
573   assert(!IsUnknownResource(resource));
574   auto it = resource_value_to_ids_.find(resource);
575   assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
576   return it->getSecond();
577 }
578 
579 const llvm::SmallSetVector<Value, 8>&
GetUniqueIdResources(const int64_t id) const580 ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
581   auto it = id_to_resource_values_.find(id);
582   assert(it != id_to_resource_values_.end() && "Unseen id was queried");
583   return it->getSecond();
584 }
585 
GetResourceAliases(Value resource) const586 llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
587     Value resource) const {
588   assert(!IsUnknownResource(resource) && "Unknown resource was queried");
589   llvm::SmallSetVector<Value, 8> aliases;
590   for (int64_t id : GetResourceUniqueIds(resource)) {
591     const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
592         GetUniqueIdResources(id);
593     aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
594   }
595   // If there are resources that were marked as unknown, they alias with all
596   // other resources.
597   auto it = id_to_resource_values_.find(kUnknownResourceId);
598   if (it != id_to_resource_values_.end())
599     aliases.insert(it->getSecond().begin(), it->getSecond().end());
600   return aliases;
601 }
602 
603 }  // namespace detail
604 
605 //===----------------------------------------------------------------------===//
606 // ResourceAliasAnalysis
607 //===----------------------------------------------------------------------===//
608 
ResourceAliasAnalysis(ModuleOp module)609 ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) {
610   // Analyze all regions for backtracking info.
611   detail::BacktrackAnalysis backtrack_analysis(module);
612 
613   // Analyze each function.
614   for (auto func : module.getOps<FuncOp>())
615     this->info_map_.try_emplace(func, func, backtrack_analysis);
616 }
617 
618 }  // namespace TF
619 }  // namespace mlir
620