1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
16 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
17 #include "mlir/Pass/Pass.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20
21 namespace mlir {
22 namespace TFL {
23 namespace {
24 // Attribute name to be added on the module to identify whether
25 // variables should be legalized to TFLite or not.
26 const char kLegalizeTflVariables[] = "tfl._legalize_tfl_variables";
27
28 // Returns true if 'op' is TF op that accepts resource type, but is
29 // supported by TFLite.
IsSupportedTFLiteResourceOp(Operation * op)30 bool IsSupportedTFLiteResourceOp(Operation* op) {
31 return llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp, TF::VarHandleOp,
32 TF::LookupTableFindV2Op, TF::LookupTableImportV2Op,
33 TF::LookupTableSizeV2Op>(op);
34 }
35
36 // Returns true if 'op' is a TFLite control flow operation.
IsTFLiteControlFlowOp(Operation * op)37 bool IsTFLiteControlFlowOp(Operation* op) {
38 return llvm::isa<TFL::WhileOp, TFL::IfOp, TFL::CallOnceOp>(op);
39 }
40 } // namespace
41
42 // Pass which analyzes the variables in the graph and add an attribute whether
43 // variables should be legalized to TFLite native ones.
44 // This pass needs to run post TF->TFL legalization and before variable
45 // legalization.
46 class AnalyzeVariablesPass
47 : public PassWrapper<AnalyzeVariablesPass, OperationPass<ModuleOp>> {
48 public:
49 AnalyzeVariablesPass() = default;
AnalyzeVariablesPass(const AnalyzeVariablesPass &)50 AnalyzeVariablesPass(const AnalyzeVariablesPass&) {}
51
getArgument() const52 StringRef getArgument() const final {
53 // This is the argument used to refer to the pass in
54 // the textual format (on the commandline for example).
55 return "tfl-analyze-variables-pass";
56 }
getDescription() const57 StringRef getDescription() const final {
58 // This is a brief description of the pass.
59 return "Analyze variables in the graph.";
60 }
61
runOnOperation()62 void runOnOperation() override {
63 auto* context = &getContext();
64 auto module = getOperation();
65 bool legalize_to_tfl = true;
66
67 module.walk([&](Operation* op) {
68 // Skip ops that are supported natively by TFLite.
69 if (IsSupportedTFLiteResourceOp(op)) return WalkResult::advance();
70
71 // Check for ops that are legalized to TFLite.
72 if (op->getDialect()->getNamespace() == "tfl") {
73 // TODO(b/189370197): Enable control flow ops after updating
74 // checks to handle them.
75 if (IsTFLiteControlFlowOp(op)) {
76 legalize_to_tfl = false;
77 return WalkResult::interrupt();
78 }
79 return WalkResult::advance();
80 }
81 // Check for ops that are not legalized to TFLite.
82
83 // If any of the operands is a resource type, then we break
84 // and mark the module as not valid for TFLite legalization.
85 // Note: this might disable native variables in more than needed cases.
86 // TODO(b/189370197): Enhance variable analysis.
87 for (auto operand : op->getOperands()) {
88 if (getElementTypeOrSelf(operand.getType()).isa<TF::ResourceType>()) {
89 legalize_to_tfl = false;
90 return WalkResult::interrupt();
91 }
92 }
93 return WalkResult::advance();
94 });
95 module->setAttr(kLegalizeTflVariables,
96 BoolAttr::get(context, legalize_to_tfl));
97 }
98 };
99
CreateAnalyzeVariablesPass()100 std::unique_ptr<OperationPass<ModuleOp>> CreateAnalyzeVariablesPass() {
101 return std::make_unique<AnalyzeVariablesPass>();
102 }
103
__anonac58b5490302null104 static PassRegistration<AnalyzeVariablesPass> pass([] {
105 return CreateAnalyzeVariablesPass();
106 });
107
108 } // namespace TFL
109 } // namespace mlir
110