Home
last modified time | relevance | path

Searched refs:indices_tensor (Results 1 – 10 of 10) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/
Dunravel_index_op.cc41 const Tensor& indices_tensor = ctx->input(0); in Compute() local
43 TensorShapeUtils::IsVector(indices_tensor.shape()) || in Compute()
44 TensorShapeUtils::IsScalar(indices_tensor.shape()), in Compute()
47 indices_tensor.shape().DebugString(), "\"")); in Compute()
68 const Tidx* indices = indices_tensor.flat<Tidx>().data(); in Compute()
69 int64_t size = indices_tensor.NumElements(); in Compute()
100 if (TensorShapeUtils::IsScalar(indices_tensor.shape())) { in Compute()
107 output = output.constant(indices_tensor.scalar<Tidx>()()); in Compute()
113 indices_tensor.NumElements()}), in Compute()
121 {1, static_cast<Eigen::Index>(indices_tensor.NumElements())}); in Compute()
[all …]
Dmap_stage_op.cc527 const Tensor* indices_tensor; in Compute() local
531 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor)); in Compute()
545 OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple)); in Compute()
579 const Tensor* indices_tensor; in Compute() local
582 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor)); in Compute()
583 OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple)); in Compute()
586 ctx, tuple.size() == indices_tensor->NumElements(), in Compute()
588 " vs. ", indices_tensor->NumElements())); in Compute()
628 const Tensor* indices_tensor; in Compute() local
631 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor)); in Compute()
[all …]
/external/ComputeLibrary/tests/validation/fixtures/
DGatherFixture.h85 TensorType indices_tensor = create_tensor<TensorType>(indices_shape, DataType::U32); in compute_target() local
92 gather.configure(&src, &indices_tensor, &dst, axis); in compute_target()
95 ARM_COMPUTE_EXPECT(indices_tensor.info()->is_resizable(), framework::LogLevel::ERRORS); in compute_target()
100 indices_tensor.allocator()->allocate(); in compute_target()
104 ARM_COMPUTE_EXPECT(!indices_tensor.info()->is_resizable(), framework::LogLevel::ERRORS); in compute_target()
109 generate_indices(AccessorType(indices_tensor), input_shape, actual_axis, indices_shape); in compute_target()
124 SimpleTensor<uint32_t> indices_tensor{ indices_shape, DataType::U32 }; in compute_reference()
129 generate_indices(indices_tensor, input_shape, actual_axis, indices_shape); in compute_reference()
131 return reference::gather(src, indices_tensor, actual_axis); in compute_reference()
/external/tensorflow/tensorflow/lite/experimental/mlir/testing/op_tests/
Dtensor_scatter_update.py44 indices_tensor = tf.compat.v1.placeholder(
55 out = tf.tensor_scatter_nd_update(input_tensor, indices_tensor,
57 return [input_tensor, indices_tensor, updates_tensors], [out]
Dtensor_scatter_add.py44 indices_tensor = tf.compat.v1.placeholder(
55 out = tf.tensor_scatter_nd_add(input_tensor, indices_tensor, adds_tensors)
56 return [input_tensor, indices_tensor, adds_tensors], [out]
/external/tensorflow/tensorflow/core/kernels/fuzzing/
Dscatter_nd_fuzz.cc90 Tensor indices_tensor(tensorflow::DT_INT32, TensorShape(indices_dims)); in FuzzImpl() local
93 auto flat_indices = indices_tensor.flat<int32>(); in FuzzImpl()
122 RunInputs({{"indices", indices_tensor}, in FuzzImpl()
/external/tensorflow/tensorflow/python/kernel_tests/
Dsparsemask_op_test.py41 indices_tensor = ops.convert_to_tensor(indices)
44 t = ops.IndexedSlices(values_tensor, indices_tensor)
/external/tensorflow/tensorflow/tools/graph_transforms/
Dsparsify_gather.cc40 Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor, in SparsifyWeights() argument
68 *indices_tensor = Tensor(DataTypeToEnum<int64>::value, in SparsifyWeights()
71 indices_tensor->flat<int64>().data()); in SparsifyWeights()
350 Tensor indices_tensor; in SparsifyGatherInternal() local
353 SparsifyWeights(weight, &indices_tensor, &values_tensor)); in SparsifyGatherInternal()
358 CreateConstNode(indices_tensor, in SparsifyGatherInternal()
/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_gather_op_test.py425 for indices_tensor in indices_tensors:
427 params_tensor, indices_tensor, axis=axis, batch_dims=batch_dims)
433 0), getattr(indices_tensor, 'ragged_rank', 0)))
/external/tensorflow/tensorflow/core/grappler/
Dgrappler_item_builder.cc346 Tensor indices_tensor(DT_INT64, shape_2d); in GrapplerItemFromMetaGraphDef() local
347 InitializeTensor(input.dtype(), &indices_tensor); in GrapplerItemFromMetaGraphDef()
350 indices_tensor); in GrapplerItemFromMetaGraphDef()