1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/memory/memory.h"
17 #include "absl/strings/str_split.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/AffineExpr.h" // from @llvm-project
30 #include "mlir/IR/AffineMap.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/PatternMatch.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
40
41 // NOLINTNEXTLINE
42 static llvm::cl::opt<std::string> quantize_stats(
43 "quant-test-stats", llvm::cl::value_desc("string"),
44 llvm::cl::desc("serialized quant info string. Only used in tests"),
45 llvm::cl::init(""));
46
47 //===----------------------------------------------------------------------===//
48 // The Pass to import quantization stats to the ops in a function. This requires
49 // a custom method to retrieve the unique name of the operation.
50
51 namespace mlir {
52 namespace quant {
53
54 using QuantParamsEntry = QuantizationInfo::QuantParams;
55
56 namespace {
57 class ImportQuantStatsPass
58 : public PassWrapper<ImportQuantStatsPass, FunctionPass> {
59 public:
ImportQuantStatsPass(OperationToName op_to_name)60 explicit ImportQuantStatsPass(OperationToName op_to_name)
61 : op_to_name_(op_to_name) {}
62
getArgument() const63 StringRef getArgument() const final {
64 // This is the argument used to refer to the pass in
65 // the textual format (on the commandline for example).
66 return "quant-import-stats";
67 }
getDescription() const68 StringRef getDescription() const final {
69 // This is a brief description of the pass.
70 return "Import quantization stats to the model";
71 }
72
73 void runOnFunction() override;
74
getDependentDialects(DialectRegistry & registry) const75 void getDependentDialects(DialectRegistry ®istry) const override {
76 registry.insert<quant::QuantizationDialect>();
77 }
78
79 // Parses the serialized quant stats protobuf and initialize the internal
80 // data structure. This method must be called after the pass is created.
81 bool ParseQuantStats(const std::string &stats_str);
82
83 private:
84 void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
85 const QuantParamsEntry &info);
86
87 void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
88 ElementsAttr axis_stats, IntegerAttr axis);
89
90 // If the index is out of range, this method returns false. Otherwise it
91 // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)92 bool IsQuantizableResult(Operation *op, int index) {
93 if (index < 0 || index >= static_cast<int>(op->getNumResults()))
94 return false;
95 Value res = op->getResult(index);
96 return res.getType().isa<ShapedType>() &&
97 res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
98 }
99
100 // A method to retrieve the name for the given op.
101 OperationToName op_to_name_;
102
103 // We split the normal names and regex names, since the former can use hash
104 // map to lookup and the latter needs to iterate all the regex to find the
105 // match.
106 // The `int` in the following two containers are to specify the result index
107 // of the given op. -1 indicates all the floating-point results.
108 llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
109 llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
110 };
111 } // namespace
112
ParseQuantStats(const std::string & stats_str)113 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
114 QuantizationInfo quant_stats;
115 if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
116 return true;
117 }
118
119 for (const auto &entry : quant_stats.entries()) {
120 if (!entry.name().empty()) {
121 std::vector<std::string> name_and_port =
122 absl::StrSplit(entry.name(), ':');
123 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
124 name_to_info_.insert({name_and_port[0], {port, entry}});
125 } else if (!entry.name_regex().empty()) {
126 std::vector<std::string> name_and_port =
127 absl::StrSplit(entry.name_regex(), ':');
128 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
129 regex_to_info_.insert({name_and_port[0], {port, entry}});
130 }
131 }
132 return false;
133 }
134
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)135 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
136 ElementsAttr layer_stats,
137 ElementsAttr axis_stats,
138 IntegerAttr axis) {
139 auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
140 layer_stats, axis_stats, axis);
141 res.replaceAllUsesWith(stats_op);
142 stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
143 }
144
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)145 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
146 int index,
147 const QuantParamsEntry &info) {
148 if (info.params_size() == 0) return;
149
150 SmallVector<APFloat, 4> min_maxs;
151 min_maxs.reserve(info.params_size() * 2);
152 for (const auto ¶m : info.params()) {
153 llvm::APFloat min(param.min_max().min());
154 llvm::APFloat max(param.min_max().max());
155 min_maxs.push_back(min);
156 min_maxs.push_back(max);
157 }
158 // The layer stats contain only the first min/max pairs.
159 ElementsAttr layer_stats = DenseFPElementsAttr::get(
160 RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
161 ElementsAttr axis_stats;
162 IntegerAttr axis;
163
164 if (info.params_size() > 1) {
165 SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
166 axis_stats = DenseFPElementsAttr::get(
167 RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
168 axis = b.getI64IntegerAttr(info.meta().quantize_axis());
169 }
170
171 b.setInsertionPointAfter(op);
172 if (IsQuantizableResult(op, index)) {
173 InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
174 axis);
175 } else {
176 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
177 if (IsQuantizableResult(op, i)) {
178 InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
179 axis);
180 }
181 }
182 }
183 }
184
runOnFunction()185 void ImportQuantStatsPass::runOnFunction() {
186 FuncOp func = getFunction();
187 OpBuilder builder(func);
188
189 func.walk([&](Operation *op) {
190 if (op->hasTrait<OpTrait::IsTerminator>()) return;
191 auto op_name = op_to_name_(op);
192
193 // Check the named info collection first.
194 auto it = name_to_info_.find(op_name);
195 if (it != name_to_info_.end()) {
196 ImportAsStatsOps(builder, op, it->second.first, it->second.second);
197 return;
198 }
199
200 // Iterate all the regex names and matches the first one.
201 for (auto ®ex : regex_to_info_) {
202 if (llvm::Regex(regex.first()).match(op_name)) {
203 ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
204 break;
205 }
206 }
207 });
208 }
209
210 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)211 std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
212 OperationToName op_to_name, const std::string &stats_str) {
213 auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
214 if (pass->ParseQuantStats(stats_str)) return nullptr;
215 return pass;
216 }
217
218 // Creates an instance pass to import quantization stats to the operations in
219 // the function. A custom method to get the name from the op is used because
220 // different dialect ops might have different ways to assign the name.
221 std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)222 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
223 auto get_name_func = [](Operation *op) {
224 Location loc = op->getLoc();
225 if (auto name = loc.dyn_cast<NameLoc>()) {
226 return name.getName().strref();
227 } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
228 for (auto sub_loc : fused_name.getLocations()) {
229 if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
230 return named_sub_loc.getName().strref();
231 }
232 }
233 }
234 return llvm::StringRef("");
235 };
236
237 return CreateImportQuantStatsPass(get_name_func, stats_str);
238 }
239
240 // Registers this pass with default values, only for test
__anonde7f3bdc0402null241 static PassRegistration<ImportQuantStatsPass> pass([] {
242 return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
243 });
244
245 } // namespace quant
246 } // namespace mlir
247