Home
last modified time | relevance | path

Searched refs:xla_shape (Results 1 – 25 of 47) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Drandom_ops.cc53 xla::Shape xla_shape; in Compile() local
54 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); in Compile()
63 XlaHelpers::One(b, dtype), xla_shape); in Compile()
87 xla::Shape xla_shape; in Compile() local
89 TensorShapeToXLAShape(input_type(1), shape, &xla_shape)); in Compile()
107 ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); in Compile()
128 xla::Shape xla_shape; in Compile() local
129 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); in Compile()
135 XlaHelpers::One(b, dtype), xla_shape); in Compile()
159 xla::Shape xla_shape; in Compile() local
[all …]
Dstateless_random_ops.cc163 xla::Shape xla_shape; in Compile() local
164 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
165 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
168 device_type_string_, seed, xla_shape, in Compile()
218 xla::Shape xla_shape; in Compile() local
219 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
221 xla_shape, minval, maxval); in Compile()
259 xla::Shape xla_shape; in Compile() local
260 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
262 StatelessRngUniformFullInt(device_type_string_, seed, xla_shape); in Compile()
[all …]
Dstateful_random_ops.cc228 xla::Shape xla_shape; in Compile() local
230 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
231 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
233 alg, key, state, xla_shape, in Compile()
272 xla::Shape xla_shape; in Compile() local
274 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
276 key, state, BitGen(alg), xla_shape); in Compile()
312 xla::Shape xla_shape; in Compile() local
314 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
317 alg, key, state, xla_shape, in Compile()
[all …]
Dstateless_random_ops_v2.cc237 xla::Shape xla_shape; in Compile() local
238 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); in Compile()
239 xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); in Compile()
244 alg, key, counter, xla_shape, in Compile()
319 xla::Shape xla_shape; in Compile() local
320 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
324 StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval); in Compile()
369 xla::Shape xla_shape; in Compile() local
370 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); in Compile()
373 auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape); in Compile()
[all …]
Dxla_call_module_op.cc161 OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(i)); in RefineDynamicShapes()
162 std::vector<int64_t> xla_dimensions(xla_shape.dimensions().begin(), in RefineDynamicShapes()
163 xla_shape.dimensions().end()); in RefineDynamicShapes()
166 ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); in RefineDynamicShapes()
219 OP_REQUIRES_VALUE(xla::Shape xla_shape, ctx, ctx->InputXlaShape(arg_idx)); in PopulateDimArgInputs()
221 int64_t dim_arg_val = xla_shape.dimensions()[arg_axis_idx]; in PopulateDimArgInputs()
Dlight_outside_compilation.cc106 TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, ctx->InputXlaShape(i)); in CompileToCustomCallCallingTfKernel()
107 if (absl::c_any_of(xla_shape.dynamic_dimensions(), in CompileToCustomCallCallingTfKernel()
132 operand_shapes_with_layout.push_back(xla_shape); in CompileToCustomCallCallingTfKernel()
385 xla::Shape xla_shape, in PopulateMetadataBufferIfNeeded()
390 xla::ShapeUtil::ByteSizeOf(xla_shape); in PopulateMetadataBufferIfNeeded()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dliteral_util.cc27 xla::Shape xla_shape; in HostTensorToBorrowingLiteral() local
29 host_tensor.shape(), &xla_shape)); in HostTensorToBorrowingLiteral()
30 return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal); in HostTensorToBorrowingLiteral()
33 Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, in HostTensorToBorrowingLiteral() argument
38 tshape.dims() == xla_shape.dimensions_size() && in HostTensorToBorrowingLiteral()
39 tshape.dim_sizes() == xla_shape.dimensions()) in HostTensorToBorrowingLiteral()
42 static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape); in HostTensorToBorrowingLiteral()
54 xla::Shape xla_shape; in HostTensorToMutableBorrowingLiteral() local
56 host_tensor->shape(), &xla_shape)); in HostTensorToMutableBorrowingLiteral()
57 return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); in HostTensorToMutableBorrowingLiteral()
[all …]
Dlayout_util.cc40 xla::Shape* xla_shape) { in RewriteLayoutWithShardedShape() argument
55 sharding->TileOffsetForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
57 sharding->TileLimitForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
58 std::vector<int64_t> dimensions(xla_shape->rank()); in RewriteLayoutWithShardedShape()
59 for (int64_t i = 0; i < xla_shape->rank(); ++i) { in RewriteLayoutWithShardedShape()
63 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); in RewriteLayoutWithShardedShape()
68 xla_shape->element_type())); in RewriteLayoutWithShardedShape()
75 *xla_shape->mutable_layout() = per_device_xla_shape.layout(); in RewriteLayoutWithShardedShape()
Dliteral_util.h36 Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
52 const xla::Shape& xla_shape, Tensor* host_tensor,
Dxla_compiler.cc770 xla::Shape xla_shape = std::get<xla::Shape>(args[i].shape); in CompileFunction() local
774 if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() && in CompileFunction()
775 xla_shape.is_static()) { in CompileFunction()
853 xla::Shape* xla_shape) const { in XLAShapeForArgument()
870 *xla_shape, in XLAShapeForArgument()
876 options_.shape_determination_fns, xla_shape)); in XLAShapeForArgument()
879 *xla_shape = std::get<xla::Shape>(arg.shape); in XLAShapeForArgument()
882 arg.type, std::get<TensorShape>(arg.shape), xla_shape)); in XLAShapeForArgument()
889 *xla_shape = std::get<xla::Shape>(arg.shape); in XLAShapeForArgument()
903 *xla_shape, in XLAShapeForArgument()
[all …]
Dxla_helpers.cc140 xla::Shape xla_shape; in IdentityShapeRepresentationFn() local
141 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); in IdentityShapeRepresentationFn()
142 return xla_shape; in IdentityShapeRepresentationFn()
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Doutfeed_ops.cc37 xla::Shape xla_shape; in Compile() local
39 ctx, TensorShapeToXLAShape(dtype_, ctx->InputShape(0), &xla_shape)); in Compile()
42 xla::Outfeed(ctx->Input(0), xla_shape, outfeed_config); in Compile()
67 xla::Shape xla_shape; in Compile() local
69 TensorShapeToXLAShape(dtypes_[i], shapes[i], &xla_shape)); in Compile()
70 xla_shapes.push_back(xla_shape); in Compile()
Dhost_compute_ops.cc52 xla::Shape* xla_shape) { in MakeXlaShapes() argument
63 *xla_shape = xla::ShapeUtil::MakeTupleShape(*xla_shapes); in MakeXlaShapes()
163 xla::Shape xla_shape; in Compile() local
165 input_shapes[i], &xla_shape)); in Compile()
174 xla_shape.element_type()); in Compile()
180 xla::SendToHost(input_handles[i], token, xla_shape, channel)); in Compile()
439 xla::Shape xla_shape; in Compile() local
441 &xla_shape)); in Compile()
449 xla_shape.element_type()); in Compile()
454 xla::SendToHost(operand, token, xla_shape, channel); in Compile()
[all …]
Dinfeed_op.cc119 xla::Shape xla_shape; in InfeedDequeueTupleOp() local
121 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in InfeedDequeueTupleOp()
122 xla_shapes_.push_back(xla_shape); in InfeedDequeueTupleOp()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dlayout_util.cc25 xla::Shape* xla_shape) { in RewriteLayoutWithShardedShape() argument
40 sharding->TileOffsetForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
42 sharding->TileLimitForDevice(*xla_shape, device); in RewriteLayoutWithShardedShape()
43 std::vector<int64_t> dimensions(xla_shape->rank()); in RewriteLayoutWithShardedShape()
44 for (int64_t i = 0; i < xla_shape->rank(); ++i) { in RewriteLayoutWithShardedShape()
48 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); in RewriteLayoutWithShardedShape()
59 *xla_shape->mutable_layout() = per_device_xla_shape.layout(); in RewriteLayoutWithShardedShape()
/external/tensorflow/tensorflow/compiler/xla/python_api/
DBUILD20 name = "xla_shape",
21 srcs = ["xla_shape.py"],
37 ":xla_shape",
Dxla_literal.py21 from tensorflow.compiler.xla.python_api import xla_shape
64 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
85 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
/external/tensorflow/tensorflow/core/tpu/kernels/
Dinfeed_ops.cc77 const xla::Shape& xla_shape) { in TransposeTensor() argument
79 const int64_t rank = xla_shape.rank(); in TransposeTensor()
83 permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i); in TransposeTensor()
84 transposed_shapes[i] = xla_shape.dimensions(permutation[i]); in TransposeTensor()
92 xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) { in TransposeTensor()
307 xla::Shape xla_shape; in PrelinearizeTupleOp() local
309 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in PrelinearizeTupleOp()
310 xla_shapes.push_back(xla_shape); in PrelinearizeTupleOp()
480 xla::Shape xla_shape; in TpuInfeedEnqueueTupleOp() local
482 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in TpuInfeedEnqueueTupleOp()
[all …]
Dtpu_reshard_variables_op_util.cc132 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape(); in BuildInputBuffers() local
133 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) { in BuildInputBuffers()
137 expected.DebugString(), "; got ", xla_shape.DebugString()); in BuildInputBuffers()
252 const xla::Shape& xla_shape = in UpdateOutputVariables() local
254 if (!xla_shape.IsArray() || in UpdateOutputVariables()
255 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) { in UpdateOutputVariables()
258 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString()); in UpdateOutputVariables()
Doutfeed_ops.h97 xla::Shape xla_shape; in TpuOutfeedDequeueTupleOp() local
99 TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); in TpuOutfeedDequeueTupleOp()
100 xla_shapes_.push_back(xla_shape); in TpuOutfeedDequeueTupleOp()
Dtpu_execute_op.cc230 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape(); in BuildComputationInputs() local
231 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) { in BuildComputationInputs()
235 expected.DebugString(), "; got ", xla_shape.DebugString()); in BuildComputationInputs()
420 const xla::Shape& xla_shape = in AllocateOutputTensors() local
422 if (!xla_shape.IsArray() || in AllocateOutputTensors()
423 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) { in AllocateOutputTensors()
426 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString()); in AllocateOutputTensors()
/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/
Dc_api_conversions.cc233 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape) { in ToC() argument
234 c_shape->element_type = xla_shape.element_type(); in ToC()
236 CreateVector(xla_shape.dimensions(), &c_shape->dimensions); in ToC()
237 CreateVector(xla_shape.dynamic_dimensions(), &c_shape->dynamic_dimensions); in ToC()
239 c_shape->ntuple_shapes = xla_shape.tuple_shapes_size(); in ToC()
243 ToC(xla_shape.tuple_shapes(i), &c_shape->tuple_shapes[i]); in ToC()
247 if (xla_shape.has_layout()) { in ToC()
249 ToC(xla_shape.layout(), &c_shape->layout); in ToC()
347 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape) { in ToC() argument
349 CHECK_LT(xla_shape.size(), 8); in ToC()
[all …]
Dc_api_conversions.h57 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
71 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape);
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc134 xla::Shape& xla_shape = individual_arg_shapes.back(); in GetXlaInputShapes() local
141 TF_ASSIGN_OR_RETURN(xla_shape, in GetXlaInputShapes()
150 sharding, shape_determination_fns, &xla_shape)); in GetXlaInputShapes()
192 const xla::Shape& xla_shape) -> StatusOr<xla::Shape> { in GetOutputInfo() argument
194 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); in GetOutputInfo()
196 xla_shape.element_type())); in GetOutputInfo()
528 [&](const xla::Shape& xla_shape) -> StatusOr<mlir::XlaLayoutPreference> { in ConvertMLIRToXlaComputation() argument
530 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); in ConvertMLIRToXlaComputation()
532 xla_shape.element_type())); in ConvertMLIRToXlaComputation()
537 [&](const xla::Shape& xla_shape, bool fast_mem, in ConvertMLIRToXlaComputation()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/translate/
Dlayouts_and_names.mlir23 xla_shape = "f16[128,64,112,112]{1,3,2,0}",
31 …%cst_1 = "arith.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, xla_shape = "s32[1,1]{0,1}"}…
78 …%0 = "mhlo.infeed"(%arg0) {infeed_config = "foobar", layout = [], xla_shape = "((), token[])"} : (…

12