Searched refs:scatter_indices_shape (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | scatter_expander.cc | 33 const Shape& scatter_indices_shape = scatter_indices->shape(); in TransposeIndexVectorDimToLast() local 35 if (scatter_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast() 39 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast() 44 permutation.reserve(scatter_indices_shape.dimensions_size()); in TransposeIndexVectorDimToLast() 45 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { in TransposeIndexVectorDimToLast() 115 const Shape& scatter_indices_shape, HloInstruction* updates, in AdjustScatterDims() argument 117 int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); in AdjustScatterDims() 118 if (index_vector_dim < scatter_indices_shape.dimensions_size()) { in AdjustScatterDims() 332 const Shape& scatter_indices_shape = scatter_indices->shape(); in ScatterTripCount() local 336 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { in ScatterTripCount() [all …]
|
D | shape_inference.cc | 3445 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape, in ValidateScatterDimensionNumbers() argument 3501 scatter_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateScatterDimensionNumbers() 3508 scatter_indices_shape[dim_numbers.index_vector_dim()]); in ValidateScatterDimensionNumbers() 3539 const Shape& operand_shape, const Shape& scatter_indices_shape, in InferScatterShape() argument 3545 ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); in InferScatterShape() 3548 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { in InferScatterShape() 3551 ShapeUtil::HumanString(scatter_indices_shape)); in InferScatterShape() 3554 if (scatter_indices_shape.dimensions_size() < in InferScatterShape() 3561 scatter_indices_shape.dimensions_size(), in InferScatterShape() 3573 SpanToVector(scatter_indices_shape.dimensions()); in InferScatterShape()
|
D | shape_inference.h | 322 const Shape& operand_shape, const Shape& scatter_indices_shape,
|
/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | ir_emitter_unnested.h | 480 Shape scatter_indices_shape; member
|
D | ir_emitter_unnested.cc | 1963 desc.scatter_indices_shape = root->operand(1)->shape(); in HandleFusion() 2582 desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType()); in EmitScatter() 2640 Shape scatter_indices_shape_fixed = desc.scatter_indices_shape; in EmitScatter() 2642 desc.scatter_indices_shape.rank()) { in EmitScatter() 2673 scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); in EmitScatter()
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | xla_builder.cc | 2308 TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape, in Scatter() 2315 *input_shape, *scatter_indices_shape, *updates_shape, in Scatter()
|