• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
21 
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Location.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 
31 namespace mlir {
32 namespace TFL {
33 
34 constexpr char kTFImplements[] = "tf._implements";
35 constexpr char kLstmCellSimple[] = "LSTMCellSimple";
36 constexpr char kLayerNormalizedLstmCellSimple[] =
37     "LayerNormalizedLstmCellSimple";
38 constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates";
39 
40 // A utility class that enables the conversion of the LSTMCellSimple composite
41 // op into a fused TFL LSTM op. The fused op is contained within a FuncOp
42 // that also contains other supporting ops needed to construct the operands for
43 // the fused op. The caller provides the containing FuncOp as input with
44 // arguments specifying the input, weight, projection and bias.
45 // The weight, projection, bias and layer norm scale all need to be
46 // RankedTensorType.
47 // This class sets the layer norm coefficients to NoneType.
48 class ConvertLSTMCellSimpleToFusedLSTM {
49  public:
ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)50   explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)
51       : fused_func_op_(fused_func_op),
52         couple_input_forget_gates_(false),
53         builder_(fused_func_op.getBody()) {}
54 
55   // not copyable.
56   ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
57       delete;
58   ConvertLSTMCellSimpleToFusedLSTM& operator=(
59       const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLSTMCellSimpleToFusedLSTM()60   virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
61 
GetCompositeOpName()62   virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
63 
64   // Rewrite the func body with constructed fused lstm.
65   LogicalResult RewriteFunc();
66 
GetNumInputs()67   int GetNumInputs() { return n_input_; }
68 
69  protected:
70   // verify input func op arguments/attributes and initialize internal state.
71   virtual LogicalResult InitializeFromFuncAttributes();
72   virtual LogicalResult Initialize();
73 
74   void UpdateFuncSignature();
75   void GenerateFusedOpOperands();
76 
77   void SetWeightForInputToCellGate();
78   void SetWeightForInputToInputGate();
79   void SetWeightForInputToForgetGate();
80   void SetWeightForInputToOutputGate();
81 
82   void SetWeightForRecurrentToCellGate();
83   void SetWeightForRecurrentToInputGate();
84   void SetWeightForRecurrentToForgetGate();
85   void SetWeightForRecurrentToOutputGate();
86 
87   void SetBiasToCellGate();
88   void SetBiasToInputGate();
89   void SetBiasToForgetGate();
90   void SetBiasToOutputGate();
91 
92   void SetProjection();
93   void SetProjectionBias();
94 
95   void SetInputActivationState();
96   void SetInputCellState();
97 
98   virtual void SetCellLayerNormCoefficients();
99   virtual void SetInputLayerNormCoefficients();
100   virtual void SetForgetLayerNormCoefficients();
101   virtual void SetOutputLayerNormCoefficients();
102 
103   // specified state
104   FuncOp fused_func_op_;
105   Value input_;
106   Value weight_;
107   Value bias_;
108   Value projection_;
109   bool couple_input_forget_gates_;
110 
111   // internal state
112   Value weight_transposed_;
113   Value projection_transposed_;
114   RankedTensorType weight_type_;
115   RankedTensorType projection_type_;
116   int num_gates_;
117   int n_cell_;
118   int n_output_;
119   int n_input_;
120   int num_cols_weight_transposed_;
121   int num_cols_projection_transposed_;
122 
123   // input -> cifg
124   Value input2input_;
125   Value input2forget_;
126   Value input2cell_;
127   Value input2output_;
128 
129   // recurrent -> cifg
130   Value rec2input_;
131   Value rec2forget_;
132   Value rec2cell_;
133   Value rec2output_;
134 
135   // bias -> cifg
136   Value bias2input_;
137   Value bias2forget_;
138   Value bias2cell_;
139   Value bias2output_;
140 
141   // projection
142   Value proj_weight_;
143   Value proj_bias_;
144 
145   // state
146   Value input_activation_state_;
147   Value input_cell_state_;
148 
149   // layer norm coefficients
150   Value input_layer_norm_coefficients_;
151   Value forget_layer_norm_coefficients_;
152   Value cell_layer_norm_coefficients_;
153   Value output_layer_norm_coefficients_;
154 
155   mlir::TFL::LSTMOp lstm_;
156 
157   Value none_;
158   SmallVector<int64_t, 1> bias_slice_shape_;
159   SmallVector<int64_t, 1> bias_size_values_;
160   SmallVector<int64_t, 2> weight_slice_shape_;
161   SmallVector<int64_t, 2> weight_slice_size_input_values_;
162   SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
163   OpBuilder builder_;
164 };
165 
166 // A utility class that enables the conversion of the
167 // LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
168 // fused op is contained within a FuncOp that also contains other supporting ops
169 // needed to construct the operands for the fused op. The caller provides the
170 // containing FuncOp as input with arguments specifying the input, weight,
171 // projection, bias and layer norm scale. The weight, projection, bias and
172 // layer norm scale all need to be RankedTensorType.
173 // This class overrides the layer norm coefficient setters from the base class.
174 class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
175     : public ConvertLSTMCellSimpleToFusedLSTM {
176  public:
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)177   explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
178       mlir::FuncOp fused_func_op)
179       : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {}
180 
181   // not copyable.
182   ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
183       const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
184   ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
185       const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM()186   ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
187 
GetCompositeOpName()188   llvm::StringRef GetCompositeOpName() override {
189     return kLayerNormalizedLstmCellSimple;
190   }
191 
192  protected:
193   LogicalResult Initialize() override;
194 
195   void SetCellLayerNormCoefficients() override;
196   void SetInputLayerNormCoefficients() override;
197   void SetForgetLayerNormCoefficients() override;
198   void SetOutputLayerNormCoefficients() override;
199 
200  private:
201   // specified state
202   Value layer_norm_scale_;
203 
204   // internal state
205   RankedTensorType layer_norm_scale_type_;
206   SmallVector<int64_t, 1> layer_norm_slice_shape_;
207   SmallVector<int64_t, 1> layer_norm_size_values_;
208 };
209 
210 LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder);
211 
212 }  // end namespace TFL
213 }  // end namespace mlir
214 
215 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
216