• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17 
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 
34 namespace {
35 
36 using ::tensorflow::Status;
37 using ::tensorflow::errors::InvalidArgument;
38 using ::xla::StatusOr;
39 
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)40 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
41                                           mlir::Builder builder,
42                                           mlir::Location loc) {
43   auto padding = tflite::Padding::Padding_VALID;
44   if (pad_params == TfLitePadding::kTfLitePaddingSame) {
45     padding = tflite::Padding_SAME;
46   } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
47     padding = tflite::Padding_VALID;
48   } else {
49     return InvalidArgument(
50         absl::StrCat("Invalid padding type", std::to_string(pad_params)));
51   }
52 
53   const char* option_name = tflite::EnumNamePadding(padding);
54   return builder.getStringAttr(option_name);
55 }
56 
57 }  // namespace
58 
59 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
60 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)61 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
62     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
63   return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
64       .Case("NONE", tflite::ActivationFunctionType_NONE)
65       .Case("RELU", tflite::ActivationFunctionType_RELU)
66       .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
67       .Case("RELU6", tflite::ActivationFunctionType_RELU6)
68       .Case("TANH", tflite::ActivationFunctionType_TANH)
69       .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
70 }
71 
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)72 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
73     tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
74   if (type == tflite::TensorType_INT64) {
75     return tflite::TensorType_INT64;
76   } else if (type == tflite::TensorType_INT32) {
77     return tflite::TensorType_INT32;
78   }
79   llvm_unreachable("invalid type in conversion.");
80 }
81 
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)82 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
83     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
84   return llvm::StringSwitch<tflite::Padding>(str)
85       .Case("SAME", tflite::Padding_SAME)
86       .Case("VALID", tflite::Padding_VALID);
87 }
88 
ConvertTFL_MirrorPaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)89 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
90     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
91   return llvm::StringSwitch<tflite::MirrorPadMode>(str)
92       .Case("REFLECT", tflite::MirrorPadMode_REFLECT)
93       .Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
94 }
95 
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)96 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
97     mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
98   if (type.isF16()) {
99     return tflite::TensorType_FLOAT16;
100   } else if (type.isF32()) {
101     return tflite::TensorType_FLOAT32;
102   } else if (type.isa<mlir::TF::StringType>()) {
103     return tflite::TensorType_STRING;
104   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
105     if (complex_type.getElementType().isF32()) {
106       return tflite::TensorType_COMPLEX64;
107     }
108     llvm_unreachable("invalid complex Type in conversion");
109   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
110     switch (itype.getWidth()) {
111       case 1:
112         return tflite::TensorType_BOOL;
113       case 8:
114         return tflite::TensorType_INT8;
115       case 16:
116         return tflite::TensorType_INT16;
117       case 32:
118         return tflite::TensorType_INT32;
119       case 64:
120         return tflite::TensorType_INT64;
121       default:
122         llvm_unreachable("invalid integer Type in conversion");
123     }
124   }
125   llvm_unreachable("invalid Type in conversion");
126 }
127 
128 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)129 static int ConvertI32AttrForOptionWriter(
130     int i, flatbuffers::FlatBufferBuilder* builder) {
131   return i;
132 }
133 
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)134 static int ConvertPositiveI32AttrForOptionWriter(
135     int i, flatbuffers::FlatBufferBuilder* builder) {
136   return ConvertI32AttrForOptionWriter(i, builder);
137 }
138 
139 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)140 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
141                                    flatbuffers::FlatBufferBuilder* builder) {
142   std::vector<int32_t> intVec;
143   intVec.reserve(attrArray.getValue().size());
144   for (auto attr : attrArray.getValue()) {
145     intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
146   }
147   return builder->CreateVector(intVec);
148 }
149 
150 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)151 static float ConvertF32AttrForOptionWriter(
152     llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
153   return f.convertToFloat();
154 }
155 
156 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)157 static bool ConvertBoolAttrForOptionWriter(
158     bool b, flatbuffers::FlatBufferBuilder* builder) {
159   return b;
160 }
161 
162 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)163 ConvertDerivedShapeAttrForOptionWriter(
164     llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
165   std::vector<int> intVec(r.begin(), r.end());
166   return builder->CreateVector(intVec);
167 }
168 
169 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)170 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
171     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
172   return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
173       .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
174       .Case("SHUFFLED4x16INT8",
175             tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
176 }
177 
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)178 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
179     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
180   return llvm::StringSwitch<tflite::LSTMKernelType>(str)
181       .Case("FULL", tflite::LSTMKernelType_FULL)
182       .Case("BASIC", tflite::LSTMKernelType_BASIC);
183 }
184 
BuildBoolAttr(bool value,mlir::Builder builder)185 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
186   return builder.getBoolAttr(value);
187 }
188 
BuildF32Attr(float value,mlir::Builder builder)189 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
190   return builder.getF32FloatAttr(value);
191 }
192 
BuildI32Attr(int32_t value,mlir::Builder builder)193 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
194   return builder.getI32IntegerAttr(value);
195 }
196 
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)197 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
198                                          mlir::Builder builder) {
199   std::vector<int64_t> typecast(value.begin(), value.end());
200   return builder.getI64ArrayAttr(typecast);
201 }
202 
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)203 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
204                                             mlir::Builder builder) {
205   return builder.getI32IntegerAttr(value);
206 }
207 
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)208 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
209                                        mlir::Builder builder) {
210   const char* option_name = tflite::EnumNameActivationFunctionType(value);
211   return builder.getStringAttr(option_name);
212 }
213 
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)214 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
215     tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
216   const char* option_name =
217       tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
218   return builder.getStringAttr(option_name);
219 }
220 
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)221 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
222                                                    mlir::Builder builder) {
223   const char* option_name = tflite::EnumNameLSTMKernelType(value);
224   return builder.getStringAttr(option_name);
225 }
226 
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)227 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
228                                                   mlir::Builder builder) {
229   const char* option_name = tflite::EnumNameMirrorPadMode(value);
230   return builder.getStringAttr(option_name);
231 }
232 
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)233 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
234                                             mlir::Builder builder) {
235   const char* option_name = tflite::EnumNamePadding(value);
236   return builder.getStringAttr(option_name);
237 }
238 
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)239 Status mlir::CustomOptionsToAttributes(
240     const std::string& custom_code, const std::vector<uint8_t>& custom_options,
241     mlir::Builder builder, mlir::Location loc,
242     llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
243   attributes->emplace_back(
244       builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
245   std::string content;
246   content.assign(reinterpret_cast<const char*>(custom_options.data()),
247                  custom_options.size());
248   ShapedType type = RankedTensorType::get(
249       {static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
250   attributes->emplace_back(builder.getNamedAttr(
251       "custom_option",
252       OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
253                               type, content)));
254 
255   return Status::OK();
256 }
257 
258 // Pull in FlatBuffer writers for TFLite generated using TableGen
259 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
260