Searched refs:axis_type (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/lite/kernels/ |
D | arg_min_max_test.cc | 33 ArgBaseOpModel(TensorType input_type, int axis_value, TensorType axis_type, in ArgBaseOpModel() argument 36 axis_type_(axis_type), in ArgBaseOpModel() 40 if (axis_type == TensorType_INT64) { in ArgBaseOpModel() 42 AddConstInput(axis_type, {static_cast<int64_t>(axis_value)}, {1}); in ArgBaseOpModel() 44 axis_ = AddConstInput(axis_type, {axis_value}, {1}); in ArgBaseOpModel() 47 axis_ = AddInput(axis_type); in ArgBaseOpModel() 85 int axis_value, TensorType axis_type, bool constant_axis, in ArgMaxOpModel() argument 87 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis, in ArgMaxOpModel() 100 int axis_value, TensorType axis_type, bool constant_axis, in ArgMinOpModel() argument 102 : ArgBaseOpModel(input_type, axis_value, axis_type, constant_axis, in ArgMinOpModel()
|
D | arg_min_max.cc | 139 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \ in Eval() argument 142 GetTensorData<axis_type>(axis), GetTensorShape(output), \ in Eval()
|
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | gather_op.cc | 169 DataType axis_type = context->input_type(2); in XlaGatherWithBatchDimsOpImpl() local 170 if (axis_type != DT_INT32 && axis_type != DT_INT64) { in XlaGatherWithBatchDimsOpImpl()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | lower_tf.cc | 1591 auto axis_type = in matchAndRewrite() local 1597 auto begin_attr = DenseIntElementsAttr::get(axis_type, begin_values); in matchAndRewrite() 1599 rewriter.create<ConstOp>(op->getLoc(), axis_type, begin_attr); in matchAndRewrite() 1604 auto size_attr = DenseIntElementsAttr::get(axis_type, output_shape); in matchAndRewrite() 1605 auto size = rewriter.create<ConstOp>(op->getLoc(), axis_type, size_attr); in matchAndRewrite()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/ |
D | tf_ops_a_m.cc | 1247 auto axis_type = RankedTensorType::get({}, getElementTypeOrSelf(axis_attr)); in matchAndRewrite() local 1249 if (axis_type.getElementType().isInteger(32)) { in matchAndRewrite() 1250 attr = DenseIntElementsAttr::get(axis_type, static_cast<int32_t>(axis)); in matchAndRewrite() 1252 assert(axis_type.getElementType().isInteger(64)); in matchAndRewrite() 1253 attr = DenseIntElementsAttr::get(axis_type, axis); in matchAndRewrite()
|
/external/tensorflow/tensorflow/compiler/tests/ |
D | randomized_tests.cc | 412 DataType axis_type; member 1131 a.axis_type = DT_INT32; in ChooseGatherArguments() 2904 .Attr("Taxis", a.axis_type) in TEST_F() 3989 auto axis_type = Choose<DataType>({DT_INT32, DT_INT64}); in TEST_F() local 3995 auto axis = RandomBoundedTensor(axis_type, 0, rank - 1, true, axis_shape); in TEST_F() 4003 .Attr("Taxis", axis_type) in TEST_F()
|