1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "absl/strings/str_join.h"
22 #include "absl/types/span.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Support/FileUtilities.h" // from @llvm-project
36 #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
38 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
39 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
40 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
41 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
43 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
44 #include "tensorflow/compiler/xla/service/hlo.pb.h"
45 #include "tensorflow/compiler/xla/service/hlo_parser.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/status.h"
51 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
52 #include "tensorflow/lite/toco/model_flags.pb.h"
53 #include "tensorflow/lite/toco/toco_flags.pb.h"
54 #include "tensorflow/lite/toco/types.pb.h"
55 #include "tensorflow/stream_executor/lib/statusor.h"
56
57 namespace tensorflow {
58 namespace {
59
60 // Error collector that simply ignores errors reported.
61 class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
62 public:
AddError(int line,int column,const string & message)63 void AddError(int line, int column, const string& message) override {}
64 };
65
LoadHloProto(const std::string & contents,xla::HloProto * hlo_proto)66 bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) {
67 tensorflow::protobuf::TextFormat::Parser parser;
68 NoOpErrorCollector collector;
69 parser.RecordErrorsTo(&collector);
70 return hlo_proto->ParseFromString(contents) ||
71 parser.ParseFromString(contents, hlo_proto) ||
72 hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
73 parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
74 }
75
HloToMlirHloTranslateFunction(llvm::StringRef input,mlir::MLIRContext * context,bool import_all_computations)76 mlir::OwningOpRef<mlir::ModuleOp> HloToMlirHloTranslateFunction(
77 llvm::StringRef input, mlir::MLIRContext* context,
78 bool import_all_computations) {
79 xla::HloProto hlo_proto;
80 string content(input.data(), input.size());
81 if (!LoadHloProto(content, &hlo_proto)) {
82 LOG(ERROR) << "Failed to load proto";
83 return nullptr;
84 }
85
86 mlir::OwningOpRef<mlir::ModuleOp> module =
87 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
88 auto status = ConvertHloToMlirHlo(
89 module.get(), hlo_proto.mutable_hlo_module(), import_all_computations);
90 if (!status.ok()) {
91 LOG(ERROR) << "Hlo module import failed: " << status;
92 return nullptr;
93 }
94
95 return module;
96 }
97
HloTextToMlirHloTranslateFunction(llvm::StringRef input,mlir::MLIRContext * context,bool import_all_computations)98 mlir::OwningOpRef<mlir::ModuleOp> HloTextToMlirHloTranslateFunction(
99 llvm::StringRef input, mlir::MLIRContext* context,
100 bool import_all_computations) {
101 xla::HloProto hlo_proto;
102 string content(input.data(), input.size());
103
104 auto hlo_module_error = xla::ParseAndReturnUnverifiedModule(content);
105 if (!hlo_module_error.ok()) {
106 LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status();
107 return nullptr;
108 }
109
110 auto hlo_module = std::move(hlo_module_error.ValueOrDie());
111 mlir::OwningOpRef<mlir::ModuleOp> module =
112 mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
113 auto status =
114 ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations);
115 if (!status.ok()) {
116 LOG(ERROR) << "HLO Module import failed: " << status;
117 return nullptr;
118 }
119
120 return module;
121 }
122
123 } // namespace
ConvertJaxToTFLiteFlatBuffer(const std::string & input,const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)124 Status ConvertJaxToTFLiteFlatBuffer(const std::string& input,
125 const toco::ModelFlags& model_flags,
126 const toco::TocoFlags& toco_flags,
127 string* result) {
128 mlir::MLIRContext context;
129 mlir::quant::QuantizationSpecs quant_specs;
130
131 // Parse input arrays.
132 std::vector<string> node_names;
133 std::vector<string> node_dtypes;
134 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
135 std::vector<llvm::Optional<double>> node_mins;
136 std::vector<llvm::Optional<double>> node_maxs;
137
138 // Populate quantization specs.
139 TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
140 model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
141 &node_shapes, &node_mins, &node_maxs));
142
143 internal::WarningUnusedFlags(model_flags, toco_flags);
144
145 // Register all custom ops, including user-specified custom ops.
146 TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
147
148 mlir::TFL::PassConfig pass_config(quant_specs);
149 bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
150 pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
151 pass_config.enable_tflite_variables =
152 toco_flags.enable_tflite_resource_variables();
153 pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul();
154 pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops();
155 // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
156 // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
157 if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
158 pass_config.unfold_batch_matmul = false;
159 }
160 pass_config.unfold_large_splat_constant =
161 toco_flags.unfold_large_splat_constant();
162 pass_config.enable_hlo_to_tf_conversion = true;
163
164 mlir::OwningOpRef<mlir::ModuleOp> module;
165 if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_TEXT) {
166 module = HloTextToMlirHloTranslateFunction(input, &context, false);
167 } else if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_PROTO) {
168 module = HloToMlirHloTranslateFunction(input, &context, false);
169 } else {
170 return errors::InvalidArgument("unknown hlo format type.");
171 }
172
173 // Set the input names.
174 auto main_func = module->lookupSymbol<mlir::func::FuncOp>("main");
175 if (!main_func) return errors::Internal("Failed to find the main function.");
176 // Retrieve input names from model flags.
177 std::vector<std::string> input_names;
178 for (const auto& input : model_flags.input_arrays()) {
179 input_names.push_back(input.name());
180 }
181
182 const auto& inputs = absl::StrJoin(input_names, ",");
183 mlir::OpBuilder builder(*module);
184 llvm::SmallVector<mlir::NamedAttribute> attrs;
185 attrs.push_back(
186 builder.getNamedAttr("inputs", builder.getStringAttr(inputs)));
187 // Jax wrapped the output nodes in a tuple, so it's pretty hard to us
188 // to tell the output at this point, we will set the output at the export
189 // phase.
190 main_func->setAttr("tf.entry_function", builder.getDictionaryAttr(attrs));
191
192 auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
193 model_flags, toco_flags, std::move(module), pass_config,
194 /*saved_model_tags=*/{}, result,
195 /*session=*/llvm::None);
196 return status;
197 }
198
199 } // namespace tensorflow
200