• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp
2index b678200a1..09736ccae 100644
3--- a/src/cpu/nchw_pooling.cpp
4+++ b/src/cpu/nchw_pooling.cpp
5@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
6     int od_end = min(OD, 1 + (padF + ID - 1) / SD);
7
8     dim_t c_blk = pd()->channel_block_size_;
9-    int c_blk_tail = C % c_blk;
10+    dim_t c_blk_tail = C % c_blk;
11+    const int nthr = pd()->nthr_;
12+
13     if (alg == alg_kind::pooling_max) {
14-        parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
15-                [&](int ithr, int, int mb, int cb) {
16+        parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
17+                [&](int ithr, int, dim_t mb, dim_t cb) {
18                     bool is_last_c_block
19                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
20                     int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
21@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
22                             diff_src_fp32, src_sp_size * curr_c_block);
23                 });
24     } else {
25-        parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
26-                [&](int ithr, int, int mb, int cb) {
27+        parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
28+                [&](int ithr, int, dim_t mb, dim_t cb) {
29                     bool is_last_c_block
30                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
31                     int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
32diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp
33index 9d649f3f5..2a73f6ae6 100644
34--- a/src/cpu/nchw_pooling.hpp
35+++ b/src/cpu/nchw_pooling.hpp
36@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
37                 ws_md_ = *hint_fwd_pd_->workspace_md();
38             }
39
40+            nthr_ = dnnl_get_max_threads();
41             calculate_channel_block_size();
42             init_scratchpad();
43
44@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
45         }
46
47         dim_t channel_block_size_;
48+        int nthr_; // To not exceed the limit in execute used for set up.
49
50     private:
51         void init_scratchpad() {
52@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t {
53             if (diff_dst_md()->data_type == data_type::bf16) {
54                 size_t dst_sz_ = OD() * OH() * OW();
55                 size_t src_sz_ = ID() * IH() * IW();
56-                size_t nthrs = dnnl_get_max_threads();
57                 auto scratchpad = scratchpad_registry().registrar();
58
59                 scratchpad.template book<float>(key_pool_src_bf16cvt,
60-                        src_sz_ * nthrs * channel_block_size_);
61+                        src_sz_ * nthr_ * channel_block_size_);
62                 scratchpad.template book<float>(key_pool_dst_bf16cvt,
63-                        dst_sz_ * nthrs * channel_block_size_);
64+                        dst_sz_ * nthr_ * channel_block_size_);
65             }
66         }
67
68@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
69             // spatial
70             dim_t dst_sz_ = OD() * OH() * OW();
71             dim_t src_sz_ = ID() * IH() * IW();
72-            dim_t nthrs = dnnl_get_max_threads();
73-            dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C());
74+            dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C());
75             const dim_t max_block_size
76                     = platform::get_per_core_cache_size(1) / 2;
77             dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16
78diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp
79index 48d9e1240..efe3083f7 100644
80--- a/src/cpu/nhwc_pooling.cpp
81+++ b/src/cpu/nhwc_pooling.cpp
82@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
83         return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
84     };
85     const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
86+    const int nthr = pd()->nthr_;
87
88-    parallel_nd_ext(0, MB, OD, OH, OW,
89+    parallel_nd_ext(nthr, MB, OD, OH, OW,
90             [&](int ithr, int, int mb, int od, int oh, int ow) {
91                 const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
92                         od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
93@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
94     auto apply_offset = [=](int index, int offset) {
95         return (index > offset) ? index - offset : 0;
96     };
97+    const int nthr = pd()->nthr_;
98
99-    parallel_nd_ext(0, MB, ID, IH, IW,
100+    parallel_nd_ext(nthr, MB, ID, IH, IW,
101             [&](int ithr, int, int mb, int id, int ih, int iw) {
102                 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
103                         id, diff_src_d_stride, ih, diff_src_h_stride, iw,
104diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp
105index c65196a94..c16e840a2 100644
106--- a/src/cpu/nhwc_pooling.hpp
107+++ b/src/cpu/nhwc_pooling.hpp
108@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t {
109                 init_default_ws();
110             }
111
112+            nthr_ = dnnl_get_max_threads();
113             init_scratchpad();
114
115             return status::success;
116         }
117
118+        int nthr_; // To not exceed the limit in execute used for set up.
119+
120     private:
121         void init_scratchpad() {
122             using namespace memory_tracking::names;
123             if (src_md()->data_type == data_type::bf16) {
124-                const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
125+                const size_t bf16cvt_sz_ = C() * nthr_;
126                 auto scratchpad = scratchpad_registry().registrar();
127                 scratchpad.template book<float>(
128                         key_pool_src_bf16cvt, bf16cvt_sz_);
129@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t {
130                 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
131             }
132
133+            nthr_ = dnnl_get_max_threads();
134             init_scratchpad();
135
136             return status::success;
137         }
138
139+        int nthr_; // To not exceed the limit in execute used for set up.
140+
141     private:
142         void init_scratchpad() {
143             using namespace memory_tracking::names;
144             if (diff_src_md()->data_type == data_type::bf16) {
145-                size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
146+                size_t bf16cvt_sz_ = C() * nthr_;
147                 auto scratchpad = scratchpad_registry().registrar();
148                 scratchpad.template book<float>(
149                         key_pool_src_bf16cvt, bf16cvt_sz_);
150diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp
151index a2a181cfa..5befb81ac 100644
152--- a/src/cpu/x64/jit_primitive_conf.hpp
153+++ b/src/cpu/x64/jit_primitive_conf.hpp
154@@ -672,6 +672,7 @@ struct jit_pool_conf_t {
155     bool with_postops;
156     bool with_eltwise;
157     bool with_binary;
158+    int nthr;
159 };
160
161 struct jit_pool_call_s {
162diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp
163index 36d129e6d..ebd4f3af1 100644
164--- a/src/cpu/x64/jit_uni_pool_kernel.cpp
165+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp
166@@ -76,8 +76,7 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
167
168 template <cpu_isa_t isa>
169 status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
170-        memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
171-        int nthreads) {
172+        memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) {
173
174     const auto &pd = *ppd->desc();
175     const memory_desc_wrapper src_d(
176@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
177
178     const int ndims = src_d.ndims();
179
180+    jpp.nthr = dnnl_get_max_threads();
181     jpp.is_training = pd.prop_kind == prop_kind::forward_training;
182     jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
183
184@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
185                     ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1)
186                     : (ndims == 5 ? jpp.od : jpp.oh);
187             work *= jpp.mb * nb2_c;
188-            auto eff = (float)work / utils::rnd_up(work, nthreads);
189+            auto eff = (float)work / utils::rnd_up(work, jpp.nthr);
190             if (eff > best_eff) {
191
192                 best_eff = eff;
193diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp
194index d5d5f25a2..57ce6f43d 100644
195--- a/src/cpu/x64/jit_uni_pool_kernel.hpp
196+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp
197@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator {
198     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)
199
200     static status_t init_conf(jit_pool_conf_t &jbp,
201-            memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
202-            int nthreads);
203+            memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd);
204
205 private:
206     using Xmm = Xbyak::Xmm;
207diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp
208index b2055f2a9..29987f70c 100644
209--- a/src/cpu/x64/jit_uni_pooling.cpp
210+++ b/src/cpu/x64/jit_uni_pooling.cpp
211@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
212         (*kernel_)(&arg);
213     };
214
215+    const int nthr = jpp.nthr;
216+
217     if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
218         const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
219         parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) {
220@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
221     } else {
222         if (trans_src || trans_dst) {
223             // ncsp format
224-            parallel_nd_ext(0, jpp.mb, jpp.nb_c,
225+            parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
226                     [&](int ithr, int nthr, int n, int b_c) {
227                         if (trans_src)
228                             transpose_facade.execute_transpose_input(
229@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
230                     });
231         } else {
232             // nChw16c, nChw8c format
233-            parallel(0, [&](std::size_t ithr, std::size_t nthr) {
234+            parallel(nthr, [&](int ithr, int nthr) {
235                 const std::size_t work_amount
236                         = static_cast<std::size_t>(jpp.mb) * jpp.nb_c * jpp.oh;
237                 if (ithr >= work_amount) return;
238@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
239         (*kernel_)(&arg);
240     };
241
242+    const int nthr = jpp.nthr;
243+
244     if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
245         const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
246         parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) {
247@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
248         });
249     } else {
250         if (trans_src || trans_dst) {
251-            parallel_nd_ext(0, jpp.mb, jpp.nb_c,
252+            parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
253                     [&](int ithr, int nthr, int n, int b_c) {
254                         if (trans_src)
255                             transpose_facade.execute_transpose_input(
256@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward(
257             transpose_facade.execute_transpose_output(ithr, n, b_c);
258     };
259
260-    parallel(0, [&](int ithr, int nthr) {
261+    const int nthr = jpp.nthr;
262+
263+    parallel(nthr, [&](int ithr, int nthr) {
264         const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
265         const std::size_t work_amount
266                 = static_cast<std::size_t>(jpp.mb) * nb2_c;
267@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
268         }
269     };
270
271+    const int nthr = jpp.nthr;
272+
273     if (jpp.simple_alg) {
274         if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
275             const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
276@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
277         } else {
278             assert(jpp.ur_bc == 1);
279             if (trans_src || trans_dst) {
280-                parallel_nd_ext(0, jpp.mb, jpp.nb_c,
281+                parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
282                         [&](int ithr, int nthr, int n, int b_c) {
283                             if (trans_src)
284                                 transpose_facade.execute_transpose_input(
285@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
286             if (!trans_src) {
287                 const size_t chunk_size
288                         = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block;
289-                parallel_nd_ext(0, jpp.mb, jpp.nb_c,
290+                parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
291                         [&](int ithr, int nthr, int n, int b_c) {
292                             const size_t offset
293                                     = ((size_t)n * jpp.nb_c + b_c) * chunk_size;
294@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
295
296         const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
297         if (trans_src || trans_dst) {
298-            parallel_nd_ext(
299-                    0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) {
300+            parallel_nd_ext(nthr, jpp.mb, nb2_c,
301+                    [&](int ithr, int nthr, int n, int b2_c) {
302                         const auto b_c = b2_c * jpp.ur_bc;
303
304                         if (trans_dst) {
305diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp
306index ec4b04a2b..e25d9ce05 100644
307--- a/src/cpu/x64/jit_uni_pooling.hpp
308+++ b/src/cpu/x64/jit_uni_pooling.hpp
309@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t {
310                 init_default_ws();
311
312             auto scratchpad = scratchpad_registry().registrar();
313-            return jit_uni_pool_kernel<isa>::init_conf(
314-                    jpp_, scratchpad, this, dnnl_get_max_threads());
315+            CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
316+
317+            return status::success;
318         }
319
320         jit_pool_conf_t jpp_;
321@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t {
322                 init_default_ws();
323                 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
324             }
325+
326             auto scratchpad = scratchpad_registry().registrar();
327-            return jit_uni_pool_kernel<isa>::init_conf(
328-                    jpp_, scratchpad, this, dnnl_get_max_threads());
329+            CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
330+
331+            return status::success;
332         }
333
334         jit_pool_conf_t jpp_;
335