• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepare the tflite fused ops for quantization.
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/Optional.h"
22 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 
31 //===----------------------------------------------------------------------===//
32 // The LoadQuantizationRecipe Pass.
33 //
34 namespace mlir {
35 namespace TFL {
36 
37 namespace {
38 
39 // This pass loads the quantization recipe for the TFLite ops to be quantized.
40 // Specifically, it extends the fused ops with their internal implementation as
41 // op regions. Each ops in the region produces results with element type
42 // AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
43 // defines the op quantization traits, which are used to propagate the
44 // quantization parameters by the following passes.
45 struct LoadQuantizationRecipe
46     : public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
47   void runOnFunction() override;
48 
49  private:
50   void Initialize(LSTMOp lstm, OpBuilder* builder);
51 
52   // Create LSTM gates with different weights for input, recurrent and
53   // cell state, and also the layer normalization parameters.
54   Operation* CreateGate(Location loc, Value in, Value in_w, Value rec,
55                         Value rec_w,
56                         llvm::Optional<std::pair<Value, Value>> cell,
57                         Value ln_w, Value ln_bias, OpBuilder* builder);
58 
59   Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias,
60                              OpBuilder* builder);
61 
62   // Add the internal implementation of the LSTM to its regions.
63   void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
64 
65   StringAttr none_af;
66   StringAttr fc_format;
67   BoolAttr keep_dims;
68   Type int8;
69   Type int16;
70   ConstantOp none_cst;
71 };
72 
Initialize(LSTMOp lstm,OpBuilder * builder)73 void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
74   Type expressed_type =
75       lstm.input().getType().cast<ShapedType>().getElementType();
76   Type int8_storage_type = builder->getIntegerType(8);
77   Type int16_storage_type = builder->getIntegerType(16);
78   auto flag = quant::QuantizationFlags::FlagValue::Signed;
79   int64_t int8_min = quant::QuantizedType::getDefaultMinimumForInteger(
80       flag, /*integralWidth=*/8);
81   int64_t int8_max = quant::QuantizedType::getDefaultMaximumForInteger(
82       flag, /*integralWidth=*/8);
83   int64_t int16_min = quant::QuantizedType::getDefaultMinimumForInteger(
84       flag, /*integralWidth=*/16);
85   int64_t int16_max = quant::QuantizedType::getDefaultMaximumForInteger(
86       flag, /*integralWidth=*/16);
87   auto any_int8 = quant::AnyQuantizedType::get(
88       flag, int8_storage_type, expressed_type, int8_min, int8_max);
89   auto any_int16 = quant::AnyQuantizedType::get(
90       flag, int16_storage_type, expressed_type, int16_min, int16_max);
91 
92   int8 = any_int8.castFromExpressedType(lstm.input().getType());
93   int16 = any_int16.castFromExpressedType(lstm.input().getType());
94 }
95 
CreateLayerNorm(Location loc,Value in,Value ln_w,Value ln_bias,OpBuilder * builder)96 Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
97                                                    Value ln_w, Value ln_bias,
98                                                    OpBuilder* builder) {
99   // Note that l2_normalization and add ops here are not the execution kernel
100   // implementation for layer_normalization and we just want to use them to
101   // model the quantization requirement.
102   auto l2_norm = builder->create<L2NormalizationOp>(loc, int16, in, none_af);
103   auto add = builder->create<AddOp>(loc, int16, in, l2_norm, none_af);
104   return builder->create<FullyConnectedOp>(loc, int16, add, ln_w, ln_bias,
105                                            none_af, fc_format, keep_dims);
106 }
107 
CreateGate(Location loc,Value in,Value in_w,Value rec,Value rec_w,llvm::Optional<std::pair<Value,Value>> cell,Value ln_w,Value ln_bias,OpBuilder * builder)108 Operation* LoadQuantizationRecipe::CreateGate(
109     Location loc, Value in, Value in_w, Value rec, Value rec_w,
110     llvm::Optional<std::pair<Value, Value>> cell, Value ln_w, Value ln_bias,
111     OpBuilder* builder) {
112   auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
113                                               none_af, fc_format, keep_dims);
114   auto s2 = builder->create<FullyConnectedOp>(loc, int16, rec, rec_w, none_cst,
115                                               none_af, fc_format, keep_dims);
116 
117   AddNOp s4;
118   if (cell.hasValue()) {
119     auto s3 = builder->create<MulOp>(loc, int16, cell.getValue().first,
120                                      cell.getValue().second, none_af);
121     s4 = builder->create<AddNOp>(
122         loc, int16,
123         llvm::ArrayRef<Value>(
124             {*s1.output().begin(), *s2.output().begin(), s3.output()}));
125 
126   } else {
127     s4 = builder->create<AddNOp>(
128         loc, int16,
129         llvm::ArrayRef<Value>({*s1.output().begin(), *s2.output().begin()}));
130   }
131 
132   auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
133 
134   if (cell.hasValue()) {
135     return builder->create<LogisticOp>(loc, int16, s5->getResult(0));
136   } else {
137     return builder->create<TanhOp>(loc, int16, s5->getResult(0));
138   }
139 }
140 
LoadForLSTMOp(LSTMOp lstm,OpBuilder * builder)141 void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) {
142   Initialize(lstm, builder);
143 
144   Region region;
145   region.push_back(new Block);
146   builder->setInsertionPointToEnd(&region.front());
147   Location loc = lstm.getLoc();
148   none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
149                                          builder->getUnitAttr());
150 
151   auto input_gate = CreateGate(
152       loc, lstm.input(), lstm.input_to_input_weights(),
153       lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
154       llvm::Optional<std::pair<Value, Value>>(
155           {lstm.input_cell_state(), lstm.cell_to_input_weights()}),
156       lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
157 
158   auto forget_gate = CreateGate(
159       loc, lstm.input(), lstm.input_to_forget_weights(),
160       lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
161       llvm::Optional<std::pair<Value, Value>>(
162           {lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
163       lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
164 
165   auto cell_gate = CreateGate(loc, lstm.input(), lstm.input_to_cell_weights(),
166                               lstm.input_activation_state(),
167                               lstm.recurrent_to_cell_weights(), llvm::None,
168                               lstm.cell_layer_norm_coefficients(),
169                               lstm.cell_bias(), builder);
170 
171   auto forget_cell_state = builder->create<MulOp>(
172       loc, int16, forget_gate->getResult(0), lstm.input_cell_state(), none_af);
173   auto input_cell_state = builder->create<MulOp>(
174       loc, int16, input_gate->getResult(0), cell_gate->getResult(0), none_af);
175   auto new_cell = builder->create<AddOp>(loc, int16, forget_cell_state.output(),
176                                          input_cell_state.output(), none_af);
177 
178   auto output_gate = CreateGate(
179       loc, lstm.input(), lstm.input_to_output_weights(),
180       lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
181       llvm::Optional<std::pair<Value, Value>>(
182           {new_cell, lstm.cell_to_output_weights()}),
183       lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
184 
185   auto new_cell_tanh = builder->create<TanhOp>(loc, int16, new_cell);
186   auto hidden_state = builder->create<MulOp>(
187       loc, int16, new_cell_tanh.output(), output_gate->getResult(0), none_af);
188   auto act = builder->create<FullyConnectedOp>(
189       loc, int8, hidden_state.output(), lstm.projection_weights(),
190       lstm.projection_bias(), none_af, fc_format, keep_dims);
191 
192   // TODO(fengliuai): define and register the op in the QuantOps Dialect.
193   OperationState return_state(loc, "tf_quant.pseudo_return", act.getResult(0),
194                               {int8}, {});
195   builder->createOperation(return_state);
196 
197   lstm.internal().takeBody(region);
198 }
199 
runOnFunction()200 void LoadQuantizationRecipe::runOnFunction() {
201   FuncOp func = getFunction();
202   OpBuilder builder(func);
203   none_af = builder.getStringAttr("NONE");
204   fc_format = builder.getStringAttr("DEFAULT");
205   keep_dims = builder.getBoolAttr(false);
206 
207   func.walk([&](Operation* op) {
208     if (auto lstm = llvm::dyn_cast<LSTMOp>(op)) {
209       LoadForLSTMOp(lstm, &builder);
210     }
211     // Handles other ops.
212   });
213 }
214 
215 }  // namespace
216 
217 // Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
218 // pass.
CreateLoadQuantizationRecipePass()219 std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
220   return absl::make_unique<LoadQuantizationRecipe>();
221 }
222 
223 static PassRegistration<LoadQuantizationRecipe> pass(
224     "tfl-load-recipe", "Load TFL op quantization recipe");
225 
226 }  // namespace TFL
227 }  // namespace mlir
228