• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Matchers.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
41 #include "tensorflow/core/platform/logging.h"
42 
43 #define DEBUG_TYPE "quantization-driver"
44 
45 namespace mlir {
46 namespace quant {
47 namespace {
EmptyParams(QuantParams p)48 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
49 
50 // The state for each op result during the quantization parameters propagation.
51 struct QuantState {
52   // Quantization parameters propagated to an op result.
53   QuantParams params;
54   // A flag indicates this state (the params) shouldn't be changed after it is
55   // initialized. This flag will be set to true if the quantization parameters
56   // are from the quantization-aware training.
57   const bool immutable;
58 
IsEmptymlir::quant::__anon1b0abf470111::QuantState59   bool IsEmpty() { return EmptyParams(params); }
60 };
61 
62 // The state for rescaling the propagated quantization parameters. This can be
63 // on the input side to satisfy the constraint of previous operation, or on the
64 // output side to satisfy the constraint of the next operation.
65 struct RequantizeState {
66   // Sometimes, we have to "requantize" the quantization result to satisfy all
67   // the constraints. The "requantize" can happen either on the input or output
68   // of the quantization result.
69   enum RequantizePosition {
70     NO_REQUANTIZE,
71     ON_INPUT,
72     ON_OUTPUT
73   } pos = NO_REQUANTIZE;
74 
75   // Quantization parameters will be used to add the requantize ops.
76   QuantParams params;
77 };
78 
79 // This is a worklist-driven driver for propagating quantization parameters
80 // across operations.
81 //
82 // The initial quantization parameters are extracted from the quantized type
83 // between adjacent tfl.quantize and tfl.dequantize ops. All these initial
84 // parameters are marked as immutable because they are from quantization-aware
85 // training.
86 //
87 // The algorithm traverses each op and sets the quantization parameters of its
88 // operands and results, according to its quantization specification, and then
89 // adds the operands and results to the worklist. If there are any conflicts
90 // (for example, there are quantization parameters propagated from the previous
91 // iteration), this process stops if the existing parameters are the immutable,
92 // or adding `requantize` op to resolve the conflicts.
93 //
94 // After the algorithm is converged, pairs of tfl.quantize and tfl.dequantize
95 // are inserted to the right position to materialize the propagation and
96 // requantize results.
97 //
98 class QuantizationDriver {
99  public:
QuantizationDriver(FuncOp fn,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_range,bool legacy_float_scale)100   explicit QuantizationDriver(FuncOp fn, bool is_signed,
101                               bool disable_per_channel,
102                               OpQuantSpecGetter op_quant_spec_getter,
103                               bool infer_tensor_range, bool legacy_float_scale)
104       : fn_(fn),
105         builder_(fn.getBody()),
106         is_signed_(is_signed),
107         disable_per_channel_(disable_per_channel),
108         op_quant_spec_getter_(op_quant_spec_getter),
109         infer_tensor_range_(infer_tensor_range),
110         legacy_float_scale_(legacy_float_scale) {}
111 
112   // The entry point of the quantization parameters propagation.
113   void Run();
114 
115  private:
116   // This is used to identify an operand or result of an op. The second element
117   // of this pair is the index of the operand or result.
118   using OpValue = std::pair<mlir::Operation *, int>;
119 
120   // Sets up the states for all the op results in the function.
121   void Initialize();
122 
123   // Propagates the quantization parameters across all the ops.
124   bool PropagateParams();
125 
126   // Inserts the Quantize and Dequantize ops according to the propagation
127   // result.
128   void Finalize();
129 
130   // The quantization parameters of bias operand are usually determined by
131   // other operands, so if a constant is used by different ops as bias, it needs
132   // to be duplicated, thus each op can assign its own quantization parameter
133   // for this bias. Also this method adds all the non-bias constants (weights)
134   // to a set for looking up later. This method also adds all the per-channel
135   // weights to a set for looking up later.
136   void PreprocessConstantOps();
137 
138   // Setup all the data structures for quantization propagation.
139   void SetupAllStates();
140 
141   // Whether the constant is a weight, which shouldn't be shared by different
142   // ops.
IsWeight(Operation * cst)143   bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
144 
145   // Returns all the related quantization constraints of the op.
146   std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
147 
148   // Whether Quantization parameters have been propagated to the results of this
149   // op.
150   bool IsQuantized(Operation *op);
151 
152   // Adds all the users of index-th result of op to the work list.
AddUserToList(Operation * op,int index)153   void AddUserToList(Operation *op, int index) {
154     for (auto *user : op->getResult(index).getUsers()) {
155       work_list_.push_back(user);
156     }
157   }
158 
159   // Adds the defining op of index-th operand of op to the work list.
AddOperandToList(Operation * op,int index)160   void AddOperandToList(Operation *op, int index) {
161     if (auto *inst = op->getOperand(index).getDefiningOp()) {
162       work_list_.push_back(inst);
163     }
164   }
165 
166   // Returns the quantization params for the bias input from the non-bias
167   // operands which have their indexes in the `non_biases` vector. The returned
168   // parameters are calculated by `func`.
169   QuantParams GetBiasParams(Operation *op, int bias,
170                             const std::vector<int> &non_biases,
171                             AccumulatorScaleFunc func);
172 
173   // Sets the quantization parameters of the result to a fixed value. If any
174   // quantization parameters have been propagated, a `requantize` will happen on
175   // the input of propagated quantization.
176   bool SetResultParams(Operation *op, int index, QuantParams params);
177 
178   // Sets the quantization parameters of the operand to a fixed value. If any
179   // quantization parameters have been propagated, a `requantize` will happen on
180   // the output of propagated quantization.
181   bool SetOperandParams(Operation *op, int index, QuantParams params);
182 
183   // Sets the quantization parameters of the constant result according to its
184   // content.
185   bool SetConstantResultParams(Operation *op);
186 
187   // Inserts the Quantize and Dequantize ops for quantizing the index-th result
188   // of the op.
189   void QuantizeOpResult(Operation *op, int index, QuantParams params);
190 
191   void QuantizeArg(BlockArgument arg, QuantParams params);
192 
193   // Inserts the Quantize and Dequantize ops to quantize the value and returns
194   // the Quantize op.
195   void QuantizeValue(Value value, QuantParams params, Location loc);
196 
197   // Inserts the Quantize ops for requantizing the index-th result of the op.
198   void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
199 
200   void RequantizeArg(BlockArgument arg, RequantizeState *state);
201 
202   // Inserts the Quantize and Dequantize ops to quantize the value and returns
203   // the Quantize op.
204   void RequantizeValue(Value value, RequantizeState *state, Location loc);
205 
206   // A heuristic to get the quantization parameter satisfies the same scale
207   // constraints for the op. Returns an empty option if this quantization
208   // parameter doesn't exist.
209   QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
210 
211   // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)212   QuantState &GetOperandQuantState(Operation *op, int index) {
213     return states_[operand_states_[{op, index}]];
214   }
215 
216   // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)217   QuantState &GetResultQuantState(Operation *op, int index) {
218     return states_[result_states_[{op, index}]];
219   }
220 
GetArgQuantState(BlockArgument arg)221   QuantState &GetArgQuantState(BlockArgument arg) {
222     return states_[arg_states_[arg]];
223   }
224 
225   // Returns the state of the index-th operand of the op.
GetOperandRequantizeState(Operation * op,int index)226   RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
227     return rescale_states_[operand_states_[{op, index}]];
228   }
229 
230   // Returns the state of the index-th result of the op.
GetResultRequantizeState(Operation * op,int index)231   RequantizeState &GetResultRequantizeState(Operation *op, int index) {
232     return rescale_states_[result_states_[{op, index}]];
233   }
234 
GetArgRequantizeState(BlockArgument arg)235   RequantizeState &GetArgRequantizeState(BlockArgument arg) {
236     return rescale_states_[arg_states_[arg]];
237   }
238 
239   // Uses the type of `val` to set the initial state of the index-th result if
240   // `as_result` is true or index-th operand if `as_result` is false. The state
241   // is immutable if the type is a quantized type. Returns the index of this
242   // new state in the state vector.
243   int InitializeState(Operation *op, int index, Value val, bool as_result);
244 
245   // Sets the state of an argument. If this value is cached, uses the cached
246   // result without creating new entry in the state vector. Otherwise, allocate
247   // a new entry in the state vector.
InitializeArgState(BlockArgument arg,Value in,llvm::DenseMap<Value,int> * cache)248   void InitializeArgState(BlockArgument arg, Value in,
249                           llvm::DenseMap<Value, int> *cache) {
250     auto cached = cache->insert({in, 0});
251     if (!cached.second) {
252       arg_states_[arg] = cached.first->second;
253       return;
254     }
255     QuantParams params =
256         quant::QuantizedType::getQuantizedElementType(in.getType());
257     bool immutable = !EmptyParams(params);
258     int next_state_index = states_.size();
259     states_.push_back({params, immutable});
260     arg_states_[arg] = next_state_index;
261     cached.first->second = next_state_index;
262   }
263 
264   // Sets the state of the index-th operand of the op. If this operand is
265   // cached, uses the cached result without creating new entry in the state
266   // vector. Otherwise, allocate a new entry in the state vector.
InitializeOperandState(Operation * op,int index,Value in,llvm::DenseMap<Value,int> * cache)267   void InitializeOperandState(Operation *op, int index, Value in,
268                               llvm::DenseMap<Value, int> *cache) {
269     auto cached = cache->insert({in, 0});
270     if (!cached.second) {
271       operand_states_.insert({{op, index}, cached.first->second});
272       return;
273     }
274     cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
275   }
276 
277   // Sets the state of the index-th result of the op. If this result is cached,
278   // uses the cached result without creating new entry in the state vector.
279   // Otherwise, allocate a new entry in the state vector.
InitializeResultState(Operation * op,int index,Value res,llvm::DenseMap<Value,int> * cache)280   void InitializeResultState(Operation *op, int index, Value res,
281                              llvm::DenseMap<Value, int> *cache) {
282     auto cached = cache->insert({res, 0});
283     if (!cached.second) {
284       result_states_.insert({{op, index}, cached.first->second});
285       return;
286     }
287     cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
288   }
289 
DumpStates(Operation * current_op)290   void DumpStates(Operation *current_op) {
291     if (current_op) {
292       llvm::dbgs() << "\n\n\n" << current_op->getName() << "\n";
293     }
294     fn_.walk([&](Operation *op) {
295       if (op->hasTrait<OpTrait::IsTerminator>() ||
296           op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
297           llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
298               op))
299         return;
300       if (current_op == op) llvm::dbgs() << "===>>>";
301       llvm::dbgs() << op->getName() << " : (";
302       if (llvm::isa<FuncOp>(op)) {
303         for (auto &arg : fn_.getArguments()) {
304           if (auto params = GetArgQuantState(arg).params) {
305             params.print(llvm::dbgs());
306             auto requantize_state = GetArgRequantizeState(arg);
307             if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
308               llvm::dbgs() << "+";
309               requantize_state.params.print(llvm::dbgs());
310             }
311           }
312           llvm::dbgs() << ",";
313         }
314       }
315       for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
316         if (auto params = GetOperandQuantState(op, i).params) {
317           params.print(llvm::dbgs());
318           auto requantize_state = GetOperandRequantizeState(op, i);
319           if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
320             llvm::dbgs() << "+";
321             requantize_state.params.print(llvm::dbgs());
322           }
323         } else {
324           op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
325               llvm::dbgs());
326         }
327         llvm::dbgs() << ",";
328       }
329       llvm::dbgs() << ") -> (";
330       for (int i = 0, e = op->getNumResults(); i < e; ++i) {
331         if (auto params = GetResultQuantState(op, i).params) {
332           params.print(llvm::dbgs());
333           auto requantize_state = GetResultRequantizeState(op, i);
334           if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
335             llvm::dbgs() << "+";
336             requantize_state.params.print(llvm::dbgs());
337           }
338         } else {
339           op->getResult(i).getType().cast<ShapedType>().getElementType().print(
340               llvm::dbgs());
341         }
342         llvm::dbgs() << ",";
343       }
344       llvm::dbgs() << ")\n";
345     });
346   }
347 
348   FuncOp fn_;
349   OpBuilder builder_;
350   bool is_signed_;
351   bool disable_per_channel_;
352 
353   // We should distinguish weights and bias constants. Biases are specified by
354   // the quantization spec or are the operands of ops with same scale spec. The
355   // rest are weights.
356   llvm::DenseSet<Operation *> weights_;
357 
358   // The weights require narrow_range quantization. This map collects all the
359   // weight operands defined by the op quant spec. If the value of the entry is
360   // positive, per-channel quantization is required.
361   llvm::DenseMap<Operation *, int> optimized_weights_;
362 
363   // All the ops needs to propagate the quantization parameters to.
364   std::vector<Operation *> work_list_;
365   std::unordered_set<Operation *> quantized_;
366 
367   // The vector contains all the quantization parameters propagated from the
368   // defining operations of the value, or from the quantization aware training.
369   std::vector<QuantState> states_;
370 
371   // The map contains all the quantization parameters which are required to
372   // satisfy the same operands and results constraint. The keys of this map are
373   // the values from `operand_states_` and `result_state_`.
374   std::unordered_map<int, RequantizeState> rescale_states_;
375 
376   // Maps of indexes to the propagation state vector from the ops operands,
377   // results and arguments.
378   llvm::DenseMap<OpValue, int> operand_states_;
379   llvm::DenseMap<OpValue, int> result_states_;
380   llvm::DenseMap<BlockArgument, int> arg_states_;
381 
382   // This vector is to preserve the arguments order, so the newly inserted
383   // quantized ops for the arguments are deterministically ordered.
384   llvm::SmallVector<BlockArgument, 4> args_;
385 
386   OpQuantSpecGetter op_quant_spec_getter_;
387 
388   // Infer output ranges for activation ops and constants. This is usually
389   // required for post-training quantization.
390   bool infer_tensor_range_;
391 
392   // Calculate scales in float instead of double, so that the scales and
393   // quantized values are exactly the same with the TOCO quantizer.
394   bool legacy_float_scale_;
395 };
396 }  // namespace
397 
GetQuantSpec(Operation * op)398 std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
399   return op_quant_spec_getter_(op);
400 }
401 
IsQuantized(Operation * op)402 bool QuantizationDriver::IsQuantized(Operation *op) {
403   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
404     if (GetResultQuantState(op, i).IsEmpty()) return false;
405   }
406   return true;
407 }
408 
InitializeState(Operation * op,int index,Value val,bool as_result)409 int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
410                                         bool as_result) {
411   QuantParams params =
412       quant::QuantizedType::getQuantizedElementType(val.getType());
413   bool immutable = !EmptyParams(params);
414   int next_state_index = states_.size();
415   states_.push_back({params, immutable});
416   if (as_result)
417     result_states_.insert({{op, index}, next_state_index});
418   else
419     operand_states_.insert({{op, index}, next_state_index});
420 
421   return next_state_index;
422 }
423 
SetConstantResultParams(Operation * op)424 bool QuantizationDriver::SetConstantResultParams(Operation *op) {
425   DenseFPElementsAttr attr;
426   Value res = op->getResult(0);
427   if (!matchPattern(res, m_Constant(&attr))) {
428     return false;
429   }
430   // TODO(fengliuai): make storage_type_width and narrow_range configurable.
431   Type final_type;
432   auto it = optimized_weights_.find(op);
433   bool is_weight = it != optimized_weights_.end();
434   bool is_weight_with_per_channel_support =
435       is_weight && it->second != -1 && is_signed_;
436 
437   if (is_weight_with_per_channel_support && !disable_per_channel_) {
438     // When `disable_per_channel_` is false, per-channel symmetric quantization
439     // parameters are created from the weights when the ops support per-channel
440     // quantization. Otherwise, uses per-tensor asymmetric quantization with
441     // narrow range.
442 
443     // per-axis quantization weight, with symmetric min/max enforced.
444     final_type = GetUniformQuantizedPerAxisTypeForWeight(
445         attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_,
446         /*narrow_range=*/true, legacy_float_scale_);
447   } else {
448     // per-tensor quantization weight
449     final_type = GetUniformQuantizedTypeForWeight(
450         attr, /*symmetric=*/is_weight && is_signed_,
451         /*num_bits=*/8, is_signed_,
452         /*narrow_range_=*/is_weight, legacy_float_scale_);
453   }
454   if (auto quant_type = final_type.dyn_cast_or_null<quant::QuantizedType>()) {
455     return SetResultParams(op, 0, quant_type);
456   }
457   return false;
458 }
459 
SetResultParams(Operation * op,int res_index,QuantParams params)460 bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
461                                          QuantParams params) {
462   auto &state = GetResultQuantState(op, res_index);
463   if (state.params == params) {
464     return false;
465   }
466   if (!state.IsEmpty()) {
467     auto &rescale = GetResultRequantizeState(op, res_index);
468     rescale.params = params;
469     rescale.pos = RequantizeState::ON_INPUT;
470     return true;
471   }
472   state.params = params;
473   AddUserToList(op, res_index);
474   return true;
475 }
476 
GetBiasParams(Operation * op,int bias,const std::vector<int> & non_biases,AccumulatorScaleFunc func)477 QuantParams QuantizationDriver::GetBiasParams(
478     Operation *op, int bias, const std::vector<int> &non_biases,
479     AccumulatorScaleFunc func) {
480   auto &bias_state = GetOperandQuantState(op, bias);
481   if (!bias_state.IsEmpty()) {
482     return bias_state.params;
483   }
484   std::vector<QuantParams> op_types;
485   op_types.reserve(non_biases.size());
486   for (auto non_bias : non_biases) {
487     auto &non_bias_type = GetOperandQuantState(op, non_bias);
488     op_types.push_back(non_bias_type.params);
489   }
490   if (op_types.empty()) return {};
491   return func(op_types, legacy_float_scale_);
492 }
493 
SetOperandParams(Operation * op,int index,QuantParams params)494 bool QuantizationDriver::SetOperandParams(Operation *op, int index,
495                                           QuantParams params) {
496   auto &state = GetOperandQuantState(op, index);
497   if (state.params == params) {
498     return false;
499   }
500 
501   if (!state.IsEmpty()) {
502     auto &rescale = GetOperandRequantizeState(op, index);
503     rescale.params = params;
504     rescale.pos = RequantizeState::ON_OUTPUT;
505     return true;
506   }
507 
508   state.params = params;
509   AddOperandToList(op, index);
510   return true;
511 }
512 
QuantizeOpResult(Operation * op,int index,QuantParams params)513 void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
514                                           QuantParams params) {
515   builder_.setInsertionPointAfter(op);
516   Value original_result = op->getResult(index);
517   QuantizeValue(original_result, params, op->getLoc());
518 }
519 
QuantizeArg(BlockArgument arg,QuantParams params)520 void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
521   builder_.setInsertionPointToStart(arg.getOwner());
522   QuantizeValue(arg, params, builder_.getUnknownLoc());
523 }
524 
QuantizeValue(Value value,QuantParams params,Location loc)525 void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
526                                        Location loc) {
527   Type expressed_type = value.getType();
528   Type new_type = params.castFromExpressedType(expressed_type);
529   // This value isn't an expressed type (float), skip.
530   if (!new_type) return;
531 
532   auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
533   auto dequantize = builder_.create<quant::DequantizeCastOp>(
534       loc, expressed_type, quantize.getResult());
535 
536   // This attribute is set to distinguish the quantize ops being added by the
537   // quantization pass. These ops can be removed without losing original
538   // program accuracy.
539   // TODO(fengliuai): make the attribute being part of op definition.
540   quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
541 
542   // `original_result` has a use to `quantize`, so this will replace that use
543   // by the result of `dequantize`. Remember to reset that use afterwards
544   value.replaceAllUsesWith(dequantize);
545   quantize.getOperation()->replaceUsesOfWith(dequantize, value);
546 }
547 
RequantizeOpResult(Operation * op,int index,RequantizeState * state)548 void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
549                                             RequantizeState *state) {
550   if (state->pos == RequantizeState::NO_REQUANTIZE) return;
551   builder_.setInsertionPointAfter(op);
552   Value value = op->getResult(index);
553   if (state->pos == RequantizeState::ON_OUTPUT) {
554     Operation *user = value.getUses().begin().getUser();
555     if (llvm::isa<quant::QuantizeCastOp>(user)) {
556       // The requantize op is inserted between `quantize` and `dequantize` ops.
557       value = user->getResult(0);
558       builder_.setInsertionPointAfter(user);
559     }
560   }
561   RequantizeValue(value, state, op->getLoc());
562 }
563 
RequantizeArg(BlockArgument arg,RequantizeState * state)564 void QuantizationDriver::RequantizeArg(BlockArgument arg,
565                                        RequantizeState *state) {
566   Value value = arg;
567   builder_.setInsertionPointToStart(arg.getOwner());
568   if (value.hasOneUse()) {
569     auto user = value.use_begin().getUser();
570     if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
571       value = q.getResult();
572       builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
573     }
574   }
575   RequantizeValue(value, state, builder_.getUnknownLoc());
576 }
577 
RequantizeValue(Value value,RequantizeState * state,Location loc)578 void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
579                                          Location loc) {
580   Type new_type;
581   if (state->pos == RequantizeState::ON_INPUT) {
582     Type expressed_type = value.getType();
583     // The value needs to be requantized. A Quantize op will be created to use
584     // it as the operand and replace its uses.
585     new_type = state->params.castFromExpressedType(expressed_type);
586   } else {
587     Type expressed_type =
588         quant::QuantizedType::castToExpressedType(value.getType());
589     if (!expressed_type) return;
590 
591     // The value needs to be requantized. A Quantize op will be created to use
592     // it as the operand and replace its uses.
593     new_type = state->params.castFromExpressedType(expressed_type);
594   }
595   // This value isn't an expressed type (float), skip.
596   if (!new_type) return;
597 
598   auto requantize_op =
599       builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
600   value.replaceAllUsesWith(requantize_op);
601   requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
602 }
603 
604 // A heuristic to get quantization parameters satisfies the same scale
605 // constraints:
606 // - If there are immutable states,
607 //   - use the single input, or,
608 //   - use the single output, or,
609 //   - use the first one in the collection,
610 // - use the single input if it is ready, or,
611 // - use the single output if it is ready, or,
612 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)613 QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
614     Operation *op) {
615   // Two vector to collect Non-empty operands and results states.
616   std::vector<QuantState *> mutable_states, immutable_states;
617   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
618     auto &state = GetOperandQuantState(op, i);
619     if (state.immutable) {
620       immutable_states.push_back(&state);
621     } else if (!state.IsEmpty()) {
622       mutable_states.push_back(&state);
623     }
624   }
625 
626   int immutable_operands_num = immutable_states.size();
627   int mutable_operands_num = mutable_states.size();
628   // Use the operand's state if it is immutable and it is the only one
629   // operand.
630   if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
631     return immutable_states.front()->params;
632   }
633 
634   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
635     auto &state = GetResultQuantState(op, i);
636     if (state.immutable) {
637       immutable_states.push_back(&state);
638     } else if (!state.IsEmpty()) {
639       mutable_states.push_back(&state);
640     }
641   }
642 
643   int immutable_results_num = immutable_states.size() - immutable_operands_num;
644   int mutable_results_num = mutable_states.size() - mutable_operands_num;
645   // Use the result's state if it is immutable and it is the only one result.
646   if (op->getNumResults() == 1 && immutable_results_num == 1) {
647     return immutable_states.back()->params;
648   }
649 
650   // Use the first immutable state to quantize the rest operands and results.
651   if (!immutable_states.empty()) return immutable_states.front()->params;
652 
653   // If there are no immutable states, use the operand's state if it is the
654   // only one operand and has parameters propagated.
655   if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
656     return mutable_states.front()->params;
657   }
658 
659   // If there are no immutable states, use the result's state if it is the
660   // only one result and has parameters propagated.
661   if (op->getNumResults() == 1 && mutable_results_num == 1) {
662     return mutable_states.back()->params;
663   }
664 
665   // Use the first propagated state to quantize the rest operands and results.
666   if (!mutable_states.empty()) return mutable_states.front()->params;
667 
668   // None operands/results have parameters propagated, skip this node for now.
669   return {};
670 }
671 
PreprocessConstantOps()672 void QuantizationDriver::PreprocessConstantOps() {
673   fn_.walk([&](ConstantOp cst) {
674     // Non-float tensors are neither weights nor require quantization.
675     auto type = cst.getType().dyn_cast<ShapedType>();
676     if (!type || !type.getElementType().isa<FloatType>()) return;
677 
678     Value value = cst.getResult();
679     builder_.setInsertionPoint(cst);
680 
681     // The following loop will change the value uses, thus we cache all the uses
682     // needs to be changed.
683     llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
684     for (auto &use : value.getUses()) {
685       uses.push_back({use.getOwner(), use.getOperandNumber()});
686     }
687     for (auto indexed_use : llvm::enumerate(uses)) {
688       Operation *user = indexed_use.value().first;
689       int operand_num = indexed_use.value().second;
690 
691       auto spec = GetQuantSpec(user);
692       auto biases = spec->biases_params;
693 
694       // The quantization parameters of a `weight` shouldn't be determined by
695       // other values. So any constants which are not bias, an operand of an
696       // op with same scale requirements, and haven't been quantized are
697       // weights.
698       if (biases.find(operand_num) == biases.end() &&
699           !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
700           !llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
701         // Needs to scan the content of weights to get the quantization
702         // parameters if there are no quantization parameters (FakeQuant ops).
703         // For this case, the weight will not be duplicated.
704         weights_.insert(cst);
705         auto affine_user =
706             llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
707         if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
708             affine_user.RequiredNarrowRangeAffineOperand()) {
709           optimized_weights_.insert(
710               {cst, affine_user.GetQuantizationDimIndex()});
711         }
712       } else {
713         // This is a bias or an operand of an op with same scale requirements,
714         // so the quantization parameter are propagated from or determined by
715         // other values. Duplicate this constant in case it is shared by
716         // different users.
717         if (uses.size() > 1) {
718           auto new_cst =
719               builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
720           user->setOperand(operand_num, new_cst);
721         }
722       }
723     }
724   });
725 }
726 
SetupAllStates()727 void QuantizationDriver::SetupAllStates() {
728   llvm::DenseMap<Value, int> value_to_state;
729 
730   for (auto arg : fn_.getArguments()) {
731     args_.push_back(arg);
732     Value value = arg;
733     // If the argument is quantized, it should only has one user.
734     if (arg.hasOneUse()) {
735       auto user = value.use_begin().getUser();
736       if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
737         value = q.getResult();
738       }
739     }
740     InitializeArgState(arg, value, &value_to_state);
741   }
742 
743   fn_.walk([&](Operation *op) {
744     if (IsOpNotQuantizable(op)) {
745       return;
746     }
747     work_list_.push_back(op);
748 
749     for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
750       auto operand = op->getOperand(i);
751       if (auto *inst = operand.getDefiningOp()) {
752         // If the operand comes from a tfl.dequantize op, we use the quantized
753         // input of this tfl.dequantize op to set the state.
754         if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
755           operand = dq.arg();
756         }
757       }
758       InitializeOperandState(op, i, operand, &value_to_state);
759     }
760 
761     for (int res = 0, e = op->getNumResults(); res != e; ++res) {
762       Value result = op->getResult(res);
763       // If the result has been quantized, it should only be used by a
764       // tfl.quantize op. For this case, we uses the quantized result to
765       // create the state and mark it immutable.
766       if (result.hasOneUse()) {
767         auto user = result.use_begin().getUser();
768         if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
769           result = q.getResult();
770         }
771       }
772       InitializeResultState(op, res, result, &value_to_state);
773     }
774   });
775 }
776 
777 // This method scans the operations in the function to setup the initial
778 // states for quantization parameter propagation.
779 // TODO(fengliuai): This algorithm assumes there are only one pair of
780 // tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
781 // check should be applied.
Initialize()782 void QuantizationDriver::Initialize() {
783   // Duplicate the bias constant, so the states can be setup correctly.
784   // TODO(fengliuai): Function definition should also be duplicated if there
785   // are multiple call sites.
786   PreprocessConstantOps();
787 
788   // Setup all the internal states.
789   SetupAllStates();
790 }
791 
PropagateParams()792 bool QuantizationDriver::PropagateParams() {
793   // TODO(fengliuai): uses a typed indicator instead of a bool value.
794   bool changed = false;
795   while (!work_list_.empty()) {
796     Operation *op = work_list_.back();
797     work_list_.pop_back();
798 
799     LLVM_DEBUG(DumpStates(op));
800 
801     // This op has been quantized, so we should not consider it again.
802     if (llvm::is_contained(quantized_, op)) continue;
803     quantized_.insert(op);
804 
805     if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
806       // If the workflow requires inferring ranges from the content
807       // (post-training quantization) and it is weight (filter) and hasn't
808       // been quantized, we infer the quantization parameters from the content.
809       if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
810         // The quantization parameters are determined by the content of the
811         // constant.
812         changed |= SetConstantResultParams(op);
813       }
814       continue;
815     }
816 
817     if (llvm::isa<SameScalesOpInterface>(op)) {
818       auto params = GetQuantParamsForSameScaleConstraint(op);
819       // The quantization parameters haven't been propagated to any operands
820       // or results. Skip this node for now.
821       if (!params) {
822         quantized_.erase(op);
823         continue;
824       }
825 
826       // Use the final state to set all the operands' parameters.
827       for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
828         if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
829           // Without this check, it will accidentally propagate the quantization
830           // information by the shared non-float tensors.
831           if (type.getElementType().isa<FloatType>())
832             changed |= SetOperandParams(op, i, params);
833         }
834       }
835 
836       // Use the final state to set all the results' parameters.
837       for (int res = 0, e = op->getNumResults(); res != e; ++res)
838         if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
839           // Without this check, it will accidentally propagate the quantization
840           // information by the shared non-float-tensors.
841           if (type.getElementType().isa<FloatType>())
842             changed |= SetResultParams(op, res, params);
843         }
844     }
845 
846     // TODO(fengliuai): make the bit width configurable.
847     auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
848     if (restricted && infer_tensor_range_) {
849       // Infer ranges from the activation ops. This is usually required for
850       // the post-training quantization workflow.
851       // TODO(fengliuai): different result can have different fixed range.
852       auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
853       for (auto i = 0; i < op->getNumResults(); ++i) {
854         // The range is null if the result has been quantized.
855         if (params) {
856           changed |= SetResultParams(op, i, params);
857         }
858       }
859     }
860 
861     auto spec = GetQuantSpec(op);
862     for (auto &it : spec->biases_params) {
863       auto params =
864           GetBiasParams(op, it.first, it.second.first, it.second.second);
865       if (!params) {
866         quantized_.erase(op);
867         continue;
868       }
869       changed |= SetOperandParams(op, it.first, params);
870     }
871   }
872 
873   LLVM_DEBUG(llvm::dbgs() << "\n\n\n");
874   LLVM_DEBUG(DumpStates(nullptr));
875 
876   return changed;
877 }
878 
Finalize()879 void QuantizationDriver::Finalize() {
880   for (auto arg : args_) {
881     auto &state = GetArgQuantState(arg);
882     auto &requantize = GetArgRequantizeState(arg);
883     if (state.IsEmpty() ||
884         (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
885       continue;
886     }
887 
888     if (!state.immutable) {
889       QuantizeArg(arg, state.params);
890     }
891 
892     if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
893       RequantizeArg(arg, &requantize);
894     }
895   }
896 
897   for (auto it : result_states_) {
898     Operation *op = it.first.first;
899     int res_index = it.first.second;
900     auto &state = GetResultQuantState(op, res_index);
901     auto &requantize = GetResultRequantizeState(op, res_index);
902     if (state.IsEmpty() ||
903         (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
904       continue;
905     }
906 
907     if (!state.immutable) {
908       QuantizeOpResult(op, res_index, state.params);
909     }
910 
911     if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
912       RequantizeOpResult(op, res_index, &requantize);
913     }
914   }
915 }
916 
Run()917 void QuantizationDriver::Run() {
918   Initialize();
919   if (PropagateParams()) {
920     Finalize();
921   }
922 }
923 
ApplyQuantizationParamsPropagation(mlir::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)924 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
925                                         bool disable_per_channel,
926                                         OpQuantSpecGetter op_quant_spec_getter,
927                                         bool infer_tensor_ranges,
928                                         bool legacy_float_scale) {
929   QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
930                      infer_tensor_ranges, legacy_float_scale)
931       .Run();
932 }
933 
934 }  // namespace quant
935 }  // namespace mlir
936