• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/memory/memory.h"
17 #include "absl/strings/str_split.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
30 #include "mlir/IR/AffineMap.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
40 
41 // NOLINTNEXTLINE
42 static llvm::cl::opt<std::string> quantize_stats(
43     "quant-test-stats", llvm::cl::value_desc("string"),
44     llvm::cl::desc("serialized quant info string. Only used in tests"),
45     llvm::cl::init(""));
46 
47 //===----------------------------------------------------------------------===//
48 // The Pass to import quantization stats to the ops in a function. This requires
49 // a custom method to retrieve the unique name of the operation.
50 
51 namespace mlir {
52 namespace quant {
53 
54 using QuantParamsEntry = QuantizationInfo::QuantParams;
55 
56 namespace {
57 class ImportQuantStatsPass
58     : public PassWrapper<ImportQuantStatsPass, FunctionPass> {
59  public:
ImportQuantStatsPass(OperationToName op_to_name)60   explicit ImportQuantStatsPass(OperationToName op_to_name)
61       : op_to_name_(op_to_name) {}
62 
getArgument() const63   StringRef getArgument() const final {
64     // This is the argument used to refer to the pass in
65     // the textual format (on the commandline for example).
66     return "quant-import-stats";
67   }
getDescription() const68   StringRef getDescription() const final {
69     // This is a brief description of the pass.
70     return "Import quantization stats to the model";
71   }
72 
73   void runOnFunction() override;
74 
getDependentDialects(DialectRegistry & registry) const75   void getDependentDialects(DialectRegistry &registry) const override {
76     registry.insert<quant::QuantizationDialect>();
77   }
78 
79   // Parses the serialized quant stats protobuf and initialize the internal
80   // data structure. This method must be called after the pass is created.
81   bool ParseQuantStats(const std::string &stats_str);
82 
83  private:
84   void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
85                         const QuantParamsEntry &info);
86 
87   void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
88                              ElementsAttr axis_stats, IntegerAttr axis);
89 
90   // If the index is out of range, this method returns false. Otherwise it
91   // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)92   bool IsQuantizableResult(Operation *op, int index) {
93     if (index < 0 || index >= static_cast<int>(op->getNumResults()))
94       return false;
95     Value res = op->getResult(index);
96     return res.getType().isa<ShapedType>() &&
97            res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
98   }
99 
100   // A method to retrieve the name for the given op.
101   OperationToName op_to_name_;
102 
103   // We split the normal names and regex names, since the former can use hash
104   // map to lookup and the latter needs to iterate all the regex to find the
105   // match.
106   // The `int` in the following two containers are to specify the result index
107   // of the given op. -1 indicates all the floating-point results.
108   llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
109   llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
110 };
111 }  // namespace
112 
ParseQuantStats(const std::string & stats_str)113 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
114   QuantizationInfo quant_stats;
115   if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
116     return true;
117   }
118 
119   for (const auto &entry : quant_stats.entries()) {
120     if (!entry.name().empty()) {
121       std::vector<std::string> name_and_port =
122           absl::StrSplit(entry.name(), ':');
123       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
124       name_to_info_.insert({name_and_port[0], {port, entry}});
125     } else if (!entry.name_regex().empty()) {
126       std::vector<std::string> name_and_port =
127           absl::StrSplit(entry.name_regex(), ':');
128       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
129       regex_to_info_.insert({name_and_port[0], {port, entry}});
130     }
131   }
132   return false;
133 }
134 
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)135 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
136                                                  ElementsAttr layer_stats,
137                                                  ElementsAttr axis_stats,
138                                                  IntegerAttr axis) {
139   auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
140                                                 layer_stats, axis_stats, axis);
141   res.replaceAllUsesWith(stats_op);
142   stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
143 }
144 
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)145 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
146                                             int index,
147                                             const QuantParamsEntry &info) {
148   if (info.params_size() == 0) return;
149 
150   SmallVector<APFloat, 4> min_maxs;
151   min_maxs.reserve(info.params_size() * 2);
152   for (const auto &param : info.params()) {
153     llvm::APFloat min(param.min_max().min());
154     llvm::APFloat max(param.min_max().max());
155     min_maxs.push_back(min);
156     min_maxs.push_back(max);
157   }
158   // The layer stats contain only the first min/max pairs.
159   ElementsAttr layer_stats = DenseFPElementsAttr::get(
160       RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
161   ElementsAttr axis_stats;
162   IntegerAttr axis;
163 
164   if (info.params_size() > 1) {
165     SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
166     axis_stats = DenseFPElementsAttr::get(
167         RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
168     axis = b.getI64IntegerAttr(info.meta().quantize_axis());
169   }
170 
171   b.setInsertionPointAfter(op);
172   if (IsQuantizableResult(op, index)) {
173     InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
174                           axis);
175   } else {
176     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
177       if (IsQuantizableResult(op, i)) {
178         InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
179                               axis);
180       }
181     }
182   }
183 }
184 
runOnFunction()185 void ImportQuantStatsPass::runOnFunction() {
186   FuncOp func = getFunction();
187   OpBuilder builder(func);
188 
189   func.walk([&](Operation *op) {
190     if (op->hasTrait<OpTrait::IsTerminator>()) return;
191     auto op_name = op_to_name_(op);
192 
193     // Check the named info collection first.
194     auto it = name_to_info_.find(op_name);
195     if (it != name_to_info_.end()) {
196       ImportAsStatsOps(builder, op, it->second.first, it->second.second);
197       return;
198     }
199 
200     // Iterate all the regex names and matches the first one.
201     for (auto &regex : regex_to_info_) {
202       if (llvm::Regex(regex.first()).match(op_name)) {
203         ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
204         break;
205       }
206     }
207   });
208 }
209 
210 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)211 std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
212     OperationToName op_to_name, const std::string &stats_str) {
213   auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
214   if (pass->ParseQuantStats(stats_str)) return nullptr;
215   return pass;
216 }
217 
218 // Creates an instance pass to import quantization stats to the operations in
219 // the function. A custom method to get the name from the op is used because
220 // different dialect ops might have different ways to assign the name.
221 std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)222 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
223   auto get_name_func = [](Operation *op) {
224     Location loc = op->getLoc();
225     if (auto name = loc.dyn_cast<NameLoc>()) {
226       return name.getName().strref();
227     } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
228       for (auto sub_loc : fused_name.getLocations()) {
229         if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
230           return named_sub_loc.getName().strref();
231         }
232       }
233     }
234     return llvm::StringRef("");
235   };
236 
237   return CreateImportQuantStatsPass(get_name_func, stats_str);
238 }
239 
240 // Registers this pass with default values, only for test
__anonde7f3bdc0402null241 static PassRegistration<ImportQuantStatsPass> pass([] {
242   return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
243 });
244 
245 }  // namespace quant
246 }  // namespace mlir
247