Searched refs:buffer_shape (Results 1 – 7 of 7) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | scatter_nd_op.cc | 32 Status ValidateUpdateShape(const TensorShape& buffer_shape, in ValidateUpdateShape() argument 50 ", buffer_shape: ", buffer_shape.DebugString(), in ValidateUpdateShape() 55 if (buffer_shape.dims() < in ValidateUpdateShape() 60 batch_dim + buffer_shape.dims() - num_index_dims) { in ValidateUpdateShape() 70 buffer_shape.dim_size(d + num_index_dims)) { in ValidateUpdateShape() 87 TensorShape buffer_shape; in Compile() local 88 OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape)); in Compile() 91 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape), in Compile() 93 "got shape: ", buffer_shape.DebugString())); in Compile() 97 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 && in Compile() [all …]
|
D | tensor_list_utils.cc | 51 TensorShape* buffer_shape) { in GetTensorListBufferShape() argument 56 return GetTensorListBufferShape(list_tuple_shape, buffer_shape); in GetTensorListBufferShape() 60 TensorShape* buffer_shape) { in GetTensorListBufferShape() argument 63 xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape)); in GetTensorListBufferShape() 75 const TensorShape& buffer_shape, in InitializeTensorList() argument 80 if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) { in InitializeTensorList() 84 "buffer size: ", buffer_shape.dim_size(0)); in InitializeTensorList() 94 buffer_shape.dim_sizes()); in InitializeTensorList()
|
D | tensor_list_utils.h | 46 TensorShape* buffer_shape); 50 TensorShape* buffer_shape); 62 const TensorShape& buffer_shape,
|
D | segment_reduction_ops.cc | 77 TensorShape buffer_shape = data_shape; in Compile() local 78 buffer_shape.RemoveDimRange(0, indices_shape.dims()); in Compile() 79 buffer_shape.InsertDim(0, num_segments); in Compile() 81 xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); in Compile()
|
/external/tensorflow/tensorflow/compiler/tf2xla/lib/ |
D | scatter.cc | 39 TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); in XlaScatter() 51 if (num_index_dims > buffer_shape.rank()) { in XlaScatter() 56 xla::ShapeUtil::HumanString(buffer_shape), ")"); in XlaScatter() 74 if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { in XlaScatter() 77 xla::ShapeUtil::HumanString(buffer_shape)); in XlaScatter() 144 int64 buffer_rank = buffer_shape.rank(); in XlaScatter() 153 expected_updates_dims.push_back(buffer_shape.dimensions(dim)); in XlaScatter() 179 xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); in XlaScatter() 189 VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); in XlaScatter()
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | xla_compiler.cc | 678 xla::Shape buffer_shape; in XLAShapeForArgument() local 680 TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); in XLAShapeForArgument() 682 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); in XLAShapeForArgument()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | layout_assignment.cc | 518 const Shape& buffer_shape = instruction->operand(0)->shape(); in AddMandatoryConstraints() local 519 TF_RET_CHECK(buffer_shape.IsArray()); in AddMandatoryConstraints() 522 ->LayoutShapeForChannel(buffer_shape, all_reduce_id); in AddMandatoryConstraints()
|