• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
17 
18 #include <cstdint>
19 #include <initializer_list>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Location.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
40 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
41 #include "mlir/Support/LLVM.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
48 
49 namespace mlir {
50 namespace TF {
51 namespace {
52 
53 constexpr auto kUnknownResourceId =
54     ResourceAliasAnalysis::Info::kUnknownResourceId;
55 
56 //===----------------------------------------------------------------------===//
57 // SideEffectAnalysisInfo helper functions.
58 //===----------------------------------------------------------------------===//
59 
60 // Returns a set that contains only kUnknownResourceId.
UnknownResourceSet()61 llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
62   llvm::SmallDenseSet<int64_t, 8> unknown_set;
63   unknown_set.insert(kUnknownResourceId);
64   return unknown_set;
65 }
66 
67 // Returns all resources that could be accessed by op, or UnknownResourceSet()
68 // if we cannot find all of them.
FindAccessedResources(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis)69 llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
70     Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) {
71   VLOG(1) << "Find accessed resources for: " << debugString(*op);
72   llvm::SmallDenseSet<int64_t, 8> resources;
73 
74   for (auto operand : filter_resources(op->getOperands())) {
75     if (alias_analysis.IsUnknownResource(operand)) {
76       VLOG(1) << "\tunknown resource in operand";
77       return UnknownResourceSet();
78     }
79     const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
80     resources.insert(ids.begin(), ids.end());
81   }
82   for (auto result : filter_resources(op->getResults())) {
83     if (alias_analysis.IsUnknownResource(result)) {
84       VLOG(1) << "\tunknown resource in result";
85       return UnknownResourceSet();
86     }
87     const auto& ids = alias_analysis.GetResourceUniqueIds(result);
88     resources.insert(ids.begin(), ids.end());
89   }
90   return resources;
91 }
92 
93 // Helper struct defining what memory effects are present for a resource.
94 struct SideEffects {
95   bool alloc = false;
96   bool free = false;
97   bool read = false;
98   bool write = false;
99 
IsAllocOnlymlir::TF::__anon6d956f0b0111::SideEffects100   bool IsAllocOnly() const { return alloc && !free && !read && !write; }
IsReadOnlymlir::TF::__anon6d956f0b0111::SideEffects101   bool IsReadOnly() const { return !alloc && !free && read && !write; }
102 };
103 
104 using SideEffectsByValue = llvm::SmallDenseMap<Value, SideEffects>;
105 
MustExecute(const MemoryEffects::EffectInstance & effect)106 bool MustExecute(const MemoryEffects::EffectInstance& effect) {
107   VLOG(1) << "MustExecute check with: "
108           << std::string(effect.getResource()->getName());
109   if (llvm::isa<ResourceEffects::TPUEmbedding>(effect.getResource())) {
110     assert(!effect.getValue() && !effect.getParameters() &&
111            isa<MemoryEffects::Write>(effect.getEffect()));
112     return true;
113   }
114   return false;
115 }
116 
117 // Collects memory side effects for an operation by value (operands and
118 // results).
GetSideEffectsByValue(Operation * op,SideEffectsByValue & side_effects_by_value,bool & must_execute)119 void GetSideEffectsByValue(Operation* op,
120                            SideEffectsByValue& side_effects_by_value,
121                            bool& must_execute) {
122   VLOG(1) << "Querying for " << mlir::debugString(*op);
123   auto interface = dyn_cast<MemoryEffectOpInterface>(op);
124   if (!interface) return;
125 
126   llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
127   interface.getEffects(effects);
128 
129   for (auto& effect : effects) {
130     if (MustExecute(effect)) {
131       VLOG(1) << "\tmust execute";
132       must_execute = true;
133       continue;
134     }
135 
136     // TODO(lyandy): Support effects with no value defined.
137     if (!effect.getValue()) {
138       VLOG(1) << "\teffect with no value, skipping";
139       side_effects_by_value.clear();
140       must_execute = false;
141       return;
142     }
143     auto it = side_effects_by_value.try_emplace(effect.getValue());
144     auto& side_effect = it.first->getSecond();
145     auto* resource_effect = effect.getEffect();
146     if (isa<MemoryEffects::Allocate>(resource_effect)) {
147       VLOG(1) << "\tallocate effect";
148       side_effect.alloc = true;
149     } else if (isa<MemoryEffects::Free>(resource_effect)) {
150       VLOG(1) << "\tfree effect";
151       side_effect.free = true;
152     } else if (isa<MemoryEffects::Read>(resource_effect)) {
153       VLOG(1) << "\tread effect";
154       side_effect.read = true;
155     } else if (isa<MemoryEffects::Write>(resource_effect)) {
156       VLOG(1) << "\twrite effect";
157       side_effect.write = true;
158     } else {
159       VLOG(1) << "\tunknown effect, skipping";
160       side_effects_by_value.clear();
161       must_execute = false;
162       return;
163     }
164   }
165 }
166 
167 // Checks if a value is a result of `op`.
IsOperationResult(Operation * op,Value value)168 bool IsOperationResult(Operation* op, Value value) {
169   return value.getDefiningOp() == op;
170 }
171 
172 // Checks if an operation's resource operands are read only. Operation results
173 // are ignored.
IsResourceOpReadOnly(Operation * op,const SideEffectsByValue & side_effects_by_value)174 bool IsResourceOpReadOnly(Operation* op,
175                           const SideEffectsByValue& side_effects_by_value) {
176   if (side_effects_by_value.empty()) return false;
177 
178   for (const auto& value_side_effect : side_effects_by_value) {
179     Value value = value_side_effect.getFirst();
180     if (IsOperationResult(op, value)) continue;
181     const SideEffects& side_effects = value_side_effect.getSecond();
182     if (!side_effects.IsReadOnly()) return false;
183   }
184 
185   return true;
186 }
187 
188 // Checks if an operation's resource results are alloc only and no side effects
189 // are present for its operands.
IsResourceOpAllocOnly(Operation * op,const SideEffectsByValue & side_effects_by_value)190 bool IsResourceOpAllocOnly(Operation* op,
191                            const SideEffectsByValue& side_effects_by_value) {
192   if (side_effects_by_value.empty()) return false;
193 
194   for (const auto& value_side_effect : side_effects_by_value) {
195     // Operand with side effect.
196     Value value = value_side_effect.getFirst();
197     if (!IsOperationResult(op, value)) return false;
198     const SideEffects& side_effects = value_side_effect.getSecond();
199     if (!side_effects.IsAllocOnly()) return false;
200   }
201 
202   return true;
203 }
204 
205 // Returns if `op` is a resource declaration.
OpIsDeclaration(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis)206 bool OpIsDeclaration(Operation* op,
207                      const ResourceAliasAnalysis::Info& alias_analysis) {
208   return llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
209          !FindAccessedResources(op, alias_analysis).empty();
210 }
211 
212 // A vector of resource variable id's with their associated resource value.
213 using ResourceIdsByValue =
214     llvm::SmallVector<std::pair<Value, const llvm::SmallSet<int64_t, 8>*>, 4>;
215 
216 // Collects resource id's by resource value. If operation resource side effects
217 // are unknown or a resource is unknown, an empty optional is returned.
GetResourceIdsByValue(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis,const SideEffectsByValue & side_effects_by_value)218 llvm::Optional<ResourceIdsByValue> GetResourceIdsByValue(
219     Operation* op, const ResourceAliasAnalysis::Info& alias_analysis,
220     const SideEffectsByValue& side_effects_by_value) {
221   ResourceIdsByValue resource_ids_by_value;
222   if (side_effects_by_value.empty()) return llvm::None;
223 
224   // Returns true iff all side-effect-related values are known to
225   // `alias_analysis`.
226   auto collect_ids = [&](ValueRange values) {
227     for (auto value : values) {
228       // Value is not related to any side-effect, skip.
229       if (side_effects_by_value.count(value) == 0) continue;
230       // Value is not a resource variable, thus not known to `alias_analysis`.
231       if (!getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>())
232         return false;
233       // Value is a resource variable not known to `alias_analysis`.
234       if (alias_analysis.IsUnknownResource(value)) return false;
235       // Value is a resource variable known to `alias_analysis`.
236       const auto& ids = alias_analysis.GetResourceUniqueIds(value);
237       resource_ids_by_value.push_back({value, &ids});
238     }
239     return true;
240   };
241 
242   if (collect_ids(op->getOperands()) && collect_ids(op->getResults()))
243     // No unknown side-effect-related values.
244     return resource_ids_by_value;
245   else
246     return llvm::None;
247 }
248 
249 // Returns true if `op` is known to not have any side effect.
OpIsKnownToHaveNoSideEffect(Operation * op)250 bool OpIsKnownToHaveNoSideEffect(Operation* op) {
251   // For op's in the Tensorflow dialect, query the dialect.
252   if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
253     return !TensorFlowDialect::CanHaveSideEffects(op);
254 
255   // Otherwise, conservatively assume that there can be side effects.
256   return false;
257 }
258 
259 }  // namespace
260 
261 namespace detail {
262 //===----------------------------------------------------------------------===//
263 // SideEffectAnalysisInfo
264 //===----------------------------------------------------------------------===//
265 
TrackAccess(int64_t resource_id,Operation * op,bool read_only)266 void SideEffectAnalysisInfo::TrackAccess(int64_t resource_id, Operation* op,
267                                          bool read_only) {
268   VLOG(1) << "TrackAccess for " << debugString(*op);
269   if (resource_id == kUnknownResourceId) {
270     VLOG(1) << "\tunknown resource id";
271     if (read_only) {
272       // New unknown read is not tracked by any known resource access.
273       for (auto& entry : per_resource_access_info_) {
274         entry.getSecond().tracked_last_unknown_read = false;
275       }
276     } else {
277       // Unknown write can clear all other tracked information, since it acts
278       // like a barrier.
279       VLOG(1) << "\tclearing per resource access info";
280       per_resource_access_info_.clear();
281     }
282   }
283   VLOG(1) << "\tinfo for " << resource_id;
284   auto& info = per_resource_access_info_[resource_id];
285   if (read_only) {
286     info.reads_since_last_write.push_back(op);
287     // Resource read must have carried control dependencies of unknown write. It
288     // can only avoid adding control edges (from uknown accesses) for a later
289     // write, but not for a later read, because this read can be reordered with
290     // a later read.
291     info.tracked_last_unknown_write_for_write = true;
292   } else {
293     // Resource write must have carried control dependencies of unknown access.
294     info.tracked_last_unknown_write_for_read = true;
295     info.tracked_last_unknown_write_for_write = true;
296     info.tracked_last_unknown_read = true;
297     info.last_write = op;
298     info.reads_since_last_write.clear();
299   }
300 }
301 
AddPredecessorsForAccess(int64_t resource_id,Operation * op,bool read_only)302 void SideEffectAnalysisInfo::AddPredecessorsForAccess(int64_t resource_id,
303                                                       Operation* op,
304                                                       bool read_only) {
305   VLOG(1) << "Adding predecessors for resource " << resource_id << " and op "
306           << debugString(*op);
307   auto it = per_resource_access_info_.find(resource_id);
308   if (it == per_resource_access_info_.end()) return;
309   const auto& access_info = it->getSecond();
310   auto& control_predecessors = control_predecessors_[op];
311   bool read_tracked = false;
312   if (!read_only) {
313     control_predecessors.insert(access_info.reads_since_last_write.begin(),
314                                 access_info.reads_since_last_write.end());
315     read_tracked = !access_info.reads_since_last_write.empty();
316   }
317   if (access_info.last_write && !read_tracked) {
318     control_predecessors.insert(access_info.last_write);
319   }
320 }
321 
AnalyzeFunction(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)322 void SideEffectAnalysisInfo::AnalyzeFunction(
323     FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
324   // AnalyzeRegion() recursively analyzes the function body, and only populates
325   // control_predecessors_.
326   AnalyzeRegion(&func_op.getBody(), alias_analysis);
327   // Populate sorted_control_predecessors_ and sorted_control_successors_ based
328   // on control_predecessors.
329   for (auto& entry : control_predecessors_) {
330     auto op = entry.getFirst();
331     auto& sorted_predecessors = sorted_control_predecessors_[op];
332     for (auto predecessor : entry.getSecond()) {
333       sorted_predecessors.push_back(predecessor);
334       sorted_control_successors_[predecessor].push_back(op);
335     }
336   }
337   control_predecessors_.clear();
338   for (auto& entry : sorted_control_predecessors_) {
339     llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
340       return a->isBeforeInBlock(b);
341     });
342   }
343   for (auto& entry : sorted_control_successors_) {
344     llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
345       return a->isBeforeInBlock(b);
346     });
347   }
348 }
349 
AnalyzeRegion(Region * region,const TF::ResourceAliasAnalysis::Info & alias_analysis)350 void SideEffectAnalysisInfo::AnalyzeRegion(
351     Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
352   // This function populates control_predecessors_ by walking through the
353   // region, and tracking resource accesses in per_resource_access_info_.
354 
355   // Returns whether an access to `resource` can skip control edges from
356   // previous accesses to unknown resources, due to that earlier accesses to
357   // `resource` already indirectly tracked previous accesses to unknown
358   // resources. `read_only` specifies the type of access of the current op being
359   // considered.
360   auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource,
361                                                            bool read_only) {
362     VLOG(1) << "\tunknown access indirectly tracked by resource " << resource;
363     auto it = per_resource_access_info_.find(resource);
364     if (it == per_resource_access_info_.end()) {
365       VLOG(1) << "\t\tnot found";
366       return false;
367     }
368     auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
369     const bool no_unknown_read =
370         unknown_it == per_resource_access_info_.end() ||
371         unknown_it->getSecond().reads_since_last_write.empty();
372     bool ret = read_only ? it->second.tracked_last_unknown_write_for_read
373                          : it->second.tracked_last_unknown_write_for_write &&
374                                (it->second.tracked_last_unknown_read ||
375                                 no_unknown_read);
376     VLOG(1) << "\t\tunknown access inderictly tracked by resource: " << ret;
377     return ret;
378   };
379 
380   // We explicitly iterates through the regions and blocks, in order to handle
381   // different nested regions separately.
382   for (auto& block : *region) {
383     llvm::SmallPtrSet<Operation*, 8> non_resource_control_predecessors;
384     for (auto& op : block) {
385       for (Region& child : op.getRegions()) {
386         SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
387         // Moves the control_predecessors_ fields in child region to current
388         // region
389         for (auto& entry : child_analysis.control_predecessors_)
390           control_predecessors_[entry.first] = std::move(entry.second);
391       }
392 
393       // We do not need explicit control edges for declaration ops.
394       if (OpIsDeclaration(&op, alias_analysis)) continue;
395 
396       SideEffectsByValue side_effects_by_value;
397       bool must_execute = false;
398       GetSideEffectsByValue(&op, side_effects_by_value, must_execute);
399 
400       if (side_effects_by_value.empty() && OpIsKnownToHaveNoSideEffect(&op))
401         continue;
402 
403       // TODO(jpienaar): This only currently uses unknown when not per value
404       // resource is used.
405       if (side_effects_by_value.empty() && must_execute) {
406         VLOG(1) << "No resources & must execute: " << debugString(op);
407         // Add unknown resource ops as predecessors of the op that must execute,
408         // to guarantee ordering between unknown resource ops.
409         AddPredecessorsForAccess(kUnknownResourceId, &op, /*read_only=*/false);
410         non_resource_control_predecessors.insert(&op);
411         continue;
412       }
413 
414       if (IsResourceOpAllocOnly(&op, side_effects_by_value)) {
415         VLOG(1) << "Resource alloc only: " << debugString(op);
416         continue;
417       }
418 
419       auto resource_ids_by_value =
420           GetResourceIdsByValue(&op, alias_analysis, side_effects_by_value);
421       const bool read_only = IsResourceOpReadOnly(&op, side_effects_by_value);
422       bool indirectly_tracked_unknown_access = false;
423       // First add edges from known resources.
424       if (!resource_ids_by_value.hasValue()) {
425         VLOG(1) << "Resource not by value: " << debugString(op);
426         for (auto& entry : per_resource_access_info_) {
427           if (entry.getFirst() == kUnknownResourceId) {
428             VLOG(1) << "\tskipping over unknown resource id";
429             continue;
430           }
431           AddPredecessorsForAccess(entry.getFirst(), &op, read_only);
432           indirectly_tracked_unknown_access |=
433               unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
434                                                             read_only);
435         }
436       } else {
437         // Collect all resource id's and whether their side effect is read only.
438         llvm::SmallDenseMap<int64_t, bool> read_only_by_resource_id;
439         for (const auto& resource_ids : *resource_ids_by_value) {
440           const bool is_result = resource_ids.first.getDefiningOp() == &op;
441           auto value_side_effect =
442               side_effects_by_value.find(resource_ids.first);
443           bool resource_read_only = false;
444           if (value_side_effect != side_effects_by_value.end()) {
445             if (is_result && value_side_effect->getSecond().IsAllocOnly())
446               continue;
447             resource_read_only = value_side_effect->getSecond().IsReadOnly();
448           }
449 
450           for (const auto& id : *resource_ids.second) {
451             auto it =
452                 read_only_by_resource_id.try_emplace(id, resource_read_only);
453             if (!it.second && !resource_read_only)
454               it.first->getSecond() = resource_read_only;
455           }
456         }
457 
458         for (const auto& resource : read_only_by_resource_id) {
459           const auto& resource_id = resource.getFirst();
460           const auto& resource_read_only = resource.getSecond();
461           AddPredecessorsForAccess(resource_id, &op, resource_read_only);
462           indirectly_tracked_unknown_access |=
463               unknown_access_indirectly_tracked_by_resource(resource_id,
464                                                             resource_read_only);
465           // Update access info for known resources.
466           TrackAccess(resource_id, &op, resource_read_only);
467         }
468       }
469 
470       // If not indirectly tracked, add edges from the resource.
471       if (!indirectly_tracked_unknown_access) {
472         VLOG(1) << "Not indirectly tracked with unknown access: "
473                 << debugString(op);
474         if (auto interface = dyn_cast<MemoryEffectOpInterface>(op)) {
475           llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
476           interface.getEffects(effects);
477         }
478         AddPredecessorsForAccess(kUnknownResourceId, &op, read_only);
479       }
480       if (!resource_ids_by_value.hasValue()) {
481         VLOG(1) << "Indirectly tracked with no value: " << debugString(op);
482 
483         // Update access info for unknown resource.
484         TrackAccess(kUnknownResourceId, &op, read_only);
485         // Add ops that must execute to unknown resource op predecessors.
486         auto& control_predecessors = control_predecessors_[&op];
487         control_predecessors.insert(non_resource_control_predecessors.begin(),
488                                     non_resource_control_predecessors.end());
489         // Ops that must execute currently tracked are cleared as transitively
490         // unknown resource ops will allow for such ops to be transitively
491         // reachable.
492         non_resource_control_predecessors.clear();
493       }
494     }
495   }
496 }
497 
498 llvm::SmallVector<Operation*, 4>
DirectControlPredecessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const499 SideEffectAnalysisInfo::DirectControlPredecessors(
500     Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
501   llvm::SmallVector<Operation*, 4> result;
502   auto it = sorted_control_predecessors_.find(op);
503   if (it == sorted_control_predecessors_.end()) return result;
504   result.reserve(it->getSecond().size());
505   for (auto predecessor : it->getSecond()) {
506     if (!filter || filter(predecessor)) result.push_back(predecessor);
507   }
508   return result;
509 }
510 
511 llvm::SmallVector<Operation*, 4>
DirectControlSuccessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const512 SideEffectAnalysisInfo::DirectControlSuccessors(
513     Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
514   llvm::SmallVector<Operation*, 4> result;
515   auto it = sorted_control_successors_.find(op);
516   if (it == sorted_control_successors_.end()) return result;
517   result.reserve(it->getSecond().size());
518   for (auto successor : it->getSecond()) {
519     if (!filter || filter(successor)) result.push_back(successor);
520   }
521   return result;
522 }
523 }  // namespace detail
524 
SideEffectAnalysis(ModuleOp module)525 SideEffectAnalysis::SideEffectAnalysis(ModuleOp module) {
526   // Analyze entire module for alias analysis info.
527   ResourceAliasAnalysis alias_analysis(module);
528 
529   // Analyze all functions.
530   for (auto func : module.getOps<FuncOp>())
531     this->info_map_.try_emplace(func, func,
532                                 alias_analysis.GetAnalysisForFunc(func));
533 }
534 
535 }  // namespace TF
536 }  // namespace mlir
537