Searched refs:num_index_dims (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | gather_op.cc | 52 int64 num_index_dims; in XlaGather() local 56 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in XlaGather() 61 num_index_dims = 1; in XlaGather() 72 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); in XlaGather() 89 for (int64 i = 0; i < num_index_dims; ++i) { in XlaGather() 127 if (axis <= i && i < (axis + num_index_dims)) { in XlaGather() 138 } else if (i >= (axis + num_index_dims)) { in XlaGather() 141 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather() 147 for (int64 i = axis; i < axis + num_index_dims; i++) { in XlaGather() 281 const int64 num_index_dims = in Compile() local [all …]
|
D | scatter_nd_op.cc | 43 const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); in ValidateUpdateShape() local 53 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); in ValidateUpdateShape() 58 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape() 62 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape() 72 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape()
|
/external/tensorflow/tensorflow/compiler/tf2xla/lib/ |
D | scatter.cc | 47 int64 num_index_dims = 1; in XlaScatter() local 50 num_index_dims = indices_dims.back(); in XlaScatter() 51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter() 73 for (int64 i = 0; i < num_index_dims; ++i) { in XlaScatter() 145 int64 num_window_dims_in_updates = buffer_rank - num_index_dims; in XlaScatter() 152 for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { in XlaScatter() 169 for (int64 i = 0; i < num_index_dims; ++i) { in XlaScatter()
|
/external/tensorflow/tensorflow/compiler/xla/client/lib/ |
D | matrix.cc | 124 const int64 num_index_dims = 2; in GetMatrixDiagonalViaGather() local 125 const int64 axis = n_dims - num_index_dims; in GetMatrixDiagonalViaGather() 131 {diag_len, num_index_dims}, {0}); in GetMatrixDiagonalViaGather()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/ |
D | tf_ops_n_z.cc | 2233 int64_t num_index_dims = indices_ty.getShape().back(); in Verify() local 2234 if (ShapedType::isDynamic(num_index_dims)) return success(); in Verify() 2236 if (num_index_dims > tensor_ty.getRank()) in Verify()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf.cc | 3931 int64_t num_index_dims = indices_ty.getShape().back(); in matchAndRewrite() local 3932 if (ShapedType::isDynamic(num_index_dims)) return failure(); in matchAndRewrite() 3938 int64_t window_dims = tensor_rank - num_index_dims; in matchAndRewrite() 3942 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), in matchAndRewrite() 3943 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter), in matchAndRewrite()
|