1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/Parser/Parser.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassManager.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
30 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.h"
31 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34
35 namespace mlir {
36 namespace quant {
37 namespace {
38
39 class InsertQuantizedFunctionsPass
40 : public PassWrapper<InsertQuantizedFunctionsPass,
41 OperationPass<ModuleOp>> {
42 public:
43 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertQuantizedFunctionsPass)
44
InsertQuantizedFunctionsPass()45 explicit InsertQuantizedFunctionsPass() {}
InsertQuantizedFunctionsPass(QuantizationMethod quantization_method,const OpSet & op_set)46 explicit InsertQuantizedFunctionsPass(QuantizationMethod quantization_method,
47 const OpSet& op_set) {
48 quantization_method_ = quantization_method;
49 op_set_ = op_set;
50 }
InsertQuantizedFunctionsPass(const InsertQuantizedFunctionsPass & other)51 InsertQuantizedFunctionsPass(const InsertQuantizedFunctionsPass& other) {
52 quantization_method_ = other.quantization_method_;
53 op_set_ = other.op_set_;
54 }
55
getArgument() const56 StringRef getArgument() const final {
57 // This is the argument used to refer to the pass in the textual format (on
58 // the commandline for example).
59 return "quant-insert-quantized-functions";
60 }
61
getDescription() const62 StringRef getDescription() const final {
63 // This is a brief description of the pass.
64 return "Insert quantized functions into the module";
65 }
66
getDependentDialects(DialectRegistry & registry) const67 void getDependentDialects(DialectRegistry& registry) const override {
68 registry.insert<TF::TensorFlowDialect, func::FuncDialect>();
69 }
70
71 private:
72 void runOnOperation() override;
73
74 // Returns the function library for the given quantization method and opset
75 // pair.
76 llvm::StringRef GetFunctionLibrary(QuantizationMethod quantization_method,
77 OpSet op_set);
78
79 Option<QuantizationMethod> quantization_method_{
80 *this, "quantization-method",
81 llvm::cl::init(QuantizationMethod::kPostTrainingQuantization),
82 llvm::cl::desc("Choose quantization method."),
83 llvm::cl::values(
84 clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq",
85 "Post-training static-range quantization"),
86 clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq",
87 "Post-training dynamic-range quantizaiton"))};
88
89 Option<OpSet> op_set_{
90 *this, "target-opset", llvm::cl::init(OpSet::TF),
91 llvm::cl::desc("Choose target opset."),
92 llvm::cl::values(
93 clEnumValN(OpSet::TF, "TF",
94 "Uses TF ops that mimic quantization behavior"),
95 clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"),
96 clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED",
97 "Uses TF Uniform Quantized ops"))};
98 };
99
GetFunctionLibrary(QuantizationMethod quantization_method,OpSet op_set)100 llvm::StringRef InsertQuantizedFunctionsPass::GetFunctionLibrary(
101 QuantizationMethod quantization_method, OpSet op_set) {
102 absl::flat_hash_map<OpSet, llvm::StringRef> function_library_map;
103 if (quantization_method == QuantizationMethod::kDynamicRangeQuantization) {
104 function_library_map = {
105 {OpSet::UNIFORM_QUANTIZED,
106 kQuantizedFunctionLibraryInMLIR_UNIFORM_QUANTIZED_DRQ},
107 {OpSet::TF, kQuantizedFunctionLibraryInMLIR_TF_DRQ}};
108 } else {
109 function_library_map = {{OpSet::TF, kQuantizedFunctionLibraryInMLIR},
110 {OpSet::XLA, kQuantizedFunctionLibraryInMLIR}};
111 }
112
113 auto it = function_library_map.find(op_set);
114 if (it != function_library_map.end()) {
115 return it->second;
116 }
117 return llvm::StringRef();
118 }
119
120 static PassRegistration<InsertQuantizedFunctionsPass> pass;
121
runOnOperation()122 void InsertQuantizedFunctionsPass::runOnOperation() {
123 ModuleOp module = getOperation();
124 SymbolTable symbol_table(module);
125
126 std::unique_ptr<llvm::MemoryBuffer> mem_buffer;
127 llvm::StringRef quantized_function_library =
128 GetFunctionLibrary(quantization_method_, op_set_);
129
130 if (quantized_function_library.empty()) {
131 emitError(module.getLoc())
132 << "Failed to get function library for the opset.";
133 signalPassFailure();
134 return;
135 }
136
137 mem_buffer =
138 llvm::MemoryBuffer::getMemBuffer(quantized_function_library,
139 /*BufferName=*/"",
140 /*RequiresNullTerminator=*/false);
141
142 llvm::SourceMgr source_mgr;
143 source_mgr.AddNewSourceBuffer(std::move(mem_buffer), llvm::SMLoc());
144 OwningOpRef<ModuleOp> module_ref =
145 parseSourceFile<ModuleOp>(source_mgr, module.getContext());
146 // Inline and optimize loaded functions.
147 MLIRContext* context = &getContext();
148 PassManager pm(context);
149 pm.addPass(createInlinerPass());
150 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
151 pm.addNestedPass<func::FuncOp>(createCSEPass());
152
153 StatusScopedDiagnosticHandler diagnostic_handler(context);
154 if (failed(pm.run(*module_ref))) {
155 emitError(module.getLoc())
156 << "failed to apply the optimization: "
157 << diagnostic_handler.ConsumeStatus().error_message();
158 signalPassFailure();
159 return;
160 }
161
162 // Copy all functions used by this signature to the final MLIR module.
163 for (func::FuncOp func : module_ref->getOps<func::FuncOp>()) {
164 // Do nothing if the function already exists.
165 if (symbol_table.lookup(func.getSymName()) != nullptr) continue;
166
167 // Set the function to private and insert to the module.
168 func::FuncOp new_func = func.clone();
169 new_func.setPrivate();
170 symbol_table.insert(new_func);
171 }
172 }
173
174 } // namespace
175
176 // Creates an instance of the pass for inserting quantized functions.
CreateInsertQuantizedFunctionsPass(QuantizationMethod quantization_method,const OpSet & op_set)177 std::unique_ptr<OperationPass<ModuleOp>> CreateInsertQuantizedFunctionsPass(
178 QuantizationMethod quantization_method, const OpSet& op_set) {
179 return std::make_unique<InsertQuantizedFunctionsPass>(quantization_method,
180 op_set);
181 }
182
183 } // namespace quant
184 } // namespace mlir
185