Searched refs:lhs_rank (Results 1 – 5 of 5) sorted by relevance
173 const int64 lhs_rank = lhs_shape.rank(); in CanonicalizeDot() local175 lhs_rank - num_batch_dims - num_contracting_dims; in CanonicalizeDot()183 for (int64 i = 0; i < lhs_rank; ++i) { in CanonicalizeDot()197 lhs_transpose.reserve(lhs_rank); in CanonicalizeDot()
1045 int64 lhs_rank = lhs->shape().rank(); in ComputeArrayForDotWithIndexedLhs() local1048 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); in ComputeArrayForDotWithIndexedLhs()
1145 const int64 lhs_rank = lhs->shape().rank(); in HandleDotStrengthReduction() local1150 if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { in HandleDotStrengthReduction()1154 int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, in HandleDotStrengthReduction()1157 if (lhs_kept_dim == -1 && lhs_rank > 1) { in HandleDotStrengthReduction()1209 if (rhs_rank == 1 && lhs_rank == 1) { in HandleDotStrengthReduction()1239 if (lhs_rank == 1 || in HandleDotStrengthReduction()1240 (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { in HandleDotStrengthReduction()1275 CHECK_EQ(rhs_rank, lhs_rank); in HandleDotStrengthReduction()1276 CHECK_EQ(dot_rank, lhs_rank); in HandleDotStrengthReduction()1353 broadcast_dims(lhs_rank, lhs_kept_dim)); in HandleDotStrengthReduction()
1031 const auto lhs_rank = lhs_shape.rank(); in HandleConvolution() local1034 CHECK_EQ(num_spatial_dims + 2, lhs_rank); in HandleConvolution()1223 const int64 lhs_rank = lhs->shape().rank(); in HandleDot() local1241 if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && in HandleDot()1284 const auto lhs_rank = lhs->shape().rank(); in HandleDotSlowPath() local1296 DimensionVector lhs_index(lhs_rank); in HandleDotSlowPath()1304 (lhs_rank - dnums.lhs_contracting_dimensions_size()) + in HandleDotSlowPath()1316 for (int64 i = 0; i < lhs_rank; i++) { in HandleDotSlowPath()
505 const int64 lhs_rank = lhs_shape.rank(); in BinaryOp() local511 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { in BinaryOp()512 const bool should_broadcast_lhs = lhs_rank < rhs_rank; in BinaryOp()