Searched refs:iota_shape (Results 1 – 9 of 9) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | where_op.cc | 44 auto iota_shape = input_shape.ValueOrDie(); in Compile() local 45 iota_shape.set_element_type(xla::S32); in Compile() 47 int64 flattened_size = xla::Product(iota_shape.dimensions()); in Compile() 63 for (int64 axis = 0; axis < iota_shape.rank(); ++axis) { in Compile() 64 xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); in Compile() 75 for (int64 i = 0; i < iota_shape.rank(); ++i) { in Compile()
|
D | matrix_band_part_op.cc | 64 xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); in Compile() local 65 xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); in Compile() 66 xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); in Compile()
|
D | unique_op.cc | 151 auto iota_shape = input_shape; in DataOutputFastPath() local 152 iota_shape.set_element_type(xla::S32); in DataOutputFastPath() 154 xla::XlaOp iota = xla::Iota(ctx->builder(), iota_shape, 0); in DataOutputFastPath()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | stable_sort_expander.cc | 64 Shape iota_shape = sort->operand(0)->shape(); in ExpandInstruction() local 69 if (iota_shape.dimensions(sort->sort_dimension()) > in ExpandInstruction() 74 iota_shape.set_element_type(S32); in ExpandInstruction() 76 HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); in ExpandInstruction() 104 new_shapes.push_back(iota_shape); in ExpandInstruction()
|
/external/tensorflow/tensorflow/compiler/xla/client/lib/ |
D | sorting.cc | 48 Shape iota_shape = in TopK() local 50 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopK() 90 Shape iota_shape = in TopKWithPartitions() local 92 XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); in TopKWithPartitions()
|
D | arithmetic.cc | 157 auto iota_shape = input_shape; in ArgMinMax() local 158 iota_shape.set_element_type(index_type); in ArgMinMax() 159 XlaOp iota = Iota(builder, iota_shape, axis); in ArgMinMax()
|
D | matrix.cc | 267 Shape iota_shape = x_shape; in EinsumDiagonalMask() local 268 iota_shape.set_element_type(S32); in EinsumDiagonalMask() 276 mask = And(mask, Eq(Iota(builder, iota_shape, first_dim), in EinsumDiagonalMask() 277 Iota(builder, iota_shape, dim))); in EinsumDiagonalMask()
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | xla_helpers.cc | 94 xla::Shape iota_shape; in OneHot() local 96 TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); in OneHot() 100 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), in OneHot()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf.cc | 5934 auto iota_shape = llvm::to_vector<4>(batch_dims); in QRBlock() local 5935 iota_shape.push_back(n); in QRBlock() 5937 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in QRBlock() 6017 auto iota_shape = llvm::to_vector<4>(batch_dims); in ComputeWYRepresentation() local 6018 iota_shape.append({m, n}); in ComputeWYRepresentation() 6020 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), in ComputeWYRepresentation()
|