• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
17 
18 #include <bitset>
19 #include <string>
20 
21 #include "absl/container/node_hash_map.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Block.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 
45 namespace mlir {
46 namespace TF {
47 namespace {
48 
49 constexpr ResourceId kUnknownResourceId =
50     ResourceAliasAnalysis::Info::kUnknownResourceId;
51 static_assert(kUnknownResourceId < 0, "kUnknownResourceId must be < 0");
52 
53 // A collection of Resource IDs. Note that `kUnknownResourceId` is smaller than
54 // all other resource IDs which are nonnegative (see check above) so it will
55 // always be the first element of a `ResourceIdSet` (we make use of this).
56 using ResourceIdSet = llvm::SmallSet<ResourceId, 8>;
57 
58 // Note that we cannot simply define a `static const llvm::SmallSet` here
59 // because of missing `initializer_list` support for `llvm::SmallSet`.
UnknownResourceSet()60 const ResourceIdSet& UnknownResourceSet() {
61   // clang-format off
62   static auto* id_set = new ResourceIdSet();
63   id_set->insert(kUnknownResourceId);
64   return *id_set;
65 }
66 
67 // Helper function to avoid frequent checks for unknown IDs.
GetResourceUniqueIdsOrUnknown(Value value,const ResourceAliasAnalysis::Info & alias_analysis)68 const ResourceIdSet& GetResourceUniqueIdsOrUnknown(
69     Value value,
70     const ResourceAliasAnalysis::Info& alias_analysis) {
71   if (!getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>() ||
72       alias_analysis.IsUnknownResource(value)) return UnknownResourceSet();
73   return alias_analysis.GetResourceUniqueIds(value);
74 }
75 
76 // Helper class for a collection of side effects for one resource.
77 class SideEffects {
78   enum Type {
79     kAlloc = 0,
80     kFree = 1,
81     kRead = 2,
82     kWrite = 3
83   };
84 
85  public:
IsAlloc() const86   bool IsAlloc() const { return effects_.test(kAlloc); }
IsFree() const87   bool IsFree() const { return effects_.test(kFree); }
IsRead() const88   bool IsRead() const { return effects_.test(kRead); }
IsWrite() const89   bool IsWrite() const { return effects_.test(kWrite); }
IsAllocOnly() const90   bool IsAllocOnly() const { return IsAlloc() && effects_.count() == 1; }
IsReadOnly() const91   bool IsReadOnly() const { return IsRead() && effects_.count() == 1; }
GetResourceId() const92   ResourceId GetResourceId() const { return resource_id_; }
93 
SetAlloc()94   void SetAlloc() { effects_.set(kAlloc); }
SetFree()95   void SetFree() { effects_.set(kFree); }
SetRead()96   void SetRead() { effects_.set(kRead); }
SetWrite()97   void SetWrite() { effects_.set(kWrite); }
SetUnknownEffect()98   void SetUnknownEffect() { effects_.set(); }
SetResourceId(ResourceId resource_id)99   void SetResourceId(ResourceId resource_id) { resource_id_ = resource_id; }
AddEffects(const SideEffects & other_effects)100   void AddEffects(const SideEffects& other_effects) {
101     effects_ |= other_effects.effects_;
102   }
103 
104  private:
105   std::bitset<4> effects_ = 0;
106   ResourceId resource_id_ = kUnknownResourceId;
107 };
108 
109 // We use `std::map` here because we rely on the order of elements.
110 using SideEffectsByResourceId = std::map<ResourceId, SideEffects>;
111 
112 // We use `std::unordered_map` here for pointer stability reasons.
113 // Note: If memory usage ever becomes a bottleneck here (not expected) we could
114 // use a Trie-like data structure to avoid storing side effects in both parent
115 // op and all its child ops (recursively), at the expense of lookup time.
116 using OpSideEffectMap = std::unordered_map<Operation*, SideEffectsByResourceId>;
117 
118 // Update `side_effects_by_resource_id` with `side_effects`.
UpdateSideEffectsByResourceId(const SideEffects & side_effects,SideEffectsByResourceId & side_effects_by_resource_id)119 void UpdateSideEffectsByResourceId(
120     const SideEffects& side_effects,
121     SideEffectsByResourceId& side_effects_by_resource_id) {
122   ResourceId id = side_effects.GetResourceId();
123   auto iter = side_effects_by_resource_id.find(id);
124   if (iter == side_effects_by_resource_id.end()) {
125     side_effects_by_resource_id[id] = side_effects;
126   } else {
127     iter->second.AddEffects(side_effects);
128   }
129 }
130 
MayHaveSideEffect(Operation * op)131 bool MayHaveSideEffect(Operation* op) {
132   if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
133     return TensorFlowDialect::CanHaveSideEffects(op);
134 
135   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return false;
136   // Conservatively assume that there can be side effects.
137   return true;
138 }
139 
ShouldUseResourceAliasAnalysis(const MemoryEffects::EffectInstance & effect)140 bool ShouldUseResourceAliasAnalysis(
141     const MemoryEffects::EffectInstance& effect) {
142   Value value = effect.getValue();
143   if (value && getElementTypeOrSelf(value.getType()).isa<ResourceType>()) {
144     // For value-based effects on resource values we can use resource alias
145     // analysis.
146     return true;
147   }
148   // For all other effects don't rely on resource alias analysis. Note that
149   // non-resource values are not processed in resource alias analysis.
150   return false;
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // SideEffectAnalysisInfo helper functions.
155 //===----------------------------------------------------------------------===//
156 
GetSideEffectsFromEffectInstance(const MemoryEffects::EffectInstance & effect_instance,Operation * op)157 SideEffects GetSideEffectsFromEffectInstance(
158     const MemoryEffects::EffectInstance& effect_instance, Operation* op) {
159   mlir::SideEffects::Effect* effect = effect_instance.getEffect();
160   SideEffects side_effects;
161   if (isa<MemoryEffects::Allocate>(effect)) {
162     side_effects.SetAlloc();
163   } else if (isa<MemoryEffects::Free>(effect)) {
164     side_effects.SetFree();
165   } else if (isa<MemoryEffects::Read>(effect)) {
166     side_effects.SetRead();
167   } else if (isa<MemoryEffects::Write>(effect)) {
168     side_effects.SetWrite();
169   } else {
170     LOG(WARNING) << "Unsupported effect for op "
171                  << op->getName().getStringRef().str();
172     side_effects.SetUnknownEffect();
173   }
174   return side_effects;
175 }
176 
177 }  // namespace
178 
179 namespace detail {
180 
181 // Class for propagating op-based side effects bottom-up and collecting them
182 // per op, by resource ID.
183 class OpSideEffectCollector {
184  public:
185   // Recursively collects op-based side effects for all ops in module and
186   // populates `op_side_effect_map_`.
OpSideEffectCollector(ModuleOp module)187   explicit OpSideEffectCollector(ModuleOp module) {
188     symbol_table_collection_.getSymbolTable(module);
189     for (auto func : module.getOps<func::FuncOp>()) {
190       CollectOpSideEffects(func);
191     }
192   }
193 
194   // Returns op-based side effects by resource ID for `op`.
GetSideEffectsForOp(Operation * op) const195   const SideEffectsByResourceId& GetSideEffectsForOp(Operation* op) const {
196     auto iter = op_side_effect_map_.find(op);
197     if (iter != op_side_effect_map_.end()) return iter->second;
198     return empty_side_effects_map_;
199   }
200 
201   // Returns true iff resource with given ID is only self-dependent, i.e., there
202   // are no dependencies to other resources (including unknown resources).
IsOnlySelfDependent(ResourceId resource_id) const203   bool IsOnlySelfDependent(ResourceId resource_id) const {
204     return self_dependent_only_ids_.contains(resource_id);
205   }
206 
207  private:
208   // Adds op-based side effects from all ops in `region` to `op` side effects.
209   // Collects side effects for ops that weren't visited before.
AddRegionSideEffectsForOp(Region & region,Operation * op)210   void AddRegionSideEffectsForOp(Region& region, Operation* op) {
211     for (Block& block : region) {
212       for (Operation& curr_op : block) {
213         if (op_side_effect_map_.count(&curr_op) == 0) {
214           CollectOpSideEffects(&curr_op);
215         }
216         for (const auto& entry : op_side_effect_map_[&curr_op]) {
217           UpdateSideEffectsByResourceId(entry.second, op_side_effect_map_[op]);
218         }
219       }
220     }
221   }
222 
223   // Collects op-based side effects for `op` in `op_side_effect_map_[op]`.
CollectOpSideEffects(Operation * op)224   void CollectOpSideEffects(Operation* op) {
225     if (!MayHaveSideEffect(op)) return;
226     // Skip following ops to avoid that every island, graph and function is
227     // classified as unknown side-effecting.
228     if (isa<tf_executor::YieldOp, tf_executor::FetchOp,
229             mlir::func::ReturnOp>(op))
230       return;
231 
232     // Propagate side effects from regions or functions attached to `op` for
233     // some special cases.
234     if (auto func = llvm::dyn_cast<func::FuncOp>(op)) {
235       AddRegionSideEffectsForOp(func.getBody(), op);
236     } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
237       func::FuncOp func_op = dyn_cast<func::FuncOp>(
238           call.resolveCallable(&symbol_table_collection_));
239       if (func_op) {
240         AddRegionSideEffectsForOp(func_op.getBody(), op);
241       }
242     } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
243       AddRegionSideEffectsForOp(if_op.then_function().getBody(), op);
244       AddRegionSideEffectsForOp(if_op.else_function().getBody(), op);
245     } else if (auto while_op = dyn_cast<WhileOp>(op)) {
246       AddRegionSideEffectsForOp(while_op.body_function().getBody(), op);
247     } else if (auto while_region_op = dyn_cast<WhileRegionOp>(op)) {
248       AddRegionSideEffectsForOp(while_region_op.body(), op);
249     } else if (auto case_op = dyn_cast<CaseOp>(op)) {
250       llvm::SmallVector<func::FuncOp, 4> branch_funcs;
251       case_op.get_branch_functions(branch_funcs);
252       for (auto branch_func : branch_funcs) {
253         AddRegionSideEffectsForOp(branch_func.getBody(), op);
254       }
255     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp,
256                    tf_executor::IslandOp, tf_executor::GraphOp, IfRegionOp,
257                    CaseRegionOp>(op)) {
258       for (Region& region : op->getRegions()) {
259         AddRegionSideEffectsForOp(region, op);
260       }
261     } else {
262       // Now handle all other ops.
263       auto& side_effects_by_resource_id = op_side_effect_map_[op];
264       llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
265       auto interface = dyn_cast<MemoryEffectOpInterface>(op);
266       if (interface) interface.getEffects(effects);
267       if (effects.empty()) {
268         // The op is potentially side-effecting and doesn't have any effect
269         // assigned, treat it as unknown side effect.
270         SideEffects side_effects;
271         side_effects.SetResourceId(kUnknownResourceId);
272         side_effects.SetUnknownEffect();
273         UpdateSideEffectsByResourceId(side_effects,
274                                       side_effects_by_resource_id);
275         // An unknown side effect dominates other side effects so we don't have
276         // to add them and can return here.
277         return;
278       }
279       // Add op-based side effects from regions (if any).
280       for (Region& region : op->getRegions()) {
281         AddRegionSideEffectsForOp(region, op);
282       }
283       // Add op-based side effects for the op itself.
284       for (const auto& effect : effects) {
285         // We handle value-based side effects for which we can use resource
286         // alias analysis at a different place, skip here.
287         if (ShouldUseResourceAliasAnalysis(effect)) continue;
288         if (llvm::isa<ResourceEffects::MustExecute>(effect.getResource()))
289           // We have this fake resource to avoid that certain ops are considered
290           // dead or get pruned, ignore it for side effect analysis.
291           continue;
292 
293         // Add side effects for op resource ID.
294         std::string instance_str = "";
295         SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
296         if (auto resource_instance_op =
297             dyn_cast<GetResourceInstanceInterface>(op)) {
298           instance_str = resource_instance_op.GetResourceInstanceStr();
299         }
300         TypeID type_id = effect.getResource()->getResourceID();
301         ResourceId resource_id = GetOpResourceId(type_id, instance_str);
302         side_effects.SetResourceId(resource_id);
303         UpdateSideEffectsByResourceId(side_effects,
304                                       side_effects_by_resource_id);
305         if (ResourceEffects::IsOnlySelfDependent(type_id)) {
306           self_dependent_only_ids_.insert(resource_id);
307         }
308       }
309     }
310   }
311 
312   // Get internal op resource ID from MLIR type ID and instance ID.
GetOpResourceId(TypeID type_id,std::string instance_str)313   ResourceId GetOpResourceId(TypeID type_id, std::string instance_str) {
314     auto emplace_result = type_instance_str_to_op_resource_id_.try_emplace(
315         std::make_pair(type_id.getAsOpaquePointer(), instance_str),
316         next_op_resource_id_);
317     // Increment type ID if we have encountered a new resource type.
318     if (emplace_result.second) ++next_op_resource_id_;
319     return emplace_result.first->second;
320   }
321 
322   // We use [0, kMaxResourceId] for resource IDs returned by resource alias
323   // analysis and [kMaxResourceId + 1, ...] for resource IDs which we generate
324   // for op-based side effects.
325   const ResourceId kMaxResourceId =
326       std::numeric_limits<ResourceId>::max() / 2;
327   // Next available ID for op-based resources (resources not handled by resource
328   // alias analysis).
329   ResourceId next_op_resource_id_ = kMaxResourceId + 1;
330   // Maps (type ID, instance ID) pairs to internal IDs for op-based resources.
331   // Also see comment above. Instead of using TypeID directly we use its opaque
332   // pointer.
333   absl::node_hash_map<std::pair<const void*, std::string>, ResourceId>
334     type_instance_str_to_op_resource_id_;
335   // Used for faster callable resolution.
336   SymbolTableCollection symbol_table_collection_;
337   // Collect all op-based side effects here.
338   OpSideEffectMap op_side_effect_map_;
339   const SideEffectsByResourceId empty_side_effects_map_;
340 
341   // Set of all resource IDs which only have dependencies to themselves, not to
342   // any other resource ID (including unknown resource ID).
343   llvm::SmallDenseSet<ResourceId, 8> self_dependent_only_ids_;
344 };
345 
346 // Collects all op-based and value-based side effects for `op` per resource ID.
CollectSideEffectsByResourceId(Operation * op,const OpSideEffectCollector & op_side_effect_collector,const TF::ResourceAliasAnalysis::Info & alias_analysis)347 SideEffectsByResourceId CollectSideEffectsByResourceId(
348     Operation* op,
349     const OpSideEffectCollector& op_side_effect_collector,
350     const TF::ResourceAliasAnalysis::Info& alias_analysis) {
351   SideEffectsByResourceId side_effects_by_resource_id;
352   if (!MayHaveSideEffect(op)) return side_effects_by_resource_id;
353 
354   if (isa<tf_device::LaunchOp, tf_device::ClusterOp, tf_executor::IslandOp,
355           tf_executor::GraphOp, IfRegionOp, CaseRegionOp, WhileRegionOp>(op)) {
356     // For ops that are side-effecting only if their attached regions are,
357     // collect effects for all ops in the regions instead of collecting effects
358     // for the op itself. This is important to avoid conservatism and to find
359     // resource variable accesses in regions which are not exposed to the op
360     // interface.
361     for (Region& region : op->getRegions()) {
362       for (Operation& region_op : region.front().without_terminator()) {
363         SideEffectsByResourceId region_op_effects =
364             CollectSideEffectsByResourceId(
365                 &region_op,
366                 op_side_effect_collector,
367                 alias_analysis);
368         for (const auto& [resource_id, side_effect] : region_op_effects) {
369           UpdateSideEffectsByResourceId(side_effect,
370                                         side_effects_by_resource_id);
371         }
372       }
373     }
374     return side_effects_by_resource_id;
375   }
376 
377   // Copy op-based side effects.
378   side_effects_by_resource_id =
379       op_side_effect_collector.GetSideEffectsForOp(op);
380   bool found_any_effect = !side_effects_by_resource_id.empty();
381 
382   // Collect value-based side effects from op interface.
383   llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
384   auto interface = dyn_cast<MemoryEffectOpInterface>(op);
385   if (interface) interface.getEffects(effects);
386 
387   llvm::SmallDenseSet<Value, 8> processed_values;
388   for (const auto& effect : effects) {
389     Value value = effect.getValue();
390     found_any_effect = true;
391 
392     // We only collect value-based side effects here for which we can use
393     // resource alias analysis. Other side effects are treated as op-based
394     // side effects.
395     if (!ShouldUseResourceAliasAnalysis(effect)) continue;
396     if (value) processed_values.insert(value);
397 
398     TypeID type_id = effect.getResource()->getResourceID();
399     if (ResourceEffects::IsOnlySelfDependent(type_id)) {
400       // For value-based side effects we currently treat resource types that are
401       // only self-dependent conservatively, i.e., we do add dependencies
402       // to/from unknown resource types. Currently, we don't have such cases and
403       // there is no indication that we will need to support them in the future.
404       LOG(WARNING) << "Self-dependent-only resource types are treated "
405                       "conservatively for value-based side effects.";
406     }
407 
408     // Add side effects for every potentially accessed resource ID.
409     SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
410     const auto& ids = GetResourceUniqueIdsOrUnknown(value, alias_analysis);
411     for (ResourceId id : ids) {
412       side_effects.SetResourceId(id);
413       UpdateSideEffectsByResourceId(side_effects, side_effects_by_resource_id);
414     }
415   }
416 
417   auto add_remaining_effects = [&](auto resource_values) {
418     for (Value resource_value : resource_values) {
419       // If we already processed this value before, skip it.
420       if (processed_values.count(resource_value) > 0) continue;
421       found_any_effect = true;
422 
423       // Conservatively set unknown effect.
424       SideEffects unknown_effect;
425       unknown_effect.SetUnknownEffect();
426 
427       // Add side effects for every potentially accessed resource ID.
428       const auto& ids =
429           GetResourceUniqueIdsOrUnknown(resource_value, alias_analysis);
430       for (ResourceId id : ids) {
431         unknown_effect.SetResourceId(id);
432         UpdateSideEffectsByResourceId(unknown_effect,
433                                       side_effects_by_resource_id);
434       }
435     }
436   };
437   // Add value-based side effects for resource values which are not covered by
438   // any side effect so far, for example, resource values being passed to
439   // `tf.While` or `tf.If` ops which are not part of the op definition but
440   // appear in a variadic input list.
441   add_remaining_effects(filter_resources(op->getOperands()));
442   add_remaining_effects(filter_resources(op->getResults()));
443 
444   if (!found_any_effect) {
445     // We haven't collected any side effect but the op is potentially
446     // side-effecting (otherwise we would have returned), therefore we have an
447     // unknown side effect for an unknown resource.
448     SideEffects unknown_effect;
449     unknown_effect.SetUnknownEffect();
450     unknown_effect.SetResourceId(kUnknownResourceId);
451     UpdateSideEffectsByResourceId(unknown_effect,
452                                   side_effects_by_resource_id);
453   }
454   return side_effects_by_resource_id;
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // SideEffectAnalysisInfo
459 //===----------------------------------------------------------------------===//
460 
AddPredecessorsForAccess(ResourceId resource_id,Operation * op,bool read_only)461 void SideEffectAnalysisInfo::AddPredecessorsForAccess(ResourceId resource_id,
462                                                       Operation* op,
463                                                       bool read_only) {
464   VLOG(2) << "    Adding predecessors for resource " << resource_id;
465   auto it = per_resource_access_info_.find(resource_id);
466   if (it == per_resource_access_info_.end()) return;
467   const auto& access_info = it->getSecond();
468 
469   auto& control_predecessors = control_predecessors_[op];
470   bool is_last_write_indirectly_tracked = false;
471   if (!read_only) {
472     // Add reads after last write as predecessors.
473     control_predecessors.insert(access_info.reads_since_last_write.begin(),
474                                 access_info.reads_since_last_write.end());
475     // Last write is indirectly tracked by any read predecessor we added.
476     is_last_write_indirectly_tracked =
477         !access_info.reads_since_last_write.empty();
478   }
479   if (access_info.last_write && !is_last_write_indirectly_tracked) {
480     // Add last write as predecessor.
481     control_predecessors.insert(access_info.last_write);
482   }
483 }
484 
UpdateAccess(ResourceId resource_id,Operation * op,bool read_only)485 void SideEffectAnalysisInfo::UpdateAccess(ResourceId resource_id,
486                                           Operation* op,
487                                           bool read_only) {
488   VLOG(2) << "    Updating access for resource " << resource_id;
489   op_to_resource_ids_[op].push_back({resource_id, read_only});
490   if (resource_id == kUnknownResourceId) {
491     if (read_only) {
492       // New unknown read is not tracked by any known resource access.
493       for (auto& entry : per_resource_access_info_) {
494         entry.getSecond().are_last_unknown_reads_tracked = false;
495       }
496     } else {
497       // Unknown write can clear all other tracked information, since it acts
498       // like a barrier.
499       per_resource_access_info_.clear();
500     }
501   }
502   auto& access_info = per_resource_access_info_[resource_id];
503   if (read_only) {
504     access_info.reads_since_last_write.push_back(op);
505     // Last unknown write is indirectly tracked by this read (we have added the
506     // write as a predecessor for `op` before).
507     access_info.is_last_unknown_write_tracked = true;
508   } else {
509     access_info.last_write = op;
510     access_info.reads_since_last_write.clear();
511     // Last unknown read(s) and write are indirectly tracked by this write (we
512     // have added the read(s) and write as predecessors for `op` before).
513     access_info.are_last_unknown_reads_tracked = true;
514     access_info.is_last_unknown_write_tracked = true;
515     access_info.is_last_unknown_write_tracked_by_write = true;
516   }
517 }
518 
AnalyzeFunction(func::FuncOp func_op)519 void SideEffectAnalysisInfo::AnalyzeFunction(func::FuncOp func_op) {
520   // AnalyzeRegion() recursively analyzes the function body, and only populates
521   // control_predecessors_.
522   AnalyzeRegion(&func_op.getBody());
523   // Populate sorted_control_predecessors_ and sorted_control_successors_ based
524   // on control_predecessors.
525   for (auto& entry : control_predecessors_) {
526     auto op = entry.getFirst();
527     auto& predecessors = entry.getSecond();
528     auto& sorted_predecessors = sorted_control_predecessors_[op];
529     for (Operation* predecessor : predecessors) {
530       sorted_predecessors.push_back(predecessor);
531       sorted_control_successors_[predecessor].push_back(op);
532     }
533   }
534   control_predecessors_.clear();
535   for (auto& entry : sorted_control_predecessors_) {
536     llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
537       return a->isBeforeInBlock(b);
538     });
539   }
540   for (auto& entry : sorted_control_successors_) {
541     llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
542       return a->isBeforeInBlock(b);
543     });
544   }
545 
546   // Populate the control sinks (i.e. side-effecting ops with no control
547   // successors) in the top level block.
548   for (const auto& entry : sorted_control_predecessors_) {
549     auto* op = entry.getFirst();
550     if (op->getBlock() == &func_op.front() &&
551         sorted_control_successors_.count(op) == 0) {
552       sorted_control_sinks_.push_back(op);
553     }
554   }
555   llvm::sort(sorted_control_sinks_, [](Operation* a, Operation* b) {
556     return a->isBeforeInBlock(b);
557   });
558 }
559 
AnalyzeRegion(Region * region)560 void SideEffectAnalysisInfo::AnalyzeRegion(Region* region) {
561   // We explicitly iterate through the regions and blocks in order to handle
562   // different nested regions separately.
563   for (Block& block : *region) {
564     for (Operation& op : block) {
565       for (Region& child_region : op.getRegions()) {
566         SideEffectAnalysisInfo child_analysis(
567             &child_region, op_side_effect_collector_, alias_analysis_);
568         // Move data from `child_analysis` to current region.
569         for (auto& entry : child_analysis.control_predecessors_)
570           control_predecessors_[entry.first] = std::move(entry.second);
571         for (auto& entry : child_analysis.op_to_resource_ids_)
572           op_to_resource_ids_[entry.first] = std::move(entry.second);
573       }
574       AnalyzeOp(&op);
575     }
576   }
577 }
578 
579 ResourceIdSet
GetConflictingIds(ResourceId resource_id,bool is_fetch_op) const580 SideEffectAnalysisInfo::GetConflictingIds(ResourceId resource_id,
581                                           bool is_fetch_op)  const {
582   ResourceIdSet conflicting_ids;
583   if (resource_id == kUnknownResourceId) {
584     // Unknown resource has potential conflict with all other resources, except
585     // those that are only self-dependent. For `Fetch` op make every resource
586     // conflicting in any case to ensure that all side-effecting ops in
587     // `Graph` feed into `Fetch` (its terminator).
588     for (auto& entry : per_resource_access_info_) {
589       ResourceId other_id = entry.getFirst();
590       if (!op_side_effect_collector_.IsOnlySelfDependent(other_id) ||
591           is_fetch_op)
592         conflicting_ids.insert(other_id);
593     }
594   } else {
595     conflicting_ids.insert(resource_id);
596     // Resource has potential conflict with unknown resource, if not only
597     // self-dependent.
598     if (!op_side_effect_collector_.IsOnlySelfDependent(resource_id))
599       conflicting_ids.insert(kUnknownResourceId);
600   }
601   return conflicting_ids;
602 }
603 
AnalyzeOp(Operation * op)604 void SideEffectAnalysisInfo::AnalyzeOp(Operation* op) {
605   VLOG(2) << "Processing op " << mlir::debugString(*op);
606   SideEffectsByResourceId side_effects_by_resource_id =
607         CollectSideEffectsByResourceId(
608             op,
609             op_side_effect_collector_,
610             alias_analysis_);
611 
612   // If the side-effecting op is a control source (i.e. it has no control
613   // predecessors), then `control_predecessors_` won't be updated below.
614   // However, we still want to track this op as it may have side effects visible
615   // to ops outside the function.
616   if (!side_effects_by_resource_id.empty()) control_predecessors_[op];
617 
618   // Traverse all resource IDs and their associated side effects.
619   bool had_unknown_resource_read = false;
620   for (auto pair : side_effects_by_resource_id) {
621     ResourceId resource_id = pair.first;
622     const SideEffects& side_effects = pair.second;
623     const bool read_only = side_effects.IsReadOnly();
624     VLOG(2) << "  Processing resource ID: " << resource_id
625             << ", read-only effect: " << read_only;
626     // An op that only allocates a resource is expected to return a handle that
627     // is used by all other accesses of the same resource. That means, other ops
628     // that access the same resource already have a data dependency on the
629     // allocating op so it doesn't need any control predecessors or successors.
630     if (side_effects.IsAllocOnly()) continue;
631     // Effect is dominated by previous unknown resource read effect.
632     if (read_only && had_unknown_resource_read) continue;
633 
634     ResourceIdSet conflicting_ids = GetConflictingIds(
635         resource_id, isa<tf_executor::FetchOp>(op));
636 
637     // Add predecessors for conflicting IDs.
638     bool is_unknown_access_indirectly_tracked = false;
639     for (ResourceId id : conflicting_ids) {
640       // Handle unknown resource later, access might already be indirectly
641       // tracked by another resource access.
642       if (id == kUnknownResourceId) continue;
643 
644       AddPredecessorsForAccess(id, op, read_only);
645       is_unknown_access_indirectly_tracked |=
646           IsUnknownAccessIndirectlyTrackedByResource(id, read_only);
647     }
648     // Add predecessors for unknown resource if necessary.
649     if (conflicting_ids.contains(kUnknownResourceId) &&
650         !is_unknown_access_indirectly_tracked)
651       AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
652     // Update resource access.
653     UpdateAccess(resource_id, op, read_only);
654 
655     // If this effect dominates all other possible effects, return here. Note
656     // that if there is any effect for an unknown resource, then we encounter it
657     // in the first iteration since `kUnknownResourceId` is smaller than all
658     // other resource IDs.
659     if (resource_id == kUnknownResourceId && !read_only) return;
660     if (resource_id == kUnknownResourceId && read_only) {
661       had_unknown_resource_read = true;
662     }
663   }
664 }
665 
IsUnknownAccessIndirectlyTrackedByResource(ResourceId resource_id,bool read_only)666 bool SideEffectAnalysisInfo::IsUnknownAccessIndirectlyTrackedByResource(
667     ResourceId resource_id, bool read_only) {
668   auto it = per_resource_access_info_.find(resource_id);
669   if (it == per_resource_access_info_.end()) return false;
670   auto access_info = it->getSecond();
671 
672   auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
673   if (unknown_it == per_resource_access_info_.end()) return true;
674   auto unknown_access_info = unknown_it->getSecond();
675 
676   bool no_unknown_read = unknown_access_info.reads_since_last_write.empty();
677   bool no_unknown_write = (unknown_access_info.last_write == nullptr);
678 
679   // For the read-only case we only need that the last unknown write is already
680   // tracked by the last `resource` write since we don't have dependencies to
681   // any other read accesses.
682   // Otherwise, we need that the last unknown read(s) and write are already
683   // tracked by any read or write accesses of `resource`.
684   bool is_tracked = read_only ?
685       no_unknown_write || access_info.is_last_unknown_write_tracked_by_write :
686       (no_unknown_write || access_info.is_last_unknown_write_tracked) &&
687       (no_unknown_read || access_info.are_last_unknown_reads_tracked);
688   if (is_tracked) {
689     VLOG(2) << "      Unknown access indirectly tracked by resource "
690             << resource_id;
691   }
692   return is_tracked;
693 }
694 
695 llvm::SmallVector<Operation*, 4>
DirectControlPredecessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const696 SideEffectAnalysisInfo::DirectControlPredecessors(
697     Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
698   llvm::SmallVector<Operation*, 4> result;
699   auto it = sorted_control_predecessors_.find(op);
700   if (it == sorted_control_predecessors_.end()) return result;
701   result.reserve(it->getSecond().size());
702   for (auto predecessor : it->getSecond()) {
703     if (!filter || filter(predecessor)) result.push_back(predecessor);
704   }
705   return result;
706 }
707 
708 llvm::SmallVector<Operation*, 4>
DirectControlSuccessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const709 SideEffectAnalysisInfo::DirectControlSuccessors(
710     Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
711   llvm::SmallVector<Operation*, 4> result;
712   auto it = sorted_control_successors_.find(op);
713   if (it == sorted_control_successors_.end()) return result;
714   result.reserve(it->getSecond().size());
715   for (auto successor : it->getSecond()) {
716     if (!filter || filter(successor)) result.push_back(successor);
717   }
718   return result;
719 }
720 
721 const llvm::SmallVector<std::pair<ResourceId, bool>>&
GetResourceIds(Operation * op) const722 SideEffectAnalysisInfo::GetResourceIds(Operation* op) const {
723   auto it = op_to_resource_ids_.find(op);
724   if (it == op_to_resource_ids_.end()) return empty_resource_ids_;
725   return it->getSecond();
726 }
727 
728 }  // namespace detail
729 
SideEffectAnalysis(ModuleOp module)730 SideEffectAnalysis::SideEffectAnalysis(ModuleOp module)
731   // Analyze entire module for alias analysis info.
732     : alias_analysis_(module) {
733   // Collect op-based side effects for entire module.
734   detail::OpSideEffectCollector op_side_effect_collector(module);
735 
736   // Analyze side effects for all functions in module.
737   for (auto func : module.getOps<func::FuncOp>())
738     this->info_map_.try_emplace(func, func,
739                                 op_side_effect_collector,
740                                 alias_analysis_.GetAnalysisForFunc(func));
741 }
742 
743 }  // namespace TF
744 }  // namespace mlir
745