• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Matchers.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
41 #include "tensorflow/core/platform/logging.h"
42 
43 #define DEBUG_TYPE "quantization-driver"
44 
45 namespace mlir {
46 namespace quant {
47 namespace {
EmptyParams(QuantParams p)48 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
49 
50 // The state for each op result during the quantization parameters propagation.
51 struct QuantState {
52   // Quantization parameters propagated to an op result.
53   QuantParams params;
54   // A flag indicates this state (the params) shouldn't be changed after it is
55   // initialized. This flag will be set to true if the quantization parameters
56   // are from the quantization-aware training.
57   const bool immutable;
58 
IsEmptymlir::quant::__anonf70393e20111::QuantState59   bool IsEmpty() { return EmptyParams(params); }
60 };
61 
62 // The state for rescaling the propagated quantization parameters. This can be
63 // on the input side to satisfy the constraint of previous operation, or on the
64 // output side to satisfy the constraint of the next operation.
65 struct RequantizeState {
66   // Sometimes, we have to "requantize" the quantization result to satisfy all
67   // the constraints. The "requantize" can happen either on the input or output
68   // of the quantization result.
69   enum RequantizePosition {
70     NO_REQUANTIZE,
71     ON_INPUT,
72     ON_OUTPUT
73   } pos = NO_REQUANTIZE;
74 
75   // Quantization parameters will be used to add the requantize ops.
76   QuantParams params;
77 };
78 
79 // This is a worklist-driven driver for propagating quantization parameters
80 // across operations.
81 //
82 // The initial quantization parameters are extracted from the quantized type
83 // between adjacent tfl.quantize and tfl.dequantize ops. All these initial
84 // parameters are marked as immutable because they are from quantization-aware
85 // training.
86 //
87 // The algorithm traverses each op and sets the quantization parameters of its
88 // operands and results, according to its quantization specification, and then
89 // adds the operands and results to the worklist. If there are any conflicts
90 // (for example, there are quantization parameters propagated from the previous
91 // iteration), this process stops if the existing parameters are the immutable,
92 // or adding `requantize` op to resolve the conflicts.
93 //
94 // After the algorithm is converged, pairs of tfl.quantize and tfl.dequantize
95 // are inserted to the right position to materialize the propagation and
96 // requantize results.
97 //
98 class QuantizationDriver {
99  public:
QuantizationDriver(FuncOp fn,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_range,bool legacy_float_scale)100   explicit QuantizationDriver(FuncOp fn, bool is_signed,
101                               bool disable_per_channel,
102                               OpQuantSpecGetter op_quant_spec_getter,
103                               bool infer_tensor_range, bool legacy_float_scale)
104       : fn_(fn),
105         builder_(fn.getBody()),
106         is_signed_(is_signed),
107         disable_per_channel_(disable_per_channel),
108         op_quant_spec_getter_(op_quant_spec_getter),
109         infer_tensor_range_(infer_tensor_range),
110         legacy_float_scale_(legacy_float_scale) {}
111 
112   // The entry point of the quantization parameters propagation.
113   void Run();
114 
115  private:
116   // This is used to identify an operand or result of an op. The second element
117   // of this pair is the index of the operand or result.
118   using OpValue = std::pair<mlir::Operation *, int>;
119 
120   // Sets up the states for all the op results in the function.
121   void Initialize();
122 
123   // Propagates the quantization parameters across all the ops.
124   bool PropagateParams();
125 
126   // Inserts the Quantize and Dequantize ops according to the propagation
127   // result.
128   void Finalize();
129 
130   // The quantization parameters of bias operand are usually determined by
131   // other operands, so if a constant is used by different ops as bias, it needs
132   // to be duplicated, thus each op can assign its own quantization parameter
133   // for this bias. Also this method adds all the non-bias constants (weights)
134   // to a set for looking up later. This method also adds all the per-channel
135   // weights to a set for looking up later.
136   void PreprocessConstantOps();
137 
138   // Setup all the data structures for quantization propagation.
139   void SetupAllStates();
140 
141   // Whether the constant is a weight, which shouldn't be shared by different
142   // ops.
IsWeight(Operation * cst)143   bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); }
144 
145   // Returns all the related quantization constraints of the op.
146   std::unique_ptr<OpQuantSpec> GetQuantSpec(Operation *op);
147 
148   // Whether Quantization parameters have been propagated to the results of this
149   // op.
150   bool IsQuantized(Operation *op);
151 
152   // Adds all the users of index-th result of op to the work list.
AddUserToList(Operation * op,int index)153   void AddUserToList(Operation *op, int index) {
154     for (auto *user : op->getResult(index).getUsers()) {
155       work_list_.push_back(user);
156     }
157   }
158 
159   // Adds the defining op of index-th operand of op to the work list.
AddOperandToList(Operation * op,int index)160   void AddOperandToList(Operation *op, int index) {
161     if (auto *inst = op->getOperand(index).getDefiningOp()) {
162       work_list_.push_back(inst);
163     }
164   }
165 
166   // Returns the quantization params for the bias input from the non-bias
167   // operands which have their indexes in the `non_biases` vector. The returned
168   // parameters are calculated by `func`.
169   QuantParams GetBiasParams(Operation *op, int bias,
170                             const std::vector<int> &non_biases,
171                             AccumulatorScaleFunc func);
172 
173   // Sets the quantization parameters of the result to a fixed value. If any
174   // quantization parameters have been propagated, a `requantize` will happen on
175   // the input of propagated quantization.
176   bool SetResultParams(Operation *op, int index, QuantParams params);
177 
178   // Sets the quantization parameters of the operand to a fixed value. If any
179   // quantization parameters have been propagated, a `requantize` will happen on
180   // the output of propagated quantization.
181   bool SetOperandParams(Operation *op, int index, QuantParams params);
182 
183   // Sets the quantization parameters of the constant result according to its
184   // content.
185   bool SetConstantResultParams(Operation *op);
186 
187   // Inserts the Quantize and Dequantize ops for quantizing the index-th result
188   // of the op.
189   void QuantizeOpResult(Operation *op, int index, QuantParams params);
190 
191   void QuantizeArg(BlockArgument arg, QuantParams params);
192 
193   // Inserts the Quantize and Dequantize ops to quantize the value and returns
194   // the Quantize op.
195   void QuantizeValue(Value value, QuantParams params, Location loc);
196 
197   // Inserts the Quantize ops for requantizing the index-th result of the op.
198   void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
199 
200   void RequantizeArg(BlockArgument arg, RequantizeState *state);
201 
202   // Inserts the Quantize and Dequantize ops to quantize the value and returns
203   // the Quantize op.
204   void RequantizeValue(Value value, RequantizeState *state, Location loc);
205 
206   // A heuristic to get the quantization parameter satisfies the same scale
207   // constraints for the op. Returns an empty option if this quantization
208   // parameter doesn't exist.
209   QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
210 
211   // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)212   QuantState &GetOperandQuantState(Operation *op, int index) {
213     return states_[operand_states_[{op, index}]];
214   }
215 
216   // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)217   QuantState &GetResultQuantState(Operation *op, int index) {
218     return states_[result_states_[{op, index}]];
219   }
220 
GetArgQuantState(BlockArgument arg)221   QuantState &GetArgQuantState(BlockArgument arg) {
222     return states_[arg_states_[arg]];
223   }
224 
225   // Returns the state of the index-th operand of the op.
GetOperandRequantizeState(Operation * op,int index)226   RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
227     return rescale_states_[operand_states_[{op, index}]];
228   }
229 
230   // Returns the state of the index-th result of the op.
GetResultRequantizeState(Operation * op,int index)231   RequantizeState &GetResultRequantizeState(Operation *op, int index) {
232     return rescale_states_[result_states_[{op, index}]];
233   }
234 
GetArgRequantizeState(BlockArgument arg)235   RequantizeState &GetArgRequantizeState(BlockArgument arg) {
236     return rescale_states_[arg_states_[arg]];
237   }
238 
239   // Uses the type of `val` to set the initial state of the index-th result if
240   // `as_result` is true or index-th operand if `as_result` is false. The state
241   // is immutable if the type is a quantized type. Returns the index of this
242   // new state in the state vector.
243   int InitializeState(Operation *op, int index, Value val, bool as_result);
244 
245   // Sets the state of an argument. If this value is cached, uses the cached
246   // result without creating new entry in the state vector. Otherwise, allocate
247   // a new entry in the state vector.
InitializeArgState(BlockArgument arg,Value in,llvm::DenseMap<Value,int> * cache)248   void InitializeArgState(BlockArgument arg, Value in,
249                           llvm::DenseMap<Value, int> *cache) {
250     auto cached = cache->insert({in, 0});
251     if (!cached.second) {
252       arg_states_[arg] = cached.first->second;
253       return;
254     }
255     QuantParams params =
256         quant::QuantizedType::getQuantizedElementType(in.getType());
257     bool immutable = !EmptyParams(params);
258     int next_state_index = states_.size();
259     states_.push_back({params, immutable});
260     arg_states_[arg] = next_state_index;
261     cached.first->second = next_state_index;
262   }
263 
264   // Sets the state of the index-th operand of the op. If this operand is
265   // cached, uses the cached result without creating new entry in the state
266   // vector. Otherwise, allocate a new entry in the state vector.
InitializeOperandState(Operation * op,int index,Value in,llvm::DenseMap<Value,int> * cache)267   void InitializeOperandState(Operation *op, int index, Value in,
268                               llvm::DenseMap<Value, int> *cache) {
269     auto cached = cache->insert({in, 0});
270     if (!cached.second) {
271       operand_states_.insert({{op, index}, cached.first->second});
272       return;
273     }
274     cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
275   }
276 
277   // Sets the state of the index-th result of the op. If this result is cached,
278   // uses the cached result without creating new entry in the state vector.
279   // Otherwise, allocate a new entry in the state vector.
InitializeResultState(Operation * op,int index,Value res,llvm::DenseMap<Value,int> * cache)280   void InitializeResultState(Operation *op, int index, Value res,
281                              llvm::DenseMap<Value, int> *cache) {
282     auto cached = cache->insert({res, 0});
283     if (!cached.second) {
284       result_states_.insert({{op, index}, cached.first->second});
285       return;
286     }
287     cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
288   }
289 
DumpStates(Operation * current_op)290   void DumpStates(Operation *current_op) {
291     if (current_op) {
292       llvm::dbgs() << "\n\n\n" << current_op->getName() << "\n";
293     }
294     fn_.walk([&](Operation *op) {
295       if (op->hasTrait<OpTrait::IsTerminator>() ||
296           op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
297           llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
298               op))
299         return;
300       if (current_op == op) llvm::dbgs() << "===>>>";
301       llvm::dbgs() << op->getName() << " : (";
302       if (llvm::isa<FuncOp>(op)) {
303         for (auto &arg : fn_.getArguments()) {
304           if (auto params = GetArgQuantState(arg).params) {
305             params.print(llvm::dbgs());
306             auto requantize_state = GetArgRequantizeState(arg);
307             if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
308               llvm::dbgs() << "+";
309               requantize_state.params.print(llvm::dbgs());
310             }
311           }
312           llvm::dbgs() << ",";
313         }
314       }
315       for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
316         if (auto params = GetOperandQuantState(op, i).params) {
317           params.print(llvm::dbgs());
318           auto requantize_state = GetOperandRequantizeState(op, i);
319           if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
320             llvm::dbgs() << "+";
321             requantize_state.params.print(llvm::dbgs());
322           }
323         } else {
324           op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
325               llvm::dbgs());
326         }
327         llvm::dbgs() << ",";
328       }
329       llvm::dbgs() << ") -> (";
330       for (int i = 0, e = op->getNumResults(); i < e; ++i) {
331         if (auto params = GetResultQuantState(op, i).params) {
332           params.print(llvm::dbgs());
333           auto requantize_state = GetResultRequantizeState(op, i);
334           if (requantize_state.pos != RequantizeState::NO_REQUANTIZE) {
335             llvm::dbgs() << "+";
336             requantize_state.params.print(llvm::dbgs());
337           }
338         } else {
339           op->getResult(i).getType().cast<ShapedType>().getElementType().print(
340               llvm::dbgs());
341         }
342         llvm::dbgs() << ",";
343       }
344       llvm::dbgs() << ")\n";
345     });
346   }
347 
348   FuncOp fn_;
349   OpBuilder builder_;
350   bool is_signed_;
351   bool disable_per_channel_;
352 
353   // We should distinguish weights and bias constants. Biases are specified by
354   // the quantization spec or are the operands of ops with same scale spec. The
355   // rest are weights.
356   llvm::DenseSet<Operation *> weights_;
357 
358   // The weights require narrow_range quantization. This map collects all the
359   // weight operands defined by the op quant spec. If the value of the entry is
360   // positive, per-channel quantization is required.
361   llvm::DenseMap<Operation *, int> optimized_weights_;
362 
363   // All the ops needs to propagate the quantization parameters to.
364   std::vector<Operation *> work_list_;
365   std::unordered_set<Operation *> quantized_;
366 
367   // The vector contains all the quantization parameters propagated from the
368   // defining operations of the value, or from the quantization aware training.
369   std::vector<QuantState> states_;
370 
371   // The map contains all the quantization parameters which are required to
372   // satisfy the same operands and results constraint. The keys of this map are
373   // the values from `operand_states_` and `result_state_`.
374   std::unordered_map<int, RequantizeState> rescale_states_;
375 
376   // Maps of indexes to the propagation state vector from the ops operands,
377   // results and arguments.
378   llvm::DenseMap<OpValue, int> operand_states_;
379   llvm::DenseMap<OpValue, int> result_states_;
380   llvm::DenseMap<BlockArgument, int> arg_states_;
381 
382   // This vector is to preserve the arguments order, so the newly inserted
383   // quantized ops for the arguments are deterministically ordered.
384   llvm::SmallVector<BlockArgument, 4> args_;
385 
386   OpQuantSpecGetter op_quant_spec_getter_;
387 
388   // Infer output ranges for activation ops and constants. This is usually
389   // required for post-training quantization.
390   bool infer_tensor_range_;
391 
392   // Calculate scales in float instead of double, so that the scales and
393   // quantized values are exactly the same with the TOCO quantizer.
394   bool legacy_float_scale_;
395 };
396 }  // namespace
397 
GetQuantSpec(Operation * op)398 std::unique_ptr<OpQuantSpec> QuantizationDriver::GetQuantSpec(Operation *op) {
399   return op_quant_spec_getter_(op);
400 }
401 
IsQuantized(Operation * op)402 bool QuantizationDriver::IsQuantized(Operation *op) {
403   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
404     if (GetResultQuantState(op, i).IsEmpty()) return false;
405   }
406   return true;
407 }
408 
InitializeState(Operation * op,int index,Value val,bool as_result)409 int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
410                                         bool as_result) {
411   QuantParams params =
412       quant::QuantizedType::getQuantizedElementType(val.getType());
413   bool immutable = !EmptyParams(params);
414   int next_state_index = states_.size();
415   states_.push_back({params, immutable});
416   if (as_result)
417     result_states_.insert({{op, index}, next_state_index});
418   else
419     operand_states_.insert({{op, index}, next_state_index});
420 
421   return next_state_index;
422 }
423 
SetConstantResultParams(Operation * op)424 bool QuantizationDriver::SetConstantResultParams(Operation *op) {
425   DenseFPElementsAttr attr;
426   Value res = op->getResult(0);
427   if (!matchPattern(res, m_Constant(&attr))) {
428     return false;
429   }
430   // TODO(fengliuai): make storage_type_width and narrow_range configurable.
431   Type final_type;
432   auto it = optimized_weights_.find(op);
433   bool is_weight = it != optimized_weights_.end();
434   bool is_weight_with_per_channel_support =
435       is_weight && it->second != -1 && is_signed_;
436 
437   if (is_weight_with_per_channel_support && !disable_per_channel_) {
438     // When `disable_per_channel_` is false, per-channel symmetric quantization
439     // parameters are created from the weights when the ops support per-channel
440     // quantization. Otherwise, uses per-tensor asymmetric quantization with
441     // narrow range.
442 
443     // per-axis quantization weight, with symmetric min/max enforced.
444     final_type = GetUniformQuantizedPerAxisTypeForWeight(
445         attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_,
446         /*narrow_range=*/true, legacy_float_scale_);
447   } else {
448     // per-tensor quantization weight
449     final_type = GetUniformQuantizedTypeForWeight(
450         attr, /*symmetric=*/is_weight && is_signed_,
451         /*num_bits=*/8, is_signed_,
452         /*narrow_range_=*/is_weight, legacy_float_scale_);
453   }
454   if (auto quant_type = final_type.dyn_cast_or_null<quant::QuantizedType>()) {
455     return SetResultParams(op, 0, quant_type);
456   }
457   return false;
458 }
459 
SetResultParams(Operation * op,int res_index,QuantParams params)460 bool QuantizationDriver::SetResultParams(Operation *op, int res_index,
461                                          QuantParams params) {
462   auto &state = GetResultQuantState(op, res_index);
463   if (state.params == params) {
464     return false;
465   }
466   if (!state.IsEmpty()) {
467     auto &rescale = GetResultRequantizeState(op, res_index);
468     rescale.params = params;
469     rescale.pos = RequantizeState::ON_INPUT;
470     return true;
471   }
472   state.params = params;
473   AddUserToList(op, res_index);
474   return true;
475 }
476 
GetBiasParams(Operation * op,int bias,const std::vector<int> & non_biases,AccumulatorScaleFunc func)477 QuantParams QuantizationDriver::GetBiasParams(
478     Operation *op, int bias, const std::vector<int> &non_biases,
479     AccumulatorScaleFunc func) {
480   auto &bias_state = GetOperandQuantState(op, bias);
481   if (!bias_state.IsEmpty()) {
482     return bias_state.params;
483   }
484   std::vector<QuantParams> op_types;
485   op_types.reserve(non_biases.size());
486   for (auto non_bias : non_biases) {
487     auto &non_bias_type = GetOperandQuantState(op, non_bias);
488     op_types.push_back(non_bias_type.params);
489   }
490   if (op_types.empty()) return {};
491   return func(op_types, legacy_float_scale_);
492 }
493 
SetOperandParams(Operation * op,int index,QuantParams params)494 bool QuantizationDriver::SetOperandParams(Operation *op, int index,
495                                           QuantParams params) {
496   auto &state = GetOperandQuantState(op, index);
497   if (state.params == params) {
498     return false;
499   }
500 
501   if (!state.IsEmpty()) {
502     auto &rescale = GetOperandRequantizeState(op, index);
503     rescale.params = params;
504     rescale.pos = RequantizeState::ON_OUTPUT;
505     return true;
506   }
507 
508   state.params = params;
509   AddOperandToList(op, index);
510   return true;
511 }
512 
QuantizeOpResult(Operation * op,int index,QuantParams params)513 void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
514                                           QuantParams params) {
515   builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
516   Value original_result = op->getResult(index);
517   QuantizeValue(original_result, params, op->getLoc());
518 }
519 
QuantizeArg(BlockArgument arg,QuantParams params)520 void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
521   builder_.setInsertionPointToStart(arg.getOwner());
522   QuantizeValue(arg, params, builder_.getUnknownLoc());
523 }
524 
QuantizeValue(Value value,QuantParams params,Location loc)525 void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
526                                        Location loc) {
527   Type expressed_type = value.getType();
528   Type new_type = params.castFromExpressedType(expressed_type);
529   // This value isn't an expressed type (float), skip.
530   if (!new_type) return;
531 
532   auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
533   auto dequantize = builder_.create<quant::DequantizeCastOp>(
534       loc, expressed_type, quantize.getResult());
535 
536   // This attribute is set to distinguish the quantize ops being added by the
537   // quantization pass. These ops can be removed without losing original
538   // program accuracy.
539   // TODO(fengliuai): make the attribute being part of op definition.
540   quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
541 
542   // `original_result` has a use to `quantize`, so this will replace that use
543   // by the result of `dequantize`. Remember to reset that use afterwards
544   value.replaceAllUsesWith(dequantize);
545   quantize.getOperation()->replaceUsesOfWith(dequantize, value);
546 }
547 
RequantizeOpResult(Operation * op,int index,RequantizeState * state)548 void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
549                                             RequantizeState *state) {
550   if (state->pos == RequantizeState::NO_REQUANTIZE) return;
551   builder_.setInsertionPointAfter(op);
552   Value value = op->getResult(index);
553   if (state->pos == RequantizeState::ON_OUTPUT) {
554     Operation *user = value.getUses().begin().getUser();
555     if (llvm::isa<quant::QuantizeCastOp>(user)) {
556       // The requantize op is inserted between `quantize` and `dequantize` ops.
557       value = user->getResult(0);
558       builder_.setInsertionPointAfter(user);
559     }
560   }
561   RequantizeValue(value, state, op->getLoc());
562 }
563 
RequantizeArg(BlockArgument arg,RequantizeState * state)564 void QuantizationDriver::RequantizeArg(BlockArgument arg,
565                                        RequantizeState *state) {
566   Value value = arg;
567   builder_.setInsertionPointToStart(arg.getOwner());
568   if (value.hasOneUse()) {
569     auto user = value.use_begin().getUser();
570     if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
571       value = q.getResult();
572       builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
573     }
574   }
575   RequantizeValue(value, state, builder_.getUnknownLoc());
576 }
577 
RequantizeValue(Value value,RequantizeState * state,Location loc)578 void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
579                                          Location loc) {
580   Type new_type;
581   if (state->pos == RequantizeState::ON_INPUT) {
582     Type expressed_type = value.getType();
583     // The value needs to be requantized. A Quantize op will be created to use
584     // it as the operand and replace its uses.
585     new_type = state->params.castFromExpressedType(expressed_type);
586   } else {
587     Type expressed_type =
588         quant::QuantizedType::castToExpressedType(value.getType());
589     if (!expressed_type) return;
590 
591     // The value needs to be requantized. A Quantize op will be created to use
592     // it as the operand and replace its uses.
593     new_type = state->params.castFromExpressedType(expressed_type);
594   }
595   // This value isn't an expressed type (float), skip.
596   if (!new_type) return;
597 
598   auto requantize_op =
599       builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
600   value.replaceAllUsesWith(requantize_op);
601   requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
602 }
603 
604 // A heuristic to get quantization parameters satisfies the same scale
605 // constraints:
606 // - If there are immutable states,
607 //   - use the single input, or,
608 //   - use the single output, or,
609 //   - use the first one in the collection,
610 // - use the single input if it is ready, or,
611 // - use the single output if it is ready, or,
612 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)613 QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint(
614     Operation *op) {
615   // Two vector to collect Non-empty operands and results states.
616   std::vector<QuantState *> mutable_states, immutable_states;
617   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
618     auto &state = GetOperandQuantState(op, i);
619     if (state.immutable) {
620       immutable_states.push_back(&state);
621     } else if (!state.IsEmpty()) {
622       mutable_states.push_back(&state);
623     }
624   }
625 
626   int immutable_operands_num = immutable_states.size();
627   int mutable_operands_num = mutable_states.size();
628   // Use the operand's state if it is immutable and it is the only one
629   // operand.
630   if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
631     return immutable_states.front()->params;
632   }
633 
634   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
635     auto &state = GetResultQuantState(op, i);
636     if (state.immutable) {
637       immutable_states.push_back(&state);
638     } else if (!state.IsEmpty()) {
639       mutable_states.push_back(&state);
640     }
641   }
642 
643   int immutable_results_num = immutable_states.size() - immutable_operands_num;
644   int mutable_results_num = mutable_states.size() - mutable_operands_num;
645   // Use the result's state if it is immutable and it is the only one result.
646   if (op->getNumResults() == 1 && immutable_results_num == 1) {
647     return immutable_states.back()->params;
648   }
649 
650   // Use the first immutable state to quantize the rest operands and results.
651   if (!immutable_states.empty()) return immutable_states.front()->params;
652 
653   // If there are no immutable states, use the operand's state if it is the
654   // only one operand and has parameters propagated.
655   if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
656     return mutable_states.front()->params;
657   }
658 
659   // If there are no immutable states, use the result's state if it is the
660   // only one result and has parameters propagated.
661   if (op->getNumResults() == 1 && mutable_results_num == 1) {
662     return mutable_states.back()->params;
663   }
664 
665   // Use the first propagated state to quantize the rest operands and results.
666   if (!mutable_states.empty()) return mutable_states.front()->params;
667 
668   // None operands/results have parameters propagated, skip this node for now.
669   return {};
670 }
671 
PreprocessConstantOps()672 void QuantizationDriver::PreprocessConstantOps() {
673   fn_.walk([&](ConstantOp cst) {
674     // Non-float tensors are neither weights nor require quantization.
675     auto type = cst.getType().dyn_cast<ShapedType>();
676     if (!type || !type.getElementType().isa<FloatType>()) return;
677 
678     Value value = cst.getResult();
679     builder_.setInsertionPoint(cst);
680 
681     // The following loop will change the value uses, thus we cache all the uses
682     // needs to be changed.
683     llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
684     for (auto &use : value.getUses()) {
685       uses.push_back({use.getOwner(), use.getOperandNumber()});
686     }
687     for (auto indexed_use : llvm::enumerate(uses)) {
688       Operation *user = indexed_use.value().first;
689       int operand_num = indexed_use.value().second;
690 
691       auto spec = GetQuantSpec(user);
692       auto biases = spec->biases_params;
693 
694       // The quantization parameters of a `weight` shouldn't be determined by
695       // other values. So any constants which are not bias, an operand of an
696       // op with same scale requirements, and haven't been quantized are
697       // weights.
698       if (biases.find(operand_num) == biases.end() &&
699           !llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
700           !llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
701         // Needs to scan the content of weights to get the quantization
702         // parameters if there are no quantization parameters (FakeQuant ops).
703         // For this case, the weight will not be duplicated.
704         weights_.insert(cst);
705         auto affine_user =
706             llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
707         if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
708             affine_user.RequiredNarrowRangeAffineOperand()) {
709           optimized_weights_.insert(
710               {cst, affine_user.GetQuantizationDimIndex()});
711         }
712       } else {
713         // This is a bias or an operand of an op with same scale requirements,
714         // so the quantization parameter are propagated from or determined by
715         // other values. Duplicate this constant in case it is shared by
716         // different users.
717         if (uses.size() > 1) {
718           auto new_cst =
719               builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
720           user->setOperand(operand_num, new_cst);
721         }
722       }
723     }
724   });
725 }
726 
SetupAllStates()727 void QuantizationDriver::SetupAllStates() {
728   llvm::DenseMap<Value, int> value_to_state;
729 
730   for (auto arg : fn_.getArguments()) {
731     args_.push_back(arg);
732     Value value = arg;
733     // If the argument is quantized, it should only has one user.
734     if (arg.hasOneUse()) {
735       auto user = value.use_begin().getUser();
736       if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
737         value = q.getResult();
738       }
739     }
740     InitializeArgState(arg, value, &value_to_state);
741   }
742 
743   fn_.walk([&](Operation *op) {
744     if (op->hasTrait<OpTrait::IsTerminator>() ||
745         op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
746         llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
747       return;
748     work_list_.push_back(op);
749 
750     for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
751       auto operand = op->getOperand(i);
752       if (auto *inst = operand.getDefiningOp()) {
753         // If the operand comes from a tfl.dequantize op, we use the quantized
754         // input of this tfl.dequantize op to set the state.
755         if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
756           operand = dq.arg();
757         }
758       }
759       InitializeOperandState(op, i, operand, &value_to_state);
760     }
761 
762     for (int res = 0, e = op->getNumResults(); res != e; ++res) {
763       Value result = op->getResult(res);
764       // If the result has been quantized, it should only be used by a
765       // tfl.quantize op. For this case, we uses the quantized result to
766       // create the state and mark it immutable.
767       if (result.hasOneUse()) {
768         auto user = result.use_begin().getUser();
769         if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
770           result = q.getResult();
771         }
772       }
773       InitializeResultState(op, res, result, &value_to_state);
774     }
775   });
776 }
777 
778 // This method scans the operations in the function to setup the initial
779 // states for quantization parameter propagation.
780 // TODO(fengliuai): This algorithm assumes there are only one pair of
781 // tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity
782 // check should be applied.
Initialize()783 void QuantizationDriver::Initialize() {
784   // Duplicate the bias constant, so the states can be setup correctly.
785   // TODO(fengliuai): Function definition should also be duplicated if there
786   // are multiple call sites.
787   PreprocessConstantOps();
788 
789   // Setup all the internal states.
790   SetupAllStates();
791 }
792 
PropagateParams()793 bool QuantizationDriver::PropagateParams() {
794   // TODO(fengliuai): uses a typed indicator instead of a bool value.
795   bool changed = false;
796   while (!work_list_.empty()) {
797     Operation *op = work_list_.back();
798     work_list_.pop_back();
799 
800     LLVM_DEBUG(DumpStates(op));
801 
802     // This op has been quantized, so we should not consider it again.
803     if (llvm::is_contained(quantized_, op)) continue;
804     quantized_.insert(op);
805 
806     if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
807       // If the workflow requires inferring ranges from the content
808       // (post-training quantization) and it is weight (filter) and hasn't
809       // been quantized, we infer the quantization parameters from the content.
810       if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
811         // The quantization parameters are determined by the content of the
812         // constant.
813         changed |= SetConstantResultParams(op);
814       }
815       continue;
816     }
817 
818     if (llvm::isa<SameScalesOpInterface>(op)) {
819       auto params = GetQuantParamsForSameScaleConstraint(op);
820       // The quantization parameters haven't been propagated to any operands
821       // or results. Skip this node for now.
822       if (!params) {
823         quantized_.erase(op);
824         continue;
825       }
826 
827       // Use the final state to set all the operands' parameters.
828       for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
829         if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
830           // Without this check, it will accidentally propagate the quantization
831           // information by the shared non-float tensors.
832           if (type.getElementType().isa<FloatType>())
833             changed |= SetOperandParams(op, i, params);
834         }
835       }
836 
837       // Use the final state to set all the results' parameters.
838       for (int res = 0, e = op->getNumResults(); res != e; ++res)
839         if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
840           // Without this check, it will accidentally propagate the quantization
841           // information by the shared non-float-tensors.
842           if (type.getElementType().isa<FloatType>())
843             changed |= SetResultParams(op, res, params);
844         }
845     }
846 
847     // TODO(fengliuai): make the bit width configurable.
848     auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
849     if (restricted && infer_tensor_range_) {
850       // Infer ranges from the activation ops. This is usually required for
851       // the post-training quantization workflow.
852       // TODO(fengliuai): different result can have different fixed range.
853       auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
854       for (auto i = 0; i < op->getNumResults(); ++i) {
855         // The range is null if the result has been quantized.
856         if (params) {
857           changed |= SetResultParams(op, i, params);
858         }
859       }
860     }
861 
862     auto spec = GetQuantSpec(op);
863     for (auto &it : spec->biases_params) {
864       auto params =
865           GetBiasParams(op, it.first, it.second.first, it.second.second);
866       if (!params) {
867         quantized_.erase(op);
868         continue;
869       }
870       changed |= SetOperandParams(op, it.first, params);
871     }
872   }
873 
874   LLVM_DEBUG(llvm::dbgs() << "\n\n\n");
875   LLVM_DEBUG(DumpStates(nullptr));
876 
877   return changed;
878 }
879 
Finalize()880 void QuantizationDriver::Finalize() {
881   for (auto arg : args_) {
882     auto &state = GetArgQuantState(arg);
883     auto &requantize = GetArgRequantizeState(arg);
884     if (state.IsEmpty() ||
885         (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
886       continue;
887     }
888 
889     if (!state.immutable) {
890       QuantizeArg(arg, state.params);
891     }
892 
893     if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
894       RequantizeArg(arg, &requantize);
895     }
896   }
897 
898   for (auto it : result_states_) {
899     Operation *op = it.first.first;
900     int res_index = it.first.second;
901     auto &state = GetResultQuantState(op, res_index);
902     auto &requantize = GetResultRequantizeState(op, res_index);
903     if (state.IsEmpty() ||
904         (state.immutable && requantize.pos == RequantizeState::NO_REQUANTIZE)) {
905       continue;
906     }
907 
908     if (!state.immutable) {
909       QuantizeOpResult(op, res_index, state.params);
910     }
911 
912     if (requantize.pos != RequantizeState::NO_REQUANTIZE) {
913       RequantizeOpResult(op, res_index, &requantize);
914     }
915   }
916 }
917 
Run()918 void QuantizationDriver::Run() {
919   Initialize();
920   if (PropagateParams()) {
921     Finalize();
922   }
923 }
924 
ApplyQuantizationParamsPropagation(mlir::FuncOp func,bool is_signed,bool disable_per_channel,OpQuantSpecGetter op_quant_spec_getter,bool infer_tensor_ranges,bool legacy_float_scale)925 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
926                                         bool disable_per_channel,
927                                         OpQuantSpecGetter op_quant_spec_getter,
928                                         bool infer_tensor_ranges,
929                                         bool legacy_float_scale) {
930   QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
931                      infer_tensor_ranges, legacy_float_scale)
932       .Run();
933 }
934 
935 }  // namespace quant
936 }  // namespace mlir
937