• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/eigen_support.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/lite/arena_planner.h"
20 #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 
23 namespace tflite {
24 namespace eigen_support {
25 namespace {
26 
27 // For legacy reasons, we use 4 threads by default unless the thread count is
28 // explicitly specified by the context.
29 const int kDefaultNumThreadpoolThreads = 4;
30 
31 #ifndef EIGEN_DONT_ALIGN
32 // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on
33 // hardware architecture and build configurations.
34 // If the static assertion fails, try to increase `kDefaultTensorAlignment` to
35 // in `arena_planner.h` to 32 or 64.
36 static_assert(
37     kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
38     "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
39 #endif  // EIGEN_DONT_ALIGN
40 
41 // Helper routine for updating the global Eigen thread count used for OpenMP.
SetEigenNbThreads(int threads)42 void SetEigenNbThreads(int threads) {
43 #if defined(EIGEN_HAS_OPENMP)
44   // The global Eigen thread count is only used when OpenMP is enabled. As this
45   // call causes problems with tsan, make it only when OpenMP is available.
46   Eigen::setNbThreads(threads);
47 #endif  // defined(EIGEN_HAS_OPENMP)
48 }
49 
50 // We have a single global threadpool for all convolution operations. This means
51 // that inferences started from different threads may block each other, but
52 // since the underlying resource of CPU cores should be consumed by the
53 // operations anyway, it shouldn't affect overall performance.
54 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
55  public:
56   // Takes ownership of 'pool'
EigenThreadPoolWrapper(Eigen::ThreadPool * pool)57   explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper()58   ~EigenThreadPoolWrapper() override {}
59 
Schedule(std::function<void ()> fn)60   void Schedule(std::function<void()> fn) override {
61     pool_->Schedule(std::move(fn));
62   }
NumThreads() const63   int NumThreads() const override { return pool_->NumThreads(); }
CurrentThreadId() const64   int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
65 
66  private:
67   std::unique_ptr<Eigen::ThreadPool> pool_;
68 };
69 
70 // Utility class for lazily creating an Eigen thread pool/device only when used.
71 class LazyEigenThreadPoolHolder {
72  public:
LazyEigenThreadPoolHolder(int num_threads)73   explicit LazyEigenThreadPoolHolder(int num_threads) {
74     SetNumThreads(num_threads);
75   }
76 
77   // Gets the ThreadPoolDevice, creating if necessary.
GetThreadPoolDevice()78   const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
79     if (!device_) {
80       thread_pool_wrapper_.reset(new EigenThreadPoolWrapper(
81           new Eigen::ThreadPool(target_num_threads_)));
82       device_.reset(new Eigen::ThreadPoolDevice(thread_pool_wrapper_.get(),
83                                                 target_num_threads_));
84     }
85     return device_.get();
86   }
87 
88   // Updates the thread count, invalidating the ThreadPoolDevice if necessary.
SetNumThreads(int num_threads)89   void SetNumThreads(int num_threads) {
90     const int target_num_threads =
91         num_threads != -1 ? num_threads : kDefaultNumThreadpoolThreads;
92     if (target_num_threads_ != target_num_threads) {
93       target_num_threads_ = target_num_threads;
94       // As the device references the thread pool wrapper, destroy it first.
95       device_.reset();
96       thread_pool_wrapper_.reset();
97     }
98   }
99 
100  private:
101   int target_num_threads_ = kDefaultNumThreadpoolThreads;
102   // Both device_ and thread_pool_wrapper_ are lazily created.
103   std::unique_ptr<Eigen::ThreadPoolDevice> device_;
104   std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
105 };
106 
107 struct RefCountedEigenContext : public TfLiteExternalContext {
108   std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
109   int num_references = 0;
110 };
111 
GetEigenContext(TfLiteContext * context)112 RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
113   return reinterpret_cast<RefCountedEigenContext*>(
114       context->GetExternalContext(context, kTfLiteEigenContext));
115 }
116 
Refresh(TfLiteContext * context)117 TfLiteStatus Refresh(TfLiteContext* context) {
118   SetEigenNbThreads(context->recommended_num_threads);
119 
120   auto* ptr = GetEigenContext(context);
121   if (ptr != nullptr) {
122     ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
123   }
124 
125   return kTfLiteOk;
126 }
127 
128 }  // namespace
129 
IncrementUsageCounter(TfLiteContext * context)130 void IncrementUsageCounter(TfLiteContext* context) {
131   auto* ptr = GetEigenContext(context);
132   if (ptr == nullptr) {
133     if (context->recommended_num_threads != -1) {
134       SetEigenNbThreads(context->recommended_num_threads);
135     }
136     ptr = new RefCountedEigenContext;
137     ptr->type = kTfLiteEigenContext;
138     ptr->Refresh = Refresh;
139     ptr->thread_pool_holder.reset(
140         new LazyEigenThreadPoolHolder(context->recommended_num_threads));
141     ptr->num_references = 0;
142     context->SetExternalContext(context, kTfLiteEigenContext, ptr);
143   }
144   ptr->num_references++;
145 }
146 
DecrementUsageCounter(TfLiteContext * context)147 void DecrementUsageCounter(TfLiteContext* context) {
148   auto* ptr = GetEigenContext(context);
149   if (ptr == nullptr) {
150     TF_LITE_FATAL(
151         "Call to DecrementUsageCounter() not preceded by "
152         "IncrementUsageCounter()");
153   }
154   if (--ptr->num_references == 0) {
155     delete ptr;
156     context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
157   }
158 }
159 
GetThreadPoolDevice(TfLiteContext * context)160 const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
161   auto* ptr = GetEigenContext(context);
162   if (ptr == nullptr) {
163     TF_LITE_FATAL(
164         "Call to GetFromContext() not preceded by IncrementUsageCounter()");
165   }
166   return ptr->thread_pool_holder->GetThreadPoolDevice();
167 }
168 
169 }  // namespace eigen_support
170 }  // namespace tflite
171