• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Block.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
31 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
32 #include "mlir/IR/Types.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
40 
41 // Background info:
42 // Currently the model taken to MLIRConverter is frozen (all the variables have
43 // been converted to constants, all the assign ops are gone, etc.). However,
44 // TFLite has these variable tensors semantics. So the variable mapping from TF
45 // to TFLite is actually broken here, we sort of hard-code the variable tensors
46 // based on the actual ops using them, such as unidirectional_sequence_lstm.
47 //
48 // MLIRConverter also benefits from lots of typical compiler optimization like
49 // merging same input values if they're identical. These optimizations are
50 // desirable but not for those TFLite ops which have variable tensors as inputs.
51 // Yes, they have identical input values, but those identical values are
52 // "stateful", their values can change during invocations.
53 //
54 // A typical example is unidirectional_sequence_lstm have two variable tensor
55 // inputs: activation state & cell state. They may have same initial values
56 // (typical zero-initialized), but their values will be changed. So we cannot
57 // just merge those values.
58 //
59 // This pass is more like short-term workaround since we don't have a good
60 // variable representation right now.
61 //
62 // This pass will duplicate input values for those variable tensor inputs.
63 
64 namespace mlir {
65 namespace TFL {
66 namespace {
67 
68 struct SplitMergedOperandsPass
69     : public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
70   void runOnFunction() override;
71 };
72 
DuplicateValueIfNeeded(Operation * op,llvm::DenseSet<Value> * values,OpBuilder * builder)73 LogicalResult DuplicateValueIfNeeded(Operation* op,
74                                      llvm::DenseSet<Value>* values,
75                                      OpBuilder* builder) {
76   std::vector<int> stateful_operands_index;
77   if (!IsStatefulOp(op, &stateful_operands_index)) return success();
78 
79   for (int index : stateful_operands_index) {
80     Value operand = op->getOperand(index);
81     auto inserted_value = values->insert(operand).second;
82     if (inserted_value) continue;
83     // We can only clone the constant op at this point.
84     // Since all ops have been legalized to tflite ops, so we only care about
85     // ConstOp or QConstOp or mlir constant op/
86     Operation* input_op = operand.getDefiningOp();
87     if (input_op == nullptr) return failure();
88 
89     Attribute attr;
90     if (!matchPattern(input_op, m_Constant(&attr))) {
91       op->emitError()
92           << "We cannot duplicate the value since it's not constant.\n";
93       return failure();
94     }
95     builder->setInsertionPoint(op);
96     Operation* duplicated_input_op = builder->clone(*input_op);
97 
98     // Rewire the inputs.
99     op->setOperand(index, duplicated_input_op->getResult(0));
100   }
101   return success();
102 }
103 
runOnFunction()104 void SplitMergedOperandsPass::runOnFunction() {
105   llvm::DenseSet<Value> stateful_values;
106   auto func = getFunction();
107   OpBuilder builder(func);
108   for (auto& bb : func.getBody()) {
109     for (auto& op : bb) {
110       if (failed(DuplicateValueIfNeeded(&op, &stateful_values, &builder))) {
111         func.emitError() << "Failed to duplicate values for the stateful op\n";
112         return signalPassFailure();
113       }
114     }
115   }
116 }
117 
118 }  // namespace
119 
120 /// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
121 /// pass.
CreateSplitMergedOperandsPass()122 std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
123   return std::make_unique<SplitMergedOperandsPass>();
124 }
125 
126 static PassRegistration<SplitMergedOperandsPass> pass(
127     "tfl-split-merged-operands",
128     "Split merged stateful operands for tfl operations.");
129 
130 }  // namespace TFL
131 }  // namespace mlir
132