Home
last modified time | relevance | path

Searched refs:updates_shape (Results 1 – 19 of 19) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dscatter_nd_op.cc36 const TensorShape& updates_shape) { in ValidateUpdateShape() argument
50 updates_shape.DebugString(), in ValidateUpdateShape()
56 if (updates_shape.dims() < batch_dim) return shape_err(); in ValidateUpdateShape()
58 num_index_dims + (updates_shape.dims() - batch_dim)) { in ValidateUpdateShape()
61 if (updates_shape.dims() != in ValidateUpdateShape()
66 if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { in ValidateUpdateShape()
70 for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { in ValidateUpdateShape()
71 if (updates_shape.dim_size(d + batch_dim) != in ValidateUpdateShape()
87 TensorShape updates_shape = context->InputShape(1); in Compile() local
100 updates_shape.num_elements() == 0), in Compile()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc40 TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); in XlaScatter()
143 int64 updates_rank = updates_shape.rank(); in XlaScatter()
158 TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); in XlaScatter()
159 updates_rank = updates_shape.rank(); in XlaScatter()
191 VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); in XlaScatter()
/external/tensorflow/tensorflow/cc/gradients/
Darray_grad_test.cc225 TensorShape updates_shape({4}); in TEST_F() local
227 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
231 RunTest(updates, updates_shape, y, y_shape); in TEST_F()
235 TensorShape updates_shape({2, 4, 4}); in TEST_F() local
237 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
241 RunTest(updates, updates_shape, y, y_shape); in TEST_F()
245 TensorShape updates_shape({4}); in TEST_F() local
249 Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); in TEST_F()
252 RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); in TEST_F()
256 TensorShape updates_shape({2, 4, 4}); in TEST_F() local
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/v1_compat_tests/
Dscatter_nd_ops_test.py108 updates_shape = (num_updates,)
110 updates_shape += (ref_shape[i],)
111 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
/external/tensorflow/tensorflow/core/ops/
Dstate_ops.cc144 ShapeHandle updates_shape; in ScatterNdUpdateShape() local
145 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); in ScatterNdUpdateShape()
146 return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, in ScatterNdUpdateShape()
Darray_ops.cc3096 ShapeHandle updates_shape; in ScatterNdTensorShape() local
3097 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); in ScatterNdTensorShape()
3098 return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, in ScatterNdTensorShape()
3142 ShapeHandle updates_shape; in __anon42d741194802() local
3143 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &updates_shape)); in __anon42d741194802()
3147 updates_shape, output_shape); in __anon42d741194802()
/external/tensorflow/tensorflow/compiler/tests/
Dscatter_nd_op_test.py105 updates_shape = (num_updates,)
107 updates_shape += (ref_shape[i],)
108 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_evaluator_typed_visitor.h2082 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
2083 int64 updates_rank = updates_shape.dimensions_size();
2090 index_count[i] = updates_shape.dimensions(i);
2101 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
2102 int64 updates_rank = updates_shape.dimensions_size();
2109 index_count[i] = updates_shape.dimensions(i);
2130 const Shape& updates_shape, const Literal* scatter_indices)
2132 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
2260 const Shape& updates_shape) {
2263 for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
[all …]
Dshape_inference.cc3446 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in ValidateScatterDimensionNumbers() argument
3459 const int64 updates_rank = updates_shape.rank(); in ValidateScatterDimensionNumbers()
3540 const Shape& updates_shape, const ProgramShape& to_apply_shape, in InferScatterShape() argument
3546 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); in InferScatterShape()
3569 {updates_shape.element_type()}, in InferScatterShape()
3581 if (updates_shape.rank() != expected_updates_rank) { in InferScatterShape()
3583 expected_updates_rank, updates_shape.rank()); in InferScatterShape()
3587 operand_shape, expanded_scatter_indices_shape, updates_shape, in InferScatterShape()
3602 if (updates_shape.dimensions(update_window_dim) > in InferScatterShape()
3608 update_window_dim, updates_shape.dimensions(update_window_dim), in InferScatterShape()
[all …]
Dshape_inference.h323 const Shape& updates_shape, const ProgramShape& to_apply_shape,
/external/tensorflow/tensorflow/core/framework/
Dcommon_shape_fns.h262 ShapeHandle updates_shape, ShapeHandle input_shape);
Dcommon_shape_fns.cc2325 ShapeHandle updates_shape, in ScatterNdShapeHelper() argument
2329 c->Value(c->NumElements(updates_shape)) > 0)) { in ScatterNdShapeHelper()
2334 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { in ScatterNdShapeHelper()
2348 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); in ScatterNdShapeHelper()
2357 ") of updates[shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper()
2365 c->Subshape(updates_shape, outer_dims, &suffix_updates)); in ScatterNdShapeHelper()
2372 outer_dims, ",", c->Rank(updates_shape), in ScatterNdShapeHelper()
2373 ") of updates[shape=", c->DebugString(updates_shape), in ScatterNdShapeHelper()
/external/tensorflow/tensorflow/core/kernels/
Dscatter_nd_op.cc823 const TensorShape& updates_shape(updates.shape()); in PrepareAndValidateInputs() local
832 updates_shape.num_elements())) { in PrepareAndValidateInputs()
842 "shape=", updates_shape.DebugString(), "] = ", updates.dim_size(0)); in PrepareAndValidateInputs()
/external/tensorflow/tensorflow/python/kernel_tests/array_ops/
Dscatter_nd_ops_test.py143 updates_shape = (num_updates,)
145 updates_shape += (ref_shape[i],)
146 updates = _AsType(np.random.randn(*(updates_shape)), vtype)
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dir_emitter_unnested.h481 Shape updates_shape; member
Dir_emitter_unnested.cc1964 desc.updates_shape = root->operand(2)->shape(); in HandleFusion()
2583 desc.updates_shape = TypeToShape(scatter.updates().getType()); in EmitScatter()
2611 raw_window_bounds.push_back(desc.updates_shape.dimensions(i)); in EmitScatter()
2702 llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(), in EmitScatter()
2723 desc.updates_shape, ir_emitter_context_->gpu_device_info()); in EmitScatter()
2727 return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape, in EmitScatter()
/external/tensorflow/tensorflow/lite/kernels/internal/reference/
Dreference_ops.h1140 const RuntimeShape& updates_shape, in ScatterNd() argument
1149 const int updates_dims = updates_shape.DimensionsCount(); in ScatterNd()
1154 slice_size *= updates_shape.Dims(i); in ScatterNd()
/external/tensorflow/tensorflow/python/kernel_tests/
Darray_ops_test.py1276 self, shape, begin, end, strides, updates_shape, *args): argument
1282 f, [array_ops.zeros(shape), array_ops.ones(updates_shape)], delta=1.0)
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc2310 TF_ASSIGN_OR_RETURN(const Shape* updates_shape, GetShapePtr(updates)); in Scatter()
2315 *input_shape, *scatter_indices_shape, *updates_shape, in Scatter()