1diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp 2index b678200a1..09736ccae 100644 3--- a/src/cpu/nchw_pooling.cpp 4+++ b/src/cpu/nchw_pooling.cpp 5@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward( 6 int od_end = min(OD, 1 + (padF + ID - 1) / SD); 7 8 dim_t c_blk = pd()->channel_block_size_; 9- int c_blk_tail = C % c_blk; 10+ dim_t c_blk_tail = C % c_blk; 11+ const int nthr = pd()->nthr_; 12+ 13 if (alg == alg_kind::pooling_max) { 14- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), 15- [&](int ithr, int, int mb, int cb) { 16+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), 17+ [&](int ithr, int, dim_t mb, dim_t cb) { 18 bool is_last_c_block 19 = c_blk_tail > 0 && (cb + 1) * c_blk > C; 20 int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; 21@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward( 22 diff_src_fp32, src_sp_size * curr_c_block); 23 }); 24 } else { 25- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), 26- [&](int ithr, int, int mb, int cb) { 27+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), 28+ [&](int ithr, int, dim_t mb, dim_t cb) { 29 bool is_last_c_block 30 = c_blk_tail > 0 && (cb + 1) * c_blk > C; 31 int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; 32diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp 33index 9d649f3f5..2a73f6ae6 100644 34--- a/src/cpu/nchw_pooling.hpp 35+++ b/src/cpu/nchw_pooling.hpp 36@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t { 37 ws_md_ = *hint_fwd_pd_->workspace_md(); 38 } 39 40+ nthr_ = dnnl_get_max_threads(); 41 calculate_channel_block_size(); 42 init_scratchpad(); 43 44@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t { 45 } 46 47 dim_t channel_block_size_; 48+ int nthr_; // To not exceed the limit in execute used for set up. 49 50 private: 51 void init_scratchpad() { 52@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t { 53 if (diff_dst_md()->data_type == data_type::bf16) { 54 size_t dst_sz_ = OD() * OH() * OW(); 55 size_t src_sz_ = ID() * IH() * IW(); 56- size_t nthrs = dnnl_get_max_threads(); 57 auto scratchpad = scratchpad_registry().registrar(); 58 59 scratchpad.template book<float>(key_pool_src_bf16cvt, 60- src_sz_ * nthrs * channel_block_size_); 61+ src_sz_ * nthr_ * channel_block_size_); 62 scratchpad.template book<float>(key_pool_dst_bf16cvt, 63- dst_sz_ * nthrs * channel_block_size_); 64+ dst_sz_ * nthr_ * channel_block_size_); 65 } 66 } 67 68@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t { 69 // spatial 70 dim_t dst_sz_ = OD() * OH() * OW(); 71 dim_t src_sz_ = ID() * IH() * IW(); 72- dim_t nthrs = dnnl_get_max_threads(); 73- dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C()); 74+ dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C()); 75 const dim_t max_block_size 76 = platform::get_per_core_cache_size(1) / 2; 77 dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 78diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp 79index 48d9e1240..efe3083f7 100644 80--- a/src/cpu/nhwc_pooling.cpp 81+++ b/src/cpu/nhwc_pooling.cpp 82@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward( 83 return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; 84 }; 85 const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); 86+ const int nthr = pd()->nthr_; 87 88- parallel_nd_ext(0, MB, OD, OH, OW, 89+ parallel_nd_ext(nthr, MB, OD, OH, OW, 90 [&](int ithr, int, int mb, int od, int oh, int ow) { 91 const size_t dst_offset_init = strided_offset(mb, dst_n_stride, 92 od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); 93@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward( 94 auto apply_offset = [=](int index, int offset) { 95 return (index > offset) ? index - offset : 0; 96 }; 97+ const int nthr = pd()->nthr_; 98 99- parallel_nd_ext(0, MB, ID, IH, IW, 100+ parallel_nd_ext(nthr, MB, ID, IH, IW, 101 [&](int ithr, int, int mb, int id, int ih, int iw) { 102 size_t src_offset_init = strided_offset(mb, diff_src_n_stride, 103 id, diff_src_d_stride, ih, diff_src_h_stride, iw, 104diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp 105index c65196a94..c16e840a2 100644 106--- a/src/cpu/nhwc_pooling.hpp 107+++ b/src/cpu/nhwc_pooling.hpp 108@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t { 109 init_default_ws(); 110 } 111 112+ nthr_ = dnnl_get_max_threads(); 113 init_scratchpad(); 114 115 return status::success; 116 } 117 118+ int nthr_; // To not exceed the limit in execute used for set up. 119+ 120 private: 121 void init_scratchpad() { 122 using namespace memory_tracking::names; 123 if (src_md()->data_type == data_type::bf16) { 124- const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); 125+ const size_t bf16cvt_sz_ = C() * nthr_; 126 auto scratchpad = scratchpad_registry().registrar(); 127 scratchpad.template book<float>( 128 key_pool_src_bf16cvt, bf16cvt_sz_); 129@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t { 130 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; 131 } 132 133+ nthr_ = dnnl_get_max_threads(); 134 init_scratchpad(); 135 136 return status::success; 137 } 138 139+ int nthr_; // To not exceed the limit in execute used for set up. 140+ 141 private: 142 void init_scratchpad() { 143 using namespace memory_tracking::names; 144 if (diff_src_md()->data_type == data_type::bf16) { 145- size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); 146+ size_t bf16cvt_sz_ = C() * nthr_; 147 auto scratchpad = scratchpad_registry().registrar(); 148 scratchpad.template book<float>( 149 key_pool_src_bf16cvt, bf16cvt_sz_); 150diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp 151index a2a181cfa..5befb81ac 100644 152--- a/src/cpu/x64/jit_primitive_conf.hpp 153+++ b/src/cpu/x64/jit_primitive_conf.hpp 154@@ -672,6 +672,7 @@ struct jit_pool_conf_t { 155 bool with_postops; 156 bool with_eltwise; 157 bool with_binary; 158+ int nthr; 159 }; 160 161 struct jit_pool_call_s { 162diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp 163index 36d129e6d..ebd4f3af1 100644 164--- a/src/cpu/x64/jit_uni_pool_kernel.cpp 165+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp 166@@ -76,8 +76,7 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel( 167 168 template <cpu_isa_t isa> 169 status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp, 170- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, 171- int nthreads) { 172+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) { 173 174 const auto &pd = *ppd->desc(); 175 const memory_desc_wrapper src_d( 176@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp, 177 178 const int ndims = src_d.ndims(); 179 180+ jpp.nthr = dnnl_get_max_threads(); 181 jpp.is_training = pd.prop_kind == prop_kind::forward_training; 182 jpp.is_backward = pd.prop_kind == prop_kind::backward_data; 183 184@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp, 185 ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) 186 : (ndims == 5 ? jpp.od : jpp.oh); 187 work *= jpp.mb * nb2_c; 188- auto eff = (float)work / utils::rnd_up(work, nthreads); 189+ auto eff = (float)work / utils::rnd_up(work, jpp.nthr); 190 if (eff > best_eff) { 191 192 best_eff = eff; 193diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp 194index d5d5f25a2..57ce6f43d 100644 195--- a/src/cpu/x64/jit_uni_pool_kernel.hpp 196+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp 197@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator { 198 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) 199 200 static status_t init_conf(jit_pool_conf_t &jbp, 201- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, 202- int nthreads); 203+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd); 204 205 private: 206 using Xmm = Xbyak::Xmm; 207diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp 208index b2055f2a9..29987f70c 100644 209--- a/src/cpu/x64/jit_uni_pooling.cpp 210+++ b/src/cpu/x64/jit_uni_pooling.cpp 211@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src, 212 (*kernel_)(&arg); 213 }; 214 215+ const int nthr = jpp.nthr; 216+ 217 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { 218 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); 219 parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) { 220@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src, 221 } else { 222 if (trans_src || trans_dst) { 223 // ncsp format 224- parallel_nd_ext(0, jpp.mb, jpp.nb_c, 225+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, 226 [&](int ithr, int nthr, int n, int b_c) { 227 if (trans_src) 228 transpose_facade.execute_transpose_input( 229@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src, 230 }); 231 } else { 232 // nChw16c, nChw8c format 233- parallel(0, [&](std::size_t ithr, std::size_t nthr) { 234+ parallel(nthr, [&](int ithr, int nthr) { 235 const std::size_t work_amount 236 = static_cast<std::size_t>(jpp.mb) * jpp.nb_c * jpp.oh; 237 if (ithr >= work_amount) return; 238@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src, 239 (*kernel_)(&arg); 240 }; 241 242+ const int nthr = jpp.nthr; 243+ 244 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { 245 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); 246 parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) { 247@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src, 248 }); 249 } else { 250 if (trans_src || trans_dst) { 251- parallel_nd_ext(0, jpp.mb, jpp.nb_c, 252+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, 253 [&](int ithr, int nthr, int n, int b_c) { 254 if (trans_src) 255 transpose_facade.execute_transpose_input( 256@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward( 257 transpose_facade.execute_transpose_output(ithr, n, b_c); 258 }; 259 260- parallel(0, [&](int ithr, int nthr) { 261+ const int nthr = jpp.nthr; 262+ 263+ parallel(nthr, [&](int ithr, int nthr) { 264 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); 265 const std::size_t work_amount 266 = static_cast<std::size_t>(jpp.mb) * nb2_c; 267@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d( 268 } 269 }; 270 271+ const int nthr = jpp.nthr; 272+ 273 if (jpp.simple_alg) { 274 if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { 275 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); 276@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d( 277 } else { 278 assert(jpp.ur_bc == 1); 279 if (trans_src || trans_dst) { 280- parallel_nd_ext(0, jpp.mb, jpp.nb_c, 281+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, 282 [&](int ithr, int nthr, int n, int b_c) { 283 if (trans_src) 284 transpose_facade.execute_transpose_input( 285@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d( 286 if (!trans_src) { 287 const size_t chunk_size 288 = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; 289- parallel_nd_ext(0, jpp.mb, jpp.nb_c, 290+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, 291 [&](int ithr, int nthr, int n, int b_c) { 292 const size_t offset 293 = ((size_t)n * jpp.nb_c + b_c) * chunk_size; 294@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d( 295 296 const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); 297 if (trans_src || trans_dst) { 298- parallel_nd_ext( 299- 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) { 300+ parallel_nd_ext(nthr, jpp.mb, nb2_c, 301+ [&](int ithr, int nthr, int n, int b2_c) { 302 const auto b_c = b2_c * jpp.ur_bc; 303 304 if (trans_dst) { 305diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp 306index ec4b04a2b..e25d9ce05 100644 307--- a/src/cpu/x64/jit_uni_pooling.hpp 308+++ b/src/cpu/x64/jit_uni_pooling.hpp 309@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t { 310 init_default_ws(); 311 312 auto scratchpad = scratchpad_registry().registrar(); 313- return jit_uni_pool_kernel<isa>::init_conf( 314- jpp_, scratchpad, this, dnnl_get_max_threads()); 315+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this)); 316+ 317+ return status::success; 318 } 319 320 jit_pool_conf_t jpp_; 321@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t { 322 init_default_ws(); 323 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; 324 } 325+ 326 auto scratchpad = scratchpad_registry().registrar(); 327- return jit_uni_pool_kernel<isa>::init_conf( 328- jpp_, scratchpad, this, dnnl_get_max_threads()); 329+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this)); 330+ 331+ return status::success; 332 } 333 334 jit_pool_conf_t jpp_; 335