• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
18 #define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
19 #ifdef INTEL_MKL
20 
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 #include "mkldnn.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/platform/threadpool.h"
30 #define EIGEN_USE_THREADS
31 #ifdef ENABLE_MKLDNN_THREADPOOL
32 using dnnl::stream_attr;
33 using dnnl::threadpool_iface;
34 
35 namespace tensorflow {
36 
37 // Divide 'n' units of work equally among 'teams' threads. If 'n' is not
38 // divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
39 // unit of work more than the rest. Returns the range of work that belongs to
40 // the team 'tid'.
41 // Parameters
42 //   n        Total number of jobs.
43 //   team     Number of workers.
44 //   tid      Current thread_id.
45 //   n_start  start of range operated by the thread.
46 //   n_end    end of the range operated by the thread.
47 
48 template <typename T, typename U>
balance211(T n,U team,U tid,T * n_start,T * n_end)49 inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
50   if (team <= 1 || n == 0) {
51     *n_start = 0;
52     *n_end = n;
53     return;
54   }
55   T min_per_team = n / team;
56   T remainder = n - min_per_team * team;  // i.e., n % teams.
57   *n_start = tid * min_per_team + std::min(tid, remainder);
58   *n_end = *n_start + min_per_team + (tid < remainder);
59 }
60 
61 struct MklDnnThreadPool : public dnnl::threadpool_iface {
62   MklDnnThreadPool() = default;
63 
MklDnnThreadPoolMklDnnThreadPool64   MklDnnThreadPool(OpKernelContext* ctx)
65       : eigen_interface_(ctx->device()
66                              ->tensorflow_cpu_worker_threads()
67                              ->workers->AsEigenThreadPool()) {}
get_num_threadsMklDnnThreadPool68   virtual int get_num_threads() const override {
69     return eigen_interface_->NumThreads();
70   }
get_in_parallelMklDnnThreadPool71   virtual bool get_in_parallel() const override {
72     return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
73   }
get_flagsMklDnnThreadPool74   virtual uint64_t get_flags() const override { return ASYNCHRONOUS; }
parallel_forMklDnnThreadPool75   virtual void parallel_for(int n,
76                             const std::function<void(int, int)>& fn) override {
77     // Should never happen (handled by DNNL)
78     if (n == 0) return;
79 
80     // Should never happen (handled by DNNL)
81     if (n == 1) {
82       fn(0, 1);
83       return;
84     }
85 
86     int nthr = get_num_threads();
87     int njobs = std::min(n, nthr);
88     bool balance = (nthr < n);
89     for (int i = 0; i < njobs; i++) {
90       eigen_interface_->ScheduleWithHint(
91           [balance, i, n, njobs, fn]() {
92             if (balance) {
93               int start, end;
94               balance211(n, njobs, i, &start, &end);
95               for (int j = start; j < end; j++) fn(j, n);
96             } else {
97               fn(i, n);
98             }
99           },
100           i, i + 1);
101     }
102   }
~MklDnnThreadPoolMklDnnThreadPool103   ~MklDnnThreadPool() {}
104 
105  private:
106   Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
107 };
108 
109 class MklDnnThreadPoolWrapper {
110  public:
GetInstance()111   static MklDnnThreadPoolWrapper& GetInstance() {
112     static MklDnnThreadPoolWrapper instance_;
113     return instance_;
114   }
CreateThreadPoolPtr(OpKernelContext * ctx)115   MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) {
116     mutex_lock l(m_);
117     if (threadpool_map_.empty() ||
118         threadpool_map_.find(ctx->device()) == threadpool_map_.end()) {
119       auto tp_iface = new MklDnnThreadPool(ctx);
120       threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface));
121       return tp_iface;
122     } else {
123       auto entry = threadpool_map_.find(ctx->device());
124       return entry->second;
125     }
126   }
127 
128  private:
129   mutex m_;
130   std::unordered_map<DeviceBase*, MklDnnThreadPool*> threadpool_map_;
MklDnnThreadPoolWrapper()131   MklDnnThreadPoolWrapper() {}
132   MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete;
133   MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete;
~MklDnnThreadPoolWrapper()134   ~MklDnnThreadPoolWrapper() {
135     for (auto& tp : threadpool_map_) {
136       delete tp.second;
137     }
138   }
139 };
140 
141 }  // namespace tensorflow
142 #endif  // ENABLE_MKLDNN_THREADPOOL
143 #endif  // INTEL_MKL
144 #endif  // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
145