• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25 
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31 
32 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
33 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
45 
46 #define PASS_NAME "tosa-convert-tfl-uint8"
47 #define DEBUG_TYPE PASS_NAME
48 
49 namespace mlir {
50 namespace tosa {
51 namespace {
52 #define GEN_PASS_CLASSES
53 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
54 
55 // Performs lowering to TOSA dialect.
56 class ConvertUint8ToInt8
57     : public TosaConvertTFLUint8PassBase<ConvertUint8ToInt8> {
58  public:
ConvertUint8ToInt8()59   explicit ConvertUint8ToInt8() {}
60   void runOnFunction() override;
61 };
62 
63 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon7fbb95430111::ConvertUint8QConstOp64   explicit ConvertUint8QConstOp(MLIRContext *context)
65       : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
66 
matchAndRewritemlir::tosa::__anon7fbb95430111::ConvertUint8QConstOp67   LogicalResult matchAndRewrite(Operation *op,
68                                 PatternRewriter &builder) const override {
69     auto tfl_qconst_op = cast<TFL::QConstOp>(op);
70 
71     // Skip if it's not ranked tensor type.
72     auto output_type =
73         tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
74     if (!output_type)
75       return builder.notifyMatchFailure(op, "not ranked tensor");
76 
77     // Skip if output is not per-tensor quantized type.
78     auto output_element_type =
79         output_type.getElementType()
80             .dyn_cast<mlir::quant::UniformQuantizedType>();
81     if (!output_element_type) return failure();
82 
83     // Skip if output is not uint8.
84     if (output_element_type.isSigned() ||
85         output_element_type.getStorageTypeIntegralWidth() != 8) {
86       return failure();
87     }
88 
89     mlir::DenseElementsAttr src_dense_attr =
90         tfl_qconst_op.value().cast<DenseElementsAttr>();
91 
92     double type_range_min =
93         static_cast<double>(output_element_type.getStorageTypeMin() -
94                             output_element_type.getZeroPoint()) *
95         output_element_type.getScale();
96     double type_range_max =
97         static_cast<double>(output_element_type.getStorageTypeMax() -
98                             output_element_type.getZeroPoint()) *
99         output_element_type.getScale();
100     bool narrow_range =
101         output_element_type.getStorageTypeMin() == 1 ? true : false;
102 
103     auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
104         output_type.getShape(),
105         buildQTypeFromMinMax(
106             builder, output_element_type.getExpressedType(),
107             builder.getF64FloatAttr(type_range_min),
108             builder.getF64FloatAttr(type_range_max),
109             builder.getI32IntegerAttr(
110                 output_element_type.getStorageTypeIntegralWidth()),
111             0, true /* signed */, builder.getBoolAttr(narrow_range))));
112 
113     Type dst_dense_element_type = builder.getIntegerType(8);
114     llvm::function_ref<APInt(const APInt &)> mapping =
115         [](const APInt &in) -> APInt {
116       int64_t in_i64 = in.getLimitedValue();
117       int64_t out_i64 = in_i64 - 128;
118       return APInt(8, out_i64, true);
119     };
120 
121     auto dst_dense_attr =
122         src_dense_attr.mapValues(dst_dense_element_type, mapping);
123 
124     builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
125                                               dst_dense_attr);
126 
127     return success();
128   }
129 };
130 
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::FuncOp & function)131 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
132                                          mlir::FuncOp &function) {
133   size_t num_blocks_in_main = 0;
134   mlir::Region *region = function.getCallableRegion();
135   OpBuilder builder(&context);
136 
137   auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
138   auto tmp_const_attr =
139       DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
140 
141   for (mlir::Block &bb : region->getBlocks()) {
142     // Always have one block for each region right now.
143     num_blocks_in_main++;
144     if (num_blocks_in_main > 1) {
145       return function.emitError("Invalid MLIR: multiple blocks in a region");
146     }
147 
148     if (!bb.isEntryBlock()) {
149       return function.emitError("Invalid MLIR: block must be entry block");
150     }
151 
152     // Insert rescale uint8->int8 after placeholders.
153     for (Value arg : bb.getArguments()) {
154       auto uint8_type = arg.getType().dyn_cast<mlir::RankedTensorType>();
155       if (!uint8_type) continue;
156 
157       auto uint8_element_type =
158           uint8_type.getElementType()
159               .dyn_cast<mlir::quant::UniformQuantizedType>();
160       if (!uint8_element_type) continue;
161 
162       if (uint8_element_type.isSigned() ||
163           uint8_element_type.getStorageTypeIntegralWidth() != 8)
164         continue;
165 
166       double type_range_min =
167           static_cast<double>(uint8_element_type.getStorageTypeMin() -
168                               uint8_element_type.getZeroPoint()) *
169           uint8_element_type.getScale();
170       double type_range_max =
171           static_cast<double>(uint8_element_type.getStorageTypeMax() -
172                               uint8_element_type.getZeroPoint()) *
173           uint8_element_type.getScale();
174       bool narrow_range =
175           uint8_element_type.getStorageTypeMin() == 1 ? true : false;
176 
177       Type int8_type = RankedTensorType::get(
178           uint8_type.getShape(),
179           buildQTypeFromMinMax(
180               builder, uint8_element_type.getExpressedType(),
181               builder.getF64FloatAttr(type_range_min),
182               builder.getF64FloatAttr(type_range_max),
183               builder.getI32IntegerAttr(
184                   uint8_element_type.getStorageTypeIntegralWidth()),
185               0, true /* signed */, builder.getBoolAttr(narrow_range)));
186 
187       int32_t uint8_zp = uint8_element_type.getZeroPoint();
188       int32_t int8_zp = uint8_zp - 128;
189 
190       // Keep original input_val use with tmp_val.
191       Value tmp_val = builder.create<TFL::ConstOp>(
192           function.getLoc(), tmp_const_type, tmp_const_attr);
193       arg.replaceAllUsesWith(tmp_val);
194       auto rescale_op = builder.create<tosa::RescaleOp>(
195           function.getLoc(), int8_type, arg,
196           builder.getI32IntegerAttr(uint8_zp),
197           builder.getI32IntegerAttr(int8_zp),
198           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
199           builder.getBoolAttr(true), builder.getBoolAttr(false),
200           builder.getBoolAttr(false));
201 
202       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
203       bb.push_front(op_rescale_op);
204       tmp_val.replaceAllUsesWith(rescale_op.getResult());
205       tmp_val.getDefiningOp()->erase();
206     }
207 
208     // Record types of original graph output before we convert intermediate
209     // tensor.
210     auto terminator = bb.getTerminator();
211     SmallVector<Type, 4> output_types;
212     for (Value val : terminator->getOperands()) {
213       output_types.push_back(val.getType());
214     }
215 
216     // Convert intermediate tensor.
217     for (auto &op : bb) {
218       for (Value output_val : op.getResults()) {
219         // Skip if output value is not RankedTensorType.
220         auto output_type =
221             output_val.getType().dyn_cast<mlir::RankedTensorType>();
222         if (!output_type) continue;
223 
224         // Skip if output value is not per-tensor quantized element type.
225         auto output_element_type =
226             output_type.getElementType()
227                 .dyn_cast<mlir::quant::UniformQuantizedType>();
228         if (!output_element_type) continue;
229 
230         // Skip if output is not uint8.
231         if (output_element_type.isSigned() ||
232             output_element_type.getStorageTypeIntegralWidth() != 8)
233           continue;
234 
235         double type_range_min =
236             static_cast<double>(output_element_type.getStorageTypeMin() -
237                                 output_element_type.getZeroPoint()) *
238             output_element_type.getScale();
239         double type_range_max =
240             static_cast<double>(output_element_type.getStorageTypeMax() -
241                                 output_element_type.getZeroPoint()) *
242             output_element_type.getScale();
243         bool narrow_range =
244             output_element_type.getStorageTypeMin() == 1 ? true : false;
245 
246         Type new_type = RankedTensorType::get(
247             output_type.getShape(),
248             buildQTypeFromMinMax(
249                 builder, output_element_type.getExpressedType(),
250                 builder.getF64FloatAttr(type_range_min),
251                 builder.getF64FloatAttr(type_range_max),
252                 builder.getI32IntegerAttr(
253                     output_element_type.getStorageTypeIntegralWidth()),
254                 0, true /* signed */, builder.getBoolAttr(narrow_range)));
255 
256         output_val.setType(new_type);
257       }
258     }
259 
260     if (terminator->getNumOperands() != output_types.size()) {
261       return function.emitError(
262           "Terminator's operand mismatch with number of outputs in graph");
263     }
264 
265     // Insert int8->uint8 rescale before all terminator's operand.
266     for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
267       auto defining_op = terminator->getOperand(i).getDefiningOp();
268       // skip if operand of terminator is block arg (nullptr in this case) or
269       // not
270       if (!defining_op) continue;
271       Value input_val = defining_op->getResult(0);
272 
273       // Check if graph output is uint8 type.
274       auto uint8_output_type =
275           output_types[i].dyn_cast<mlir::RankedTensorType>();
276       if (!uint8_output_type) continue;
277 
278       auto uint8_output_element_type =
279           uint8_output_type.getElementType()
280               .dyn_cast<mlir::quant::UniformQuantizedType>();
281       if (!uint8_output_element_type) continue;
282 
283       if (uint8_output_element_type.isSigned() ||
284           uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
285         continue;
286 
287       // Check if output coming into terminator is int8 type.
288       auto int8_output_type = terminator->getOperand(i)
289                                   .getType()
290                                   .dyn_cast<mlir::RankedTensorType>();
291       if (!int8_output_type) continue;
292 
293       auto int8_output_element_type =
294           int8_output_type.getElementType()
295               .dyn_cast<mlir::quant::UniformQuantizedType>();
296       if (!int8_output_element_type) continue;
297 
298       if (!int8_output_element_type.isSigned() ||
299           int8_output_element_type.getStorageTypeIntegralWidth() != 8)
300         continue;
301 
302       int32_t int8_zp = int8_output_element_type.getZeroPoint();
303       int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
304 
305       // Sanity check if uint8/int8's scale and zeropoint match.
306       if (((uint8_zp - int8_zp) != 128) ||
307           (int8_output_element_type.getScale() !=
308            uint8_output_element_type.getScale())) {
309         return terminator->emitError(
310             "convert_uint8_to_int8: scale mismatch at the output tensors");
311       }
312 
313       // Keep original input_val use with tmp_val.
314       Value tmp_val = builder.create<TFL::ConstOp>(
315           function.getLoc(), tmp_const_type, tmp_const_attr);
316       input_val.replaceAllUsesWith(tmp_val);
317       auto rescale_op = builder.create<tosa::RescaleOp>(
318           function.getLoc(), uint8_output_type, input_val,
319           builder.getI32IntegerAttr(int8_zp),
320           builder.getI32IntegerAttr(uint8_zp),
321           builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
322           builder.getBoolAttr(true), builder.getBoolAttr(false),
323           builder.getBoolAttr(false));
324 
325       Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
326       bb.push_back(op_rescale_op);
327       op_rescale_op->moveBefore(terminator);
328       tmp_val.replaceAllUsesWith(rescale_op.getResult());
329       tmp_val.getDefiningOp()->erase();
330     }
331   }
332 
333   return success();
334 }
335 
runOnFunction()336 void ConvertUint8ToInt8::runOnFunction() {
337   OwningRewritePatternList patterns(&getContext());
338   auto &ctx = getContext();
339   auto func = getFunction();
340 
341   // Convert uint8 const tensor. const needs to be handled specifically.
342   patterns.insert<ConvertUint8QConstOp>(&ctx);
343   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
344 
345   // Replace uint8 tensor in the graph and insert rescale as needed.
346   (void)convert_graph_uint8_tensor(ctx, func);
347 }
348 
349 }  // anonymous namespace
350 
createConvertTFLUint8Pass()351 std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass() {
352   return std::make_unique<ConvertUint8ToInt8>();
353 }
354 
355 static PassRegistration<ConvertUint8ToInt8> pass(
356     PASS_NAME, "Convert uint8 graph to int8.");
357 
358 }  // namespace tosa
359 
360 }  // namespace mlir
361