1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines common utils used by TFLite transformation 17 // passes to work with tf.FakeQuant* ops. 18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 19 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 20 21 #include "mlir/IR/Attributes.h" // from @llvm-project 22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 23 #include "mlir/IR/MLIRContext.h" // from @llvm-project 24 #include "mlir/Support/LLVM.h" // from @llvm-project 25 #include "mlir/Support/LogicalResult.h" // from @llvm-project 26 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" 28 29 namespace mlir { 30 namespace TFL { 31 32 template <class TFFakeQuantOp> 33 struct FetchMinMaxAttrs { 34 using AttrType = FloatAttr; operatorFetchMinMaxAttrs35 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, 36 AttrType &max_value) const { 37 min_value = tf_op.minAttr(); 38 max_value = tf_op.maxAttr(); 39 return true; // Succesfully matched and fetched. 40 } 41 }; 42 43 template <class TFFakeQuantOp> 44 struct FetchConstantMinMaxInputs { 45 using AttrType = DenseFPElementsAttr; operatorFetchConstantMinMaxInputs46 bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, 47 AttrType &max_value) const { 48 Value min = tf_op.min(), max = tf_op.max(); 49 if (!matchPattern(min, m_Constant(&min_value))) { 50 return false; 51 } 52 if (!matchPattern(max, m_Constant(&max_value))) { 53 return false; 54 } 55 return true; // Succesfully matched and fetched. 56 } 57 }; 58 59 // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the 60 // tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op 61 // before the op being constant folded. Since the constant 62 // folding logic will use a "std.constant" op to replace the 63 // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve 64 // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to 65 // convert the output type to the next op. Here are the transformations: 66 // 67 // input min cst max cst input min cst max cst 68 // \ | | \ | | 69 // \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity) 70 // \ | | \ | | 71 // tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars 72 // | | 73 // tfl.quantize 74 // | 75 // tfl.dequantize 76 // | 77 // If the input is a constant, the result pattern will eventually converted to 78 // 79 // quant-emulated input 80 // | 81 // tfl.quantize 82 // | 83 // tfl.dequantize 84 // | 85 // 86 // 87 // Warns if the (most likely unwanted, currently not quite correctly handled) 88 // case of back-to-back tf.FakeQuant occurs 89 // 90 // tf.FakeQuant* 91 // | 92 // tf.FakeQuant* 93 // 94 template <typename TFFakeQuantOp, bool PerAxis, class FetchMinMax> 95 class InsertTFLQuantOpsAfterTFFakeQuantOp { 96 public: 97 FetchMinMax fetch_min_max_; 98 99 using FetchAttrType = typename FetchMinMax::AttrType; matchAndRewrite(TFFakeQuantOp tf_op,OpBuilder & rewriter)100 LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, 101 OpBuilder &rewriter) const { 102 // We don't want to insert quantize/dequantize if the quantize op exists. 103 auto res = tf_op.outputs(); 104 if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin())) { 105 return failure(); 106 } 107 108 // Extract the min/max constant values from the operands. We also consider 109 // a special case that there are tf.Identity ops between the min/max 110 // constants and the tf.FakeQuantWithMinMaxVarsOp. 111 112 FetchAttrType min_value, max_value; 113 if (!fetch_min_max_(tf_op, min_value, max_value)) { 114 return failure(); 115 } 116 117 int quant_dim = -1; 118 if (PerAxis) { 119 // This is a special case that the quant_dim is the last dimensions. 120 quant_dim = res.getType().template cast<ShapedType>().getRank() - 1; 121 } 122 // Use the min/max from the operands and the num_bits and narrow_range 123 // attribute to create the quantization parameter for the new quantize op. 124 rewriter.setInsertionPointAfter(tf_op.getOperation()); 125 IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); 126 BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); 127 Type res_type = tf_op.getType(); 128 TypeAttr qtype = quant::GetQuantizedTypeAttr( 129 rewriter, res_type, min_value, max_value, quant_dim, num_bits, 130 narrow_range, /*is_signed=*/false); 131 if (!qtype) { 132 return failure(); 133 } 134 135 // Finally, use the quantization parameter to create the quantize and 136 // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp 137 // and its users. 138 Value value = tf_op.outputs(); 139 auto quantize = rewriter.create<TFL::QuantizeOp>( 140 tf_op.getLoc(), qtype.getValue(), value, qtype); 141 auto dequantize = rewriter.create<TFL::DequantizeOp>( 142 tf_op.getLoc(), res_type, quantize.output()); 143 value.replaceAllUsesWith(dequantize); 144 quantize.getOperation()->replaceUsesOfWith(dequantize, value); 145 146 return success(); 147 } 148 }; 149 150 // Removes the wrapper of the tf.FakeQuant* ops and creates the tfl.quantize 151 // and tfl.dequantize pairs before tf.FakeQuant* being foled. 152 LogicalResult ConvertFakeQuantOps(FuncOp func, MLIRContext *ctx); 153 154 // Returns the names of all the considered tf.FakeQuant* ops. 155 std::vector<std::string> AllTfFakeQuantOps(); 156 157 } // namespace TFL 158 } // namespace mlir 159 160 #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ 161