Home
last modified time | relevance | path

Searched refs:num_index_dims (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dgather_op.cc53 int64_t num_index_dims; in XlaGather() local
57 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in XlaGather()
62 num_index_dims = 1; in XlaGather()
73 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); in XlaGather()
90 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaGather()
128 if (axis <= i && i < (axis + num_index_dims)) { in XlaGather()
139 } else if (i >= (axis + num_index_dims)) { in XlaGather()
142 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather()
148 for (int64_t i = axis; i < axis + num_index_dims; i++) { in XlaGather()
285 const int64_t num_index_dims = in Compile() local
[all …]
Dscatter_nd_op.cc44 const int64_t num_index_dims = in ValidateUpdateShape() local
55 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); in ValidateUpdateShape()
64 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape()
68 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape()
78 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape()
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc47 int64_t num_index_dims = 1; in XlaScatter() local
50 num_index_dims = indices_dims.back(); in XlaScatter()
51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter()
73 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaScatter()
145 int64_t num_window_dims_in_updates = buffer_rank - num_index_dims; in XlaScatter()
152 for (int64_t dim = num_index_dims; dim < buffer_rank; ++dim) { in XlaScatter()
169 for (int64_t i = 0; i < num_index_dims; ++i) { in XlaScatter()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmatrix.cc125 const int64_t num_index_dims = 2; in GetMatrixDiagonalViaGather() local
126 const int64_t axis = n_dims - num_index_dims; in GetMatrixDiagonalViaGather()
132 {diag_len, num_index_dims}, {0}); in GetMatrixDiagonalViaGather()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc1496 int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); in matchAndRewrite() local
1499 if (num_index_dims == ShapedType::kDynamicSize) return failure(); in matchAndRewrite()
1504 if (i < num_index_dims) { in matchAndRewrite()
1515 if (i < num_index_dims) { in matchAndRewrite()
1536 collapsed_slice_dims.reserve(num_index_dims); in matchAndRewrite()
1537 for (int64_t i = 0; i < num_index_dims; ++i) { in matchAndRewrite()
1542 offset_dims.reserve(params_rank - num_index_dims); in matchAndRewrite()
1543 for (int64_t i = num_index_dims; i < params_rank; i++) { in matchAndRewrite()
1544 offset_dims.push_back(i + indices_rank - 1 - num_index_dims); in matchAndRewrite()
1548 offset_dims.reserve(num_index_dims); in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/
Dtf_ops_n_z.cc2414 int64_t num_index_dims = indices_ty.getShape().back(); in verify() local
2415 if (ShapedType::isDynamic(num_index_dims)) return success(); in verify()
2417 if (num_index_dims > tensor_ty.getRank()) in verify()