1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/Matchers.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/Value.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
40
41 #define DEBUG_TYPE "quantization-context"
42
43 namespace mlir {
44 namespace quant {
45
QuantizeContext(FuncOp func,const DeviceTarget & spec)46 QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
47 : func_(func), target_spec_(spec) {
48 llvm::DenseMap<Value, int> value_to_state;
49 func.walk([&](quant::QuantizeRegionOp op) {
50 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
51 states_manager_.InitializeOperandState(op, i, &value_to_state);
52 }
53
54 for (int res = 0, e = op.getNumResults(); res != e; ++res) {
55 states_manager_.InitializeResultState(op, res, &value_to_state);
56 }
57 });
58 }
59
GetAllOps()60 std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
61 std::vector<quant::QuantizeRegionOp> all_ops;
62 all_ops.reserve(128);
63 func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
64 return all_ops;
65 }
66
GetSignature(QuantizeRegionOp op)67 KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
68 KernelSpecs::Signature signature;
69 signature.reserve(op.input_specs().size() + op.output_specs().size());
70 for (int i = 0; i < op.getNumOperands(); ++i) {
71 DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
72 }
73 for (int i = 0; i < op.getNumResults(); ++i) {
74 DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
75 }
76 return signature;
77 }
78
Handle(quant::QuantizeRegionOp op,llvm::SmallVectorImpl<Operation * > * new_items,bool * changed)79 LogicalResult QuantizeContext::Handle(
80 quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
81 bool *changed) {
82 auto signature = GetSignature(op);
83 auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
84 if (!spec.hasValue()) {
85 op.emitWarning(
86 "Couldn't find kernel from the registration for quantization.");
87 return success();
88 }
89 switch (spec->type) {
90 case ScaleConstraintType::OutputInputFreeScale: {
91 // no propagation.
92 *changed |= false;
93 break;
94 }
95 case ScaleConstraintType::CustomScale: {
96 if (failed(spec->scale_fn(this, op, new_items, changed))) {
97 return failure();
98 }
99 break;
100 }
101 case ScaleConstraintType::OutputInputSameScale: {
102 auto params = GetQuantParamsForSameScaleConstraint(op);
103 if (EmptyParams(params)) {
104 *changed |= false;
105 break;
106 }
107 // propagate this params to all the quantizable ports.
108 if (failed(PropagateQuantParams(op, params, new_items, changed))) {
109 return failure();
110 }
111 break;
112 }
113 default: {
114 // TODO(fengliuai): implement the other types.
115 llvm_unreachable("no implementation.");
116 return failure();
117 }
118 }
119 return success();
120 }
121
Finalize()122 LogicalResult QuantizeContext::Finalize() {
123 MLIRContext *context = func_.getContext();
124 func_.walk([&](quant::QuantizeRegionOp op) {
125 llvm::SmallVector<Attribute, 4> input_specs;
126 auto original_input_specs = op.input_specs().getValue();
127 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
128 auto &state = states_manager_.GetOperandQuantState(op, i);
129 auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
130 if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
131 input_specs.push_back(original_input_specs[i]);
132 } else if (requantize.pos == RequantizeState::ON_OUTPUT) {
133 input_specs.push_back(TypeAttr::get(requantize.params));
134 } else {
135 input_specs.push_back(TypeAttr::get(state.params));
136 }
137 }
138 op->setAttr("input_specs", ArrayAttr::get(context, input_specs));
139
140 llvm::SmallVector<Attribute, 4> output_specs;
141 auto original_output_specs = op.output_specs().getValue();
142 for (int res = 0, e = op.getNumResults(); res != e; ++res) {
143 auto &state = states_manager_.GetResultQuantState(op, res);
144 auto &requantize = states_manager_.GetResultRequantizeState(op, res);
145 if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
146 output_specs.push_back(original_output_specs[res]);
147 } else if (requantize.pos == RequantizeState::ON_INPUT) {
148 output_specs.push_back(TypeAttr::get(requantize.params));
149 } else {
150 output_specs.push_back(TypeAttr::get(state.params));
151 }
152 }
153 op->setAttr("output_specs", ArrayAttr::get(context, output_specs));
154 });
155 return success();
156 }
157
DumpStates(QuantizeRegionOp current_op)158 void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
159 if (current_op) {
160 llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
161 }
162 func_.walk([&](QuantizeRegionOp op) {
163 if (current_op == op) llvm::errs() << "===>>>";
164 llvm::errs() << op.logical_kernel() << " : (";
165 for (auto i = 0; i < op.getNumOperands(); ++i) {
166 if (auto params = GetOperandParams(op, i))
167 params.print(llvm::errs());
168 else
169 llvm::errs() << "_";
170 llvm::errs() << ",";
171 }
172 llvm::errs() << ") -> (";
173 for (auto i = 0; i < op.getNumResults(); ++i) {
174 if (auto params = GetResultParams(op, i))
175 params.print(llvm::errs());
176 else
177 llvm::errs() << "_";
178 llvm::errs() << ",";
179 }
180 llvm::errs() << ")\n";
181 });
182 }
183
184 // A heuristic to get quantization parameters satisfies the same scale
185 // constraints:
186 // - If there are immutable states,
187 // - use the single input, or,
188 // - use the single output, or,
189 // - use the first one in the collection,
190 // - use the single input if it is ready, or,
191 // - use the single output if it is ready, or,
192 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)193 QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint(
194 Operation *op) {
195 // Two vector to collect Non-empty operands and results states.
196 std::vector<quant::QuantState *> mutable_states, immutable_states;
197 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
198 auto &state = states_manager_.GetOperandQuantState(op, i);
199 if (state.immutable) {
200 immutable_states.push_back(&state);
201 } else if (!state.IsEmpty()) {
202 mutable_states.push_back(&state);
203 }
204 }
205
206 int immutable_operands_num = immutable_states.size();
207 int mutable_operands_num = mutable_states.size();
208 // Use the operand's state if it is immutable and it is the only one
209 // operand.
210 if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
211 return immutable_states.front()->params;
212 }
213
214 for (int i = 0, e = op->getNumResults(); i != e; ++i) {
215 auto &state = states_manager_.GetResultQuantState(op, i);
216 if (state.immutable) {
217 immutable_states.push_back(&state);
218 } else if (!state.IsEmpty()) {
219 mutable_states.push_back(&state);
220 }
221 }
222
223 int immutable_results_num = immutable_states.size() - immutable_operands_num;
224 int mutable_results_num = mutable_states.size() - mutable_operands_num;
225 // Use the result's state if it is immutable and it is the only one result.
226 if (op->getNumResults() == 1 && immutable_results_num == 1) {
227 return immutable_states.back()->params;
228 }
229
230 LLVM_DEBUG(llvm::dbgs()
231 << "Quantization parameters are not collected in an ideal place. "
232 "Has to fallback values which might introduce errors.\n");
233
234 // Use the first immutable state to quantize the rest operands and results.
235 if (!immutable_states.empty()) return immutable_states.front()->params;
236
237 // If there are no immutable states, use the operand's state if it is the
238 // only one operand and has parameters propagated.
239 if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
240 return mutable_states.front()->params;
241 }
242
243 // If there are no immutable states, use the result's state if it is the
244 // only one result and has parameters propagated.
245 if (op->getNumResults() == 1 && mutable_results_num == 1) {
246 return mutable_states.back()->params;
247 }
248
249 // Use the first propagated state to quantize the rest operands and results.
250 if (!mutable_states.empty()) return mutable_states.front()->params;
251
252 // None operands/results have parameters propagated, skip this node for now.
253 return {};
254 }
255
PropagateQuantParams(Operation * op,const QuantParams params,quant::AdjacentOperations * new_items,bool * changed)256 LogicalResult QuantizeContext::PropagateQuantParams(
257 Operation *op, const QuantParams params,
258 quant::AdjacentOperations *new_items, bool *changed) {
259 // Use the final state to set all the operands' parameters.
260 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
261 auto ele = op->getOperand(i).getType().cast<ShapedType>().getElementType();
262 if (ele.isa<FloatType>() && SetOperandParams(op, i, params)) {
263 *changed |= true;
264 new_items->push_back(op->getOperand(i).getDefiningOp());
265 }
266 }
267
268 // Use the final state to set all the results' parameters.
269 for (int res = 0, e = op->getNumResults(); res != e; ++res) {
270 auto ele = op->getResult(res).getType().cast<ShapedType>().getElementType();
271 if (ele.isa<FloatType>() && SetResultParams(op, res, params)) {
272 auto users = op->getResult(res).getUsers();
273 *changed |= !users.empty();
274 new_items->append(users.begin(), users.end());
275 }
276 }
277 return success();
278 }
279
InitializeState(quant::QuantizeRegionOp op,int index,bool as_result)280 int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
281 int index, bool as_result) {
282 Attribute params_attr;
283 if (as_result) {
284 params_attr = op.output_specs()[index];
285 } else {
286 params_attr = op.input_specs()[index];
287 }
288 QuantParams params =
289 params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
290 bool immutable = !EmptyParams(params);
291 int next_state_index = states_.size();
292 states_.push_back({params, immutable});
293 if (as_result) {
294 result_states_.insert({{op, index}, next_state_index});
295 } else {
296 operand_states_.insert({{op, index}, next_state_index});
297 }
298 return next_state_index;
299 }
300
InitializeOperandState(quant::QuantizeRegionOp op,int index,llvm::DenseMap<Value,int> * cache)301 void QuantizeContext::StatesManager::InitializeOperandState(
302 quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
303 Value in = op.getOperand(index);
304 auto cached = cache->insert({in, 0});
305 if (!cached.second) {
306 operand_states_.insert({{op, index}, cached.first->second});
307 return;
308 }
309 cached.first->second = InitializeState(op, index, /*as_result=*/false);
310 }
311
InitializeResultState(quant::QuantizeRegionOp op,int index,llvm::DenseMap<Value,int> * cache)312 void QuantizeContext::StatesManager::InitializeResultState(
313 quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
314 auto res = op.getResult(index);
315 auto cached = cache->insert({res, 0});
316 if (!cached.second) {
317 result_states_.insert({{op, index}, cached.first->second});
318 return;
319 }
320 cached.first->second = InitializeState(op, index, /*as_result=*/true);
321 }
322
SetConstantResultParams(Operation * op)323 bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
324 llvm_unreachable("no implementation.");
325 return false;
326 }
327
SetResultParams(Operation * op,int res_index,QuantParams params)328 bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
329 int res_index,
330 QuantParams params) {
331 auto &state = GetResultQuantState(op, res_index);
332 if (state.params == params) {
333 return false;
334 }
335 if (!state.IsEmpty()) {
336 auto &rescale = GetResultRequantizeState(op, res_index);
337 rescale.params = params;
338 rescale.pos = RequantizeState::ON_INPUT;
339 return false;
340 }
341 state.params = params;
342 return true;
343 }
344
SetOperandParams(Operation * op,int index,QuantParams params)345 bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
346 QuantParams params) {
347 auto &state = GetOperandQuantState(op, index);
348 if (state.params == params) {
349 return false;
350 }
351
352 if (!state.IsEmpty()) {
353 auto &rescale = GetOperandRequantizeState(op, index);
354 rescale.params = params;
355 rescale.pos = RequantizeState::ON_OUTPUT;
356 return false;
357 }
358 state.params = params;
359 return true;
360 }
361 } // namespace quant
362 } // namespace mlir
363