• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/eigen_support.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/lite/arena_planner.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 
26 namespace tflite {
27 namespace eigen_support {
28 namespace {
29 
30 // For legacy reasons, we use 4 threads by default unless the thread count is
31 // explicitly specified by the context.
32 const int kDefaultNumThreadpoolThreads = 4;
33 
IsValidNumThreads(int num_threads)34 bool IsValidNumThreads(int num_threads) { return num_threads >= -1; }
GetNumThreads(int num_threads)35 int GetNumThreads(int num_threads) {
36   return num_threads > -1 ? num_threads : kDefaultNumThreadpoolThreads;
37 }
38 
39 #ifndef EIGEN_DONT_ALIGN
40 // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on
41 // hardware architecture and build configurations.
42 // If the static assertion fails, try to increase `kDefaultTensorAlignment` to
43 // in `arena_planner.h` to 32 or 64.
44 static_assert(
45     kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
46     "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
47 #endif  // EIGEN_DONT_ALIGN
48 
49 // Helper routine for updating the global Eigen thread count used for OpenMP.
SetEigenNbThreads(int threads)50 void SetEigenNbThreads(int threads) {
51 #if defined(EIGEN_HAS_OPENMP)
52   // The global Eigen thread count is only used when OpenMP is enabled. As this
53   // call causes problems with tsan, make it only when OpenMP is available.
54   Eigen::setNbThreads(threads);
55 #endif  // defined(EIGEN_HAS_OPENMP)
56 }
57 
58 // We have a single global threadpool for all convolution operations. This means
59 // that inferences started from different threads may block each other, but
60 // since the underlying resource of CPU cores should be consumed by the
61 // operations anyway, it shouldn't affect overall performance. Note that we
62 // also avoid ThreadPool creation if the target thread count is 1, avoiding
63 // unnecessary overhead, and more closely mimicking Gemmlowp threadpool
64 // behavior.
65 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
66  public:
67   // Takes ownership of 'pool'
EigenThreadPoolWrapper(int num_threads)68   explicit EigenThreadPoolWrapper(int num_threads) {
69     // Avoid creating any threads for the single-threaded case.
70     if (num_threads > 1) {
71       pool_.reset(new Eigen::ThreadPool(num_threads));
72     }
73   }
~EigenThreadPoolWrapper()74   ~EigenThreadPoolWrapper() override {}
75 
Schedule(std::function<void ()> fn)76   void Schedule(std::function<void()> fn) override {
77     if (pool_) {
78       pool_->Schedule(std::move(fn));
79     } else {
80       fn();
81     }
82   }
NumThreads() const83   int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; }
CurrentThreadId() const84   int CurrentThreadId() const override {
85     return pool_ ? pool_->CurrentThreadId() : 0;
86   }
87 
88  private:
89   // May be null if num_threads <= 1.
90   std::unique_ptr<Eigen::ThreadPool> pool_;
91 };
92 
93 // Utility class for lazily creating an Eigen thread pool/device only when used.
94 class LazyEigenThreadPoolHolder {
95  public:
LazyEigenThreadPoolHolder(int num_threads)96   explicit LazyEigenThreadPoolHolder(int num_threads) {
97     SetNumThreads(num_threads);
98   }
99 
100   // Gets the ThreadPoolDevice, creating if necessary.
GetThreadPoolDevice()101   const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
102     if (!device_) {
103       thread_pool_wrapper_.reset(
104           new EigenThreadPoolWrapper(target_num_threads_));
105       device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
106                                                 target_num_threads_));
107     }
108     return device_.get();
109   }
110 
111   // Updates the thread count, invalidating the ThreadPoolDevice if necessary.
SetNumThreads(int num_threads)112   void SetNumThreads(int num_threads) {
113     const int target_num_threads = GetNumThreads(num_threads);
114     if (target_num_threads_ != target_num_threads) {
115       target_num_threads_ = target_num_threads;
116       // As the device references the thread pool wrapper, destroy it first.
117       device_.reset();
118       thread_pool_wrapper_.reset();
119     }
120   }
121 
122  private:
123   int target_num_threads_ = kDefaultNumThreadpoolThreads;
124   // Both device_ and thread_pool_wrapper_ are lazily created.
125   std::unique_ptr<Eigen::ThreadPoolDevice> device_;
126   std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
127 };
128 
129 struct RefCountedEigenContext : public TfLiteExternalContext {
130   std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
131   int num_references = 0;
132 };
133 
GetEigenContext(TfLiteContext * context)134 RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
135   return reinterpret_cast<RefCountedEigenContext*>(
136       context->GetExternalContext(context, kTfLiteEigenContext));
137 }
138 
Refresh(TfLiteContext * context)139 TfLiteStatus Refresh(TfLiteContext* context) {
140   if (IsValidNumThreads(context->recommended_num_threads)) {
141     SetEigenNbThreads(GetNumThreads(context->recommended_num_threads));
142   }
143 
144   auto* ptr = GetEigenContext(context);
145   if (ptr != nullptr) {
146     ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
147   }
148 
149   return kTfLiteOk;
150 }
151 
152 }  // namespace
153 
IncrementUsageCounter(TfLiteContext * context)154 void IncrementUsageCounter(TfLiteContext* context) {
155   auto* ptr = GetEigenContext(context);
156   if (ptr == nullptr) {
157     if (IsValidNumThreads(context->recommended_num_threads)) {
158       SetEigenNbThreads(context->recommended_num_threads);
159     }
160     ptr = new RefCountedEigenContext;
161     ptr->type = kTfLiteEigenContext;
162     ptr->Refresh = Refresh;
163     ptr->thread_pool_holder.reset(
164         new LazyEigenThreadPoolHolder(context->recommended_num_threads));
165     ptr->num_references = 0;
166     context->SetExternalContext(context, kTfLiteEigenContext, ptr);
167   }
168   ptr->num_references++;
169 }
170 
DecrementUsageCounter(TfLiteContext * context)171 void DecrementUsageCounter(TfLiteContext* context) {
172   auto* ptr = GetEigenContext(context);
173   if (ptr == nullptr) {
174     TF_LITE_FATAL(
175         "Call to DecrementUsageCounter() not preceded by "
176         "IncrementUsageCounter()");
177   }
178   if (--ptr->num_references == 0) {
179     delete ptr;
180     context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
181   }
182 }
183 
GetThreadPoolDevice(TfLiteContext * context)184 const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
185   auto* ptr = GetEigenContext(context);
186   if (ptr == nullptr) {
187     TF_LITE_FATAL(
188         "Call to GetFromContext() not preceded by IncrementUsageCounter()");
189   }
190   return ptr->thread_pool_holder->GetThreadPoolDevice();
191 }
192 
193 }  // namespace eigen_support
194 }  // namespace tflite
195