1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
17
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/platform/status.h"
23
24 namespace mlir {
25 namespace TFL {
26
CreateConstOpWithSingleValue(PatternRewriter * rewriter,Location loc,ShapedType shaped_type,int value)27 stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
28 PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
29 int value) {
30 Type element_type = shaped_type.getElementType();
31 ShapedType scalar_type = RankedTensorType::get({}, element_type);
32 Attribute attr;
33 if (element_type.isF16()) {
34 auto floatType = mlir::FloatType::getF16(element_type.getContext());
35 auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
36 std::vector<Attribute> floatValues({floatAttr});
37 attr = DenseElementsAttr::get(scalar_type, floatValues);
38 } else if (element_type.isBF16()) {
39 auto floatType = mlir::FloatType::getBF16(element_type.getContext());
40 auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
41 std::vector<Attribute> floatValues({floatAttr});
42 attr = DenseElementsAttr::get(scalar_type, floatValues);
43 } else if (element_type.isF32()) {
44 attr =
45 DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
46 } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
47 auto etype = complex_type.getElementType();
48 if (etype.isF32()) {
49 auto dialect = etype.getContext()->getLoadedDialect("tf");
50 tensorflow::TensorProto repr;
51 repr.set_dtype(tensorflow::DT_COMPLEX64);
52
53 tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
54 shape->set_unknown_rank(false);
55 shape->add_dim()->set_size(int64_t{1});
56 std::string content;
57 auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
58 content.assign(reinterpret_cast<const char*>(&complex_value),
59 sizeof(complex_value));
60 repr.set_tensor_content(content);
61 std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
62
63 attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
64 } else {
65 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
66 "Unsupported type");
67 }
68 } else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
69 if (element_type.isSignedInteger()) {
70 switch (itype.getWidth()) {
71 case 8:
72 attr = DenseElementsAttr::get<int8_t>(scalar_type,
73 static_cast<int8_t>(value));
74 break;
75 case 16:
76 attr = DenseElementsAttr::get<int16_t>(scalar_type,
77 static_cast<int16_t>(value));
78 break;
79 case 32:
80 attr = DenseElementsAttr::get<int32_t>(scalar_type,
81 static_cast<int32_t>(value));
82 break;
83 case 64:
84 attr = DenseElementsAttr::get<int64_t>(scalar_type,
85 static_cast<int64_t>(value));
86 break;
87 default:
88 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
89 "Unsupported type");
90 }
91 } else {
92 switch (itype.getWidth()) {
93 case 8:
94 attr = DenseElementsAttr::get<uint8_t>(scalar_type,
95 static_cast<uint8_t>(value));
96 break;
97 case 16:
98 attr = DenseElementsAttr::get<uint16_t>(scalar_type,
99 static_cast<uint16_t>(value));
100 break;
101 case 32:
102 attr = DenseElementsAttr::get<uint32_t>(scalar_type,
103 static_cast<uint32_t>(value));
104 break;
105 case 64:
106 attr = DenseElementsAttr::get<uint64_t>(scalar_type,
107 static_cast<uint64_t>(value));
108 break;
109 default:
110 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
111 "Unsupported type");
112 }
113 }
114 } else {
115 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
116 "Unsupported type");
117 }
118 return rewriter->create<ConstantOp>(loc, scalar_type, attr);
119 }
120
121 } // namespace TFL
122 } // namespace mlir
123