• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
17 
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 
37 namespace {
38 
39 using ::tensorflow::Status;
40 using ::tensorflow::errors::InvalidArgument;
41 using ::xla::StatusOr;
42 
GetPaddingAttr(TfLitePadding pad_params,mlir::Builder builder,mlir::Location loc)43 StatusOr<mlir::StringAttr> GetPaddingAttr(TfLitePadding pad_params,
44                                           mlir::Builder builder,
45                                           mlir::Location loc) {
46   auto padding = tflite::Padding::Padding_VALID;
47   if (pad_params == TfLitePadding::kTfLitePaddingSame) {
48     padding = tflite::Padding_SAME;
49   } else if (pad_params == TfLitePadding::kTfLitePaddingValid) {
50     padding = tflite::Padding_VALID;
51   } else {
52     return InvalidArgument(
53         absl::StrCat("Invalid padding type", std::to_string(pad_params)));
54   }
55 
56   const char* option_name = tflite::EnumNamePadding(padding);
57   return builder.getStringAttr(option_name);
58 }
59 
60 }  // namespace
61 
62 // TODO(jpienaar): This is a placeholder. This should be done in more efficient
63 // way when part of the translation of module.
ConvertTFL_AFAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)64 static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
65     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
66   return llvm::StringSwitch<tflite::ActivationFunctionType>(str)
67       .Case("NONE", tflite::ActivationFunctionType_NONE)
68       .Case("RELU", tflite::ActivationFunctionType_RELU)
69       .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1)
70       .Case("RELU6", tflite::ActivationFunctionType_RELU6)
71       .Case("TANH", tflite::ActivationFunctionType_TANH)
72       .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
73 }
74 
ConvertDerivedTFLiteTypeAttrForOptionWriter(tflite::TensorType type,flatbuffers::FlatBufferBuilder * builder)75 static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
76     tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
77   if (type == tflite::TensorType_INT64) {
78     return tflite::TensorType_INT64;
79   } else if (type == tflite::TensorType_INT32) {
80     return tflite::TensorType_INT32;
81   }
82   llvm_unreachable("invalid type in conversion.");
83 }
84 
ConvertTFL_PaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)85 static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
86     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
87   return llvm::StringSwitch<tflite::Padding>(str)
88       .Case("SAME", tflite::Padding_SAME)
89       .Case("VALID", tflite::Padding_VALID);
90 }
91 
ConvertTFL_MirrorPaddingAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)92 static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
93     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
94   return llvm::StringSwitch<tflite::MirrorPadMode>(str)
95       .Case("REFLECT", tflite::MirrorPadMode_REFLECT)
96       .Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
97 }
98 
ConvertDerivedTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)99 static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
100     mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
101   return tflite::ConvertTypeToTensorType(type);
102 }
103 
104 // I32Attr already returns an int as required by flatbuffer builders.
ConvertI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)105 static int ConvertI32AttrForOptionWriter(
106     int i, flatbuffers::FlatBufferBuilder* builder) {
107   return i;
108 }
109 
ConvertPositiveI32AttrForOptionWriter(int i,flatbuffers::FlatBufferBuilder * builder)110 static int ConvertPositiveI32AttrForOptionWriter(
111     int i, flatbuffers::FlatBufferBuilder* builder) {
112   return ConvertI32AttrForOptionWriter(i, builder);
113 }
114 
115 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,flatbuffers::FlatBufferBuilder * builder)116 ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray,
117                                    flatbuffers::FlatBufferBuilder* builder) {
118   std::vector<int32_t> intVec;
119   intVec.reserve(attrArray.getValue().size());
120   for (auto attr : attrArray.getValue()) {
121     intVec.push_back(attr.cast<mlir::IntegerAttr>().getInt());
122   }
123   return builder->CreateVector(intVec);
124 }
125 
126 // F32Attr already returns a float as required by flatbuffer builders.
ConvertF32AttrForOptionWriter(llvm::APFloat f,flatbuffers::FlatBufferBuilder * builder)127 static float ConvertF32AttrForOptionWriter(
128     llvm::APFloat f, flatbuffers::FlatBufferBuilder* builder) {
129   return f.convertToFloat();
130 }
131 
132 // BoolAttr already returns a bool as required by flatbuffer builders.
ConvertBoolAttrForOptionWriter(bool b,flatbuffers::FlatBufferBuilder * builder)133 static bool ConvertBoolAttrForOptionWriter(
134     bool b, flatbuffers::FlatBufferBuilder* builder) {
135   return b;
136 }
137 
ConvertStrAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)138 static flatbuffers::Offset<flatbuffers::String> ConvertStrAttrForOptionWriter(
139     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
140   return builder->CreateString(str.str());
141 }
142 
ConvertTypeAttrForOptionWriter(mlir::Type type,flatbuffers::FlatBufferBuilder * builder)143 static tflite::TensorType ConvertTypeAttrForOptionWriter(
144     mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
145   return tflite::ConvertTypeToTensorType(type);
146 }
147 
148 static flatbuffers::Offset<flatbuffers::Vector<int32_t>>
ConvertDerivedShapeAttrForOptionWriter(llvm::ArrayRef<int64_t> r,flatbuffers::FlatBufferBuilder * builder)149 ConvertDerivedShapeAttrForOptionWriter(
150     llvm::ArrayRef<int64_t> r, flatbuffers::FlatBufferBuilder* builder) {
151   std::vector<int> intVec(r.begin(), r.end());
152   return builder->CreateVector(intVec);
153 }
154 
155 static tflite::FullyConnectedOptionsWeightsFormat
ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)156 ConvertTFL_FullyConnectedOptionsWeightFormatAttrForOptionWriter(
157     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
158   return llvm::StringSwitch<tflite::FullyConnectedOptionsWeightsFormat>(str)
159       .Case("DEFAULT", tflite::FullyConnectedOptionsWeightsFormat_DEFAULT)
160       .Case("SHUFFLED4x16INT8",
161             tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
162 }
163 
ConvertTFL_LSTMKernelTypeAttrForOptionWriter(llvm::StringRef str,flatbuffers::FlatBufferBuilder * builder)164 static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter(
165     llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
166   return llvm::StringSwitch<tflite::LSTMKernelType>(str)
167       .Case("FULL", tflite::LSTMKernelType_FULL)
168       .Case("BASIC", tflite::LSTMKernelType_BASIC);
169 }
170 
BuildBoolAttr(bool value,mlir::Builder builder)171 static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) {
172   return builder.getBoolAttr(value);
173 }
174 
BuildStrAttr(llvm::StringRef str,mlir::Builder builder)175 static mlir::Attribute BuildStrAttr(llvm::StringRef str,
176                                     mlir::Builder builder) {
177   return builder.getStringAttr(str);
178 }
179 
BuildF32Attr(float value,mlir::Builder builder)180 static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) {
181   return builder.getF32FloatAttr(value);
182 }
183 
BuildI32Attr(int32_t value,mlir::Builder builder)184 static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) {
185   return builder.getI32IntegerAttr(value);
186 }
187 
BuildI64ArrayAttr(std::vector<int32_t> value,mlir::Builder builder)188 static mlir::Attribute BuildI64ArrayAttr(std::vector<int32_t> value,
189                                          mlir::Builder builder) {
190   std::vector<int64_t> typecast(value.begin(), value.end());
191   return builder.getI64ArrayAttr(typecast);
192 }
193 
BuildPositiveI32Attr(int32_t value,mlir::Builder builder)194 static mlir::Attribute BuildPositiveI32Attr(int32_t value,
195                                             mlir::Builder builder) {
196   return builder.getI32IntegerAttr(value);
197 }
198 
BuildTypeAttr(tflite::TensorType value,mlir::Builder builder)199 static mlir::Attribute BuildTypeAttr(tflite::TensorType value,
200                                      mlir::Builder builder) {
201   return mlir::TypeAttr::get(ConvertElementType(value, builder));
202 }
203 
BuildTFL_AFAttr(tflite::ActivationFunctionType value,mlir::Builder builder)204 static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value,
205                                        mlir::Builder builder) {
206   const char* option_name = tflite::EnumNameActivationFunctionType(value);
207   return builder.getStringAttr(option_name);
208 }
209 
BuildTFL_FullyConnectedOptionsWeightFormatAttr(tflite::FullyConnectedOptionsWeightsFormat value,mlir::Builder builder)210 static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr(
211     tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) {
212   const char* option_name =
213       tflite::EnumNameFullyConnectedOptionsWeightsFormat(value);
214   return builder.getStringAttr(option_name);
215 }
216 
BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,mlir::Builder builder)217 static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value,
218                                                    mlir::Builder builder) {
219   const char* option_name = tflite::EnumNameLSTMKernelType(value);
220   return builder.getStringAttr(option_name);
221 }
222 
BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,mlir::Builder builder)223 static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
224                                                   mlir::Builder builder) {
225   const char* option_name = tflite::EnumNameMirrorPadMode(value);
226   return builder.getStringAttr(option_name);
227 }
228 
BuildTFL_PaddingAttr(tflite::Padding value,mlir::Builder builder)229 static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
230                                             mlir::Builder builder) {
231   const char* option_name = tflite::EnumNamePadding(value);
232   return builder.getStringAttr(option_name);
233 }
234 
CustomOptionsToAttributes(const std::string & custom_code,const std::vector<uint8_t> & custom_options,mlir::Builder builder,mlir::Location loc,llvm::SmallVectorImpl<mlir::NamedAttribute> * attributes)235 Status mlir::CustomOptionsToAttributes(
236     const std::string& custom_code, const std::vector<uint8_t>& custom_options,
237     mlir::Builder builder, mlir::Location loc,
238     llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
239   attributes->emplace_back(
240       builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
241   std::string content;
242   content.assign(reinterpret_cast<const char*>(custom_options.data()),
243                  custom_options.size());
244   ShapedType type = RankedTensorType::get(
245       {static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
246   attributes->emplace_back(builder.getNamedAttr(
247       "custom_option",
248       OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
249                               type, content)));
250 
251   return Status::OK();
252 }
253 
254 // Pull in FlatBuffer writers for TFLite generated using TableGen
255 #include "tensorflow/compiler/mlir/lite/operator_converters.inc"
256