1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/eigen_support.h"
16
17 #include <utility>
18
19 #include "tensorflow/lite/arena_planner.h"
20 #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22
23 namespace tflite {
24 namespace eigen_support {
25 namespace {
26
27 // For legacy reasons, we use 4 threads by default unless the thread count is
28 // explicitly specified by the context.
29 const int kDefaultNumThreadpoolThreads = 4;
30
31 #ifndef EIGEN_DONT_ALIGN
32 // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on
33 // hardware architecture and build configurations.
34 // If the static assertion fails, try to increase `kDefaultTensorAlignment` to
35 // in `arena_planner.h` to 32 or 64.
36 static_assert(
37 kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
38 "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
39 #endif // EIGEN_DONT_ALIGN
40
41 // Helper routine for updating the global Eigen thread count used for OpenMP.
SetEigenNbThreads(int threads)42 void SetEigenNbThreads(int threads) {
43 #if defined(EIGEN_HAS_OPENMP)
44 // The global Eigen thread count is only used when OpenMP is enabled. As this
45 // call causes problems with tsan, make it only when OpenMP is available.
46 Eigen::setNbThreads(threads);
47 #endif // defined(EIGEN_HAS_OPENMP)
48 }
49
50 // We have a single global threadpool for all convolution operations. This means
51 // that inferences started from different threads may block each other, but
52 // since the underlying resource of CPU cores should be consumed by the
53 // operations anyway, it shouldn't affect overall performance.
54 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
55 public:
56 // Takes ownership of 'pool'
EigenThreadPoolWrapper(Eigen::ThreadPool * pool)57 explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper()58 ~EigenThreadPoolWrapper() override {}
59
Schedule(std::function<void ()> fn)60 void Schedule(std::function<void()> fn) override {
61 pool_->Schedule(std::move(fn));
62 }
NumThreads() const63 int NumThreads() const override { return pool_->NumThreads(); }
CurrentThreadId() const64 int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
65
66 private:
67 std::unique_ptr<Eigen::ThreadPool> pool_;
68 };
69
70 // Utility class for lazily creating an Eigen thread pool/device only when used.
71 class LazyEigenThreadPoolHolder {
72 public:
LazyEigenThreadPoolHolder(int num_threads)73 explicit LazyEigenThreadPoolHolder(int num_threads) {
74 SetNumThreads(num_threads);
75 }
76
77 // Gets the ThreadPoolDevice, creating if necessary.
GetThreadPoolDevice()78 const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
79 if (!device_) {
80 thread_pool_wrapper_.reset(new EigenThreadPoolWrapper(
81 new Eigen::ThreadPool(target_num_threads_)));
82 device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
83 target_num_threads_));
84 }
85 return device_.get();
86 }
87
88 // Updates the thread count, invalidating the ThreadPoolDevice if necessary.
SetNumThreads(int num_threads)89 void SetNumThreads(int num_threads) {
90 const int target_num_threads =
91 num_threads != -1 ? num_threads : kDefaultNumThreadpoolThreads;
92 if (target_num_threads_ != target_num_threads) {
93 target_num_threads_ = target_num_threads;
94 // As the device references the thread pool wrapper, destroy it first.
95 device_.reset();
96 thread_pool_wrapper_.reset();
97 }
98 }
99
100 private:
101 int target_num_threads_ = kDefaultNumThreadpoolThreads;
102 // Both device_ and thread_pool_wrapper_ are lazily created.
103 std::unique_ptr<Eigen::ThreadPoolDevice> device_;
104 std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
105 };
106
107 struct RefCountedEigenContext : public TfLiteExternalContext {
108 std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
109 int num_references = 0;
110 };
111
GetEigenContext(TfLiteContext * context)112 RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
113 return reinterpret_cast<RefCountedEigenContext*>(
114 context->GetExternalContext(context, kTfLiteEigenContext));
115 }
116
Refresh(TfLiteContext * context)117 TfLiteStatus Refresh(TfLiteContext* context) {
118 SetEigenNbThreads(context->recommended_num_threads);
119
120 auto* ptr = GetEigenContext(context);
121 if (ptr != nullptr) {
122 ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
123 }
124
125 return kTfLiteOk;
126 }
127
128 } // namespace
129
IncrementUsageCounter(TfLiteContext * context)130 void IncrementUsageCounter(TfLiteContext* context) {
131 auto* ptr = GetEigenContext(context);
132 if (ptr == nullptr) {
133 if (context->recommended_num_threads != -1) {
134 SetEigenNbThreads(context->recommended_num_threads);
135 }
136 ptr = new RefCountedEigenContext;
137 ptr->type = kTfLiteEigenContext;
138 ptr->Refresh = Refresh;
139 ptr->thread_pool_holder.reset(
140 new LazyEigenThreadPoolHolder(context->recommended_num_threads));
141 ptr->num_references = 0;
142 context->SetExternalContext(context, kTfLiteEigenContext, ptr);
143 }
144 ptr->num_references++;
145 }
146
DecrementUsageCounter(TfLiteContext * context)147 void DecrementUsageCounter(TfLiteContext* context) {
148 auto* ptr = GetEigenContext(context);
149 if (ptr == nullptr) {
150 TF_LITE_FATAL(
151 "Call to DecrementUsageCounter() not preceded by "
152 "IncrementUsageCounter()");
153 }
154 if (--ptr->num_references == 0) {
155 delete ptr;
156 context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
157 }
158 }
159
GetThreadPoolDevice(TfLiteContext * context)160 const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
161 auto* ptr = GetEigenContext(context);
162 if (ptr == nullptr) {
163 TF_LITE_FATAL(
164 "Call to GetFromContext() not preceded by IncrementUsageCounter()");
165 }
166 return ptr->thread_pool_holder->GetThreadPoolDevice();
167 }
168
169 } // namespace eigen_support
170 } // namespace tflite
171