• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
44 
45 #define PASS_NAME "tosa-convert-tfl-uint8"
46 #define DEBUG_TYPE PASS_NAME
47 
48 namespace mlir {
49 
50 namespace tosa {
51 
52 namespace {
53 // Performs lowering to TOSA dialect.
54 class ConvertUint8ToInt8
55     : public PassWrapper<ConvertUint8ToInt8, FunctionPass> {
56  public:
ConvertUint8ToInt8()57   explicit ConvertUint8ToInt8() {}
58   void runOnFunction() override;
59 };
60 
61 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon3584545e0111::ConvertUint8QConstOp62   explicit ConvertUint8QConstOp(MLIRContext *context)
63       : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
64 
matchAndRewritemlir::tosa::__anon3584545e0111::ConvertUint8QConstOp65   LogicalResult matchAndRewrite(Operation *op,
66                                 PatternRewriter &builder) const override {
67     auto tfl_qconst_op = cast<TFL::QConstOp>(op);
68 
69     // Skip if it's not ranked tensor type.
70     auto output_type =
71         tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
72     if (!output_type)
73       return builder.notifyMatchFailure(op, "not ranked tensor");
74 
75     // Skip if output is not per-tensor quantized type.
76     auto output_element_type =
77         output_type.getElementType()
78             .dyn_cast<mlir::quant::UniformQuantizedType>();
79     if (!output_element_type) return failure();
80 
81     // Skip if output is not uint8.
82     if (output_element_type.isSigned() ||
83         output_element_type.getStorageTypeIntegralWidth() != 8) {
84       return failure();
85     }
86 
87     mlir::DenseElementsAttr src_dense_attr =
88         tfl_qconst_op.value().cast<DenseElementsAttr>();
89 
90     double type_range_min =
91         static_cast<double>(output_element_type.getStorageTypeMin() -
92                             output_element_type.getZeroPoint()) *
93         output_element_type.getScale();
94     double type_range_max =
95         static_cast<double>(output_element_type.getStorageTypeMax() -
96                             output_element_type.getZeroPoint()) *
97         output_element_type.getScale();
98     bool narrow_range =
99         output_element_type.getStorageTypeMin() == 1 ? true : false;
100 
101     auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
102         output_type.getShape(),
103         buildQTypeFromMinMax(
104             builder, output_element_type.getExpressedType(),
105             builder.getF64FloatAttr(type_range_min),
106             builder.getF64FloatAttr(type_range_max),
107             builder.getI32IntegerAttr(
108                 output_element_type.getStorageTypeIntegralWidth()),
109             0, true /* signed */, builder.getBoolAttr(narrow_range))));
110 
111     Type dst_dense_element_type = builder.getIntegerType(8);
112     llvm::function_ref<APInt(const APInt &)> mapping =
113         [](const APInt &in) -> APInt {
114       int64_t in_i64 = in.getLimitedValue();
115       int64_t out_i64 = in_i64 - 128;
116       return APInt(8, out_i64, true);
117     };
118 
119     auto dst_dense_attr =
120         src_dense_attr.mapValues(dst_dense_element_type, mapping);
121 
122     builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
123                                               dst_dense_attr);
124 
125     return success();
126   }
127 };
128 
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::FuncOp & function)129 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
130                                          mlir::FuncOp &function) {
131   size_t num_blocks_in_main = 0;
132   mlir::Region *region = function.getCallableRegion();
133   OpBuilder builder(&context);
134 
135   auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
136   auto tmp_const_attr =
137       DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
138 
139   for (mlir::Block &bb : region->getBlocks()) {
140     // Always have one block for each region right now.
141     num_blocks_in_main++;
142     if (num_blocks_in_main > 1) {
143       return function.emitError("Invalid MLIR: multiple blocks in a region");
144     }
145 
146     if (!bb.isEntryBlock()) {
147       return function.emitError("Invalid MLIR: block must be entry block");
148     }
149 
150     // Insert rescale uint8->int8 after placeholders.
151     for (Value arg : bb.getArguments()) {
152       auto uint8_type = arg.getType().dyn_cast<mlir::RankedTensorType>();
153       if (!uint8_type) continue;
154 
155       auto uint8_element_type =
156           uint8_type.getElementType()
157               .dyn_cast<mlir::quant::UniformQuantizedType>();
158       if (!uint8_element_type) continue;
159 
160       if (uint8_element_type.isSigned() ||
161           uint8_element_type.getStorageTypeIntegralWidth() != 8)
162         continue;
163 
164       double type_range_min =
165           static_cast<double>(uint8_element_type.getStorageTypeMin() -
166                               uint8_element_type.getZeroPoint()) *
167           uint8_element_type.getScale();
168       double type_range_max =
169           static_cast<double>(uint8_element_type.getStorageTypeMax() -
170                               uint8_element_type.getZeroPoint()) *
171           uint8_element_type.getScale();
172       bool narrow_range =
173           uint8_element_type.getStorageTypeMin() == 1 ? true : false;
174 
175       Type int8_type = RankedTensorType::get(
176           uint8_type.getShape(),
177           buildQTypeFromMinMax(
178               builder, uint8_element_type.getExpressedType(),
179               builder.getF64FloatAttr(type_range_min),
180               builder.getF64FloatAttr(type_range_max),
181               builder.getI32IntegerAttr(
182                   uint8_element_type.getStorageTypeIntegralWidth()),
183               0, true /* signed */, builder.getBoolAttr(narrow_range)));
184 
185       int32_t uint8_zp = uint8_element_type.getZeroPoint();
186       int32_t int8_zp = uint8_zp - 128;
187 
188       // Keep original input_val use with tmp_val.
189       Value tmp_val = builder.create<TFL::ConstOp>(
190           function.getLoc(), tmp_const_type, tmp_const_attr);
191       arg.replaceAllUsesWith(tmp_val);
192       auto rescale_op = builder.create<tosa::RescaleOp>(
193           function.getLoc(), int8_type, arg,
194           builder.getI32IntegerAttr(uint8_zp),
195           builder.getI32IntegerAttr(int8_zp),
196           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
197           builder.getBoolAttr(true), builder.getBoolAttr(false),
198           builder.getBoolAttr(false));
199 
200       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
201       bb.push_front(op_rescale_op);
202       tmp_val.replaceAllUsesWith(rescale_op.getResult());
203       tmp_val.getDefiningOp()->erase();
204     }
205 
206     // Record types of original graph output before we convert intermediate
207     // tensor.
208     auto terminator = bb.getTerminator();
209     SmallVector<Type, 4> output_types;
210     for (Value val : terminator->getOperands()) {
211       output_types.push_back(val.getType());
212     }
213 
214     // Convert intermediate tensor.
215     for (auto &op : bb) {
216       for (Value output_val : op.getResults()) {
217         // Skip if output value is not RankedTensorType.
218         auto output_type =
219             output_val.getType().dyn_cast<mlir::RankedTensorType>();
220         if (!output_type) continue;
221 
222         // Skip if output value is not per-tensor quantized element type.
223         auto output_element_type =
224             output_type.getElementType()
225                 .dyn_cast<mlir::quant::UniformQuantizedType>();
226         if (!output_element_type) continue;
227 
228         // Skip if output is not uint8.
229         if (output_element_type.isSigned() ||
230             output_element_type.getStorageTypeIntegralWidth() != 8)
231           continue;
232 
233         double type_range_min =
234             static_cast<double>(output_element_type.getStorageTypeMin() -
235                                 output_element_type.getZeroPoint()) *
236             output_element_type.getScale();
237         double type_range_max =
238             static_cast<double>(output_element_type.getStorageTypeMax() -
239                                 output_element_type.getZeroPoint()) *
240             output_element_type.getScale();
241         bool narrow_range =
242             output_element_type.getStorageTypeMin() == 1 ? true : false;
243 
244         Type new_type = RankedTensorType::get(
245             output_type.getShape(),
246             buildQTypeFromMinMax(
247                 builder, output_element_type.getExpressedType(),
248                 builder.getF64FloatAttr(type_range_min),
249                 builder.getF64FloatAttr(type_range_max),
250                 builder.getI32IntegerAttr(
251                     output_element_type.getStorageTypeIntegralWidth()),
252                 0, true /* signed */, builder.getBoolAttr(narrow_range)));
253 
254         output_val.setType(new_type);
255       }
256     }
257 
258     if (terminator->getNumOperands() != output_types.size()) {
259       return function.emitError(
260           "Terminator's operand mismatch with number of outputs in graph");
261     }
262 
263     // Insert int8->uint8 rescale before all terminator's operand.
264     for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
265       auto defining_op = terminator->getOperand(i).getDefiningOp();
266       // skip if operand of terminator is block arg (nullptr in this case) or
267       // not
268       if (!defining_op) continue;
269       Value input_val = defining_op->getResult(0);
270 
271       // Check if graph output is uint8 type.
272       auto uint8_output_type =
273           output_types[i].dyn_cast<mlir::RankedTensorType>();
274       if (!uint8_output_type) continue;
275 
276       auto uint8_output_element_type =
277           uint8_output_type.getElementType()
278               .dyn_cast<mlir::quant::UniformQuantizedType>();
279       if (!uint8_output_element_type) continue;
280 
281       if (uint8_output_element_type.isSigned() ||
282           uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
283         continue;
284 
285       // Check if output coming into terminator is int8 type.
286       auto int8_output_type = terminator->getOperand(i)
287                                   .getType()
288                                   .dyn_cast<mlir::RankedTensorType>();
289       if (!int8_output_type) continue;
290 
291       auto int8_output_element_type =
292           int8_output_type.getElementType()
293               .dyn_cast<mlir::quant::UniformQuantizedType>();
294       if (!int8_output_element_type) continue;
295 
296       if (!int8_output_element_type.isSigned() ||
297           int8_output_element_type.getStorageTypeIntegralWidth() != 8)
298         continue;
299 
300       int32_t int8_zp = int8_output_element_type.getZeroPoint();
301       int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
302 
303       // Sanity check if uint8/int8's scale and zeropoint match.
304       if (((uint8_zp - int8_zp) != 128) ||
305           (int8_output_element_type.getScale() !=
306            uint8_output_element_type.getScale())) {
307         return terminator->emitError(
308             "convert_uint8_to_int8: scale mismatch at the output tensors");
309       }
310 
311       // Keep original input_val use with tmp_val.
312       Value tmp_val = builder.create<TFL::ConstOp>(
313           function.getLoc(), tmp_const_type, tmp_const_attr);
314       input_val.replaceAllUsesWith(tmp_val);
315       auto rescale_op = builder.create<tosa::RescaleOp>(
316           function.getLoc(), uint8_output_type, input_val,
317           builder.getI32IntegerAttr(int8_zp),
318           builder.getI32IntegerAttr(uint8_zp),
319           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
320           builder.getBoolAttr(true), builder.getBoolAttr(false),
321           builder.getBoolAttr(false));
322 
323       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
324       bb.push_back(op_rescale_op);
325       op_rescale_op->moveBefore(terminator);
326       tmp_val.replaceAllUsesWith(rescale_op.getResult());
327       tmp_val.getDefiningOp()->erase();
328     }
329   }
330 
331   return success();
332 }
333 
runOnFunction()334 void ConvertUint8ToInt8::runOnFunction() {
335   OwningRewritePatternList patterns;
336   auto &ctx = getContext();
337   auto func = getFunction();
338 
339   // Convert uint8 const tensor. const needs to be handled specifically.
340   patterns.insert<ConvertUint8QConstOp>(&ctx);
341   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
342 
343   // Replace uint8 tensor in the graph and insert rescale as needed.
344   (void)convert_graph_uint8_tensor(ctx, func);
345 }
346 
347 }  // anonymous namespace
348 
createConvertTFLUint8Pass()349 std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass() {
350   return std::make_unique<ConvertUint8ToInt8>();
351 }
352 
353 static PassRegistration<ConvertUint8ToInt8> pass(
354     PASS_NAME, "Convert uint8 graph to int8.");
355 
356 }  // namespace tosa
357 
358 }  // namespace mlir
359