1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
33 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/PatternMatch.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
44 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
45
46 #define PASS_NAME "tosa-convert-tfl-uint8"
47 #define DEBUG_TYPE PASS_NAME
48
49 namespace mlir {
50 namespace tosa {
51 namespace {
52 #define GEN_PASS_CLASSES
53 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
54
55 // Performs lowering to TOSA dialect.
56 class ConvertUint8ToInt8
57 : public TosaConvertTFLUint8PassBase<ConvertUint8ToInt8> {
58 public:
ConvertUint8ToInt8()59 explicit ConvertUint8ToInt8() {}
60 void runOnFunction() override;
61 };
62
63 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon7fbb95430111::ConvertUint8QConstOp64 explicit ConvertUint8QConstOp(MLIRContext *context)
65 : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
66
matchAndRewritemlir::tosa::__anon7fbb95430111::ConvertUint8QConstOp67 LogicalResult matchAndRewrite(Operation *op,
68 PatternRewriter &builder) const override {
69 auto tfl_qconst_op = cast<TFL::QConstOp>(op);
70
71 // Skip if it's not ranked tensor type.
72 auto output_type =
73 tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
74 if (!output_type)
75 return builder.notifyMatchFailure(op, "not ranked tensor");
76
77 // Skip if output is not per-tensor quantized type.
78 auto output_element_type =
79 output_type.getElementType()
80 .dyn_cast<mlir::quant::UniformQuantizedType>();
81 if (!output_element_type) return failure();
82
83 // Skip if output is not uint8.
84 if (output_element_type.isSigned() ||
85 output_element_type.getStorageTypeIntegralWidth() != 8) {
86 return failure();
87 }
88
89 mlir::DenseElementsAttr src_dense_attr =
90 tfl_qconst_op.value().cast<DenseElementsAttr>();
91
92 double type_range_min =
93 static_cast<double>(output_element_type.getStorageTypeMin() -
94 output_element_type.getZeroPoint()) *
95 output_element_type.getScale();
96 double type_range_max =
97 static_cast<double>(output_element_type.getStorageTypeMax() -
98 output_element_type.getZeroPoint()) *
99 output_element_type.getScale();
100 bool narrow_range =
101 output_element_type.getStorageTypeMin() == 1 ? true : false;
102
103 auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
104 output_type.getShape(),
105 buildQTypeFromMinMax(
106 builder, output_element_type.getExpressedType(),
107 builder.getF64FloatAttr(type_range_min),
108 builder.getF64FloatAttr(type_range_max),
109 builder.getI32IntegerAttr(
110 output_element_type.getStorageTypeIntegralWidth()),
111 0, true /* signed */, builder.getBoolAttr(narrow_range))));
112
113 Type dst_dense_element_type = builder.getIntegerType(8);
114 llvm::function_ref<APInt(const APInt &)> mapping =
115 [](const APInt &in) -> APInt {
116 int64_t in_i64 = in.getLimitedValue();
117 int64_t out_i64 = in_i64 - 128;
118 return APInt(8, out_i64, true);
119 };
120
121 auto dst_dense_attr =
122 src_dense_attr.mapValues(dst_dense_element_type, mapping);
123
124 builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
125 dst_dense_attr);
126
127 return success();
128 }
129 };
130
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::FuncOp & function)131 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
132 mlir::FuncOp &function) {
133 size_t num_blocks_in_main = 0;
134 mlir::Region *region = function.getCallableRegion();
135 OpBuilder builder(&context);
136
137 auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
138 auto tmp_const_attr =
139 DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
140
141 for (mlir::Block &bb : region->getBlocks()) {
142 // Always have one block for each region right now.
143 num_blocks_in_main++;
144 if (num_blocks_in_main > 1) {
145 return function.emitError("Invalid MLIR: multiple blocks in a region");
146 }
147
148 if (!bb.isEntryBlock()) {
149 return function.emitError("Invalid MLIR: block must be entry block");
150 }
151
152 // Insert rescale uint8->int8 after placeholders.
153 for (Value arg : bb.getArguments()) {
154 auto uint8_type = arg.getType().dyn_cast<mlir::RankedTensorType>();
155 if (!uint8_type) continue;
156
157 auto uint8_element_type =
158 uint8_type.getElementType()
159 .dyn_cast<mlir::quant::UniformQuantizedType>();
160 if (!uint8_element_type) continue;
161
162 if (uint8_element_type.isSigned() ||
163 uint8_element_type.getStorageTypeIntegralWidth() != 8)
164 continue;
165
166 double type_range_min =
167 static_cast<double>(uint8_element_type.getStorageTypeMin() -
168 uint8_element_type.getZeroPoint()) *
169 uint8_element_type.getScale();
170 double type_range_max =
171 static_cast<double>(uint8_element_type.getStorageTypeMax() -
172 uint8_element_type.getZeroPoint()) *
173 uint8_element_type.getScale();
174 bool narrow_range =
175 uint8_element_type.getStorageTypeMin() == 1 ? true : false;
176
177 Type int8_type = RankedTensorType::get(
178 uint8_type.getShape(),
179 buildQTypeFromMinMax(
180 builder, uint8_element_type.getExpressedType(),
181 builder.getF64FloatAttr(type_range_min),
182 builder.getF64FloatAttr(type_range_max),
183 builder.getI32IntegerAttr(
184 uint8_element_type.getStorageTypeIntegralWidth()),
185 0, true /* signed */, builder.getBoolAttr(narrow_range)));
186
187 int32_t uint8_zp = uint8_element_type.getZeroPoint();
188 int32_t int8_zp = uint8_zp - 128;
189
190 // Keep original input_val use with tmp_val.
191 Value tmp_val = builder.create<TFL::ConstOp>(
192 function.getLoc(), tmp_const_type, tmp_const_attr);
193 arg.replaceAllUsesWith(tmp_val);
194 auto rescale_op = builder.create<tosa::RescaleOp>(
195 function.getLoc(), int8_type, arg,
196 builder.getI32IntegerAttr(uint8_zp),
197 builder.getI32IntegerAttr(int8_zp),
198 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
199 builder.getBoolAttr(true), builder.getBoolAttr(false),
200 builder.getBoolAttr(false));
201
202 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
203 bb.push_front(op_rescale_op);
204 tmp_val.replaceAllUsesWith(rescale_op.getResult());
205 tmp_val.getDefiningOp()->erase();
206 }
207
208 // Record types of original graph output before we convert intermediate
209 // tensor.
210 auto terminator = bb.getTerminator();
211 SmallVector<Type, 4> output_types;
212 for (Value val : terminator->getOperands()) {
213 output_types.push_back(val.getType());
214 }
215
216 // Convert intermediate tensor.
217 for (auto &op : bb) {
218 for (Value output_val : op.getResults()) {
219 // Skip if output value is not RankedTensorType.
220 auto output_type =
221 output_val.getType().dyn_cast<mlir::RankedTensorType>();
222 if (!output_type) continue;
223
224 // Skip if output value is not per-tensor quantized element type.
225 auto output_element_type =
226 output_type.getElementType()
227 .dyn_cast<mlir::quant::UniformQuantizedType>();
228 if (!output_element_type) continue;
229
230 // Skip if output is not uint8.
231 if (output_element_type.isSigned() ||
232 output_element_type.getStorageTypeIntegralWidth() != 8)
233 continue;
234
235 double type_range_min =
236 static_cast<double>(output_element_type.getStorageTypeMin() -
237 output_element_type.getZeroPoint()) *
238 output_element_type.getScale();
239 double type_range_max =
240 static_cast<double>(output_element_type.getStorageTypeMax() -
241 output_element_type.getZeroPoint()) *
242 output_element_type.getScale();
243 bool narrow_range =
244 output_element_type.getStorageTypeMin() == 1 ? true : false;
245
246 Type new_type = RankedTensorType::get(
247 output_type.getShape(),
248 buildQTypeFromMinMax(
249 builder, output_element_type.getExpressedType(),
250 builder.getF64FloatAttr(type_range_min),
251 builder.getF64FloatAttr(type_range_max),
252 builder.getI32IntegerAttr(
253 output_element_type.getStorageTypeIntegralWidth()),
254 0, true /* signed */, builder.getBoolAttr(narrow_range)));
255
256 output_val.setType(new_type);
257 }
258 }
259
260 if (terminator->getNumOperands() != output_types.size()) {
261 return function.emitError(
262 "Terminator's operand mismatch with number of outputs in graph");
263 }
264
265 // Insert int8->uint8 rescale before all terminator's operand.
266 for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
267 auto defining_op = terminator->getOperand(i).getDefiningOp();
268 // skip if operand of terminator is block arg (nullptr in this case) or
269 // not
270 if (!defining_op) continue;
271 Value input_val = defining_op->getResult(0);
272
273 // Check if graph output is uint8 type.
274 auto uint8_output_type =
275 output_types[i].dyn_cast<mlir::RankedTensorType>();
276 if (!uint8_output_type) continue;
277
278 auto uint8_output_element_type =
279 uint8_output_type.getElementType()
280 .dyn_cast<mlir::quant::UniformQuantizedType>();
281 if (!uint8_output_element_type) continue;
282
283 if (uint8_output_element_type.isSigned() ||
284 uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
285 continue;
286
287 // Check if output coming into terminator is int8 type.
288 auto int8_output_type = terminator->getOperand(i)
289 .getType()
290 .dyn_cast<mlir::RankedTensorType>();
291 if (!int8_output_type) continue;
292
293 auto int8_output_element_type =
294 int8_output_type.getElementType()
295 .dyn_cast<mlir::quant::UniformQuantizedType>();
296 if (!int8_output_element_type) continue;
297
298 if (!int8_output_element_type.isSigned() ||
299 int8_output_element_type.getStorageTypeIntegralWidth() != 8)
300 continue;
301
302 int32_t int8_zp = int8_output_element_type.getZeroPoint();
303 int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
304
305 // Sanity check if uint8/int8's scale and zeropoint match.
306 if (((uint8_zp - int8_zp) != 128) ||
307 (int8_output_element_type.getScale() !=
308 uint8_output_element_type.getScale())) {
309 return terminator->emitError(
310 "convert_uint8_to_int8: scale mismatch at the output tensors");
311 }
312
313 // Keep original input_val use with tmp_val.
314 Value tmp_val = builder.create<TFL::ConstOp>(
315 function.getLoc(), tmp_const_type, tmp_const_attr);
316 input_val.replaceAllUsesWith(tmp_val);
317 auto rescale_op = builder.create<tosa::RescaleOp>(
318 function.getLoc(), uint8_output_type, input_val,
319 builder.getI32IntegerAttr(int8_zp),
320 builder.getI32IntegerAttr(uint8_zp),
321 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
322 builder.getBoolAttr(true), builder.getBoolAttr(false),
323 builder.getBoolAttr(false));
324
325 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
326 bb.push_back(op_rescale_op);
327 op_rescale_op->moveBefore(terminator);
328 tmp_val.replaceAllUsesWith(rescale_op.getResult());
329 tmp_val.getDefiningOp()->erase();
330 }
331 }
332
333 return success();
334 }
335
runOnFunction()336 void ConvertUint8ToInt8::runOnFunction() {
337 OwningRewritePatternList patterns(&getContext());
338 auto &ctx = getContext();
339 auto func = getFunction();
340
341 // Convert uint8 const tensor. const needs to be handled specifically.
342 patterns.insert<ConvertUint8QConstOp>(&ctx);
343 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
344
345 // Replace uint8 tensor in the graph and insert rescale as needed.
346 (void)convert_graph_uint8_tensor(ctx, func);
347 }
348
349 } // anonymous namespace
350
createConvertTFLUint8Pass()351 std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass() {
352 return std::make_unique<ConvertUint8ToInt8>();
353 }
354
355 static PassRegistration<ConvertUint8ToInt8> pass(
356 PASS_NAME, "Convert uint8 graph to int8.");
357
358 } // namespace tosa
359
360 } // namespace mlir
361