1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
17
18 #include <cstdint>
19 #include <initializer_list>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Analysis/CallGraph.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/IR/Visitors.h" // from @llvm-project
38 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45
46 namespace mlir {
47 namespace TF {
48 namespace detail {
49
50 //===----------------------------------------------------------------------===//
51 // BacktrackAnalysisInfo
52 //===----------------------------------------------------------------------===//
53 // Class to hold backtrack analysis for a results of a region. Backtrack
54 // analysis will trace back the definition of return values of regions through
55 // pass-through operations, so that the return value of the region will have the
56 // same value as the backtracked value.
57 class BacktrackAnalysisInfo {
58 public:
59 // Initializes the backtrack analysis for the given region.
60 explicit BacktrackAnalysisInfo(Region& region,
61 detail::BacktrackAnalysis& backtrack_analysis);
62
63 BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
64
65 // Returns the value to which the given result number of the region can be
66 // backtracked to.
GetValue(int result_index) const67 Value GetValue(int result_index) const {
68 return backtracked_values_[result_index];
69 }
70
71 // Returns the argument index of the region to which the given result number
72 // can backtracked to. Such results will be called "function passthrough". If
73 // the result cannot be backtracked to a region argument, returns llvm::None.
GetArg(int result_index) const74 llvm::Optional<int> GetArg(int result_index) const {
75 if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
76 if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber();
77 return llvm::None;
78 }
79
80 private:
81 friend class detail::BacktrackAnalysis;
82
83 // Region for which this object holds the analysis info.
84 Region* region_;
85
86 // Backtracked values indexed by the result number.
87 llvm::SmallVector<Value, 4> backtracked_values_;
88 };
89
90 //===----------------------------------------------------------------------===//
91 // BacktrackAnalysis
92 //===----------------------------------------------------------------------===//
93 // Holds backtrack analysis for all functions and regions within a module.
94 class BacktrackAnalysis {
95 public:
96 using InfoT = BacktrackAnalysisInfo;
97
98 // Constructs the analysis by analyzing the given module.
99 explicit BacktrackAnalysis(ModuleOp module);
100
101 // Returns backtracking analysis for the given region.
GetAnalysisForRegion(Region & region) const102 const InfoT& GetAnalysisForRegion(Region& region) const {
103 auto it = info_map_.find(®ion);
104 assert(it != info_map_.end());
105 return it->second;
106 }
107
108 // Returns backtracking analysis for the given function.
GetAnalysisForFunc(FuncOp func) const109 const InfoT& GetAnalysisForFunc(FuncOp func) const {
110 return GetAnalysisForRegion(func.getBody());
111 }
112
113 // Backtracks the given value.
114 Value BacktrackValue(Value value);
115
116 private:
117 // Returns the analysis for the given region (analyzing the region if it has
118 // not yet been analyzed).
GetOrCreateAnalysis(Region & region)119 const InfoT& GetOrCreateAnalysis(Region& region) {
120 auto it = info_map_.find(®ion);
121 if (it == info_map_.end()) {
122 // Note: Keep object construction and insertion separate. If we use
123 // emplace() to construct and insert in a single shot, when analyzing
124 // this region, calls to BacktrackValue() may end up inserting additional
125 // entries in the map, causing the underlying storage to be moved. This
126 // would also include this pertially constructed object that we have just
127 // inserted into the map and are constructing it. To avoid this issue,
128 // construct the analysis object separately and then insert it into the
129 // map.
130 InfoT info(region, *this);
131 info_map_.insert({®ion, std::move(info)});
132 }
133
134 return GetAnalysisForRegion(region);
135 }
136
137 // Returns the backtrack analysis for the given region if it exists.
138 // If the region has not yet been analyzed, returns llvm::None.
GetAnalysisIfExists(Region & region) const139 Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
140 auto it = info_map_.find(®ion);
141 if (it == info_map_.end()) return llvm::None;
142 return &it->second;
143 }
144
GetAnalysisIfExists(FuncOp func) const145 Optional<const InfoT*> GetAnalysisIfExists(FuncOp func) const {
146 return GetAnalysisIfExists(func.getBody());
147 }
148
149 private:
150 llvm::SmallDenseMap<Region*, InfoT> info_map_;
151 };
152
153 // Analyzes all regions attached to all operations in the module.
BacktrackAnalysis(ModuleOp module)154 BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
155 const CallGraph call_graph(module);
156
157 // Visit functions bottom up when doing the analysis. Note that SCC iterator
158 // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
159 // after SCC2, i.e., the graph is traversed bottom up just the way we want.
160 auto scc_begin = llvm::scc_begin(&call_graph);
161 auto scc_end = llvm::scc_end(&call_graph);
162 for (auto& scc : make_range(scc_begin, scc_end)) {
163 // Each SCC node is a collection of callgraph nodes that form a cycle. We
164 // will visit these nodes in an arbitrary order. If a node being visited
165 // calls a function that has not yet been analyzed, we will not be able to
166 // backtrack through that function call (our analysis will be correct but
167 // pessimistic).
168 for (CallGraphNode* node : scc) {
169 if (node->isExternal()) continue;
170 Region* region = node->getCallableRegion();
171 GetOrCreateAnalysis(*region);
172 }
173 }
174
175 // This above call graph analysis will cover all regions attached to functions
176 // but we also need to analyze regions attached to other ops.
177 module->walk([this](Operation* op) {
178 if (op->hasTrait<OpTrait::NoTerminator>()) return;
179 for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
180 });
181 }
182
183 // Backtracks the definition of `value` looking through passthrough ops.
184 // Returns a non-null value and can return `value` if backtracking is not
185 // possible.
BacktrackValue(Value value)186 Value BacktrackAnalysis::BacktrackValue(Value value) {
187 while (Operation* op = value.getDefiningOp()) {
188 int res_index = value.cast<OpResult>().getResultNumber();
189 if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
190 value = graph.GetFetch().getOperand(res_index);
191 } else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
192 // Control output is generated by the IslandOp, not the yield in
193 // in the Island body.
194 if (value == island.control()) break;
195 value = island.GetYield().getOperand(res_index);
196 } else if (isa<IdentityNOp, IdentityOp>(op)) {
197 value = op->getOperand(res_index);
198 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
199 FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
200 if (!func) break;
201 // Check if the function being called has been analyzed. if not,
202 // we cannot backtrack the value further.
203 Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
204 if (!callee_info) break;
205 Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
206 if (!passthrough_arg) break;
207 value = call.getArgOperands()[passthrough_arg.getValue()];
208 } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
209 value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
210 } else {
211 break;
212 }
213 }
214 return value;
215 }
216
217 // Analyze the region.
BacktrackAnalysisInfo(Region & region,detail::BacktrackAnalysis & backtrack_analysis)218 BacktrackAnalysisInfo::BacktrackAnalysisInfo(
219 Region& region, detail::BacktrackAnalysis& backtrack_analysis)
220 : region_(®ion) {
221 if (region.empty()) return;
222
223 assert(llvm::hasSingleElement(region.getBlocks()));
224
225 auto results = region.front().getTerminator()->getOperands();
226 if (results.empty()) return;
227
228 backtracked_values_.reserve(results.size());
229 for (auto result : results)
230 backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
231 }
232
233 //===----------------------------------------------------------------------===//
234 // ResourceAliasAnalysisInfo
235 //===----------------------------------------------------------------------===//
236
237 namespace {
238
239 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
240
IsResourceAllocatingOp(Operation * op)241 bool IsResourceAllocatingOp(Operation* op) {
242 auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
243 if (!mem_interface) return false;
244
245 for (Value value : filter_resources(op->getResults())) {
246 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
247 mem_interface.getEffectsOnValue(value, effects);
248 for (auto& effect_instance : effects) {
249 if (isa<MemoryEffects::Allocate>(effect_instance.getEffect())) {
250 return true;
251 }
252 }
253 }
254 return false;
255 }
256
257 } // namespace
258
259 constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
260
IncrementResourceTypeId(int64_t & resource_type_id)261 void IncrementResourceTypeId(int64_t& resource_type_id) {
262 if (resource_type_id == ResourceAliasAnalysisInfo::kMaxResourceTypeId) {
263 // We don't expect this to happen, currently there are 10 resource types in
264 // TF dialect. Still, it should be visible if this ever happens.
265 LOG(WARNING) << "reached limit for supported number of resource types ("
266 << ResourceAliasAnalysisInfo::kMaxResourceTypeId
267 << "); this could lead to overly conservative execution order";
268 // Note: By not incrementing `resource_type_id` we still maintain
269 // correctness, we might only handle different resource types as the same
270 // type (for ID `kMaxResourceTypeId`) which is overly conservative.
271 } else {
272 ++resource_type_id;
273 }
274 }
275
276 // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(FuncOp func_op,const BacktrackAnalysis & backtrack_analysis)277 ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
278 FuncOp func_op, const BacktrackAnalysis& backtrack_analysis) {
279 // This function populates resource_value_to_ids_ and id_to_resource_values_.
280
281 // See `ResourceAliasAnalysisInfo` class for ID semantics.
282 int64_t next_unique_type_id = 0;
283 int64_t next_unique_instance_id = kMaxResourceTypeId + 1;
284
285 // Helper to assign new unique id for all resources in the given list of
286 // values.
287 auto assign_unique_id_to_all = [&](ValueRange values) {
288 for (Value value : filter_resources(values)) {
289 AddValueUniqueIDMapping(value, next_unique_instance_id++);
290 }
291 };
292
293 // Helper to assign new unknown id for all resources in the given list of
294 // values.
295 auto assign_unknown_id_to_all = [&](ValueRange values) {
296 for (Value value : filter_resources(values)) {
297 AddValueUniqueIDMapping(value, kUnknownResourceId);
298 }
299 };
300
301 // If `tf.resource_arg_unique_id` argument attributes are present for
302 // resource-type arguments, use those to decide which arguments correspond to
303 // the same resource (and thus need the same ID). Otherwise, they must not
304 // alias.
305 const bool has_arg_unique_id_attrs =
306 llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
307 return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
308 });
309 if (has_arg_unique_id_attrs) {
310 // Resource arguments have ID's attached (via `kResourceArgUniqueIdAttr`)
311 // that represent different resources. Map those ID's to the internal
312 // instance ID's used by this pass.
313 llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
314 for (auto arg : filter_resources(func_op.getArguments())) {
315 auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
316 arg.getArgNumber(), kResourceArgUniqueIdAttr);
317 assert(id_attr &&
318 "tf.resource_arg_unique_id attribute should exist on either "
319 "none or all arguments.");
320 auto emplace_res = attr_id_to_internal_id.try_emplace(
321 id_attr.getInt(), next_unique_instance_id);
322 AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
323 // Only increment ID if it has been used.
324 if (emplace_res.second) ++next_unique_instance_id;
325 }
326 } else {
327 // No `kResourceArgUniqueIdAttr` attribute is present, so all resource
328 // arguments must correspond to different resources and we can assign unique
329 // ID's.
330 assign_unique_id_to_all(func_op.getArguments());
331 }
332
333 // Since this analysis is neither inter-procedural nor inter-regional,
334 // each region attached to Op's within a function is analyzed independently.
335 // Seed this analysis for each such region by mapping all resource arguments
336 // for such regions to a new unique-id. This is required because walk() walks
337 // the attached regions first before visiting the op, so there is no
338 // opportunity during the walk to seed region arguments. Also note that walk
339 // eventually also visits the Op on which the walk() is called, so make sure
340 // we do not overwrite the function argument mapping here.
341 func_op.walk([&](Operation* op) {
342 if (op == func_op) return;
343 for (Region& region : op->getRegions()) {
344 assign_unique_id_to_all(region.getArguments());
345 }
346 });
347
348 llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
349 func_op.walk([&](Operation* op) {
350 if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
351 llvm::SmallVector<ResourceHandleValueAndId, 4> resources =
352 resource_alloc.GetResourceHandleValueAndIdList(
353 resource_handle_id_map, next_unique_instance_id);
354 for (auto& resource_handle : resources) {
355 AddValueUniqueIDMapping(resource_handle.value, resource_handle.id);
356 }
357 } else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
358 for (auto result : filter_resources(op->getResults()))
359 PropagateInputToOutput(op->getOperand(result.getResultNumber()),
360 result);
361 } else if (auto while_op = dyn_cast<WhileOp>(op)) {
362 AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
363 while_op.body_function()));
364 } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
365 AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
366 while_region.body()));
367 } else if (auto case_op = dyn_cast<CaseOp>(op)) {
368 llvm::SmallVector<FuncOp, 4> functions;
369 case_op.get_branch_functions(functions);
370 AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
371 } else if (auto if_op = dyn_cast<IfOp>(op)) {
372 AnalyzeFunctionalCaseOrIfOp(
373 if_op, {if_op.then_function(), if_op.else_function()},
374 backtrack_analysis);
375 } else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
376 AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
377 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
378 FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
379 if (!func) {
380 assign_unknown_id_to_all(op->getResults());
381 return WalkResult::advance();
382 }
383 const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func);
384 for (auto result : filter_resources(op->getResults())) {
385 auto passthrough_arg = func_info.GetArg(result.getResultNumber());
386 if (passthrough_arg) {
387 PropagateInputToOutput(
388 call.getArgOperands()[passthrough_arg.getValue()], result);
389 } else {
390 AddValueUniqueIDMapping(result, kUnknownResourceId);
391 }
392 }
393 } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
394 Region& region = op->getRegion(0);
395 const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
396 for (auto result : filter_resources(op->getResults())) {
397 Value body_result = body_info.GetValue(result.getResultNumber());
398 PropagateInputToOutput(body_result, result);
399 }
400 } else {
401 auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
402 for (Value value : filter_resources(op->getResults())) {
403 // Set unknown ID first, reset later if applicable.
404 int64_t resource_id = kUnknownResourceId;
405
406 if (mem_interface) {
407 auto alloc_effect =
408 mem_interface.getEffectOnValue<MemoryEffects::Allocate>(value);
409 if (alloc_effect) {
410 TypeID mlir_type_id =
411 alloc_effect.getValue().getResource()->getResourceID();
412 // Update or lookup internal type ID.
413 auto emplace_result = type_id_to_internal_type_id_.try_emplace(
414 mlir_type_id, next_unique_type_id);
415 // Change unknown ID to type-based ID.
416 resource_id = emplace_result.first->getSecond();
417 // Only increment ID if we have encountered a new resource type.
418 if (emplace_result.second)
419 IncrementResourceTypeId(next_unique_type_id);
420 }
421 }
422 AddValueUniqueIDMapping(value, resource_id);
423 }
424 }
425 return WalkResult::advance();
426 });
427 }
428
429 // Propagates the resource ID's from an input operand to a result. Returns true
430 // if the mapping changed.
PropagateInputToOutput(const Value & operand,const OpResult & result)431 bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand,
432 const OpResult& result) {
433 auto operand_it = resource_value_to_ids_.find(operand);
434 assert(operand_it != resource_value_to_ids_.end() &&
435 "A resource-type output does not have the corresponding "
436 "resource-type input.");
437 bool change = false;
438 for (int64_t id : operand_it->second)
439 change = AddValueUniqueIDMapping(result, id) || change;
440 return change;
441 }
442
443 // Analyzes while loops to compute resourceIDs for the loop results.
444 //
445 // (1) The base case for the analysis is that if the loop body does not execute
446 // at all, the resource IDs for each result is the same as the resource IDs
447 // of the corresponding input.
448 // (2) If the loop does execute one or more times, then we need to account for
449 // data flow through the body of the while loop. If result #r is the same
450 // as arg #a of the loop body (pass through argument), then we can reason
451 // further, else if the result is not a passthrough, we mark it as unknown.
452 // (3) For passthrough results, if result #r is the same as arg #a of the loop
453 // body, after one iteration, result #r = arg #a, so we need to also
454 // propagate arg #a to result #r. After another iteration, arg #a of the
455 // loop body will be result #a of the previous iteration. So then we need
456 // propagate from result #a to result #r. Generalizing, the resource ID
457 // propagation (for results which are passthrough) looks like:
458 //
459 // for r in (0, num_results) : result[r] = arg[r];
460 // repeat till no change {
461 // a = passthrough arg for result #r;
462 // result[r] += result[a];
463 // }
464 //
AnalyzeWhileLoop(Operation * while_op,const BacktrackAnalysisInfo & body_info)465 void ResourceAliasAnalysisInfo::AnalyzeWhileLoop(
466 Operation* while_op, const BacktrackAnalysisInfo& body_info) {
467 // Seed the resource ID's for the results using either the resource ID of the
468 // passthrough arg, or unknown. We need to perform further analysis if we
469 // find a passthrough arg which is not the same as corresponding the result #.
470 llvm::SmallVector<Optional<int>, 4> passthrough_args(
471 while_op->getNumResults());
472 bool need_analysis = false;
473 for (auto result : filter_resources(while_op->getResults())) {
474 int result_index = result.getResultNumber();
475 passthrough_args[result_index] = body_info.GetArg(result_index);
476 if (passthrough_args[result_index]) {
477 int passthru_index = passthrough_args[result_index].getValue();
478 PropagateInputToOutput(while_op->getOperand(passthru_index), result);
479 need_analysis |=
480 !IsUnknownResource(result) && passthru_index != result_index;
481 } else {
482 AddValueUniqueIDMapping(result, kUnknownResourceId);
483 }
484 }
485
486 if (!need_analysis) return;
487
488 // We found a result that is not unknown and whose passthrough operand index
489 // is not the same as the result index, which means there is "crosstalk"
490 // between 2 or more operands. In that case, we do an iterative propagation
491 // of resource ID's till the results converge.
492 bool change = true;
493 while (change) {
494 change = false;
495 for (auto result : filter_resources(while_op->getResults())) {
496 if (IsUnknownResource(result)) continue;
497 // If this result has a valid passthrough arg, propagate resource ID's
498 // from the result of the passthrough arg
499 int result_index = result.getResultNumber();
500 int passthru_index = passthrough_args[result_index].getValue();
501 change =
502 PropagateInputToOutput(while_op->getResult(passthru_index), result) ||
503 change;
504 }
505 }
506 }
507
508 template <class CaseOrIfOp>
AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op,llvm::ArrayRef<FuncOp> functions,const BacktrackAnalysis & backtrack_analysis)509 void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp(
510 CaseOrIfOp case_or_if_op, llvm::ArrayRef<FuncOp> functions,
511 const BacktrackAnalysis& backtrack_analysis) {
512 llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
513 infos.reserve(functions.size());
514 for (FuncOp func : functions)
515 infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func));
516
517 // If a result is a passthrough of all branches' inputs, merge the resource
518 // IDs of corresponding operands for all the inputs.
519 for (auto result : filter_resources(case_or_if_op.getResults())) {
520 llvm::SmallVector<llvm::Optional<int>, 2> passthrough_args;
521 passthrough_args.reserve(functions.size());
522 for (const auto* info : infos)
523 passthrough_args.emplace_back(info->GetArg(result.getResultNumber()));
524
525 const bool all_passthrough_args_known = llvm::all_of(
526 passthrough_args, [](const llvm::Optional<int>& passthrough_arg) {
527 return passthrough_arg.hasValue();
528 });
529 if (all_passthrough_args_known) {
530 for (const auto& passthrough_arg : passthrough_args) {
531 Value operand = case_or_if_op.input()[passthrough_arg.getValue()];
532 PropagateInputToOutput(operand, result);
533 }
534 } else {
535 AddValueUniqueIDMapping(result, kUnknownResourceId);
536 }
537 }
538 }
539
AnalyzeRegionCaseOrIfOp(Operation * case_or_if_op,const BacktrackAnalysis & backtrack_analysis)540 void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp(
541 Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) {
542 llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
543 infos.reserve(case_or_if_op->getNumRegions());
544 for (Region& region : case_or_if_op->getRegions())
545 infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region));
546
547 // For region Case/If, the walk would have visited all branch regions before
548 // visiting the Case/If op. Backtracking of each region results will either
549 // give a value computed within these regions, or a region capture. If it is a
550 // region capture computed before this Case/If, it will have been visited
551 // earlier and a mapping would exist for that value. If it is computed within
552 // the region, then again a mapping would exist.
553 for (auto result : filter_resources(case_or_if_op->getResults())) {
554 for (const auto* info : infos) {
555 Value region_result = info->GetValue(result.getResultNumber());
556 PropagateInputToOutput(region_result, result);
557 }
558 }
559 }
560
IsUnknownResource(Value resource) const561 bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const {
562 auto it = resource_value_to_ids_.find(resource);
563 assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
564 // The set is sorted so we only need to check the first element since
565 // kUnknownResourceId < 0.
566 static_assert(kUnknownResourceId < 0,
567 "kUnknownResourceId should be negative");
568 return *it->getSecond().begin() == kUnknownResourceId;
569 }
570
571 const llvm::SmallSet<int64_t, 8>&
GetResourceUniqueIds(Value resource) const572 ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
573 assert(!IsUnknownResource(resource));
574 auto it = resource_value_to_ids_.find(resource);
575 assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
576 return it->getSecond();
577 }
578
579 const llvm::SmallSetVector<Value, 8>&
GetUniqueIdResources(const int64_t id) const580 ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
581 auto it = id_to_resource_values_.find(id);
582 assert(it != id_to_resource_values_.end() && "Unseen id was queried");
583 return it->getSecond();
584 }
585
GetResourceAliases(Value resource) const586 llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
587 Value resource) const {
588 assert(!IsUnknownResource(resource) && "Unknown resource was queried");
589 llvm::SmallSetVector<Value, 8> aliases;
590 for (int64_t id : GetResourceUniqueIds(resource)) {
591 const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
592 GetUniqueIdResources(id);
593 aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
594 }
595 // If there are resources that were marked as unknown, they alias with all
596 // other resources.
597 auto it = id_to_resource_values_.find(kUnknownResourceId);
598 if (it != id_to_resource_values_.end())
599 aliases.insert(it->getSecond().begin(), it->getSecond().end());
600 return aliases;
601 }
602
603 } // namespace detail
604
605 //===----------------------------------------------------------------------===//
606 // ResourceAliasAnalysis
607 //===----------------------------------------------------------------------===//
608
ResourceAliasAnalysis(ModuleOp module)609 ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) {
610 // Analyze all regions for backtracking info.
611 detail::BacktrackAnalysis backtrack_analysis(module);
612
613 // Analyze each function.
614 for (auto func : module.getOps<FuncOp>())
615 this->info_map_.try_emplace(func, func, backtrack_analysis);
616 }
617
618 } // namespace TF
619 } // namespace mlir
620