• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <string>
24 
25 #include "third_party/eigen3/Eigen/Core"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
35 #include "mlir/IR/Attributes.h"  // from @llvm-project
36 #include "mlir/IR/Builders.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/Location.h"  // from @llvm-project
40 #include "mlir/IR/Matchers.h"  // from @llvm-project
41 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
42 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
43 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
44 #include "mlir/Support/LLVM.h"  // from @llvm-project
45 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
46 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
47 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
48 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
49 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
50 #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/core/framework/kernel_shape_util.h"
55 
56 namespace mlir {
57 namespace TFL {
58 namespace {
59 
getDefiningBroadcastArgsOp(Value operand)60 Operation *getDefiningBroadcastArgsOp(Value operand) {
61   auto *defining_op = operand.getDefiningOp();
62   if (!llvm::dyn_cast_or_null<TF::BroadcastToOp>(defining_op) &&
63       !llvm::dyn_cast_or_null<TFL::BroadcastToOp>(defining_op)) {
64     return nullptr;
65   }
66 
67   Value broadcast_shape = defining_op->getOperand(
68       1);  // Broadcasted shape operand of BroadcastTo op.
69   Operation *parent_of_defining_op = broadcast_shape.getDefiningOp();
70   if (!llvm::dyn_cast_or_null<TF::BroadcastArgsOp>(parent_of_defining_op) &&
71       !llvm::dyn_cast_or_null<TFL::BroadcastArgsOp>(parent_of_defining_op)) {
72     return nullptr;
73   }
74   return parent_of_defining_op;
75 }
76 
77 }  // namespace
78 
79 // Returns true when the given operand arguments have the same shape or
80 // broadcastable shape within the given rank. If any given shapes are
81 // non-static and maximum rank is within the given rank, this method returns
82 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)83 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
84     Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
85   if (indices.empty()) return true;
86 
87   // First, it checks there are any inputs that has unknown rank.
88   bool has_unknown_shape_input = false;
89   bool has_same_shape = true;
90   bool reach_first_known_shape = false;
91   int64_t max_rank = -1;
92 
93   ArrayRef<int64_t> pivot_shape;
94   SmallVector<int64_t, 4> current_shape;
95   SmallVector<int64_t, 4> result_shape;
96 
97   for (unsigned index : indices) {
98     ShapedType shaped_type =
99         op->getOperand(index).getType().dyn_cast<ShapedType>();
100     if (!shaped_type || !shaped_type.hasRank()) {
101       // Marks that we have an unknown rank input.
102       has_unknown_shape_input = true;
103       continue;
104     }
105     max_rank = std::max(max_rank, shaped_type.getRank());
106     if (!shaped_type.hasStaticShape()) {
107       // Marks that we have an unknown shape input.
108       has_unknown_shape_input = true;
109       continue;
110     }
111 
112     ArrayRef<int64_t> shape = shaped_type.getShape();
113     if (!reach_first_known_shape) {
114       pivot_shape = shape;
115       current_shape.assign(shape.begin(), shape.end());
116       reach_first_known_shape = true;
117       continue;
118     }
119 
120     if (!pivot_shape.equals(shape)) {
121       has_same_shape = false;
122     }
123     //  Checks if all the inputs are broadcastable since they have not all the
124     //  same shapes.
125     if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
126                                             result_shape)) {
127       return false;
128     }
129     current_shape = result_shape;
130   }
131 
132   // If all the shape is known and same, CPU kernels are able to handle inputs
133   // regardless of dimension size.
134   if (!has_unknown_shape_input) {
135     return has_same_shape || max_rank <= max_bcast_rank;
136   }
137 
138   // It will treat the unknown shape inputs as acceptable inputs for model
139   // compatibility if all known ranks are no bigger than the allowed broadcast
140   // maximum rank.
141   if (max_rank <= max_bcast_rank) {
142     return true;
143   }
144 
145   // Checks if all operands are broadcasted by BroadcastTo ops with the shape
146   // is calculated from the same BroadcastArgs op. In such case, all operands
147   // will have the same shape.
148   Operation *broadcast_args_pivot = nullptr;
149   for (unsigned index : indices) {
150     Operation *parent_broadcast_args =
151         getDefiningBroadcastArgsOp(op->getOperand(index));
152     if (parent_broadcast_args == nullptr) {
153       return false;
154     }
155 
156     if (broadcast_args_pivot == nullptr) {
157       broadcast_args_pivot = parent_broadcast_args;
158       continue;
159     }
160 
161     if (broadcast_args_pivot != parent_broadcast_args) {
162       return false;
163     }
164   }
165   return true;
166 }
167 
168 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)169 bool IsQI8Type(Type element_type) {
170   auto quantized_type = element_type.dyn_cast<QuantizedType>();
171   return quantized_type != nullptr &&
172          quantized_type.getStorageTypeIntegralWidth() == 8 &&
173          quantized_type.isSigned();
174 }
175 
176 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)177 bool IsQUI8Type(Type element_type) {
178   auto quantized_type = element_type.dyn_cast<QuantizedType>();
179   return quantized_type != nullptr &&
180          quantized_type.getStorageTypeIntegralWidth() == 8 &&
181          !quantized_type.isSigned();
182 }
183 
184 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)185 bool IsQI16Type(Type element_type) {
186   auto quantized_type = element_type.dyn_cast<QuantizedType>();
187   return quantized_type != nullptr &&
188          quantized_type.getStorageTypeIntegralWidth() == 16 &&
189          quantized_type.isSigned();
190 }
191 
192 // Return true when the given element_type is I32.
IsI32Type(Type element_type)193 bool IsI32Type(Type element_type) {
194   return element_type.isInteger(32) && !element_type.isUnsignedInteger();
195 }
196 
197 // Return true when the given element_type is I64.
IsI64Type(Type element_type)198 bool IsI64Type(Type element_type) {
199   return element_type.isInteger(64) && !element_type.isUnsignedInteger();
200 }
201 
202 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)203 bool EqualsZero(Value value) {
204   DenseElementsAttr constant;
205   if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
206     return false;
207   }
208 
209   Type element_type = value.getType().cast<ShapedType>().getElementType();
210   if (element_type.isa<FloatType>()) {
211     return constant.getSplatValue<APFloat>().isZero();
212   } else {
213     return false;
214   }
215 }
216 
217 // Replaces the bias operand with a "none" type value if the bias value is
218 // constant zero.
219 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
220 // bias operand named 'bias'.
221 template <typename ConcreteOpType>
222 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
223   using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
224 
matchAndRewritemlir::TFL::RemoveOptionalZeroBias225   LogicalResult matchAndRewrite(ConcreteOpType op,
226                                 PatternRewriter &rewriter) const override {
227     if (EqualsZero(op.bias())) {
228       auto none_value = rewriter.create<mlir::ConstantOp>(
229           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
230       op.biasMutable().assign(none_value);
231     }
232 
233     return success();
234   }
235 };
236 
237 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)238 bool VerifyAddOpShapeConstraints(AddOp op) {
239   auto element_type = getElementTypeOrSelf(op.output().getType());
240 
241   // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
242   // which are broadcastable shapes up to four dimensions or have same shapes.
243   if (element_type.isF32() || IsQI8Type(element_type) ||
244       IsQUI8Type(element_type) || IsI32Type(element_type) ||
245       IsI64Type(element_type)) {
246     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
247         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
248         /*max_bcast_rank=*/4);
249   }
250 
251   // Allows QI16 output when operands have the same shape.
252   if (IsQI16Type(element_type)) {
253     return succeeded(
254         mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
255   }
256   return false;
257 }
258 
259 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)260 bool VerifySubOpShapeConstraints(SubOp op) {
261   auto element_type = getElementTypeOrSelf(op.output().getType());
262 
263   // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
264   // which are broadcastable shapes up to five dimension or have same shapes.
265   if (element_type.isF32() || IsI32Type(element_type) ||
266       IsI64Type(element_type) || IsQUI8Type(element_type) ||
267       IsQI16Type(element_type)) {
268     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
269         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
270         /*max_bcast_rank=*/5);
271   }
272 
273   // Allows QI8 output when the operands have valid shapes, which are
274   // broadcastable shapes up to four dimension or have same shapes.
275   if (IsQI8Type(element_type)) {
276     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
277         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
278         /*max_bcast_rank=*/4);
279   }
280   return false;
281 }
282 
283 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)284 bool VerifyMulOpShapeConstraints(MulOp op) {
285   auto element_type = getElementTypeOrSelf(op.output().getType());
286 
287   // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
288   // output type is not QI16. If the output type is Q16, allows only the same
289   // shape operands.
290   if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
291     if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
292       return succeeded(
293           mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
294     }
295     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
296         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
297         /*max_bcast_rank=*/4);
298   }
299 
300   // Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
301   // which are broadcastable shapes up to four dimension or have same shapes.
302   if (IsI32Type(element_type) || IsI64Type(element_type) ||
303       IsQI16Type(element_type) || element_type.isF32()) {
304     return VerifyOperandsHaveSameShapesOrBroadcastableShape(
305         /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
306         /*max_bcast_rank=*/4);
307   }
308   return false;
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // TensorFlowLiteDialect
313 //===----------------------------------------------------------------------===//
314 
315 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
316   using DialectInlinerInterface::DialectInlinerInterface;
317 
318   //===--------------------------------------------------------------------===//
319   // Analysis Hooks
320   //===--------------------------------------------------------------------===//
321 
322   // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface323   bool isLegalToInline(Operation *call, Operation *callable,
324                        bool wouldBeCloned) const final {
325     return true;
326   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface327   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
328                        BlockAndValueMapping &) const final {
329     // No TFLite op restricts inlining today, revise as needed in the future.
330     return true;
331   }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface332   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
333                        BlockAndValueMapping &valueMapping) const final {
334     return isa<WhileOp>(dest->getParentOp());
335   }
336 };
337 
338 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
339   using DialectFoldInterface::DialectFoldInterface;
340 
341   // Registered hook to check if the given region, which is attached to an
342   // operation that is *not* isolated from above (i.e. no internal regions
343   // reference values defined in an enclosing region), should be used when
344   // materializing constants.
345   // In the TFLite dialect we materialize inside a while regions as slightly
346   // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface347   bool shouldMaterializeInto(Region *region) const final {
348     return isa<WhileOp>(region->getParentOp());
349   }
350 };
351 
TensorFlowLiteDialect(mlir::MLIRContext * context)352 TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
353     : Dialect(/*name=*/"tfl", context, TypeID::get<TensorFlowLiteDialect>()) {
354   addOperations<
355 #define GET_OP_LIST
356 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
357       >();
358   addInterfaces<TensorFlowLiteInlinerInterface,
359                 TensorFlowLiteDialectFoldInterface>();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // Common support logic
364 //===----------------------------------------------------------------------===//
365 
366 namespace {
367 
368 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
369 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
370 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)371 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
372   if (a.size() > b.size()) return false;
373 
374   return std::equal(a.rbegin(), a.rend(), b.rbegin());
375 }
376 
377 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)378 inline bool IsF32ShapedType(Type t) {
379   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
380     return shaped_type.getElementType().isF32();
381   }
382   return false;
383 }
384 
385 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)386 inline bool IsBF16ShapedType(Type t) {
387   if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
388     return shaped_type.getElementType().isBF16();
389   }
390   return false;
391 }
392 
393 // Returns new shape with rank 'new_dims' with padded ones on the
394 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)395 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
396                                            int new_dims) {
397   std::vector<int64_t> new_shape(new_dims, 1);
398   std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
399   return new_shape;
400 }
401 
402 // Helper method that given and 'current_index' representing
403 // index in broadcasted tensor, get the index in the flat original tensor.
404 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)405 int64_t GetElementIndex(const std::vector<int64_t> &shape,
406                         const std::vector<int64_t> &current_index) {
407   int64_t ind = 0;
408   int64_t mul = 1;
409   for (int i = shape.size() - 1; i >= 0; --i) {
410     ind += (current_index[i] % shape[i]) * mul;
411     mul *= shape[i];
412   }
413   return ind;
414 }
415 
416 // Helper method that increment index represented in 'current_index_ptr'
417 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)418 void IncrementIndex(ArrayRef<int64_t> result_shape,
419                     std::vector<int64_t> *current_index_ptr) {
420   std::vector<int64_t> &current_index = *current_index_ptr;
421   for (int i = result_shape.size() - 1; i >= 0; --i) {
422     current_index[i]++;
423     if (current_index[i] == result_shape[i]) {
424       current_index[i] = 0;
425     } else {
426       break;
427     }
428   }
429 }
430 
431 /// Performs const folding `calculate` with broadcast behavior on the two
432 /// attributes `operand1` and `operand2` and returns the result if possible.
433 /// This function assumes the both operands are verified to have value
434 /// attributes of broadcastable types.
435 template <class AttrElementT,
436           class ElementValueT = typename AttrElementT::ValueType,
437           class CalculationT =
438               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)439 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
440                                       DenseElementsAttr rhs,
441                                       const CalculationT &calculate) {
442   auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
443                   .dyn_cast_or_null<ShapedType>();
444   if (!type) {
445     return {};
446   }
447 
448   const bool rhs_is_splat = rhs.isSplat();
449   const bool lhs_is_splat = lhs.isSplat();
450 
451   // If both of them are splat, compute and return.
452   if (lhs_is_splat && rhs_is_splat) {
453     auto element_result = AttrElementT::get(
454         type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
455                                          rhs.getSplatValue<ElementValueT>()));
456     if (!element_result) return {};
457 
458     return DenseElementsAttr::get(type, element_result);
459   }
460 
461   auto num_elements = type.getNumElements();
462 
463   SmallVector<ElementValueT, 16> new_values;
464   new_values.reserve(num_elements);
465   const auto result_shape = type.getShape();
466   std::vector<int64_t> current_index(type.getRank(), 0);
467   // Create the new shape with ones padded to the left.
468   const std::vector<int64_t> lhs_new_shape =
469       GetPaddedShape(lhs.getType().getShape(), type.getRank());
470   const std::vector<int64_t> rhs_new_shape =
471       GetPaddedShape(rhs.getType().getShape(), type.getRank());
472 
473   auto lhs_old_values = lhs.getValues<ElementValueT>();
474   auto rhs_old_values = rhs.getValues<ElementValueT>();
475 
476   // Add each pair of the corresponding values in the dense elements
477   // attributes.
478   for (int64_t i = 0; i < num_elements; ++i) {
479     // current_index represents the index
480     // in the N-dimension tensor. GetElementIndex returns
481     // the index in the flat representation of the original tensor
482     // to use.
483     const int64_t lhs_index =
484         lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
485     const int64_t rhs_index =
486         rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
487 
488     new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
489                                    *(rhs_old_values.begin() + rhs_index)));
490     IncrementIndex(result_shape, &current_index);
491   }
492   return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
493 }
494 
495 /// Performs const folding `calculate` with broadcast behavior on the two
496 /// attributes `operand1` and `operand2` and returns the result if possible.
497 /// This function assumes the two operands are verified to have value
498 /// attributes of broadcastable types.
499 template <class AttrElementT,
500           class ElementValueT = typename AttrElementT::ValueType,
501           class CalculationT =
502               llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)503 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
504                             Attribute operand2, const CalculationT &calculate) {
505   if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
506       operand2.dyn_cast_or_null<DenseElementsAttr>()) {
507     return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
508         result_type, operand1.cast<DenseElementsAttr>(),
509         operand2.cast<DenseElementsAttr>(), calculate);
510   }
511 
512   // TODO: support other attribute kinds
513 
514   return {};
515 }
516 
517 /// Performs const folding with broadcast behavior on the two attributes in
518 /// `operands` and returns the result if possible.
519 /// Depending on the given `resultType`, either `floatCalculate` or
520 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)521 Attribute ConstFoldBinaryOp(
522     Type result_type, ArrayRef<Attribute> operands,
523     llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
524     llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
525   // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
526   // represented as tensor<f32>. So we are only handling tensor types here.
527   auto type = result_type.dyn_cast<ShapedType>();
528   if (!type) return {};
529 
530   auto elemType = type.getElementType();
531 
532   if (elemType.isa<FloatType>())
533     return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
534                                         float_calculate);
535 
536   if (elemType.isSignlessInteger())
537     return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
538                                           int_calculate);
539 
540   return {};
541 }
542 
543 /// Performs const folding a attributes `operand` and returns the result if
544 /// possible.
545 /// The function currently asserts that the `result_type` to be a f32 tensor
546 /// type.
547 /// TODO: Extend this function to handle integral tensor for ops like
548 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)549 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
550                            llvm::function_ref<APFloat(APFloat)> calculate) {
551   assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
552   auto result_shape_type = result_type.cast<ShapedType>();
553 
554   if (!result_shape_type.hasStaticShape()) return {};
555 
556   if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
557     SmallVector<APFloat, 16> new_values;
558     const int num_elements = result_shape_type.getNumElements();
559     new_values.reserve(num_elements);
560 
561     for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
562       new_values.push_back(calculate(old_value));
563     }
564 
565     return DenseElementsAttr::get(result_shape_type, new_values);
566   }
567 
568   return {};
569 }
570 
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)571 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
572                           Value rhs) {
573   auto result_type =
574       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
575   if (!result_type)
576     emitError(result.location)
577         << "non-broadcastable operands: " << lhs.getType() << " and "
578         << rhs.getType();
579   result.addOperands({lhs, rhs});
580   // Comparison binary ops always return i1 tensor.
581   if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
582     auto result_shape = shaped_type.getShape();
583     result.types.push_back(
584         RankedTensorType::get(result_shape, builder->getI1Type()));
585   } else {
586     result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
587   }
588 }
589 
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)590 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
591                                   Value lhs, Value rhs,
592                                   StringAttr fused_activation_function) {
593   auto result_type =
594       OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
595 
596   if (!result_type)
597     emitError(result.location)
598         << "non-broadcastable operands: " << lhs.getType() << " and "
599         << rhs.getType();
600 
601   result.addOperands({lhs, rhs});
602   result.addAttribute("fused_activation_function", fused_activation_function);
603   result.types.push_back(result_type);
604 }
605 
606 }  // end anonymous namespace
607 
608 //===----------------------------------------------------------------------===//
609 // AddOp
610 //===----------------------------------------------------------------------===//
611 
fold(ArrayRef<Attribute> operands)612 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
613   // TODO(b/142478136): Handle fused ops.
614   if (fused_activation_function() != "NONE") return {};
615   return ConstFoldBinaryOp(
616       getType(), operands, [](APFloat a, APFloat b) { return a + b; },
617       [](APInt a, APInt b) { return a + b; });
618 }
619 
GetArithmeticCount(Operation * op)620 int64_t AddOp::GetArithmeticCount(Operation *op) {
621   int64_t count;
622   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
623 
624   return -1;
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // ConcatenationOp
629 //===----------------------------------------------------------------------===//
630 // TODO(ashwinm): Implement shape inference for Concatenation
631 
632 namespace {
633 
GetConcatenationOpAxis(ConcatenationOp op)634 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
635   auto output_type = op.output().getType().cast<RankedTensorType>();
636   int32_t axis = op.axis();
637   if (axis < 0) axis += output_type.getRank();
638   return axis;
639 }
640 
641 // Verify operand types and the result type:
642 //
643 // 1. Operand type ranks must be equal to the output type rank.
644 //
645 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
646 //    previously seen dimension sizes of the same dimension.
647 //
648 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
649 //    the dimension size of the `axis` dimension of output.
650 //
651 // Note: If an operand has unranked tensor type or has dynamic dimension size,
652 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)653 LogicalResult VerifyConcatenationOpTypes(Operation *op,
654                                          RankedTensorType output_type,
655                                          ArrayRef<TensorType> operand_types,
656                                          int64_t axis) {
657   const int64_t output_rank = output_type.getRank();
658 
659   constexpr int64_t kDynamicSize = -1;
660   SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
661   SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
662                                            output_type.getShape().end());
663   result_dim_sizes[axis] = 0;
664 
665   auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
666     const int64_t loc = result_dim_sizes_loc[dim];
667     if (loc == -1) return std::string("output");
668     return llvm::formatv("operand #{0}", loc).str();
669   };
670 
671   for (auto operand : llvm::enumerate(operand_types)) {
672     auto operand_type = operand.value().dyn_cast<RankedTensorType>();
673     if (!operand_type) {
674       result_dim_sizes[axis] = kDynamicSize;
675       continue;
676     }
677 
678     const int64_t operand_rank = operand_type.getRank();
679     if (operand_rank != output_rank)
680       return op->emitOpError() << "rank of operand #" << operand.index()
681                                << " must be equal to rank of output, expected "
682                                << output_rank << ", got " << operand_rank;
683 
684     for (int64_t dim = 0; dim < output_rank; ++dim) {
685       const int64_t operand_dim_size = operand_type.getDimSize(dim);
686       const int64_t result_dim_size = result_dim_sizes[dim];
687 
688       if (dim == axis) {
689         if (RankedTensorType::isDynamic(operand_dim_size) ||
690             RankedTensorType::isDynamic(result_dim_size))
691           result_dim_sizes[axis] = kDynamicSize;
692         else
693           result_dim_sizes[axis] += operand_dim_size;
694         continue;
695       }
696 
697       if (RankedTensorType::isDynamic(operand_dim_size)) continue;
698 
699       if (RankedTensorType::isDynamic(result_dim_size)) {
700         result_dim_sizes[dim] = operand_dim_size;
701         result_dim_sizes_loc[dim] = operand.index();
702         continue;
703       }
704 
705       if (result_dim_size != operand_dim_size)
706         return op->emitOpError()
707                << "dimension size of dimension #" << dim << " of operand #"
708                << operand.index() << " must be equal to "
709                << "dimension size of dimension #" << dim << " of "
710                << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
711                << operand_dim_size;
712     }
713   }
714 
715   const int64_t output_concated_dim_size = output_type.getDimSize(axis);
716   if (!RankedTensorType::isDynamic(output_concated_dim_size) &&
717       !RankedTensorType::isDynamic(result_dim_sizes[axis]) &&
718       result_dim_sizes[axis] != output_concated_dim_size)
719     return op->emitOpError()
720            << "dimension size of dimension #" << axis << " of output "
721            << "must be equal to the sum of dimension sizes of dimension #"
722            << axis << ", expected " << result_dim_sizes[axis] << ", got "
723            << output_concated_dim_size;
724 
725   return success();
726 }
727 
Verify(ConcatenationOp op)728 LogicalResult Verify(ConcatenationOp op) {
729   auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
730 
731   // If the output type is unranked, there is nothing else to be verified.
732   if (!output_type) return success();
733 
734   const int64_t axis = GetConcatenationOpAxis(op);
735   if (axis < 0 || axis >= output_type.getRank())
736     return op.emitOpError("concatenation dimension must be in [-rank, rank)");
737 
738   SmallVector<TensorType, 4> operand_types;
739   for (Value operand : op.values())
740     operand_types.push_back(operand.getType().cast<TensorType>());
741 
742   return VerifyConcatenationOpTypes(op.getOperation(), output_type,
743                                     operand_types, axis);
744 }
745 
746 // Returns true when all operands are instances of DenseElementsAttr and the
747 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)748 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
749                                     ArrayRef<Attribute> operands,
750                                     RankedTensorType output_type,
751                                     int64_t axis) {
752   if (operands.empty()) return false;
753   if (!output_type.hasStaticShape()) return false;
754   if (axis < 0) return false;
755 
756   return llvm::all_of(operands, [](Attribute operand) {
757     return operand && operand.isa<DenseElementsAttr>();
758   });
759 }
760 
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)761 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
762                                               RankedTensorType output_type,
763                                               int64_t axis) {
764   const auto outer_dims = output_type.getShape().take_front(axis);
765   const int64_t outer_size = std::accumulate(
766       outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
767 
768   const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
769   const int64_t base_inner_size =
770       std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
771                       std::multiplies<int64_t>());
772 
773   // Splits each input operand into outer_size pieces and combines them in
774   // round-robin ordering.
775   std::vector<Attribute> out_attrs(output_type.getNumElements());
776   int64_t out = 0;
777   for (int64_t outer = 0; outer < outer_size; ++outer) {
778     for (auto op : operands) {
779       const int64_t dim_size =
780           op.getType().cast<RankedTensorType>().getDimSize(axis);
781       const int64_t inner_size = dim_size * base_inner_size;
782 
783       auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
784       auto input_iter = input_attrs.begin() + outer * inner_size;
785       for (int64_t inner = 0; inner < inner_size; ++inner)
786         out_attrs[out++] = *input_iter++;
787     }
788   }
789 
790   return DenseElementsAttr::get(output_type, out_attrs);
791 }
792 
793 }  // end anonymous namespace
794 
fold(ArrayRef<Attribute> operands)795 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
796   if (fused_activation_function() == "NONE") {
797     if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
798       const int64_t axis = GetConcatenationOpAxis(*this);
799       if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
800         return ConstFoldConcatenateOpDense(operands, output_type, axis);
801     }
802   }
803 
804   // Remove all empty values.
805   SmallVector<Value, 4> non_empty_values;
806   for (Value value : this->values()) {
807     const auto shaped_type = value.getType().cast<ShapedType>();
808     if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
809       continue;
810     }
811     non_empty_values.push_back(value);
812   }
813 
814   // All are not empty, do nothing.
815   if (non_empty_values.size() == getNumOperands()) return nullptr;
816 
817   // If only one input is non-empty, just return it as the result of folding.
818   if (non_empty_values.size() == 1) {
819     return non_empty_values[0];
820   }
821 
822   // Otherwise, build a new concatenation op with non-empty values.
823   mlir::OpBuilder builder(getOperation());
824   auto new_concat = builder.create<TFL::ConcatenationOp>(
825       getLoc(), getType(), non_empty_values,
826       builder.getIntegerAttr(builder.getIntegerType(32), axis()),
827       builder.getStringAttr(fused_activation_function()));
828   return new_concat.getResult();
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // CustomOp
833 //===----------------------------------------------------------------------===//
834 
Verify(CustomOp op)835 static LogicalResult Verify(CustomOp op) {
836   OpaqueElementsAttr opaque_attr =
837       op.custom_option().cast<OpaqueElementsAttr>();
838   if (!opaque_attr.getType().hasStaticShape())
839     return op.emitOpError("custom_option should have a static shape.");
840   const int attribute_size = opaque_attr.getValue().size();
841   if (attribute_size != opaque_attr.getType().cast<ShapedType>().getDimSize(0))
842     return op.emitOpError(
843         "custom_option should have the same length of content with shape.");
844   return success();
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // CustomTfOp
849 //===----------------------------------------------------------------------===//
850 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange ranges,SmallVectorImpl<Type> & inferredReturnTypes)851 LogicalResult CustomTfOp::inferReturnTypes(
852     MLIRContext *, Optional<Location> location, ValueRange operands,
853     DictionaryAttr attr, RegionRange ranges,
854     SmallVectorImpl<Type> &inferredReturnTypes) {
855   CustomTfOpAdaptor op(operands, attr, ranges);
856 
857   if (op.getRegions().empty()) return success();
858   auto *real_op = &op.body().front().front();
859   if (llvm::isa<TF::FakeQuantWithMinMaxArgsOp, TF::FakeQuantWithMinMaxVarsOp,
860                 TF::FakeQuantWithMinMaxVarsPerChannelOp>(real_op)) {
861     Value input = *operands.begin();
862     inferredReturnTypes.assign({input.getType()});
863   }
864   return success();
865 }
866 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)867 bool CustomTfOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
868   if (lhs.empty()) return true;
869   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
870   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
871   return true;
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // FullyConnectedOp
876 //===----------------------------------------------------------------------===//
877 
Verify(FullyConnectedOp op)878 LogicalResult Verify(FullyConnectedOp op) {
879   ShapedType input_type = op.input().getType().cast<ShapedType>();
880   ShapedType filter_type = op.filter().getType().cast<ShapedType>();
881   if (filter_type.hasRank() && filter_type.getRank() != 2) {
882     return op.emitOpError("expect 2d filter, got ") << filter_type;
883   }
884 
885   if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
886     return mlir::success();
887   }
888 
889   // Input's element size must be multiple of parameter's z_in dimension.
890   const int z_in = filter_type.getDimSize(1);
891   const int num_input_elements = input_type.getNumElements();
892   if (num_input_elements % z_in != 0) {
893     return op.emitOpError(llvm::formatv(
894                "expect 'input' num_elements % {0} == 0, got input type ", z_in))
895            << input_type;
896   }
897 
898   // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
899   // format.
900   if (op.weights_format() == "DEFAULT") {
901     ShapedType output_type =
902         (*op.output().begin()).getType().cast<ShapedType>();
903     if (!output_type.hasStaticShape()) {
904       return mlir::success();
905     }
906 
907     const int num_output_elements = output_type.getNumElements();
908     const int z_out = filter_type.getDimSize(0);
909     if (num_output_elements % z_out != 0) {
910       return op.emitOpError(llvm::formatv(
911                  "expect 'output' num_elements % {0} == 0, got ", z_out))
912              << output_type;
913     }
914 
915     if (num_input_elements / z_in != num_output_elements / z_out) {
916       return op.emitOpError(
917           "num_input_elements / z_in != num_output_elements / z_out");
918     }
919   }
920 
921   return mlir::success();
922 }
923 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)924 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
925                                      SmallVectorImpl<OpFoldResult> &results) {
926   assert(operands.size() == 3);
927 
928   // Folding not implemented with any activation function or any weight type
929   // besides the default.
930   if (fused_activation_function() != "NONE") return failure();
931   if (weights_format() != "DEFAULT") return failure();
932 
933   // Bias tensor is optional.
934   const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
935 
936   // Get the tensors.
937   DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
938   if (!matchPattern(input(), m_Constant(&input_tensor)) ||
939       !matchPattern(filter(), m_Constant(&weights_tensor)) ||
940       (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
941     return failure();
942   }
943 
944   // Get the tensor types.
945   const auto input_type = input_tensor.getType().cast<ShapedType>();
946   const auto weights_type = weights_tensor.getType().cast<ShapedType>();
947   const auto bias_type =
948       has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
949 
950   const auto output_type = getType(0).cast<ShapedType>();
951 
952   // Folding only implemented for float tensors.
953   if (!input_type.getElementType().isF32() ||
954       !weights_type.getElementType().isF32() ||
955       !output_type.getElementType().isF32() ||
956       (has_bias && !bias_type.getElementType().isF32())) {
957     return failure();
958   }
959 
960   // Folding only implemented for static shapes
961   if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
962       (has_bias && !bias_type.hasStaticShape())) {
963     return failure();
964   }
965 
966   // Folding only implemented for 1D input, 2D weights and 1D bias
967   if (input_type.getShape().size() != 1 ||
968       weights_type.getShape().size() != 2 ||
969       (has_bias && bias_type.getShape().size() != 1)) {
970     return failure();
971   }
972 
973   // Get the sizes
974   const auto input_size = input_type.getNumElements();
975   const auto output_size = output_type.getNumElements();
976 
977   // Get iterators to the tensors.
978   const auto input_values_it = input_tensor.getValues<float>().begin();
979   const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
980   auto weights_row_it = weights_values_ptr;
981   // The 'else' case could be nullptr, but the types don't match.
982   auto bias_values_it =
983       has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
984 
985   // Do the actual folding, one output at a time.
986   std::vector<float> result_values;
987   result_values.reserve(output_size);
988 
989   for (int i = 0; i < output_size; ++i) {
990     // Dot product with Kahan/Neumaier summation to minimize numeric errors.
991     float sum = has_bias ? *bias_values_it : 0.0f;
992     float compensation = 0.0f;
993     for (int j = 0; j < input_size; ++j) {
994       const float addend = input_values_it[j] * weights_row_it[j];
995       const float new_sum = sum + addend;
996       // DO NOT enable -funsafe-math-optimizations here.
997       // There is a test detecting unsafe optimizations.
998       // Unsafe math optimizations can reorder float formulas, and set the
999       // compensation to constant 0. The formula must be evaluated as written
1000       // for the algorithm to work.
1001       // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
1002       if (std::abs(sum) >= std::abs(addend)) {
1003         compensation += (sum - new_sum) + addend;
1004       } else {
1005         compensation += (addend - new_sum) + sum;
1006       }
1007       sum = new_sum;
1008     }
1009     result_values.push_back(sum + compensation);
1010     weights_row_it += input_size;
1011     bias_values_it++;
1012   }
1013 
1014   // Set result tensor
1015   const auto folded =
1016       DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
1017   results.assign({folded});
1018 
1019   return success();
1020 }
1021 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1022 void FullyConnectedOp::getCanonicalizationPatterns(
1023     OwningRewritePatternList &results, MLIRContext *context) {
1024   results.insert<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
1025 }
1026 
GetArithmeticCount(Operation * op)1027 int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) {
1028   int64_t count;
1029   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1030           op, &count))
1031     return count;
1032 
1033   return -1;
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // Conv2DOp
1038 //===----------------------------------------------------------------------===//
1039 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1040 void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1041                                            MLIRContext *context) {
1042   // TODO(b/180121750): Enable the pattern after the integration tests are
1043   // fixed.
1044   // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
1045 }
1046 
ComputeConvWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding,int64_t * output_size)1047 static LogicalResult ComputeConvWindowedOutputSize(
1048     int64_t input_size, int64_t filter_size, int64_t dilation_rate,
1049     int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
1050   int64_t pad_low;
1051   int64_t pad_high;
1052 
1053   tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1054       input_size, filter_size, dilation_rate, stride, padding, output_size,
1055       &pad_low, &pad_high);
1056   // Return failure if expected_output_size could not be calculated.
1057   if (!status.ok()) return failure();
1058   return success();
1059 }
1060 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1061 LogicalResult Conv2DOp::inferReturnTypes(
1062     MLIRContext *, Optional<Location> location, ValueRange operands,
1063     DictionaryAttr attr, RegionRange,
1064     SmallVectorImpl<Type> &inferredReturnTypes) {
1065   Conv2DOpAdaptor op(operands, attr);
1066 
1067   const Value input = op.input();
1068   const Value filter = op.filter();
1069 
1070   const RankedTensorType input_ty =
1071       input.getType().dyn_cast_or_null<RankedTensorType>();
1072   const RankedTensorType filter_ty =
1073       filter.getType().dyn_cast_or_null<RankedTensorType>();
1074   // If indeed both input type & filter type are ranked type and have ranks.
1075   // We will need to check their ranks are valid.
1076   if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
1077       (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
1078     return emitOptionalError(location, "Invalid ranks");
1079   }
1080 
1081   // If either input or filter is unranked, we will just return unranked output
1082   // shape.
1083   if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
1084     Type result_type;
1085     result_type = UnrankedTensorType::get(
1086         input.getType().cast<ShapedType>().getElementType());
1087     inferredReturnTypes.assign({result_type});
1088     return success();
1089   }
1090 
1091   auto stride_h = op.stride_h().getInt();
1092   auto stride_w = op.stride_w().getInt();
1093   auto dilation_h = op.dilation_h_factor().getInt();
1094   auto dilation_w = op.dilation_w_factor().getInt();
1095 
1096   // We don't have EXPLICIT PADDING in TfLite.
1097   auto paddings = op.padding().getValue();
1098   tensorflow::Padding padding;
1099   auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1100   if (!padding_is_valid.ok()) {
1101     return emitOptionalError(location, "invalid padding format provided");
1102   }
1103 
1104   // Output always have rank 4. All dimensions are initialized to
1105   // dynamic size and can be partially inferred.
1106   // TFL's conv2d is always NHWC format & the filter is OHWI.
1107   SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
1108   return_shape[0] = input_ty.getDimSize(0);
1109   return_shape[3] = filter_ty.getDimSize(0);
1110 
1111   // Spatial dimensions can be inferred only when both input and filter are
1112   // ranked because we need to get their spatial dimensions.
1113 
1114   // Height.
1115   if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
1116     int64_t output_height;
1117     if (failed(ComputeConvWindowedOutputSize(
1118             input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
1119             stride_h, padding, &output_height))) {
1120       return failure();
1121     }
1122     return_shape[1] = output_height;
1123   }
1124 
1125   // Width.
1126   if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
1127     int64_t output_width;
1128     if (failed(ComputeConvWindowedOutputSize(
1129             input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
1130             stride_w, padding, &output_width))) {
1131       return failure();
1132     }
1133     return_shape[2] = output_width;
1134   }
1135 
1136   auto result_type =
1137       mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
1138 
1139   inferredReturnTypes.assign({result_type});
1140   return success();
1141 }
1142 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1143 bool Conv2DOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1144   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1145   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1146   return true;
1147 }
1148 
GetArithmeticCount(Operation * op)1149 int64_t Conv2DOp::GetArithmeticCount(Operation *op) {
1150   int64_t count;
1151   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1152           op, &count))
1153     return count;
1154 
1155   return -1;
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // DepthwiseConv2DO
1160 //===----------------------------------------------------------------------===//
1161 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1162 void DepthwiseConv2DOp::getCanonicalizationPatterns(
1163     OwningRewritePatternList &results, MLIRContext *context) {
1164   // TODO(b/180121750): Enable the pattern after the integration tests are
1165   // fixed.
1166   // results.insert<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
1167 }
1168 
GetArithmeticCount(Operation * op)1169 int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) {
1170   int64_t count;
1171   if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1172           op, &count))
1173     return count;
1174 
1175   return -1;
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // GatherOp
1180 //===----------------------------------------------------------------------===//
1181 
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis,IntegerAttr batch_dims)1182 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
1183                           Value params, Value indices, IntegerAttr axis,
1184                           IntegerAttr batch_dims) {
1185   auto params_type = params.getType().cast<TensorType>();
1186   auto indices_type = indices.getType().cast<TensorType>();
1187 
1188   // If params/indices is unranked, then output is unranked.
1189   if (!params_type.hasRank() || !indices_type.hasRank())
1190     return TFL::GatherOp::build(
1191         *builder, result, UnrankedTensorType::get(params_type.getElementType()),
1192         params, indices, axis, batch_dims);
1193 
1194   int64_t params_rank = params_type.getRank();
1195   int64_t indices_rank = indices_type.getRank();
1196 
1197   // params rank is guaranteed to be at least 1.
1198   // Produces an output tensor with shape:
1199   // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
1200   std::vector<int64_t> shape(params_type.getShape());
1201   int64_t axis_i = axis.getInt();
1202 
1203   // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
1204   if (axis_i < 0) {
1205     axis_i += params_rank;
1206   }
1207 
1208   // params must be at least rank axis + 1
1209   if (params_rank < axis_i + 1) {
1210     emitError(result.location, "params must be at least rank axis + 1");
1211   }
1212 
1213   int64_t batch_dims_i = batch_dims.getInt();
1214   if (batch_dims_i < 0) {
1215     batch_dims_i += indices_rank;
1216   }
1217 
1218   if (batch_dims_i > axis_i) {
1219     emitError(result.location,
1220               "axis should be bigger than or equal to batch_dims");
1221   }
1222   if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
1223     emitError(result.location,
1224               "batch_dims must be smaller than params' rank and smaller than "
1225               "or equal to indices'rank");
1226   }
1227   for (int i = 0; i < batch_dims_i; ++i) {
1228     if (indices_type.getShape()[i] != params_type.getShape()[i]) {
1229       emitError(result.location,
1230                 "batch dimensions of params must be equal to batch dimensions "
1231                 "of indices");
1232     }
1233   }
1234 
1235   if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
1236     // Scalar indices (output is rank(params) - 1).
1237     // Erase shape[axis]
1238     shape.erase(shape.begin() + axis_i);
1239   } else if (indices_rank == 1) {
1240     // Vector indices (output is rank(params)).
1241     // Copy indices.shape into params.shape[axis]
1242     std::copy(std::begin(indices_type.getShape()),
1243               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1244   } else {
1245     // Higher rank indices (output is rank(params) + rank(indices) - 1).
1246     shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
1247     // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1248     std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1249               std::end(params_type.getShape()),
1250               std::begin(shape) + axis_i + indices_rank - batch_dims_i);
1251 
1252     // Copy indices.shape into params.shape[axis]
1253     std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
1254               std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1255   }
1256 
1257   TFL::GatherOp::build(
1258       *builder, result,
1259       RankedTensorType::get(shape, params_type.getElementType()), params,
1260       indices, axis, batch_dims);
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // ScatterNdOp
1265 //===----------------------------------------------------------------------===//
1266 
Verify(ScatterNdOp op)1267 static LogicalResult Verify(ScatterNdOp op) {
1268   auto indices = op.indices();
1269   auto updates = op.updates();
1270   auto shape = op.shape();
1271   auto output = op.output();
1272 
1273   auto updates_type = updates.getType().cast<ShapedType>();
1274   auto indices_type = indices.getType().cast<ShapedType>();
1275 
1276   if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1277     return success();
1278   }
1279 
1280   // Checks if the shape of `updates` is a tensor of shape
1281   // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1282   // ScatterNd op description.
1283 
1284   auto outer_dims = indices_type.getRank() - 1;
1285   auto outermost_dim = indices_type.getDimSize(outer_dims);
1286   // Checks whether the first `outer_dims` dimensions of `indices` and
1287   // `updates` are equal.
1288   for (auto i = 0; i < outer_dims; i++) {
1289     if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1290       return op.emitOpError()
1291              << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1292              << ", but updates.Dims(" << i
1293              << ") == " << updates_type.getDimSize(i);
1294     }
1295   }
1296 
1297   auto output_type = output.getType().cast<ShapedType>();
1298   auto shape_type = shape.getType().cast<ShapedType>();
1299   if (shape_type.hasStaticShape()) {
1300     // Check the rank of `shape`.
1301     auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1302     if (shape_type.getDimSize(0) != output_rank) {
1303       return op.emitOpError()
1304              << "shape must be a vector of length " << output_rank;
1305     }
1306     if (output_type.hasRank()) {
1307       if (output_type.getRank() != output_rank) {
1308         return op.emitOpError()
1309                << "output must have the same rank with the length of shape = "
1310                << output_rank;
1311       }
1312     }
1313   }
1314 
1315   DenseIntElementsAttr shape_value;
1316   if (matchPattern(shape, m_Constant(&shape_value))) {
1317     for (const auto shape_elem : shape_value) {
1318       if (shape_elem.getSExtValue() <= 0) {
1319         return op.emitOpError("all elements of shape must be > 0");
1320       }
1321     }
1322 
1323     // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1324     // dimensions of `updates` and `shape` are equal.
1325     for (auto shape_it : llvm::enumerate(shape_value)) {
1326       int64_t i = shape_it.index();
1327       auto value = shape_it.value().getSExtValue();
1328       if (i >= outermost_dim) {
1329         auto corresponding_dim = i - outermost_dim + outer_dims;
1330         if (value != updates_type.getDimSize(corresponding_dim)) {
1331           return op.emitOpError()
1332                  << "updates.Dims(" << i
1333                  << ") == " << updates_type.getDimSize(corresponding_dim)
1334                  << ", but shape[" << i << "] == " << value;
1335         }
1336       }
1337     }
1338 
1339     // Checks if the output has the shape specified by `shape`.
1340     if (output_type.hasStaticShape()) {
1341       for (auto shape_it : llvm::enumerate(shape_value)) {
1342         int i = shape_it.index();
1343         auto value = shape_it.value().getSExtValue();
1344         if (output_type.getDimSize(i) != value) {
1345           return op.emitOpError()
1346                  << "output shape [" << output_type.getShape()
1347                  << "] must be equal to the value of shape " << shape_value;
1348         }
1349       }
1350     }
1351   }
1352   return success();
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // MulOp
1357 //===----------------------------------------------------------------------===//
1358 
fold(ArrayRef<Attribute> operands)1359 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1360   // TODO(b/142478136): Handle fused ops.
1361   if (fused_activation_function() != "NONE") return {};
1362 
1363   // This function is performance critical for op fusion patterns, e.g.
1364   // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1365   // So a few specializations are provided to evaluate the math operation
1366   // more efficiently.
1367 
1368   // Specialization for f32 type.
1369   if (getType().cast<ShapedType>().getElementType().isF32()) {
1370     return ConstFoldBinaryOp<FloatAttr, float>(
1371         getType(), operands[0], operands[1],
1372         [](float a, float b) { return a * b; });
1373   }
1374 
1375   // Specialization for bf16 type.
1376   if (getType().cast<ShapedType>().getElementType().isBF16()) {
1377     return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1378         getType(), operands[0], operands[1],
1379         [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1380   }
1381 
1382   // Generic fallback with APFloat
1383   return ConstFoldBinaryOp(
1384       getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1385       [](APInt a, APInt b) { return a * b; });
1386 }
1387 
GetArithmeticCount(Operation * op)1388 int64_t MulOp::GetArithmeticCount(Operation *op) {
1389   int64_t count;
1390   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1391 
1392   return -1;
1393 }
1394 
1395 //===----------------------------------------------------------------------===//
1396 // DivOp
1397 //===----------------------------------------------------------------------===//
1398 
fold(ArrayRef<Attribute> operands)1399 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1400   // TODO(b/142478136): Handle fused ops.
1401   if (fused_activation_function() != "NONE") return {};
1402   return ConstFoldBinaryOp(
1403       getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1404       [](APInt a, APInt b) { return a.sdiv(b); });
1405 }
1406 
GetArithmeticCount(Operation * op)1407 int64_t DivOp::GetArithmeticCount(Operation *op) {
1408   int64_t count;
1409   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1410 
1411   return -1;
1412 }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // PackOp
1416 //===----------------------------------------------------------------------===//
1417 
1418 // TODO(b/133486129): Implement shape inference for pack
1419 
Verify(PackOp op)1420 static LogicalResult Verify(PackOp op) {
1421   // TODO(antiagainst): Implement other checks as in
1422   // tensorflow/lite/kernels/pack.cc
1423 
1424   if (op.getOperation()->getNumOperands() != op.values_count())
1425     return op.emitOpError("input count should match 'values_count' attribute");
1426 
1427   Value operand0 = op.getOperand(0);
1428   auto input_type = operand0.getType().cast<ShapedType>();
1429 
1430   // Check axis bounds.
1431   if (input_type.hasRank()) {
1432     int32_t axis_value = op.axis();
1433     if (axis_value < 0) axis_value += input_type.getRank() + 1;
1434     if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1435       return op.emitOpError()
1436              << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1437              << "got rank = " << input_type.getRank()
1438              << ", and axis = " << op.axis();
1439   }
1440 
1441   // Make sure all inputs have the same shape and element type.
1442   // TODO(b/135032063): Simplify once fixed.
1443   for (Type operand_type : op.getOperandTypes()) {
1444     if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1445       return op.emitOpError("operands should be of the same type. got ")
1446              << input_type << ", " << operand_type;
1447   }
1448 
1449   return success();
1450 }
1451 
1452 //===----------------------------------------------------------------------===//
1453 // PReluOp
1454 //===----------------------------------------------------------------------===//
1455 
Verify(PReluOp op)1456 static LogicalResult Verify(PReluOp op) {
1457   auto input_type = op.input().getType().cast<ShapedType>();
1458   auto alpha_type = op.alpha().getType().cast<ShapedType>();
1459   auto output_type = op.output().getType().cast<ShapedType>();
1460 
1461   if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1462     if (input_type.getRank() != alpha_type.getRank() + 1) {
1463       return op.emitOpError("'alpha' should have one less rank than 'input'.");
1464     }
1465 
1466     // Check if alpha is broadcastable
1467     for (int i = 0; i < alpha_type.getRank(); i++) {
1468       if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1469           alpha_type.getDimSize(i) != 1) {
1470         return op.emitOpError(
1471             llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1472       }
1473     }
1474   }
1475 
1476   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1477     if (input_type.getRank() != output_type.getRank()) {
1478       return op.emitOpError("'input' and 'output' should have the same rank.");
1479     }
1480 
1481     // Check if input and output shapes are same
1482     for (int i = 0; i < input_type.getRank(); i++) {
1483       if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1484         return op.emitOpError(
1485             "'input' and 'output' should have the same shape.");
1486       }
1487     }
1488   }
1489   return success();
1490 }
1491 
1492 //===----------------------------------------------------------------------===//
1493 // ReshapeOp
1494 //===----------------------------------------------------------------------===//
1495 
1496 namespace {
1497 // This pattern matches and merges a tfl.reshape under the following
1498 // condition:
1499 // * The input's defining op is another tfl.reshape.
1500 // TODO(antiagainst): This pattern probably should be moved to the peephole
1501 // category, after we have the infra for peephole passes.
1502 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1503   RemoveAdjacentReshape(MLIRContext *context)
1504       : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1505 
matchmlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1506   LogicalResult match(Operation *op) const override {
1507     auto thisOp = cast<ReshapeOp>(op);
1508     auto prevOp = thisOp.getOperand(0).getDefiningOp();
1509     return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1510   }
1511 
rewritemlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1512   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1513     auto thisOp = cast<ReshapeOp>(op);
1514     auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1515 
1516     // Replace
1517     //   %1 = "tfl.reshape"(%0, %shape0)
1518     //   %2 = "tfl.reshape"(%1, %shape1)
1519     // With
1520     //   %2 = "tfl.reshape"(%0, %shape1)
1521     rewriter.replaceOpWithNewOp<ReshapeOp>(
1522         op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1523   }
1524 };
1525 
1526 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1527 // the dimensions are '1's except the last dimension, it will be reshaped to a
1528 // 1-D tensor.
1529 // Note that this pattern doesn't check or change the content of the shape
1530 // tensor.
1531 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1532   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1533 
matchAndRewritemlir::TFL::__anona97bba200e11::ConvertShapeTo1D1534   LogicalResult matchAndRewrite(ReshapeOp reshape,
1535                                 PatternRewriter &rewriter) const override {
1536     if (!reshape.shape().hasOneUse()) return failure();
1537 
1538     DenseIntElementsAttr shape;
1539     if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1540       return failure();
1541     }
1542     // It is already a 1-D constant, no change.
1543     auto old_shape = shape.getType().getShape();
1544     if (old_shape.size() == 1) {
1545       return failure();
1546     }
1547     // Verify all the leading dimensions are length one, except the last one.
1548     for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1549       if (*it != 1) {
1550         reshape->emitError(
1551             "Non-vector shape input is used, might cause runtime error");
1552         return failure();
1553       }
1554     }
1555     auto new_shape = shape.reshape(RankedTensorType::get(
1556         {*old_shape.rbegin()}, shape.getType().getElementType()));
1557     rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1558                                               new_shape);
1559     return success();
1560   }
1561 };
1562 
InputOutputHasSameShape(mlir::Type input_type,mlir::Type output_type)1563 bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type) {
1564   auto input_shaped_type = input_type.dyn_cast_or_null<ShapedType>();
1565   if (!input_shaped_type || !input_shaped_type.hasStaticShape()) return false;
1566 
1567   auto output_shaped_type = output_type.dyn_cast_or_null<ShapedType>();
1568   if (!output_shaped_type || !output_shaped_type.hasStaticShape()) return false;
1569 
1570   return input_shaped_type == output_shaped_type;
1571 }
1572 
1573 }  // end anonymous namespace
1574 
fold(ArrayRef<Attribute> operands)1575 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1576   // Remove identity reshape with both static result and input shape.
1577   auto result_type = getType().cast<ShapedType>();
1578   auto input_type = getOperand(0).getType().cast<ShapedType>();
1579   if (InputOutputHasSameShape(input_type, result_type)) return input();
1580 
1581   // Constant folding
1582   if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1583     // If the result type isn't static, tries to derive the result type from
1584     // the #2 operand.
1585     if (!result_type.hasStaticShape()) {
1586       auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1587       if (!shape_elements) return nullptr;
1588 
1589       SmallVector<int64_t, 4> shape_data;
1590       for (const auto &it : shape_elements.getValues<APInt>()) {
1591         shape_data.push_back(it.getSExtValue());
1592       }
1593       result_type =
1594           RankedTensorType::get(shape_data, input_type.getElementType());
1595     }
1596     return dense_elements.reshape(result_type);
1597   }
1598 
1599   return nullptr;
1600 }
1601 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1602 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1603                                             MLIRContext *context) {
1604   results.insert<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1605 }
1606 
1607 using ReshapeErrorHandler =
1608     llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1609 
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1610 LogicalResult GetReshapeOutputType(Value input, Value shape,
1611                                    ReshapeErrorHandler error_handler,
1612                                    TensorType &output_ty) {
1613   auto input_ty = input.getType().cast<TensorType>();
1614   auto element_ty = input_ty.getElementType();
1615   output_ty = UnrankedTensorType::get(element_ty);
1616 
1617   auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1618   if (!shape_ty) return success();
1619   if (shape_ty.getRank() != 1)
1620     return error_handler(llvm::formatv(
1621         "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1622 
1623   DenseIntElementsAttr shape_attr;
1624   if (!matchPattern(shape, m_Constant(&shape_attr))) {
1625     // If only shape of `shape` is known, return ranked but dynamic output
1626     // shape.
1627     if (shape_ty.hasStaticShape()) {
1628       llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1629                                                   ShapedType::kDynamicSize);
1630       output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1631     }
1632     return success();
1633   }
1634 
1635   // Detect if reshape output shape is folded.
1636   bool shape_ty_zero_dim = false;
1637   int unknown_index = -1;
1638   // The product of constant shape argument excluding unknown dimension.
1639   int64_t shape_ty_size = 1;
1640   llvm::SmallVector<int64_t, 8> output_ty_shape;
1641   output_ty_shape.reserve(shape_attr.getNumElements());
1642   for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
1643     const int64_t size = dim.value().getSExtValue();
1644     if (size == ShapedType::kDynamicSize) {
1645       if (unknown_index != -1)
1646         return error_handler(llvm::formatv(
1647             "requires 'shape' to have at most one dynamic dimension, but got "
1648             "multiple dynamic dimensions at indices {0} and {1}. You need to "
1649             "set up the unspecified size(s) to avoid this problem, for example,"
1650             "setting batch size in keras model or setting unspecified input "
1651             "size(s) with fixed ones.",
1652             unknown_index, dim.index()));
1653 
1654       unknown_index = dim.index();
1655     } else if (size == 0) {
1656       shape_ty_zero_dim = true;
1657     } else if (size > 0) {
1658       shape_ty_size *= size;
1659     } else {
1660       return error_handler(
1661           llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1662                         "but got {0} at index {1}",
1663                         size, dim.index()));
1664     }
1665     output_ty_shape.push_back(size);
1666   }
1667 
1668   if (!input_ty.hasStaticShape()) {
1669     output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1670     return success();
1671   }
1672 
1673   // Compute the value of the unknown dimension.
1674   if (unknown_index != -1) {
1675     // Compute number of elements in tensor shape.
1676     int64_t input_ty_size = 1;
1677     bool input_ty_zero_dim = false;
1678     for (const auto &dim : input_ty.getShape()) {
1679       if (dim > 0 || !shape_ty_zero_dim) {
1680         input_ty_size *= dim;
1681       } else {
1682         input_ty_zero_dim = true;
1683       }
1684     }
1685 
1686     const int64_t missing_dim = input_ty_size / shape_ty_size;
1687     if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1688       return error_handler(
1689           llvm::formatv("requires 'input' number of elements be a multiple of "
1690                         "{0}, but got {1}",
1691                         shape_ty_size, input_ty_size));
1692 
1693     // Set the unknown dimension such that total number of elements remain
1694     // constant.
1695     output_ty_shape[unknown_index] = missing_dim;
1696   }
1697 
1698   output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1699 
1700   return success();
1701 }
1702 
Verify(ReshapeOp op)1703 static LogicalResult Verify(ReshapeOp op) {
1704   auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1705     return op.emitOpError() << message;
1706   };
1707   TensorType expected_ty;
1708   if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1709                                   expected_ty)))
1710     return failure();
1711 
1712   auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1713   if (!output_ty) return success();
1714   auto input_ty = op.input().getType().cast<TensorType>();
1715   if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1716     const int64_t output_ty_size = output_ty.getNumElements();
1717     const int64_t input_ty_size = input_ty.getNumElements();
1718     if (input_ty_size != output_ty_size)
1719       return op.emitOpError() << "requires 'output' number of elements to "
1720                                  "match 'input' number of elements, but got "
1721                               << output_ty_size << " and " << input_ty_size;
1722   }
1723 
1724   if (!TF::AreCastCompatible({output_ty, expected_ty}))
1725     return op.emitOpError()
1726            << "requires 'output' type " << output_ty
1727            << " to be cast compatible with expected type " << expected_ty;
1728 
1729   return success();
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // PackOp
1734 //===----------------------------------------------------------------------===//
1735 
1736 // Remove redundant unpack pack op.
1737 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1738 // unpack op is only consumed by the pack op, it will be removed as well.
1739 // An example illustration is:
1740 //                  Unpack [5, 8, 9], axis = 1
1741 //                /       \
1742 //            value  ...  value [5, 9], 8 values in total
1743 //              \           /
1744 //                 pack,   axis = 1
1745 //                   |
1746 //               value   [5, 8, 9]
1747 //
1748 //   This can actually be simplified into just:
1749 //
1750 //           =>   Value [5, 8, 9]
1751 // TODO(b/133341698): Move to tablegen when variadic is supported.
1752 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1753   explicit RemoveRedundantUnpackPack(MLIRContext *context)
1754       : RewritePattern(PackOp::getOperationName(), 2, context) {}
1755 
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1756   LogicalResult matchAndRewrite(Operation *op,
1757                                 PatternRewriter &rewriter) const override {
1758     TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1759     Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1760     if (!first_input) return failure();
1761     auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1762     if (!input_unpack_op) return failure();
1763 
1764     // The unpack & pack should have the same axis & num inputs/outputs.
1765     if (pack_op.axis() != input_unpack_op.axis() ||
1766         pack_op.values_count() != input_unpack_op.num())
1767       return failure();
1768 
1769     const int total_pack_inputs = pack_op.getNumOperands();
1770     const int num_results = input_unpack_op.getNumResults();
1771     if (total_pack_inputs != num_results) return failure();
1772     for (auto input_output :
1773          llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1774       Value pack_input = std::get<0>(input_output);
1775       Value unpack_output = std::get<1>(input_output);
1776       // Make sure the ordering is the same for the pack op & unpack op.
1777       if (pack_input != unpack_output) return failure();
1778     }
1779 
1780     // Replace the pack's output to the unpack's input.
1781     rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1782     // At this point, we don't manually remove the redundant pack op & unpack op
1783     // (we cannot actually), but trust the PatterRewriter to garbage collect
1784     // these two ops.
1785     return success();
1786   }
1787 };
1788 
1789 // Replace PackOp with a reshape when there is only one operand.
1790 struct ReplacePackWithReshape : public RewritePattern {
ReplacePackWithReshapemlir::TFL::ReplacePackWithReshape1791   explicit ReplacePackWithReshape(MLIRContext *context)
1792       : RewritePattern(PackOp::getOperationName(), 2, context) {}
matchAndRewritemlir::TFL::ReplacePackWithReshape1793   LogicalResult matchAndRewrite(Operation *op,
1794                                 PatternRewriter &rewriter) const override {
1795     TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1796     if (pack_op.getNumOperands() != 1) return failure();
1797 
1798     Location loc = pack_op.getLoc();
1799     auto output_type = pack_op.getType().cast<ShapedType>();
1800     if (!output_type.hasStaticShape()) return failure();
1801 
1802     // This is to workaround the unnecessary cast i64 -> i32.
1803     SmallVector<int32_t, 4> new_shape_array;
1804     for (auto size : output_type.getShape()) {
1805       new_shape_array.push_back(static_cast<int32_t>(size));
1806     }
1807 
1808     auto new_shape = rewriter.create<TFL::ConstOp>(
1809         loc, DenseIntElementsAttr::get(
1810                  RankedTensorType::get(new_shape_array.size(),
1811                                        rewriter.getIntegerType(32)),
1812                  new_shape_array));
1813 
1814     rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
1815                                            pack_op.getOperand(0), new_shape);
1816     return success();
1817   }
1818 };
1819 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1820 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1821                                          MLIRContext *context) {
1822   results.insert<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
1823 }
1824 
1825 //===----------------------------------------------------------------------===//
1826 // SliceOp
1827 //===----------------------------------------------------------------------===//
1828 
Verify(SliceOp op)1829 static LogicalResult Verify(SliceOp op) {
1830   auto input_type = op.input().getType().cast<ShapedType>();
1831   auto begin_type = op.begin().getType().cast<ShapedType>();
1832   auto size_type = op.size().getType().cast<ShapedType>();
1833   if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
1834       size_type.hasStaticShape()) {
1835     if (input_type.getRank() != begin_type.getNumElements()) {
1836       return op.emitError(
1837           "begin tensor elements size is not equal to input tensor rank");
1838     }
1839 
1840     if (input_type.getRank() != size_type.getNumElements()) {
1841       return op.emitError(
1842           "size tensor elements size is not equal to input tensor rank");
1843     }
1844   }
1845 
1846   DenseIntElementsAttr begin;
1847   if (matchPattern(op.begin(), m_Constant(&begin))) {
1848     int axis = 0;
1849     for (auto begin_i : llvm::enumerate(begin)) {
1850       if (begin_i.value().getSExtValue() < 0) {
1851         return op.emitError(
1852             llvm::formatv("begin[{0}] cannot be negative", axis));
1853       }
1854       axis++;
1855     }
1856   }
1857 
1858   DenseIntElementsAttr size;
1859   if (matchPattern(op.size(), m_Constant(&size))) {
1860     int axis = 0;
1861     for (auto size_i : llvm::enumerate(size)) {
1862       if (size_i.value().getSExtValue() < -1) {
1863         return op.emitError(
1864             llvm::formatv("size[{0}] cannot be negative other than -1", axis));
1865       }
1866       axis++;
1867     }
1868   }
1869 
1870   if (begin && size && input_type.hasStaticShape()) {
1871     for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
1872       int begin_i =
1873           begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1874       int size_i =
1875           size.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1876       int dim_i = input_type.getShape()[i];
1877       if (begin_i > dim_i) {
1878         return op.emitOpError(llvm::formatv(
1879             "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
1880       }
1881       if (size_i >= 0 && begin_i + size_i > dim_i) {
1882         return op.emitError(llvm::formatv(
1883             "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
1884             dim_i));
1885       }
1886     }
1887   }
1888 
1889   return success();
1890 }
1891 
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)1892 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
1893                                              RankedTensorType value_type,
1894                                              Location loc, OpBuilder *builder) {
1895   if (input_op == nullptr) return nullptr;
1896 
1897   mlir::DenseIntElementsAttr attr;
1898   if (!matchPattern(input_op, m_Constant(&attr))) {
1899     return nullptr;
1900   }
1901 
1902   auto value_shape_type = mlir::RankedTensorType::get(
1903       value_type.getShape(), builder->getIntegerType(32));
1904 
1905   SmallVector<int32_t, 4> value_i32;
1906   value_i32.reserve(value_type.getRank());
1907   for (const auto &size : attr) {
1908     value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
1909   }
1910   auto new_value_i32_attr =
1911       mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
1912 
1913   return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
1914 }
1915 
1916 // This will cast down int64 values for TFL slice op.
1917 // This will require the begin & size are constants.
1918 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
1919   using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
1920 
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt321921   LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
1922                                 PatternRewriter &rewriter) const override {
1923     auto begin = slice_op.begin();
1924     auto size = slice_op.size();
1925     auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
1926     auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
1927     auto begin_op = begin.getDefiningOp();
1928     auto size_op = size.getDefiningOp();
1929 
1930     if (begin_op == nullptr && size_op == nullptr) return failure();
1931 
1932     if (begin_type == nullptr && size_type == nullptr) return failure();
1933 
1934     // Handle begin.
1935     if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
1936       auto new_begin = NarrowDownInt64InputValuesForOp(
1937           begin_op, begin_type, slice_op.getLoc(), &rewriter);
1938       if (new_begin != nullptr) {
1939         slice_op.setOperand(1, new_begin);
1940       }
1941     }
1942 
1943     // Handle size.
1944     if (size_op && size_type && size_type.getElementType().isInteger(64)) {
1945       auto new_size = NarrowDownInt64InputValuesForOp(
1946           size_op, size_type, slice_op.getLoc(), &rewriter);
1947       if (new_size != nullptr) {
1948         slice_op.setOperand(2, new_size);
1949       }
1950     }
1951 
1952     return success();
1953   }
1954 };
1955 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1956 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1957                                           MLIRContext *context) {
1958   results.insert<CastDonwInt64BeginEndToInt32>(context);
1959 }
1960 
1961 //===----------------------------------------------------------------------===//
1962 // SqueezeOp
1963 //===----------------------------------------------------------------------===//
1964 
fold(ArrayRef<Attribute> operands)1965 OpFoldResult SqueezeOp::fold(ArrayRef<Attribute> operands) {
1966   auto input_ty = input().getType().dyn_cast<RankedTensorType>();
1967   auto result_ty = getType().dyn_cast<RankedTensorType>();
1968 
1969   if (!input_ty || !result_ty) return {};
1970   if (input_ty == result_ty) return input();
1971   return {};
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // SubOp
1976 //===----------------------------------------------------------------------===//
1977 
fold(ArrayRef<Attribute> operands)1978 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1979   // TODO(b/142478136): Handle fused ops.
1980   if (fused_activation_function() != "NONE") return {};
1981   return ConstFoldBinaryOp(
1982       getType(), operands, [](APFloat a, APFloat b) { return a - b; },
1983       [](APInt a, APInt b) { return a - b; });
1984 }
1985 
GetArithmeticCount(Operation * op)1986 int64_t SubOp::GetArithmeticCount(Operation *op) {
1987   int64_t count;
1988   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1989 
1990   return -1;
1991 }
1992 
1993 //===----------------------------------------------------------------------===//
1994 // TopKOp
1995 //===----------------------------------------------------------------------===//
1996 
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)1997 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
1998                         Value k) {
1999   // Output size is only known if k is constant value. A negative dimension is
2000   // considered dynamic so use -1 here if k is not a constant value.
2001   int const_k = -1;
2002   ElementsAttr cst;
2003   if (matchPattern(k, m_Constant(&cst)))
2004     // These casts should all be valid due to how Tensor constants are stored.
2005     // TODO(jpienaar): This should use a helper function.
2006     const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
2007 
2008   auto val_type = input.getType().cast<TensorType>();
2009   // If value is unranked, then so is results.
2010   if (!val_type.hasRank())
2011     return TFL::TopKV2Op::build(
2012         *builder, result, UnrankedTensorType::get(val_type.getElementType()),
2013         UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
2014 
2015   // Resultant shape is value.shape[:-1] + [k]
2016   std::vector<int64_t> shape(val_type.getShape());
2017   shape[shape.size() - 1] = const_k;
2018   TFL::TopKV2Op::build(
2019       *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
2020       RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
2021 }
2022 
2023 //===----------------------------------------------------------------------===//
2024 // FakeQuantOp
2025 //===----------------------------------------------------------------------===//
2026 
2027 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)2028 static inline bool HasValidMinMaxAttribute(Operation *op) {
2029   auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
2030   return minmax && minmax.getValue().size() == 2;
2031 }
2032 
2033 namespace {
2034 
2035 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
2036 /// and itself have "minmax" attribute set.
2037 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anona97bba201211::DropFakeQuant2038   explicit DropFakeQuant(MLIRContext *context)
2039       : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
2040 
matchmlir::TFL::__anona97bba201211::DropFakeQuant2041   LogicalResult match(Operation *op) const override {
2042     // We only match the op with valid "minmax" attribute.
2043     if (!HasValidMinMaxAttribute(op)) return failure();
2044 
2045     // If all the users of this op have valid "minmax" attributes, it is matched
2046     // and can be removed.
2047     auto fakeQuantOp = cast<FakeQuantOp>(op);
2048     for (auto *operand : fakeQuantOp.getResult().getUsers())
2049       if (!HasValidMinMaxAttribute(operand)) return failure();
2050 
2051     return success();
2052   }
2053 
rewritemlir::TFL::__anona97bba201211::DropFakeQuant2054   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
2055     // Replace the matched FakeQuantOp by its primary operand.
2056     rewriter.replaceOp(op, op->getOperand(0));
2057   }
2058 };
2059 }  // end anonymous namespace
2060 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2061 void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2062                                               MLIRContext *context) {
2063   results.insert<DropFakeQuant>(context);
2064 }
2065 
2066 //===----------------------------------------------------------------------===//
2067 // UnpackOp
2068 //===----------------------------------------------------------------------===//
2069 
2070 // TODO(b/133486129): Implement shape inference for unpack
2071 
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2072 LogicalResult UnpackOp::inferReturnTypes(
2073     MLIRContext *context, Optional<Location> loc, ValueRange operands,
2074     DictionaryAttr attributes, RegionRange regions,
2075     SmallVectorImpl<Type> &inferredReturnTypes) {
2076   UnpackOpAdaptor op(operands, attributes);
2077   // TODO(jpienaar): Refactor verify
2078   if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
2079     return failure();
2080 
2081   if (operands.size() != 1) {
2082     return emitOptionalError(loc, "input count should be equal to 1");
2083   }
2084 
2085   const int64_t num_value = op.num().getInt();
2086   auto input_type = operands[0].getType().dyn_cast<ShapedType>();
2087   if (!input_type || !input_type.hasRank()) {
2088     // If input is unranked, then so is output.
2089     inferredReturnTypes.assign(
2090         num_value, UnrankedTensorType::get(input_type.getElementType()));
2091     return success();
2092   }
2093 
2094   if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
2095     return emitOptionalError(
2096         loc, "number of elements in input should be larger than 0");
2097   }
2098 
2099   const int64_t rank = input_type.getRank();
2100   if (rank <= 0) {
2101     return emitOptionalError(loc, "input should be of rank larger than 0");
2102   }
2103 
2104   int64_t axis_value = op.axis().getInt();
2105   if (axis_value < 0) {
2106     axis_value += rank;
2107   }
2108   if (axis_value < 0 || axis_value >= rank) {
2109     return emitOptionalError(
2110         loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
2111         op.axis().getInt(), ", and rank = ", rank);
2112   }
2113 
2114   if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
2115       input_type.getDimSize(axis_value) != num_value) {
2116     return emitOptionalError(loc, "output count should match 'num' attribute");
2117   }
2118 
2119   auto output_shape = llvm::to_vector<4>(input_type.getShape());
2120   output_shape.erase(output_shape.begin() + axis_value);
2121 
2122   auto output_type =
2123       RankedTensorType::get(output_shape, input_type.getElementType());
2124   inferredReturnTypes.assign(num_value, output_type);
2125 
2126   return success();
2127 }
2128 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2129 bool UnpackOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
2130   if (lhs.size() != rhs.size()) return false;
2131   for (auto pair : llvm::zip(lhs, rhs)) {
2132     if (failed(
2133             mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
2134       return false;
2135   }
2136   return true;
2137 }
2138 
2139 //===----------------------------------------------------------------------===//
2140 // SplitOp
2141 //===----------------------------------------------------------------------===//
2142 
2143 // Extracts and returns the signed integer constant in a 0-rank integer tensor
2144 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)2145 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
2146   ElementsAttr attr;
2147   if (!matchPattern(value, m_Constant(&attr))) return {};
2148   if (attr.getNumElements() != 1) return {};
2149   IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
2150   return int_attr.getValue().getSExtValue();
2151 }
2152 
2153 // Returns a RankedTensorType which is similar to `input_type` but replaces the
2154 // dimension size of `dim` with `dim_size`.  For example,
2155 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
2156 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)2157 static RankedTensorType SubstituteRankedTensorTypeDimSize(
2158     RankedTensorType input_type, int64_t dim, int64_t dim_size) {
2159   auto shape = input_type.getShape().vec();
2160   shape[dim] = dim_size;
2161   return RankedTensorType::get(shape, input_type.getElementType());
2162 }
2163 
2164 // Verifies the output tensor types of SplitOp or SplitVOp.
2165 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)2166 static LogicalResult VerifySplitOpOutputTypes(
2167     Operation *op, int64_t num_splits,
2168     ExpectedOutputTypeGetter get_expected_output_type) {
2169   for (int64_t i = 0; i < num_splits; ++i) {
2170     auto expected_output_type = get_expected_output_type(i);
2171     Value output = op->getResult(i);
2172     if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
2173       return op->emitOpError()
2174              << "output #" << i << " should be " << expected_output_type
2175              << " instead got " << output.getType();
2176   }
2177   return success();
2178 }
2179 
Verify(SplitOp op)2180 static LogicalResult Verify(SplitOp op) {
2181   int64_t num_splits = op.num_splits();
2182   if (op.getNumResults() != num_splits)
2183     return op.emitOpError("output count should match 'num_splits' attribute");
2184 
2185   // If 'split_dim' is not a constant, there are no other checks.
2186   llvm::Optional<int64_t> split_dim_opt =
2187       ExtractConstantIntFromTensor(op.split_dim());
2188   if (!split_dim_opt) return success();
2189 
2190   // If 'input' is not a ranked tensor, there are no other checks.
2191   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2192   if (!input_type) return success();
2193 
2194   int64_t split_dim = split_dim_opt.getValue();
2195   const int64_t rank = input_type.getRank();
2196   if (split_dim < 0) split_dim += rank;
2197   if (split_dim < 0 || split_dim >= rank)
2198     return op.emitOpError("'split_dim' should be in [-rank, rank)");
2199 
2200   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2201   // there are no other checks.
2202   const int64_t dim_size = input_type.getDimSize(split_dim);
2203   if (ShapedType::isDynamic(dim_size)) return success();
2204 
2205   if (dim_size % num_splits != 0)
2206     return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
2207 
2208   // Verifies output tensor types.
2209   RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
2210       input_type, split_dim, dim_size / num_splits);
2211   return VerifySplitOpOutputTypes(
2212       op.getOperation(), num_splits,
2213       [expected_output_type](int64_t) { return expected_output_type; });
2214 }
2215 
Verify(SplitVOp op)2216 static LogicalResult Verify(SplitVOp op) {
2217   int64_t num_splits = op.num_splits();
2218   if (op.getNumResults() != num_splits)
2219     return op.emitOpError("output count should match 'num_splits' attribute");
2220 
2221   // If 'split_dim' is not a constant, there are no other checks.
2222   llvm::Optional<int64_t> split_dim_opt =
2223       ExtractConstantIntFromTensor(op.split_dim());
2224   if (!split_dim_opt) return success();
2225 
2226   // If 'input' is not a ranked tensor, there are no other checks.
2227   auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2228   if (!input_type) return success();
2229 
2230   int64_t split_dim = split_dim_opt.getValue();
2231   const int64_t rank = input_type.getRank();
2232   if (split_dim < 0) split_dim += rank;
2233   if (split_dim < 0 || split_dim >= rank)
2234     return op.emitOpError("'split_dim' should be in [-rank, rank)");
2235 
2236   // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2237   // there are no other checks.
2238   const int64_t dim_size = input_type.getDimSize(split_dim);
2239   if (ShapedType::isDynamic(dim_size)) return success();
2240 
2241   // If 'size_splits' is not a constant, there are no other checks.
2242   ElementsAttr size_splits_attr;
2243   if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
2244     return success();
2245 
2246   if (size_splits_attr.getNumElements() != num_splits) {
2247     auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
2248     RankedTensorType expected_size_splits_type =
2249         RankedTensorType::get({num_splits}, size_splits_type.getElementType());
2250     return op.emitOpError("'size_splits' should be ")
2251            << expected_size_splits_type;
2252   }
2253 
2254   // Normalizes and verifies 'size_splits'.
2255   // Note: TensorFlow allows one -1 element in 'size_splits'.  The -1 element
2256   // means the rest of the dimension size.
2257   llvm::SmallVector<int64_t, 4> size_splits;
2258   size_splits.reserve(num_splits);
2259 
2260   int64_t negative_size_split_loc = -1;
2261   int64_t total_size_splits = 0;
2262 
2263   for (int64_t i = 0; i < num_splits; ++i) {
2264     auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
2265     int64_t size_split = size_split_attr.getValue().getSExtValue();
2266     size_splits.push_back(size_split);
2267     if (size_split >= 0) {
2268       total_size_splits += size_split;
2269       continue;
2270     }
2271     if (size_split < -1)
2272       return op.emitOpError(
2273           "elements of 'size_splits' should be greater than or equal to -1");
2274     if (negative_size_split_loc != -1)
2275       return op.emitOpError("'size_splits' can only have one -1");
2276     negative_size_split_loc = i;
2277   }
2278 
2279   if (negative_size_split_loc != -1) {
2280     if (total_size_splits > dim_size)
2281       return op.emitOpError(
2282           "sum of non-negative elements of 'size_splits' is greater than the "
2283           "dimension size of 'split_dim' axis");
2284     size_splits[negative_size_split_loc] = dim_size - total_size_splits;
2285     total_size_splits = dim_size;
2286   }
2287 
2288   if (total_size_splits != dim_size)
2289     return op.emitOpError(
2290         "sum of 'size_splits' should match the dimension size of 'split_dim' "
2291         "axis");
2292 
2293   // Verifies result tensor types.
2294   auto get_expected_output_type = [input_type, split_dim,
2295                                    &size_splits](int64_t i) {
2296     return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
2297                                              size_splits[i]);
2298   };
2299   return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
2300                                   get_expected_output_type);
2301 }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // MeanOp
2305 //===----------------------------------------------------------------------===//
2306 
2307 // TODO(b/133854225): Implement shape inference to Mean
2308 
2309 //===----------------------------------------------------------------------===//
2310 // LSTMOp
2311 //===----------------------------------------------------------------------===//
2312 
Verify(LSTMOp op)2313 static LogicalResult Verify(LSTMOp op) {
2314   auto operands = op.GetStatefulOperands();
2315   if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2316     return op.emitOpError("LSTMOp expected to have two stateful operands");
2317   }
2318 
2319   const auto input_type = op.input().getType().cast<ShapedType>();
2320   // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2321   // doesn't have static shape, we skip the shape check below.
2322   if (!input_type.hasStaticShape()) return success();
2323   // The input should be at least 2D tensor since it will go through fully
2324   // connected layer.
2325   if (!input_type.hasRank() || input_type.getRank() < 2)
2326     return op.emitOpError(
2327         "the first input operand should have more than 2 dimensions.");
2328 
2329   const auto activation_state =
2330       op.input_activation_state().getType().cast<ShapedType>();
2331   const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2332   const auto input_to_output_weights =
2333       op.input_to_output_weights().getType().cast<ShapedType>();
2334   const auto recurrent_to_output_weights =
2335       op.recurrent_to_output_weights().getType().cast<ShapedType>();
2336   if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2337       input_to_output_weights.hasStaticShape() &&
2338       recurrent_to_output_weights.hasStaticShape()) {
2339     const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2340     const int n_cell = input_to_output_weights.getDimSize(0);
2341     const int n_output = recurrent_to_output_weights.getDimSize(1);
2342     const int output_state_size = activation_state.getNumElements();
2343     const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2344                                                   : input_type.getDimSize(1);
2345     const int state_size = cell_state.getNumElements();
2346 
2347     // Check if the dimension of the inputs matches.
2348     if ((output_state_size != n_batch * n_output) ||
2349         (state_size != n_batch * n_cell) ||
2350         (input_to_output_weights.getDimSize(1) != n_input) ||
2351         (recurrent_to_output_weights.getRank() != 2) ||
2352         (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2353         (input_to_output_weights.getRank() != 2)) {
2354       return op.emitOpError("inputs don't match with the dimensions.");
2355     }
2356 
2357     const bool is_layer_norm_lstm =
2358         !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2359     if (is_layer_norm_lstm) {
2360       const auto forget_layer_norm_coefficients =
2361           op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2362       // If this lstm has layer normalization, this input value,
2363       // "forget_layer_norm_coefficients" should be a 1D tensor.
2364       if (!forget_layer_norm_coefficients.hasRank() ||
2365           forget_layer_norm_coefficients.getRank() != 1 ||
2366           forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2367         return op.emitOpError(
2368             "coefficient inputs have more than 2 dimensions or "
2369             "don't match the dimension with input operand "
2370             "`input_to_output_weights`.");
2371     }
2372   }
2373 
2374   return success();
2375 }
2376 
2377 namespace {
2378 
2379 // Replaces the optional bias operands with a "none" type value if the bias
2380 // values are constant zeros.
2381 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2382   using OpRewritePattern<LSTMOp>::OpRewritePattern;
2383 
matchAndRewritemlir::TFL::__anona97bba201511::RemoveLSTMOpZeroBias2384   LogicalResult matchAndRewrite(LSTMOp op,
2385                                 PatternRewriter &rewriter) const override {
2386     if (EqualsZero(op.input_gate_bias())) {
2387       auto none_value = rewriter.create<mlir::ConstantOp>(
2388           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2389       op.input_gate_biasMutable().assign(none_value);
2390     }
2391 
2392     if (EqualsZero(op.projection_bias())) {
2393       auto none_value = rewriter.create<mlir::ConstantOp>(
2394           rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2395       op.projection_biasMutable().assign(none_value);
2396     }
2397 
2398     return success();
2399   }
2400 };
2401 
2402 }  // namespace
2403 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2404 void LSTMOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2405                                          MLIRContext *context) {
2406   results.insert<RemoveLSTMOpZeroBias>(context);
2407 }
2408 
2409 //===----------------------------------------------------------------------===//
2410 // UnidirectionalSequenceLSTMOp
2411 //===----------------------------------------------------------------------===//
2412 
Verify(UnidirectionalSequenceLSTMOp op)2413 static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
2414   auto operands = op.GetStatefulOperands();
2415   if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2416     return success();
2417   }
2418   return op.emitError(
2419       "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2420 }
2421 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2422 LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
2423     MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
2424     RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
2425   Value input = operands[0];
2426   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
2427 
2428   Value output_state = operands[18];
2429   auto output_state_type =
2430       output_state.getType().dyn_cast_or_null<RankedTensorType>();
2431 
2432   if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
2433     return failure();
2434   }
2435 
2436   if (output_state_type && output_state_type.hasRank() &&
2437       output_state_type.getRank() != 2) {
2438     return failure();
2439   }
2440 
2441   if (!input_type || !input_type.hasRank() || !output_state_type ||
2442       !output_state_type.hasRank()) {
2443     // We cannot infer the output shape since we don't know the input shape or
2444     // the output state shape. We will set the output shape as unranked.
2445     Type result_type;
2446     result_type = UnrankedTensorType::get(
2447         input.getType().cast<ShapedType>().getElementType());
2448     inferredReturnTypes.assign({result_type});
2449     return success();
2450   }
2451 
2452   // Default to non-time_major.
2453   bool time_majored = attr.getNamed("time_major").hasValue()
2454                           ? attr.getNamed("time_major")
2455                                 .getValue()
2456                                 .second.cast<BoolAttr>()
2457                                 .getValue()
2458                           : false;
2459 
2460   int batch =
2461       time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
2462   int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
2463   int n_output = output_state_type.getDimSize(1);
2464 
2465   // Build the output shape.
2466   SmallVector<int64_t, 3> output_shape;
2467   if (time_majored) {
2468     output_shape = {time, batch, n_output};
2469   } else {
2470     output_shape = {batch, time, n_output};
2471   }
2472   auto result_type =
2473       mlir::RankedTensorType::get(output_shape, input_type.getElementType());
2474 
2475   inferredReturnTypes.assign({result_type});
2476   return success();
2477 }
2478 
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2479 bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(TypeRange lhs,
2480                                                            TypeRange rhs) {
2481   if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
2482   if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
2483   return true;
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // BidirectionalSequenceLSTMOp
2488 //===----------------------------------------------------------------------===//
2489 
Verify(BidirectionalSequenceLSTMOp op)2490 static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
2491   auto operands = op.GetStatefulOperands();
2492   if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2493       operands[2] == 37 && operands[3] == 38) {
2494     return success();
2495   }
2496   return op.emitError(
2497       "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2498 }
2499 
2500 //===----------------------------------------------------------------------===//
2501 // UnidirectionalSequenceRNNOp
2502 //===----------------------------------------------------------------------===//
2503 
Verify(UnidirectionalSequenceRNNOp op)2504 static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
2505   auto operands = op.GetStatefulOperands();
2506   if (operands.size() == 1 && operands[0] == 4) {
2507     return success();
2508   }
2509   return op.emitError(
2510       "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2511 }
2512 
2513 //===----------------------------------------------------------------------===//
2514 // SvdfOp
2515 //===----------------------------------------------------------------------===//
2516 
Verify(SVDFOp op)2517 static LogicalResult Verify(SVDFOp op) {
2518   auto operands = op.GetStatefulOperands();
2519   if (operands.size() == 1 && operands[0] == 4) {
2520     return success();
2521   }
2522   return op.emitError("SvdfOp expected to have one stateful operand");
2523 }
2524 
2525 //===----------------------------------------------------------------------===//
2526 // AbsOp
2527 //===----------------------------------------------------------------------===//
2528 
fold(ArrayRef<Attribute> operands)2529 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2530   Type result_type = getType();
2531   // Only constant fold for tensor of f32 is implemented.
2532   if (!IsF32ShapedType(result_type)) return nullptr;
2533 
2534   auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2535   return ConstFoldUnaryOp(result_type, operands[0], compute);
2536 }
2537 
2538 //===----------------------------------------------------------------------===//
2539 // NegOp
2540 //===----------------------------------------------------------------------===//
2541 
fold(ArrayRef<Attribute> operands)2542 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2543   Type result_type = getType();
2544   // Only constant fold for tensor of f32 is implemented.
2545   if (!IsF32ShapedType(result_type)) return nullptr;
2546 
2547   auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2548   return ConstFoldUnaryOp(result_type, operands[0], compute);
2549 }
2550 
2551 //===----------------------------------------------------------------------===//
2552 // SinOp
2553 //===----------------------------------------------------------------------===//
2554 
fold(ArrayRef<Attribute> operands)2555 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2556   Type result_type = getType();
2557   // Only constant fold for tensor of f32 is implemented.
2558   if (!IsF32ShapedType(result_type)) return nullptr;
2559 
2560   auto compute = [](APFloat value) -> APFloat {
2561     float f = value.convertToFloat();
2562     float result = std::sin(f);
2563     return APFloat(result);
2564   };
2565   return ConstFoldUnaryOp(result_type, operands[0], compute);
2566 }
2567 
2568 //===----------------------------------------------------------------------===//
2569 // CosOp
2570 //===----------------------------------------------------------------------===//
2571 
fold(ArrayRef<Attribute> operands)2572 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2573   Type result_type = getType();
2574   // Only constant fold for tensor of f32 is implemented.
2575   if (!IsF32ShapedType(result_type)) return nullptr;
2576 
2577   auto compute = [](APFloat value) -> APFloat {
2578     float f = value.convertToFloat();
2579     float result = std::cos(f);
2580     return APFloat(result);
2581   };
2582   return ConstFoldUnaryOp(result_type, operands[0], compute);
2583 }
2584 
2585 //===----------------------------------------------------------------------===//
2586 // LogOp
2587 //===----------------------------------------------------------------------===//
2588 
fold(ArrayRef<Attribute> operands)2589 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2590   Type result_type = getType();
2591   // Only constant fold for tensor of f32 is implemented.
2592   if (!IsF32ShapedType(result_type)) return nullptr;
2593 
2594   auto compute = [](APFloat value) -> APFloat {
2595     float f = value.convertToFloat();
2596     float result = std::log(f);
2597     return APFloat(result);
2598   };
2599   return ConstFoldUnaryOp(result_type, operands[0], compute);
2600 }
2601 
2602 //===----------------------------------------------------------------------===//
2603 // ShapeOp
2604 //===----------------------------------------------------------------------===//
2605 
fold(ArrayRef<Attribute> operands)2606 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
2607   auto input_type = input().getType().cast<ShapedType>();
2608   if (!input_type.hasStaticShape()) return nullptr;
2609 
2610   ArrayRef<int64_t> shape = input_type.getShape();
2611   auto result_type = getType().cast<ShapedType>();
2612   if (result_type.getElementType().isInteger(64)) {
2613     return DenseElementsAttr::get<int64_t>(result_type, shape);
2614   } else if (result_type.getElementType().isInteger(32)) {
2615     SmallVector<int32_t, 4> shape_i32;
2616     shape_i32.reserve(shape.size());
2617     for (int64_t dim : shape) {
2618       shape_i32.push_back(dim);
2619     }
2620     return DenseElementsAttr::get<int32_t>(result_type, shape_i32);
2621   }
2622   return nullptr;
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // SqrtOp
2627 //===----------------------------------------------------------------------===//
2628 
fold(ArrayRef<Attribute> operands)2629 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2630   Type result_type = getType();
2631   // Only constant fold for tensor of f32 is implemented.
2632   if (!IsF32ShapedType(result_type)) return nullptr;
2633 
2634   auto compute = [](APFloat value) -> APFloat {
2635     float f = value.convertToFloat();
2636     float result = std::sqrt(f);
2637     return APFloat(result);
2638   };
2639   return ConstFoldUnaryOp(result_type, operands[0], compute);
2640 }
2641 
2642 //===----------------------------------------------------------------------===//
2643 // RsqrtOp
2644 //===----------------------------------------------------------------------===//
2645 
fold(ArrayRef<Attribute> operands)2646 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2647   Type result_type = getType();
2648   // Only constant fold for tensor of f32/bf16 is implemented.
2649   if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2650     return nullptr;
2651 
2652   auto compute = [](APFloat value) -> APFloat {
2653     bool loseInfo;
2654     const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2655     value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2656                   &loseInfo);
2657     float f = value.convertToFloat();
2658     APFloat result(1.f / std::sqrt(f));
2659     result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2660                    &loseInfo);
2661     return result;
2662   };
2663   return ConstFoldUnaryOp(result_type, operands[0], compute);
2664 }
2665 
2666 //===----------------------------------------------------------------------===//
2667 // SquareOp
2668 //===----------------------------------------------------------------------===//
2669 
fold(ArrayRef<Attribute> operands)2670 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2671   Type result_type = getType();
2672   // Only constant fold for tensor of f32 is implemented.
2673   if (!IsF32ShapedType(result_type)) return nullptr;
2674 
2675   auto compute = [](APFloat value) -> APFloat { return value * value; };
2676   return ConstFoldUnaryOp(result_type, operands[0], compute);
2677 }
2678 
2679 //===----------------------------------------------------------------------===//
2680 // RankOp
2681 //===----------------------------------------------------------------------===//
2682 
fold(ArrayRef<Attribute> operands)2683 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2684   assert(operands.size() == 1);
2685   auto result_type = getType().cast<ShapedType>();
2686   if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2687     auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2688     return DenseElementsAttr::get(result_type, {rank});
2689   }
2690 
2691   // Also fold if `input` has a known rank.
2692   auto input_type = input().getType().cast<ShapedType>();
2693   // Do not fold if rank is zero because the TFLite converter doesn't
2694   // distinguish between unranked input and scalar input due to b/138865275.
2695   // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2696   // predicate and fold the op when rank is zero.
2697   if (input_type.hasRank() && input_type.getRank() != 0) {
2698     auto rank = static_cast<int32_t>(input_type.getRank());
2699     return DenseElementsAttr::get(result_type, {rank});
2700   }
2701 
2702   return nullptr;
2703 }
2704 
2705 //===----------------------------------------------------------------------===//
2706 // ConstOp
2707 //===----------------------------------------------------------------------===//
2708 
fold(ArrayRef<Attribute> operands)2709 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2710   assert(operands.empty() && "constant has no operands");
2711   // Return the held attribute value.
2712   return value();
2713 }
2714 
2715 namespace {
2716 struct FoldPseudoConstOp : public OpRewritePattern<ConstOp> {
2717   using OpRewritePattern<ConstOp>::OpRewritePattern;
2718 
matchAndRewritemlir::TFL::__anona97bba201e11::FoldPseudoConstOp2719   LogicalResult matchAndRewrite(ConstOp const_op,
2720                                 PatternRewriter &rewriter) const override {
2721     if (!ConstantOp::isBuildableWith(const_op.value(), const_op.getType()))
2722       return failure();
2723     rewriter.replaceOpWithNewOp<ConstantOp>(const_op, const_op.value());
2724     return success();
2725   }
2726 };
2727 
2728 }  // namespace
2729 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2730 void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2731                                           MLIRContext *context) {
2732   results.insert<FoldPseudoConstOp>(context);
2733 }
2734 
2735 //===----------------------------------------------------------------------===//
2736 // CastOp
2737 //===----------------------------------------------------------------------===//
2738 
fold(ArrayRef<Attribute> operands)2739 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2740   assert(operands.size() == 1);
2741   if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2742     return input();
2743   }
2744 
2745   // For now, only supports cast between integer types.
2746   auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2747   if (!elements_attr) {
2748     return nullptr;
2749   }
2750 
2751   auto result_element_type =
2752       getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2753   auto operand_element_type = input()
2754                                   .getType()
2755                                   .cast<ShapedType>()
2756                                   .getElementType()
2757                                   .dyn_cast<IntegerType>();
2758   // Returns nullptr if either result/operand element type is not integer.
2759   if (!result_element_type || !operand_element_type) {
2760     return nullptr;
2761   }
2762 
2763   const bool is_unsigned = operand_element_type.isUnsigned();
2764   const bool involves_bool = operand_element_type.getWidth() == 1 ||
2765                              result_element_type.getWidth() == 1;
2766   const int output_bitwidth = result_element_type.getWidth();
2767   // The integer cast op is the same as C integer cast. Depends on the operand
2768   // type's signedness, we will determine whether or not sign extension is
2769   // needed.
2770   auto cast = [&](APInt value) {
2771     if (involves_bool) {
2772       // Handle boolean inputs or outputs explicitly as it doesn't have the same
2773       // behavior as extension or truncation.
2774       // true input should always be cast to 1 and not -1 as the sign extension
2775       // would do for signed outputs. Similarly, non-zero inputs should be cast
2776       // to true. Truncating even numbers to one bit will result in `false`.
2777       return APInt(result_element_type.getWidth(), value != 0);
2778     }
2779     return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2780                        : value.sextOrTrunc(output_bitwidth);
2781   };
2782 
2783   return elements_attr.mapValues(result_element_type, cast);
2784 }
2785 
2786 //===----------------------------------------------------------------------===//
2787 // SelectV2Op
2788 //===----------------------------------------------------------------------===//
2789 
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2790 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2791                             Value cond, Value x, Value y) {
2792   auto operand_type =
2793       OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2794 
2795   if (!operand_type)
2796     emitError(result.location) << "non-broadcastable operands: " << x.getType()
2797                                << " and " << y.getType();
2798 
2799   bool has_static_cond_shape = false;
2800   bool has_static_operand_shape = false;
2801   ArrayRef<int64_t> cond_shape;
2802   ArrayRef<int64_t> operand_shape;
2803 
2804   if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
2805     if (shaped_type.hasStaticShape()) {
2806       has_static_cond_shape = true;
2807       cond_shape = shaped_type.getShape();
2808     }
2809   }
2810   if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
2811     if (shaped_type.hasStaticShape()) {
2812       has_static_operand_shape = true;
2813       operand_shape = shaped_type.getShape();
2814     }
2815   }
2816 
2817   SmallVector<int64_t, 4> broadcastedShape;
2818   if (has_static_cond_shape && has_static_operand_shape &&
2819       !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
2820                                           broadcastedShape)) {
2821     emitError(result.location) << "non-broadcastable operands: " << operand_type
2822                                << " and " << cond.getType();
2823   }
2824 
2825   result.addOperands({cond, x, y});
2826 
2827   auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
2828   if (has_static_cond_shape && has_static_operand_shape) {
2829     result.types.push_back(
2830         RankedTensorType::get(broadcastedShape, elementType));
2831   } else {
2832     result.types.push_back(UnrankedTensorType::get(elementType));
2833   }
2834 }
2835 
2836 //===----------------------------------------------------------------------===//
2837 // RangeOp
2838 //===----------------------------------------------------------------------===//
2839 
2840 namespace {
2841 
2842 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
2843 // Template parameter `FloatOrInt` must be standard C integer or floating-point
2844 // types.
2845 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)2846 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
2847   // Refer to the implementation in
2848   // tensorflow/lite/kernels/range.cc.
2849   return std::is_integral<FloatOrInt>::value
2850              ? ((std::abs(limit - start) + std::abs(delta) - 1) /
2851                 std::abs(delta))
2852              : std::ceil(std::abs((limit - start) / delta));
2853 }
2854 
2855 // Builds a constant range tensor of `result_elem_type` elements.
2856 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
2857 // mlir::FloatAttr.
2858 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)2859 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
2860                                         FloatOrIntAtrr start_attr,
2861                                         FloatOrIntAtrr delta_attr) {
2862   using ValueType = typename FloatOrIntAtrr::ValueType;  // APInt or APFloat
2863   ValueType start = start_attr.getValue();
2864   ValueType delta = delta_attr.getValue();
2865 
2866   SmallVector<ValueType, 16> new_values;
2867   new_values.reserve(num_elements);
2868   ValueType new_value = start;
2869   for (int i = 0; i < num_elements; ++i) {
2870     new_values.push_back(new_value);
2871     new_value = new_value + delta;
2872   }
2873   // Result is always a 1-D tensor.
2874   auto new_result_type =
2875       RankedTensorType::get({num_elements}, result_elem_type);
2876   return DenseElementsAttr::get(new_result_type, new_values);
2877 }
2878 }  // namespace
2879 
fold(ArrayRef<Attribute> operands)2880 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
2881   assert(operands.size() == 3);
2882   auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2883   auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2884   auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
2885   if (start_tensor && limit_tensor && delta_tensor) {
2886     // Operands should all be scalars
2887     assert(start_tensor.getType().getRank() == 0 &&
2888            limit_tensor.getType().getRank() == 0 &&
2889            delta_tensor.getType().getRank() == 0);
2890     Type elem_type = getType().cast<ShapedType>().getElementType();
2891     if (elem_type.isSignlessInteger()) {
2892       auto start_attr = start_tensor.getValue<IntegerAttr>({});
2893       auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
2894       auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
2895       const int num_elements = GetLengthOfRange(
2896           start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
2897       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2898                                    delta_attr);
2899     } else if (elem_type.isa<FloatType>()) {
2900       auto start_attr = start_tensor.getValue<FloatAttr>({});
2901       auto limit_attr = limit_tensor.getValue<FloatAttr>({});
2902       auto delta_attr = delta_tensor.getValue<FloatAttr>({});
2903       const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
2904                                                 limit_attr.getValueAsDouble(),
2905                                                 delta_attr.getValueAsDouble());
2906       return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2907                                    delta_attr);
2908     }
2909   }
2910 
2911   return nullptr;
2912 }
2913 
2914 //===----------------------------------------------------------------------===//
2915 // TransposeConvOp
2916 //===----------------------------------------------------------------------===//
2917 
Verify(TransposeConvOp op)2918 static LogicalResult Verify(TransposeConvOp op) {
2919   ShapedType output_type = op.output().getType().cast<ShapedType>();
2920   ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
2921   if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
2922     if (output_type.getRank() != output_shape_type.getDimSize(0)) {
2923       return op.emitOpError(llvm::formatv(
2924           "expect output type has rank = {0}, got output type {1}",
2925           output_shape_type.getDimSize(0), output_type));
2926     }
2927   }
2928 
2929   DenseIntElementsAttr output_shape_elements;
2930   if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
2931     return success();
2932   }
2933 
2934   llvm::SmallVector<int64_t, 4> output_shape;
2935   output_shape.reserve(output_shape_elements.getNumElements());
2936   for (auto dim : output_shape_elements.getValues<int>()) {
2937     output_shape.push_back(dim);
2938   }
2939 
2940   auto expected_output_type =
2941       RankedTensorType::get(output_shape, output_type.getElementType());
2942   if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2943     return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2944                                         expected_output_type, output_type));
2945   }
2946 
2947   return success();
2948 }
2949 
GetArithmeticCount(Operation * op)2950 int64_t TransposeConvOp::GetArithmeticCount(Operation *op) {
2951   int64_t count = -1;
2952   auto transpose_conv = llvm::dyn_cast<TransposeConvOp>(op);
2953   auto input_type = transpose_conv.input()
2954                         .getType()
2955                         .dyn_cast_or_null<mlir::RankedTensorType>();
2956   auto weight_type = transpose_conv.weights()
2957                          .getType()
2958                          .dyn_cast_or_null<mlir::RankedTensorType>();
2959   if (input_type && weight_type && input_type.hasStaticShape() &&
2960       weight_type.hasStaticShape()) {
2961     // Compute op count from the seven nested loops of
2962     // tflite::reference_ops::TransposeConv():
2963     count = 2 * input_type.getNumElements() * weight_type.getDimSize(0) *
2964             weight_type.getDimSize(1) * weight_type.getDimSize(2);
2965   }
2966 
2967   return count;
2968 }
2969 
2970 //===----------------------------------------------------------------------===//
2971 // StridedSliceOp
2972 //===----------------------------------------------------------------------===//
2973 
Verify(StridedSliceOp op)2974 LogicalResult Verify(StridedSliceOp op) {
2975   auto ranked_input_type = op.input().getType().dyn_cast<RankedTensorType>();
2976 
2977   // If input is unranked, there is nothing else to be verified.
2978   if (!ranked_input_type) return success();
2979   int num_input_dims = ranked_input_type.getRank();
2980 
2981   if (auto begin_type = op.begin().getType().dyn_cast<RankedTensorType>()) {
2982     if (begin_type.getRank() != 1) return failure();
2983     if (begin_type.getDimSize(0) > num_input_dims) return failure();
2984   }
2985 
2986   if (auto end_type = op.end().getType().dyn_cast<RankedTensorType>()) {
2987     if (end_type.getRank() != 1) return failure();
2988     if (end_type.getDimSize(0) > num_input_dims) return failure();
2989   }
2990 
2991   if (auto strides_type = op.strides().getType().dyn_cast<RankedTensorType>()) {
2992     if (strides_type.getRank() != 1) return failure();
2993     if (strides_type.getDimSize(0) > num_input_dims) return failure();
2994   }
2995 
2996   // The kernel will reshape the input tensor with new axis, it only supports
2997   // this reshaped tensor up to 5D.
2998   uint32_t ellipsis_mask = op.ellipsis_mask();
2999   uint32_t new_axis_mask = op.new_axis_mask();
3000   int num_added_axis = 0;
3001   for (int i = 0; i < 8; ++i) {
3002     if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) {
3003       num_added_axis++;
3004     }
3005   }
3006   if (num_input_dims + num_added_axis > 5) return failure();
3007   return success();
3008 }
3009 
fold(ArrayRef<Attribute> operands)3010 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
3011   // Currently only support all masks being 0.
3012   if (begin_mask() != 0 || end_mask() != 0 || ellipsis_mask() != 0 ||
3013       new_axis_mask() != 0 || shrink_axis_mask() != 0)
3014     return {};
3015 
3016   auto input_type = input().getType().dyn_cast_or_null<RankedTensorType>();
3017   if (!input_type || !input_type.hasStaticShape()) return {};
3018 
3019   // Begin has to be all 0s.
3020   DenseIntElementsAttr begin_dense_elem_attr;
3021   if (!matchPattern(begin(), m_Constant(&begin_dense_elem_attr))) {
3022     return {};
3023   }
3024   for (auto begin_ele : begin_dense_elem_attr) {
3025     if (begin_ele.getSExtValue() != 0) {
3026       return {};
3027     }
3028   }
3029 
3030   // Strides has to be all 1s.
3031   DenseIntElementsAttr strides_dense_elem_attr;
3032   if (!matchPattern(strides(), m_Constant(&strides_dense_elem_attr))) {
3033     return {};
3034   }
3035   for (auto stride_ele : strides_dense_elem_attr) {
3036     if (stride_ele.getSExtValue() != 1) {
3037       return {};
3038     }
3039   }
3040   // End has to map the input shape.
3041   DenseIntElementsAttr end_dense_elem_attr;
3042   if (!matchPattern(end(), m_Constant(&end_dense_elem_attr))) {
3043     return {};
3044   }
3045   int i = 0;
3046   for (auto end_ele : end_dense_elem_attr) {
3047     if (end_ele.getSExtValue() != input_type.getDimSize(i)) {
3048       return {};
3049     }
3050     ++i;
3051   }
3052 
3053   return input();
3054 }
3055 
3056 //===----------------------------------------------------------------------===//
3057 // TransposeOp
3058 //===----------------------------------------------------------------------===//
3059 
3060 namespace {
3061 
3062 // Computes the permutation of a constant `input_tensor` according to `perm`.
3063 // The function recursively traverses the dimensions of the output tensor in
3064 // a row-major order and writes the value in the output tensor into
3065 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)3066 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
3067                         ArrayRef<int64_t> output_shape, int num_dimensions,
3068                         int output_axis, std::vector<uint64_t> *input_indices,
3069                         std::vector<Attribute> *new_values) {
3070   // Refer to the implementation of `Transpose` function in
3071   // tensorflow/lite/kernels/internal/reference/reference_ops.h
3072   assert(output_axis < num_dimensions);
3073   const int input_axis = perm[output_axis];
3074   for (int i = 0; i < output_shape[output_axis]; ++i) {
3075     // Update the input indices on `input_axis`.
3076     input_indices->at(input_axis) = i;
3077     // Write the value from `input_tensor` if it is the last axis or
3078     // recurse into the next axis.
3079     const bool is_last_axis = output_axis == num_dimensions - 1;
3080     if (is_last_axis) {
3081       new_values->push_back(input_tensor.getValue(*input_indices));
3082     } else {
3083       ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3084                          output_axis + 1, input_indices, new_values);
3085     }
3086   }
3087 }
3088 
3089 }  // namespace
3090 
fold(ArrayRef<Attribute> operands)3091 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3092   assert(operands.size() == 2);
3093   auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3094   auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3095   if (!input_tensor || !perm_tensor) return nullptr;
3096 
3097   // Do not try to fold elements attr of a quant type because
3098   // DenseElementsAttr does not support it.
3099   if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
3100     return nullptr;
3101 
3102   assert(perm_tensor.getType().getRank() == 1);
3103   const int num_dimensions = input_tensor.getType().getRank();
3104   assert(perm_tensor.getType().getNumElements() == num_dimensions);
3105 
3106   ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
3107   auto output_type = getType().cast<ShapedType>();
3108 
3109   SmallVector<int32_t, 4> perm;
3110   SmallVector<int64_t, 4> output_shape;
3111   for (int i = 0; i < num_dimensions; ++i) {
3112     perm.push_back(
3113         perm_tensor.getValue<IntegerAttr>({static_cast<uint64_t>(i)}).getInt());
3114     output_shape.push_back(input_shape[perm[i]]);
3115 
3116     // Check that the derived output shape matches the static shape.
3117     assert(!output_type.hasStaticShape() ||
3118            output_type.getShape()[i] == output_shape[i]);
3119   }
3120 
3121   std::vector<Attribute> new_values;
3122   new_values.reserve(input_tensor.getType().getNumElements());
3123   std::vector<uint64_t> input_indices(num_dimensions);
3124   ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3125                      /*output_axis=*/0, &input_indices, &new_values);
3126   auto result_type =
3127       RankedTensorType::get(output_shape, output_type.getElementType());
3128   return DenseElementsAttr::get(result_type, new_values);
3129 }
3130 
Verify(TransposeOp op)3131 static LogicalResult Verify(TransposeOp op) {
3132   auto input_type = op.input().getType().cast<ShapedType>();
3133   auto perm_type = op.perm().getType().cast<ShapedType>();
3134   auto output_type = op.output().getType().cast<ShapedType>();
3135   if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
3136     if (perm_type.getNumElements() != input_type.getRank()) {
3137       return op.emitOpError(
3138           "perm tensor elements size is not equal to input tensor rank");
3139     }
3140   }
3141 
3142   DenseIntElementsAttr perm;
3143   if (!matchPattern(op.perm(), m_Constant(&perm))) {
3144     return success();
3145   }
3146 
3147   int index = 0;
3148   llvm::SmallVector<int64_t, 4> axes;
3149   for (const auto &axis_int : perm.getValues<APInt>()) {
3150     const int64_t axis = axis_int.getSExtValue();
3151     if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
3152       return op.emitOpError(
3153           llvm::formatv("perm[{0}] must be in [0, rank)", index));
3154     }
3155     if (std::count(axes.begin(), axes.end(), axis) > 0) {
3156       return op.emitOpError(
3157           llvm::formatv("perm[{0}] cannot have duplicated axis", index));
3158     }
3159     axes.push_back(axis);
3160     index++;
3161   }
3162 
3163   if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
3164     llvm::SmallVector<int64_t, 4> transposed_shape;
3165     for (int64_t axis : axes) {
3166       transposed_shape.push_back(input_type.getDimSize(axis));
3167     }
3168     auto expected_output_type =
3169         RankedTensorType::get(transposed_shape, input_type.getElementType());
3170     if (failed(
3171             mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3172       return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3173                                           expected_output_type, output_type));
3174     }
3175   }
3176 
3177   return success();
3178 }
3179 
BuildTransposeOp(OpBuilder * builder,OperationState & result,Value input,Value perm)3180 static void BuildTransposeOp(OpBuilder *builder, OperationState &result,
3181                              Value input, Value perm) {
3182   // Output size is only known if input is ranked and perm is a constant.
3183   auto input_type = input.getType().cast<TensorType>();
3184   DenseIntElementsAttr perm_const;
3185   if (!input_type.hasRank() || !matchPattern(perm, m_Constant(&perm_const)) ||
3186       perm_const.getIntValues().empty()) {
3187     TFL::TransposeOp::build(
3188         *builder, result, UnrankedTensorType::get(input_type.getElementType()),
3189         input, perm);
3190     return;
3191   }
3192 
3193   const auto perm_value_it = perm_const.getIntValues().begin();
3194 
3195   const ArrayRef<int64_t> input_shape = input_type.getShape();
3196   SmallVector<int64_t, 4> output_shape(input_shape.size());
3197 
3198   for (int i = 0; i < output_shape.size(); ++i) {
3199     const APInt perm_val = perm_value_it[i];
3200     output_shape[i] = input_shape[perm_val.getSExtValue()];
3201   }
3202 
3203   TFL::TransposeOp::build(
3204       *builder, result,
3205       RankedTensorType::get(output_shape, input_type.getElementType()), input,
3206       perm);
3207 }
3208 
3209 //===----------------------------------------------------------------------===//
3210 // IfOp
3211 //===----------------------------------------------------------------------===//
3212 
3213 /// Given the region at `index`, or the parent operation if `index` is None,
3214 /// return the successor regions. These are the regions that may be selected
3215 /// during the flow of control. `operands` is a set of optional attributes that
3216 /// correspond to a constant value for each operand, or null if that operand is
3217 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3218 void IfOp::getSuccessorRegions(Optional<unsigned> index,
3219                                ArrayRef<Attribute> operands,
3220                                SmallVectorImpl<RegionSuccessor> &regions) {
3221   // The `then` and the `else` region branch back to the parent operation.
3222   if (index.hasValue()) {
3223     regions.push_back(RegionSuccessor(getResults()));
3224     return;
3225   }
3226 
3227   // Don't consider the else region if it is empty.
3228   Region *else_reg = &else_region();
3229   if (else_reg->empty()) else_reg = nullptr;
3230 
3231   // Otherwise, the successor is dependent on the condition.
3232   bool condition;
3233   if (auto cond_attr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
3234     condition = cond_attr.getValue().isOneValue();
3235   } else {
3236     // If the condition isn't constant, both regions may be executed.
3237     regions.push_back(RegionSuccessor(&then_region()));
3238     // If the else region does not exist, it is not a viable successor.
3239     if (else_reg) regions.push_back(RegionSuccessor(else_reg));
3240     return;
3241   }
3242 
3243   // Add the successor regions using the condition.
3244   regions.push_back(RegionSuccessor(condition ? &then_region() : else_reg));
3245 }
3246 
3247 //===----------------------------------------------------------------------===//
3248 // WhileOp
3249 //===----------------------------------------------------------------------===//
3250 
Verify(WhileOp op)3251 LogicalResult Verify(WhileOp op) {
3252   if (op.getNumOperands() != op.getNumResults())
3253     return op.emitOpError(llvm::formatv(
3254         "number of operands does not match number of results ({0} != {1})",
3255         op.getNumOperands(), op.getNumResults()));
3256   // TODO(jpienaar): Verify operand, result & block arguments types
3257   return success();
3258 }
3259 
3260 namespace {
3261 // Canonicalize While op so that results and operands match and external values
3262 // are via implicit capture rather than via block args.
3263 struct WhileResultOperandsMatchAndImplicitCapture
3264     : public OpRewritePattern<WhileOp> {
3265   using OpRewritePattern<WhileOp>::OpRewritePattern;
3266 
matchAndRewritemlir::TFL::__anona97bba202211::WhileResultOperandsMatchAndImplicitCapture3267   LogicalResult matchAndRewrite(WhileOp while_op,
3268                                 PatternRewriter &rewriter) const override {
3269     // Replace values simply passed through the body with extern values
3270     // (in both body and condition regions as well as while result). The
3271     // block arguments of body and while match and so the corresponding cond
3272     // argument can be easily found.
3273     bool unchanged = true;
3274     auto &body_block = while_op.body().front();
3275     auto &cond_block = while_op.cond().front();
3276     auto &yield = *body_block.getTerminator();
3277     for (auto ba : body_block.getArguments()) {
3278       int arg_no = ba.getArgNumber();
3279       // Skip removing resources that are not read-only variables.
3280       if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
3281         bool has_read_only_variables = true;
3282         for (auto user : ba.getUsers()) {
3283           // Ternimator ops, for example, tfl::yield op, should be ignored since
3284           // the argument can be used for yielding as the `body` function result
3285           // and that does not give any meaningful points to the decision
3286           // whether the given arugment is a read-only variable or not.
3287           if (user->hasTrait<OpTrait::IsTerminator>()) continue;
3288           if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
3289             has_read_only_variables = false;
3290             break;
3291           }
3292         }
3293         if (!has_read_only_variables) continue;
3294       }
3295       if (ba == yield.getOperand(arg_no)) {
3296         unchanged = false;
3297         auto value = while_op.getOperand(arg_no);
3298         ba.replaceAllUsesWith(value);
3299         cond_block.getArgument(arg_no).replaceAllUsesWith(value);
3300 
3301         // This could be relaxed and casts inserted.
3302         if (while_op.getResult(arg_no).getType() == value.getType())
3303           while_op.getResult(arg_no).replaceAllUsesWith(value);
3304       }
3305     }
3306 
3307     // The While ops operands and result types need to match
3308     SmallVector<Value, 4> new_operands;
3309     SmallVector<Value, 4> new_body_yield;
3310     SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
3311     llvm::SmallVector<Type, 4> types;
3312     new_operands.reserve(while_op.getNumOperands());
3313     new_body_yield.reserve(while_op.getNumOperands());
3314     types.reserve(while_op.getNumOperands());
3315 
3316     // Remove block arguments not used in either cond or body. This leaves the
3317     // block arguments of body and cond matching still.
3318     int arg_index = 0;
3319     for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
3320          ++while_index) {
3321       auto value = while_op.getOperand(while_index);
3322       if (body_block.getArgument(arg_index).use_empty() &&
3323           cond_block.getArgument(arg_index).use_empty() &&
3324           // Note: since we are not erasing results, need to use while_index
3325           // to check if the corresponding result is unused.
3326           while_op.getResult(while_index).use_empty()) {
3327         unchanged = false;
3328         body_block.eraseArgument(arg_index);
3329         cond_block.eraseArgument(arg_index);
3330 
3331         // Mark operand for removal.
3332         removed_operand[while_index] = true;
3333       } else {
3334         new_operands.push_back(value);
3335         new_body_yield.push_back(yield.getOperand(while_index));
3336         auto type = while_op.getResult(while_index).getType();
3337         types.push_back(type);
3338         ++arg_index;
3339       }
3340     }
3341 
3342     // Done if no values removed from blocks and operands & results match.
3343     if (unchanged) return failure();
3344 
3345     // Replace with new While with matching operands and results.
3346     Operation *op = while_op.getOperation();
3347     Operation *new_op = rewriter.insert(
3348         Operation::create(op->getLoc(), op->getName(), types, new_operands,
3349                           op->getAttrs(), {}, /*numRegions=*/2));
3350 
3351     for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
3352     int new_index = 0;
3353     for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
3354       if (removed_operand[op_index]) continue;
3355       op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
3356       ++new_index;
3357     }
3358     rewriter.eraseOp(op);
3359 
3360     Block &new_body_block = cast<WhileOp>(new_op).body().front();
3361     rewriter.setInsertionPointToEnd(&new_body_block);
3362     rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
3363                                          new_body_yield);
3364 
3365     return success();
3366   }
3367 };
3368 
3369 }  // namespace
3370 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3371 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3372                                           MLIRContext *context) {
3373   results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
3374 }
3375 
getLoopBody()3376 Region &WhileOp::getLoopBody() { return body(); }
3377 
isDefinedOutsideOfLoop(Value value)3378 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
3379   // TODO(jpienaar): This is to overly conservative and disables anything other
3380   // than constant hoisting initially.
3381   return false;
3382 }
3383 
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)3384 LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
3385   if (ops.empty()) return success();
3386 
3387   // Move the hoisted value to just before the while.
3388   Operation *while_op = this->getOperation();
3389   for (auto op : ops) op->moveBefore(while_op);
3390 
3391   return success();
3392 }
3393 
3394 //===----------------------------------------------------------------------===//
3395 // LogisticOp
3396 //===----------------------------------------------------------------------===//
3397 
GetArithmeticCount(Operation * op)3398 int64_t LogisticOp::GetArithmeticCount(Operation *op) {
3399   int64_t count;
3400   // As a very rough ballpark, the cost of evaluating a math function
3401   // such as tanh or logistic is about 32 multiplications, and about as
3402   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3403   // from looking at actual implementations that we use in runtime/code).
3404   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3405     return 64 * count;
3406 
3407   return -1;
3408 }
3409 
3410 //===----------------------------------------------------------------------===//
3411 // LogSoftmaxOp
3412 //===----------------------------------------------------------------------===//
3413 
GetArithmeticCount(Operation * op)3414 int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) {
3415   int64_t count;
3416   // As a very rough ballpark, the cost of evaluating a math function
3417   // such as tanh or logistic is about 32 multiplications, and about as
3418   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3419   // from looking at actual implementations that we use in runtime/code).
3420   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3421     return 64 * count;
3422 
3423   return -1;
3424 }
3425 
3426 //===----------------------------------------------------------------------===//
3427 // SoftmaxOp
3428 //===----------------------------------------------------------------------===//
3429 
GetArithmeticCount(Operation * op)3430 int64_t SoftmaxOp::GetArithmeticCount(Operation *op) {
3431   int64_t count;
3432   // As a very rough ballpark, the cost of evaluating a math function
3433   // such as tanh or logistic is about 32 multiplications, and about as
3434   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3435   // from looking at actual implementations that we use in runtime/code).
3436   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3437     return 64 * count;
3438 
3439   return -1;
3440 }
3441 
3442 //===----------------------------------------------------------------------===//
3443 // TanhOp
3444 //===----------------------------------------------------------------------===//
3445 
GetArithmeticCount(Operation * op)3446 int64_t TanhOp::GetArithmeticCount(Operation *op) {
3447   int64_t count;
3448   // As a very rough ballpark, the cost of evaluating a math function
3449   // such as tanh or logistic is about 32 multiplications, and about as
3450   // many additions/subtractions. (Just a power-of-two order-of-magnitude
3451   // from looking at actual implementations that we use in runtime/code).
3452   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3453     return 64 * count;
3454 
3455   return -1;
3456 }
3457 
3458 //===----------------------------------------------------------------------===//
3459 // AddNOp
3460 //===----------------------------------------------------------------------===//
3461 
GetArithmeticCount(Operation * op)3462 int64_t AddNOp::GetArithmeticCount(Operation *op) {
3463   int64_t count;
3464   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3465     // AddN cost is roughly the same cost as N-1 Adds.
3466     const int64_t num_adds = op->getNumOperands() - 1;
3467     return num_adds * count;
3468   }
3469 
3470   return -1;
3471 }
3472 
3473 //===----------------------------------------------------------------------===//
3474 // AveragePool2DOp
3475 //===----------------------------------------------------------------------===//
3476 
GetArithmeticCount(Operation * op)3477 int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) {
3478   int64_t count;
3479   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3480     auto avg_pool = llvm::dyn_cast<AveragePool2DOp>(op);
3481     return avg_pool.filter_height() * avg_pool.filter_width() * count;
3482   }
3483 
3484   return -1;
3485 }
3486 
3487 //===----------------------------------------------------------------------===//
3488 // MaxPool2DOp
3489 //===----------------------------------------------------------------------===//
3490 
GetArithmeticCount(Operation * op)3491 int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) {
3492   int64_t count;
3493   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3494     auto max_pool = llvm::dyn_cast<MaxPool2DOp>(op);
3495     return max_pool.filter_height() * max_pool.filter_width() * count;
3496   }
3497 
3498   return -1;
3499 }
3500 
3501 //===----------------------------------------------------------------------===//
3502 // L2NormalizationOp
3503 //===----------------------------------------------------------------------===//
3504 
GetArithmeticCount(Operation * op)3505 int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) {
3506   int64_t count;
3507   // Computing the squared L2 norm is N multiply-adds so 2N ops,
3508   // then the single inverse-sqrt is negligible, then we multiply each
3509   // value by the resulting multiplier, so an extra N ops. count 3N ops.
3510   if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3511     return 3 * count;
3512   }
3513 
3514   return -1;
3515 }
3516 
3517 //===----------------------------------------------------------------------===//
3518 // PadOp
3519 //===----------------------------------------------------------------------===//
3520 
fold(ArrayRef<Attribute> operands)3521 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
3522   if (InputOutputHasSameShape(input().getType(), output().getType()))
3523     return input();
3524 
3525   return {};
3526 }
3527 
3528 //===----------------------------------------------------------------------===//
3529 // PadV2Op
3530 //===----------------------------------------------------------------------===//
3531 
fold(ArrayRef<Attribute> operands)3532 OpFoldResult PadV2Op::fold(ArrayRef<Attribute> operands) {
3533   if (InputOutputHasSameShape(input().getType(), output().getType()))
3534     return input();
3535 
3536   return {};
3537 }
3538 
3539 //===----------------------------------------------------------------------===//
3540 // TableGen'd op method definitions
3541 //===----------------------------------------------------------------------===//
3542 
3543 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
3544 
3545 }  // namespace TFL
3546 }  // namespace mlir
3547 
3548 #define GET_OP_CLASSES
3549 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
3550 
3551 namespace mlir {
3552 namespace TFL {
3553 
3554 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
3555 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)3556 Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
3557                                                       Attribute value,
3558                                                       Type type, Location loc) {
3559   // If this is an opaque elements attribute or the result type doesn't match
3560   // the attribute type, then generate a tfl.pseudo_const.
3561   if (value.isa<OpaqueElementsAttr>() ||
3562       (value.isa<ElementsAttr>() && value.getType() != type))
3563     return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
3564   if (ConstantOp::isBuildableWith(value, type))
3565     return builder.create<ConstantOp>(loc, type, value);
3566   return nullptr;
3567 }
3568 
3569 }  // namespace TFL
3570 }  // namespace mlir
3571