Home
last modified time | relevance | path

Searched refs:shaped_type (Results 1 – 25 of 30) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/
Dtf_jitrt_legalize_i1_type.cc97 ShapedType shaped_type = int_attr.getType(); in isLegalAttribute() local
98 if (!shaped_type.isa<RankedTensorType>()) return true; in isLegalAttribute()
99 return !shaped_type.getElementType().isInteger(/*width=*/1); in isLegalAttribute()
110 ShapedType shaped_type = int_attr.getType(); in convertAttribute() local
112 if (!shaped_type.isa<RankedTensorType>()) return attr; in convertAttribute()
113 if (!shaped_type.getElementType().isInteger(/*width=*/1)) return attr; in convertAttribute()
124 RankedTensorType::get(shaped_type.getShape(), rewriter.getI8Type()); in convertAttribute()
/external/tensorflow/tensorflow/dtensor/mlir/
Dgroup_assignment.cc84 auto shaped_type = mlir::RankedTensorType::get( in ToMLIR() local
93 return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids); in ToMLIR()
110 mlir::ShapedType shaped_type = group_assignment_attr.getType(); in FromMLIR() local
111 if (!shaped_type.hasRank()) { in FromMLIR()
114 if (shaped_type.getRank() != 2) { in FromMLIR()
117 shaped_type.getRank()); in FromMLIR()
119 llvm::ArrayRef<int64_t> shape = shaped_type.getShape(); in FromMLIR()
Dshape_utils.cc143 if (auto shaped_type = in InferShapeOfTFOpWithCustomOperandConstantFn() local
145 if (shaped_type.hasRank()) { in InferShapeOfTFOpWithCustomOperandConstantFn()
147 mlir::ShapedTypeComponents(shaped_type.getShape(), in InferShapeOfTFOpWithCustomOperandConstantFn()
148 shaped_type.getElementType()); in InferShapeOfTFOpWithCustomOperandConstantFn()
151 mlir::ShapedTypeComponents(shaped_type.getElementType()); in InferShapeOfTFOpWithCustomOperandConstantFn()
Dgroup_assignment_test.cc49 auto shaped_type = mlir::RankedTensorType::get( in CreateGroupAssignmentAttr() local
51 return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids); in CreateGroupAssignmentAttr()
Dcollectives.cc312 auto shaped_type = mlir::RankedTensorType::get( in EmitAllReduce() local
316 mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat); in EmitAllReduce()
420 auto shaped_type = mlir::RankedTensorType::get( in CreateConstSrcTargetPair() local
424 mlir::DenseIntElementsAttr::get(shaped_type, src_target_pair_flat); in CreateConstSrcTargetPair()
/external/tensorflow/tensorflow/compiler/mlir/lite/
Dflatbuffer_import.cc379 mlir::RankedTensorType shaped_type, mlir::FloatType elem_type, in ConvertFloatBuffer() argument
402 return mlir::ElementsAttr(DenseElementsAttr::get(shaped_type, values)); in ConvertFloatBuffer()
419 DenseElementsAttr::get(shaped_type, ArrayRef<float>(values))); in ConvertFloatBuffer()
436 DenseElementsAttr::get(shaped_type, ArrayRef<double>(values))); in ConvertFloatBuffer()
443 mlir::RankedTensorType shaped_type, mlir::Type elem_type, in ConvertIntBuffer() argument
450 shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(), in ConvertIntBuffer()
465 DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values))); in ConvertIntBuffer()
469 DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer))); in ConvertIntBuffer()
474 DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values))); in ConvertIntBuffer()
479 DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values))); in ConvertIntBuffer()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/transforms/
Dshape_simplification.cc197 auto shaped_type = in matchAndRewrite() local
201 if (!shaped_type.hasRank()) return failure(); in matchAndRewrite()
202 if (shaped_type.getRank() <= idx) continue; in matchAndRewrite()
206 if (shaped_type.isDynamicDim(idx)) { in matchAndRewrite()
212 if (shaped_type.getDimSize(idx) == 1) continue; in matchAndRewrite()
216 shaped_type.getDimSize(idx)); in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/
Dutils.cc44 auto shaped_type = value.getType().dyn_cast<ShapedType>(); in HasStaticShape() local
45 if (!shaped_type) return false; in HasStaticShape()
47 return shaped_type.hasStaticShape(); in HasStaticShape()
51 auto shaped_type = value.getType().dyn_cast<ShapedType>(); in HasStaticShapeAtDims() local
52 if (!shaped_type) return false; in HasStaticShapeAtDims()
55 if (shaped_type.isDynamicDim(dim)) return false; in HasStaticShapeAtDims()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Densure_static_shapes_pass.cc32 if (ShapedType shaped_type = type.dyn_cast<ShapedType>()) { in runOnFunction() local
33 return !shaped_type.hasStaticShape(); in runOnFunction()
Dfreeze_saved_model_assets.cc92 ShapedType shaped_type = in runOnOperation() local
96 DenseStringElementsAttr::get(shaped_type, {filename})); in runOnOperation()
Dinit_text_file_to_import_test_pass.cc82 ShapedType shaped_type = in runOnOperation() local
95 op.setValueAttr(DenseStringElementsAttr::get(shaped_type, {filename})); in runOnOperation()
Dreadonly_references_to_resources.cc150 ShapedType shaped_type = in runOnOperation() local
152 TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>(); in runOnOperation()
/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/common/
Dutils.h49 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) { in IsF32ShapedType() local
50 return shaped_type.getElementType().isF32(); in IsF32ShapedType()
/external/tensorflow/tensorflow/compiler/mlir/lite/utils/
Dconvert_type.cc194 auto shaped_type = type.dyn_cast<mlir::ShapedType>(); in GetShapeStrippedType() local
195 if (shaped_type) { in GetShapeStrippedType()
196 return shaped_type.getElementType(); in GetShapeStrippedType()
Dconstant_utils.cc32 PatternRewriter* rewriter, Location loc, ShapedType shaped_type, in CreateConstOpWithSingleValue() argument
34 Type element_type = shaped_type.getElementType(); in CreateConstOpWithSingleValue()
Dconstant_utils.h32 PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dlegalize_tf.cc100 auto shaped_type = value.getType().dyn_cast<ShapedType>(); in HasSameStaticShapes() local
101 if (!shaped_type || !shaped_type.hasStaticShape()) { in HasSameStaticShapes()
105 shape = shaped_type.getShape(); in HasSameStaticShapes()
107 if (shape != shaped_type.getShape()) { in HasSameStaticShapes()
119 if (auto shaped_type = val.getType().dyn_cast<RankedTensorType>()) { in CreateCastToInt32() local
121 RankedTensorType::get(shaped_type.getShape(), new_ele_type); in CreateCastToInt32()
132 auto shaped_type = input.getType().cast<ShapedType>(); in GetShape() local
133 if (shaped_type.hasStaticShape()) { in GetShape()
134 auto static_shape = shaped_type.getShape(); in GetShape()
Dlower_static_tensor_list.cc481 if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) { in matchAndRewrite() local
482 if (shaped_type.hasRank() && shaped_type.getRank() == 0) { in matchAndRewrite()
503 if (auto shaped_type = in matchAndRewrite() local
505 if (shaped_type.hasStaticShape()) { in matchAndRewrite()
507 {shaped_type.getRank()}, rewriter.getIntegerType(32)); in matchAndRewrite()
509 for (int64_t dim : shaped_type.getShape()) { in matchAndRewrite()
Doptimize.cc316 auto shaped_type = type.dyn_cast<ShapedType>(); in GetShapeStrippedType() local
317 if (shaped_type) { in GetShapeStrippedType()
318 return shaped_type.getElementType(); in GetShapeStrippedType()
1257 if (auto shaped_type = t.dyn_cast<ShapedType>()) { in CanFuseAffineOp() local
1258 element_type = shaped_type.getElementType(); in CanFuseAffineOp()
/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/
Dcompile_and_run_test.py132 shaped_type = ir.ShapedType(static_type)
134 shaped_type.element_type)
136 -10000.0, 10000.0, size=shaped_type.shape).astype(np_element_type)
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dexport_utils.cc230 if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) { in ConvertAttribute() local
231 elt_type = shaped_type.getElementType(); in ConvertAttribute()
443 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, in SetShapeAttribute() argument
446 SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape()); in SetShapeAttribute()
/external/tensorflow/tensorflow/core/ir/types/
Ddialect.cc441 ShapeAttr ShapeAttr::get(MLIRContext *context, ShapedType shaped_type) { in get() argument
442 if (shaped_type.hasRank()) in get()
443 return Base::get(context, shaped_type.getShape(), /*unranked=*/false); in get()
474 auto shaped_type = value.getType().cast<ShapedType>(); in GetShape() local
475 if (shaped_type.hasRank()) return shaped_type.getShape(); in GetShape()
/external/tensorflow/tensorflow/core/ir/importexport/
Dconvert_attributes.cc181 if (auto shaped_type = elt_type.dyn_cast<ShapedType>()) { in ConvertAttribute() local
182 elt_type = shaped_type.getElementType(); in ConvertAttribute()
280 Status SetShapeAttribute(absl::string_view name, ShapedType shaped_type, in SetShapeAttribute() argument
283 SetTensorShapeProto(shaped_type, value.mutable_shape()); in SetShapeAttribute()
Dconvert_attributes.h45 ShapedType shaped_type,
/external/tensorflow/tensorflow/compiler/mlir/lite/ir/
Dtfl_ops.cc165 ShapedType shaped_type = in VerifyOperandsHaveSameShapesOrBroadcastableShape() local
167 if (!shaped_type || !shaped_type.hasRank()) { in VerifyOperandsHaveSameShapesOrBroadcastableShape()
172 max_rank = std::max(max_rank, shaped_type.getRank()); in VerifyOperandsHaveSameShapesOrBroadcastableShape()
173 if (!shaped_type.hasStaticShape()) { in VerifyOperandsHaveSameShapesOrBroadcastableShape()
179 ArrayRef<int64_t> shape = shaped_type.getShape(); in VerifyOperandsHaveSameShapesOrBroadcastableShape()
468 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) { in IsF32ShapedType() local
469 return shaped_type.getElementType().isF32(); in IsF32ShapedType()
476 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) { in IsBF16ShapedType() local
477 return shaped_type.getElementType().isBF16(); in IsBF16ShapedType()
670 if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) { in buildComparisonBinOp() local
[all …]

12