• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
16 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
17 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
18 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
19 #include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
20 #include "tfrt/compiler/stream_analysis.h"  // from @tf_runtime
21 
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25 
26 // This pass inserts copy kernels for fallback tensors when they are passed to
27 // multiple threads, to avoid atomic contention on their refcounts.
28 class InsertFallbackTensorCopy
29     : public mlir::PassWrapper<InsertFallbackTensorCopy,
30                                mlir::OperationPass<mlir::FuncOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const31   void getDependentDialects(mlir::DialectRegistry& registry) const override {
32     registry.insert<tfrt::fallback_async::FallbackAsyncDialect>();
33   }
34 
getArgument() const35   llvm::StringRef getArgument() const final {
36     return "tfrt-insert-fallback-tensor-copy";
37   }
38 
getDescription() const39   llvm::StringRef getDescription() const final {
40     return "Inserts copy kernels for fallback tensors when they are passed to "
41            "multiple threads, to avoid atomic contention on refcounts.";
42   }
43 
44  public:
runOnOperation()45   void runOnOperation() override {
46     mlir::FuncOp func_op = getOperation();
47 
48     // Use stream analysis to know whether a value is passed to different
49     // threads.
50     tfrt::compiler::StreamAnalysis stream_analysis(func_op);
51 
52     auto builder = mlir::OpBuilder::atBlockBegin(&func_op.front());
53 
54     // Process function arguments first.
55     for (auto arg : func_op.getArguments()) {
56       if (!arg.getType().isa<tfrt::fallback::TFTensorType>()) continue;
57       InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder,
58                                        stream_analysis);
59     }
60 
61     // Then process each operations in the block.
62     for (mlir::Operation& op : llvm::make_early_inc_range(func_op.front())) {
63       if (llvm::isa<tfrt::fallback_async::ExecuteOp,
64                     tfrt::fallback_async::ExecuteOpSeq>(&op)) {
65         InsertFallbackTensorCopyForFallbackOp(&op, builder, stream_analysis);
66       }
67     }
68   }
69 
70  private:
InsertFallbackTensorCopyForFallbackOp(mlir::Operation * op,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)71   void InsertFallbackTensorCopyForFallbackOp(
72       mlir::Operation* op, mlir::OpBuilder& builder,
73       const tfrt::compiler::StreamAnalysis& stream_analysis) {
74     builder.setInsertionPointAfter(op);
75 
76     // Process each result value.
77     for (auto result : op->getResults()) {
78       if (!result.getType().isa<tfrt::fallback::TFTensorType>()) continue;
79       InsertFallbackTensorCopyForValue(result, op->getLoc(), builder,
80                                        stream_analysis);
81     }
82   }
83 
84   // Insert copy kernels to copy the result, and allocate new atomic refcount
85   // if the value is going to be used by different streams/threads, in order to
86   // avoid contention on the atomic counter.
InsertFallbackTensorCopyForValue(mlir::Value value,mlir::Location loc,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)87   void InsertFallbackTensorCopyForValue(
88       mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder,
89       const tfrt::compiler::StreamAnalysis& stream_analysis) {
90     llvm::DenseMap<int, llvm::SmallVector<mlir::OpOperand*, 4>> stream_map;
91 
92     // Find out streams that use this value and the corresponding uses.
93     for (mlir::OpOperand& use : value.getUses()) {
94       // Skip return op as there should not be atomic contention on the return
95       // op.
96       if (llvm::isa<tfrt::compiler::ReturnOp>(use.getOwner())) continue;
97 
98       int stream_id = stream_analysis.GetStream(use.getOwner()).id();
99       stream_map[stream_id].push_back(&use);
100     }
101 
102     // Organize these uses into groups. If a stream has many uses of this value,
103     // put these uses into one stream. Otherwise, streams with small number
104     // of uses are grouped with each other to form groups with enough uses.
105     constexpr int kCopyGroupThreshold = 16;
106     llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> small_copies;
107     llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> copies;
108     for (const auto& iter : stream_map) {
109       if (iter.second.size() >= kCopyGroupThreshold) {
110         copies.push_back(iter.second);
111       } else {
112         if (small_copies.empty() ||
113             small_copies.back().size() >= kCopyGroupThreshold) {
114           small_copies.push_back(iter.second);
115         } else {
116           small_copies.back().append(iter.second.begin(), iter.second.end());
117         }
118       }
119     }
120 
121     if (!small_copies.empty())
122       copies.append(small_copies.begin(), small_copies.end());
123 
124     // If it is only used by one group, then we don't need to copy.
125     if (copies.size() <= 1) return;
126 
127     // Remove one group from the candidates, as we can just use the original
128     // value for this group.
129     copies.pop_back();
130 
131     // For each stream, we will create one new value that replaces the uses in
132     // that stream.
133 
134     assert(value.getType().isa<tfrt::fallback::TFTensorType>());
135 
136     // The number of results is the number candidate streams.
137     llvm::SmallVector<mlir::Type, 4> result_types(copies.size(),
138                                                   value.getType());
139     assert(!result_types.empty());
140 
141     // Create the tfrt_fallback_async.copy_if_small kernel.
142     auto copy_op = builder.create<tfrt::fallback_async::CopyIfSmallOp>(
143         loc, result_types, value);
144 
145     // Finally, replaces all uses with the new value.
146     for (int i = 0; i < copies.size(); ++i) {
147       const auto& uses = copies[i];
148       auto new_value = copy_op.getResult(i);
149       for (auto* use : uses) {
150         use->set(new_value);
151       }
152     }
153   }
154 };
155 
156 }  // namespace
157 
158 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
CreateInsertFallbackTensorCopyPass()159 CreateInsertFallbackTensorCopyPass() {
160   return std::make_unique<InsertFallbackTensorCopy>();
161 }
162 
163 static mlir::PassRegistration<InsertFallbackTensorCopy> register_pass(
164     CreateInsertFallbackTensorCopyPass);
165 
166 }  // namespace tfrt_compiler
167 }  // namespace tensorflow
168