1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Block.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/MLIRContext.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/OperationSupport.h" // from @llvm-project
31 #include "mlir/IR/SymbolTable.h" // from @llvm-project
32 #include "mlir/IR/Types.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
40
41 // Background info:
42 // Currently the model taken to MLIRConverter is frozen (all the variables have
43 // been converted to constants, all the assign ops are gone, etc.). However,
44 // TFLite has these variable tensors semantics. So the variable mapping from TF
45 // to TFLite is actually broken here, we sort of hard-code the variable tensors
46 // based on the actual ops using them, such as unidirectional_sequence_lstm.
47 //
48 // MLIRConverter also benefits from lots of typical compiler optimization like
49 // merging same input values if they're identical. These optimizations are
50 // desirable but not for those TFLite ops which have variable tensors as inputs.
51 // Yes, they have identical input values, but those identical values are
52 // "stateful", their values can change during invocations.
53 //
54 // A typical example is unidirectional_sequence_lstm have two variable tensor
55 // inputs: activation state & cell state. They may have same initial values
56 // (typical zero-initialized), but their values will be changed. So we cannot
57 // just merge those values.
58 //
59 // This pass is more like short-term workaround since we don't have a good
60 // variable representation right now.
61 //
62 // This pass will duplicate input values for those variable tensor inputs.
63
64 namespace mlir {
65 namespace TFL {
66 namespace {
67
68 struct SplitMergedOperandsPass
69 : public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
70 void runOnFunction() override;
71 };
72
DuplicateValueIfNeeded(Operation * op,llvm::DenseSet<Value> * values,OpBuilder * builder)73 LogicalResult DuplicateValueIfNeeded(Operation* op,
74 llvm::DenseSet<Value>* values,
75 OpBuilder* builder) {
76 std::vector<int> stateful_operands_index;
77 if (!IsStatefulOp(op, &stateful_operands_index)) return success();
78
79 for (int index : stateful_operands_index) {
80 Value operand = op->getOperand(index);
81 auto inserted_value = values->insert(operand).second;
82 if (inserted_value) continue;
83 // We can only clone the constant op at this point.
84 // Since all ops have been legalized to tflite ops, so we only care about
85 // ConstOp or QConstOp or mlir constant op/
86 Operation* input_op = operand.getDefiningOp();
87 if (input_op == nullptr) return failure();
88
89 Attribute attr;
90 if (!matchPattern(input_op, m_Constant(&attr))) {
91 op->emitError()
92 << "We cannot duplicate the value since it's not constant.\n";
93 return failure();
94 }
95 builder->setInsertionPoint(op);
96 Operation* duplicated_input_op = builder->clone(*input_op);
97
98 // Rewire the inputs.
99 op->setOperand(index, duplicated_input_op->getResult(0));
100 }
101 return success();
102 }
103
runOnFunction()104 void SplitMergedOperandsPass::runOnFunction() {
105 llvm::DenseSet<Value> stateful_values;
106 auto func = getFunction();
107 OpBuilder builder(func);
108 for (auto& bb : func.getBody()) {
109 for (auto& op : bb) {
110 if (failed(DuplicateValueIfNeeded(&op, &stateful_values, &builder))) {
111 func.emitError() << "Failed to duplicate values for the stateful op\n";
112 return signalPassFailure();
113 }
114 }
115 }
116 }
117
118 } // namespace
119
120 /// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
121 /// pass.
CreateSplitMergedOperandsPass()122 std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
123 return std::make_unique<SplitMergedOperandsPass>();
124 }
125
126 static PassRegistration<SplitMergedOperandsPass> pass(
127 "tfl-split-merged-operands",
128 "Split merged stateful operands for tfl operations.");
129
130 } // namespace TFL
131 } // namespace mlir
132