Searched refs:weight_tf_shape (Results 1 – 3 of 3) sorted by relevance
73 auto weight_tf_shape = weight_tensor.shape(); in Compute() local78 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(weight_tf_shape), in Compute()91 const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); in Compute()94 ctx, k == weight_tf_shape.dim_size(dim_pair[1]), in Compute()97 ", In[1]: ", weight_tf_shape.DebugString())); in Compute()
184 auto weight_tf_shape = weight_mkl_shape.IsMklTensor() in Compute() local189 weight_dims = TFShapeToMklDnnDims(weight_tf_shape); in Compute()191 static_cast<int>(weight_tf_shape.dim_size(1))}; in Compute()195 weight_dims = {static_cast<int>(weight_tf_shape.dim_size(1)), in Compute()196 static_cast<int>(weight_tf_shape.dim_size(0))}; in Compute()
447 TensorShape weight_tf_shape; in CacheWeight() local448 weight_tf_shape.AddDim(weight_size / sizeof(Tweight)); in CacheWeight()451 DataTypeToEnum<Tweight>::value, weight_tf_shape, in CacheWeight()