1
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
18 #define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
19 #ifdef INTEL_MKL
20
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/platform/threadpool.h"
30 #define EIGEN_USE_THREADS
31 #ifdef ENABLE_MKLDNN_THREADPOOL
32 using dnnl::stream_attr;
33 using dnnl::threadpool_iface;
34
35 namespace tensorflow {
36
37 // Divide 'n' units of work equally among 'teams' threads. If 'n' is not
38 // divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
39 // unit of work more than the rest. Returns the range of work that belongs to
40 // the team 'tid'.
41 // Parameters
42 // n Total number of jobs.
43 // team Number of workers.
44 // tid Current thread_id.
45 // n_start start of range operated by the thread.
46 // n_end end of the range operated by the thread.
47
48 template <typename T, typename U>
balance211(T n,U team,U tid,T * n_start,T * n_end)49 inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
50 if (team <= 1 || n == 0) {
51 *n_start = 0;
52 *n_end = n;
53 return;
54 }
55 T min_per_team = n / team;
56 T remainder = n - min_per_team * team; // i.e., n % teams.
57 *n_start = tid * min_per_team + std::min(tid, remainder);
58 *n_end = *n_start + min_per_team + (tid < remainder);
59 }
60
61 struct MklDnnThreadPool : public dnnl::threadpool_iface {
62 MklDnnThreadPool() = default;
63
MklDnnThreadPoolMklDnnThreadPool64 MklDnnThreadPool(OpKernelContext* ctx)
65 : eigen_interface_(ctx->device()
66 ->tensorflow_cpu_worker_threads()
67 ->workers->AsEigenThreadPool()) {}
get_num_threadsMklDnnThreadPool68 virtual int get_num_threads() const override {
69 return eigen_interface_->NumThreads();
70 }
get_in_parallelMklDnnThreadPool71 virtual bool get_in_parallel() const override {
72 return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
73 }
get_flagsMklDnnThreadPool74 virtual uint64_t get_flags() const override { return ASYNCHRONOUS; }
parallel_forMklDnnThreadPool75 virtual void parallel_for(int n,
76 const std::function<void(int, int)>& fn) override {
77 // Should never happen (handled by DNNL)
78 if (n == 0) return;
79
80 // Should never happen (handled by DNNL)
81 if (n == 1) {
82 fn(0, 1);
83 return;
84 }
85
86 int nthr = get_num_threads();
87 int njobs = std::min(n, nthr);
88 bool balance = (nthr < n);
89 for (int i = 0; i < njobs; i++) {
90 eigen_interface_->ScheduleWithHint(
91 [balance, i, n, njobs, fn]() {
92 if (balance) {
93 int start, end;
94 balance211(n, njobs, i, &start, &end);
95 for (int j = start; j < end; j++) fn(j, n);
96 } else {
97 fn(i, n);
98 }
99 },
100 i, i + 1);
101 }
102 }
~MklDnnThreadPoolMklDnnThreadPool103 ~MklDnnThreadPool() {}
104
105 private:
106 Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
107 };
108
109 class MklDnnThreadPoolWrapper {
110 public:
GetInstance()111 static MklDnnThreadPoolWrapper& GetInstance() {
112 static MklDnnThreadPoolWrapper instance_;
113 return instance_;
114 }
CreateThreadPoolPtr(OpKernelContext * ctx)115 MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) {
116 mutex_lock l(m_);
117 if (threadpool_map_.empty() ||
118 threadpool_map_.find(ctx->device()) == threadpool_map_.end()) {
119 auto tp_iface = new MklDnnThreadPool(ctx);
120 threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface));
121 return tp_iface;
122 } else {
123 auto entry = threadpool_map_.find(ctx->device());
124 return entry->second;
125 }
126 }
127
128 private:
129 mutex m_;
130 std::unordered_map<DeviceBase*, MklDnnThreadPool*> threadpool_map_;
MklDnnThreadPoolWrapper()131 MklDnnThreadPoolWrapper() {}
132 MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete;
133 MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete;
~MklDnnThreadPoolWrapper()134 ~MklDnnThreadPoolWrapper() {
135 for (auto& tp : threadpool_map_) {
136 delete tp.second;
137 }
138 }
139 };
140
141 } // namespace tensorflow
142 #endif // ENABLE_MKLDNN_THREADPOOL
143 #endif // INTEL_MKL
144 #endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
145