1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Identifier.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/OpDefinition.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41
42 namespace mlir {
43 namespace TFL {
44
45 namespace {
46
CreateI32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,int32_t val,mlir::Location location)47 Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
48 int32_t val, mlir::Location location) {
49 auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
50 auto attr = DenseElementsAttr::get(type, val);
51 return builder->create<ConstantOp>(location, type, attr);
52 }
53
CreateF32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)54 Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
55 float val, mlir::Location location) {
56 auto type = RankedTensorType::get(shape, builder->getF32Type());
57 auto attr = DenseElementsAttr::get(type, val);
58 return builder->create<ConstantOp>(location, type, attr);
59 }
60
CreatTfF32ConstOp(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)61 Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef<int64_t> shape, float val,
62 mlir::Location location) {
63 auto type = RankedTensorType::get(shape, builder->getF32Type());
64 auto ele_type = RankedTensorType::get({1}, builder->getF32Type());
65 auto attr = DenseElementsAttr::get(ele_type, val);
66 return builder->create<TF::ConstOp>(location, type, attr);
67 }
68
CreateI64DenseConst(OpBuilder * builder,ArrayRef<int64_t> shape,ArrayRef<int64_t> values,mlir::Location location)69 Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
70 ArrayRef<int64_t> values, mlir::Location location) {
71 auto type = RankedTensorType::get(static_cast<int>(shape.size()),
72 builder->getIntegerType(64));
73 auto attr = DenseElementsAttr::get(type, values);
74 return builder->create<ConstantOp>(location, type, attr);
75 }
76
CreateI32DenseConst(OpBuilder * builder,ArrayRef<int32_t> values,mlir::Location location)77 Value CreateI32DenseConst(OpBuilder* builder, ArrayRef<int32_t> values,
78 mlir::Location location) {
79 auto type = RankedTensorType::get(static_cast<int>(values.size()),
80 builder->getIntegerType(32));
81 auto attr = DenseElementsAttr::get(type, values);
82 return builder->create<ConstantOp>(location, type, attr);
83 }
84
CreateNoneValue(OpBuilder * builder,mlir::Location location)85 Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
86 return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
87 builder->getUnitAttr());
88 }
89
Transpose(OpBuilder * builder,Value value_to_transpose,SmallVector<int32_t,4> perm,RankedTensorType original_type,mlir::Location location)90 Value Transpose(OpBuilder* builder, Value value_to_transpose,
91 SmallVector<int32_t, 4> perm, RankedTensorType original_type,
92 mlir::Location location) {
93 // Create a constant op for transpose permutation.
94 auto perm_op = CreateI32DenseConst(builder, perm, location);
95
96 // Create tensor type for the transpose result.
97 auto transpose_type = original_type;
98 auto transpose_shape =
99 llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
100 return transpose_type.getDimSize(dim);
101 }));
102 auto elem_type = transpose_type.getElementType();
103 auto result_type = RankedTensorType::get(transpose_shape, elem_type);
104
105 return builder->create<TF::TransposeOp>(location, result_type,
106 value_to_transpose, perm_op);
107 }
108
Transpose2D(OpBuilder * builder,Value value_to_transpose,RankedTensorType type,mlir::Location location)109 Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
110 RankedTensorType type, mlir::Location location) {
111 // Create a constant op for transpose permutation.
112 SmallVector<int32_t, 4> perm = {1, 0};
113 return Transpose(builder, value_to_transpose, perm, type, location);
114 }
115
Reverse(OpBuilder * builder,Value value_to_reverse,int axis,RankedTensorType type,mlir::Location location)116 Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
117 RankedTensorType type, mlir::Location location) {
118 auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
119 // The result type will be the same as the input.
120 return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
121 axis_op);
122 }
123
GetRankedTensorShape(Value value)124 ArrayRef<int64_t> GetRankedTensorShape(Value value) {
125 return value.getType().cast<RankedTensorType>().getShape();
126 }
127
SliceRankedTensor(OpBuilder * builder,Value input,ArrayRef<int64_t> begin_shape,ArrayRef<int64_t> begin_values,ArrayRef<int64_t> size_shape,ArrayRef<int64_t> size_values,mlir::Location location)128 Value SliceRankedTensor(OpBuilder* builder, Value input,
129 ArrayRef<int64_t> begin_shape,
130 ArrayRef<int64_t> begin_values,
131 ArrayRef<int64_t> size_shape,
132 ArrayRef<int64_t> size_values,
133 mlir::Location location) {
134 // If the size of the tensor to be sliced from the input overflows
135 // the input tensor's dimensions, return 0-valued tensor of the requested
136 // shape.
137 ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
138 for (int i = 0, end = input_shape.size(); i < end; i++) {
139 if (begin_values[i] < 0 ||
140 (begin_values[i] + size_values[i] > input_shape[i])) {
141 return CreateF32SplatConst(builder, size_shape, 0, location);
142 }
143 }
144
145 // Create a dense constant op for slice's begin
146 auto slice_i2c_begin =
147 CreateI64DenseConst(builder, begin_shape, begin_values, location);
148
149 // Create a dense constant op for slice's size
150 auto slice_i2c_size =
151 CreateI64DenseConst(builder, size_shape, size_values, location);
152
153 return builder->create<TF::SliceOp>(
154 location,
155 RankedTensorType::get(
156 size_values,
157 input.getType().cast<RankedTensorType>().getElementType()),
158 input, slice_i2c_begin, slice_i2c_size);
159 }
160
CreateStridedSliceOp(mlir::Location loc,ArrayRef<int64_t> output_shape,Value input,ArrayRef<int32_t> begin,ArrayRef<int32_t> end,ArrayRef<int32_t> strides,int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask,OpBuilder * builder)161 Value CreateStridedSliceOp(mlir::Location loc, ArrayRef<int64_t> output_shape,
162 Value input, ArrayRef<int32_t> begin,
163 ArrayRef<int32_t> end, ArrayRef<int32_t> strides,
164 int64_t begin_mask, int64_t end_mask,
165 int64_t ellipsis_mask, int64_t new_axis_mask,
166 int64_t shrink_axis_mask, OpBuilder* builder) {
167 auto output_type = RankedTensorType::get(
168 output_shape, input.getType().cast<RankedTensorType>().getElementType());
169 auto begin_tensor = CreateI32DenseConst(builder, begin, loc);
170 auto end_tensor = CreateI32DenseConst(builder, end, loc);
171 auto strides_tensor = CreateI32DenseConst(builder, strides, loc);
172
173 return builder->create<TF::StridedSliceOp>(
174 loc, output_type, input, begin_tensor, end_tensor, strides_tensor,
175 builder->getI64IntegerAttr(begin_mask),
176 builder->getI64IntegerAttr(end_mask),
177 builder->getI64IntegerAttr(ellipsis_mask),
178 builder->getI64IntegerAttr(new_axis_mask),
179 builder->getI64IntegerAttr(shrink_axis_mask));
180 }
181
182 } // namespace
183
SetWeightForInputToCellGate()184 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
185 SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
186 input2cell_ = SliceRankedTensor(
187 &builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
188 weight_slice_shape_, weight_slice_size_input_values_,
189 fused_func_op_.getLoc());
190 }
191
SetWeightForInputToInputGate()192 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
193 SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
194 input2input_ = couple_input_forget_gates_
195 ? none_
196 : SliceRankedTensor(&builder_, weight_transposed_,
197 weight_slice_shape_, begin_i2i_values,
198 weight_slice_shape_,
199 weight_slice_size_input_values_,
200 fused_func_op_.getLoc());
201 }
202
SetWeightForInputToForgetGate()203 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
204 int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
205 SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
206 input2forget_ = SliceRankedTensor(
207 &builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
208 weight_slice_shape_, weight_slice_size_input_values_,
209 fused_func_op_.getLoc());
210 }
211
SetWeightForInputToOutputGate()212 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
213 int input_output_start =
214 couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
215 SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
216 input2output_ = SliceRankedTensor(
217 &builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
218 weight_slice_shape_, weight_slice_size_input_values_,
219 fused_func_op_.getLoc());
220 }
221
SetWeightForRecurrentToCellGate()222 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
223 SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
224 rec2cell_ = SliceRankedTensor(
225 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
226 weight_slice_shape_, weight_slice_size_recurrent_values_,
227 fused_func_op_.getLoc());
228 }
229
SetWeightForRecurrentToInputGate()230 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
231 SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
232 rec2input_ = couple_input_forget_gates_
233 ? none_
234 : SliceRankedTensor(&builder_, weight_transposed_,
235 weight_slice_shape_, begin_rec2i_values,
236 weight_slice_shape_,
237 weight_slice_size_recurrent_values_,
238 fused_func_op_.getLoc());
239 }
240
SetWeightForRecurrentToForgetGate()241 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
242 int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
243 SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
244 rec2forget_ = SliceRankedTensor(
245 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
246 weight_slice_shape_, weight_slice_size_recurrent_values_,
247 fused_func_op_.getLoc());
248 }
249
SetWeightForRecurrentToOutputGate()250 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
251 int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
252 SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
253 rec2output_ = SliceRankedTensor(
254 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
255 weight_slice_shape_, weight_slice_size_recurrent_values_,
256 fused_func_op_.getLoc());
257 }
258
SetBiasToCellGate()259 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
260 SmallVector<int64_t, 1> begin_bias2c_values = {0};
261 bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
262 begin_bias2c_values, bias_slice_shape_,
263 bias_size_values_, fused_func_op_.getLoc());
264 }
265
SetBiasToInputGate()266 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
267 SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
268 bias2input_ =
269 couple_input_forget_gates_
270 ? none_
271 : SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
272 begin_bias2i_values, bias_slice_shape_,
273 bias_size_values_, fused_func_op_.getLoc());
274 }
275
SetBiasToForgetGate()276 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
277 int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
278 SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
279 bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
280 begin_bias2f_values, bias_slice_shape_,
281 bias_size_values_, fused_func_op_.getLoc());
282 }
283
SetBiasToOutputGate()284 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
285 int bias_output_start =
286 couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
287 SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
288 bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
289 begin_bias2o_values, bias_slice_shape_,
290 bias_size_values_, fused_func_op_.getLoc());
291 }
292
SetProjection()293 void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
294 SmallVector<int64_t, 2> projection_slice_shape = {
295 1, num_cols_projection_transposed_};
296 SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
297 SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
298 proj_weight_ =
299 !projection_
300 ? none_
301 : SliceRankedTensor(
302 &builder_, projection_transposed_, projection_slice_shape,
303 projection_slice_begin_values, projection_slice_shape,
304 projection_slice_size_values, fused_func_op_.getLoc());
305 }
306
SetProjectionBias()307 void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
308 proj_bias_ = !projection_type_
309 ? none_
310 : CreateF32SplatConst(&builder_, {n_output_}, 0,
311 fused_func_op_.getLoc());
312 }
313
SetInputActivationState()314 void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
315 input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
316 fused_func_op_.getLoc());
317 }
318
SetInputCellState()319 void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
320 input_cell_state_ =
321 CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
322 }
323
SetCellLayerNormCoefficients()324 void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
325 cell_layer_norm_coefficients_ = none_;
326 }
327
SetInputLayerNormCoefficients()328 void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
329 input_layer_norm_coefficients_ = none_;
330 }
331
SetForgetLayerNormCoefficients()332 void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
333 forget_layer_norm_coefficients_ = none_;
334 }
SetOutputLayerNormCoefficients()335 void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
336 output_layer_norm_coefficients_ = none_;
337 }
338
GenerateFusedOpOperands()339 void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
340 // Transpose both weight and projection.
341 weight_transposed_ =
342 Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
343 projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
344 fused_func_op_.getLoc());
345
346 none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
347 // Extract input to cifg gates via slicing the weight tensor
348 SetWeightForInputToCellGate();
349 SetWeightForInputToInputGate();
350 SetWeightForInputToForgetGate();
351 SetWeightForInputToOutputGate();
352
353 // Extract recurrent to cifg gates via slicing the weight tensor
354 SetWeightForRecurrentToCellGate();
355 SetWeightForRecurrentToInputGate();
356 SetWeightForRecurrentToForgetGate();
357 SetWeightForRecurrentToOutputGate();
358
359 // Extract bias to cifg gates via slicing the bias tensor
360 SetBiasToCellGate();
361 SetBiasToInputGate();
362 SetBiasToForgetGate();
363 SetBiasToOutputGate();
364
365 // Extract projection and set an empty projection bias
366 SetProjection();
367 SetProjectionBias();
368
369 // Set the variable tensors
370 SetInputActivationState();
371 SetInputCellState();
372
373 // Extract the layer norm coefficients
374 SetCellLayerNormCoefficients();
375 SetInputLayerNormCoefficients();
376 SetForgetLayerNormCoefficients();
377 SetOutputLayerNormCoefficients();
378 }
379
UpdateFuncSignature()380 void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
381 // https://github.com/tensorflow/community/pull/113
382 SmallVector<int64_t, 2> output_shape{1, -1};
383 auto input_types = fused_func_op_.getType().getInputs();
384 auto output_type = mlir::RankedTensorType::get(
385 output_shape, input_.getType().cast<RankedTensorType>().getElementType());
386 fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
387 input_types, output_type));
388 }
389
RewriteFunc()390 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
391 LogicalResult result = Initialize();
392 if (failed(result)) {
393 return result;
394 }
395
396 // Update the func signature, based on output shape.
397 // The func will ultimately return the output of the fused
398 // LSTM op.
399 UpdateFuncSignature();
400
401 // Transform the weights, projection, bias and layer norm coefficients
402 // to generate operands for the TFL fused LSTM op.
403 GenerateFusedOpOperands();
404
405 // Create the fused LSTM op.
406 SmallVector<int64_t, 2> output_shape = {1, n_output_};
407 auto result_type = mlir::RankedTensorType::get(
408 output_shape, input_.getType().cast<RankedTensorType>().getElementType());
409 lstm_ = builder_.create<mlir::TFL::LSTMOp>(
410 fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
411 input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
412 rec2output_, /*cell_to_input_weights*/ none_,
413 /*cell_to_forget_weights*/ none_,
414 /*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
415 bias2output_, proj_weight_, proj_bias_, input_activation_state_,
416 input_cell_state_, input_layer_norm_coefficients_,
417 forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
418 output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
419 builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
420 builder_.getStringAttr("FULL"),
421 /*input_to_input_intermediate=*/mlir::TypeAttr(),
422 /*input_to_forget_intermediate=*/mlir::TypeAttr(),
423 /*input_to_cell_intermediate=*/mlir::TypeAttr(),
424 /*input_to_output_intermediate=*/mlir::TypeAttr(),
425 /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
426
427 // Cast the static shaped lstm result to FuncOp's signature -
428 // Ranked but unknown 2nd dimension to support stacking these.
429 SmallVector<int64_t, 2> func_output_shape = {1, -1};
430 auto func_result_type = mlir::RankedTensorType::get(
431 func_output_shape,
432 input_.getType().cast<RankedTensorType>().getElementType());
433
434 auto tensor_cast = builder_.create<mlir::tensor::CastOp>(
435 fused_func_op_.getLoc(), func_result_type, lstm_.getResult());
436 builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(),
437 tensor_cast.getResult());
438 return success();
439 }
440
InitializeFromFuncAttributes()441 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
442 auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
443 if (!attr) {
444 return fused_func_op_.emitError()
445 << "Invalid function attribute, expected " << kTFImplements
446 << " attribute "
447 "not found";
448 }
449
450 // TODO(ashwinm, b/144775479): Make these NamedAttribute on TF import
451 // once tf.function can support this.
452 llvm::SmallVector<llvm::StringRef, 4> attr_tokens;
453 attr.getValue().split(attr_tokens, ",");
454 if (attr_tokens.empty()) {
455 return fused_func_op_.emitError()
456 << kTFImplements << " attribute should be set";
457 }
458
459 // Check if the interface matches.
460 if (GetCompositeOpName().str() != attr_tokens[0]) {
461 return fused_func_op_.emitError()
462 << "Unexpected interface for the composite op. Expected: "
463 << GetCompositeOpName() << " Actual: " << attr_tokens[0];
464 }
465
466 // Extract other interface attributes, for now cifg.
467 couple_input_forget_gates_ =
468 std::find(attr_tokens.begin() + 1, attr_tokens.end(),
469 kCoupleInputForgetGates) != attr_tokens.end();
470
471 return success();
472 }
473
Initialize()474 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
475 if (failed(InitializeFromFuncAttributes())) {
476 return fused_func_op_.emitError()
477 << "Expected function attributes were not set on the function "
478 "encapsulating the composite op";
479 }
480
481 num_gates_ = couple_input_forget_gates_ ? 3 : 4;
482
483 input_ = fused_func_op_.getArgument(0);
484 bias_ = fused_func_op_.getArgument(2);
485
486 weight_ = fused_func_op_.getArgument(1);
487 weight_type_ = weight_.getType().cast<RankedTensorType>();
488
489 if (weight_type_.getRank() != 2) {
490 return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
491 }
492
493 if (weight_type_.getDimSize(1) % num_gates_ != 0) {
494 return fused_func_op_.emitError()
495 << "Invalid dimension 1 of weight tensor, "
496 "should be divisible by the number of gates";
497 }
498 n_cell_ = weight_type_.getDimSize(1) / num_gates_;
499
500 projection_ = fused_func_op_.getArgument(3);
501 projection_type_ = projection_.getType().cast<RankedTensorType>();
502 if (projection_type_.getRank() != 2) {
503 n_output_ = n_cell_;
504 } else {
505 n_output_ = projection_type_.getDimSize(1);
506 }
507 n_input_ = weight_type_.getDimSize(0) - n_output_;
508 num_cols_weight_transposed_ = weight_type_.getDimSize(0);
509 num_cols_projection_transposed_ = projection_type_.getDimSize(0);
510
511 bias_slice_shape_ = {n_cell_};
512 bias_size_values_ = {n_cell_};
513 weight_slice_shape_ = {1, num_cols_weight_transposed_};
514 weight_slice_size_input_values_ = {n_cell_, n_input_};
515 weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
516
517 return success();
518 }
519
Initialize()520 LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
521 if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
522 return fused_func_op_.emitError()
523 << "Specified LayerNormalizedLSTMCellSimple was not of the expected "
524 "interface and cannot not be converted to the fused LSTM op";
525 }
526
527 layer_norm_scale_ = fused_func_op_.getArgument(4);
528 layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
529 if (layer_norm_scale_type_.getRank() != 1) {
530 return fused_func_op_.emitError()
531 << "The layer_norm_scale tensor was not of rank 1";
532 }
533 layer_norm_slice_shape_ = {n_cell_};
534 layer_norm_size_values_ = {n_cell_};
535
536 return success();
537 }
538
539 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients()540 SetCellLayerNormCoefficients() {
541 SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
542 cell_layer_norm_coefficients_ =
543 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
544 begin_cell_layer_norm_values, layer_norm_slice_shape_,
545 layer_norm_size_values_, fused_func_op_.getLoc());
546 }
547
548 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients()549 SetInputLayerNormCoefficients() {
550 SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
551 input_layer_norm_coefficients_ =
552 couple_input_forget_gates_
553 ? none_
554 : SliceRankedTensor(
555 &builder_, layer_norm_scale_, layer_norm_slice_shape_,
556 begin_input_layer_norm_values, layer_norm_slice_shape_,
557 layer_norm_size_values_, fused_func_op_.getLoc());
558 }
559
560 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients()561 SetForgetLayerNormCoefficients() {
562 SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
563 forget_layer_norm_coefficients_ =
564 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
565 begin_forget_layer_norm_values, layer_norm_slice_shape_,
566 layer_norm_size_values_, fused_func_op_.getLoc());
567 }
568
569 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients()570 SetOutputLayerNormCoefficients() {
571 SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
572 output_layer_norm_coefficients_ =
573 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
574 begin_output_layer_norm_values, layer_norm_slice_shape_,
575 layer_norm_size_values_, fused_func_op_.getLoc());
576 }
577
Create1DConstantOp(const std::vector<int> & value,Location loc,OpBuilder * builder)578 TF::ConstOp Create1DConstantOp(const std::vector<int>& value, Location loc,
579 OpBuilder* builder) {
580 auto type =
581 mlir::RankedTensorType::get(value.size(), builder->getIntegerType(32));
582 auto dense_values = mlir::DenseIntElementsAttr::get(type, value);
583 return builder->create<TF::ConstOp>(loc, dense_values);
584 }
585
CreateScalarConstantOp(int value,Location loc,OpBuilder * builder)586 TF::ConstOp CreateScalarConstantOp(int value, Location loc,
587 OpBuilder* builder) {
588 return builder->create<TF::ConstOp>(loc, builder->getI32IntegerAttr(value));
589 }
590
CreateEqualSizeSplitVOp(Value input,int axis,int splits,Location loc,OpBuilder * builder,Operation ** result)591 LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
592 Location loc, OpBuilder* builder,
593 Operation** result) {
594 auto input_type = input.getType().cast<RankedTensorType>();
595 SmallVector<int64_t, 4> output_shape;
596 int size_of_splits;
597 if (input_type.getRank() < axis || axis < 0) return failure();
598 for (int i = 0; i < input_type.getRank(); ++i) {
599 int dim = input_type.getDimSize(i);
600 if (i == axis) {
601 if (dim % splits != 0) {
602 return failure();
603 }
604 size_of_splits = dim / splits;
605 output_shape.push_back(size_of_splits);
606 } else {
607 output_shape.push_back(dim);
608 }
609 }
610
611 SmallVector<mlir::Type, 4> output_types;
612 for (int i = 0; i < splits; ++i) {
613 output_types.push_back(
614 mlir::RankedTensorType::get(output_shape, input_type.getElementType()));
615 }
616 auto size_of_splits_op = Create1DConstantOp(
617 {size_of_splits, size_of_splits, size_of_splits, size_of_splits}, loc,
618 builder);
619
620 auto axis_op = CreateScalarConstantOp(axis, loc, builder);
621 *result = builder->create<TF::SplitVOp>(loc, output_types, input,
622 size_of_splits_op.getResult(),
623 axis_op.getResult());
624 return success();
625 }
626
627 // TODO(b/147436982): Consider refactor this to be more general.
ConvertKerasLSTMLayer(mlir::FuncOp func_op,OpBuilder * builder)628 LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
629 // For argument order, please check out standard_lstm under
630 // tensorflow/python/keras/layers/recurrent_v2.py
631 Value input = func_op.getArgument(0);
632 Value output_init_state = func_op.getArgument(1);
633 Value hidden_init_state = func_op.getArgument(2);
634 Value weight_kernel = func_op.getArgument(3);
635 Value recurrent_kernel = func_op.getArgument(4);
636 Value bias = func_op.getArgument(5);
637
638 // The func op should have 5 outputs.
639 if (func_op.getNumResults() != 5) return failure();
640
641 // TFL lstm only supports time-majored inputs, so if it's not time-majored,
642 // we will transpose the inputs and outputs.
643 auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
644 if (time_major_attr == nullptr) return failure();
645
646 bool time_majored = time_major_attr.getValue();
647 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
648 if (!input_type) {
649 func_op.emitError() << "Input type is not a ranked tensor type";
650 return failure();
651 }
652
653 auto final_inputs = input;
654 auto final_input_type = input_type;
655
656 // Handle go_backwards:
657 // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
658 auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
659
660 if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
661 int time_dim = time_majored ? 0 : 1;
662 final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type,
663 func_op.getLoc());
664 }
665
666 int batch = time_majored ? final_input_type.getDimSize(1)
667 : final_input_type.getDimSize(0);
668 int time = time_majored ? final_input_type.getDimSize(0)
669 : final_input_type.getDimSize(1);
670
671 // Setup correct weights.
672 RankedTensorType weight_type =
673 weight_kernel.getType().cast<RankedTensorType>();
674 if (weight_type.getRank() != 2)
675 return func_op.emitError() << "The weight should be rank of 2";
676
677 Value transposed_weight_kernel =
678 Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc());
679
680 RankedTensorType recurrent_kernel_type =
681 recurrent_kernel.getType().cast<RankedTensorType>();
682 const int n_output = recurrent_kernel_type.getDimSize(0);
683
684 Value transpose_recurrent_kernel = Transpose2D(
685 builder, recurrent_kernel, recurrent_kernel_type, func_op.getLoc());
686
687 // Splits the weights into 4: i, f, c, o.
688 const int splits = 4;
689
690 Operation* weights_array;
691 if (failed(CreateEqualSizeSplitVOp(transposed_weight_kernel, 0, splits,
692 func_op.getLoc(), builder,
693 &weights_array)))
694 return failure();
695
696 // Splits the recurrent_weights into 4:
697 Operation* recurrent_weights_array;
698 if (failed(CreateEqualSizeSplitVOp(transpose_recurrent_kernel, 0, splits,
699 func_op.getLoc(), builder,
700 &recurrent_weights_array)))
701 return failure();
702
703 // Splits the bias into 4:
704 Operation* bias_array;
705 if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder,
706 &bias_array)))
707 return failure();
708
709 // Build the lstm op.
710 SmallVector<int64_t, 3> output_shape;
711 if (time_majored) {
712 output_shape = {time, batch, n_output};
713 } else {
714 output_shape = {batch, time, n_output};
715 }
716 auto result_type = mlir::RankedTensorType::get(
717 output_shape,
718 final_inputs.getType().cast<RankedTensorType>().getElementType());
719
720 Value none = builder->create<mlir::ConstantOp>(
721 func_op.getLoc(), builder->getNoneType(), builder->getUnitAttr());
722 auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
723 func_op.getLoc(), result_type, /*input=*/final_inputs,
724 /*input_to_input_weights=*/weights_array->getResult(0),
725 /*input_to_forget_weights=*/weights_array->getResult(1),
726 /*input_to_cell_weights=*/weights_array->getResult(2),
727 /*input_to_output_weights=*/weights_array->getResult(3),
728 /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
729 /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
730 /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
731 /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
732 /*cell_to_input_weights=*/none,
733 /*cell_to_forget_weights=*/none,
734 /*cell_to_output_weights=*/none,
735 /*input_gate_bias=*/bias_array->getResult(0),
736 /*forget_gate_bias=*/bias_array->getResult(1),
737 /*cell_bias=*/bias_array->getResult(2),
738 /*output_gate_bias=*/bias_array->getResult(3),
739 /*projection_weights=*/none,
740 /*projection_bias=*/none,
741 /*input_activation_state=*/output_init_state,
742 /*input_cell_state=*/hidden_init_state,
743 /*input_layer_norm_coefficients=*/none,
744 /*forget_layer_norm_coefficients=*/none,
745 /*cell_layer_norm_coefficients=*/none,
746 /*output_layer_norm_coefficients=*/none, builder->getStringAttr("TANH"),
747 builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
748 builder->getBoolAttr(time_majored),
749 /*input_to_input_intermediate=*/mlir::TypeAttr(),
750 /*input_to_forget_intermediate=*/mlir::TypeAttr(),
751 /*input_to_cell_intermediate=*/mlir::TypeAttr(),
752 /*input_to_output_intermediate=*/mlir::TypeAttr(),
753 /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
754
755 auto final_output_full_sequences = lstm.getResult();
756
757 // Populate the last output: last output is sliced from the full sequences.
758 // If time_major: last_output = outputs[-1, :, :]
759 // else: last_output = outputs[:, -1, :]
760 //
761 // As we are creating the strided_slice op, we need to populate the following
762 // fields:
763 // end: should always be (0, 0, 0)
764 // strides: should always be (1, 1, 1)
765 // begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored.
766 // new_axis_mask: should always be 0.
767 // ellipsis_mask: should always be 0.
768 // begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's
769 // time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's
770 // time-majored.
771 SmallVector<int64_t, 2> last_output_shape({batch, n_output});
772
773 SmallVector<int32_t, 3> end({0, 0, 0});
774 SmallVector<int32_t, 3> strides({1, 1, 1});
775 SmallVector<int32_t, 3> begin;
776
777 int64_t new_axis_mask = 0;
778 int64_t ellipsis_mask = 0;
779 int64_t begin_mask;
780 int64_t end_mask;
781 int64_t shrink_axis_mask;
782 if (time_majored) {
783 begin_mask = 6;
784 end_mask = 6;
785 shrink_axis_mask = 1;
786 begin = {-1, 0, 0};
787 } else {
788 begin_mask = 5;
789 end_mask = 5;
790 shrink_axis_mask = 2;
791 begin = {0, -1, 0};
792 }
793
794 auto last_output = CreateStridedSliceOp(
795 func_op.getLoc(), last_output_shape, final_output_full_sequences, begin,
796 end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
797 shrink_axis_mask, builder);
798
799 SmallVector<Value, 5> outputs;
800 SmallVector<Type, 5> output_types;
801
802 // Due to the existence of the while loop, the timestamp may be unknown
803 // for the signature, for us, since we know the inputs, we can infer the time
804 // steps.
805
806 // Last output.
807 outputs.push_back(last_output);
808 output_types.push_back(last_output.getType());
809
810 // Full sequences.
811 outputs.push_back(final_output_full_sequences);
812 output_types.push_back(final_output_full_sequences.getType());
813
814 // All the rest: states, device.
815 for (int i = 2; i < 5; ++i) {
816 auto result_type =
817 func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
818 outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
819 func_op.getLoc()));
820 output_types.push_back(result_type);
821 }
822
823 // Update function signatures.
824 func_op.setType(mlir::FunctionType::get(
825 func_op.getContext(), func_op.getType().getInputs(), output_types));
826
827 builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
828 return success();
829 }
830
831 } // namespace TFL
832 } // namespace mlir
833