1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iostream>
17 #include <memory>
18
19 #include "absl/strings/string_view.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/InitLLVM.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/PrettyStackTrace.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28
29 using llvm::Optional;
30 using llvm::cl::opt;
31
32 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
33 // RUN: | %p/importer_test_min_max - \
34 // RUN: | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - \
35 // RUN: | FileCheck %s
36
37 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
38 // RUN: | %p/importer_test_min_max - \
39 // RUN: | flatbuffer_to_string - \
40 // RUN: | FileCheck --check-prefix=FB %s
41
42 // Tests for verifying the tflite model with min/max can be imported
43 // correctly.
44
45 // NOLINTNEXTLINE
46 static opt<std::string> inputFileName(llvm::cl::Positional,
47 llvm::cl::desc("<input file>"),
48 llvm::cl::init("-"));
49
50 namespace mlir {
51 namespace {
InjectStatsToFullyConnected(llvm::StringRef buffer)52 Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
53 llvm::StringRef buffer) {
54 auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
55 buffer.data(), buffer.size());
56 if (nullptr == model_ptr) {
57 return llvm::None;
58 }
59 std::unique_ptr<tflite::ModelT> model(model_ptr->GetModel()->UnPack());
60
61 // FB-LABEL: name: "arg0",
62 // FB-NEXT: quantization: {
63 // FB-NEXT: min: [ -1.0 ],
64 // FB-NEXT: max: [ 1.0 ]
65 // FB-NEXT: }
66
67 // FB-LABEL: name: "arg1",
68 // FB-NEXT: quantization: {
69 // FB-EMPTY:
70 // FB-NEXT: }
71
72 // FB-LABEL: name: "tfl.fully_connected",
73 // FB-NEXT: quantization: {
74 // FB-NEXT: min: [ -0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0,
75 // FB-SAME: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0, -16.0,
76 // FB-SAME: -17.0, -18.0, -19.0, -20.0, -21.0, -22.0, -23.0, -24.0, -25.0,
77 // FB-SAME: -26.0, -27.0, -28.0, -29.0, -30.0, -31.0, -32.0, -33.0, -34.0,
78 // FB-SAME: -35.0, -36.0, -37.0, -38.0, -39.0 ],
79 // FB-NEXT: max: [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
80 // FB-SAME: 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
81 // FB-SAME: 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
82 // FB-SAME: 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0 ],
83 // FB-NEXT: quantized_dimension: 1
84 // FB-NEXT: }
85
86 // FB-LABEL: name: "tfl.fully_connected:1",
87 // FB-NEXT: quantization: {
88 // FB-EMPTY:
89 // FB-NEXT: }
90
91 // FB-LABEL: operators: [ {
92 // FB-NEXT: inputs: [ 0, 1, 2 ],
93 // FB-NEXT: outputs: [ 3, 4 ],
94 // FB-NEXT: builtin_options_type: FullyConnectedOptions,
95 // FB-NEXT: builtin_options: {
96 // FB-EMPTY:
97 // FB-NEXT: }
98 // FB-NEXT: } ],
99
100 // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
101 // CHECK-SAME: -> tensor<40x40xf32>
102 // CHECK: %[[stat:.*]] = "quant.stats"(%arg0) {layerStats = dense<
103 // CHECK-SAME: [-1.000000e+00, 1.000000e+00]> : tensor<2xf32>}
104 // CHECK-SAME: : (tensor<40x37xf32>) -> tensor<40x37xf32>
105 // CHECK-NEXT: %[[cst:.*]] = "tfl.pseudo_const"() {value = dense<
106 // CHECK-SAME: 1.000000e+00> : tensor<40xf32>} : () -> tensor<40xf32>
107 // CHECK-NEXT: %[[fc:.*]]:2 = "tfl.fully_connected"(%[[stat]], %arg1,
108 // CHECK-NEXT: %[[stat1:.*]] = "quant.stats"(%[[fc]]#0) {axis = 1 : i64,
109 // CHECK-SAME: axisStats = dense<{{\[}}[-0.000000e+00, 0.000000e+00],
110 // CHECK-SAME: [-1.000000e+00, 1.000000e+00],
111 // CHECK-SAME: [-2.000000e+00, 2.000000e+00]
112 // CHECK-NEXT: return %[[stat1]] : tensor<40x40xf32>
113 // CHECK-NEXT: }
114
115 // Find the tensors and inject the min and max to the input and output
116 for (auto& sub_graph : model->subgraphs) {
117 for (auto& op : sub_graph->operators) {
118 if (tflite::GetBuiltinCode(
119 model->operator_codes[op->opcode_index].get()) ==
120 tflite::BuiltinOperator_FULLY_CONNECTED) {
121 // inject min/max to the input and output tensors
122 auto& input_tensor = sub_graph->tensors[op->inputs[0]];
123 input_tensor->quantization->scale.clear();
124 input_tensor->quantization->zero_point.clear();
125 input_tensor->quantization->min.push_back(-1.0);
126 input_tensor->quantization->max.push_back(1.0);
127
128 auto& output_tensor = sub_graph->tensors[op->outputs[0]];
129 auto shape = output_tensor->shape;
130 output_tensor->quantization->scale.clear();
131 output_tensor->quantization->zero_point.clear();
132 for (int i = 0; i < shape.back(); ++i) {
133 output_tensor->quantization->min.push_back(-1.0 * i);
134 output_tensor->quantization->max.push_back(1.0 * i);
135 }
136 output_tensor->quantization->quantized_dimension = shape.size() - 1;
137 }
138 }
139 }
140 return model;
141 }
142
143 } // namespace
144 } // namespace mlir
145
main(int argc,char ** argv)146 int main(int argc, char** argv) {
147 llvm::InitLLVM y(argc, argv);
148 llvm::cl::ParseCommandLineOptions(argc, argv);
149 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
150 if (std::error_code error = file_or_err.getError()) {
151 llvm::errs() << argv[0] << ": could not open input file '" << inputFileName
152 << "': " << error.message() << "\n";
153 return 1;
154 }
155 auto buffer = file_or_err->get();
156 auto maybe_module =
157 mlir::InjectStatsToFullyConnected(buffer->getBuffer().str());
158 if (!maybe_module.hasValue()) {
159 return 1;
160 }
161 flatbuffers::FlatBufferBuilder builder;
162 flatbuffers::Offset<tflite::Model> output_model_location =
163 tflite::Model::Pack(builder, maybe_module.getValue().get());
164 tflite::FinishModelBuffer(builder, output_model_location);
165 std::string output_model_content(
166 reinterpret_cast<const char*>(builder.GetBufferPointer()),
167 builder.GetSize());
168 std::cout << output_model_content << "\n";
169 return 0;
170 }
171