• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iostream>
17 #include <memory>
18 
19 #include "absl/strings/string_view.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/InitLLVM.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/PrettyStackTrace.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28 
29 using llvm::Optional;
30 using llvm::cl::opt;
31 
32 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
33 // RUN:   | %p/importer_test_min_max - \
34 // RUN:   | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - \
35 // RUN:   | FileCheck %s
36 
37 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \
38 // RUN:   | %p/importer_test_min_max - \
39 // RUN:   | flatbuffer_to_string - \
40 // RUN:   | FileCheck --check-prefix=FB %s
41 
42 // Tests for verifying the tflite model with min/max can be imported
43 // correctly.
44 
45 // NOLINTNEXTLINE
46 static opt<std::string> inputFileName(llvm::cl::Positional,
47                                       llvm::cl::desc("<input file>"),
48                                       llvm::cl::init("-"));
49 
50 namespace mlir {
51 namespace {
InjectStatsToFullyConnected(llvm::StringRef buffer)52 Optional<std::unique_ptr<tflite::ModelT>> InjectStatsToFullyConnected(
53     llvm::StringRef buffer) {
54   auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
55       buffer.data(), buffer.size());
56   if (nullptr == model_ptr) {
57     return llvm::None;
58   }
59   std::unique_ptr<tflite::ModelT> model(model_ptr->GetModel()->UnPack());
60 
61   // FB-LABEL:     name: "arg0",
62   // FB-NEXT:      quantization: {
63   // FB-NEXT:              min: [ -1.0 ],
64   // FB-NEXT:              max: [ 1.0 ]
65   // FB-NEXT:      }
66 
67   // FB-LABEL:     name: "arg1",
68   // FB-NEXT:            quantization: {
69   // FB-EMPTY:
70   // FB-NEXT:            }
71 
72   // FB-LABEL:     name: "tfl.fully_connected",
73   // FB-NEXT:      quantization: {
74   // FB-NEXT:        min: [ -0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0,
75   // FB-SAME:  -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0, -16.0,
76   // FB-SAME:  -17.0, -18.0, -19.0, -20.0, -21.0, -22.0, -23.0, -24.0, -25.0,
77   // FB-SAME:  -26.0, -27.0, -28.0, -29.0, -30.0, -31.0, -32.0, -33.0, -34.0,
78   // FB-SAME:  -35.0, -36.0, -37.0, -38.0, -39.0 ],
79   // FB-NEXT:        max: [ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
80   // FB-SAME:  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
81   // FB-SAME:  21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
82   // FB-SAME:  32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0 ],
83   // FB-NEXT:        quantized_dimension: 1
84   // FB-NEXT:      }
85 
86   // FB-LABEL:     name: "tfl.fully_connected:1",
87   // FB-NEXT:      quantization: {
88   // FB-EMPTY:
89   // FB-NEXT:      }
90 
91   // FB-LABEL:      operators: [ {
92   // FB-NEXT:             inputs: [ 0, 1, 2 ],
93   // FB-NEXT:             outputs: [ 3, 4 ],
94   // FB-NEXT:             builtin_options_type: FullyConnectedOptions,
95   // FB-NEXT:             builtin_options: {
96   // FB-EMPTY:
97   // FB-NEXT:             }
98   // FB-NEXT:       } ],
99 
100   // CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
101   // CHECK-SAME:      -> tensor<40x40xf32>
102   // CHECK:         %[[stat:.*]] = "quant.stats"(%arg0) {layerStats = dense<
103   // CHECK-SAME:      [-1.000000e+00, 1.000000e+00]> : tensor<2xf32>}
104   // CHECK-SAME:      : (tensor<40x37xf32>) -> tensor<40x37xf32>
105   // CHECK-NEXT:    %[[cst:.*]] = "tfl.pseudo_const"() {value = dense<
106   // CHECK-SAME:      1.000000e+00> : tensor<40xf32>} : () -> tensor<40xf32>
107   // CHECK-NEXT:    %[[fc:.*]]:2 = "tfl.fully_connected"(%[[stat]], %arg1,
108   // CHECK-NEXT:    %[[stat1:.*]] = "quant.stats"(%[[fc]]#0) {axis = 1 : i64,
109   // CHECK-SAME:      axisStats = dense<{{\[}}[-0.000000e+00, 0.000000e+00],
110   // CHECK-SAME:      [-1.000000e+00, 1.000000e+00],
111   // CHECK-SAME:      [-2.000000e+00, 2.000000e+00]
112   // CHECK-NEXT:    return %[[stat1]] : tensor<40x40xf32>
113   // CHECK-NEXT:  }
114 
115   // Find the tensors and inject the min and max to the input and output
116   for (auto& sub_graph : model->subgraphs) {
117     for (auto& op : sub_graph->operators) {
118       if (tflite::GetBuiltinCode(
119               model->operator_codes[op->opcode_index].get()) ==
120           tflite::BuiltinOperator_FULLY_CONNECTED) {
121         // inject min/max to the input and output tensors
122         auto& input_tensor = sub_graph->tensors[op->inputs[0]];
123         input_tensor->quantization->scale.clear();
124         input_tensor->quantization->zero_point.clear();
125         input_tensor->quantization->min.push_back(-1.0);
126         input_tensor->quantization->max.push_back(1.0);
127 
128         auto& output_tensor = sub_graph->tensors[op->outputs[0]];
129         auto shape = output_tensor->shape;
130         output_tensor->quantization->scale.clear();
131         output_tensor->quantization->zero_point.clear();
132         for (int i = 0; i < shape.back(); ++i) {
133           output_tensor->quantization->min.push_back(-1.0 * i);
134           output_tensor->quantization->max.push_back(1.0 * i);
135         }
136         output_tensor->quantization->quantized_dimension = shape.size() - 1;
137       }
138     }
139   }
140   return model;
141 }
142 
143 }  // namespace
144 }  // namespace mlir
145 
main(int argc,char ** argv)146 int main(int argc, char** argv) {
147   llvm::InitLLVM y(argc, argv);
148   llvm::cl::ParseCommandLineOptions(argc, argv);
149   auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
150   if (std::error_code error = file_or_err.getError()) {
151     llvm::errs() << argv[0] << ": could not open input file '" << inputFileName
152                  << "': " << error.message() << "\n";
153     return 1;
154   }
155   auto buffer = file_or_err->get();
156   auto maybe_module =
157       mlir::InjectStatsToFullyConnected(buffer->getBuffer().str());
158   if (!maybe_module.hasValue()) {
159     return 1;
160   }
161   flatbuffers::FlatBufferBuilder builder;
162   flatbuffers::Offset<tflite::Model> output_model_location =
163       tflite::Model::Pack(builder, maybe_module.getValue().get());
164   tflite::FinishModelBuffer(builder, output_model_location);
165   std::string output_model_content(
166       reinterpret_cast<const char*>(builder.GetBufferPointer()),
167       builder.GetSize());
168   std::cout << output_model_content << "\n";
169   return 0;
170 }
171