diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp index b678200a1..09736ccae 100644 --- a/src/cpu/nchw_pooling.cpp +++ b/src/cpu/nchw_pooling.cpp @@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t::execute_backward( int od_end = min(OD, 1 + (padF + ID - 1) / SD); dim_t c_blk = pd()->channel_block_size_; - int c_blk_tail = C % c_blk; + dim_t c_blk_tail = C % c_blk; + const int nthr = pd()->nthr_; + if (alg == alg_kind::pooling_max) { - parallel_nd_ext(0, MB, utils::div_up(C, c_blk), - [&](int ithr, int, int mb, int cb) { + parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), + [&](int ithr, int, dim_t mb, dim_t cb) { bool is_last_c_block = c_blk_tail > 0 && (cb + 1) * c_blk > C; int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; @@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t::execute_backward( diff_src_fp32, src_sp_size * curr_c_block); }); } else { - parallel_nd_ext(0, MB, utils::div_up(C, c_blk), - [&](int ithr, int, int mb, int cb) { + parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), + [&](int ithr, int, dim_t mb, dim_t cb) { bool is_last_c_block = c_blk_tail > 0 && (cb + 1) * c_blk > C; int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp index 9d649f3f5..2a73f6ae6 100644 --- a/src/cpu/nchw_pooling.hpp +++ b/src/cpu/nchw_pooling.hpp @@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t { ws_md_ = *hint_fwd_pd_->workspace_md(); } + nthr_ = dnnl_get_max_threads(); calculate_channel_block_size(); init_scratchpad(); @@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t { } dim_t channel_block_size_; + int nthr_; // To not exceed the limit in execute used for set up. private: void init_scratchpad() { @@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t { if (diff_dst_md()->data_type == data_type::bf16) { size_t dst_sz_ = OD() * OH() * OW(); size_t src_sz_ = ID() * IH() * IW(); - size_t nthrs = dnnl_get_max_threads(); auto scratchpad = scratchpad_registry().registrar(); scratchpad.template book(key_pool_src_bf16cvt, - src_sz_ * nthrs * channel_block_size_); + src_sz_ * nthr_ * channel_block_size_); scratchpad.template book(key_pool_dst_bf16cvt, - dst_sz_ * nthrs * channel_block_size_); + dst_sz_ * nthr_ * channel_block_size_); } } @@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t { // spatial dim_t dst_sz_ = OD() * OH() * OW(); dim_t src_sz_ = ID() * IH() * IW(); - dim_t nthrs = dnnl_get_max_threads(); - dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C()); + dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C()); const dim_t max_block_size = platform::get_per_core_cache_size(1) / 2; dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp index 48d9e1240..efe3083f7 100644 --- a/src/cpu/nhwc_pooling.cpp +++ b/src/cpu/nhwc_pooling.cpp @@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t::execute_forward( return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; }; const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); + const int nthr = pd()->nthr_; - parallel_nd_ext(0, MB, OD, OH, OW, + parallel_nd_ext(nthr, MB, OD, OH, OW, [&](int ithr, int, int mb, int od, int oh, int ow) { const size_t dst_offset_init = strided_offset(mb, dst_n_stride, od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); @@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t::execute_backward( auto apply_offset = [=](int index, int offset) { return (index > offset) ? index - offset : 0; }; + const int nthr = pd()->nthr_; - parallel_nd_ext(0, MB, ID, IH, IW, + parallel_nd_ext(nthr, MB, ID, IH, IW, [&](int ithr, int, int mb, int id, int ih, int iw) { size_t src_offset_init = strided_offset(mb, diff_src_n_stride, id, diff_src_d_stride, ih, diff_src_h_stride, iw, diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp index c65196a94..c16e840a2 100644 --- a/src/cpu/nhwc_pooling.hpp +++ b/src/cpu/nhwc_pooling.hpp @@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t { init_default_ws(); } + nthr_ = dnnl_get_max_threads(); init_scratchpad(); return status::success; } + int nthr_; // To not exceed the limit in execute used for set up. + private: void init_scratchpad() { using namespace memory_tracking::names; if (src_md()->data_type == data_type::bf16) { - const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); + const size_t bf16cvt_sz_ = C() * nthr_; auto scratchpad = scratchpad_registry().registrar(); scratchpad.template book( key_pool_src_bf16cvt, bf16cvt_sz_); @@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t { if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; } + nthr_ = dnnl_get_max_threads(); init_scratchpad(); return status::success; } + int nthr_; // To not exceed the limit in execute used for set up. + private: void init_scratchpad() { using namespace memory_tracking::names; if (diff_src_md()->data_type == data_type::bf16) { - size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); + size_t bf16cvt_sz_ = C() * nthr_; auto scratchpad = scratchpad_registry().registrar(); scratchpad.template book( key_pool_src_bf16cvt, bf16cvt_sz_); diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index a2a181cfa..5befb81ac 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -672,6 +672,7 @@ struct jit_pool_conf_t { bool with_postops; bool with_eltwise; bool with_binary; + int nthr; }; struct jit_pool_call_s { diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp index 36d129e6d..ebd4f3af1 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.cpp +++ b/src/cpu/x64/jit_uni_pool_kernel.cpp @@ -76,8 +76,7 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( template status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, - int nthreads) { + memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) { const auto &pd = *ppd->desc(); const memory_desc_wrapper src_d( @@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, const int ndims = src_d.ndims(); + jpp.nthr = dnnl_get_max_threads(); jpp.is_training = pd.prop_kind == prop_kind::forward_training; jpp.is_backward = pd.prop_kind == prop_kind::backward_data; @@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) : (ndims == 5 ? jpp.od : jpp.oh); work *= jpp.mb * nb2_c; - auto eff = (float)work / utils::rnd_up(work, nthreads); + auto eff = (float)work / utils::rnd_up(work, jpp.nthr); if (eff > best_eff) { best_eff = eff; diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp index d5d5f25a2..57ce6f43d 100644 --- a/src/cpu/x64/jit_uni_pool_kernel.hpp +++ b/src/cpu/x64/jit_uni_pool_kernel.hpp @@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) static status_t init_conf(jit_pool_conf_t &jbp, - memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, - int nthreads); + memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd); private: using Xmm = Xbyak::Xmm; diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp index b2055f2a9..29987f70c 100644 --- a/src/cpu/x64/jit_uni_pooling.cpp +++ b/src/cpu/x64/jit_uni_pooling.cpp @@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, (*kernel_)(&arg); }; + const int nthr = jpp.nthr; + if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) { @@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, } else { if (trans_src || trans_dst) { // ncsp format - parallel_nd_ext(0, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, [&](int ithr, int nthr, int n, int b_c) { if (trans_src) transpose_facade.execute_transpose_input( @@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, }); } else { // nChw16c, nChw8c format - parallel(0, [&](std::size_t ithr, std::size_t nthr) { + parallel(nthr, [&](int ithr, int nthr) { const std::size_t work_amount = static_cast(jpp.mb) * jpp.nb_c * jpp.oh; if (ithr >= work_amount) return; @@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, (*kernel_)(&arg); }; + const int nthr = jpp.nthr; + if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) { @@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, }); } else { if (trans_src || trans_dst) { - parallel_nd_ext(0, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, [&](int ithr, int nthr, int n, int b_c) { if (trans_src) transpose_facade.execute_transpose_input( @@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t::execute_backward( transpose_facade.execute_transpose_output(ithr, n, b_c); }; - parallel(0, [&](int ithr, int nthr) { + const int nthr = jpp.nthr; + + parallel(nthr, [&](int ithr, int nthr) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); const std::size_t work_amount = static_cast(jpp.mb) * nb2_c; @@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( } }; + const int nthr = jpp.nthr; + if (jpp.simple_alg) { if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); @@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( } else { assert(jpp.ur_bc == 1); if (trans_src || trans_dst) { - parallel_nd_ext(0, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, [&](int ithr, int nthr, int n, int b_c) { if (trans_src) transpose_facade.execute_transpose_input( @@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( if (!trans_src) { const size_t chunk_size = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; - parallel_nd_ext(0, jpp.mb, jpp.nb_c, + parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, [&](int ithr, int nthr, int n, int b_c) { const size_t offset = ((size_t)n * jpp.nb_c + b_c) * chunk_size; @@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); if (trans_src || trans_dst) { - parallel_nd_ext( - 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) { + parallel_nd_ext(nthr, jpp.mb, nb2_c, + [&](int ithr, int nthr, int n, int b2_c) { const auto b_c = b2_c * jpp.ur_bc; if (trans_dst) { diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp index ec4b04a2b..e25d9ce05 100644 --- a/src/cpu/x64/jit_uni_pooling.hpp +++ b/src/cpu/x64/jit_uni_pooling.hpp @@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t { init_default_ws(); auto scratchpad = scratchpad_registry().registrar(); - return jit_uni_pool_kernel::init_conf( - jpp_, scratchpad, this, dnnl_get_max_threads()); + CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); + + return status::success; } jit_pool_conf_t jpp_; @@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t { init_default_ws(); if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; } + auto scratchpad = scratchpad_registry().registrar(); - return jit_uni_pool_kernel::init_conf( - jpp_, scratchpad, this, dnnl_get_max_threads()); + CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); + + return status::success; } jit_pool_conf_t jpp_;