1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
18
19 #include "llvm/ADT/DenseMap.h"
20 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
21 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/Value.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
29
30 namespace mlir {
31 namespace quant {
32
EmptyParams(QuantParams p)33 static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
34
35 // The state for each op result during the quantization parameters propagation.
36 struct QuantState {
37 // Quantization parameters propagated to an op result.
38 QuantParams params;
39 // A flag indicates this state (the params) shouldn't be changed after it is
40 // initialized. This flag will be set to true if the quantization parameters
41 // are from the quantization-aware training.
42 const bool immutable;
43
IsEmptyQuantState44 bool IsEmpty() { return EmptyParams(params); }
45 };
46
47 // The state for rescaling the propagated quantization parameters. This can be
48 // on the input side to satisfy the constraint of previous operation, or on the
49 // output side to satisfy the constraint of the next operation.
50 struct RequantizeState {
51 // Sometimes, we have to "requantize" the quantization result to satisfy all
52 // the constraints. The "requantize" can happen either on the input or output
53 // of the quantization result.
54 enum RequantizePosition {
55 NO_REQUANTIZE,
56 ON_INPUT,
57 ON_OUTPUT
58 } pos = NO_REQUANTIZE;
59
60 // Quantization parameters will be used to add the requantize ops.
61 QuantParams params;
62 };
63
64 // This class manages all the intermediate quantization states.
65 class QuantizeContext {
66 public:
67 QuantizeContext(FuncOp func, const DeviceTarget &spec);
68
69 // Returns all the quant region ops.
70 std::vector<quant::QuantizeRegionOp> GetAllOps();
71
72 // For each quant region op, propagates its quantization parameters according
73 // to the kernel specification and also returns the adjacent quant region ops
74 // which get the new quantization parameters propagated.
75 LogicalResult Handle(quant::QuantizeRegionOp op,
76 llvm::SmallVectorImpl<Operation *> *new_items,
77 bool *changed);
78
79 // Updates the port quantization specifications of all the quant region ops
80 // with the propagation results.
81 LogicalResult Finalize();
82
83 // Dumps the states stores in the state manager.
84 void DumpStates(QuantizeRegionOp current_op = {});
85
86 // Update the quantization parameter for certain result of the op. By this
87 // method, the quantization parameter is propagated to all the users of the
88 // result as well.
SetResultParams(Operation * op,int index,QuantParams params)89 bool SetResultParams(Operation *op, int index, QuantParams params) {
90 return states_manager_.SetResultParams(op, index, params);
91 }
92
93 // Update the quantization parameter for certain operand of the op. By this
94 // method, the quantization parameter is propagated to the defining op of
95 // operand as well.
SetOperandParams(Operation * op,int index,QuantParams params)96 bool SetOperandParams(Operation *op, int index, QuantParams params) {
97 return states_manager_.SetOperandParams(op, index, params);
98 }
99
100 // Return the quantization parameter of certain result of the op.
GetResultParams(Operation * op,int index)101 QuantParams GetResultParams(Operation *op, int index) {
102 return states_manager_.GetResultParams(op, index);
103 }
104
105 // Return the quantization parameter of certain operand of the op.
GetOperandParams(Operation * op,int index)106 QuantParams GetOperandParams(Operation *op, int index) {
107 return states_manager_.GetOperandParams(op, index);
108 }
109
110 // Return the signature of the op.
111 KernelSpecs::Signature GetSignature(QuantizeRegionOp op);
112
113 // A heuristic to get quantization parameters satisfies the same scale
114 // constraints:
115 // - If there are immutable states,
116 // - use the single input, or,
117 // - use the single output, or,
118 // - use the first one in the collection,
119 // - use the single input if it is ready, or,
120 // - use the single output if it is ready, or,
121 // - use the first ready one in the collection.
122 QuantParams GetQuantParamsForSameScaleConstraint(Operation *op);
123
124 // Propagate `params` to all the quantizable port of the `op`. The adjacent
125 // ops, which have the parameters propagated to, are collected by `new_items`,
126 // so they can be added to the working queue. `changed` is set to true if
127 // there are any new elements being added to `new_items`.
128 LogicalResult PropagateQuantParams(Operation *op, const QuantParams params,
129 AdjacentOperations *new_items,
130 bool *changed);
131
132 private:
133 class StatesManager {
134 public:
135 // Sets the quantization parameters of the constant result according to its
136 // content.
137 //
138 // Always returns true.
139 bool SetConstantResultParams(Operation *op);
140
141 // Sets the quantization parameters of the result to a fixed value. If any
142 // quantization parameters have been propagated, a `requantize` will happen
143 // on the input of propagated quantization.
144 //
145 // Returns true, if the users of the result needs to be added to the
146 // worklist.
147 bool SetResultParams(Operation *op, int index, QuantParams params);
148
149 // Sets the quantization parameters of the operand to a fixed value. If any
150 // quantization parameters have been propagated, a `requantize` will happen
151 // on the output of propagated quantization.
152 //
153 // Returns true, if the defining op of the operand needs to be added to the
154 // worklist.
155 bool SetOperandParams(Operation *op, int index, QuantParams params);
156
157 // Returns the quantization parameters of the index-th result of the op.
GetResultParams(Operation * op,int index)158 QuantParams GetResultParams(Operation *op, int index) {
159 return states_[result_states_[{op, index}]].params;
160 }
161
162 // Returns the quantization parameters of the index-th operand of the op.
GetOperandParams(Operation * op,int index)163 QuantParams GetOperandParams(Operation *op, int index) {
164 return states_[operand_states_[{op, index}]].params;
165 }
166
167 private:
168 friend class QuantizeContext;
169
170 // Uses the type of `val` to set the initial state of the index-th result if
171 // `as_result` is true or index-th operand if `as_result` is false. The
172 // state is immutable if the type is a quantized type. Returns the index of
173 // this new state in the state vector.
174 int InitializeState(quant::QuantizeRegionOp op, int index, bool as_result);
175
176 // Sets the state of the index-th operand of the op. If this operand is
177 // cached, uses the cached result without creating new entry in the state
178 // vector. Otherwise, allocate a new entry in the state vector.
179 void InitializeOperandState(quant::QuantizeRegionOp op, int index,
180 llvm::DenseMap<Value, int> *cache);
181
182 // Sets the state of the index-th result of the op. If this result is
183 // cached, uses the cached result without creating new entry in the state
184 // vector. Otherwise, allocate a new entry in the state vector.
185 void InitializeResultState(quant::QuantizeRegionOp op, int index,
186 llvm::DenseMap<Value, int> *cache);
187
188 // Returns the state of the index-th operand of the op.
GetOperandQuantState(Operation * op,int index)189 QuantState &GetOperandQuantState(Operation *op, int index) {
190 return states_[operand_states_[{op, index}]];
191 }
192
193 // Returns the state of the index-th result of the op.
GetResultQuantState(Operation * op,int index)194 QuantState &GetResultQuantState(Operation *op, int index) {
195 return states_[result_states_[{op, index}]];
196 }
197
198 // Returns the state of the index-th operand of the op.
GetOperandRequantizeState(Operation * op,int index)199 RequantizeState &GetOperandRequantizeState(Operation *op, int index) {
200 return rescale_states_[operand_states_[{op, index}]];
201 }
202
203 // Returns the state of the index-th result of the op.
GetResultRequantizeState(Operation * op,int index)204 RequantizeState &GetResultRequantizeState(Operation *op, int index) {
205 return rescale_states_[result_states_[{op, index}]];
206 }
207
208 private:
209 // This is used to identify an operand or result of an op. The second
210 // element of this pair is the index of the operand or result.
211 using OpValue = std::pair<mlir::Operation *, int>;
212
213 // The vector contains all the quantization parameters propagated from the
214 // defining operations of the value, or from the quantization aware
215 // training.
216 std::vector<QuantState> states_;
217
218 // The map contains all the quantization parameters which are required to
219 // satisfy the same operands and results constraint. The keys of this map
220 // are the values from `operand_states_` and `result_state_`.
221 std::unordered_map<int, RequantizeState> rescale_states_;
222
223 // Maps of indexes to the propagation state vector from the ops operands,
224 // results and arguments.
225 llvm::DenseMap<OpValue, int> operand_states_;
226 llvm::DenseMap<OpValue, int> result_states_;
227 };
228
229 FuncOp func_;
230
231 DeviceTarget target_spec_;
232
233 StatesManager states_manager_;
234 };
235
236 } // namespace quant
237 } // namespace mlir
238
239 #endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_
240