1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
17
18 #include <cstdint>
19 #include <initializer_list>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/iterator_range.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Block.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Location.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
40 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
41 #include "mlir/Support/LLVM.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
48
49 namespace mlir {
50 namespace TF {
51 namespace {
52
53 constexpr auto kUnknownResourceId =
54 ResourceAliasAnalysis::Info::kUnknownResourceId;
55
56 //===----------------------------------------------------------------------===//
57 // SideEffectAnalysisInfo helper functions.
58 //===----------------------------------------------------------------------===//
59
60 // Returns a set that contains only kUnknownResourceId.
UnknownResourceSet()61 llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
62 llvm::SmallDenseSet<int64_t, 8> unknown_set;
63 unknown_set.insert(kUnknownResourceId);
64 return unknown_set;
65 }
66
67 // Returns all resources that could be accessed by op, or UnknownResourceSet()
68 // if we cannot find all of them.
FindAccessedResources(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis)69 llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
70 Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) {
71 VLOG(1) << "Find accessed resources for: " << debugString(*op);
72 llvm::SmallDenseSet<int64_t, 8> resources;
73
74 for (auto operand : filter_resources(op->getOperands())) {
75 if (alias_analysis.IsUnknownResource(operand)) {
76 VLOG(1) << "\tunknown resource in operand";
77 return UnknownResourceSet();
78 }
79 const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
80 resources.insert(ids.begin(), ids.end());
81 }
82 for (auto result : filter_resources(op->getResults())) {
83 if (alias_analysis.IsUnknownResource(result)) {
84 VLOG(1) << "\tunknown resource in result";
85 return UnknownResourceSet();
86 }
87 const auto& ids = alias_analysis.GetResourceUniqueIds(result);
88 resources.insert(ids.begin(), ids.end());
89 }
90 return resources;
91 }
92
93 // Helper struct defining what memory effects are present for a resource.
94 struct SideEffects {
95 bool alloc = false;
96 bool free = false;
97 bool read = false;
98 bool write = false;
99
IsAllocOnlymlir::TF::__anon6d956f0b0111::SideEffects100 bool IsAllocOnly() const { return alloc && !free && !read && !write; }
IsReadOnlymlir::TF::__anon6d956f0b0111::SideEffects101 bool IsReadOnly() const { return !alloc && !free && read && !write; }
102 };
103
104 using SideEffectsByValue = llvm::SmallDenseMap<Value, SideEffects>;
105
MustExecute(const MemoryEffects::EffectInstance & effect)106 bool MustExecute(const MemoryEffects::EffectInstance& effect) {
107 VLOG(1) << "MustExecute check with: "
108 << std::string(effect.getResource()->getName());
109 if (llvm::isa<ResourceEffects::TPUEmbedding>(effect.getResource())) {
110 assert(!effect.getValue() && !effect.getParameters() &&
111 isa<MemoryEffects::Write>(effect.getEffect()));
112 return true;
113 }
114 return false;
115 }
116
117 // Collects memory side effects for an operation by value (operands and
118 // results).
GetSideEffectsByValue(Operation * op,SideEffectsByValue & side_effects_by_value,bool & must_execute)119 void GetSideEffectsByValue(Operation* op,
120 SideEffectsByValue& side_effects_by_value,
121 bool& must_execute) {
122 VLOG(1) << "Querying for " << mlir::debugString(*op);
123 auto interface = dyn_cast<MemoryEffectOpInterface>(op);
124 if (!interface) return;
125
126 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
127 interface.getEffects(effects);
128
129 for (auto& effect : effects) {
130 if (MustExecute(effect)) {
131 VLOG(1) << "\tmust execute";
132 must_execute = true;
133 continue;
134 }
135
136 // TODO(lyandy): Support effects with no value defined.
137 if (!effect.getValue()) {
138 VLOG(1) << "\teffect with no value, skipping";
139 side_effects_by_value.clear();
140 must_execute = false;
141 return;
142 }
143 auto it = side_effects_by_value.try_emplace(effect.getValue());
144 auto& side_effect = it.first->getSecond();
145 auto* resource_effect = effect.getEffect();
146 if (isa<MemoryEffects::Allocate>(resource_effect)) {
147 VLOG(1) << "\tallocate effect";
148 side_effect.alloc = true;
149 } else if (isa<MemoryEffects::Free>(resource_effect)) {
150 VLOG(1) << "\tfree effect";
151 side_effect.free = true;
152 } else if (isa<MemoryEffects::Read>(resource_effect)) {
153 VLOG(1) << "\tread effect";
154 side_effect.read = true;
155 } else if (isa<MemoryEffects::Write>(resource_effect)) {
156 VLOG(1) << "\twrite effect";
157 side_effect.write = true;
158 } else {
159 VLOG(1) << "\tunknown effect, skipping";
160 side_effects_by_value.clear();
161 must_execute = false;
162 return;
163 }
164 }
165 }
166
167 // Checks if a value is a result of `op`.
IsOperationResult(Operation * op,Value value)168 bool IsOperationResult(Operation* op, Value value) {
169 return value.getDefiningOp() == op;
170 }
171
172 // Checks if an operation's resource operands are read only. Operation results
173 // are ignored.
IsResourceOpReadOnly(Operation * op,const SideEffectsByValue & side_effects_by_value)174 bool IsResourceOpReadOnly(Operation* op,
175 const SideEffectsByValue& side_effects_by_value) {
176 if (side_effects_by_value.empty()) return false;
177
178 for (const auto& value_side_effect : side_effects_by_value) {
179 Value value = value_side_effect.getFirst();
180 if (IsOperationResult(op, value)) continue;
181 const SideEffects& side_effects = value_side_effect.getSecond();
182 if (!side_effects.IsReadOnly()) return false;
183 }
184
185 return true;
186 }
187
188 // Checks if an operation's resource results are alloc only and no side effects
189 // are present for its operands.
IsResourceOpAllocOnly(Operation * op,const SideEffectsByValue & side_effects_by_value)190 bool IsResourceOpAllocOnly(Operation* op,
191 const SideEffectsByValue& side_effects_by_value) {
192 if (side_effects_by_value.empty()) return false;
193
194 for (const auto& value_side_effect : side_effects_by_value) {
195 // Operand with side effect.
196 Value value = value_side_effect.getFirst();
197 if (!IsOperationResult(op, value)) return false;
198 const SideEffects& side_effects = value_side_effect.getSecond();
199 if (!side_effects.IsAllocOnly()) return false;
200 }
201
202 return true;
203 }
204
205 // Returns if `op` is a resource declaration.
OpIsDeclaration(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis)206 bool OpIsDeclaration(Operation* op,
207 const ResourceAliasAnalysis::Info& alias_analysis) {
208 return llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
209 !FindAccessedResources(op, alias_analysis).empty();
210 }
211
212 // A vector of resource variable id's with their associated resource value.
213 using ResourceIdsByValue =
214 llvm::SmallVector<std::pair<Value, const llvm::SmallSet<int64_t, 8>*>, 4>;
215
216 // Collects resource id's by resource value. If operation resource side effects
217 // are unknown or a resource is unknown, an empty optional is returned.
GetResourceIdsByValue(Operation * op,const ResourceAliasAnalysis::Info & alias_analysis,const SideEffectsByValue & side_effects_by_value)218 llvm::Optional<ResourceIdsByValue> GetResourceIdsByValue(
219 Operation* op, const ResourceAliasAnalysis::Info& alias_analysis,
220 const SideEffectsByValue& side_effects_by_value) {
221 ResourceIdsByValue resource_ids_by_value;
222 if (side_effects_by_value.empty()) return llvm::None;
223
224 // Returns true iff all side-effect-related values are known to
225 // `alias_analysis`.
226 auto collect_ids = [&](ValueRange values) {
227 for (auto value : values) {
228 // Value is not related to any side-effect, skip.
229 if (side_effects_by_value.count(value) == 0) continue;
230 // Value is not a resource variable, thus not known to `alias_analysis`.
231 if (!getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>())
232 return false;
233 // Value is a resource variable not known to `alias_analysis`.
234 if (alias_analysis.IsUnknownResource(value)) return false;
235 // Value is a resource variable known to `alias_analysis`.
236 const auto& ids = alias_analysis.GetResourceUniqueIds(value);
237 resource_ids_by_value.push_back({value, &ids});
238 }
239 return true;
240 };
241
242 if (collect_ids(op->getOperands()) && collect_ids(op->getResults()))
243 // No unknown side-effect-related values.
244 return resource_ids_by_value;
245 else
246 return llvm::None;
247 }
248
249 // Returns true if `op` is known to not have any side effect.
OpIsKnownToHaveNoSideEffect(Operation * op)250 bool OpIsKnownToHaveNoSideEffect(Operation* op) {
251 // For op's in the Tensorflow dialect, query the dialect.
252 if (isa_and_nonnull<TF::TensorFlowDialect>(op->getDialect()))
253 return !TensorFlowDialect::CanHaveSideEffects(op);
254
255 // Otherwise, conservatively assume that there can be side effects.
256 return false;
257 }
258
259 } // namespace
260
261 namespace detail {
262 //===----------------------------------------------------------------------===//
263 // SideEffectAnalysisInfo
264 //===----------------------------------------------------------------------===//
265
TrackAccess(int64_t resource_id,Operation * op,bool read_only)266 void SideEffectAnalysisInfo::TrackAccess(int64_t resource_id, Operation* op,
267 bool read_only) {
268 VLOG(1) << "TrackAccess for " << debugString(*op);
269 if (resource_id == kUnknownResourceId) {
270 VLOG(1) << "\tunknown resource id";
271 if (read_only) {
272 // New unknown read is not tracked by any known resource access.
273 for (auto& entry : per_resource_access_info_) {
274 entry.getSecond().tracked_last_unknown_read = false;
275 }
276 } else {
277 // Unknown write can clear all other tracked information, since it acts
278 // like a barrier.
279 VLOG(1) << "\tclearing per resource access info";
280 per_resource_access_info_.clear();
281 }
282 }
283 VLOG(1) << "\tinfo for " << resource_id;
284 auto& info = per_resource_access_info_[resource_id];
285 if (read_only) {
286 info.reads_since_last_write.push_back(op);
287 // Resource read must have carried control dependencies of unknown write. It
288 // can only avoid adding control edges (from uknown accesses) for a later
289 // write, but not for a later read, because this read can be reordered with
290 // a later read.
291 info.tracked_last_unknown_write_for_write = true;
292 } else {
293 // Resource write must have carried control dependencies of unknown access.
294 info.tracked_last_unknown_write_for_read = true;
295 info.tracked_last_unknown_write_for_write = true;
296 info.tracked_last_unknown_read = true;
297 info.last_write = op;
298 info.reads_since_last_write.clear();
299 }
300 }
301
AddPredecessorsForAccess(int64_t resource_id,Operation * op,bool read_only)302 void SideEffectAnalysisInfo::AddPredecessorsForAccess(int64_t resource_id,
303 Operation* op,
304 bool read_only) {
305 VLOG(1) << "Adding predecessors for resource " << resource_id << " and op "
306 << debugString(*op);
307 auto it = per_resource_access_info_.find(resource_id);
308 if (it == per_resource_access_info_.end()) return;
309 const auto& access_info = it->getSecond();
310 auto& control_predecessors = control_predecessors_[op];
311 bool read_tracked = false;
312 if (!read_only) {
313 control_predecessors.insert(access_info.reads_since_last_write.begin(),
314 access_info.reads_since_last_write.end());
315 read_tracked = !access_info.reads_since_last_write.empty();
316 }
317 if (access_info.last_write && !read_tracked) {
318 control_predecessors.insert(access_info.last_write);
319 }
320 }
321
AnalyzeFunction(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)322 void SideEffectAnalysisInfo::AnalyzeFunction(
323 FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
324 // AnalyzeRegion() recursively analyzes the function body, and only populates
325 // control_predecessors_.
326 AnalyzeRegion(&func_op.getBody(), alias_analysis);
327 // Populate sorted_control_predecessors_ and sorted_control_successors_ based
328 // on control_predecessors.
329 for (auto& entry : control_predecessors_) {
330 auto op = entry.getFirst();
331 auto& sorted_predecessors = sorted_control_predecessors_[op];
332 for (auto predecessor : entry.getSecond()) {
333 sorted_predecessors.push_back(predecessor);
334 sorted_control_successors_[predecessor].push_back(op);
335 }
336 }
337 control_predecessors_.clear();
338 for (auto& entry : sorted_control_predecessors_) {
339 llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
340 return a->isBeforeInBlock(b);
341 });
342 }
343 for (auto& entry : sorted_control_successors_) {
344 llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
345 return a->isBeforeInBlock(b);
346 });
347 }
348 }
349
AnalyzeRegion(Region * region,const TF::ResourceAliasAnalysis::Info & alias_analysis)350 void SideEffectAnalysisInfo::AnalyzeRegion(
351 Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
352 // This function populates control_predecessors_ by walking through the
353 // region, and tracking resource accesses in per_resource_access_info_.
354
355 // Returns whether an access to `resource` can skip control edges from
356 // previous accesses to unknown resources, due to that earlier accesses to
357 // `resource` already indirectly tracked previous accesses to unknown
358 // resources. `read_only` specifies the type of access of the current op being
359 // considered.
360 auto unknown_access_indirectly_tracked_by_resource = [&](int64_t resource,
361 bool read_only) {
362 VLOG(1) << "\tunknown access indirectly tracked by resource " << resource;
363 auto it = per_resource_access_info_.find(resource);
364 if (it == per_resource_access_info_.end()) {
365 VLOG(1) << "\t\tnot found";
366 return false;
367 }
368 auto unknown_it = per_resource_access_info_.find(kUnknownResourceId);
369 const bool no_unknown_read =
370 unknown_it == per_resource_access_info_.end() ||
371 unknown_it->getSecond().reads_since_last_write.empty();
372 bool ret = read_only ? it->second.tracked_last_unknown_write_for_read
373 : it->second.tracked_last_unknown_write_for_write &&
374 (it->second.tracked_last_unknown_read ||
375 no_unknown_read);
376 VLOG(1) << "\t\tunknown access inderictly tracked by resource: " << ret;
377 return ret;
378 };
379
380 // We explicitly iterates through the regions and blocks, in order to handle
381 // different nested regions separately.
382 for (auto& block : *region) {
383 llvm::SmallPtrSet<Operation*, 8> non_resource_control_predecessors;
384 for (auto& op : block) {
385 for (Region& child : op.getRegions()) {
386 SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
387 // Moves the control_predecessors_ fields in child region to current
388 // region
389 for (auto& entry : child_analysis.control_predecessors_)
390 control_predecessors_[entry.first] = std::move(entry.second);
391 }
392
393 // We do not need explicit control edges for declaration ops.
394 if (OpIsDeclaration(&op, alias_analysis)) continue;
395
396 SideEffectsByValue side_effects_by_value;
397 bool must_execute = false;
398 GetSideEffectsByValue(&op, side_effects_by_value, must_execute);
399
400 if (side_effects_by_value.empty() && OpIsKnownToHaveNoSideEffect(&op))
401 continue;
402
403 // TODO(jpienaar): This only currently uses unknown when not per value
404 // resource is used.
405 if (side_effects_by_value.empty() && must_execute) {
406 VLOG(1) << "No resources & must execute: " << debugString(op);
407 // Add unknown resource ops as predecessors of the op that must execute,
408 // to guarantee ordering between unknown resource ops.
409 AddPredecessorsForAccess(kUnknownResourceId, &op, /*read_only=*/false);
410 non_resource_control_predecessors.insert(&op);
411 continue;
412 }
413
414 if (IsResourceOpAllocOnly(&op, side_effects_by_value)) {
415 VLOG(1) << "Resource alloc only: " << debugString(op);
416 continue;
417 }
418
419 auto resource_ids_by_value =
420 GetResourceIdsByValue(&op, alias_analysis, side_effects_by_value);
421 const bool read_only = IsResourceOpReadOnly(&op, side_effects_by_value);
422 bool indirectly_tracked_unknown_access = false;
423 // First add edges from known resources.
424 if (!resource_ids_by_value.hasValue()) {
425 VLOG(1) << "Resource not by value: " << debugString(op);
426 for (auto& entry : per_resource_access_info_) {
427 if (entry.getFirst() == kUnknownResourceId) {
428 VLOG(1) << "\tskipping over unknown resource id";
429 continue;
430 }
431 AddPredecessorsForAccess(entry.getFirst(), &op, read_only);
432 indirectly_tracked_unknown_access |=
433 unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
434 read_only);
435 }
436 } else {
437 // Collect all resource id's and whether their side effect is read only.
438 llvm::SmallDenseMap<int64_t, bool> read_only_by_resource_id;
439 for (const auto& resource_ids : *resource_ids_by_value) {
440 const bool is_result = resource_ids.first.getDefiningOp() == &op;
441 auto value_side_effect =
442 side_effects_by_value.find(resource_ids.first);
443 bool resource_read_only = false;
444 if (value_side_effect != side_effects_by_value.end()) {
445 if (is_result && value_side_effect->getSecond().IsAllocOnly())
446 continue;
447 resource_read_only = value_side_effect->getSecond().IsReadOnly();
448 }
449
450 for (const auto& id : *resource_ids.second) {
451 auto it =
452 read_only_by_resource_id.try_emplace(id, resource_read_only);
453 if (!it.second && !resource_read_only)
454 it.first->getSecond() = resource_read_only;
455 }
456 }
457
458 for (const auto& resource : read_only_by_resource_id) {
459 const auto& resource_id = resource.getFirst();
460 const auto& resource_read_only = resource.getSecond();
461 AddPredecessorsForAccess(resource_id, &op, resource_read_only);
462 indirectly_tracked_unknown_access |=
463 unknown_access_indirectly_tracked_by_resource(resource_id,
464 resource_read_only);
465 // Update access info for known resources.
466 TrackAccess(resource_id, &op, resource_read_only);
467 }
468 }
469
470 // If not indirectly tracked, add edges from the resource.
471 if (!indirectly_tracked_unknown_access) {
472 VLOG(1) << "Not indirectly tracked with unknown access: "
473 << debugString(op);
474 if (auto interface = dyn_cast<MemoryEffectOpInterface>(op)) {
475 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
476 interface.getEffects(effects);
477 }
478 AddPredecessorsForAccess(kUnknownResourceId, &op, read_only);
479 }
480 if (!resource_ids_by_value.hasValue()) {
481 VLOG(1) << "Indirectly tracked with no value: " << debugString(op);
482
483 // Update access info for unknown resource.
484 TrackAccess(kUnknownResourceId, &op, read_only);
485 // Add ops that must execute to unknown resource op predecessors.
486 auto& control_predecessors = control_predecessors_[&op];
487 control_predecessors.insert(non_resource_control_predecessors.begin(),
488 non_resource_control_predecessors.end());
489 // Ops that must execute currently tracked are cleared as transitively
490 // unknown resource ops will allow for such ops to be transitively
491 // reachable.
492 non_resource_control_predecessors.clear();
493 }
494 }
495 }
496 }
497
498 llvm::SmallVector<Operation*, 4>
DirectControlPredecessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const499 SideEffectAnalysisInfo::DirectControlPredecessors(
500 Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
501 llvm::SmallVector<Operation*, 4> result;
502 auto it = sorted_control_predecessors_.find(op);
503 if (it == sorted_control_predecessors_.end()) return result;
504 result.reserve(it->getSecond().size());
505 for (auto predecessor : it->getSecond()) {
506 if (!filter || filter(predecessor)) result.push_back(predecessor);
507 }
508 return result;
509 }
510
511 llvm::SmallVector<Operation*, 4>
DirectControlSuccessors(Operation * op,llvm::function_ref<bool (Operation *)> filter) const512 SideEffectAnalysisInfo::DirectControlSuccessors(
513 Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
514 llvm::SmallVector<Operation*, 4> result;
515 auto it = sorted_control_successors_.find(op);
516 if (it == sorted_control_successors_.end()) return result;
517 result.reserve(it->getSecond().size());
518 for (auto successor : it->getSecond()) {
519 if (!filter || filter(successor)) result.push_back(successor);
520 }
521 return result;
522 }
523 } // namespace detail
524
SideEffectAnalysis(ModuleOp module)525 SideEffectAnalysis::SideEffectAnalysis(ModuleOp module) {
526 // Analyze entire module for alias analysis info.
527 ResourceAliasAnalysis alias_analysis(module);
528
529 // Analyze all functions.
530 for (auto func : module.getOps<FuncOp>())
531 this->info_map_.try_emplace(func, func,
532 alias_analysis.GetAnalysisForFunc(func));
533 }
534
535 } // namespace TF
536 } // namespace mlir
537