1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36
37 namespace {
38
39 using ::tensorflow::Status;
40 using ::tensorflow::errors::InvalidArgument;
41 using ::xla::StatusOr;
42
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)43 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
44 mlir::Builder builder,
45 mlir::Location loc) {
46 auto padding = tflite::Padding::Padding_VALID;
47 if (pad_params == TfLitePadding::kTfLitePaddingSame) {
48 padding = tflite::Padding_SAME;
49 } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
50 padding = tflite::Padding_VALID;
51 } else {
52 return InvalidArgument(
53 absl::StrCat("Invalid padding type", std::to_string(pad_params)));
54 }
55
56 const char* option_name = tflite::EnumNamePadding(padding);
57 return builder.getStringAttr(option_name);
58 }
59
60 } // namespace
61
62 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
63 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)64 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
65 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
66 return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
67 .Case("NONE", tflite::ActivationFunctionType_NONE)
68 .Case("RELU", tflite::ActivationFunctionType_RELU)
69 .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
70 .Case("RELU6", tflite::ActivationFunctionType_RELU6)
71 .Case("TANH", tflite::ActivationFunctionType_TANH)
72 .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
73 }
74
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)75 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
76 tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
77 if (type == tflite::TensorType_INT64) {
78 return tflite::TensorType_INT64;
79 } else if (type == tflite::TensorType_INT32) {
80 return tflite::TensorType_INT32;
81 }
82 llvm_unreachable("invalid type in conversion.");
83 }
84
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)85 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
86 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
87 return llvm::StringSwitch<tflite::Padding>(str)
88 .Case("SAME", tflite::Padding_SAME)
89 .Case("VALID", tflite::Padding_VALID);
90 }
91
ConvertTFL_MirrorPaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)92 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
93 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
94 return llvm::StringSwitch<tflite::MirrorPadMode>(str)
95 .Case("REFLECT", tflite::MirrorPadMode_REFLECT)
96 .Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
97 }
98
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)99 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
100 mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
101 return tflite::ConvertTypeToTensorType(type);
102 }
103
104 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)105 static int ConvertI32AttrForOptionWriter(
106 int i, flatbuffers::FlatBufferBuilder* builder) {
107 return i;
108 }
109
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)110 static int ConvertPositiveI32AttrForOptionWriter(
111 int i, flatbuffers::FlatBufferBuilder* builder) {
112 return ConvertI32AttrForOptionWriter(i, builder);
113 }
114
115 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)116 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
117 flatbuffers::FlatBufferBuilder* builder) {
118 std::vector<int32_t> intVec;
119 intVec.reserve(attrArray.getValue().size());
120 for (auto attr : attrArray.getValue()) {
121 intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
122 }
123 return builder->CreateVector(intVec);
124 }
125
126 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)127 static float ConvertF32AttrForOptionWriter(
128 llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
129 return f.convertToFloat();
130 }
131
132 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)133 static bool ConvertBoolAttrForOptionWriter(
134 bool b, flatbuffers::FlatBufferBuilder* builder) {
135 return b;
136 }
137
ConvertStrAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)138 static flatbuffers::Offset<flatbuffers::String> ConvertStrAttrForOptionWriter(
139 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
140 return builder->CreateString(str.str());
141 }
142
ConvertTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)143 static tflite::TensorType ConvertTypeAttrForOptionWriter(
144 mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
145 return tflite::ConvertTypeToTensorType(type);
146 }
147
148 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)149 ConvertDerivedShapeAttrForOptionWriter(
150 llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
151 std::vector<int> intVec(r.begin(), r.end());
152 return builder->CreateVector(intVec);
153 }
154
155 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)156 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
157 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
158 return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
159 .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
160 .Case("SHUFFLED4x16INT8",
161 tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
162 }
163
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)164 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
165 llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
166 return llvm::StringSwitch<tflite::LSTMKernelType>(str)
167 .Case("FULL", tflite::LSTMKernelType_FULL)
168 .Case("BASIC", tflite::LSTMKernelType_BASIC);
169 }
170
BuildBoolAttr(bool value,mlir::Builder builder)171 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
172 return builder.getBoolAttr(value);
173 }
174
BuildStrAttr(llvm::StringRef str,mlir::Builder builder)175 static mlir::Attribute BuildStrAttr(llvm::StringRef str,
176 mlir::Builder builder) {
177 return builder.getStringAttr(str);
178 }
179
BuildF32Attr(float value,mlir::Builder builder)180 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
181 return builder.getF32FloatAttr(value);
182 }
183
BuildI32Attr(int32_t value,mlir::Builder builder)184 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
185 return builder.getI32IntegerAttr(value);
186 }
187
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)188 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
189 mlir::Builder builder) {
190 std::vector<int64_t> typecast(value.begin(), value.end());
191 return builder.getI64ArrayAttr(typecast);
192 }
193
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)194 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
195 mlir::Builder builder) {
196 return builder.getI32IntegerAttr(value);
197 }
198
BuildTypeAttr(tflite::TensorType value,mlir::Builder builder)199 static mlir::Attribute BuildTypeAttr(tflite::TensorType value,
200 mlir::Builder builder) {
201 return mlir::TypeAttr::get(ConvertElementType(value, builder));
202 }
203
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)204 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
205 mlir::Builder builder) {
206 const char* option_name = tflite::EnumNameActivationFunctionType(value);
207 return builder.getStringAttr(option_name);
208 }
209
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)210 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
211 tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
212 const char* option_name =
213 tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
214 return builder.getStringAttr(option_name);
215 }
216
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)217 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
218 mlir::Builder builder) {
219 const char* option_name = tflite::EnumNameLSTMKernelType(value);
220 return builder.getStringAttr(option_name);
221 }
222
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)223 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
224 mlir::Builder builder) {
225 const char* option_name = tflite::EnumNameMirrorPadMode(value);
226 return builder.getStringAttr(option_name);
227 }
228
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)229 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
230 mlir::Builder builder) {
231 const char* option_name = tflite::EnumNamePadding(value);
232 return builder.getStringAttr(option_name);
233 }
234
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)235 Status mlir::CustomOptionsToAttributes(
236 const std::string& custom_code, const std::vector<uint8_t>& custom_options,
237 mlir::Builder builder, mlir::Location loc,
238 llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
239 attributes->emplace_back(
240 builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
241 std::string content;
242 content.assign(reinterpret_cast<const char*>(custom_options.data()),
243 custom_options.size());
244 ShapedType type = RankedTensorType::get(
245 {static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
246 attributes->emplace_back(builder.getNamedAttr(
247 "custom_option",
248 OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
249 type, content)));
250
251 return Status::OK();
252 }
253
254 // Pull in FlatBuffer writers for TFLite generated using TableGen
255 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
256