• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
18 
19 #include "llvm/ADT/DenseMap.h"
20 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
21 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
29 
30 namespace mlir {
31 namespace quant {
32 
EmptyParams(QuantParams p)33 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
34 
35 // The state for each op result during the quantization parameters propagation.
36 struct QuantState {
37   // Quantization parameters propagated to an op result.
38   QuantParams params;
39   // A flag indicates this state (the params) shouldn't be changed after it is
40   // initialized. This flag will be set to true if the quantization parameters
41   // are from the quantization-aware training.
42   const bool immutable;
43 
IsEmptyQuantState44   bool IsEmpty() { return EmptyParams(params); }
45 };
46 
47 // The state for rescaling the propagated quantization parameters. This can be
48 // on the input side to satisfy the constraint of previous operation, or on the
49 // output side to satisfy the constraint of the next operation.
50 struct RequantizeState {
51   // Sometimes, we have to "requantize" the quantization result to satisfy all
52   // the constraints. The "requantize" can happen either on the input or output
53   // of the quantization result.
54   enum RequantizePosition {
55     NO_REQUANTIZE,
56     ON_INPUT,
57     ON_OUTPUT
58   } pos = NO_REQUANTIZE;
59 
60   // Quantization parameters will be used to add the requantize ops.
61   QuantParams params;
62 };
63 
64 // This class manages all the intermediate quantization states.
65 class QuantizeContext {
66  public:
67   QuantizeContext(FuncOp func, const DeviceTarget &spec);
68 
69   // Returns all the quant region ops.
70   std::vector<quant::QuantizeRegionOp> GetAllOps();
71 
72   // For each quant region op, propagates its quantization parameters according
73   // to the kernel specification and also returns the adjacent quant region ops
74   // which get the new quantization parameters propagated.
75   LogicalResult Handle(quant::QuantizeRegionOp op,
76                        llvm::SmallVectorImpl<Operation *> *new_items,
77                        bool *changed);
78 
79   // Updates the port quantization specifications of all the quant region ops
80   // with the propagation results.
81   LogicalResult Finalize();
82 
83   // Dumps the states stores in the state manager.
84   void DumpStates(QuantizeRegionOp current_op = {});
85 
86   // Update the quantization parameter for certain result of the op. By this
87   // method, the quantization parameter is propagated to all the users of the
88   // result as well.
SetResultParams(Operation * op,int index,QuantParams params)89   bool SetResultParams(Operation *op, int index, QuantParams params) {
90     return states_manager_.SetResultParams(op, index, params);
91   }
92 
93   // Update the quantization parameter for certain operand of the op. By this
94   // method, the quantization parameter is propagated to the defining op of
95   // operand as well.
SetOperandParams(Operation * op,int index,QuantParams params)96   bool SetOperandParams(Operation *op, int index, QuantParams params) {
97     return states_manager_.SetOperandParams(op, index, params);
98   }
99 
100   // Return the quantization parameter of certain result of the op.
GetResultParams(Operation * op,int index)101   QuantParams GetResultParams(Operation *op, int index) {
102     return states_manager_.GetResultParams(op, index);
103   }
104 
105   // Return the quantization parameter of certain operand of the op.
GetOperandParams(Operation * op,int index)106   QuantParams GetOperandParams(Operation *op, int index) {
107     return states_manager_.GetOperandParams(op, index);
108   }
109 
110   // Return the signature of the op.
111   KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
112 
113   // A heuristic to get quantization parameters satisfies the same scale
114   // constraints:
115   // - If there are immutable states,
116   //   - use the single input, or,
117   //   - use the single output, or,
118   //   - use the first one in the collection,
119   // - use the single input if it is ready, or,
120   // - use the single output if it is ready, or,
121   // - use the first ready one in the collection.
122   QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
123 
124   // Propagate `params` to all the quantizable port of the `op`. The adjacent
125   // ops, which have the parameters propagated to, are collected by `new_items`,
126   // so they can be added to the working queue. `changed` is set to true if
127   // there are any new elements being added to `new_items`.
128   LogicalResult PropagateQuantParams(Operation *op, const QuantParams params,
129                                      AdjacentOperations *new_items,
130                                      bool *changed);
131 
132  private:
133   class StatesManager {
134    public:
135     // Sets the quantization parameters of the constant result according to its
136     // content.
137     //
138     // Always returns true.
139     bool SetConstantResultParams(Operation *op);
140 
141     // Sets the quantization parameters of the result to a fixed value. If any
142     // quantization parameters have been propagated, a `requantize` will happen
143     // on the input of propagated quantization.
144     //
145     // Returns true, if the users of the result needs to be added to the
146     // worklist.
147     bool SetResultParams(Operation *op, int index, QuantParams params);
148 
149     // Sets the quantization parameters of the operand to a fixed value. If any
150     // quantization parameters have been propagated, a `requantize` will happen
151     // on the output of propagated quantization.
152     //
153     // Returns true, if the defining op of the operand needs to be added to the
154     // worklist.
155     bool SetOperandParams(Operation *op, int index, QuantParams params);
156 
157     // Returns the quantization parameters of the index-th result of the op.
GetResultParams(Operation * op,int index)158     QuantParams GetResultParams(Operation *op, int index) {
159       return states_[result_states_[{op, index}]].params;
160     }
161 
162     // Returns the quantization parameters of the index-th operand of the op.
GetOperandParams(Operation * op,int index)163     QuantParams GetOperandParams(Operation *op, int index) {
164       return states_[operand_states_[{op, index}]].params;
165     }
166 
167    private:
168     friend class QuantizeContext;
169 
170     // Uses the type of `val` to set the initial state of the index-th result if
171     // `as_result` is true or index-th operand if `as_result` is false. The
172     // state is immutable if the type is a quantized type. Returns the index of
173     // this new state in the state vector.
174     int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
175 
176     // Sets the state of the index-th operand of the op. If this operand is
177     // cached, uses the cached result without creating new entry in the state
178     // vector. Otherwise, allocate a new entry in the state vector.
179     void InitializeOperandState(quant::QuantizeRegionOp op, int index,
180                                 llvm::DenseMap<Value, int> *cache);
181 
182     // Sets the state of the index-th result of the op. If this result is
183     // cached, uses the cached result without creating new entry in the state
184     // vector. Otherwise, allocate a new entry in the state vector.
185     void InitializeResultState(quant::QuantizeRegionOp op, int index,
186                                llvm::DenseMap<Value, int> *cache);
187 
188     // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)189     QuantState &GetOperandQuantState(Operation *op, int index) {
190       return states_[operand_states_[{op, index}]];
191     }
192 
193     // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)194     QuantState &GetResultQuantState(Operation *op, int index) {
195       return states_[result_states_[{op, index}]];
196     }
197 
198     // Returns the state of the index-th operand of the op.
GetOperandRequantizeState(Operation * op,int index)199     RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
200       return rescale_states_[operand_states_[{op, index}]];
201     }
202 
203     // Returns the state of the index-th result of the op.
GetResultRequantizeState(Operation * op,int index)204     RequantizeState &GetResultRequantizeState(Operation *op, int index) {
205       return rescale_states_[result_states_[{op, index}]];
206     }
207 
208    private:
209     // This is used to identify an operand or result of an op. The second
210     // element of this pair is the index of the operand or result.
211     using OpValue = std::pair<mlir::Operation *, int>;
212 
213     // The vector contains all the quantization parameters propagated from the
214     // defining operations of the value, or from the quantization aware
215     // training.
216     std::vector<QuantState> states_;
217 
218     // The map contains all the quantization parameters which are required to
219     // satisfy the same operands and results constraint. The keys of this map
220     // are the values from `operand_states_` and `result_state_`.
221     std::unordered_map<int, RequantizeState> rescale_states_;
222 
223     // Maps of indexes to the propagation state vector from the ops operands,
224     // results and arguments.
225     llvm::DenseMap<OpValue, int> operand_states_;
226     llvm::DenseMap<OpValue, int> result_states_;
227   };
228 
229   FuncOp func_;
230 
231   DeviceTarget target_spec_;
232 
233   StatesManager states_manager_;
234 };
235 
236 }  // namespace quant
237 }  // namespace mlir
238 
239 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
240