Home
last modified time | relevance | path

Searched refs:indices_type (Results 1 – 15 of 15) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/
Dunsorted_segment_join_op.cc152 #define REGISTER_CPU_KERNEL(indices_type, num_segments_type) \ argument
156 .TypeConstraint<indices_type>("Tindices") \
158 UnsortedSegmentJoinOp<indices_type, num_segments_type>);
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dlegalize_hlo.cc972 ShapedType &indices_type, in NormalizeIndexVector() argument
975 if (index_vector_dim == indices_type.getRank()) { in NormalizeIndexVector()
977 indices_type.getShape().begin(), indices_type.getShape().end()); in NormalizeIndexVector()
979 indices_type = RankedTensorType::get(new_start_indices_shape, in NormalizeIndexVector()
980 indices_type.getElementType()); in NormalizeIndexVector()
982 indices_type, indices); in NormalizeIndexVector()
983 } else if (index_vector_dim != indices_type.getRank() - 1) { in NormalizeIndexVector()
1098 ShapedType indices_type = indices.getType().cast<ShapedType>(); in matchAndRewrite() local
1102 if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() || in matchAndRewrite()
1110 if (failed(NormalizeIndexVector(scatter_op, indices, indices_type, in matchAndRewrite()
[all …]
Dtensor_list_ops_decomposition.cc796 auto indices_type = scatter.indices().getType().cast<RankedTensorType>(); in HandleTensorListScatterIntoExistingListOp() local
797 if (!indices_type) return scatter.emitOpError("unranked indices shape"); in HandleTensorListScatterIntoExistingListOp()
802 shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1})); in HandleTensorListScatterIntoExistingListOp()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dsoftmax_op.cc224 const DataType indices_type = input_type(1); in Compile() local
243 xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices), in Compile()
245 builder, indices_type, depth))), in Compile()
Dgather_op.cc273 DataType indices_type = context->input_type(1); in Compile() local
297 indices_type, builder, &gather)); in Compile()
Dtensor_list_ops.cc398 DataType indices_type = ctx->input_type(1); in Compile() local
417 /*indices_are_nd=*/false, dtype_, indices_type, in Compile()
/external/tensorflow/tensorflow/lite/kernels/
Done_hot_test.cc37 T off_value = 0, TensorType indices_type = TensorType_INT32) { in OneHotOpModel() argument
38 indices_ = AddInput(indices_type); in OneHotOpModel()
/external/tensorflow/tensorflow/compiler/mlir/lite/utils/
Dperception_ops_utils_test.cc55 auto indices_type = RankedTensorType::get(input_shape, builder->getI64Type()); in createMaxUnpoolingFunc() local
57 SmallVector<mlir::Type, 2> input_types{input_type, indices_type}; in createMaxUnpoolingFunc()
/external/tensorflow/tensorflow/compiler/mlir/lite/ir/
Dtfl_ops.cc969 auto indices_type = indices.getType().cast<TensorType>(); in BuildGatherOp() local
972 if (!params_type.hasRank() || !indices_type.hasRank()) in BuildGatherOp()
978 int64_t indices_rank = indices_type.getRank(); in BuildGatherOp()
1003 std::copy(std::begin(indices_type.getShape()), in BuildGatherOp()
1004 std::end(indices_type.getShape()), std::begin(shape) + axis_i); in BuildGatherOp()
1014 std::copy(std::begin(indices_type.getShape()), in BuildGatherOp()
1015 std::end(indices_type.getShape()), std::begin(shape) + axis_i); in BuildGatherOp()
1035 auto indices_type = indices.getType().cast<ShapedType>(); in Verify() local
1037 if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) { in Verify()
1045 auto outer_dims = indices_type.getRank() - 1; in Verify()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Doptimize.cc182 auto indices_type = indices.getType().dyn_cast<RankedTensorType>(); in CanOptimizeIdentityGatherNdOrScatterNdOp() local
186 if (!params_type || !indices_type || indices_type.getRank() != 2 || in CanOptimizeIdentityGatherNdOrScatterNdOp()
187 indices_type.getDimSize(0) != params_type.getDimSize(0) || in CanOptimizeIdentityGatherNdOrScatterNdOp()
188 indices_type.getDimSize(1) != 1) in CanOptimizeIdentityGatherNdOrScatterNdOp()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc238 auto indices_type = value.getType().cast<RankedTensorType>(); in UnpackTensorAlongZeroDim() local
239 int num_outputs = indices_type.getShape().front(); in UnpackTensorAlongZeroDim()
241 num_outputs, RankedTensorType::get({}, indices_type.getElementType())); in UnpackTensorAlongZeroDim()
5068 auto indices_type = in matchAndRewrite() local
5071 op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); in matchAndRewrite()
5180 auto indices_type = indices.getType().cast<ShapedType>(); in matchAndRewrite() local
5181 if (!indices_type.hasStaticShape()) return failure(); in matchAndRewrite()
5183 if (indices_type.getRank() != 1) return failure(); in matchAndRewrite()
5186 indices_type.getDimSize(0), in matchAndRewrite()
5187 RankedTensorType::get({}, indices_type.getElementType())); in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/lite/
Dflatbuffer_export.cc1902 tflite::SparseIndexVector indices_type; in BuildSparsityParameters() local
1905 indices_type = tflite::SparseIndexVector_Uint8Vector; in BuildSparsityParameters()
1912 indices_type = tflite::SparseIndexVector_Uint16Vector; in BuildSparsityParameters()
1919 indices_type = tflite::SparseIndexVector_Int32Vector; in BuildSparsityParameters()
1927 array_segments, indices_type, array_indices); in BuildSparsityParameters()
/external/tensorflow/tensorflow/lite/toco/
Dexport_tensorflow.cc1825 const tensorflow::DataType indices_type = in ConvertReduceOperator() local
1827 (*new_op->mutable_attr())["Tidx"].set_type(indices_type); in ConvertReduceOperator()
/external/tensorflow/tensorflow/compiler/xla/service/
Dalgebraic_simplifier.cc2061 auto indices_type = dynamic_slice->operand(1)->shape().element_type(); in OptimizeDotOfGather() local
2062 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); in OptimizeDotOfGather()
2064 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); in OptimizeDotOfGather()
/external/tensorflow/tensorflow/core/grappler/optimizers/
Darithmetic_optimizer_test.cc4217 for (DataType indices_type : {DT_INT32, DT_INT64}) { in TEST_F()
4225 indices_type); in TEST_F()