• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/Parser/Parser.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassManager.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
30 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.h"
31 #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34 
35 namespace mlir {
36 namespace quant {
37 namespace {
38 
39 class InsertQuantizedFunctionsPass
40     : public PassWrapper<InsertQuantizedFunctionsPass,
41                          OperationPass<ModuleOp>> {
42  public:
43   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertQuantizedFunctionsPass)
44 
InsertQuantizedFunctionsPass()45   explicit InsertQuantizedFunctionsPass() {}
InsertQuantizedFunctionsPass(QuantizationMethod quantization_method,const OpSet & op_set)46   explicit InsertQuantizedFunctionsPass(QuantizationMethod quantization_method,
47                                         const OpSet& op_set) {
48     quantization_method_ = quantization_method;
49     op_set_ = op_set;
50   }
InsertQuantizedFunctionsPass(const InsertQuantizedFunctionsPass & other)51   InsertQuantizedFunctionsPass(const InsertQuantizedFunctionsPass& other) {
52     quantization_method_ = other.quantization_method_;
53     op_set_ = other.op_set_;
54   }
55 
getArgument() const56   StringRef getArgument() const final {
57     // This is the argument used to refer to the pass in the textual format (on
58     // the commandline for example).
59     return "quant-insert-quantized-functions";
60   }
61 
getDescription() const62   StringRef getDescription() const final {
63     // This is a brief description of the pass.
64     return "Insert quantized functions into the module";
65   }
66 
getDependentDialects(DialectRegistry & registry) const67   void getDependentDialects(DialectRegistry& registry) const override {
68     registry.insert<TF::TensorFlowDialect, func::FuncDialect>();
69   }
70 
71  private:
72   void runOnOperation() override;
73 
74   // Returns the function library for the given quantization method and opset
75   // pair.
76   llvm::StringRef GetFunctionLibrary(QuantizationMethod quantization_method,
77                                      OpSet op_set);
78 
79   Option<QuantizationMethod> quantization_method_{
80       *this, "quantization-method",
81       llvm::cl::init(QuantizationMethod::kPostTrainingQuantization),
82       llvm::cl::desc("Choose quantization method."),
83       llvm::cl::values(
84           clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq",
85                      "Post-training static-range quantization"),
86           clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq",
87                      "Post-training dynamic-range quantizaiton"))};
88 
89   Option<OpSet> op_set_{
90       *this, "target-opset", llvm::cl::init(OpSet::TF),
91       llvm::cl::desc("Choose target opset."),
92       llvm::cl::values(
93           clEnumValN(OpSet::TF, "TF",
94                      "Uses TF ops that mimic quantization behavior"),
95           clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"),
96           clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED",
97                      "Uses TF Uniform Quantized ops"))};
98 };
99 
GetFunctionLibrary(QuantizationMethod quantization_method,OpSet op_set)100 llvm::StringRef InsertQuantizedFunctionsPass::GetFunctionLibrary(
101     QuantizationMethod quantization_method, OpSet op_set) {
102   absl::flat_hash_map<OpSet, llvm::StringRef> function_library_map;
103   if (quantization_method == QuantizationMethod::kDynamicRangeQuantization) {
104     function_library_map = {
105         {OpSet::UNIFORM_QUANTIZED,
106          kQuantizedFunctionLibraryInMLIR_UNIFORM_QUANTIZED_DRQ},
107         {OpSet::TF, kQuantizedFunctionLibraryInMLIR_TF_DRQ}};
108   } else {
109     function_library_map = {{OpSet::TF, kQuantizedFunctionLibraryInMLIR},
110                             {OpSet::XLA, kQuantizedFunctionLibraryInMLIR}};
111   }
112 
113   auto it = function_library_map.find(op_set);
114   if (it != function_library_map.end()) {
115     return it->second;
116   }
117   return llvm::StringRef();
118 }
119 
120 static PassRegistration<InsertQuantizedFunctionsPass> pass;
121 
runOnOperation()122 void InsertQuantizedFunctionsPass::runOnOperation() {
123   ModuleOp module = getOperation();
124   SymbolTable symbol_table(module);
125 
126   std::unique_ptr<llvm::MemoryBuffer> mem_buffer;
127   llvm::StringRef quantized_function_library =
128       GetFunctionLibrary(quantization_method_, op_set_);
129 
130   if (quantized_function_library.empty()) {
131     emitError(module.getLoc())
132         << "Failed to get function library for the opset.";
133     signalPassFailure();
134     return;
135   }
136 
137   mem_buffer =
138       llvm::MemoryBuffer::getMemBuffer(quantized_function_library,
139                                        /*BufferName=*/"",
140                                        /*RequiresNullTerminator=*/false);
141 
142   llvm::SourceMgr source_mgr;
143   source_mgr.AddNewSourceBuffer(std::move(mem_buffer), llvm::SMLoc());
144   OwningOpRef<ModuleOp> module_ref =
145       parseSourceFile<ModuleOp>(source_mgr, module.getContext());
146   // Inline and optimize loaded functions.
147   MLIRContext* context = &getContext();
148   PassManager pm(context);
149   pm.addPass(createInlinerPass());
150   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
151   pm.addNestedPass<func::FuncOp>(createCSEPass());
152 
153   StatusScopedDiagnosticHandler diagnostic_handler(context);
154   if (failed(pm.run(*module_ref))) {
155     emitError(module.getLoc())
156         << "failed to apply the optimization: "
157         << diagnostic_handler.ConsumeStatus().error_message();
158     signalPassFailure();
159     return;
160   }
161 
162   // Copy all functions used by this signature to the final MLIR module.
163   for (func::FuncOp func : module_ref->getOps<func::FuncOp>()) {
164     // Do nothing if the function already exists.
165     if (symbol_table.lookup(func.getSymName()) != nullptr) continue;
166 
167     // Set the function to private and insert to the module.
168     func::FuncOp new_func = func.clone();
169     new_func.setPrivate();
170     symbol_table.insert(new_func);
171   }
172 }
173 
174 }  // namespace
175 
176 // Creates an instance of the pass for inserting quantized functions.
CreateInsertQuantizedFunctionsPass(QuantizationMethod quantization_method,const OpSet & op_set)177 std::unique_ptr<OperationPass<ModuleOp>> CreateInsertQuantizedFunctionsPass(
178     QuantizationMethod quantization_method, const OpSet& op_set) {
179   return std::make_unique<InsertQuantizedFunctionsPass>(quantization_method,
180                                                         op_set);
181 }
182 
183 }  // namespace quant
184 }  // namespace mlir
185