Home
last modified time | relevance | path

Searched refs:num_index_dims (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dgather_op.cc52 int64 num_index_dims; in XlaGather() local
56 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in XlaGather()
61 num_index_dims = 1; in XlaGather()
72 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); in XlaGather()
89 for (int64 i = 0; i < num_index_dims; ++i) { in XlaGather()
127 if (axis <= i && i < (axis + num_index_dims)) { in XlaGather()
138 } else if (i >= (axis + num_index_dims)) { in XlaGather()
141 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather()
147 for (int64 i = axis; i < axis + num_index_dims; i++) { in XlaGather()
281 const int64 num_index_dims = in Compile() local
[all …]
Dscatter_nd_op.cc43 const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in ValidateUpdateShape() local
53 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); in ValidateUpdateShape()
58 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape()
62 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape()
72 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape()
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc47 int64 num_index_dims = 1; in XlaScatter() local
50 num_index_dims = indices_dims.back(); in XlaScatter()
51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter()
73 for (int64 i = 0; i < num_index_dims; ++i) { in XlaScatter()
145 int64 num_window_dims_in_updates = buffer_rank - num_index_dims; in XlaScatter()
152 for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { in XlaScatter()
169 for (int64 i = 0; i < num_index_dims; ++i) { in XlaScatter()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmatrix.cc124 const int64 num_index_dims = 2; in GetMatrixDiagonalViaGather() local
125 const int64 axis = n_dims - num_index_dims; in GetMatrixDiagonalViaGather()
131 {diag_len, num_index_dims}, {0}); in GetMatrixDiagonalViaGather()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/
Dtf_ops_n_z.cc2233 int64_t num_index_dims = indices_ty.getShape().back(); in Verify() local
2234 if (ShapedType::isDynamic(num_index_dims)) return success(); in Verify()
2236 if (num_index_dims > tensor_ty.getRank()) in Verify()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc3931 int64_t num_index_dims = indices_ty.getShape().back(); in matchAndRewrite() local
3932 if (ShapedType::isDynamic(num_index_dims)) return failure(); in matchAndRewrite()
3938 int64_t window_dims = tensor_rank - num_index_dims; in matchAndRewrite()
3942 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), in matchAndRewrite()
3943 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), in matchAndRewrite()