Home
last modified time | relevance | path

Searched refs:xla_shape (Results 1 – 25 of 31) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/tf2xla/
Dliteral_util.cc27 xla::Shape xla_shape; in HostTensorToBorrowingLiteral() local
29 host_tensor.shape(), &xla_shape)); in HostTensorToBorrowingLiteral()
30 return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal); in HostTensorToBorrowingLiteral()
33 Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, in HostTensorToBorrowingLiteral() argument
38 tshape.dims() == xla_shape.dimensions_size() && in HostTensorToBorrowingLiteral()
39 tshape.dim_sizes() == xla_shape.dimensions()) in HostTensorToBorrowingLiteral()
42 static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape); in HostTensorToBorrowingLiteral()
54 xla::Shape xla_shape; in HostTensorToMutableBorrowingLiteral() local
56 host_tensor->shape(), &xla_shape)); in HostTensorToMutableBorrowingLiteral()
57 return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); in HostTensorToMutableBorrowingLiteral()
[all …]
Dxla_helpers.cc132 xla::Shape xla_shape; in IdentityShapeRepresentationFn() local
133 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); in IdentityShapeRepresentationFn()
134 return xla_shape; in IdentityShapeRepresentationFn()
142 xla::Shape* xla_shape) { in RewriteLayoutWithShardedShape() argument
157 sharding->TileOffsetForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
158 std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
159 std::vector<int64> dimensions(xla_shape->rank()); in RewriteLayoutWithShardedShape()
160 for (int64 i = 0; i < xla_shape->rank(); ++i) { in RewriteLayoutWithShardedShape()
164 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); in RewriteLayoutWithShardedShape()
169 xla_shape->element_type())); in RewriteLayoutWithShardedShape()
[all …]
Dliteral_util.h36 Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
52 const xla::Shape& xla_shape, Tensor* host_tensor,
Dxla_compiler.cc761 xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape); in CompileFunction() local
765 if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() && in CompileFunction()
766 xla_shape.is_static()) { in CompileFunction()
836 xla::Shape* xla_shape) const { in XLAShapeForArgument()
849 TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( in XLAShapeForArgument()
854 options_.shape_representation_fn, xla_shape)); in XLAShapeForArgument()
857 *xla_shape = absl::get<xla::Shape>(arg.shape); in XLAShapeForArgument()
860 arg.type, absl::get<TensorShape>(arg.shape), xla_shape)); in XLAShapeForArgument()
867 *xla_shape = absl::get<xla::Shape>(arg.shape); in XLAShapeForArgument()
877 TF_ASSIGN_OR_RETURN(*xla_shape, in XLAShapeForArgument()
[all …]
Dxla_compiler_test.cc309 xla::Shape xla_shape; in TEST_F() local
310 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); in TEST_F()
311 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); in TEST_F()
312 return xla_shape; in TEST_F()
352 xla::Shape xla_shape; in TEST_F() local
353 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); in TEST_F()
354 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); in TEST_F()
358 return xla_shape; in TEST_F()
405 xla::Shape xla_shape; in TEST_F() local
406 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); in TEST_F()
[all …]
Dxla_op_kernel.cc497 xla::Shape xla_shape; in ReadVariableInputTensor() local
499 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); in ReadVariableInputTensor()
500 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { in ReadVariableInputTensor()
640 xla::Shape xla_shape; in AssignVariableTensor() local
641 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); in AssignVariableTensor()
642 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { in AssignVariableTensor()
Dxla_expression.cc208 TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, in GetShape()
211 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); in GetShape()
Dxla_helpers.h89 xla::Shape* xla_shape);
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Drandom_ops.cc49 xla::Shape xla_shape; in Compile() local
50 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); in Compile()
58 XlaHelpers::One(b, dtype), xla_shape); in Compile()
226 xla::Shape xla_shape; in Compile() local
228 TensorShapeToXLAShape(input_type(1), shape, &xla_shape)); in Compile()
245 ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); in Compile()
265 xla::Shape xla_shape; in Compile() local
266 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); in Compile()
272 XlaHelpers::One(b, dtype), xla_shape); in Compile()
293 xla::Shape xla_shape; in Compile() local
[all …]
Dstateless_random_ops.cc162 xla::Shape xla_shape; in Compile() local
163 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
164 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
167 device_type_string_, seed, xla_shape, in Compile()
217 xla::Shape xla_shape; in Compile() local
218 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
220 xla_shape, minval, maxval); in Compile()
258 xla::Shape xla_shape; in Compile() local
259 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
261 StatelessRngUniformFullInt(device_type_string_, seed, xla_shape); in Compile()
[all …]
Dstateful_random_ops.cc227 xla::Shape xla_shape; in Compile() local
229 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
230 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
232 alg, key, state, xla_shape, in Compile()
271 xla::Shape xla_shape; in Compile() local
273 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
275 key, state, BitGen(alg), xla_shape); in Compile()
311 xla::Shape xla_shape; in Compile() local
313 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
316 alg, key, state, xla_shape, in Compile()
[all …]
Dstateless_random_ops_v2.cc223 xla::Shape xla_shape; in Compile() local
224 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
225 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
230 alg, key, counter, xla_shape, in Compile()
289 xla::Shape xla_shape; in Compile() local
290 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
294 StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval); in Compile()
336 xla::Shape xla_shape; in Compile() local
337 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
340 auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape); in Compile()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Doutfeed_ops.cc37 xla::Shape xla_shape; in Compile() local
39 ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape)); in Compile()
42 xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config); in Compile()
67 xla::Shape xla_shape; in Compile() local
69 TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape)); in Compile()
70 xla_shapes.push_back(xla_shape); in Compile()
Dhost_compute_ops.cc52 xla::Shape* xla_shape) { in MakeXlaShapes() argument
63 *xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes); in MakeXlaShapes()
151 xla::Shape xla_shape; in Compile() local
153 input_shapes[i], &xla_shape)); in Compile()
160 xla_shape.element_type()); in Compile()
166 xla::SendToHost(input_handles[i], token, xla_shape, channel)); in Compile()
417 xla::Shape xla_shape; in Compile() local
419 &xla_shape)); in Compile()
425 xla_shape.element_type()); in Compile()
430 xla::SendToHost(operand, token, xla_shape, channel); in Compile()
[all …]
Dinfeed_op.cc119 xla::Shape xla_shape; in InfeedDequeueTupleOp() local
121 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in InfeedDequeueTupleOp()
122 xla_shapes_.push_back(xla_shape); in InfeedDequeueTupleOp()
/external/tensorflow/tensorflow/compiler/xla/python_api/
DBUILD22 name = "xla_shape",
23 srcs = ["xla_shape.py"],
39 ":xla_shape",
Dxla_literal.py25 from tensorflow.compiler.xla.python_api import xla_shape
68 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
89 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
/external/tensorflow/tensorflow/stream_executor/tpu/
Dc_api_conversions.cc212 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { in ToC() argument
213 c_shape->element_type = xla_shape.element_type(); in ToC()
215 CopyVector(xla_shape.dimensions(), &c_shape->dimensions); in ToC()
216 CopyVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); in ToC()
218 c_shape->ntuple_shapes = xla_shape.tuple_shapes_size(); in ToC()
222 ToC(xla_shape.tuple_shapes(i), &c_shape->tuple_shapes[i]); in ToC()
226 if (xla_shape.has_layout()) { in ToC()
227 ToC(xla_shape.layout(), &c_shape->layout); in ToC()
317 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) { in ToC() argument
319 CHECK_LT(xla_shape.size(), 8); in ToC()
[all …]
Dc_api_conversions.h47 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
61 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape);
/external/tensorflow/tensorflow/core/tpu/kernels/
Dinfeed_ops.cc77 const xla::Shape& xla_shape) { in TransposeTensor() argument
79 const int64 rank = xla_shape.rank(); in TransposeTensor()
83 permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i); in TransposeTensor()
84 transposed_shapes[i] = xla_shape.dimensions(permutation[i]); in TransposeTensor()
92 xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) { in TransposeTensor()
307 xla::Shape xla_shape; in PrelinearizeTupleOp() local
309 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in PrelinearizeTupleOp()
310 xla_shapes.push_back(xla_shape); in PrelinearizeTupleOp()
460 xla::Shape xla_shape; in TpuInfeedEnqueueTupleOp() local
462 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in TpuInfeedEnqueueTupleOp()
[all …]
Doutfeed_ops.cc77 xla::Shape xla_shape; in TpuOutfeedDequeueTupleOp() local
79 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in TpuOutfeedDequeueTupleOp()
80 xla_shapes_.push_back(xla_shape); in TpuOutfeedDequeueTupleOp()
Dtpu_execute_op.cc231 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape(); in BuildComputationInputs() local
232 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) { in BuildComputationInputs()
236 expected.DebugString(), "; got ", xla_shape.DebugString()); in BuildComputationInputs()
421 const xla::Shape& xla_shape = in AllocateOutputTensors() local
423 if (!xla_shape.IsArray() || in AllocateOutputTensors()
424 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) { in AllocateOutputTensors()
427 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString()); in AllocateOutputTensors()
/external/tensorflow/tensorflow/compiler/jit/
Dxla_device_context.cc99 xla::Shape xla_shape; in XlaDeviceContext() local
100 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); in XlaDeviceContext()
101 return xla_shape; in XlaDeviceContext()
Dxla_tpu_device.cc56 xla::Shape xla_shape; in TpuShapeRepresentation() local
58 tensorflow::TensorShapeToXLAShape(type, shape, &xla_shape)); in TpuShapeRepresentation()
59 ApiConverter::StackHelper<XLA_Shape> se_shape(xla_shape); in TpuShapeRepresentation()
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
DBUILD18 "//tensorflow/compiler/xla/python_api:xla_shape",

12