1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines common utils used by TFLite transformation 17 // passes to work with op attributes. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 20 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 21 22 #include "llvm/ADT/StringRef.h" 23 #include "mlir/IR/Builders.h" // from @llvm-project 24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 26 #include "mlir/IR/Location.h" // from @llvm-project 27 #include "mlir/IR/Value.h" // from @llvm-project 28 #include "mlir/Support/LogicalResult.h" // from @llvm-project 29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 30 31 namespace mlir { 32 namespace TFL { 33 34 constexpr char kTFImplements[] = "tf._implements"; 35 constexpr char kLstmCellSimple[] = "LSTMCellSimple"; 36 constexpr char kLayerNormalizedLstmCellSimple[] = 37 "LayerNormalizedLstmCellSimple"; 38 constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates"; 39 40 // A utility class that enables the conversion of the LSTMCellSimple composite 41 // op into a fused TFL LSTM op. The fused op is contained within a FuncOp 42 // that also contains other supporting ops needed to construct the operands for 43 // the fused op. The caller provides the containing FuncOp as input with 44 // arguments specifying the input, weight, projection and bias. 45 // The weight, projection, bias and layer norm scale all need to be 46 // RankedTensorType. 47 // This class sets the layer norm coefficients to NoneType. 48 class ConvertLSTMCellSimpleToFusedLSTM { 49 public: ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)50 explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op) 51 : fused_func_op_(fused_func_op), 52 couple_input_forget_gates_(false), 53 builder_(fused_func_op.getBody()) {} 54 55 // not copyable. 56 ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) = 57 delete; 58 ConvertLSTMCellSimpleToFusedLSTM& operator=( 59 const ConvertLSTMCellSimpleToFusedLSTM&) = delete; ~ConvertLSTMCellSimpleToFusedLSTM()60 virtual ~ConvertLSTMCellSimpleToFusedLSTM() {} 61 GetCompositeOpName()62 virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; } 63 64 // Rewrite the func body with constructed fused lstm. 65 LogicalResult RewriteFunc(); 66 GetNumInputs()67 int GetNumInputs() { return n_input_; } 68 69 protected: 70 // verify input func op arguments/attributes and initialize internal state. 71 virtual LogicalResult InitializeFromFuncAttributes(); 72 virtual LogicalResult Initialize(); 73 74 void UpdateFuncSignature(); 75 void GenerateFusedOpOperands(); 76 77 void SetWeightForInputToCellGate(); 78 void SetWeightForInputToInputGate(); 79 void SetWeightForInputToForgetGate(); 80 void SetWeightForInputToOutputGate(); 81 82 void SetWeightForRecurrentToCellGate(); 83 void SetWeightForRecurrentToInputGate(); 84 void SetWeightForRecurrentToForgetGate(); 85 void SetWeightForRecurrentToOutputGate(); 86 87 void SetBiasToCellGate(); 88 void SetBiasToInputGate(); 89 void SetBiasToForgetGate(); 90 void SetBiasToOutputGate(); 91 92 void SetProjection(); 93 void SetProjectionBias(); 94 95 void SetInputActivationState(); 96 void SetInputCellState(); 97 98 virtual void SetCellLayerNormCoefficients(); 99 virtual void SetInputLayerNormCoefficients(); 100 virtual void SetForgetLayerNormCoefficients(); 101 virtual void SetOutputLayerNormCoefficients(); 102 103 // specified state 104 FuncOp fused_func_op_; 105 Value input_; 106 Value weight_; 107 Value bias_; 108 Value projection_; 109 bool couple_input_forget_gates_; 110 111 // internal state 112 Value weight_transposed_; 113 Value projection_transposed_; 114 RankedTensorType weight_type_; 115 RankedTensorType projection_type_; 116 int num_gates_; 117 int n_cell_; 118 int n_output_; 119 int n_input_; 120 int num_cols_weight_transposed_; 121 int num_cols_projection_transposed_; 122 123 // input -> cifg 124 Value input2input_; 125 Value input2forget_; 126 Value input2cell_; 127 Value input2output_; 128 129 // recurrent -> cifg 130 Value rec2input_; 131 Value rec2forget_; 132 Value rec2cell_; 133 Value rec2output_; 134 135 // bias -> cifg 136 Value bias2input_; 137 Value bias2forget_; 138 Value bias2cell_; 139 Value bias2output_; 140 141 // projection 142 Value proj_weight_; 143 Value proj_bias_; 144 145 // state 146 Value input_activation_state_; 147 Value input_cell_state_; 148 149 // layer norm coefficients 150 Value input_layer_norm_coefficients_; 151 Value forget_layer_norm_coefficients_; 152 Value cell_layer_norm_coefficients_; 153 Value output_layer_norm_coefficients_; 154 155 mlir::TFL::LSTMOp lstm_; 156 157 Value none_; 158 SmallVector<int64_t, 1> bias_slice_shape_; 159 SmallVector<int64_t, 1> bias_size_values_; 160 SmallVector<int64_t, 2> weight_slice_shape_; 161 SmallVector<int64_t, 2> weight_slice_size_input_values_; 162 SmallVector<int64_t, 2> weight_slice_size_recurrent_values_; 163 OpBuilder builder_; 164 }; 165 166 // A utility class that enables the conversion of the 167 // LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The 168 // fused op is contained within a FuncOp that also contains other supporting ops 169 // needed to construct the operands for the fused op. The caller provides the 170 // containing FuncOp as input with arguments specifying the input, weight, 171 // projection, bias and layer norm scale. The weight, projection, bias and 172 // layer norm scale all need to be RankedTensorType. 173 // This class overrides the layer norm coefficient setters from the base class. 174 class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM 175 : public ConvertLSTMCellSimpleToFusedLSTM { 176 public: ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)177 explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( 178 mlir::FuncOp fused_func_op) 179 : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {} 180 181 // not copyable. 182 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( 183 const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; 184 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=( 185 const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM()186 ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {} 187 GetCompositeOpName()188 llvm::StringRef GetCompositeOpName() override { 189 return kLayerNormalizedLstmCellSimple; 190 } 191 192 protected: 193 LogicalResult Initialize() override; 194 195 void SetCellLayerNormCoefficients() override; 196 void SetInputLayerNormCoefficients() override; 197 void SetForgetLayerNormCoefficients() override; 198 void SetOutputLayerNormCoefficients() override; 199 200 private: 201 // specified state 202 Value layer_norm_scale_; 203 204 // internal state 205 RankedTensorType layer_norm_scale_type_; 206 SmallVector<int64_t, 1> layer_norm_slice_shape_; 207 SmallVector<int64_t, 1> layer_norm_size_values_; 208 }; 209 210 LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder); 211 212 } // end namespace TFL 213 } // end namespace mlir 214 215 #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 216