Home
last modified time | relevance | path

Searched refs:iota_shape (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dwhere_op.cc44 auto iota_shape = input_shape.ValueOrDie(); in Compile() local
45 iota_shape.set_element_type(xla::S32); in Compile()
47 int64 flattened_size = xla::Product(iota_shape.dimensions()); in Compile()
63 for (int64 axis = 0; axis < iota_shape.rank(); ++axis) { in Compile()
64 xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); in Compile()
75 for (int64 i = 0; i < iota_shape.rank(); ++i) { in Compile()
Dmatrix_band_part_op.cc64 xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); in Compile() local
65 xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); in Compile()
66 xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); in Compile()
Dunique_op.cc151 auto iota_shape = input_shape; in DataOutputFastPath() local
152 iota_shape.set_element_type(xla::S32); in DataOutputFastPath()
154 xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, 0); in DataOutputFastPath()
/external/tensorflow/tensorflow/compiler/xla/service/
Dstable_sort_expander.cc64 Shape iota_shape = sort->operand(0)->shape(); in ExpandInstruction() local
69 if (iota_shape.dimensions(sort->sort_dimension()) > in ExpandInstruction()
74 iota_shape.set_element_type(S32); in ExpandInstruction()
76 HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); in ExpandInstruction()
104 new_shapes.push_back(iota_shape); in ExpandInstruction()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dsorting.cc48 Shape iota_shape = in TopK() local
50 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopK()
90 Shape iota_shape = in TopKWithPartitions() local
92 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopKWithPartitions()
Darithmetic.cc157 auto iota_shape = input_shape; in ArgMinMax() local
158 iota_shape.set_element_type(index_type); in ArgMinMax()
159 XlaOp iota = Iota(builder, iota_shape, axis); in ArgMinMax()
Dmatrix.cc267 Shape iota_shape = x_shape; in EinsumDiagonalMask() local
268 iota_shape.set_element_type(S32); in EinsumDiagonalMask()
276 mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), in EinsumDiagonalMask()
277 Iota(builder, iota_shape, dim))); in EinsumDiagonalMask()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_helpers.cc94 xla::Shape iota_shape; in OneHot() local
96 TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); in OneHot()
100 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), in OneHot()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc5934 auto iota_shape = llvm::to_vector<4>(batch_dims); in QRBlock() local
5935 iota_shape.push_back(n); in QRBlock()
5937 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in QRBlock()
6017 auto iota_shape = llvm::to_vector<4>(batch_dims); in ComputeWYRepresentation() local
6018 iota_shape.append({m, n}); in ComputeWYRepresentation()
6020 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in ComputeWYRepresentation()