• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
27 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h"
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 
33 namespace mlir {
34 namespace TFL {
35 namespace tac {
36 namespace {
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/generated_transform_patterns.inc"
38 }  // namespace
39 
GetHardwareRewritePatterns(MLIRContext * context,const std::string & hardware)40 OwningRewritePatternList GetHardwareRewritePatterns(
41     MLIRContext* context, const std::string& hardware) {
42   auto* devce_hardware = GetTargetHardware(hardware);
43   if (devce_hardware == nullptr) return {context};
44   return devce_hardware->GetTransformations(context);
45 }
46 
IsSupported(Operation * op,const std::string & hardware)47 bool IsSupported(Operation* op, const std::string& hardware) {
48   auto* devce_hardware = GetTargetHardware(hardware);
49   if (devce_hardware == nullptr) return {};
50   return devce_hardware->IsOpSupported(op);
51 }
52 
53 // ================== Convert Quantized Op ============================
54 
55 // Walk through the func and convert the quantize ops to their float version.
ConvertQuantizedOpToFloat(mlir::FuncOp func,OpBuilder * builder)56 void ConvertQuantizedOpToFloat(mlir::FuncOp func, OpBuilder* builder) {
57   func.walk([&](Operation* op) {
58     // TODO(renjieliu): Find a generic way to deal with const ops.
59     if (op->hasTrait<OpTrait::IsTerminator>() ||
60         llvm::isa<TFL::QConstOp, TFL::ConstOp>(op) ||
61         llvm::isa<TFL::QConstOp, TFL::ConstOp, TF::ConstOp, ConstOp>(op))
62       return;
63 
64     bool int8_type_observed = false;
65     bool uint8_type_observed = false;
66     for (auto& input : op->getOpOperands()) {
67       auto input_type = input.get().getType();
68       if (IsQI8Type(input_type)) {
69         int8_type_observed = true;
70       } else if (IsQUI8Type(input_type)) {
71         uint8_type_observed = true;
72       }
73     }
74 
75     // TODO(renjieliu): We probably should check whether the op supports float
76     // execution to be safe. Although normally they should support float
77     // execution. Not Quantized ops.
78     if (!int8_type_observed && !uint8_type_observed) return;
79 
80     // Insert dequantize ops for every quantized input.
81     SmallVector<Value, 4> dequantized_inputs;
82     for (auto& input : op->getOpOperands()) {
83       auto input_type = input.get().getType();
84       if (IsQI8Type(input_type) || IsQUI8Type(input_type) ||
85           IsQI32Type(input_type)) {
86         auto dequantized_input_type =
87             mlir::quant::QuantizedType::castToExpressedType(input_type);
88         builder->setInsertionPoint(op);
89         auto dequantize_op = builder->create<TFL::DequantizeOp>(
90             op->getLoc(), dequantized_input_type, input.get());
91         dequantized_inputs.push_back(dequantize_op);
92       } else {
93         dequantized_inputs.push_back(input.get());
94       }
95     }
96 
97     // Result types.
98     SmallVector<Type, 4> result_types;
99     for (auto result_type : op->getResultTypes()) {
100       if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
101         auto dequantized_result_type =
102             mlir::quant::QuantizedType::castToExpressedType(result_type);
103         result_types.push_back(dequantized_result_type);
104       } else {
105         result_types.push_back(result_type);
106       }
107     }
108 
109     // Build the new float-versioned op.
110     OperationState state(op->getLoc(), op->getName());
111     state.operands = dequantized_inputs;
112     state.types = result_types;
113     state.attributes = op->getAttrs();
114     state.successors = op->getSuccessors();
115     builder->setInsertionPoint(op);
116     Operation* new_op = builder->createOperation(state);
117 
118     // Insert quantize ops for every outputs and rewrite.
119     for (int i = 0; i < op->getNumResults(); ++i) {
120       auto result = op->getResult(i);
121       auto result_type = result.getType();
122 
123       Value new_result = new_op->getResult(i);
124       if (IsQI8Type(result_type) || IsQUI8Type(result_type)) {
125         builder->setInsertionPoint(op);
126         TFL::QuantizeOp quant_op = builder->create<TFL::QuantizeOp>(
127             op->getLoc(), result_type, new_result, TypeAttr::get(result_type));
128         new_result = quant_op.getResult();
129       }
130 
131       // Rewire the outputs.
132       result.replaceAllUsesWith(new_result);
133     }
134 
135     // Remove the old op.
136     op->erase();
137   });
138 }
139 
140 // Fold quantized i32 (normally bias) into their float values.
141 struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
142   using OpRewritePattern<TFL::DequantizeOp>::OpRewritePattern;
143 
matchAndRewritemlir::TFL::tac::FoldQuantizedI32ToFloat144   LogicalResult matchAndRewrite(TFL::DequantizeOp dequant_op,
145                                 PatternRewriter& rewriter) const override {
146     // We only fold i32 -> float pattern.
147     auto input = dequant_op.input().getDefiningOp();
148     if (!input) return failure();
149 
150     auto input_dequant = llvm::dyn_cast_or_null<TFL::QConstOp>(input);
151     if (!input_dequant) return failure();
152 
153     if (!IsQI32Type(input_dequant.getType())) return failure();
154 
155     auto output_type =
156         dequant_op.output().getType().dyn_cast_or_null<ShapedType>();
157     if (!output_type || !output_type.getElementType().isF32()) return failure();
158 
159     auto input_type = input_dequant.getType().dyn_cast<ShapedType>();
160     // TODO(renjieliu): support UniformQuantizedPerAxisType.
161     auto q_type = input_type.getElementType()
162                       .dyn_cast_or_null<quant::UniformQuantizedType>();
163     if (!q_type) return failure();
164 
165     const float scale = q_type.getScale();
166     const float zp = q_type.getZeroPoint();
167 
168     auto input_values = input_dequant.value();
169 
170     // mapValues always takes a function returning APInt, even when the output
171     // is actually float.
172     using DequantizeFuncType = llvm::APInt(const llvm::APInt&);
173     auto dequantize_func = [&](const APInt& ap_int_value) -> APInt {
174       const int64_t int_value = ap_int_value.getSExtValue();
175 
176       const float real = (int_value - zp) * scale;
177 
178       auto real_int = absl::bit_cast<int32_t>(real);
179       return APInt(/*numBits=*/32, real_int);
180     };
181 
182     auto dequant_values = input_values.mapValues(
183         FloatType::getF32(rewriter.getContext()),
184         llvm::function_ref<DequantizeFuncType>(dequantize_func));
185     rewriter.replaceOpWithNewOp<TFL::ConstOp>(dequant_op, dequant_op.getType(),
186                                               dequant_values);
187 
188     return success();
189   }
190 };
191 
192 // If the quant op has no consumer, we will remove them.
193 struct RemoveUnusedQuant : public OpRewritePattern<TFL::QuantizeOp> {
194   using OpRewritePattern<TFL::QuantizeOp>::OpRewritePattern;
195 
matchAndRewritemlir::TFL::tac::RemoveUnusedQuant196   LogicalResult matchAndRewrite(TFL::QuantizeOp quant_op,
197                                 PatternRewriter& rewriter) const override {
198     if (!quant_op.getResult().use_empty()) return failure();
199 
200     rewriter.eraseOp(quant_op);
201     return success();
202   }
203 };
204 
OptimizeQuantizedOpToFloat(FuncOp func,MLIRContext * context)205 void OptimizeQuantizedOpToFloat(FuncOp func, MLIRContext* context) {
206   OwningRewritePatternList patterns(func.getContext());
207   patterns.insert<FoldQuantizedI32ToFloat, FoldQuantizeDequantize,
208                   RemoveUnusedQuant>(context);
209   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
210 }
211 
212 }  // namespace tac
213 }  // namespace TFL
214 }  // namespace mlir
215