• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with tf.FakeQuant* ops.
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
20 
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
28 
29 namespace mlir {
30 namespace TFL {
31 
32 template <class TFFakeQuantOp>
33 struct FetchMinMaxAttrs {
34   using AttrType = FloatAttr;
operatorFetchMinMaxAttrs35   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
36                   AttrType &max_value) const {
37     min_value = tf_op.minAttr();
38     max_value = tf_op.maxAttr();
39     return true;  // Succesfully matched and fetched.
40   }
41 };
42 
43 template <class TFFakeQuantOp>
44 struct FetchConstantMinMaxInputs {
45   using AttrType = DenseFPElementsAttr;
operatorFetchConstantMinMaxInputs46   bool operator()(TFFakeQuantOp tf_op, AttrType &min_value,
47                   AttrType &max_value) const {
48     Value min = tf_op.min(), max = tf_op.max();
49     if (!matchPattern(min, m_Constant(&min_value))) {
50       return false;
51     }
52     if (!matchPattern(max, m_Constant(&max_value))) {
53       return false;
54     }
55     return true;  // Succesfully matched and fetched.
56   }
57 };
58 
59 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
60 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op
61 // before the op being constant folded. Since the constant
62 // folding logic will use a "std.constant" op to replace the
63 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
64 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
65 // convert the output type to the next op. Here are the transformations:
66 //
67 // input   min cst       max cst          input   min cst       max cst
68 //  \       |             |                \       |             |
69 //   \  (tf.Identity) (tf.Identity)   =>    \  (tf.Identity) (tf.Identity)
70 //    \     |             |                  \     |             |
71 //       tf.FakeQuantWithMinMaxVars       tf.FakeQuantWithMinMaxVars
72 //                   |                                 |
73 //                                                tfl.quantize
74 //                                                     |
75 //                                                tfl.dequantize
76 //                                                     |
77 // If the input is a constant, the result pattern will eventually converted to
78 //
79 //            quant-emulated input
80 //                   |
81 //               tfl.quantize
82 //                   |
83 //              tfl.dequantize
84 //                   |
85 //
86 //
87 // Warns if the (most likely unwanted, currently not quite correctly handled)
88 // case of back-to-back tf.FakeQuant occurs
89 //
90 //             tf.FakeQuant*
91 //                   |
92 //             tf.FakeQuant*
93 //
94 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax>
95 class InsertTFLQuantOpsAfterTFFakeQuantOp {
96  public:
97   FetchMinMax fetch_min_max_;
98 
99   using FetchAttrType = typename FetchMinMax::AttrType;
matchAndRewrite(TFFakeQuantOp tf_op,OpBuilder & rewriter)100   LogicalResult matchAndRewrite(TFFakeQuantOp tf_op,
101                                 OpBuilder &rewriter) const {
102     // We don't want to insert quantize/dequantize if the quantize op exists.
103     auto res = tf_op.outputs();
104     if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) {
105       return failure();
106     }
107 
108     // Extract the min/max constant values from the operands. We also consider
109     // a special case that there are tf.Identity ops between the min/max
110     // constants and the tf.FakeQuantWithMinMaxVarsOp.
111 
112     FetchAttrType min_value, max_value;
113     if (!fetch_min_max_(tf_op, min_value, max_value)) {
114       return failure();
115     }
116 
117     int quant_dim = -1;
118     if (PerAxis) {
119       // This is a special case that the quant_dim is the last dimensions.
120       quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
121     }
122     // Use the min/max from the operands and the num_bits and narrow_range
123     // attribute to create the quantization parameter for the new quantize op.
124     rewriter.setInsertionPointAfter(tf_op.getOperation());
125     IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits());
126     BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
127     Type res_type = tf_op.getType();
128     TypeAttr qtype = quant::GetQuantizedTypeAttr(
129         rewriter, res_type, min_value, max_value, quant_dim, num_bits,
130         narrow_range, /*is_signed=*/false);
131     if (!qtype) {
132       return failure();
133     }
134 
135     // Finally, use the quantization parameter to create the quantize and
136     // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
137     // and its users.
138     Value value = tf_op.outputs();
139     auto quantize = rewriter.create<TFL::QuantizeOp>(
140         tf_op.getLoc(), qtype.getValue(), value, qtype);
141     auto dequantize = rewriter.create<TFL::DequantizeOp>(
142         tf_op.getLoc(), res_type, quantize.output());
143     value.replaceAllUsesWith(dequantize);
144     quantize.getOperation()->replaceUsesOfWith(dequantize, value);
145 
146     return success();
147   }
148 };
149 
150 // Removes the wrapper of the tf.FakeQuant* ops and creates the tfl.quantize
151 // and tfl.dequantize pairs before tf.FakeQuant* being foled.
152 LogicalResult ConvertFakeQuantOps(FuncOp func, MLIRContext *ctx);
153 
154 // Returns the names of all the considered tf.FakeQuant* ops.
155 std::vector<std::string> AllTfFakeQuantOps();
156 
157 }  // namespace TFL
158 }  // namespace mlir
159 
160 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_
161