1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Block.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/MLIRContext.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/OperationSupport.h" // from @llvm-project
31 #include "mlir/IR/SymbolTable.h" // from @llvm-project
32 #include "mlir/IR/Types.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
40
41 // Background info:
42 // Currently the model taken to MLIRConverter is frozen (all the variables have
43 // been converted to constants, all the assign ops are gone, etc.). However,
44 // TFLite has these variable tensors semantics. So the variable mapping from TF
45 // to TFLite is actually broken here, we sort of hard-code the variable tensors
46 // based on the actual ops using them, such as unidirectional_sequence_lstm.
47 //
48 // MLIRConverter also benefits from lots of typical compiler optimization like
49 // merging same input values if they're identical. These optimizations are
50 // desirable but not for those TFLite ops which have variable tensors as inputs.
51 // Yes, they have identical input values, but those identical values are
52 // "stateful", their values can change during invocations.
53 //
54 // A typical example is unidirectional_sequence_lstm have two variable tensor
55 // inputs: activation state & cell state. They may have same initial values
56 // (typical zero-initialized), but their values will be changed. So we cannot
57 // just merge those values.
58 //
59 // This pass is more like short-term workaround since we don't have a good
60 // variable representation right now.
61 //
62 // This pass will duplicate input values for those variable tensor inputs.
63
64 namespace mlir {
65 namespace TFL {
66 namespace {
67
68 struct SplitMergedOperandsPass
69 : public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
70 void runOnFunction() override;
71
getArgumentmlir::TFL::__anon492f56970111::SplitMergedOperandsPass72 StringRef getArgument() const final {
73 // This is the argument used to refer to the pass in
74 // the textual format (on the commandline for example).
75 return "tfl-split-merged-operands";
76 }
getDescriptionmlir::TFL::__anon492f56970111::SplitMergedOperandsPass77 StringRef getDescription() const final {
78 // This is a brief description of the pass.
79 return "Split merged stateful operands for tfl operations.";
80 }
81 };
82
DuplicateValueIfNeeded(Operation * op,llvm::DenseSet<Value> * values,OpBuilder * builder)83 LogicalResult DuplicateValueIfNeeded(Operation* op,
84 llvm::DenseSet<Value>* values,
85 OpBuilder* builder) {
86 std::vector<int> stateful_operands_index;
87 if (!IsStatefulOp(op, &stateful_operands_index)) return success();
88
89 for (int index : stateful_operands_index) {
90 Value operand = op->getOperand(index);
91 auto inserted_value = values->insert(operand).second;
92 if (inserted_value) continue;
93 // We can only clone the constant op at this point.
94 // Since all ops have been legalized to tflite ops, so we only care about
95 // ConstOp or QConstOp or mlir constant op/
96 Operation* input_op = operand.getDefiningOp();
97 if (input_op == nullptr) return failure();
98
99 Attribute attr;
100 if (!matchPattern(input_op, m_Constant(&attr))) {
101 op->emitError()
102 << "We cannot duplicate the value since it's not constant.\n";
103 return failure();
104 }
105 builder->setInsertionPoint(op);
106 Operation* duplicated_input_op = builder->clone(*input_op);
107
108 // Rewire the inputs.
109 op->setOperand(index, duplicated_input_op->getResult(0));
110 }
111 return success();
112 }
113
runOnFunction()114 void SplitMergedOperandsPass::runOnFunction() {
115 llvm::DenseSet<Value> stateful_values;
116 auto func = getFunction();
117 OpBuilder builder(func);
118 for (auto& bb : func.getBody()) {
119 for (auto& op : bb) {
120 if (failed(DuplicateValueIfNeeded(&op, &stateful_values, &builder))) {
121 func.emitError() << "Failed to duplicate values for the stateful op\n";
122 return signalPassFailure();
123 }
124 }
125 }
126 }
127
128 } // namespace
129
130 /// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
131 /// pass.
CreateSplitMergedOperandsPass()132 std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
133 return std::make_unique<SplitMergedOperandsPass>();
134 }
135
136 static PassRegistration<SplitMergedOperandsPass> pass;
137
138 } // namespace TFL
139 } // namespace mlir
140