1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
23 #include "mlir/IR/MLIRContext.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33
34 namespace mlir {
35 namespace TFL {
36 namespace {
37 // This file has Legalize variable pass which is responsible for:
38 // - Converting all TF::ReadVariableOp and TF::AssignVariableOp to the
39 // TFLite equivalent ops.
40 // Note that, this pass assumes all variables are already available as
41 // GlobalTensorOp and all varHandle are converted already to a function
42 // arguments with bounded_input attribute.
43 // Also all other ops are already legalized to TFLite.
44 // TODO(b/149099381): Handle flex support use cases.
45
46 // Returns Value representing the resource_id.
GetResourceIDAsI32(int resource_id,Location loc,mlir::OpBuilder & rewriter)47 Value GetResourceIDAsI32(int resource_id, Location loc,
48 mlir::OpBuilder& rewriter) {
49 return rewriter.create<ConstOp>(
50 loc, DenseElementsAttr::get(
51 RankedTensorType::get({1}, rewriter.getIntegerType(32)),
52 resource_id));
53 }
54
55 // Helper method that fetches the global tensor that 'op' points to it.
56 template <typename T>
GetGlobalTensor(const SymbolTable & symbol_table,T op,FuncOp func)57 tf_saved_model::GlobalTensorOp GetGlobalTensor(const SymbolTable& symbol_table,
58 T op, FuncOp func) {
59 auto block_arg = op.resource().template dyn_cast<BlockArgument>();
60 if (!block_arg) return nullptr;
61 int index = block_arg.getArgNumber();
62 auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>(
63 index, "tf_saved_model.bound_input");
64 if (!sym) {
65 return nullptr;
66 }
67 return symbol_table.lookup<tf_saved_model::GlobalTensorOp>(sym.getValue());
68 }
69
GetAssignVariableOp(int variable_id,TF::AssignVariableOp assign_op,mlir::OpBuilder builder)70 mlir::Operation* GetAssignVariableOp(int variable_id,
71 TF::AssignVariableOp assign_op,
72 mlir::OpBuilder builder) {
73 return builder.create<TFL::AssignVariableOp>(
74 assign_op.getLoc(),
75 GetResourceIDAsI32(variable_id, assign_op.getLoc(), builder),
76 assign_op.value());
77 }
78
GetReadVariableOp(int variable_id,TF::ReadVariableOp read_op,mlir::OpBuilder builder)79 mlir::Operation* GetReadVariableOp(int variable_id, TF::ReadVariableOp read_op,
80 mlir::OpBuilder builder) {
81 return builder.create<TFL::ReadVariableOp>(
82 read_op.getLoc(), read_op.getResult().getType(),
83 GetResourceIDAsI32(variable_id, read_op.getLoc(), builder));
84 }
85
86 template <typename T>
87 class LegalizeVariablesPattern : public mlir::OpConversionPattern<T> {
88 public:
LegalizeVariablesPattern(mlir::MLIRContext * context,const std::map<std::string,int> * global_tensor_id,SymbolTable symbol_table)89 LegalizeVariablesPattern(mlir::MLIRContext* context,
90 const std::map<std::string, int>* global_tensor_id,
91 SymbolTable symbol_table)
92 : mlir::OpConversionPattern<T>(context),
93 global_tensor_id_(global_tensor_id),
94 symbol_table_(symbol_table) {}
95
matchAndRewrite(T var_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const96 LogicalResult matchAndRewrite(
97 T var_op, ArrayRef<Value> operands,
98 ConversionPatternRewriter& rewriter) const override {
99 auto* op = var_op.getOperation();
100 auto func = var_op->template getParentOfType<FuncOp>();
101 if (!func) return failure();
102 auto global_tensor = GetGlobalTensor<T>(symbol_table_, var_op, func);
103 if (!global_tensor) return failure();
104 auto variable_id = global_tensor_id_->at(global_tensor.sym_name().str());
105 mlir::OpBuilder builder(var_op);
106 mlir::Operation* tfl_var_op = nullptr;
107 if (llvm::isa<TF::AssignVariableOp>(op)) {
108 auto assign_op = llvm::cast<TF::AssignVariableOp>(op);
109 tfl_var_op = GetAssignVariableOp(variable_id, assign_op, builder);
110 } else {
111 auto read_op = llvm::cast<TF::ReadVariableOp>(op);
112 tfl_var_op = GetReadVariableOp(variable_id, read_op, builder);
113 }
114 var_op->replaceAllUsesWith(tfl_var_op);
115 rewriter.eraseOp(var_op);
116 return success();
117 }
118
119 private:
120 const std::map<std::string, int>* global_tensor_id_;
121 SymbolTable symbol_table_;
122 };
123
124 // Pass which legalizes TF variables which are already passed as bounded
125 // arguments to functions, to TFLite variables.
126 class LegalizeVariables
127 : public PassWrapper<LegalizeVariables, OperationPass<ModuleOp>> {
128 public:
129 LegalizeVariables() = default;
LegalizeVariables(const LegalizeVariables &)130 LegalizeVariables(const LegalizeVariables&) {}
131
runOnOperation()132 void runOnOperation() override {
133 auto module = getOperation();
134 // Use ordered container to make sure ids are deterministic if we got tensor
135 // ids from different part, also easier to debug.
136 // TODO(b/149099381): Remove integer IDs after adding the new variable
137 // handle type.
138 std::map<std::string, int> global_tensor_id;
139 for (auto global_tensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
140 global_tensor_id[global_tensor.sym_name().str()];
141 }
142 int id = 0;
143 for (auto& tensor : global_tensor_id) tensor.second = id++;
144
145 SymbolTable symbol_table(module);
146 ConversionTarget target(getContext());
147 OwningRewritePatternList patterns;
148 patterns.insert<LegalizeVariablesPattern<TF::ReadVariableOp>,
149 LegalizeVariablesPattern<TF::AssignVariableOp>>(
150 &getContext(), &global_tensor_id, symbol_table);
151 if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
152 signalPassFailure();
153 return;
154 }
155 }
156 };
157
158 } // namespace
159
CreateLegalizeVariablesPass()160 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeVariablesPass() {
161 return std::make_unique<LegalizeVariables>();
162 }
163
164 static PassRegistration<LegalizeVariables> pass(
165 "tfl-legalize-variables-tf",
166 "Legalize TensorFlow variables to TensorFlow Lite dialect");
167
168 } // namespace TFL
169 } // namespace mlir
170