Home
last modified time | relevance | path

Searched refs:element_ty (Results 1 – 12 of 12) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/
Dhlo_utils.h41 Type element_ty = getElementTypeOrSelf(ty); in getSplat() local
43 if (element_ty.isSignlessInteger()) in getSplat()
44 return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant)); in getSplat()
46 if (element_ty.isa<FloatType>()) in getSplat()
47 return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant)); in getSplat()
49 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) { in getSplat()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dtype_to_shape.cc55 mlir::Type element_ty = complex_type.getElementType(); in TypeToPrimitiveType() local
56 if (element_ty.isF32()) { in TypeToPrimitiveType()
59 } else if (element_ty.isF64()) { in TypeToPrimitiveType()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dconstant_fold.cc99 Type element_ty = shaped_ty.getElementType(); in ConstantFoldFallbackHook() local
101 element_ty.isIntOrFloat(); in ConstantFoldFallbackHook()
Dlower_tf.cc342 auto element_ty = input_ty.getElementType(); in matchAndRewrite() local
343 auto scalar_ty = RankedTensorType::get({}, element_ty); in matchAndRewrite()
358 scalar_ty, ConvertToAPFloat(bits_min, element_ty))); in matchAndRewrite()
362 scalar_ty, ConvertToAPFloat(bits_max, element_ty))); in matchAndRewrite()
367 scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty))); in matchAndRewrite()
414 DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty))); in matchAndRewrite()
1035 auto element_ty = input_ty.getElementType(); in matchAndRewrite() local
1088 op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input, in matchAndRewrite()
1115 op.getLoc(), RankedTensorType::get(transpose_shape, element_ty), in matchAndRewrite()
1141 op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty), in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/
Dtf_ops_helpers.inc102 Type element_ty = getElementTypeOrSelf(input_ty);
106 if (!ranked_ty) return UnrankedTensorType::get(element_ty);
113 if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
116 return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
124 if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
141 return RankedTensorType::get(out_shape, element_ty);
494 Type element_ty = lhs_type.getElementType();
496 if (auto ty = element_ty.template dyn_cast<FloatType>()) {
498 } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
Dtf_ops_n_z.cc149 Type element_ty = on_value.getType().cast<TensorType>().getElementType(); in InferOneHotOpType() local
150 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferOneHotOpType()
165 return RankedTensorType::get(shape, element_ty); in InferOneHotOpType()
581 auto element_ty = tensor_ty.getElementType(); in GetReshapeOutputType() local
582 output_ty = UnrankedTensorType::get(element_ty); in GetReshapeOutputType()
597 output_ty = RankedTensorType::get(dynamic_shape, element_ty); in GetReshapeOutputType()
633 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType()
662 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType()
811 Type element_ty = e.getType().cast<TensorType>().getElementType(); in InferSelectV2OpType() local
812 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferSelectV2OpType()
[all …]
Dtf_types.cc377 Type element_ty = getElementTypeOrSelf(ty); in DropTypeHelper() local
378 auto composed_type = element_ty.dyn_cast<ComposedType>(); in DropTypeHelper()
Dtf_ops_a_m.cc2127 Type element_ty = input.getType().cast<TensorType>().getElementType(); in InferExpandDimsOpType() local
2128 auto unranked_ty = UnrankedTensorType::get(element_ty); in InferExpandDimsOpType()
2145 return RankedTensorType::get(shape, element_ty); in InferExpandDimsOpType()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
Dhlo_ops.cc462 Type element_ty = operand_ty.getElementType(); in inferReturnTypes() local
463 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) { in inferReturnTypes()
464 element_ty = complex_ty.getElementType(); in inferReturnTypes()
469 result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty); in inferReturnTypes()
471 result_ty = UnrankedTensorType::get(element_ty); in inferReturnTypes()
964 auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); in inferReturnTypes() local
967 result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty); in inferReturnTypes()
969 result_ty = UnrankedTensorType::get(element_ty); in inferReturnTypes()
971 result_ty = element_ty; in inferReturnTypes()
993 auto element_ty = getElementTypeOrSelf(type); in CreateRealType() local
[all …]
/external/tensorflow/tensorflow/compiler/mlir/lite/
Dflatbuffer_import.cc476 mlir::Type element_ty = getElementTypeOrSelf(type); in GetSplat() local
478 if (element_ty.isSignlessInteger()) in GetSplat()
480 type, builder.getIntegerAttr(element_ty, unique_index)); in GetSplat()
482 if (element_ty.isa<mlir::FloatType>()) in GetSplat()
484 type, builder.getFloatAttr(element_ty, unique_index)); in GetSplat()
486 if (auto qtype = element_ty.dyn_cast<QuantizedType>()) { in GetSplat()
/external/tensorflow/tensorflow/compiler/mlir/lite/ir/
Dtfl_ops.cc1353 auto element_ty = input_ty.getElementType(); in GetReshapeOutputType() local
1354 output_ty = UnrankedTensorType::get(element_ty); in GetReshapeOutputType()
1369 output_ty = RankedTensorType::get(dynamic_shape, element_ty); in GetReshapeOutputType()
1408 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType()
1437 output_ty = RankedTensorType::get(output_ty_shape, element_ty); in GetReshapeOutputType()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc908 auto element_ty = ty.cast<TensorType>().getElementType(); in GetEpsilonValue() local
909 auto scalar_ty = RankedTensorType::get({}, element_ty); in GetEpsilonValue()
910 if (element_ty.isF16()) { in GetEpsilonValue()
914 } else if (element_ty.isBF16()) { in GetEpsilonValue()
918 } else if (element_ty.isF32()) { in GetEpsilonValue()
921 } else if (element_ty.isF64()) { in GetEpsilonValue()