Home
last modified time | relevance | path

Searched refs:iota_shape (Results 1 – 11 of 11) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dwhere_op.cc156 auto iota_shape = input_shape; in CompileWhereWithSort() local
157 iota_shape.set_element_type(xla::S32); in CompileWhereWithSort()
159 int64_t flattened_size = xla::Product(iota_shape.dimensions()); in CompileWhereWithSort()
168 for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) { in CompileWhereWithSort()
169 XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); in CompileWhereWithSort()
179 for (int64_t i = 0; i < iota_shape.rank(); ++i) { in CompileWhereWithSort()
271 auto iota_shape = input_shape; in CompileWhereWithPrefixSum() local
272 iota_shape.set_element_type(S32); in CompileWhereWithPrefixSum()
273 for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) { in CompileWhereWithPrefixSum()
275 xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1})); in CompileWhereWithPrefixSum()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dstable_sort_expander.cc64 Shape iota_shape = sort->operand(0)->shape(); in ExpandInstruction() local
69 if (iota_shape.dimensions(sort->sort_dimension()) > in ExpandInstruction()
74 iota_shape.set_element_type(S32); in ExpandInstruction()
76 HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); in ExpandInstruction()
104 new_shapes.push_back(iota_shape); in ExpandInstruction()
Dselect_and_scatter_expander.cc41 const auto iota_shape = ShapeUtil::ChangeElementType(operand_shape, S32); in ExpandInstruction() local
45 ShapeUtil::MakeScalarShape(iota_shape.element_type()); in ExpandInstruction()
56 computation->AddInstruction(HloInstruction::CreateIota(iota_shape, i))); in ExpandInstruction()
178 ShapeUtil::MakeShape(iota_shape.element_type(), concatenated_iotas_dims), in ExpandInstruction()
179 iota_indices, iota_shape.rank())); in ExpandInstruction()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dsorting.cc48 Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions()); in TopK() local
49 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopK()
168 Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions()); in TopKWithPartitions() local
169 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopKWithPartitions()
Darithmetic.cc151 auto iota_shape = input_shape; in ArgMinMax() local
152 iota_shape.set_element_type(index_type); in ArgMinMax()
153 XlaOp iota = Iota(builder, iota_shape, axis); in ArgMinMax()
Dmatrix.cc298 Shape iota_shape = x_shape; in EinsumDiagonalMask() local
299 iota_shape.set_element_type(S32); in EinsumDiagonalMask()
307 mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), in EinsumDiagonalMask()
308 Iota(builder, iota_shape, dim))); in EinsumDiagonalMask()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_helpers.cc102 xla::Shape iota_shape; in OneHot() local
104 TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); in OneHot()
108 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), in OneHot()
/external/tensorflow/tensorflow/python/ops/
Dnn_grad.py1161 iota_shape = list(itertools.repeat(1, rank + 1))
1162 iota_shape[d] = iota_len
1163 iota = array_ops.reshape(math_ops.range(iota_len), iota_shape)
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dlegalize_hlo.cc1506 auto iota_shape = iota_type.getShape(); in MatchIotaConst() local
1514 iota_const_attr, iota_shape, *index, reduce_dim); in MatchIotaConst()
1519 std::move(*index), iota_shape, reduce_dim); in MatchIotaConst()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc6847 auto iota_shape = llvm::to_vector<4>(batch_dims); in QRBlock() local
6848 iota_shape.push_back(n); in QRBlock()
6850 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in QRBlock()
6929 auto iota_shape = llvm::to_vector<4>(batch_dims); in ComputeWYRepresentation() local
6930 iota_shape.append({m, n}); in ComputeWYRepresentation()
6932 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in ComputeWYRepresentation()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.cc4835 const Shape iota_shape = in PreprocessHlos() local
4838 HloInstruction::CreateIota(iota_shape, dim)); in PreprocessHlos()
4863 HloInstruction::CreateBroadcast(iota_shape, limit, {})); in PreprocessHlos()