1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Matchers.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
41 #include "tensorflow/core/platform/logging.h"
42
43 #define DEBUG_TYPE "quantization-driver"
44
45 namespace mlir {
46 namespace quant {
47 namespace {
EmptyParams(QuantParams p)48 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
49
50 // The state for each op result during the quantization parameters propagation.
51 struct QuantState {
52 // Quantization parameters propagated to an op result.
53 QuantParams params;
54 // A flag indicates this state (the params) shouldn't be changed after it is
55 // initialized. This flag will be set to true if the quantization parameters
56 // are from the quantization-aware training.
57 const bool immutable;
58
IsEmptymlir::quant::__anonf70393e20111::QuantState59 bool IsEmpty() { return EmptyParams(params); }
60 };
61
62 // The state for rescaling the propagated quantization parameters. This can be
63 // on the input side to satisfy the constraint of previous operation, or on the
64 // output side to satisfy the constraint of the next operation.
65 struct RequantizeState {
66 // Sometimes, we have to "requantize" the quantization result to satisfy all
67 // the constraints. The "requantize" can happen either on the input or output
68 // of the quantization result.
69 enum RequantizePosition {
70 NO_REQUANTIZE,
71 ON_INPUT,
72 ON_OUTPUT
73 } pos = NO_REQUANTIZE;
74
75 // Quantization parameters will be used to add the requantize ops.
76 QuantParams params;
77 };
78
79 // This is a worklist-driven driver for propagating quantization parameters
80 // across operations.
81 //
82 // The initial quantization parameters are extracted from the quantized type
83 // between adjacent tfl.quantize and tfl.dequantize ops. All these initial
84 // parameters are marked as immutable because they are from quantization-aware
85 // training.
86 //
87 // The algorithm traverses each op and sets the quantization parameters of its
88 // operands and results, according to its quantization specification, and then
89 // adds the operands and results to the worklist. If there are any conflicts
90 // (for example, there are quantization parameters propagated from the previous
91 // iteration), this process stops if the existing parameters are the immutable,
92 // or adding `requantize` op to resolve the conflicts.
93 //
94 // After the algorithm is converged, pairs of tfl.quantize and tfl.dequantize
95 // are inserted to the right position to materialize the propagation and
96 // requantize results.
97 //
98 class QuantizationDriver {
99 public:
QuantizationDriver(FuncOp fn,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_range,bool legacy_float_scale)100 explicit QuantizationDriver(FuncOp fn, bool is_signed,
101 bool disable_per_channel,
102 OpQuantSpecGetter op_quant_spec_getter,
103 bool infer_tensor_range, bool legacy_float_scale)
104 : fn_(fn),
105 builder_(fn.getBody()),
106 is_signed_(is_signed),
107 disable_per_channel_(disable_per_channel),
108 op_quant_spec_getter_(op_quant_spec_getter),
109 infer_tensor_range_(infer_tensor_range),
110 legacy_float_scale_(legacy_float_scale) {}
111
112 // The entry point of the quantization parameters propagation.
113 void Run();
114
115 private:
116 // This is used to identify an operand or result of an op. The second element
117 // of this pair is the index of the operand or result.
118 using OpValue = std::pair<mlir::Operation *, int>;
119
120 // Sets up the states for all the op results in the function.
121 void Initialize();
122
123 // Propagates the quantization parameters across all the ops.
124 bool PropagateParams();
125
126 // Inserts the Quantize and Dequantize ops according to the propagation
127 // result.
128 void Finalize();
129
130 // The quantization parameters of bias operand are usually determined by
131 // other operands, so if a constant is used by different ops as bias, it needs
132 // to be duplicated, thus each op can assign its own quantization parameter
133 // for this bias. Also this method adds all the non-bias constants (weights)
134 // to a set for looking up later. This method also adds all the per-channel
135 // weights to a set for looking up later.
136 void PreprocessConstantOps();
137
138 // Setup all the data structures for quantization propagation.
139 void SetupAllStates();
140
141 // Whether the constant is a weight, which shouldn't be shared by different
142 // ops.
IsWeight(Operation * cst)143 bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
144
145 // Returns all the related quantization constraints of the op.
146 std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
147
148 // Whether Quantization parameters have been propagated to the results of this
149 // op.
150 bool IsQuantized(Operation *op);
151
152 // Adds all the users of index-th result of op to the work list.
AddUserToList(Operation * op,int index)153 void AddUserToList(Operation *op, int index) {
154 for (auto *user : op->getResult(index).getUsers()) {
155 work_list_.push_back(user);
156 }
157 }
158
159 // Adds the defining op of index-th operand of op to the work list.
AddOperandToList(Operation * op,int index)160 void AddOperandToList(Operation *op, int index) {
161 if (auto *inst = op->getOperand(index).getDefiningOp()) {
162 work_list_.push_back(inst);
163 }
164 }
165
166 // Returns the quantization params for the bias input from the non-bias
167 // operands which have their indexes in the `non_biases` vector. The returned
168 // parameters are calculated by `func`.
169 QuantParams GetBiasParams(Operation *op, int bias,
170 const std::vector<int> &non_biases,
171 AccumulatorScaleFunc func);
172
173 // Sets the quantization parameters of the result to a fixed value. If any
174 // quantization parameters have been propagated, a `requantize` will happen on
175 // the input of propagated quantization.
176 bool SetResultParams(Operation *op, int index, QuantParams params);
177
178 // Sets the quantization parameters of the operand to a fixed value. If any
179 // quantization parameters have been propagated, a `requantize` will happen on
180 // the output of propagated quantization.
181 bool SetOperandParams(Operation *op, int index, QuantParams params);
182
183 // Sets the quantization parameters of the constant result according to its
184 // content.
185 bool SetConstantResultParams(Operation *op);
186
187 // Inserts the Quantize and Dequantize ops for quantizing the index-th result
188 // of the op.
189 void QuantizeOpResult(Operation *op, int index, QuantParams params);
190
191 void QuantizeArg(BlockArgument arg, QuantParams params);
192
193 // Inserts the Quantize and Dequantize ops to quantize the value and returns
194 // the Quantize op.
195 void QuantizeValue(Value value, QuantParams params, Location loc);
196
197 // Inserts the Quantize ops for requantizing the index-th result of the op.
198 void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
199
200 void RequantizeArg(BlockArgument arg, RequantizeState *state);
201
202 // Inserts the Quantize and Dequantize ops to quantize the value and returns
203 // the Quantize op.
204 void RequantizeValue(Value value, RequantizeState *state, Location loc);
205
206 // A heuristic to get the quantization parameter satisfies the same scale
207 // constraints for the op. Returns an empty option if this quantization
208 // parameter doesn't exist.
209 QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
210
211 // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)212 QuantState &GetOperandQuantState(Operation *op, int index) {
213 return states_[operand_states_[{op, index}]];
214 }
215
216 // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)217 QuantState &GetResultQuantState(Operation *op, int index) {
218 return states_[result_states_[{op, index}]];
219 }
220
GetArgQuantState(BlockArgument arg)221 QuantState &GetArgQuantState(BlockArgument arg) {
222 return states_[arg_states_[arg]];
223 }
224
225 // Returns the state of the index-th operand of the op.
GetOperandRequantizeState(Operation * op,int index)226 RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
227 return rescale_states_[operand_states_[{op, index}]];
228 }
229
230 // Returns the state of the index-th result of the op.
GetResultRequantizeState(Operation * op,int index)231 RequantizeState &GetResultRequantizeState(Operation *op, int index) {
232 return rescale_states_[result_states_[{op, index}]];
233 }
234
GetArgRequantizeState(BlockArgument arg)235 RequantizeState &GetArgRequantizeState(BlockArgument arg) {
236 return rescale_states_[arg_states_[arg]];
237 }
238
239 // Uses the type of `val` to set the initial state of the index-th result if
240 // `as_result` is true or index-th operand if `as_result` is false. The state
241 // is immutable if the type is a quantized type. Returns the index of this
242 // new state in the state vector.
243 int InitializeState(Operation *op, int index, Value val, bool as_result);
244
245 // Sets the state of an argument. If this value is cached, uses the cached
246 // result without creating new entry in the state vector. Otherwise, allocate
247 // a new entry in the state vector.
InitializeArgState(BlockArgument arg,Value in,llvm::DenseMap<Value,int> * cache)248 void InitializeArgState(BlockArgument arg, Value in,
249 llvm::DenseMap<Value, int> *cache) {
250 auto cached = cache->insert({in, 0});
251 if (!cached.second) {
252 arg_states_[arg] = cached.first->second;
253 return;
254 }
255 QuantParams params =
256 quant::QuantizedType::getQuantizedElementType(in.getType());
257 bool immutable = !EmptyParams(params);
258 int next_state_index = states_.size();
259 states_.push_back({params, immutable});
260 arg_states_[arg] = next_state_index;
261 cached.first->second = next_state_index;
262 }
263
264 // Sets the state of the index-th operand of the op. If this operand is
265 // cached, uses the cached result without creating new entry in the state
266 // vector. Otherwise, allocate a new entry in the state vector.
InitializeOperandState(Operation * op,int index,Value in,llvm::DenseMap<Value,int> * cache)267 void InitializeOperandState(Operation *op, int index, Value in,
268 llvm::DenseMap<Value, int> *cache) {
269 auto cached = cache->insert({in, 0});
270 if (!cached.second) {
271 operand_states_.insert({{op, index}, cached.first->second});
272 return;
273 }
274 cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
275 }
276
277 // Sets the state of the index-th result of the op. If this result is cached,
278 // uses the cached result without creating new entry in the state vector.
279 // Otherwise, allocate a new entry in the state vector.
InitializeResultState(Operation * op,int index,Value res,llvm::DenseMap<Value,int> * cache)280 void InitializeResultState(Operation *op, int index, Value res,
281 llvm::DenseMap<Value, int> *cache) {
282 auto cached = cache->insert({res, 0});
283 if (!cached.second) {
284 result_states_.insert({{op, index}, cached.first->second});
285 return;
286 }
287 cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
288 }
289
DumpStates(Operation * current_op)290 void DumpStates(Operation *current_op) {
291 if (current_op) {
292 llvm::dbgs() << "\n\n\n" << current_op->getName() << "\n";
293 }
294 fn_.walk([&](Operation *op) {
295 if (op->hasTrait<OpTrait::IsTerminator>() ||
296 op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
297 llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
298 op))
299 return;
300 if (current_op == op) llvm::dbgs() << "===>>>";
301 llvm::dbgs() << op->getName() << " : (";
302 if (llvm::isa<FuncOp>(op)) {
303 for (auto &arg : fn_.getArguments()) {
304 if (auto params = GetArgQuantState(arg).params) {
305 params.print(llvm::dbgs());
306 auto requantize_state = GetArgRequantizeState(arg);
307 if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
308 llvm::dbgs() << "+";
309 requantize_state.params.print(llvm::dbgs());
310 }
311 }
312 llvm::dbgs() << ",";
313 }
314 }
315 for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
316 if (auto params = GetOperandQuantState(op, i).params) {
317 params.print(llvm::dbgs());
318 auto requantize_state = GetOperandRequantizeState(op, i);
319 if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
320 llvm::dbgs() << "+";
321 requantize_state.params.print(llvm::dbgs());
322 }
323 } else {
324 op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
325 llvm::dbgs());
326 }
327 llvm::dbgs() << ",";
328 }
329 llvm::dbgs() << ") -> (";
330 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
331 if (auto params = GetResultQuantState(op, i).params) {
332 params.print(llvm::dbgs());
333 auto requantize_state = GetResultRequantizeState(op, i);
334 if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
335 llvm::dbgs() << "+";
336 requantize_state.params.print(llvm::dbgs());
337 }
338 } else {
339 op->getResult(i).getType().cast<ShapedType>().getElementType().print(
340 llvm::dbgs());
341 }
342 llvm::dbgs() << ",";
343 }
344 llvm::dbgs() << ")\n";
345 });
346 }
347
348 FuncOp fn_;
349 OpBuilder builder_;
350 bool is_signed_;
351 bool disable_per_channel_;
352
353 // We should distinguish weights and bias constants. Biases are specified by
354 // the quantization spec or are the operands of ops with same scale spec. The
355 // rest are weights.
356 llvm::DenseSet<Operation *> weights_;
357
358 // The weights require narrow_range quantization. This map collects all the
359 // weight operands defined by the op quant spec. If the value of the entry is
360 // positive, per-channel quantization is required.
361 llvm::DenseMap<Operation *, int> optimized_weights_;
362
363 // All the ops needs to propagate the quantization parameters to.
364 std::vector<Operation *> work_list_;
365 std::unordered_set<Operation *> quantized_;
366
367 // The vector contains all the quantization parameters propagated from the
368 // defining operations of the value, or from the quantization aware training.
369 std::vector<QuantState> states_;
370
371 // The map contains all the quantization parameters which are required to
372 // satisfy the same operands and results constraint. The keys of this map are
373 // the values from `operand_states_` and `result_state_`.
374 std::unordered_map<int, RequantizeState> rescale_states_;
375
376 // Maps of indexes to the propagation state vector from the ops operands,
377 // results and arguments.
378 llvm::DenseMap<OpValue, int> operand_states_;
379 llvm::DenseMap<OpValue, int> result_states_;
380 llvm::DenseMap<BlockArgument, int> arg_states_;
381
382 // This vector is to preserve the arguments order, so the newly inserted
383 // quantized ops for the arguments are deterministically ordered.
384 llvm::SmallVector<BlockArgument, 4> args_;
385
386 OpQuantSpecGetter op_quant_spec_getter_;
387
388 // Infer output ranges for activation ops and constants. This is usually
389 // required for post-training quantization.
390 bool infer_tensor_range_;
391
392 // Calculate scales in float instead of double, so that the scales and
393 // quantized values are exactly the same with the TOCO quantizer.
394 bool legacy_float_scale_;
395 };
396 } // namespace
397
GetQuantSpec(Operation * op)398 std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
399 return op_quant_spec_getter_(op);
400 }
401
IsQuantized(Operation * op)402 bool QuantizationDriver::IsQuantized(Operation *op) {
403 for (int i = 0, e = op->getNumResults(); i != e; ++i) {
404 if (GetResultQuantState(op, i).IsEmpty()) return false;
405 }
406 return true;
407 }
408
InitializeState(Operation * op,int index,Value val,bool as_result)409 int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
410 bool as_result) {
411 QuantParams params =
412 quant::QuantizedType::getQuantizedElementType(val.getType());
413 bool immutable = !EmptyParams(params);
414 int next_state_index = states_.size();
415 states_.push_back({params, immutable});
416 if (as_result)
417 result_states_.insert({{op, index}, next_state_index});
418 else
419 operand_states_.insert({{op, index}, next_state_index});
420
421 return next_state_index;
422 }
423
SetConstantResultParams(Operation * op)424 bool QuantizationDriver::SetConstantResultParams(Operation *op) {
425 DenseFPElementsAttr attr;
426 Value res = op->getResult(0);
427 if (!matchPattern(res, m_Constant(&attr))) {
428 return false;
429 }
430 // TODO(fengliuai): make storage_type_width and narrow_range configurable.
431 Type final_type;
432 auto it = optimized_weights_.find(op);
433 bool is_weight = it != optimized_weights_.end();
434 bool is_weight_with_per_channel_support =
435 is_weight && it->second != -1 && is_signed_;
436
437 if (is_weight_with_per_channel_support && !disable_per_channel_) {
438 // When `disable_per_channel_` is false, per-channel symmetric quantization
439 // parameters are created from the weights when the ops support per-channel
440 // quantization. Otherwise, uses per-tensor asymmetric quantization with
441 // narrow range.
442
443 // per-axis quantization weight, with symmetric min/max enforced.
444 final_type = GetUniformQuantizedPerAxisTypeForWeight(
445 attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_,
446 /*narrow_range=*/true, legacy_float_scale_);
447 } else {
448 // per-tensor quantization weight
449 final_type = GetUniformQuantizedTypeForWeight(
450 attr, /*symmetric=*/is_weight && is_signed_,
451 /*num_bits=*/8, is_signed_,
452 /*narrow_range_=*/is_weight, legacy_float_scale_);
453 }
454 if (auto quant_type = final_type.dyn_cast_or_null<quant::QuantizedType>()) {
455 return SetResultParams(op, 0, quant_type);
456 }
457 return false;
458 }
459
SetResultParams(Operation * op,int res_index,QuantParams params)460 bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
461 QuantParams params) {
462 auto &state = GetResultQuantState(op, res_index);
463 if (state.params == params) {
464 return false;
465 }
466 if (!state.IsEmpty()) {
467 auto &rescale = GetResultRequantizeState(op, res_index);
468 rescale.params = params;
469 rescale.pos = RequantizeState::ON_INPUT;
470 return true;
471 }
472 state.params = params;
473 AddUserToList(op, res_index);
474 return true;
475 }
476
GetBiasParams(Operation * op,int bias,const std::vector<int> & non_biases,AccumulatorScaleFunc func)477 QuantParams QuantizationDriver::GetBiasParams(
478 Operation *op, int bias, const std::vector<int> &non_biases,
479 AccumulatorScaleFunc func) {
480 auto &bias_state = GetOperandQuantState(op, bias);
481 if (!bias_state.IsEmpty()) {
482 return bias_state.params;
483 }
484 std::vector<QuantParams> op_types;
485 op_types.reserve(non_biases.size());
486 for (auto non_bias : non_biases) {
487 auto &non_bias_type = GetOperandQuantState(op, non_bias);
488 op_types.push_back(non_bias_type.params);
489 }
490 if (op_types.empty()) return {};
491 return func(op_types, legacy_float_scale_);
492 }
493
SetOperandParams(Operation * op,int index,QuantParams params)494 bool QuantizationDriver::SetOperandParams(Operation *op, int index,
495 QuantParams params) {
496 auto &state = GetOperandQuantState(op, index);
497 if (state.params == params) {
498 return false;
499 }
500
501 if (!state.IsEmpty()) {
502 auto &rescale = GetOperandRequantizeState(op, index);
503 rescale.params = params;
504 rescale.pos = RequantizeState::ON_OUTPUT;
505 return true;
506 }
507
508 state.params = params;
509 AddOperandToList(op, index);
510 return true;
511 }
512
QuantizeOpResult(Operation * op,int index,QuantParams params)513 void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
514 QuantParams params) {
515 builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
516 Value original_result = op->getResult(index);
517 QuantizeValue(original_result, params, op->getLoc());
518 }
519
QuantizeArg(BlockArgument arg,QuantParams params)520 void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
521 builder_.setInsertionPointToStart(arg.getOwner());
522 QuantizeValue(arg, params, builder_.getUnknownLoc());
523 }
524
QuantizeValue(Value value,QuantParams params,Location loc)525 void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
526 Location loc) {
527 Type expressed_type = value.getType();
528 Type new_type = params.castFromExpressedType(expressed_type);
529 // This value isn't an expressed type (float), skip.
530 if (!new_type) return;
531
532 auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
533 auto dequantize = builder_.create<quant::DequantizeCastOp>(
534 loc, expressed_type, quantize.getResult());
535
536 // This attribute is set to distinguish the quantize ops being added by the
537 // quantization pass. These ops can be removed without losing original
538 // program accuracy.
539 // TODO(fengliuai): make the attribute being part of op definition.
540 quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
541
542 // `original_result` has a use to `quantize`, so this will replace that use
543 // by the result of `dequantize`. Remember to reset that use afterwards
544 value.replaceAllUsesWith(dequantize);
545 quantize.getOperation()->replaceUsesOfWith(dequantize, value);
546 }
547
RequantizeOpResult(Operation * op,int index,RequantizeState * state)548 void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
549 RequantizeState *state) {
550 if (state->pos == RequantizeState::NO_REQUANTIZE) return;
551 builder_.setInsertionPointAfter(op);
552 Value value = op->getResult(index);
553 if (state->pos == RequantizeState::ON_OUTPUT) {
554 Operation *user = value.getUses().begin().getUser();
555 if (llvm::isa<quant::QuantizeCastOp>(user)) {
556 // The requantize op is inserted between `quantize` and `dequantize` ops.
557 value = user->getResult(0);
558 builder_.setInsertionPointAfter(user);
559 }
560 }
561 RequantizeValue(value, state, op->getLoc());
562 }
563
RequantizeArg(BlockArgument arg,RequantizeState * state)564 void QuantizationDriver::RequantizeArg(BlockArgument arg,
565 RequantizeState *state) {
566 Value value = arg;
567 builder_.setInsertionPointToStart(arg.getOwner());
568 if (value.hasOneUse()) {
569 auto user = value.use_begin().getUser();
570 if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
571 value = q.getResult();
572 builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
573 }
574 }
575 RequantizeValue(value, state, builder_.getUnknownLoc());
576 }
577
RequantizeValue(Value value,RequantizeState * state,Location loc)578 void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
579 Location loc) {
580 Type new_type;
581 if (state->pos == RequantizeState::ON_INPUT) {
582 Type expressed_type = value.getType();
583 // The value needs to be requantized. A Quantize op will be created to use
584 // it as the operand and replace its uses.
585 new_type = state->params.castFromExpressedType(expressed_type);
586 } else {
587 Type expressed_type =
588 quant::QuantizedType::castToExpressedType(value.getType());
589 if (!expressed_type) return;
590
591 // The value needs to be requantized. A Quantize op will be created to use
592 // it as the operand and replace its uses.
593 new_type = state->params.castFromExpressedType(expressed_type);
594 }
595 // This value isn't an expressed type (float), skip.
596 if (!new_type) return;
597
598 auto requantize_op =
599 builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
600 value.replaceAllUsesWith(requantize_op);
601 requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
602 }
603
604 // A heuristic to get quantization parameters satisfies the same scale
605 // constraints:
606 // - If there are immutable states,
607 // - use the single input, or,
608 // - use the single output, or,
609 // - use the first one in the collection,
610 // - use the single input if it is ready, or,
611 // - use the single output if it is ready, or,
612 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)613 QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
614 Operation *op) {
615 // Two vector to collect Non-empty operands and results states.
616 std::vector<QuantState *> mutable_states, immutable_states;
617 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
618 auto &state = GetOperandQuantState(op, i);
619 if (state.immutable) {
620 immutable_states.push_back(&state);
621 } else if (!state.IsEmpty()) {
622 mutable_states.push_back(&state);
623 }
624 }
625
626 int immutable_operands_num = immutable_states.size();
627 int mutable_operands_num = mutable_states.size();
628 // Use the operand's state if it is immutable and it is the only one
629 // operand.
630 if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
631 return immutable_states.front()->params;
632 }
633
634 for (int i = 0, e = op->getNumResults(); i != e; ++i) {
635 auto &state = GetResultQuantState(op, i);
636 if (state.immutable) {
637 immutable_states.push_back(&state);
638 } else if (!state.IsEmpty()) {
639 mutable_states.push_back(&state);
640 }
641 }
642
643 int immutable_results_num = immutable_states.size() - immutable_operands_num;
644 int mutable_results_num = mutable_states.size() - mutable_operands_num;
645 // Use the result's state if it is immutable and it is the only one result.
646 if (op->getNumResults() == 1 && immutable_results_num == 1) {
647 return immutable_states.back()->params;
648 }
649
650 // Use the first immutable state to quantize the rest operands and results.
651 if (!immutable_states.empty()) return immutable_states.front()->params;
652
653 // If there are no immutable states, use the operand's state if it is the
654 // only one operand and has parameters propagated.
655 if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
656 return mutable_states.front()->params;
657 }
658
659 // If there are no immutable states, use the result's state if it is the
660 // only one result and has parameters propagated.
661 if (op->getNumResults() == 1 && mutable_results_num == 1) {
662 return mutable_states.back()->params;
663 }
664
665 // Use the first propagated state to quantize the rest operands and results.
666 if (!mutable_states.empty()) return mutable_states.front()->params;
667
668 // None operands/results have parameters propagated, skip this node for now.
669 return {};
670 }
671
PreprocessConstantOps()672 void QuantizationDriver::PreprocessConstantOps() {
673 fn_.walk([&](ConstantOp cst) {
674 // Non-float tensors are neither weights nor require quantization.
675 auto type = cst.getType().dyn_cast<ShapedType>();
676 if (!type || !type.getElementType().isa<FloatType>()) return;
677
678 Value value = cst.getResult();
679 builder_.setInsertionPoint(cst);
680
681 // The following loop will change the value uses, thus we cache all the uses
682 // needs to be changed.
683 llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
684 for (auto &use : value.getUses()) {
685 uses.push_back({use.getOwner(), use.getOperandNumber()});
686 }
687 for (auto indexed_use : llvm::enumerate(uses)) {
688 Operation *user = indexed_use.value().first;
689 int operand_num = indexed_use.value().second;
690
691 auto spec = GetQuantSpec(user);
692 auto biases = spec->biases_params;
693
694 // The quantization parameters of a `weight` shouldn't be determined by
695 // other values. So any constants which are not bias, an operand of an
696 // op with same scale requirements, and haven't been quantized are
697 // weights.
698 if (biases.find(operand_num) == biases.end() &&
699 !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
700 !llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
701 // Needs to scan the content of weights to get the quantization
702 // parameters if there are no quantization parameters (FakeQuant ops).
703 // For this case, the weight will not be duplicated.
704 weights_.insert(cst);
705 auto affine_user =
706 llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
707 if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
708 affine_user.RequiredNarrowRangeAffineOperand()) {
709 optimized_weights_.insert(
710 {cst, affine_user.GetQuantizationDimIndex()});
711 }
712 } else {
713 // This is a bias or an operand of an op with same scale requirements,
714 // so the quantization parameter are propagated from or determined by
715 // other values. Duplicate this constant in case it is shared by
716 // different users.
717 if (uses.size() > 1) {
718 auto new_cst =
719 builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
720 user->setOperand(operand_num, new_cst);
721 }
722 }
723 }
724 });
725 }
726
SetupAllStates()727 void QuantizationDriver::SetupAllStates() {
728 llvm::DenseMap<Value, int> value_to_state;
729
730 for (auto arg : fn_.getArguments()) {
731 args_.push_back(arg);
732 Value value = arg;
733 // If the argument is quantized, it should only has one user.
734 if (arg.hasOneUse()) {
735 auto user = value.use_begin().getUser();
736 if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
737 value = q.getResult();
738 }
739 }
740 InitializeArgState(arg, value, &value_to_state);
741 }
742
743 fn_.walk([&](Operation *op) {
744 if (op->hasTrait<OpTrait::IsTerminator>() ||
745 op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
746 llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
747 return;
748 work_list_.push_back(op);
749
750 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
751 auto operand = op->getOperand(i);
752 if (auto *inst = operand.getDefiningOp()) {
753 // If the operand comes from a tfl.dequantize op, we use the quantized
754 // input of this tfl.dequantize op to set the state.
755 if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
756 operand = dq.arg();
757 }
758 }
759 InitializeOperandState(op, i, operand, &value_to_state);
760 }
761
762 for (int res = 0, e = op->getNumResults(); res != e; ++res) {
763 Value result = op->getResult(res);
764 // If the result has been quantized, it should only be used by a
765 // tfl.quantize op. For this case, we uses the quantized result to
766 // create the state and mark it immutable.
767 if (result.hasOneUse()) {
768 auto user = result.use_begin().getUser();
769 if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
770 result = q.getResult();
771 }
772 }
773 InitializeResultState(op, res, result, &value_to_state);
774 }
775 });
776 }
777
778 // This method scans the operations in the function to setup the initial
779 // states for quantization parameter propagation.
780 // TODO(fengliuai): This algorithm assumes there are only one pair of
781 // tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
782 // check should be applied.
Initialize()783 void QuantizationDriver::Initialize() {
784 // Duplicate the bias constant, so the states can be setup correctly.
785 // TODO(fengliuai): Function definition should also be duplicated if there
786 // are multiple call sites.
787 PreprocessConstantOps();
788
789 // Setup all the internal states.
790 SetupAllStates();
791 }
792
PropagateParams()793 bool QuantizationDriver::PropagateParams() {
794 // TODO(fengliuai): uses a typed indicator instead of a bool value.
795 bool changed = false;
796 while (!work_list_.empty()) {
797 Operation *op = work_list_.back();
798 work_list_.pop_back();
799
800 LLVM_DEBUG(DumpStates(op));
801
802 // This op has been quantized, so we should not consider it again.
803 if (llvm::is_contained(quantized_, op)) continue;
804 quantized_.insert(op);
805
806 if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
807 // If the workflow requires inferring ranges from the content
808 // (post-training quantization) and it is weight (filter) and hasn't
809 // been quantized, we infer the quantization parameters from the content.
810 if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
811 // The quantization parameters are determined by the content of the
812 // constant.
813 changed |= SetConstantResultParams(op);
814 }
815 continue;
816 }
817
818 if (llvm::isa<SameScalesOpInterface>(op)) {
819 auto params = GetQuantParamsForSameScaleConstraint(op);
820 // The quantization parameters haven't been propagated to any operands
821 // or results. Skip this node for now.
822 if (!params) {
823 quantized_.erase(op);
824 continue;
825 }
826
827 // Use the final state to set all the operands' parameters.
828 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
829 if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
830 // Without this check, it will accidentally propagate the quantization
831 // information by the shared non-float tensors.
832 if (type.getElementType().isa<FloatType>())
833 changed |= SetOperandParams(op, i, params);
834 }
835 }
836
837 // Use the final state to set all the results' parameters.
838 for (int res = 0, e = op->getNumResults(); res != e; ++res)
839 if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
840 // Without this check, it will accidentally propagate the quantization
841 // information by the shared non-float-tensors.
842 if (type.getElementType().isa<FloatType>())
843 changed |= SetResultParams(op, res, params);
844 }
845 }
846
847 // TODO(fengliuai): make the bit width configurable.
848 auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
849 if (restricted && infer_tensor_range_) {
850 // Infer ranges from the activation ops. This is usually required for
851 // the post-training quantization workflow.
852 // TODO(fengliuai): different result can have different fixed range.
853 auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
854 for (auto i = 0; i < op->getNumResults(); ++i) {
855 // The range is null if the result has been quantized.
856 if (params) {
857 changed |= SetResultParams(op, i, params);
858 }
859 }
860 }
861
862 auto spec = GetQuantSpec(op);
863 for (auto &it : spec->biases_params) {
864 auto params =
865 GetBiasParams(op, it.first, it.second.first, it.second.second);
866 if (!params) {
867 quantized_.erase(op);
868 continue;
869 }
870 changed |= SetOperandParams(op, it.first, params);
871 }
872 }
873
874 LLVM_DEBUG(llvm::dbgs() << "\n\n\n");
875 LLVM_DEBUG(DumpStates(nullptr));
876
877 return changed;
878 }
879
Finalize()880 void QuantizationDriver::Finalize() {
881 for (auto arg : args_) {
882 auto &state = GetArgQuantState(arg);
883 auto &requantize = GetArgRequantizeState(arg);
884 if (state.IsEmpty() ||
885 (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
886 continue;
887 }
888
889 if (!state.immutable) {
890 QuantizeArg(arg, state.params);
891 }
892
893 if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
894 RequantizeArg(arg, &requantize);
895 }
896 }
897
898 for (auto it : result_states_) {
899 Operation *op = it.first.first;
900 int res_index = it.first.second;
901 auto &state = GetResultQuantState(op, res_index);
902 auto &requantize = GetResultRequantizeState(op, res_index);
903 if (state.IsEmpty() ||
904 (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
905 continue;
906 }
907
908 if (!state.immutable) {
909 QuantizeOpResult(op, res_index, state.params);
910 }
911
912 if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
913 RequantizeOpResult(op, res_index, &requantize);
914 }
915 }
916 }
917
Run()918 void QuantizationDriver::Run() {
919 Initialize();
920 if (PropagateParams()) {
921 Finalize();
922 }
923 }
924
ApplyQuantizationParamsPropagation(mlir::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)925 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
926 bool disable_per_channel,
927 OpQuantSpecGetter op_quant_spec_getter,
928 bool infer_tensor_ranges,
929 bool legacy_float_scale) {
930 QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
931 infer_tensor_ranges, legacy_float_scale)
932 .Run();
933 }
934
935 } // namespace quant
936 } // namespace mlir
937