• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/str_join.h"
22 #include "absl/types/span.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
36 #include "mlir/Transforms/ViewOpGraph.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
38 #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
39 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
40 #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
41 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
43 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
44 #include "tensorflow/compiler/xla/service/hlo.pb.h"
45 #include "tensorflow/compiler/xla/service/hlo_parser.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/status.h"
51 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
52 #include "tensorflow/lite/toco/model_flags.pb.h"
53 #include "tensorflow/lite/toco/toco_flags.pb.h"
54 #include "tensorflow/lite/toco/types.pb.h"
55 #include "tensorflow/stream_executor/lib/statusor.h"
56 
57 namespace tensorflow {
58 namespace {
59 
60 // Error collector that simply ignores errors reported.
61 class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
62  public:
AddError(int line,int column,const string & message)63   void AddError(int line, int column, const string& message) override {}
64 };
65 
LoadHloProto(const std::string & contents,xla::HloProto * hlo_proto)66 bool LoadHloProto(const std::string& contents, xla::HloProto* hlo_proto) {
67   tensorflow::protobuf::TextFormat::Parser parser;
68   NoOpErrorCollector collector;
69   parser.RecordErrorsTo(&collector);
70   return hlo_proto->ParseFromString(contents) ||
71          parser.ParseFromString(contents, hlo_proto) ||
72          hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
73          parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
74 }
75 
HloToMlirHloTranslateFunction(llvm::StringRef input,mlir::MLIRContext * context,bool import_all_computations)76 mlir::OwningOpRef<mlir::ModuleOp> HloToMlirHloTranslateFunction(
77     llvm::StringRef input, mlir::MLIRContext* context,
78     bool import_all_computations) {
79   xla::HloProto hlo_proto;
80   string content(input.data(), input.size());
81   if (!LoadHloProto(content, &hlo_proto)) {
82     LOG(ERROR) << "Failed to load proto";
83     return nullptr;
84   }
85 
86   mlir::OwningOpRef<mlir::ModuleOp> module =
87       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
88   auto status = ConvertHloToMlirHlo(
89       module.get(), hlo_proto.mutable_hlo_module(), import_all_computations);
90   if (!status.ok()) {
91     LOG(ERROR) << "Hlo module import failed: " << status;
92     return nullptr;
93   }
94 
95   return module;
96 }
97 
HloTextToMlirHloTranslateFunction(llvm::StringRef input,mlir::MLIRContext * context,bool import_all_computations)98 mlir::OwningOpRef<mlir::ModuleOp> HloTextToMlirHloTranslateFunction(
99     llvm::StringRef input, mlir::MLIRContext* context,
100     bool import_all_computations) {
101   xla::HloProto hlo_proto;
102   string content(input.data(), input.size());
103 
104   auto hlo_module_error = xla::ParseAndReturnUnverifiedModule(content);
105   if (!hlo_module_error.ok()) {
106     LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status();
107     return nullptr;
108   }
109 
110   auto hlo_module = std::move(hlo_module_error.ValueOrDie());
111   mlir::OwningOpRef<mlir::ModuleOp> module =
112       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
113   auto status =
114       ConvertHloToMlirHlo(*module, hlo_module.get(), import_all_computations);
115   if (!status.ok()) {
116     LOG(ERROR) << "HLO Module import failed: " << status;
117     return nullptr;
118   }
119 
120   return module;
121 }
122 
123 }  // namespace
ConvertJaxToTFLiteFlatBuffer(const std::string & input,const toco::ModelFlags & model_flags,const toco::TocoFlags & toco_flags,string * result)124 Status ConvertJaxToTFLiteFlatBuffer(const std::string& input,
125                                     const toco::ModelFlags& model_flags,
126                                     const toco::TocoFlags& toco_flags,
127                                     string* result) {
128   mlir::MLIRContext context;
129   mlir::quant::QuantizationSpecs quant_specs;
130 
131   // Parse input arrays.
132   std::vector<string> node_names;
133   std::vector<string> node_dtypes;
134   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
135   std::vector<llvm::Optional<double>> node_mins;
136   std::vector<llvm::Optional<double>> node_maxs;
137 
138   // Populate quantization specs.
139   TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(
140       model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
141       &node_shapes, &node_mins, &node_maxs));
142 
143   internal::WarningUnusedFlags(model_flags, toco_flags);
144 
145   // Register all custom ops, including user-specified custom ops.
146   TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
147 
148   mlir::TFL::PassConfig pass_config(quant_specs);
149   bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
150   pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
151   pass_config.enable_tflite_variables =
152       toco_flags.enable_tflite_resource_variables();
153   pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul();
154   pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops();
155   // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
156   // conversion to an unsupported 16x16 TFL::FullyConnectedOp.
157   if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
158     pass_config.unfold_batch_matmul = false;
159   }
160   pass_config.unfold_large_splat_constant =
161       toco_flags.unfold_large_splat_constant();
162   pass_config.enable_hlo_to_tf_conversion = true;
163 
164   mlir::OwningOpRef<mlir::ModuleOp> module;
165   if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_TEXT) {
166     module = HloTextToMlirHloTranslateFunction(input, &context, false);
167   } else if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_PROTO) {
168     module = HloToMlirHloTranslateFunction(input, &context, false);
169   } else {
170     return errors::InvalidArgument("unknown hlo format type.");
171   }
172 
173   // Set the input names.
174   auto main_func = module->lookupSymbol<mlir::func::FuncOp>("main");
175   if (!main_func) return errors::Internal("Failed to find the main function.");
176   // Retrieve input names from model flags.
177   std::vector<std::string> input_names;
178   for (const auto& input : model_flags.input_arrays()) {
179     input_names.push_back(input.name());
180   }
181 
182   const auto& inputs = absl::StrJoin(input_names, ",");
183   mlir::OpBuilder builder(*module);
184   llvm::SmallVector<mlir::NamedAttribute> attrs;
185   attrs.push_back(
186       builder.getNamedAttr("inputs", builder.getStringAttr(inputs)));
187   // Jax wrapped the output nodes in a tuple, so it's pretty hard to us
188   // to tell the output at this point, we will set the output at the export
189   // phase.
190   main_func->setAttr("tf.entry_function", builder.getDictionaryAttr(attrs));
191 
192   auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
193       model_flags, toco_flags, std::move(module), pass_config,
194       /*saved_model_tags=*/{}, result,
195       /*session=*/llvm::None);
196   return status;
197 }
198 
199 }  // namespace tensorflow
200