Home
last modified time | relevance | path

Searched refs:updates_shape (Results 1 – 12 of 12) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dscatter_nd_op.cc34 const TensorShape& updates_shape) { in ValidateUpdateShape() argument
48 updates_shape.DebugString(), in ValidateUpdateShape()
54 if (updates_shape.dims() < batch_dim) return shape_err(); in ValidateUpdateShape()
56 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape()
59 if (updates_shape.dims() != in ValidateUpdateShape()
64 if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { in ValidateUpdateShape()
68 for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { in ValidateUpdateShape()
69 if (updates_shape.dim_size(d + batch_dim) != in ValidateUpdateShape()
85 TensorShape updates_shape = context->InputShape(1); in Compile() local
98 updates_shape.num_elements() == 0), in Compile()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc40 TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); in XlaScatter()
143 int64 updates_rank = updates_shape.rank(); in XlaScatter()
158 TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); in XlaScatter()
159 updates_rank = updates_shape.rank(); in XlaScatter()
191 VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); in XlaScatter()
/external/tensorflow/tensorflow/cc/gradients/
Darray_grad_test.cc225 TensorShape updates_shape({4}); in TEST_F() local
227 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
231 RunTest(updates, updates_shape, y, y_shape); in TEST_F()
235 TensorShape updates_shape({2, 4, 4}); in TEST_F() local
237 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
241 RunTest(updates, updates_shape, y, y_shape); in TEST_F()
245 TensorShape updates_shape({4}); in TEST_F() local
249 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
252 RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); in TEST_F()
256 TensorShape updates_shape({2, 4, 4}); in TEST_F() local
[all …]
/external/tensorflow/tensorflow/compiler/tests/
Dscatter_nd_op_test.py104 updates_shape = (num_updates,)
106 updates_shape += (ref_shape[i],)
107 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_evaluator_typed_visitor.h2059 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateScatterIndices() argument
2060 int64 updates_rank = updates_shape.dimensions_size(); in IterationSpaceForUpdateScatterIndices()
2067 index_count[i] = updates_shape.dimensions(i); in IterationSpaceForUpdateScatterIndices()
2078 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateWindowIndices() argument
2079 int64 updates_rank = updates_shape.dimensions_size(); in IterationSpaceForUpdateWindowIndices()
2086 index_count[i] = updates_shape.dimensions(i); in IterationSpaceForUpdateWindowIndices()
2107 const Shape& updates_shape, const Literal* scatter_indices) in UpdateScatterIndexToInputIndex() argument
2109 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { in UpdateScatterIndexToInputIndex()
2237 const Shape& updates_shape) { in UpdateWindowIndexToInputIndex() argument
2240 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { in UpdateWindowIndexToInputIndex()
[all …]
Dshape_inference.cc3071 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in ValidateScatterDimensionNumbers() argument
3084 const int64 updates_rank = updates_shape.rank(); in ValidateScatterDimensionNumbers()
3165 const Shape& updates_shape, const ProgramShape& to_apply_shape, in InferScatterShape() argument
3171 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); in InferScatterShape()
3194 {updates_shape.element_type()}, in InferScatterShape()
3206 if (updates_shape.rank() != expected_updates_rank) { in InferScatterShape()
3208 expected_updates_rank, updates_shape.rank()); in InferScatterShape()
3212 operand_shape, expanded_scatter_indices_shape, updates_shape, in InferScatterShape()
3227 if (updates_shape.dimensions(update_window_dim) > in InferScatterShape()
3233 update_window_dim, updates_shape.dimensions(update_window_dim), in InferScatterShape()
[all …]
Dshape_inference.h292 const Shape& updates_shape, const ProgramShape& to_apply_shape,
/external/tensorflow/tensorflow/core/ops/
Darray_ops.cc2875 ShapeHandle updates_shape, in ScatterNdShapeHelper() argument
2879 c->Value(c->NumElements(updates_shape)) > 0)) { in ScatterNdShapeHelper()
2884 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { in ScatterNdShapeHelper()
2898 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); in ScatterNdShapeHelper()
2906 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper()
2914 c->Subshape(updates_shape, outer_dims, &suffix_updates)); in ScatterNdShapeHelper()
2920 " must match the inner ", c->Rank(updates_shape) - outer_dims, in ScatterNdShapeHelper()
2921 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper()
2934 ShapeHandle updates_shape; in ScatterNdShape() local
2935 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); in ScatterNdShape()
[all …]
/external/tensorflow/tensorflow/core/framework/
Dcommon_shape_fns.cc1546 ShapeHandle updates_shape; in ScatterNdUpdateShape() local
1547 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); in ScatterNdUpdateShape()
1551 c->Value(c->NumElements(updates_shape)) > 0)) { in ScatterNdUpdateShape()
1556 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { in ScatterNdUpdateShape()
1570 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); in ScatterNdUpdateShape()
1578 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdUpdateShape()
1586 c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); in ScatterNdUpdateShape()
1592 " must match the inner ", c->Rank(updates_shape) - num_outer_dims, in ScatterNdUpdateShape()
1593 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdUpdateShape()
/external/tensorflow/tensorflow/core/kernels/
Dscatter_nd_op.cc559 const TensorShape& updates_shape(updates.shape()); in PrepareAndValidateInputs() local
568 updates_shape.num_elements())) { in PrepareAndValidateInputs()
578 ", updates.shape ", updates_shape.DebugString()); in PrepareAndValidateInputs()
/external/tensorflow/tensorflow/python/kernel_tests/
Dscatter_nd_ops_test.py134 updates_shape = (num_updates,)
136 updates_shape += (ref_shape[i],)
137 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc1874 TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); in Scatter()
1879 input_shape, scatter_indices_shape, updates_shape, in Scatter()