1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts a TFLite uint8 graph to the int8 domain, with adaptors at
17 // input and output tensors. This is needed because TOSA precision is
18 // implemented in the int8 domain. This pass does:
19 // 1. match TFL::QConst with uint8, generate TFL::QConst with int8 with value
20 // remapped.
21 // 2. insert tosa.RESCALE uint8 -> int8 if block argument (placeholder of graph)
22 // is uint8 typed.
23 // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8
24 // typed.
25
26 #include <climits>
27 #include <cstddef>
28 #include <cstdint>
29 #include <iterator>
30 #include <numeric>
31
32 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
42 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
43 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
44
45 #define PASS_NAME "tosa-convert-tfl-uint8"
46 #define DEBUG_TYPE PASS_NAME
47
48 namespace mlir {
49
50 namespace tosa {
51
52 namespace {
53 // Performs lowering to TOSA dialect.
54 class ConvertUint8ToInt8
55 : public PassWrapper<ConvertUint8ToInt8, FunctionPass> {
56 public:
ConvertUint8ToInt8()57 explicit ConvertUint8ToInt8() {}
58 void runOnFunction() override;
59 };
60
61 struct ConvertUint8QConstOp : public RewritePattern {
ConvertUint8QConstOpmlir::tosa::__anon3584545e0111::ConvertUint8QConstOp62 explicit ConvertUint8QConstOp(MLIRContext *context)
63 : RewritePattern(TFL::QConstOp::getOperationName(), 1, context) {}
64
matchAndRewritemlir::tosa::__anon3584545e0111::ConvertUint8QConstOp65 LogicalResult matchAndRewrite(Operation *op,
66 PatternRewriter &builder) const override {
67 auto tfl_qconst_op = cast<TFL::QConstOp>(op);
68
69 // Skip if it's not ranked tensor type.
70 auto output_type =
71 tfl_qconst_op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
72 if (!output_type)
73 return builder.notifyMatchFailure(op, "not ranked tensor");
74
75 // Skip if output is not per-tensor quantized type.
76 auto output_element_type =
77 output_type.getElementType()
78 .dyn_cast<mlir::quant::UniformQuantizedType>();
79 if (!output_element_type) return failure();
80
81 // Skip if output is not uint8.
82 if (output_element_type.isSigned() ||
83 output_element_type.getStorageTypeIntegralWidth() != 8) {
84 return failure();
85 }
86
87 mlir::DenseElementsAttr src_dense_attr =
88 tfl_qconst_op.value().cast<DenseElementsAttr>();
89
90 double type_range_min =
91 static_cast<double>(output_element_type.getStorageTypeMin() -
92 output_element_type.getZeroPoint()) *
93 output_element_type.getScale();
94 double type_range_max =
95 static_cast<double>(output_element_type.getStorageTypeMax() -
96 output_element_type.getZeroPoint()) *
97 output_element_type.getScale();
98 bool narrow_range =
99 output_element_type.getStorageTypeMin() == 1 ? true : false;
100
101 auto dst_qconst_type = TypeAttr::get(RankedTensorType::get(
102 output_type.getShape(),
103 buildQTypeFromMinMax(
104 builder, output_element_type.getExpressedType(),
105 builder.getF64FloatAttr(type_range_min),
106 builder.getF64FloatAttr(type_range_max),
107 builder.getI32IntegerAttr(
108 output_element_type.getStorageTypeIntegralWidth()),
109 0, true /* signed */, builder.getBoolAttr(narrow_range))));
110
111 Type dst_dense_element_type = builder.getIntegerType(8);
112 llvm::function_ref<APInt(const APInt &)> mapping =
113 [](const APInt &in) -> APInt {
114 int64_t in_i64 = in.getLimitedValue();
115 int64_t out_i64 = in_i64 - 128;
116 return APInt(8, out_i64, true);
117 };
118
119 auto dst_dense_attr =
120 src_dense_attr.mapValues(dst_dense_element_type, mapping);
121
122 builder.replaceOpWithNewOp<TFL::QConstOp>(op, dst_qconst_type,
123 dst_dense_attr);
124
125 return success();
126 }
127 };
128
convert_graph_uint8_tensor(mlir::MLIRContext & context,mlir::FuncOp & function)129 LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context,
130 mlir::FuncOp &function) {
131 size_t num_blocks_in_main = 0;
132 mlir::Region *region = function.getCallableRegion();
133 OpBuilder builder(&context);
134
135 auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8));
136 auto tmp_const_attr =
137 DenseElementsAttr::get(tmp_const_type, {static_cast<uint8_t>(0)});
138
139 for (mlir::Block &bb : region->getBlocks()) {
140 // Always have one block for each region right now.
141 num_blocks_in_main++;
142 if (num_blocks_in_main > 1) {
143 return function.emitError("Invalid MLIR: multiple blocks in a region");
144 }
145
146 if (!bb.isEntryBlock()) {
147 return function.emitError("Invalid MLIR: block must be entry block");
148 }
149
150 // Insert rescale uint8->int8 after placeholders.
151 for (Value arg : bb.getArguments()) {
152 auto uint8_type = arg.getType().dyn_cast<mlir::RankedTensorType>();
153 if (!uint8_type) continue;
154
155 auto uint8_element_type =
156 uint8_type.getElementType()
157 .dyn_cast<mlir::quant::UniformQuantizedType>();
158 if (!uint8_element_type) continue;
159
160 if (uint8_element_type.isSigned() ||
161 uint8_element_type.getStorageTypeIntegralWidth() != 8)
162 continue;
163
164 double type_range_min =
165 static_cast<double>(uint8_element_type.getStorageTypeMin() -
166 uint8_element_type.getZeroPoint()) *
167 uint8_element_type.getScale();
168 double type_range_max =
169 static_cast<double>(uint8_element_type.getStorageTypeMax() -
170 uint8_element_type.getZeroPoint()) *
171 uint8_element_type.getScale();
172 bool narrow_range =
173 uint8_element_type.getStorageTypeMin() == 1 ? true : false;
174
175 Type int8_type = RankedTensorType::get(
176 uint8_type.getShape(),
177 buildQTypeFromMinMax(
178 builder, uint8_element_type.getExpressedType(),
179 builder.getF64FloatAttr(type_range_min),
180 builder.getF64FloatAttr(type_range_max),
181 builder.getI32IntegerAttr(
182 uint8_element_type.getStorageTypeIntegralWidth()),
183 0, true /* signed */, builder.getBoolAttr(narrow_range)));
184
185 int32_t uint8_zp = uint8_element_type.getZeroPoint();
186 int32_t int8_zp = uint8_zp - 128;
187
188 // Keep original input_val use with tmp_val.
189 Value tmp_val = builder.create<TFL::ConstOp>(
190 function.getLoc(), tmp_const_type, tmp_const_attr);
191 arg.replaceAllUsesWith(tmp_val);
192 auto rescale_op = builder.create<tosa::RescaleOp>(
193 function.getLoc(), int8_type, arg,
194 builder.getI32IntegerAttr(uint8_zp),
195 builder.getI32IntegerAttr(int8_zp),
196 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
197 builder.getBoolAttr(true), builder.getBoolAttr(false),
198 builder.getBoolAttr(false));
199
200 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
201 bb.push_front(op_rescale_op);
202 tmp_val.replaceAllUsesWith(rescale_op.getResult());
203 tmp_val.getDefiningOp()->erase();
204 }
205
206 // Record types of original graph output before we convert intermediate
207 // tensor.
208 auto terminator = bb.getTerminator();
209 SmallVector<Type, 4> output_types;
210 for (Value val : terminator->getOperands()) {
211 output_types.push_back(val.getType());
212 }
213
214 // Convert intermediate tensor.
215 for (auto &op : bb) {
216 for (Value output_val : op.getResults()) {
217 // Skip if output value is not RankedTensorType.
218 auto output_type =
219 output_val.getType().dyn_cast<mlir::RankedTensorType>();
220 if (!output_type) continue;
221
222 // Skip if output value is not per-tensor quantized element type.
223 auto output_element_type =
224 output_type.getElementType()
225 .dyn_cast<mlir::quant::UniformQuantizedType>();
226 if (!output_element_type) continue;
227
228 // Skip if output is not uint8.
229 if (output_element_type.isSigned() ||
230 output_element_type.getStorageTypeIntegralWidth() != 8)
231 continue;
232
233 double type_range_min =
234 static_cast<double>(output_element_type.getStorageTypeMin() -
235 output_element_type.getZeroPoint()) *
236 output_element_type.getScale();
237 double type_range_max =
238 static_cast<double>(output_element_type.getStorageTypeMax() -
239 output_element_type.getZeroPoint()) *
240 output_element_type.getScale();
241 bool narrow_range =
242 output_element_type.getStorageTypeMin() == 1 ? true : false;
243
244 Type new_type = RankedTensorType::get(
245 output_type.getShape(),
246 buildQTypeFromMinMax(
247 builder, output_element_type.getExpressedType(),
248 builder.getF64FloatAttr(type_range_min),
249 builder.getF64FloatAttr(type_range_max),
250 builder.getI32IntegerAttr(
251 output_element_type.getStorageTypeIntegralWidth()),
252 0, true /* signed */, builder.getBoolAttr(narrow_range)));
253
254 output_val.setType(new_type);
255 }
256 }
257
258 if (terminator->getNumOperands() != output_types.size()) {
259 return function.emitError(
260 "Terminator's operand mismatch with number of outputs in graph");
261 }
262
263 // Insert int8->uint8 rescale before all terminator's operand.
264 for (int32_t i = 0; i < terminator->getNumOperands(); i++) {
265 auto defining_op = terminator->getOperand(i).getDefiningOp();
266 // skip if operand of terminator is block arg (nullptr in this case) or
267 // not
268 if (!defining_op) continue;
269 Value input_val = defining_op->getResult(0);
270
271 // Check if graph output is uint8 type.
272 auto uint8_output_type =
273 output_types[i].dyn_cast<mlir::RankedTensorType>();
274 if (!uint8_output_type) continue;
275
276 auto uint8_output_element_type =
277 uint8_output_type.getElementType()
278 .dyn_cast<mlir::quant::UniformQuantizedType>();
279 if (!uint8_output_element_type) continue;
280
281 if (uint8_output_element_type.isSigned() ||
282 uint8_output_element_type.getStorageTypeIntegralWidth() != 8)
283 continue;
284
285 // Check if output coming into terminator is int8 type.
286 auto int8_output_type = terminator->getOperand(i)
287 .getType()
288 .dyn_cast<mlir::RankedTensorType>();
289 if (!int8_output_type) continue;
290
291 auto int8_output_element_type =
292 int8_output_type.getElementType()
293 .dyn_cast<mlir::quant::UniformQuantizedType>();
294 if (!int8_output_element_type) continue;
295
296 if (!int8_output_element_type.isSigned() ||
297 int8_output_element_type.getStorageTypeIntegralWidth() != 8)
298 continue;
299
300 int32_t int8_zp = int8_output_element_type.getZeroPoint();
301 int32_t uint8_zp = uint8_output_element_type.getZeroPoint();
302
303 // Sanity check if uint8/int8's scale and zeropoint match.
304 if (((uint8_zp - int8_zp) != 128) ||
305 (int8_output_element_type.getScale() !=
306 uint8_output_element_type.getScale())) {
307 return terminator->emitError(
308 "convert_uint8_to_int8: scale mismatch at the output tensors");
309 }
310
311 // Keep original input_val use with tmp_val.
312 Value tmp_val = builder.create<TFL::ConstOp>(
313 function.getLoc(), tmp_const_type, tmp_const_attr);
314 input_val.replaceAllUsesWith(tmp_val);
315 auto rescale_op = builder.create<tosa::RescaleOp>(
316 function.getLoc(), uint8_output_type, input_val,
317 builder.getI32IntegerAttr(int8_zp),
318 builder.getI32IntegerAttr(uint8_zp),
319 builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}),
320 builder.getBoolAttr(true), builder.getBoolAttr(false),
321 builder.getBoolAttr(false));
322
323 Operation *op_rescale_op = static_cast<Operation *>(rescale_op);
324 bb.push_back(op_rescale_op);
325 op_rescale_op->moveBefore(terminator);
326 tmp_val.replaceAllUsesWith(rescale_op.getResult());
327 tmp_val.getDefiningOp()->erase();
328 }
329 }
330
331 return success();
332 }
333
runOnFunction()334 void ConvertUint8ToInt8::runOnFunction() {
335 OwningRewritePatternList patterns;
336 auto &ctx = getContext();
337 auto func = getFunction();
338
339 // Convert uint8 const tensor. const needs to be handled specifically.
340 patterns.insert<ConvertUint8QConstOp>(&ctx);
341 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
342
343 // Replace uint8 tensor in the graph and insert rescale as needed.
344 (void)convert_graph_uint8_tensor(ctx, func);
345 }
346
347 } // anonymous namespace
348
createConvertTFLUint8Pass()349 std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass() {
350 return std::make_unique<ConvertUint8ToInt8>();
351 }
352
353 static PassRegistration<ConvertUint8ToInt8> pass(
354 PASS_NAME, "Convert uint8 graph to int8.");
355
356 } // namespace tosa
357
358 } // namespace mlir
359