• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/common_runtime/local_device.h"
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/common_runtime/process_state.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/lib/core/threadpool.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/cpu_feature_guard.h"
26 #include "tensorflow/core/platform/cpu_info.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/numa.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/util/env_var.h"
32 
33 namespace tensorflow {
34 namespace {
35 
OverrideGlobalThreadPoolFromEnvironment()36 bool OverrideGlobalThreadPoolFromEnvironment() {
37   static const bool override_global_threadpool = [] {
38     bool flag;
39     auto status = ReadBoolFromEnvVar("TF_OVERRIDE_GLOBAL_THREADPOOL",
40                                      /*default_val=*/false, &flag);
41     if (!status.ok()) {
42       LOG(ERROR) << "OverrideGlobalThreadPool: " << status.error_message();
43       return false;
44     }
45     return flag;
46   }();
47   return override_global_threadpool;
48 }
49 
50 }  // namespace
51 
52 /* static */
53 bool LocalDevice::use_global_threadpool_ = true;
54 mutex LocalDevice::global_tp_mu_;
55 gtl::InlinedVector<LocalDevice::EigenThreadPoolInfo*, 4>
56     LocalDevice::global_tp_info_;
57 
58 struct LocalDevice::EigenThreadPoolInfo {
59   // Wrapper so we can provide the CPUAllocator to Eigen for use
60   // when ops need extra tmp memory.
61   class EigenAllocator : public Eigen::Allocator {
62    public:
EigenAllocator(tensorflow::Allocator * a)63     explicit EigenAllocator(tensorflow::Allocator* a) : allocator_(a) {}
allocate(size_t num_bytes) const64     void* allocate(size_t num_bytes) const override {
65       return allocator_->AllocateRaw(64, num_bytes);
66     }
deallocate(void * buffer) const67     void deallocate(void* buffer) const override {
68       allocator_->DeallocateRaw(buffer);
69     }
70     tensorflow::Allocator* allocator_;
71   };
72 
EigenThreadPoolInfotensorflow::LocalDevice::EigenThreadPoolInfo73   explicit EigenThreadPoolInfo(const SessionOptions& options, int numa_node,
74                                Allocator* allocator) {
75     // Use session setting if specified.
76     int32_t intra_op_parallelism_threads =
77         options.config.intra_op_parallelism_threads();
78     // If no session setting, use environment setting.
79     if (intra_op_parallelism_threads == 0) {
80       static int env_num_threads = NumIntraOpThreadsFromEnvironment();
81       intra_op_parallelism_threads = env_num_threads;
82       // If no session setting or environment, compute a reasonable default.
83       if (intra_op_parallelism_threads == 0) {
84         intra_op_parallelism_threads = port::MaxParallelism(numa_node);
85       }
86     }
87     ThreadOptions thread_opts;
88     thread_opts.numa_node = numa_node;
89     eigen_worker_threads_.num_threads = intra_op_parallelism_threads;
90     eigen_worker_threads_.workers = new thread::ThreadPool(
91         options.env, thread_opts, strings::StrCat("numa_", numa_node, "_Eigen"),
92         intra_op_parallelism_threads,
93         !options.config.experimental().disable_thread_spinning(),
94         /*allocator=*/nullptr);
95     Eigen::ThreadPoolInterface* threadpool =
96         eigen_worker_threads_.workers->AsEigenThreadPool();
97     if (allocator != nullptr) {
98       eigen_allocator_.reset(new EigenAllocator(allocator));
99     }
100     eigen_device_.reset(new Eigen::ThreadPoolDevice(
101         threadpool, eigen_worker_threads_.num_threads, eigen_allocator_.get()));
102   }
103 
~EigenThreadPoolInfotensorflow::LocalDevice::EigenThreadPoolInfo104   ~EigenThreadPoolInfo() {
105     eigen_device_.reset();
106     delete eigen_worker_threads_.workers;
107   }
108 
109   DeviceBase::CpuWorkerThreads eigen_worker_threads_;
110   std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
111   std::unique_ptr<EigenAllocator> eigen_allocator_;
112 };
113 
LocalDevice(const SessionOptions & options,const DeviceAttributes & attributes)114 LocalDevice::LocalDevice(const SessionOptions& options,
115                          const DeviceAttributes& attributes)
116     : Device(options.env, attributes), owned_tp_info_(nullptr) {
117   // Log info messages if TensorFlow is not compiled with instructions that
118   // could speed up performance and are available on the current CPU.
119   port::InfoAboutUnusedCPUFeatures();
120   LocalDevice::EigenThreadPoolInfo* tp_info;
121 
122   if (OverrideGlobalThreadPoolFromEnvironment()) {
123     set_use_global_threadpool(false);
124   }
125 
126   if (use_global_threadpool_) {
127     mutex_lock l(global_tp_mu_);
128     if (options.config.experimental().use_numa_affinity()) {
129       int numa_node = attributes.locality().numa_node();
130       int num_numa_nodes = port::NUMANumNodes();
131       DCHECK_LT(numa_node, num_numa_nodes);
132       Allocator* numa_allocator =
133           ProcessState::singleton()->GetCPUAllocator(numa_node);
134       while (numa_node >= global_tp_info_.size()) {
135         global_tp_info_.push_back(nullptr);
136       }
137       if (!global_tp_info_[numa_node]) {
138         global_tp_info_[numa_node] = new LocalDevice::EigenThreadPoolInfo(
139             options, numa_node, numa_allocator);
140       }
141       tp_info = global_tp_info_[numa_node];
142     } else {
143       if (global_tp_info_.empty()) {
144         global_tp_info_.push_back(new LocalDevice::EigenThreadPoolInfo(
145             options, port::kNUMANoAffinity, nullptr));
146       }
147       tp_info = global_tp_info_[0];
148     }
149   } else {
150     // Each LocalDevice owns a separate ThreadPoolDevice for numerical
151     // computations.
152     // TODO(tucker): NUMA for these too?
153     owned_tp_info_.reset(new LocalDevice::EigenThreadPoolInfo(
154         options, port::kNUMANoAffinity, nullptr));
155     tp_info = owned_tp_info_.get();
156   }
157   set_tensorflow_cpu_worker_threads(&tp_info->eigen_worker_threads_);
158   set_eigen_cpu_device(tp_info->eigen_device_.get());
159 }
160 
~LocalDevice()161 LocalDevice::~LocalDevice() {}
162 
163 }  // namespace tensorflow
164