1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
17
18 #include <bitset>
19 #include <string>
20
21 #include "absl/container/node_hash_map.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Block.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44
45 namespace mlir {
46 namespace TF {
47 namespace {
48
49 constexpr ResourceId kUnknownResourceId =
50 ResourceAliasAnalysis::Info::kUnknownResourceId;
51 static_assert(kUnknownResourceId < 0, "kUnknownResourceId must be < 0");
52
53 // A collection of Resource IDs. Note that `kUnknownResourceId` is smaller than
54 // all other resource IDs which are nonnegative (see check above) so it will
55 // always be the first element of a `ResourceIdSet` (we make use of this).
56 using ResourceIdSet = llvm::SmallSet<ResourceId, 8>;
57
58 // Note that we cannot simply define a `static const llvm::SmallSet` here
59 // because of missing `initializer_list` support for `llvm::SmallSet`.
UnknownResourceSet()60 const ResourceIdSet& UnknownResourceSet() {
61 // clang-format off
62 static auto* id_set = new ResourceIdSet();
63 id_set->insert(kUnknownResourceId);
64 return *id_set;
65 }
66
67 // Helper function to avoid frequent checks for unknown IDs.
GetResourceUniqueIdsOrUnknown(Value value,const ResourceAliasAnalysis::Info & alias_analysis)68 const ResourceIdSet& GetResourceUniqueIdsOrUnknown(
69 Value value,
70 const ResourceAliasAnalysis::Info& alias_analysis) {
71 if (!getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>() ||
72 alias_analysis.IsUnknownResource(value)) return UnknownResourceSet();
73 return alias_analysis.GetResourceUniqueIds(value);
74 }
75
76 // Helper class for a collection of side effects for one resource.
77 class SideEffects {
78 enum Type {
79 kAlloc = 0,
80 kFree = 1,
81 kRead = 2,
82 kWrite = 3
83 };
84
85 public:
IsAlloc() const86 bool IsAlloc() const { return effects_.test(kAlloc); }
IsFree() const87 bool IsFree() const { return effects_.test(kFree); }
IsRead() const88 bool IsRead() const { return effects_.test(kRead); }
IsWrite() const89 bool IsWrite() const { return effects_.test(kWrite); }
IsAllocOnly() const90 bool IsAllocOnly() const { return IsAlloc() && effects_.count() == 1; }
IsReadOnly() const91 bool IsReadOnly() const { return IsRead() && effects_.count() == 1; }
GetResourceId() const92 ResourceId GetResourceId() const { return resource_id_; }
93
SetAlloc()94 void SetAlloc() { effects_.set(kAlloc); }
SetFree()95 void SetFree() { effects_.set(kFree); }
SetRead()96 void SetRead() { effects_.set(kRead); }
SetWrite()97 void SetWrite() { effects_.set(kWrite); }
SetUnknownEffect()98 void SetUnknownEffect() { effects_.set(); }
SetResourceId(ResourceId resource_id)99 void SetResourceId(ResourceId resource_id) { resource_id_ = resource_id; }
AddEffects(const SideEffects & other_effects)100 void AddEffects(const SideEffects& other_effects) {
101 effects_ |= other_effects.effects_;
102 }
103
104 private:
105 std::bitset<4> effects_ = 0;
106 ResourceId resource_id_ = kUnknownResourceId;
107 };
108
109 // We use `std::map` here because we rely on the order of elements.
110 using SideEffectsByResourceId = std::map<ResourceId, SideEffects>;
111
112 // We use `std::unordered_map` here for pointer stability reasons.
113 // Note: If memory usage ever becomes a bottleneck here (not expected) we could
114 // use a Trie-like data structure to avoid storing side effects in both parent
115 // op and all its child ops (recursively), at the expense of lookup time.
116 using OpSideEffectMap = std::unordered_map<Operation*, SideEffectsByResourceId>;
117
118 // Update `side_effects_by_resource_id` with `side_effects`.
UpdateSideEffectsByResourceId(const SideEffects & side_effects,SideEffectsByResourceId & side_effects_by_resource_id)119 void UpdateSideEffectsByResourceId(
120 const SideEffects& side_effects,
121 SideEffectsByResourceId& side_effects_by_resource_id) {
122 ResourceId id = side_effects.GetResourceId();
123 auto iter = side_effects_by_resource_id.find(id);
124 if (iter == side_effects_by_resource_id.end()) {
125 side_effects_by_resource_id[id] = side_effects;
126 } else {
127 iter->second.AddEffects(side_effects);
128 }
129 }
130
MayHaveSideEffect(Operation * op)131 bool MayHaveSideEffect(Operation* op) {
132 if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
133 return TensorFlowDialect::CanHaveSideEffects(op);
134
135 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return false;
136 // Conservatively assume that there can be side effects.
137 return true;
138 }
139
ShouldUseResourceAliasAnalysis(const MemoryEffects::EffectInstance & effect)140 bool ShouldUseResourceAliasAnalysis(
141 const MemoryEffects::EffectInstance& effect) {
142 Value value = effect.getValue();
143 if (value && getElementTypeOrSelf(value.getType()).isa<ResourceType>()) {
144 // For value-based effects on resource values we can use resource alias
145 // analysis.
146 return true;
147 }
148 // For all other effects don't rely on resource alias analysis. Note that
149 // non-resource values are not processed in resource alias analysis.
150 return false;
151 }
152
153 //===----------------------------------------------------------------------===//
154 // SideEffectAnalysisInfo helper functions.
155 //===----------------------------------------------------------------------===//
156
GetSideEffectsFromEffectInstance(const MemoryEffects::EffectInstance & effect_instance,Operation * op)157 SideEffects GetSideEffectsFromEffectInstance(
158 const MemoryEffects::EffectInstance& effect_instance, Operation* op) {
159 mlir::SideEffects::Effect* effect = effect_instance.getEffect();
160 SideEffects side_effects;
161 if (isa<MemoryEffects::Allocate>(effect)) {
162 side_effects.SetAlloc();
163 } else if (isa<MemoryEffects::Free>(effect)) {
164 side_effects.SetFree();
165 } else if (isa<MemoryEffects::Read>(effect)) {
166 side_effects.SetRead();
167 } else if (isa<MemoryEffects::Write>(effect)) {
168 side_effects.SetWrite();
169 } else {
170 LOG(WARNING) << "Unsupported effect for op "
171 << op->getName().getStringRef().str();
172 side_effects.SetUnknownEffect();
173 }
174 return side_effects;
175 }
176
177 } // namespace
178
179 namespace detail {
180
181 // Class for propagating op-based side effects bottom-up and collecting them
182 // per op, by resource ID.
183 class OpSideEffectCollector {
184 public:
185 // Recursively collects op-based side effects for all ops in module and
186 // populates `op_side_effect_map_`.
OpSideEffectCollector(ModuleOp module)187 explicit OpSideEffectCollector(ModuleOp module) {
188 symbol_table_collection_.getSymbolTable(module);
189 for (auto func : module.getOps<func::FuncOp>()) {
190 CollectOpSideEffects(func);
191 }
192 }
193
194 // Returns op-based side effects by resource ID for `op`.
GetSideEffectsForOp(Operation * op) const195 const SideEffectsByResourceId& GetSideEffectsForOp(Operation* op) const {
196 auto iter = op_side_effect_map_.find(op);
197 if (iter != op_side_effect_map_.end()) return iter->second;
198 return empty_side_effects_map_;
199 }
200
201 // Returns true iff resource with given ID is only self-dependent, i.e., there
202 // are no dependencies to other resources (including unknown resources).
IsOnlySelfDependent(ResourceId resource_id) const203 bool IsOnlySelfDependent(ResourceId resource_id) const {
204 return self_dependent_only_ids_.contains(resource_id);
205 }
206
207 private:
208 // Adds op-based side effects from all ops in `region` to `op` side effects.
209 // Collects side effects for ops that weren't visited before.
AddRegionSideEffectsForOp(Region & region,Operation * op)210 void AddRegionSideEffectsForOp(Region& region, Operation* op) {
211 for (Block& block : region) {
212 for (Operation& curr_op : block) {
213 if (op_side_effect_map_.count(&curr_op) == 0) {
214 CollectOpSideEffects(&curr_op);
215 }
216 for (const auto& entry : op_side_effect_map_[&curr_op]) {
217 UpdateSideEffectsByResourceId(entry.second, op_side_effect_map_[op]);
218 }
219 }
220 }
221 }
222
223 // Collects op-based side effects for `op` in `op_side_effect_map_[op]`.
CollectOpSideEffects(Operation * op)224 void CollectOpSideEffects(Operation* op) {
225 if (!MayHaveSideEffect(op)) return;
226 // Skip following ops to avoid that every island, graph and function is
227 // classified as unknown side-effecting.
228 if (isa<tf_executor::YieldOp, tf_executor::FetchOp,
229 mlir::func::ReturnOp>(op))
230 return;
231
232 // Propagate side effects from regions or functions attached to `op` for
233 // some special cases.
234 if (auto func = llvm::dyn_cast<func::FuncOp>(op)) {
235 AddRegionSideEffectsForOp(func.getBody(), op);
236 } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
237 func::FuncOp func_op = dyn_cast<func::FuncOp>(
238 call.resolveCallable(&symbol_table_collection_));
239 if (func_op) {
240 AddRegionSideEffectsForOp(func_op.getBody(), op);
241 }
242 } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
243 AddRegionSideEffectsForOp(if_op.then_function().getBody(), op);
244 AddRegionSideEffectsForOp(if_op.else_function().getBody(), op);
245 } else if (auto while_op = dyn_cast<WhileOp>(op)) {
246 AddRegionSideEffectsForOp(while_op.body_function().getBody(), op);
247 } else if (auto while_region_op = dyn_cast<WhileRegionOp>(op)) {
248 AddRegionSideEffectsForOp(while_region_op.body(), op);
249 } else if (auto case_op = dyn_cast<CaseOp>(op)) {
250 llvm::SmallVector<func::FuncOp, 4> branch_funcs;
251 case_op.get_branch_functions(branch_funcs);
252 for (auto branch_func : branch_funcs) {
253 AddRegionSideEffectsForOp(branch_func.getBody(), op);
254 }
255 } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp,
256 tf_executor::IslandOp, tf_executor::GraphOp, IfRegionOp,
257 CaseRegionOp>(op)) {
258 for (Region& region : op->getRegions()) {
259 AddRegionSideEffectsForOp(region, op);
260 }
261 } else {
262 // Now handle all other ops.
263 auto& side_effects_by_resource_id = op_side_effect_map_[op];
264 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
265 auto interface = dyn_cast<MemoryEffectOpInterface>(op);
266 if (interface) interface.getEffects(effects);
267 if (effects.empty()) {
268 // The op is potentially side-effecting and doesn't have any effect
269 // assigned, treat it as unknown side effect.
270 SideEffects side_effects;
271 side_effects.SetResourceId(kUnknownResourceId);
272 side_effects.SetUnknownEffect();
273 UpdateSideEffectsByResourceId(side_effects,
274 side_effects_by_resource_id);
275 // An unknown side effect dominates other side effects so we don't have
276 // to add them and can return here.
277 return;
278 }
279 // Add op-based side effects from regions (if any).
280 for (Region& region : op->getRegions()) {
281 AddRegionSideEffectsForOp(region, op);
282 }
283 // Add op-based side effects for the op itself.
284 for (const auto& effect : effects) {
285 // We handle value-based side effects for which we can use resource
286 // alias analysis at a different place, skip here.
287 if (ShouldUseResourceAliasAnalysis(effect)) continue;
288 if (llvm::isa<ResourceEffects::MustExecute>(effect.getResource()))
289 // We have this fake resource to avoid that certain ops are considered
290 // dead or get pruned, ignore it for side effect analysis.
291 continue;
292
293 // Add side effects for op resource ID.
294 std::string instance_str = "";
295 SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
296 if (auto resource_instance_op =
297 dyn_cast<GetResourceInstanceInterface>(op)) {
298 instance_str = resource_instance_op.GetResourceInstanceStr();
299 }
300 TypeID type_id = effect.getResource()->getResourceID();
301 ResourceId resource_id = GetOpResourceId(type_id, instance_str);
302 side_effects.SetResourceId(resource_id);
303 UpdateSideEffectsByResourceId(side_effects,
304 side_effects_by_resource_id);
305 if (ResourceEffects::IsOnlySelfDependent(type_id)) {
306 self_dependent_only_ids_.insert(resource_id);
307 }
308 }
309 }
310 }
311
312 // Get internal op resource ID from MLIR type ID and instance ID.
GetOpResourceId(TypeID type_id,std::string instance_str)313 ResourceId GetOpResourceId(TypeID type_id, std::string instance_str) {
314 auto emplace_result = type_instance_str_to_op_resource_id_.try_emplace(
315 std::make_pair(type_id.getAsOpaquePointer(), instance_str),
316 next_op_resource_id_);
317 // Increment type ID if we have encountered a new resource type.
318 if (emplace_result.second) ++next_op_resource_id_;
319 return emplace_result.first->second;
320 }
321
322 // We use [0, kMaxResourceId] for resource IDs returned by resource alias
323 // analysis and [kMaxResourceId + 1, ...] for resource IDs which we generate
324 // for op-based side effects.
325 const ResourceId kMaxResourceId =
326 std::numeric_limits<ResourceId>::max() / 2;
327 // Next available ID for op-based resources (resources not handled by resource
328 // alias analysis).
329 ResourceId next_op_resource_id_ = kMaxResourceId + 1;
330 // Maps (type ID, instance ID) pairs to internal IDs for op-based resources.
331 // Also see comment above. Instead of using TypeID directly we use its opaque
332 // pointer.
333 absl::node_hash_map<std::pair<const void*, std::string>, ResourceId>
334 type_instance_str_to_op_resource_id_;
335 // Used for faster callable resolution.
336 SymbolTableCollection symbol_table_collection_;
337 // Collect all op-based side effects here.
338 OpSideEffectMap op_side_effect_map_;
339 const SideEffectsByResourceId empty_side_effects_map_;
340
341 // Set of all resource IDs which only have dependencies to themselves, not to
342 // any other resource ID (including unknown resource ID).
343 llvm::SmallDenseSet<ResourceId, 8> self_dependent_only_ids_;
344 };
345
346 // Collects all op-based and value-based side effects for `op` per resource ID.
CollectSideEffectsByResourceId(Operation * op,const OpSideEffectCollector & op_side_effect_collector,const TF::ResourceAliasAnalysis::Info & alias_analysis)347 SideEffectsByResourceId CollectSideEffectsByResourceId(
348 Operation* op,
349 const OpSideEffectCollector& op_side_effect_collector,
350 const TF::ResourceAliasAnalysis::Info& alias_analysis) {
351 SideEffectsByResourceId side_effects_by_resource_id;
352 if (!MayHaveSideEffect(op)) return side_effects_by_resource_id;
353
354 if (isa<tf_device::LaunchOp, tf_device::ClusterOp, tf_executor::IslandOp,
355 tf_executor::GraphOp, IfRegionOp, CaseRegionOp, WhileRegionOp>(op)) {
356 // For ops that are side-effecting only if their attached regions are,
357 // collect effects for all ops in the regions instead of collecting effects
358 // for the op itself. This is important to avoid conservatism and to find
359 // resource variable accesses in regions which are not exposed to the op
360 // interface.
361 for (Region& region : op->getRegions()) {
362 for (Operation& region_op : region.front().without_terminator()) {
363 SideEffectsByResourceId region_op_effects =
364 CollectSideEffectsByResourceId(
365 ®ion_op,
366 op_side_effect_collector,
367 alias_analysis);
368 for (const auto& [resource_id, side_effect] : region_op_effects) {
369 UpdateSideEffectsByResourceId(side_effect,
370 side_effects_by_resource_id);
371 }
372 }
373 }
374 return side_effects_by_resource_id;
375 }
376
377 // Copy op-based side effects.
378 side_effects_by_resource_id =
379 op_side_effect_collector.GetSideEffectsForOp(op);
380 bool found_any_effect = !side_effects_by_resource_id.empty();
381
382 // Collect value-based side effects from op interface.
383 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
384 auto interface = dyn_cast<MemoryEffectOpInterface>(op);
385 if (interface) interface.getEffects(effects);
386
387 llvm::SmallDenseSet<Value, 8> processed_values;
388 for (const auto& effect : effects) {
389 Value value = effect.getValue();
390 found_any_effect = true;
391
392 // We only collect value-based side effects here for which we can use
393 // resource alias analysis. Other side effects are treated as op-based
394 // side effects.
395 if (!ShouldUseResourceAliasAnalysis(effect)) continue;
396 if (value) processed_values.insert(value);
397
398 TypeID type_id = effect.getResource()->getResourceID();
399 if (ResourceEffects::IsOnlySelfDependent(type_id)) {
400 // For value-based side effects we currently treat resource types that are
401 // only self-dependent conservatively, i.e., we do add dependencies
402 // to/from unknown resource types. Currently, we don't have such cases and
403 // there is no indication that we will need to support them in the future.
404 LOG(WARNING) << "Self-dependent-only resource types are treated "
405 "conservatively for value-based side effects.";
406 }
407
408 // Add side effects for every potentially accessed resource ID.
409 SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
410 const auto& ids = GetResourceUniqueIdsOrUnknown(value, alias_analysis);
411 for (ResourceId id : ids) {
412 side_effects.SetResourceId(id);
413 UpdateSideEffectsByResourceId(side_effects, side_effects_by_resource_id);
414 }
415 }
416
417 auto add_remaining_effects = [&](auto resource_values) {
418 for (Value resource_value : resource_values) {
419 // If we already processed this value before, skip it.
420 if (processed_values.count(resource_value) > 0) continue;
421 found_any_effect = true;
422
423 // Conservatively set unknown effect.
424 SideEffects unknown_effect;
425 unknown_effect.SetUnknownEffect();
426
427 // Add side effects for every potentially accessed resource ID.
428 const auto& ids =
429 GetResourceUniqueIdsOrUnknown(resource_value, alias_analysis);
430 for (ResourceId id : ids) {
431 unknown_effect.SetResourceId(id);
432 UpdateSideEffectsByResourceId(unknown_effect,
433 side_effects_by_resource_id);
434 }
435 }
436 };
437 // Add value-based side effects for resource values which are not covered by
438 // any side effect so far, for example, resource values being passed to
439 // `tf.While` or `tf.If` ops which are not part of the op definition but
440 // appear in a variadic input list.
441 add_remaining_effects(filter_resources(op->getOperands()));
442 add_remaining_effects(filter_resources(op->getResults()));
443
444 if (!found_any_effect) {
445 // We haven't collected any side effect but the op is potentially
446 // side-effecting (otherwise we would have returned), therefore we have an
447 // unknown side effect for an unknown resource.
448 SideEffects unknown_effect;
449 unknown_effect.SetUnknownEffect();
450 unknown_effect.SetResourceId(kUnknownResourceId);
451 UpdateSideEffectsByResourceId(unknown_effect,
452 side_effects_by_resource_id);
453 }
454 return side_effects_by_resource_id;
455 }
456
457 //===----------------------------------------------------------------------===//
458 // SideEffectAnalysisInfo
459 //===----------------------------------------------------------------------===//
460
AddPredecessorsForAccess(ResourceId resource_id,Operation * op,bool read_only)461 void SideEffectAnalysisInfo::AddPredecessorsForAccess(ResourceId resource_id,
462 Operation* op,
463 bool read_only) {
464 VLOG(2) << " Adding predecessors for resource " << resource_id;
465 auto it = per_resource_access_info_.find(resource_id);
466 if (it == per_resource_access_info_.end()) return;
467 const auto& access_info = it->getSecond();
468
469 auto& control_predecessors = control_predecessors_[op];
470 bool is_last_write_indirectly_tracked = false;
471 if (!read_only) {
472 // Add reads after last write as predecessors.
473 control_predecessors.insert(access_info.reads_since_last_write.begin(),
474 access_info.reads_since_last_write.end());
475 // Last write is indirectly tracked by any read predecessor we added.
476 is_last_write_indirectly_tracked =
477 !access_info.reads_since_last_write.empty();
478 }
479 if (access_info.last_write && !is_last_write_indirectly_tracked) {
480 // Add last write as predecessor.
481 control_predecessors.insert(access_info.last_write);
482 }
483 }
484
UpdateAccess(ResourceId resource_id,Operation * op,bool read_only)485 void SideEffectAnalysisInfo::UpdateAccess(ResourceId resource_id,
486 Operation* op,
487 bool read_only) {
488 VLOG(2) << " Updating access for resource " << resource_id;
489 op_to_resource_ids_[op].push_back({resource_id, read_only});
490 if (resource_id == kUnknownResourceId) {
491 if (read_only) {
492 // New unknown read is not tracked by any known resource access.
493 for (auto& entry : per_resource_access_info_) {
494 entry.getSecond().are_last_unknown_reads_tracked = false;
495 }
496 } else {
497 // Unknown write can clear all other tracked information, since it acts
498 // like a barrier.
499 per_resource_access_info_.clear();
500 }
501 }
502 auto& access_info = per_resource_access_info_[resource_id];
503 if (read_only) {
504 access_info.reads_since_last_write.push_back(op);
505 // Last unknown write is indirectly tracked by this read (we have added the
506 // write as a predecessor for `op` before).
507 access_info.is_last_unknown_write_tracked = true;
508 } else {
509 access_info.last_write = op;
510 access_info.reads_since_last_write.clear();
511 // Last unknown read(s) and write are indirectly tracked by this write (we
512 // have added the read(s) and write as predecessors for `op` before).
513 access_info.are_last_unknown_reads_tracked = true;
514 access_info.is_last_unknown_write_tracked = true;
515 access_info.is_last_unknown_write_tracked_by_write = true;
516 }
517 }
518
AnalyzeFunction(func::FuncOp func_op)519 void SideEffectAnalysisInfo::AnalyzeFunction(func::FuncOp func_op) {
520 // AnalyzeRegion() recursively analyzes the function body, and only populates
521 // control_predecessors_.
522 AnalyzeRegion(&func_op.getBody());
523 // Populate sorted_control_predecessors_ and sorted_control_successors_ based
524 // on control_predecessors.
525 for (auto& entry : control_predecessors_) {
526 auto op = entry.getFirst();
527 auto& predecessors = entry.getSecond();
528 auto& sorted_predecessors = sorted_control_predecessors_[op];
529 for (Operation* predecessor : predecessors) {
530 sorted_predecessors.push_back(predecessor);
531 sorted_control_successors_[predecessor].push_back(op);
532 }
533 }
534 control_predecessors_.clear();
535 for (auto& entry : sorted_control_predecessors_) {
536 llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
537 return a->isBeforeInBlock(b);
538 });
539 }
540 for (auto& entry : sorted_control_successors_) {
541 llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
542 return a->isBeforeInBlock(b);
543 });
544 }
545
546 // Populate the control sinks (i.e. side-effecting ops with no control
547 // successors) in the top level block.
548 for (const auto& entry : sorted_control_predecessors_) {
549 auto* op = entry.getFirst();
550 if (op->getBlock() == &func_op.front() &&
551 sorted_control_successors_.count(op) == 0) {
552 sorted_control_sinks_.push_back(op);
553 }
554 }
555 llvm::sort(sorted_control_sinks_, [](Operation* a, Operation* b) {
556 return a->isBeforeInBlock(b);
557 });
558 }
559
AnalyzeRegion(Region * region)560 void SideEffectAnalysisInfo::AnalyzeRegion(Region* region) {
561 // We explicitly iterate through the regions and blocks in order to handle
562 // different nested regions separately.
563 for (Block& block : *region) {
564 for (Operation& op : block) {
565 for (Region& child_region : op.getRegions()) {
566 SideEffectAnalysisInfo child_analysis(
567 &child_region, op_side_effect_collector_, alias_analysis_);
568 // Move data from `child_analysis` to current region.
569 for (auto& entry : child_analysis.control_predecessors_)
570 control_predecessors_[entry.first] = std::move(entry.second);
571 for (auto& entry : child_analysis.op_to_resource_ids_)
572 op_to_resource_ids_[entry.first] = std::move(entry.second);
573 }
574 AnalyzeOp(&op);
575 }
576 }
577 }
578
579 ResourceIdSet
GetConflictingIds(ResourceId resource_id,bool is_fetch_op) const580 SideEffectAnalysisInfo::GetConflictingIds(ResourceId resource_id,
581 bool is_fetch_op) const {
582 ResourceIdSet conflicting_ids;
583 if (resource_id == kUnknownResourceId) {
584 // Unknown resource has potential conflict with all other resources, except
585 // those that are only self-dependent. For `Fetch` op make every resource
586 // conflicting in any case to ensure that all side-effecting ops in
587 // `Graph` feed into `Fetch` (its terminator).
588 for (auto& entry : per_resource_access_info_) {
589 ResourceId other_id = entry.getFirst();
590 if (!op_side_effect_collector_.IsOnlySelfDependent(other_id) ||
591 is_fetch_op)
592 conflicting_ids.insert(other_id);
593 }
594 } else {
595 conflicting_ids.insert(resource_id);
596 // Resource has potential conflict with unknown resource, if not only
597 // self-dependent.
598 if (!op_side_effect_collector_.IsOnlySelfDependent(resource_id))
599 conflicting_ids.insert(kUnknownResourceId);
600 }
601 return conflicting_ids;
602 }
603
AnalyzeOp(Operation * op)604 void SideEffectAnalysisInfo::AnalyzeOp(Operation* op) {
605 VLOG(2) << "Processing op " << mlir::debugString(*op);
606 SideEffectsByResourceId side_effects_by_resource_id =
607 CollectSideEffectsByResourceId(
608 op,
609 op_side_effect_collector_,
610 alias_analysis_);
611
612 // If the side-effecting op is a control source (i.e. it has no control
613 // predecessors), then `control_predecessors_` won't be updated below.
614 // However, we still want to track this op as it may have side effects visible
615 // to ops outside the function.
616 if (!side_effects_by_resource_id.empty()) control_predecessors_[op];
617
618 // Traverse all resource IDs and their associated side effects.
619 bool had_unknown_resource_read = false;
620 for (auto pair : side_effects_by_resource_id) {
621 ResourceId resource_id = pair.first;
622 const SideEffects& side_effects = pair.second;
623 const bool read_only = side_effects.IsReadOnly();
624 VLOG(2) << " Processing resource ID: " << resource_id
625 << ", read-only effect: " << read_only;
626 // An op that only allocates a resource is expected to return a handle that
627 // is used by all other accesses of the same resource. That means, other ops
628 // that access the same resource already have a data dependency on the
629 // allocating op so it doesn't need any control predecessors or successors.
630 if (side_effects.IsAllocOnly()) continue;
631 // Effect is dominated by previous unknown resource read effect.
632 if (read_only && had_unknown_resource_read) continue;
633
634 ResourceIdSet conflicting_ids = GetConflictingIds(
635 resource_id, isa<tf_executor::FetchOp>(op));
636
637 // Add predecessors for conflicting IDs.
638 bool is_unknown_access_indirectly_tracked = false;
639 for (ResourceId id : conflicting_ids) {
640 // Handle unknown resource later, access might already be indirectly
641 // tracked by another resource access.
642 if (id == kUnknownResourceId) continue;
643
644 AddPredecessorsForAccess(id, op, read_only);
645 is_unknown_access_indirectly_tracked |=
646 IsUnknownAccessIndirectlyTrackedByResource(id, read_only);
647 }
648 // Add predecessors for unknown resource if necessary.
649 if (conflicting_ids.contains(kUnknownResourceId) &&
650 !is_unknown_access_indirectly_tracked)
651 AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
652 // Update resource access.
653 UpdateAccess(resource_id, op, read_only);
654
655 // If this effect dominates all other possible effects, return here. Note
656 // that if there is any effect for an unknown resource, then we encounter it
657 // in the first iteration since `kUnknownResourceId` is smaller than all
658 // other resource IDs.
659 if (resource_id == kUnknownResourceId && !read_only) return;
660 if (resource_id == kUnknownResourceId && read_only) {
661 had_unknown_resource_read = true;
662 }
663 }
664 }
665
IsUnknownAccessIndirectlyTrackedByResource(ResourceId resource_id,bool read_only)666 bool SideEffectAnalysisInfo::IsUnknownAccessIndirectlyTrackedByResource(
667 ResourceId resource_id, bool read_only) {
668 auto it = per_resource_access_info_.find(resource_id);
669 if (it == per_resource_access_info_.end()) return false;
670 auto access_info = it->getSecond();
671
672 auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
673 if (unknown_it == per_resource_access_info_.end()) return true;
674 auto unknown_access_info = unknown_it->getSecond();
675
676 bool no_unknown_read = unknown_access_info.reads_since_last_write.empty();
677 bool no_unknown_write = (unknown_access_info.last_write == nullptr);
678
679 // For the read-only case we only need that the last unknown write is already
680 // tracked by the last `resource` write since we don't have dependencies to
681 // any other read accesses.
682 // Otherwise, we need that the last unknown read(s) and write are already
683 // tracked by any read or write accesses of `resource`.
684 bool is_tracked = read_only ?
685 no_unknown_write || access_info.is_last_unknown_write_tracked_by_write :
686 (no_unknown_write || access_info.is_last_unknown_write_tracked) &&
687 (no_unknown_read || access_info.are_last_unknown_reads_tracked);
688 if (is_tracked) {
689 VLOG(2) << " Unknown access indirectly tracked by resource "
690 << resource_id;
691 }
692 return is_tracked;
693 }
694
695 llvm::SmallVector<Operation*, 4>
DirectControlPredecessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const696 SideEffectAnalysisInfo::DirectControlPredecessors(
697 Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
698 llvm::SmallVector<Operation*, 4> result;
699 auto it = sorted_control_predecessors_.find(op);
700 if (it == sorted_control_predecessors_.end()) return result;
701 result.reserve(it->getSecond().size());
702 for (auto predecessor : it->getSecond()) {
703 if (!filter || filter(predecessor)) result.push_back(predecessor);
704 }
705 return result;
706 }
707
708 llvm::SmallVector<Operation*, 4>
DirectControlSuccessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const709 SideEffectAnalysisInfo::DirectControlSuccessors(
710 Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
711 llvm::SmallVector<Operation*, 4> result;
712 auto it = sorted_control_successors_.find(op);
713 if (it == sorted_control_successors_.end()) return result;
714 result.reserve(it->getSecond().size());
715 for (auto successor : it->getSecond()) {
716 if (!filter || filter(successor)) result.push_back(successor);
717 }
718 return result;
719 }
720
721 const llvm::SmallVector<std::pair<ResourceId, bool>>&
GetResourceIds(Operation * op) const722 SideEffectAnalysisInfo::GetResourceIds(Operation* op) const {
723 auto it = op_to_resource_ids_.find(op);
724 if (it == op_to_resource_ids_.end()) return empty_resource_ids_;
725 return it->getSecond();
726 }
727
728 } // namespace detail
729
SideEffectAnalysis(ModuleOp module)730 SideEffectAnalysis::SideEffectAnalysis(ModuleOp module)
731 // Analyze entire module for alias analysis info.
732 : alias_analysis_(module) {
733 // Collect op-based side effects for entire module.
734 detail::OpSideEffectCollector op_side_effect_collector(module);
735
736 // Analyze side effects for all functions in module.
737 for (auto func : module.getOps<func::FuncOp>())
738 this->info_map_.try_emplace(func, func,
739 op_side_effect_collector,
740 alias_analysis_.GetAnalysisForFunc(func));
741 }
742
743 } // namespace TF
744 } // namespace mlir
745