1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/eigen_support.h"
16
17 #include <functional>
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/lite/arena_planner.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25
26 namespace tflite {
27 namespace eigen_support {
28 namespace {
29
30 // For legacy reasons, we use 4 threads by default unless the thread count is
31 // explicitly specified by the context.
32 const int kDefaultNumThreadpoolThreads = 4;
33
IsValidNumThreads(int num_threads)34 bool IsValidNumThreads(int num_threads) { return num_threads >= -1; }
GetNumThreads(int num_threads)35 int GetNumThreads(int num_threads) {
36 return num_threads > -1 ? num_threads : kDefaultNumThreadpoolThreads;
37 }
38
39 #ifndef EIGEN_DONT_ALIGN
40 // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on
41 // hardware architecture and build configurations.
42 // If the static assertion fails, try to increase `kDefaultTensorAlignment` to
43 // in `arena_planner.h` to 32 or 64.
44 static_assert(
45 kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
46 "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
47 #endif // EIGEN_DONT_ALIGN
48
49 // Helper routine for updating the global Eigen thread count used for OpenMP.
SetEigenNbThreads(int threads)50 void SetEigenNbThreads(int threads) {
51 #if defined(EIGEN_HAS_OPENMP)
52 // The global Eigen thread count is only used when OpenMP is enabled. As this
53 // call causes problems with tsan, make it only when OpenMP is available.
54 Eigen::setNbThreads(threads);
55 #endif // defined(EIGEN_HAS_OPENMP)
56 }
57
58 // We have a single global threadpool for all convolution operations. This means
59 // that inferences started from different threads may block each other, but
60 // since the underlying resource of CPU cores should be consumed by the
61 // operations anyway, it shouldn't affect overall performance. Note that we
62 // also avoid ThreadPool creation if the target thread count is 1, avoiding
63 // unnecessary overhead, and more closely mimicking Gemmlowp threadpool
64 // behavior.
65 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
66 public:
67 // Takes ownership of 'pool'
EigenThreadPoolWrapper(int num_threads)68 explicit EigenThreadPoolWrapper(int num_threads) {
69 // Avoid creating any threads for the single-threaded case.
70 if (num_threads > 1) {
71 pool_.reset(new Eigen::ThreadPool(num_threads));
72 }
73 }
~EigenThreadPoolWrapper()74 ~EigenThreadPoolWrapper() override {}
75
Schedule(std::function<void ()> fn)76 void Schedule(std::function<void()> fn) override {
77 if (pool_) {
78 pool_->Schedule(std::move(fn));
79 } else {
80 fn();
81 }
82 }
NumThreads() const83 int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; }
CurrentThreadId() const84 int CurrentThreadId() const override {
85 return pool_ ? pool_->CurrentThreadId() : 0;
86 }
87
88 private:
89 // May be null if num_threads <= 1.
90 std::unique_ptr<Eigen::ThreadPool> pool_;
91 };
92
93 // Utility class for lazily creating an Eigen thread pool/device only when used.
94 class LazyEigenThreadPoolHolder {
95 public:
LazyEigenThreadPoolHolder(int num_threads)96 explicit LazyEigenThreadPoolHolder(int num_threads) {
97 SetNumThreads(num_threads);
98 }
99
100 // Gets the ThreadPoolDevice, creating if necessary.
GetThreadPoolDevice()101 const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
102 if (!device_) {
103 thread_pool_wrapper_.reset(
104 new EigenThreadPoolWrapper(target_num_threads_));
105 device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
106 target_num_threads_));
107 }
108 return device_.get();
109 }
110
111 // Updates the thread count, invalidating the ThreadPoolDevice if necessary.
SetNumThreads(int num_threads)112 void SetNumThreads(int num_threads) {
113 const int target_num_threads = GetNumThreads(num_threads);
114 if (target_num_threads_ != target_num_threads) {
115 target_num_threads_ = target_num_threads;
116 // As the device references the thread pool wrapper, destroy it first.
117 device_.reset();
118 thread_pool_wrapper_.reset();
119 }
120 }
121
122 private:
123 int target_num_threads_ = kDefaultNumThreadpoolThreads;
124 // Both device_ and thread_pool_wrapper_ are lazily created.
125 std::unique_ptr<Eigen::ThreadPoolDevice> device_;
126 std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
127 };
128
129 struct RefCountedEigenContext : public TfLiteExternalContext {
130 std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
131 int num_references = 0;
132 };
133
GetEigenContext(TfLiteContext * context)134 RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
135 return reinterpret_cast<RefCountedEigenContext*>(
136 context->GetExternalContext(context, kTfLiteEigenContext));
137 }
138
Refresh(TfLiteContext * context)139 TfLiteStatus Refresh(TfLiteContext* context) {
140 if (IsValidNumThreads(context->recommended_num_threads)) {
141 SetEigenNbThreads(GetNumThreads(context->recommended_num_threads));
142 }
143
144 auto* ptr = GetEigenContext(context);
145 if (ptr != nullptr) {
146 ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
147 }
148
149 return kTfLiteOk;
150 }
151
152 } // namespace
153
IncrementUsageCounter(TfLiteContext * context)154 void IncrementUsageCounter(TfLiteContext* context) {
155 auto* ptr = GetEigenContext(context);
156 if (ptr == nullptr) {
157 if (IsValidNumThreads(context->recommended_num_threads)) {
158 SetEigenNbThreads(context->recommended_num_threads);
159 }
160 ptr = new RefCountedEigenContext;
161 ptr->type = kTfLiteEigenContext;
162 ptr->Refresh = Refresh;
163 ptr->thread_pool_holder.reset(
164 new LazyEigenThreadPoolHolder(context->recommended_num_threads));
165 ptr->num_references = 0;
166 context->SetExternalContext(context, kTfLiteEigenContext, ptr);
167 }
168 ptr->num_references++;
169 }
170
DecrementUsageCounter(TfLiteContext * context)171 void DecrementUsageCounter(TfLiteContext* context) {
172 auto* ptr = GetEigenContext(context);
173 if (ptr == nullptr) {
174 TF_LITE_FATAL(
175 "Call to DecrementUsageCounter() not preceded by "
176 "IncrementUsageCounter()");
177 }
178 if (--ptr->num_references == 0) {
179 delete ptr;
180 context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
181 }
182 }
183
GetThreadPoolDevice(TfLiteContext * context)184 const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
185 auto* ptr = GetEigenContext(context);
186 if (ptr == nullptr) {
187 TF_LITE_FATAL(
188 "Call to GetFromContext() not preceded by IncrementUsageCounter()");
189 }
190 return ptr->thread_pool_holder->GetThreadPoolDevice();
191 }
192
193 } // namespace eigen_support
194 } // namespace tflite
195