Home
last modified time | relevance | path

Searched refs:scatter_indices_shape (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dscatter_expander.cc33 const Shape& scatter_indices_shape = scatter_indices->shape(); in TransposeIndexVectorDimToLast() local
35 if (scatter_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast()
39 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast()
44 permutation.reserve(scatter_indices_shape.dimensions_size()); in TransposeIndexVectorDimToLast()
45 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { in TransposeIndexVectorDimToLast()
115 const Shape& scatter_indices_shape, HloInstruction* updates, in AdjustScatterDims() argument
117 int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); in AdjustScatterDims()
118 if (index_vector_dim < scatter_indices_shape.dimensions_size()) { in AdjustScatterDims()
332 const Shape& scatter_indices_shape = scatter_indices->shape(); in ScatterTripCount() local
336 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { in ScatterTripCount()
[all …]
Dshape_inference.cc3445 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape, in ValidateScatterDimensionNumbers() argument
3501 scatter_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateScatterDimensionNumbers()
3508 scatter_indices_shape[dim_numbers.index_vector_dim()]); in ValidateScatterDimensionNumbers()
3539 const Shape& operand_shape, const Shape& scatter_indices_shape, in InferScatterShape() argument
3545 ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); in InferScatterShape()
3548 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { in InferScatterShape()
3551 ShapeUtil::HumanString(scatter_indices_shape)); in InferScatterShape()
3554 if (scatter_indices_shape.dimensions_size() < in InferScatterShape()
3561 scatter_indices_shape.dimensions_size(), in InferScatterShape()
3573 SpanToVector(scatter_indices_shape.dimensions()); in InferScatterShape()
Dshape_inference.h322 const Shape& operand_shape, const Shape& scatter_indices_shape,
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dir_emitter_unnested.h480 Shape scatter_indices_shape; member
Dir_emitter_unnested.cc1963 desc.scatter_indices_shape = root->operand(1)->shape(); in HandleFusion()
2582 desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType()); in EmitScatter()
2640 Shape scatter_indices_shape_fixed = desc.scatter_indices_shape; in EmitScatter()
2642 desc.scatter_indices_shape.rank()) { in EmitScatter()
2673 scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); in EmitScatter()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc2308 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape, in Scatter()
2315 *input_shape, *scatter_indices_shape, *updates_shape, in Scatter()