1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepare the tflite fused ops for quantization.
17
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/Optional.h"
22 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30
31 //===----------------------------------------------------------------------===//
32 // The LoadQuantizationRecipe Pass.
33 //
34 namespace mlir {
35 namespace TFL {
36
37 namespace {
38
39 // This pass loads the quantization recipe for the TFLite ops to be quantized.
40 // Specifically, it extends the fused ops with their internal implementation as
41 // op regions. Each ops in the region produces results with element type
42 // AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
43 // defines the op quantization traits, which are used to propagate the
44 // quantization parameters by the following passes.
45 struct LoadQuantizationRecipe
46 : public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
47 void runOnFunction() override;
48
49 private:
50 void Initialize(LSTMOp lstm, OpBuilder* builder);
51
52 // Create LSTM gates with different weights for input, recurrent and
53 // cell state, and also the layer normalization parameters.
54 Operation* CreateGate(Location loc, Value in, Value in_w, Value rec,
55 Value rec_w,
56 llvm::Optional<std::pair<Value, Value>> cell,
57 Value ln_w, Value ln_bias, OpBuilder* builder);
58
59 Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias,
60 OpBuilder* builder);
61
62 // Add the internal implementation of the LSTM to its regions.
63 void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
64
65 StringAttr none_af;
66 StringAttr fc_format;
67 BoolAttr keep_dims;
68 Type int8;
69 Type int16;
70 ConstantOp none_cst;
71 };
72
Initialize(LSTMOp lstm,OpBuilder * builder)73 void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
74 Type expressed_type =
75 lstm.input().getType().cast<ShapedType>().getElementType();
76 Type int8_storage_type = builder->getIntegerType(8);
77 Type int16_storage_type = builder->getIntegerType(16);
78 auto flag = quant::QuantizationFlags::FlagValue::Signed;
79 int64_t int8_min = quant::QuantizedType::getDefaultMinimumForInteger(
80 flag, /*integralWidth=*/8);
81 int64_t int8_max = quant::QuantizedType::getDefaultMaximumForInteger(
82 flag, /*integralWidth=*/8);
83 int64_t int16_min = quant::QuantizedType::getDefaultMinimumForInteger(
84 flag, /*integralWidth=*/16);
85 int64_t int16_max = quant::QuantizedType::getDefaultMaximumForInteger(
86 flag, /*integralWidth=*/16);
87 auto any_int8 = quant::AnyQuantizedType::get(
88 flag, int8_storage_type, expressed_type, int8_min, int8_max);
89 auto any_int16 = quant::AnyQuantizedType::get(
90 flag, int16_storage_type, expressed_type, int16_min, int16_max);
91
92 int8 = any_int8.castFromExpressedType(lstm.input().getType());
93 int16 = any_int16.castFromExpressedType(lstm.input().getType());
94 }
95
CreateLayerNorm(Location loc,Value in,Value ln_w,Value ln_bias,OpBuilder * builder)96 Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
97 Value ln_w, Value ln_bias,
98 OpBuilder* builder) {
99 // Note that l2_normalization and add ops here are not the execution kernel
100 // implementation for layer_normalization and we just want to use them to
101 // model the quantization requirement.
102 auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
103 auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
104 return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
105 none_af, fc_format, keep_dims);
106 }
107
CreateGate(Location loc,Value in,Value in_w,Value rec,Value rec_w,llvm::Optional<std::pair<Value,Value>> cell,Value ln_w,Value ln_bias,OpBuilder * builder)108 Operation* LoadQuantizationRecipe::CreateGate(
109 Location loc, Value in, Value in_w, Value rec, Value rec_w,
110 llvm::Optional<std::pair<Value, Value>> cell, Value ln_w, Value ln_bias,
111 OpBuilder* builder) {
112 auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
113 none_af, fc_format, keep_dims);
114 auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
115 none_af, fc_format, keep_dims);
116
117 AddNOp s4;
118 if (cell.hasValue()) {
119 auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
120 cell.getValue().second, none_af);
121 s4 = builder->create<AddNOp>(
122 loc, int16,
123 llvm::ArrayRef<Value>(
124 {*s1.output().begin(), *s2.output().begin(), s3.output()}));
125
126 } else {
127 s4 = builder->create<AddNOp>(
128 loc, int16,
129 llvm::ArrayRef<Value>({*s1.output().begin(), *s2.output().begin()}));
130 }
131
132 auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
133
134 if (cell.hasValue()) {
135 return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
136 } else {
137 return builder->create<TanhOp>(loc, int16, s5->getResult(0));
138 }
139 }
140
LoadForLSTMOp(LSTMOp lstm,OpBuilder * builder)141 void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
142 Initialize(lstm, builder);
143
144 Region region;
145 region.push_back(new Block);
146 builder->setInsertionPointToEnd(®ion.front());
147 Location loc = lstm.getLoc();
148 none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
149 builder->getUnitAttr());
150
151 auto input_gate = CreateGate(
152 loc, lstm.input(), lstm.input_to_input_weights(),
153 lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
154 llvm::Optional<std::pair<Value, Value>>(
155 {lstm.input_cell_state(), lstm.cell_to_input_weights()}),
156 lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
157
158 auto forget_gate = CreateGate(
159 loc, lstm.input(), lstm.input_to_forget_weights(),
160 lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
161 llvm::Optional<std::pair<Value, Value>>(
162 {lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
163 lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
164
165 auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
166 lstm.input_activation_state(),
167 lstm.recurrent_to_cell_weights(), llvm::None,
168 lstm.cell_layer_norm_coefficients(),
169 lstm.cell_bias(), builder);
170
171 auto forget_cell_state = builder->create<MulOp>(
172 loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
173 auto input_cell_state = builder->create<MulOp>(
174 loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
175 auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
176 input_cell_state.output(), none_af);
177
178 auto output_gate = CreateGate(
179 loc, lstm.input(), lstm.input_to_output_weights(),
180 lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
181 llvm::Optional<std::pair<Value, Value>>(
182 {new_cell, lstm.cell_to_output_weights()}),
183 lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
184
185 auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
186 auto hidden_state = builder->create<MulOp>(
187 loc, int16, new_cell_tanh.output(), output_gate->getResult(0), none_af);
188 auto act = builder->create<FullyConnectedOp>(
189 loc, int8, hidden_state.output(), lstm.projection_weights(),
190 lstm.projection_bias(), none_af, fc_format, keep_dims);
191
192 // TODO(fengliuai): define and register the op in the QuantOps Dialect.
193 OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
194 {int8}, {});
195 builder->createOperation(return_state);
196
197 lstm.internal().takeBody(region);
198 }
199
runOnFunction()200 void LoadQuantizationRecipe::runOnFunction() {
201 FuncOp func = getFunction();
202 OpBuilder builder(func);
203 none_af = builder.getStringAttr("NONE");
204 fc_format = builder.getStringAttr("DEFAULT");
205 keep_dims = builder.getBoolAttr(false);
206
207 func.walk([&](Operation* op) {
208 if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
209 LoadForLSTMOp(lstm, &builder);
210 }
211 // Handles other ops.
212 });
213 }
214
215 } // namespace
216
217 // Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
218 // pass.
CreateLoadQuantizationRecipePass()219 std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
220 return absl::make_unique<LoadQuantizationRecipe>();
221 }
222
223 static PassRegistration<LoadQuantizationRecipe> pass(
224 "tfl-load-recipe", "Load TFL op quantization recipe");
225
226 } // namespace TFL
227 } // namespace mlir
228