• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
26 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/Matchers.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/Value.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
40 
41 #define DEBUG_TYPE "quantization-context"
42 
43 namespace mlir {
44 namespace quant {
45 
QuantizeContext(FuncOp func,const DeviceTarget & spec)46 QuantizeContext::QuantizeContext(FuncOp func, const DeviceTarget &spec)
47     : func_(func), target_spec_(spec) {
48   llvm::DenseMap<Value, int> value_to_state;
49   func.walk([&](quant::QuantizeRegionOp op) {
50     for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
51       states_manager_.InitializeOperandState(op, i, &value_to_state);
52     }
53 
54     for (int res = 0, e = op.getNumResults(); res != e; ++res) {
55       states_manager_.InitializeResultState(op, res, &value_to_state);
56     }
57   });
58 }
59 
GetAllOps()60 std::vector<quant::QuantizeRegionOp> QuantizeContext::GetAllOps() {
61   std::vector<quant::QuantizeRegionOp> all_ops;
62   all_ops.reserve(128);
63   func_.walk([&](quant::QuantizeRegionOp op) { all_ops.push_back(op); });
64   return all_ops;
65 }
66 
GetSignature(QuantizeRegionOp op)67 KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) {
68   KernelSpecs::Signature signature;
69   signature.reserve(op.input_specs().size() + op.output_specs().size());
70   for (int i = 0; i < op.getNumOperands(); ++i) {
71     DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature);
72   }
73   for (int i = 0; i < op.getNumResults(); ++i) {
74     DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature);
75   }
76   return signature;
77 }
78 
Handle(quant::QuantizeRegionOp op,llvm::SmallVectorImpl<Operation * > * new_items,bool * changed)79 LogicalResult QuantizeContext::Handle(
80     quant::QuantizeRegionOp op, llvm::SmallVectorImpl<Operation *> *new_items,
81     bool *changed) {
82   auto signature = GetSignature(op);
83   auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature);
84   if (!spec.hasValue()) {
85     op.emitWarning(
86         "Couldn't find kernel from the registration for quantization.");
87     return success();
88   }
89   switch (spec->type) {
90     case ScaleConstraintType::OutputInputFreeScale: {
91       // no propagation.
92       *changed |= false;
93       break;
94     }
95     case ScaleConstraintType::CustomScale: {
96       if (failed(spec->scale_fn(this, op, new_items, changed))) {
97         return failure();
98       }
99       break;
100     }
101     case ScaleConstraintType::OutputInputSameScale: {
102       auto params = GetQuantParamsForSameScaleConstraint(op);
103       if (EmptyParams(params)) {
104         *changed |= false;
105         break;
106       }
107       // propagate this params to all the quantizable ports.
108       if (failed(PropagateQuantParams(op, params, new_items, changed))) {
109         return failure();
110       }
111       break;
112     }
113     default: {
114       // TODO(fengliuai): implement the other types.
115       llvm_unreachable("no implementation.");
116       return failure();
117     }
118   }
119   return success();
120 }
121 
Finalize()122 LogicalResult QuantizeContext::Finalize() {
123   MLIRContext *context = func_.getContext();
124   func_.walk([&](quant::QuantizeRegionOp op) {
125     llvm::SmallVector<Attribute, 4> input_specs;
126     auto original_input_specs = op.input_specs().getValue();
127     for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
128       auto &state = states_manager_.GetOperandQuantState(op, i);
129       auto &requantize = states_manager_.GetOperandRequantizeState(op, i);
130       if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
131         input_specs.push_back(original_input_specs[i]);
132       } else if (requantize.pos == RequantizeState::ON_OUTPUT) {
133         input_specs.push_back(TypeAttr::get(requantize.params));
134       } else {
135         input_specs.push_back(TypeAttr::get(state.params));
136       }
137     }
138     op->setAttr("input_specs", ArrayAttr::get(context, input_specs));
139 
140     llvm::SmallVector<Attribute, 4> output_specs;
141     auto original_output_specs = op.output_specs().getValue();
142     for (int res = 0, e = op.getNumResults(); res != e; ++res) {
143       auto &state = states_manager_.GetResultQuantState(op, res);
144       auto &requantize = states_manager_.GetResultRequantizeState(op, res);
145       if (state.IsEmpty() && requantize.pos == RequantizeState::NO_REQUANTIZE) {
146         output_specs.push_back(original_output_specs[res]);
147       } else if (requantize.pos == RequantizeState::ON_INPUT) {
148         output_specs.push_back(TypeAttr::get(requantize.params));
149       } else {
150         output_specs.push_back(TypeAttr::get(state.params));
151       }
152     }
153     op->setAttr("output_specs", ArrayAttr::get(context, output_specs));
154   });
155   return success();
156 }
157 
DumpStates(QuantizeRegionOp current_op)158 void QuantizeContext::DumpStates(QuantizeRegionOp current_op) {
159   if (current_op) {
160     llvm::errs() << "\n\n\n" << current_op.logical_kernel() << "\n";
161   }
162   func_.walk([&](QuantizeRegionOp op) {
163     if (current_op == op) llvm::errs() << "===>>>";
164     llvm::errs() << op.logical_kernel() << " : (";
165     for (auto i = 0; i < op.getNumOperands(); ++i) {
166       if (auto params = GetOperandParams(op, i))
167         params.print(llvm::errs());
168       else
169         llvm::errs() << "_";
170       llvm::errs() << ",";
171     }
172     llvm::errs() << ") -> (";
173     for (auto i = 0; i < op.getNumResults(); ++i) {
174       if (auto params = GetResultParams(op, i))
175         params.print(llvm::errs());
176       else
177         llvm::errs() << "_";
178       llvm::errs() << ",";
179     }
180     llvm::errs() << ")\n";
181   });
182 }
183 
184 // A heuristic to get quantization parameters satisfies the same scale
185 // constraints:
186 // - If there are immutable states,
187 //   - use the single input, or,
188 //   - use the single output, or,
189 //   - use the first one in the collection,
190 // - use the single input if it is ready, or,
191 // - use the single output if it is ready, or,
192 // - use the first ready one in the collection.
GetQuantParamsForSameScaleConstraint(Operation * op)193 QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint(
194     Operation *op) {
195   // Two vector to collect Non-empty operands and results states.
196   std::vector<quant::QuantState *> mutable_states, immutable_states;
197   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
198     auto &state = states_manager_.GetOperandQuantState(op, i);
199     if (state.immutable) {
200       immutable_states.push_back(&state);
201     } else if (!state.IsEmpty()) {
202       mutable_states.push_back(&state);
203     }
204   }
205 
206   int immutable_operands_num = immutable_states.size();
207   int mutable_operands_num = mutable_states.size();
208   // Use the operand's state if it is immutable and it is the only one
209   // operand.
210   if (op->getNumOperands() == 1 && immutable_operands_num == 1) {
211     return immutable_states.front()->params;
212   }
213 
214   for (int i = 0, e = op->getNumResults(); i != e; ++i) {
215     auto &state = states_manager_.GetResultQuantState(op, i);
216     if (state.immutable) {
217       immutable_states.push_back(&state);
218     } else if (!state.IsEmpty()) {
219       mutable_states.push_back(&state);
220     }
221   }
222 
223   int immutable_results_num = immutable_states.size() - immutable_operands_num;
224   int mutable_results_num = mutable_states.size() - mutable_operands_num;
225   // Use the result's state if it is immutable and it is the only one result.
226   if (op->getNumResults() == 1 && immutable_results_num == 1) {
227     return immutable_states.back()->params;
228   }
229 
230   LLVM_DEBUG(llvm::dbgs()
231              << "Quantization parameters are not collected in an ideal place. "
232                 "Has to fallback values which might introduce errors.\n");
233 
234   // Use the first immutable state to quantize the rest operands and results.
235   if (!immutable_states.empty()) return immutable_states.front()->params;
236 
237   // If there are no immutable states, use the operand's state if it is the
238   // only one operand and has parameters propagated.
239   if (op->getNumOperands() == 1 && mutable_operands_num == 1) {
240     return mutable_states.front()->params;
241   }
242 
243   // If there are no immutable states, use the result's state if it is the
244   // only one result and has parameters propagated.
245   if (op->getNumResults() == 1 && mutable_results_num == 1) {
246     return mutable_states.back()->params;
247   }
248 
249   // Use the first propagated state to quantize the rest operands and results.
250   if (!mutable_states.empty()) return mutable_states.front()->params;
251 
252   // None operands/results have parameters propagated, skip this node for now.
253   return {};
254 }
255 
PropagateQuantParams(Operation * op,const QuantParams params,quant::AdjacentOperations * new_items,bool * changed)256 LogicalResult QuantizeContext::PropagateQuantParams(
257     Operation *op, const QuantParams params,
258     quant::AdjacentOperations *new_items, bool *changed) {
259   // Use the final state to set all the operands' parameters.
260   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
261     auto ele = op->getOperand(i).getType().cast<ShapedType>().getElementType();
262     if (ele.isa<FloatType>() && SetOperandParams(op, i, params)) {
263       *changed |= true;
264       new_items->push_back(op->getOperand(i).getDefiningOp());
265     }
266   }
267 
268   // Use the final state to set all the results' parameters.
269   for (int res = 0, e = op->getNumResults(); res != e; ++res) {
270     auto ele = op->getResult(res).getType().cast<ShapedType>().getElementType();
271     if (ele.isa<FloatType>() && SetResultParams(op, res, params)) {
272       auto users = op->getResult(res).getUsers();
273       *changed |= !users.empty();
274       new_items->append(users.begin(), users.end());
275     }
276   }
277   return success();
278 }
279 
InitializeState(quant::QuantizeRegionOp op,int index,bool as_result)280 int QuantizeContext::StatesManager::InitializeState(quant::QuantizeRegionOp op,
281                                                     int index, bool as_result) {
282   Attribute params_attr;
283   if (as_result) {
284     params_attr = op.output_specs()[index];
285   } else {
286     params_attr = op.input_specs()[index];
287   }
288   QuantParams params =
289       params_attr.cast<TypeAttr>().getValue().dyn_cast<QuantParams>();
290   bool immutable = !EmptyParams(params);
291   int next_state_index = states_.size();
292   states_.push_back({params, immutable});
293   if (as_result) {
294     result_states_.insert({{op, index}, next_state_index});
295   } else {
296     operand_states_.insert({{op, index}, next_state_index});
297   }
298   return next_state_index;
299 }
300 
InitializeOperandState(quant::QuantizeRegionOp op,int index,llvm::DenseMap<Value,int> * cache)301 void QuantizeContext::StatesManager::InitializeOperandState(
302     quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
303   Value in = op.getOperand(index);
304   auto cached = cache->insert({in, 0});
305   if (!cached.second) {
306     operand_states_.insert({{op, index}, cached.first->second});
307     return;
308   }
309   cached.first->second = InitializeState(op, index, /*as_result=*/false);
310 }
311 
InitializeResultState(quant::QuantizeRegionOp op,int index,llvm::DenseMap<Value,int> * cache)312 void QuantizeContext::StatesManager::InitializeResultState(
313     quant::QuantizeRegionOp op, int index, llvm::DenseMap<Value, int> *cache) {
314   auto res = op.getResult(index);
315   auto cached = cache->insert({res, 0});
316   if (!cached.second) {
317     result_states_.insert({{op, index}, cached.first->second});
318     return;
319   }
320   cached.first->second = InitializeState(op, index, /*as_result=*/true);
321 }
322 
SetConstantResultParams(Operation * op)323 bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) {
324   llvm_unreachable("no implementation.");
325   return false;
326 }
327 
SetResultParams(Operation * op,int res_index,QuantParams params)328 bool QuantizeContext::StatesManager::SetResultParams(Operation *op,
329                                                      int res_index,
330                                                      QuantParams params) {
331   auto &state = GetResultQuantState(op, res_index);
332   if (state.params == params) {
333     return false;
334   }
335   if (!state.IsEmpty()) {
336     auto &rescale = GetResultRequantizeState(op, res_index);
337     rescale.params = params;
338     rescale.pos = RequantizeState::ON_INPUT;
339     return false;
340   }
341   state.params = params;
342   return true;
343 }
344 
SetOperandParams(Operation * op,int index,QuantParams params)345 bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index,
346                                                       QuantParams params) {
347   auto &state = GetOperandQuantState(op, index);
348   if (state.params == params) {
349     return false;
350   }
351 
352   if (!state.IsEmpty()) {
353     auto &rescale = GetOperandRequantizeState(op, index);
354     rescale.params = params;
355     rescale.pos = RequantizeState::ON_OUTPUT;
356     return false;
357   }
358   state.params = params;
359   return true;
360 }
361 }  //  namespace quant
362 }  // namespace mlir
363