Searched refs:num_index_dims (Results 1 – 6 of 6) sorted by relevance
53 int64_t num_index_dims; in XlaGather() local57 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in XlaGather()62 num_index_dims = 1; in XlaGather()73 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); in XlaGather()90 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaGather()128 if (axis <= i && i < (axis + num_index_dims)) { in XlaGather()139 } else if (i >= (axis + num_index_dims)) { in XlaGather()142 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather()148 for (int64_t i = axis; i < axis + num_index_dims; i++) { in XlaGather()285 const int64_t num_index_dims = in Compile() local[all …]
44 const int64_t num_index_dims = in ValidateUpdateShape() local55 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); in ValidateUpdateShape()64 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape()68 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape()78 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape()
47 int64_t num_index_dims = 1; in XlaScatter() local50 num_index_dims = indices_dims.back(); in XlaScatter()51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter()73 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaScatter()145 int64_t num_window_dims_in_updates = buffer_rank - num_index_dims; in XlaScatter()152 for (int64_t dim = num_index_dims; dim < buffer_rank; ++dim) { in XlaScatter()169 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaScatter()
125 const int64_t num_index_dims = 2; in GetMatrixDiagonalViaGather() local126 const int64_t axis = n_dims - num_index_dims; in GetMatrixDiagonalViaGather()132 {diag_len, num_index_dims}, {0}); in GetMatrixDiagonalViaGather()
1496 int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); in matchAndRewrite() local1499 if (num_index_dims == ShapedType::kDynamicSize) return failure(); in matchAndRewrite()1504 if (i < num_index_dims) { in matchAndRewrite()1515 if (i < num_index_dims) { in matchAndRewrite()1536 collapsed_slice_dims.reserve(num_index_dims); in matchAndRewrite()1537 for (int64_t i = 0; i < num_index_dims; ++i) { in matchAndRewrite()1542 offset_dims.reserve(params_rank - num_index_dims); in matchAndRewrite()1543 for (int64_t i = num_index_dims; i < params_rank; i++) { in matchAndRewrite()1544 offset_dims.push_back(i + indices_rank - 1 - num_index_dims); in matchAndRewrite()1548 offset_dims.reserve(num_index_dims); in matchAndRewrite()[all …]
2414 int64_t num_index_dims = indices_ty.getShape().back(); in verify() local2415 if (ShapedType::isDynamic(num_index_dims)) return success(); in verify()2417 if (num_index_dims > tensor_ty.getRank()) in verify()