/external/tensorflow/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/ |
D | hlo_utils.h | 41 Type element_ty = getElementTypeOrSelf(ty); in getSplat() local 43 if (element_ty.isSignlessInteger()) in getSplat() 44 return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant)); in getSplat() 46 if (element_ty.isa<FloatType>()) in getSplat() 47 return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant)); in getSplat() 49 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) { in getSplat()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/ |
D | type_to_shape.cc | 55 mlir::Type element_ty = complex_type.getElementType(); in TypeToPrimitiveType() local 56 if (element_ty.isF32()) { in TypeToPrimitiveType() 59 } else if (element_ty.isF64()) { in TypeToPrimitiveType()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | constant_fold.cc | 99 Type element_ty = shaped_ty.getElementType(); in ConstantFoldFallbackHook() local 101 element_ty.isIntOrFloat(); in ConstantFoldFallbackHook()
|
D | lower_tf.cc | 342 auto element_ty = input_ty.getElementType(); in matchAndRewrite() local 343 auto scalar_ty = RankedTensorType::get({}, element_ty); in matchAndRewrite() 358 scalar_ty, ConvertToAPFloat(bits_min, element_ty))); in matchAndRewrite() 362 scalar_ty, ConvertToAPFloat(bits_max, element_ty))); in matchAndRewrite() 367 scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty))); in matchAndRewrite() 414 DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty))); in matchAndRewrite() 1035 auto element_ty = input_ty.getElementType(); in matchAndRewrite() local 1088 op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input, in matchAndRewrite() 1115 op.getLoc(), RankedTensorType::get(transpose_shape, element_ty), in matchAndRewrite() 1141 op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty), in matchAndRewrite() [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/ |
D | tf_ops_helpers.inc | 102 Type element_ty = getElementTypeOrSelf(input_ty); 106 if (!ranked_ty) return UnrankedTensorType::get(element_ty); 113 if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); 116 return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty); 124 if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); 141 return RankedTensorType::get(out_shape, element_ty); 494 Type element_ty = lhs_type.getElementType(); 496 if (auto ty = element_ty.template dyn_cast<FloatType>()) { 498 } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
|
D | tf_ops_n_z.cc | 149 Type element_ty = on_value.getType().cast<TensorType>().getElementType(); in InferOneHotOpType() local 150 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferOneHotOpType() 165 return RankedTensorType::get(shape, element_ty); in InferOneHotOpType() 581 auto element_ty = tensor_ty.getElementType(); in GetReshapeOutputType() local 582 output_ty = UnrankedTensorType::get(element_ty); in GetReshapeOutputType() 597 output_ty = RankedTensorType::get(dynamic_shape, element_ty); in GetReshapeOutputType() 633 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType() 662 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType() 811 Type element_ty = e.getType().cast<TensorType>().getElementType(); in InferSelectV2OpType() local 812 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferSelectV2OpType() [all …]
|
D | tf_types.cc | 377 Type element_ty = getElementTypeOrSelf(ty); in DropTypeHelper() local 378 auto composed_type = element_ty.dyn_cast<ComposedType>(); in DropTypeHelper()
|
D | tf_ops_a_m.cc | 2127 Type element_ty = input.getType().cast<TensorType>().getElementType(); in InferExpandDimsOpType() local 2128 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferExpandDimsOpType() 2145 return RankedTensorType::get(shape, element_ty); in InferExpandDimsOpType()
|
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/ |
D | hlo_ops.cc | 462 Type element_ty = operand_ty.getElementType(); in inferReturnTypes() local 463 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) { in inferReturnTypes() 464 element_ty = complex_ty.getElementType(); in inferReturnTypes() 469 result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty); in inferReturnTypes() 471 result_ty = UnrankedTensorType::get(element_ty); in inferReturnTypes() 964 auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); in inferReturnTypes() local 967 result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty); in inferReturnTypes() 969 result_ty = UnrankedTensorType::get(element_ty); in inferReturnTypes() 971 result_ty = element_ty; in inferReturnTypes() 993 auto element_ty = getElementTypeOrSelf(type); in CreateRealType() local [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/lite/ |
D | flatbuffer_import.cc | 476 mlir::Type element_ty = getElementTypeOrSelf(type); in GetSplat() local 478 if (element_ty.isSignlessInteger()) in GetSplat() 480 type, builder.getIntegerAttr(element_ty, unique_index)); in GetSplat() 482 if (element_ty.isa<mlir::FloatType>()) in GetSplat() 484 type, builder.getFloatAttr(element_ty, unique_index)); in GetSplat() 486 if (auto qtype = element_ty.dyn_cast<QuantizedType>()) { in GetSplat()
|
/external/tensorflow/tensorflow/compiler/mlir/lite/ir/ |
D | tfl_ops.cc | 1353 auto element_ty = input_ty.getElementType(); in GetReshapeOutputType() local 1354 output_ty = UnrankedTensorType::get(element_ty); in GetReshapeOutputType() 1369 output_ty = RankedTensorType::get(dynamic_shape, element_ty); in GetReshapeOutputType() 1408 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType() 1437 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf.cc | 908 auto element_ty = ty.cast<TensorType>().getElementType(); in GetEpsilonValue() local 909 auto scalar_ty = RankedTensorType::get({}, element_ty); in GetEpsilonValue() 910 if (element_ty.isF16()) { in GetEpsilonValue() 914 } else if (element_ty.isBF16()) { in GetEpsilonValue() 918 } else if (element_ty.isF32()) { in GetEpsilonValue() 921 } else if (element_ty.isF64()) { in GetEpsilonValue()
|