/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | scatter_nd_op.cc | 34 const TensorShape& updates_shape) { in ValidateUpdateShape() argument 48 updates_shape.DebugString(), in ValidateUpdateShape() 54 if (updates_shape.dims() < batch_dim) return shape_err(); in ValidateUpdateShape() 56 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape() 59 if (updates_shape.dims() != in ValidateUpdateShape() 64 if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { in ValidateUpdateShape() 68 for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { in ValidateUpdateShape() 69 if (updates_shape.dim_size(d + batch_dim) != in ValidateUpdateShape() 85 TensorShape updates_shape = context->InputShape(1); in Compile() local 98 updates_shape.num_elements() == 0), in Compile() [all …]
|
/external/tensorflow/tensorflow/compiler/tf2xla/lib/ |
D | scatter.cc | 40 TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); in XlaScatter() 143 int64 updates_rank = updates_shape.rank(); in XlaScatter() 158 TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); in XlaScatter() 159 updates_rank = updates_shape.rank(); in XlaScatter() 191 VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); in XlaScatter()
|
/external/tensorflow/tensorflow/cc/gradients/ |
D | array_grad_test.cc | 225 TensorShape updates_shape({4}); in TEST_F() local 227 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F() 231 RunTest(updates, updates_shape, y, y_shape); in TEST_F() 235 TensorShape updates_shape({2, 4, 4}); in TEST_F() local 237 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F() 241 RunTest(updates, updates_shape, y, y_shape); in TEST_F() 245 TensorShape updates_shape({4}); in TEST_F() local 249 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F() 252 RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); in TEST_F() 256 TensorShape updates_shape({2, 4, 4}); in TEST_F() local [all …]
|
/external/tensorflow/tensorflow/compiler/tests/ |
D | scatter_nd_op_test.py | 104 updates_shape = (num_updates,) 106 updates_shape += (ref_shape[i],) 107 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_evaluator_typed_visitor.h | 2059 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateScatterIndices() argument 2060 int64 updates_rank = updates_shape.dimensions_size(); in IterationSpaceForUpdateScatterIndices() 2067 index_count[i] = updates_shape.dimensions(i); in IterationSpaceForUpdateScatterIndices() 2078 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateWindowIndices() argument 2079 int64 updates_rank = updates_shape.dimensions_size(); in IterationSpaceForUpdateWindowIndices() 2086 index_count[i] = updates_shape.dimensions(i); in IterationSpaceForUpdateWindowIndices() 2107 const Shape& updates_shape, const Literal* scatter_indices) in UpdateScatterIndexToInputIndex() argument 2109 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { in UpdateScatterIndexToInputIndex() 2237 const Shape& updates_shape) { in UpdateWindowIndexToInputIndex() argument 2240 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { in UpdateWindowIndexToInputIndex() [all …]
|
D | shape_inference.cc | 3071 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in ValidateScatterDimensionNumbers() argument 3084 const int64 updates_rank = updates_shape.rank(); in ValidateScatterDimensionNumbers() 3165 const Shape& updates_shape, const ProgramShape& to_apply_shape, in InferScatterShape() argument 3171 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); in InferScatterShape() 3194 {updates_shape.element_type()}, in InferScatterShape() 3206 if (updates_shape.rank() != expected_updates_rank) { in InferScatterShape() 3208 expected_updates_rank, updates_shape.rank()); in InferScatterShape() 3212 operand_shape, expanded_scatter_indices_shape, updates_shape, in InferScatterShape() 3227 if (updates_shape.dimensions(update_window_dim) > in InferScatterShape() 3233 update_window_dim, updates_shape.dimensions(update_window_dim), in InferScatterShape() [all …]
|
D | shape_inference.h | 292 const Shape& updates_shape, const ProgramShape& to_apply_shape,
|
/external/tensorflow/tensorflow/core/ops/ |
D | array_ops.cc | 2875 ShapeHandle updates_shape, in ScatterNdShapeHelper() argument 2879 c->Value(c->NumElements(updates_shape)) > 0)) { in ScatterNdShapeHelper() 2884 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { in ScatterNdShapeHelper() 2898 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); in ScatterNdShapeHelper() 2906 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper() 2914 c->Subshape(updates_shape, outer_dims, &suffix_updates)); in ScatterNdShapeHelper() 2920 " must match the inner ", c->Rank(updates_shape) - outer_dims, in ScatterNdShapeHelper() 2921 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper() 2934 ShapeHandle updates_shape; in ScatterNdShape() local 2935 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); in ScatterNdShape() [all …]
|
/external/tensorflow/tensorflow/core/framework/ |
D | common_shape_fns.cc | 1546 ShapeHandle updates_shape; in ScatterNdUpdateShape() local 1547 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); in ScatterNdUpdateShape() 1551 c->Value(c->NumElements(updates_shape)) > 0)) { in ScatterNdUpdateShape() 1556 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { in ScatterNdUpdateShape() 1570 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); in ScatterNdUpdateShape() 1578 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdUpdateShape() 1586 c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); in ScatterNdUpdateShape() 1592 " must match the inner ", c->Rank(updates_shape) - num_outer_dims, in ScatterNdUpdateShape() 1593 " dimensions of updates.shape=", c->DebugString(updates_shape), in ScatterNdUpdateShape()
|
/external/tensorflow/tensorflow/core/kernels/ |
D | scatter_nd_op.cc | 559 const TensorShape& updates_shape(updates.shape()); in PrepareAndValidateInputs() local 568 updates_shape.num_elements())) { in PrepareAndValidateInputs() 578 ", updates.shape ", updates_shape.DebugString()); in PrepareAndValidateInputs()
|
/external/tensorflow/tensorflow/python/kernel_tests/ |
D | scatter_nd_ops_test.py | 134 updates_shape = (num_updates,) 136 updates_shape += (ref_shape[i],) 137 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | xla_builder.cc | 1874 TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); in Scatter() 1879 input_shape, scatter_indices_shape, updates_shape, in Scatter()
|