1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
16 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
17 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
18 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
19 #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
20 #include "tfrt/compiler/stream_analysis.h" // from @tf_runtime
21
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25
26 // This pass inserts copy kernels for fallback tensors when they are passed to
27 // multiple threads, to avoid atomic contention on their refcounts.
28 class InsertFallbackTensorCopy
29 : public mlir::PassWrapper<InsertFallbackTensorCopy,
30 mlir::OperationPass<mlir::FuncOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const31 void getDependentDialects(mlir::DialectRegistry& registry) const override {
32 registry.insert<tfrt::fallback_async::FallbackAsyncDialect>();
33 }
34
getArgument() const35 llvm::StringRef getArgument() const final {
36 return "tfrt-insert-fallback-tensor-copy";
37 }
38
getDescription() const39 llvm::StringRef getDescription() const final {
40 return "Inserts copy kernels for fallback tensors when they are passed to "
41 "multiple threads, to avoid atomic contention on refcounts.";
42 }
43
44 public:
runOnOperation()45 void runOnOperation() override {
46 mlir::FuncOp func_op = getOperation();
47
48 // Use stream analysis to know whether a value is passed to different
49 // threads.
50 tfrt::compiler::StreamAnalysis stream_analysis(func_op);
51
52 auto builder = mlir::OpBuilder::atBlockBegin(&func_op.front());
53
54 // Process function arguments first.
55 for (auto arg : func_op.getArguments()) {
56 if (!arg.getType().isa<tfrt::fallback::TFTensorType>()) continue;
57 InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder,
58 stream_analysis);
59 }
60
61 // Then process each operations in the block.
62 for (mlir::Operation& op : llvm::make_early_inc_range(func_op.front())) {
63 if (llvm::isa<tfrt::fallback_async::ExecuteOp,
64 tfrt::fallback_async::ExecuteOpSeq>(&op)) {
65 InsertFallbackTensorCopyForFallbackOp(&op, builder, stream_analysis);
66 }
67 }
68 }
69
70 private:
InsertFallbackTensorCopyForFallbackOp(mlir::Operation * op,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)71 void InsertFallbackTensorCopyForFallbackOp(
72 mlir::Operation* op, mlir::OpBuilder& builder,
73 const tfrt::compiler::StreamAnalysis& stream_analysis) {
74 builder.setInsertionPointAfter(op);
75
76 // Process each result value.
77 for (auto result : op->getResults()) {
78 if (!result.getType().isa<tfrt::fallback::TFTensorType>()) continue;
79 InsertFallbackTensorCopyForValue(result, op->getLoc(), builder,
80 stream_analysis);
81 }
82 }
83
84 // Insert copy kernels to copy the result, and allocate new atomic refcount
85 // if the value is going to be used by different streams/threads, in order to
86 // avoid contention on the atomic counter.
InsertFallbackTensorCopyForValue(mlir::Value value,mlir::Location loc,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)87 void InsertFallbackTensorCopyForValue(
88 mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder,
89 const tfrt::compiler::StreamAnalysis& stream_analysis) {
90 llvm::DenseMap<int, llvm::SmallVector<mlir::OpOperand*, 4>> stream_map;
91
92 // Find out streams that use this value and the corresponding uses.
93 for (mlir::OpOperand& use : value.getUses()) {
94 // Skip return op as there should not be atomic contention on the return
95 // op.
96 if (llvm::isa<tfrt::compiler::ReturnOp>(use.getOwner())) continue;
97
98 int stream_id = stream_analysis.GetStream(use.getOwner()).id();
99 stream_map[stream_id].push_back(&use);
100 }
101
102 // Organize these uses into groups. If a stream has many uses of this value,
103 // put these uses into one stream. Otherwise, streams with small number
104 // of uses are grouped with each other to form groups with enough uses.
105 constexpr int kCopyGroupThreshold = 16;
106 llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> small_copies;
107 llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> copies;
108 for (const auto& iter : stream_map) {
109 if (iter.second.size() >= kCopyGroupThreshold) {
110 copies.push_back(iter.second);
111 } else {
112 if (small_copies.empty() ||
113 small_copies.back().size() >= kCopyGroupThreshold) {
114 small_copies.push_back(iter.second);
115 } else {
116 small_copies.back().append(iter.second.begin(), iter.second.end());
117 }
118 }
119 }
120
121 if (!small_copies.empty())
122 copies.append(small_copies.begin(), small_copies.end());
123
124 // If it is only used by one group, then we don't need to copy.
125 if (copies.size() <= 1) return;
126
127 // Remove one group from the candidates, as we can just use the original
128 // value for this group.
129 copies.pop_back();
130
131 // For each stream, we will create one new value that replaces the uses in
132 // that stream.
133
134 assert(value.getType().isa<tfrt::fallback::TFTensorType>());
135
136 // The number of results is the number candidate streams.
137 llvm::SmallVector<mlir::Type, 4> result_types(copies.size(),
138 value.getType());
139 assert(!result_types.empty());
140
141 // Create the tfrt_fallback_async.copy_if_small kernel.
142 auto copy_op = builder.create<tfrt::fallback_async::CopyIfSmallOp>(
143 loc, result_types, value);
144
145 // Finally, replaces all uses with the new value.
146 for (int i = 0; i < copies.size(); ++i) {
147 const auto& uses = copies[i];
148 auto new_value = copy_op.getResult(i);
149 for (auto* use : uses) {
150 use->set(new_value);
151 }
152 }
153 }
154 };
155
156 } // namespace
157
158 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
CreateInsertFallbackTensorCopyPass()159 CreateInsertFallbackTensorCopyPass() {
160 return std::make_unique<InsertFallbackTensorCopy>();
161 }
162
163 static mlir::PassRegistration<InsertFallbackTensorCopy> register_pass(
164 CreateInsertFallbackTensorCopyPass);
165
166 } // namespace tfrt_compiler
167 } // namespace tensorflow
168