1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33
34 namespace {
35
36 using ::tensorflow::Status;
37 using ::tensorflow::errors::InvalidArgument;
38 using ::xla::StatusOr;
39
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)40 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
41 mlir::Builder builder,
42 mlir::Location loc) {
43 auto padding = tflite::Padding::Padding_VALID;
44 if (pad_params == TfLitePadding::kTfLitePaddingSame) {
45 padding = tflite::Padding_SAME;
46 } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
47 padding = tflite::Padding_VALID;
48 } else {
49 return InvalidArgument(
50 absl::StrCat("Invalid padding type", std::to_string(pad_params)));
51 }
52
53 const char* option_name = tflite::EnumNamePadding(padding);
54 return builder.getStringAttr(option_name);
55 }
56
57 } // namespace
58
59 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
60 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)61 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
62 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
63 return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
64 .Case("NONE", tflite::ActivationFunctionType_NONE)
65 .Case("RELU", tflite::ActivationFunctionType_RELU)
66 .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
67 .Case("RELU6", tflite::ActivationFunctionType_RELU6)
68 .Case("TANH", tflite::ActivationFunctionType_TANH)
69 .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
70 }
71
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)72 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
73 tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
74 if (type == tflite::TensorType_INT64) {
75 return tflite::TensorType_INT64;
76 } else if (type == tflite::TensorType_INT32) {
77 return tflite::TensorType_INT32;
78 }
79 llvm_unreachable("invalid type in conversion.");
80 }
81
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)82 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
83 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
84 return llvm::StringSwitch<tflite::Padding>(str)
85 .Case("SAME", tflite::Padding_SAME)
86 .Case("VALID", tflite::Padding_VALID);
87 }
88
ConvertTFL_MirrorPaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)89 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
90 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
91 return llvm::StringSwitch<tflite::MirrorPadMode>(str)
92 .Case("REFLECT", tflite::MirrorPadMode_REFLECT)
93 .Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
94 }
95
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)96 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
97 mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
98 if (type.isF16()) {
99 return tflite::TensorType_FLOAT16;
100 } else if (type.isF32()) {
101 return tflite::TensorType_FLOAT32;
102 } else if (type.isa<mlir::TF::StringType>()) {
103 return tflite::TensorType_STRING;
104 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
105 if (complex_type.getElementType().isF32()) {
106 return tflite::TensorType_COMPLEX64;
107 }
108 llvm_unreachable("invalid complex Type in conversion");
109 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
110 switch (itype.getWidth()) {
111 case 1:
112 return tflite::TensorType_BOOL;
113 case 8:
114 return tflite::TensorType_INT8;
115 case 16:
116 return tflite::TensorType_INT16;
117 case 32:
118 return tflite::TensorType_INT32;
119 case 64:
120 return tflite::TensorType_INT64;
121 default:
122 llvm_unreachable("invalid integer Type in conversion");
123 }
124 }
125 llvm_unreachable("invalid Type in conversion");
126 }
127
128 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)129 static int ConvertI32AttrForOptionWriter(
130 int i, flatbuffers::FlatBufferBuilder* builder) {
131 return i;
132 }
133
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)134 static int ConvertPositiveI32AttrForOptionWriter(
135 int i, flatbuffers::FlatBufferBuilder* builder) {
136 return ConvertI32AttrForOptionWriter(i, builder);
137 }
138
139 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)140 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
141 flatbuffers::FlatBufferBuilder* builder) {
142 std::vector<int32_t> intVec;
143 intVec.reserve(attrArray.getValue().size());
144 for (auto attr : attrArray.getValue()) {
145 intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
146 }
147 return builder->CreateVector(intVec);
148 }
149
150 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)151 static float ConvertF32AttrForOptionWriter(
152 llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
153 return f.convertToFloat();
154 }
155
156 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)157 static bool ConvertBoolAttrForOptionWriter(
158 bool b, flatbuffers::FlatBufferBuilder* builder) {
159 return b;
160 }
161
162 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)163 ConvertDerivedShapeAttrForOptionWriter(
164 llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
165 std::vector<int> intVec(r.begin(), r.end());
166 return builder->CreateVector(intVec);
167 }
168
169 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)170 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
171 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
172 return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
173 .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
174 .Case("SHUFFLED4x16INT8",
175 tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
176 }
177
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)178 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
179 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
180 return llvm::StringSwitch<tflite::LSTMKernelType>(str)
181 .Case("FULL", tflite::LSTMKernelType_FULL)
182 .Case("BASIC", tflite::LSTMKernelType_BASIC);
183 }
184
BuildBoolAttr(bool value,mlir::Builder builder)185 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
186 return builder.getBoolAttr(value);
187 }
188
BuildF32Attr(float value,mlir::Builder builder)189 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
190 return builder.getF32FloatAttr(value);
191 }
192
BuildI32Attr(int32_t value,mlir::Builder builder)193 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
194 return builder.getI32IntegerAttr(value);
195 }
196
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)197 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
198 mlir::Builder builder) {
199 std::vector<int64_t> typecast(value.begin(), value.end());
200 return builder.getI64ArrayAttr(typecast);
201 }
202
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)203 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
204 mlir::Builder builder) {
205 return builder.getI32IntegerAttr(value);
206 }
207
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)208 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
209 mlir::Builder builder) {
210 const char* option_name = tflite::EnumNameActivationFunctionType(value);
211 return builder.getStringAttr(option_name);
212 }
213
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)214 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
215 tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
216 const char* option_name =
217 tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
218 return builder.getStringAttr(option_name);
219 }
220
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)221 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
222 mlir::Builder builder) {
223 const char* option_name = tflite::EnumNameLSTMKernelType(value);
224 return builder.getStringAttr(option_name);
225 }
226
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)227 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
228 mlir::Builder builder) {
229 const char* option_name = tflite::EnumNameMirrorPadMode(value);
230 return builder.getStringAttr(option_name);
231 }
232
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)233 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
234 mlir::Builder builder) {
235 const char* option_name = tflite::EnumNamePadding(value);
236 return builder.getStringAttr(option_name);
237 }
238
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)239 Status mlir::CustomOptionsToAttributes(
240 const std::string& custom_code, const std::vector<uint8_t>& custom_options,
241 mlir::Builder builder, mlir::Location loc,
242 llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
243 attributes->emplace_back(
244 builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
245 std::string content;
246 content.assign(reinterpret_cast<const char*>(custom_options.data()),
247 custom_options.size());
248 ShapedType type = RankedTensorType::get(
249 {static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
250 attributes->emplace_back(builder.getNamedAttr(
251 "custom_option",
252 OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
253 type, content)));
254
255 return Status::OK();
256 }
257
258 // Pull in FlatBuffer writers for TFLite generated using TableGen
259 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
260