1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
17
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/Support/LLVM.h" // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
27 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h"
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32
33 namespace mlir {
34 namespace TFL {
35 namespace tac {
36 namespace {
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/generated_transform_patterns.inc"
38 } // namespace
39
GetHardwareRewritePatterns(MLIRContext * context,const std::string & hardware)40 OwningRewritePatternList GetHardwareRewritePatterns(
41 MLIRContext* context, const std::string& hardware) {
42 auto* devce_hardware = GetTargetHardware(hardware);
43 if (devce_hardware == nullptr) return {context};
44 return devce_hardware->GetTransformations(context);
45 }
46
IsSupported(Operation * op,const std::string & hardware)47 bool IsSupported(Operation* op, const std::string& hardware) {
48 auto* devce_hardware = GetTargetHardware(hardware);
49 if (devce_hardware == nullptr) return {};
50 return devce_hardware->IsOpSupported(op);
51 }
52
53 // ================== Convert Quantized Op ============================
54
55 // Walk through the func and convert the quantize ops to their float version.
ConvertQuantizedOpToFloat(mlir::FuncOp func,OpBuilder * builder)56 void ConvertQuantizedOpToFloat(mlir::FuncOp func, OpBuilder* builder) {
57 func.walk([&](Operation* op) {
58 // TODO(renjieliu): Find a generic way to deal with const ops.
59 if (op->hasTrait<OpTrait::IsTerminator>() ||
60 llvm::isa<TFL::QConstOp, TFL::ConstOp>(op) ||
61 llvm::isa<TFL::QConstOp, TFL::ConstOp, TF::ConstOp, ConstOp>(op))
62 return;
63
64 bool int8_type_observed = false;
65 bool uint8_type_observed = false;
66 for (auto& input : op->getOpOperands()) {
67 auto input_type = input.get().getType();
68 if (IsQI8Type(input_type)) {
69 int8_type_observed = true;
70 } else if (IsQUI8Type(input_type)) {
71 uint8_type_observed = true;
72 }
73 }
74
75 // TODO(renjieliu): We probably should check whether the op supports float
76 // execution to be safe. Although normally they should support float
77 // execution. Not Quantized ops.
78 if (!int8_type_observed && !uint8_type_observed) return;
79
80 // Insert dequantize ops for every quantized input.
81 SmallVector<Value, 4> dequantized_inputs;
82 for (auto& input : op->getOpOperands()) {
83 auto input_type = input.get().getType();
84 if (IsQI8Type(input_type) || IsQUI8Type(input_type) ||
85 IsQI32Type(input_type)) {
86 auto dequantized_input_type =
87 mlir::quant::QuantizedType::castToExpressedType(input_type);
88 builder->setInsertionPoint(op);
89 auto dequantize_op = builder->create<TFL::DequantizeOp>(
90 op->getLoc(), dequantized_input_type, input.get());
91 dequantized_inputs.push_back(dequantize_op);
92 } else {
93 dequantized_inputs.push_back(input.get());
94 }
95 }
96
97 // Result types.
98 SmallVector<Type, 4> result_types;
99 for (auto result_type : op->getResultTypes()) {
100 if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
101 auto dequantized_result_type =
102 mlir::quant::QuantizedType::castToExpressedType(result_type);
103 result_types.push_back(dequantized_result_type);
104 } else {
105 result_types.push_back(result_type);
106 }
107 }
108
109 // Build the new float-versioned op.
110 OperationState state(op->getLoc(), op->getName());
111 state.operands = dequantized_inputs;
112 state.types = result_types;
113 state.attributes = op->getAttrs();
114 state.successors = op->getSuccessors();
115 builder->setInsertionPoint(op);
116 Operation* new_op = builder->createOperation(state);
117
118 // Insert quantize ops for every outputs and rewrite.
119 for (int i = 0; i < op->getNumResults(); ++i) {
120 auto result = op->getResult(i);
121 auto result_type = result.getType();
122
123 Value new_result = new_op->getResult(i);
124 if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
125 builder->setInsertionPoint(op);
126 TFL::QuantizeOp quant_op = builder->create<TFL::QuantizeOp>(
127 op->getLoc(), result_type, new_result, TypeAttr::get(result_type));
128 new_result = quant_op.getResult();
129 }
130
131 // Rewire the outputs.
132 result.replaceAllUsesWith(new_result);
133 }
134
135 // Remove the old op.
136 op->erase();
137 });
138 }
139
140 // Fold quantized i32 (normally bias) into their float values.
141 struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
142 using OpRewritePattern<TFL::DequantizeOp>::OpRewritePattern;
143
matchAndRewritemlir::TFL::tac::FoldQuantizedI32ToFloat144 LogicalResult matchAndRewrite(TFL::DequantizeOp dequant_op,
145 PatternRewriter& rewriter) const override {
146 // We only fold i32 -> float pattern.
147 auto input = dequant_op.input().getDefiningOp();
148 if (!input) return failure();
149
150 auto input_dequant = llvm::dyn_cast_or_null<TFL::QConstOp>(input);
151 if (!input_dequant) return failure();
152
153 if (!IsQI32Type(input_dequant.getType())) return failure();
154
155 auto output_type =
156 dequant_op.output().getType().dyn_cast_or_null<ShapedType>();
157 if (!output_type || !output_type.getElementType().isF32()) return failure();
158
159 auto input_type = input_dequant.getType().dyn_cast<ShapedType>();
160 // TODO(renjieliu): support UniformQuantizedPerAxisType.
161 auto q_type = input_type.getElementType()
162 .dyn_cast_or_null<quant::UniformQuantizedType>();
163 if (!q_type) return failure();
164
165 const float scale = q_type.getScale();
166 const float zp = q_type.getZeroPoint();
167
168 auto input_values = input_dequant.value();
169
170 // mapValues always takes a function returning APInt, even when the output
171 // is actually float.
172 using DequantizeFuncType = llvm::APInt(const llvm::APInt&);
173 auto dequantize_func = [&](const APInt& ap_int_value) -> APInt {
174 const int64_t int_value = ap_int_value.getSExtValue();
175
176 const float real = (int_value - zp) * scale;
177
178 auto real_int = absl::bit_cast<int32_t>(real);
179 return APInt(/*numBits=*/32, real_int);
180 };
181
182 auto dequant_values = input_values.mapValues(
183 FloatType::getF32(rewriter.getContext()),
184 llvm::function_ref<DequantizeFuncType>(dequantize_func));
185 rewriter.replaceOpWithNewOp<TFL::ConstOp>(dequant_op, dequant_op.getType(),
186 dequant_values);
187
188 return success();
189 }
190 };
191
192 // If the quant op has no consumer, we will remove them.
193 struct RemoveUnusedQuant : public OpRewritePattern<TFL::QuantizeOp> {
194 using OpRewritePattern<TFL::QuantizeOp>::OpRewritePattern;
195
matchAndRewritemlir::TFL::tac::RemoveUnusedQuant196 LogicalResult matchAndRewrite(TFL::QuantizeOp quant_op,
197 PatternRewriter& rewriter) const override {
198 if (!quant_op.getResult().use_empty()) return failure();
199
200 rewriter.eraseOp(quant_op);
201 return success();
202 }
203 };
204
OptimizeQuantizedOpToFloat(FuncOp func,MLIRContext * context)205 void OptimizeQuantizedOpToFloat(FuncOp func, MLIRContext* context) {
206 OwningRewritePatternList patterns(func.getContext());
207 patterns.insert<FoldQuantizedI32ToFloat, FoldQuantizeDequantize,
208 RemoveUnusedQuant>(context);
209 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
210 }
211
212 } // namespace tac
213 } // namespace TFL
214 } // namespace mlir
215