Home
last modified time | relevance | path

Searched refs:indices_rank (Results 1 – 14 of 14) sorted by relevance

/external/tensorflow/tensorflow/dtensor/mlir/expansions/
Dgather_spmd_expander.cc45 const auto indices_rank = ValueRank(gather_op.indices()); in ExpandOp() local
48 if (indices_rank == -1) in ExpandOp()
94 output_layout->sharding_spec(i + indices_rank - batch_dims - 1)) { in ExpandOp()
121 for (int i = 0; i < indices_rank; ++i) { in ExpandOp()
184 const int indices_rank = ValueRank(gather_op.indices()); in ComputeLayoutForward() local
187 if (indices_rank == -1) in ComputeLayoutForward()
224 for (int i = batch_dims; i < indices_rank; ++i) in ComputeLayoutForward()
256 const int indices_rank = ValueRank(gather_op.indices()); in ComputeLayoutBackward() local
259 if (indices_rank == -1) in ComputeLayoutBackward()
268 indices_layout_specs.reserve(indices_rank); in ComputeLayoutBackward()
[all …]
Dscatter_spmd_expander.cc188 const int indices_rank = ValueRank(scatter_op.indices()); in TensorScatterOpComputeLayoutBackward() local
190 if (tensor_rank == -1 || indices_rank == -1 || updates_rank == -1) in TensorScatterOpComputeLayoutBackward()
200 const Layout indices_layout = Layout::ReplicatedOnMesh(mesh, indices_rank); in TensorScatterOpComputeLayoutBackward()
/external/tensorflow/tensorflow/core/kernels/fuzzing/
Dscatter_nd_fuzz.cc76 size_t indices_rank = 1 + (data[data_ix++] % kMaxIndicesRank); in FuzzImpl() local
79 if (data_ix + indices_rank >= size) { in FuzzImpl()
84 for (i = 0; i < indices_rank; i++) { in FuzzImpl()
105 for (i = 0; i < indices_rank - 1; i++) { in FuzzImpl()
109 int64_t last = indices_dims[indices_rank - 1]; in FuzzImpl()
/external/tensorflow/tensorflow/lite/kernels/
Dgather_nd.cc72 const int indices_rank = NumDimensions(indices); in Prepare() local
73 const int indices_nd = SizeOfDimension(indices, indices_rank - 1); in Prepare()
78 if (indices_rank < 1) { in Prepare()
93 const int output_rank = indices_rank + params_rank - indices_nd - 1; in Prepare()
96 for (int i = 0; i < indices_rank - 1; ++i) { in Prepare()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dgather_op.cc140 int64_t indices_rank = in XlaGather() local
142 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather()
/external/tensorflow/tensorflow/compiler/mlir/lite/ir/
Dtfl_ops.cc1337 int64_t indices_rank = indices_type.getRank(); in BuildGatherOp() local
1357 batch_dims_i += indices_rank; in BuildGatherOp()
1364 if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) { in BuildGatherOp()
1377 if ((indices_rank == 0) || (indices_rank == batch_dims_i)) { in BuildGatherOp()
1381 } else if (indices_rank == 1) { in BuildGatherOp()
1388 shape.resize(params_rank + indices_rank - 1 - batch_dims_i); in BuildGatherOp()
1392 std::begin(shape) + axis_i + indices_rank - batch_dims_i); in BuildGatherOp()
/external/libpalmrejection/ui/events/ozone/evdev/touch_filter/palm_model/
Donedevice_train_palm_detection_filter_inference_beta.cc550 int32_t indices_rank, in Gather() argument
555 const int32_t num_indices = ShapeSize(indices_rank, indices_shape); in Gather()
1273 int32_t indices_rank, in BroadcastOffset() argument
1279 input_shape[i] == 1 ? 0 : indices[i + indices_rank - input_rank]; in BroadcastOffset()
Donedevice_train_palm_detection_filter_inference.cc551 int32_t indices_rank, in Gather() argument
556 const int32_t num_indices = ShapeSize(indices_rank, indices_shape); in Gather()
1274 int32_t indices_rank, in BroadcastOffset() argument
1280 input_shape[i] == 1 ? 0 : indices[i + indices_rank - input_rank]; in BroadcastOffset()
Donedevice_train_palm_detection_filter_inference_v2.cc550 int32_t indices_rank, in Gather() argument
555 const int32_t num_indices = ShapeSize(indices_rank, indices_shape); in Gather()
1274 int32_t indices_rank, in BroadcastOffset() argument
1280 input_shape[i] == 1 ? 0 : indices[i + indices_rank - input_rank]; in BroadcastOffset()
/external/tensorflow/tensorflow/compiler/mlir/tosa/transforms/
Dlegalize_common.cc3464 int indices_rank = indices_type.getShape().size(); in convertGatherOp() local
3466 if (!(batch_dims <= indices_rank)) { in convertGatherOp()
3506 indices_type.getShape().slice(batch_dims, indices_rank - batch_dims)); in convertGatherOp()
3740 int indices_rank = indices_type.getShape().size(); in convertGatherNdOp() local
3742 ND = indices_type.getShape()[indices_rank - 1]; in convertGatherNdOp()
3750 for (int i = 0; i < (indices_rank - 1); i++) { in convertGatherNdOp()
/external/tensorflow/tensorflow/core/ops/
Ddata_flow_ops.cc109 const int64_t indices_rank = c->Rank(indices_shape); in DynamicStitchShapeFunction() local
118 TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest)); in DynamicStitchShapeFunction()
/external/tensorflow/tensorflow/compiler/xla/service/
Ddynamic_dimension_inference.cc1398 int64_t indices_rank = hlo->operand(1)->shape().rank(); in HandleGather() local
1416 CHECK(indices_dim == indices_rank); in HandleGather()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc1495 auto indices_rank = indices_ty.getRank(); in matchAndRewrite() local
1496 int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); in matchAndRewrite()
1544 offset_dims.push_back(i + indices_rank - 1 - num_index_dims); in matchAndRewrite()
1553 int64_t index_vector_dim = indices_rank - 1; in matchAndRewrite()
4511 int64_t indices_rank = indices_ty.getRank(); in matchAndRewrite() local
4522 indices_rank - 1); in matchAndRewrite()
/external/tensorflow/tensorflow/python/ops/parallel_for/
Dpfor.py4265 indices_rank = array_ops.rank(indices)
4270 array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length])
4277 indices_shape, [[0], [indices_rank - 1]], [1, 1])