Searched refs:weight_md (Results 1 – 3 of 3) sorted by relevance
/external/tensorflow/tensorflow/core/kernels/mkl/ |
D | mkl_matmul_op_fused.cc | 224 const memory::desc weight_md = in Compute() local 226 if (weight_md != matmul_pd->weights_desc()) { in Compute() 232 weight_mkl, weight_md); in Compute() 244 weight_mkl.SetUsrMem(weight_md, weight_data); in Compute()
|
D | mkl_qmatmul_op.cc | 217 auto weight_md = weight_mkl_shape.IsMklTensor() in Compute() local 221 weight.SetUsrMem(weight_md, &weight_tensor); in Compute() 260 if (weight_md != matmul_fwd_pd->weights_desc()) { in Compute() 271 weight_tensor, weight, weight_md); in Compute() 279 weight.SetUsrMem(weight_md, &weight_tensor); in Compute()
|
D | mkl_matmul_ops_common.h | 152 std::shared_ptr<mkldnn::memory::desc> weight_md; member 170 weight_md(nullptr), in MklDnnMatMulFwdContext() 183 context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, in Setup() 196 prop_kind::forward_inference, *context_.src_md, *context_.weight_md, in Setup() 428 MklDnnData<Tweight>& weight, const memory::desc& weight_md) in CacheWeight() argument 439 weight.SetUsrMem(weight_md, &weight_tensor); in CacheWeight()
|