1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <string>
24
25 #include "third_party/eigen3/Eigen/Core"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
35 #include "mlir/IR/Attributes.h" // from @llvm-project
36 #include "mlir/IR/Builders.h" // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/Location.h" // from @llvm-project
40 #include "mlir/IR/Matchers.h" // from @llvm-project
41 #include "mlir/IR/OpImplementation.h" // from @llvm-project
42 #include "mlir/IR/PatternMatch.h" // from @llvm-project
43 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
44 #include "mlir/Support/LLVM.h" // from @llvm-project
45 #include "mlir/Support/LogicalResult.h" // from @llvm-project
46 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
47 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
48 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
49 #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
50 #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/core/framework/kernel_shape_util.h"
55
56 namespace mlir {
57 namespace TFL {
58 namespace {
59
getDefiningBroadcastArgsOp(Value operand)60 Operation *getDefiningBroadcastArgsOp(Value operand) {
61 auto *defining_op = operand.getDefiningOp();
62 if (!llvm::dyn_cast_or_null<TF::BroadcastToOp>(defining_op) &&
63 !llvm::dyn_cast_or_null<TFL::BroadcastToOp>(defining_op)) {
64 return nullptr;
65 }
66
67 Value broadcast_shape = defining_op->getOperand(
68 1); // Broadcasted shape operand of BroadcastTo op.
69 Operation *parent_of_defining_op = broadcast_shape.getDefiningOp();
70 if (!llvm::dyn_cast_or_null<TF::BroadcastArgsOp>(parent_of_defining_op) &&
71 !llvm::dyn_cast_or_null<TFL::BroadcastArgsOp>(parent_of_defining_op)) {
72 return nullptr;
73 }
74 return parent_of_defining_op;
75 }
76
77 } // namespace
78
79 // Returns true when the given operand arguments have the same shape or
80 // broadcastable shape within the given rank. If any given shapes are
81 // non-static and maximum rank is within the given rank, this method returns
82 // true.
VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation * op,ArrayRef<unsigned> indices,int max_bcast_rank)83 bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
84 Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
85 if (indices.empty()) return true;
86
87 // First, it checks there are any inputs that has unknown rank.
88 bool has_unknown_shape_input = false;
89 bool has_same_shape = true;
90 bool reach_first_known_shape = false;
91 int64_t max_rank = -1;
92
93 ArrayRef<int64_t> pivot_shape;
94 SmallVector<int64_t, 4> current_shape;
95 SmallVector<int64_t, 4> result_shape;
96
97 for (unsigned index : indices) {
98 ShapedType shaped_type =
99 op->getOperand(index).getType().dyn_cast<ShapedType>();
100 if (!shaped_type || !shaped_type.hasRank()) {
101 // Marks that we have an unknown rank input.
102 has_unknown_shape_input = true;
103 continue;
104 }
105 max_rank = std::max(max_rank, shaped_type.getRank());
106 if (!shaped_type.hasStaticShape()) {
107 // Marks that we have an unknown shape input.
108 has_unknown_shape_input = true;
109 continue;
110 }
111
112 ArrayRef<int64_t> shape = shaped_type.getShape();
113 if (!reach_first_known_shape) {
114 pivot_shape = shape;
115 current_shape.assign(shape.begin(), shape.end());
116 reach_first_known_shape = true;
117 continue;
118 }
119
120 if (!pivot_shape.equals(shape)) {
121 has_same_shape = false;
122 }
123 // Checks if all the inputs are broadcastable since they have not all the
124 // same shapes.
125 if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
126 result_shape)) {
127 return false;
128 }
129 current_shape = result_shape;
130 }
131
132 // If all the shape is known and same, CPU kernels are able to handle inputs
133 // regardless of dimension size.
134 if (!has_unknown_shape_input) {
135 return has_same_shape || max_rank <= max_bcast_rank;
136 }
137
138 // It will treat the unknown shape inputs as acceptable inputs for model
139 // compatibility if all known ranks are no bigger than the allowed broadcast
140 // maximum rank.
141 if (max_rank <= max_bcast_rank) {
142 return true;
143 }
144
145 // Checks if all operands are broadcasted by BroadcastTo ops with the shape
146 // is calculated from the same BroadcastArgs op. In such case, all operands
147 // will have the same shape.
148 Operation *broadcast_args_pivot = nullptr;
149 for (unsigned index : indices) {
150 Operation *parent_broadcast_args =
151 getDefiningBroadcastArgsOp(op->getOperand(index));
152 if (parent_broadcast_args == nullptr) {
153 return false;
154 }
155
156 if (broadcast_args_pivot == nullptr) {
157 broadcast_args_pivot = parent_broadcast_args;
158 continue;
159 }
160
161 if (broadcast_args_pivot != parent_broadcast_args) {
162 return false;
163 }
164 }
165 return true;
166 }
167
168 // Return true when the given element_type is QI8.
IsQI8Type(Type element_type)169 bool IsQI8Type(Type element_type) {
170 auto quantized_type = element_type.dyn_cast<QuantizedType>();
171 return quantized_type != nullptr &&
172 quantized_type.getStorageTypeIntegralWidth() == 8 &&
173 quantized_type.isSigned();
174 }
175
176 // Return true when the given element_type is QUI8.
IsQUI8Type(Type element_type)177 bool IsQUI8Type(Type element_type) {
178 auto quantized_type = element_type.dyn_cast<QuantizedType>();
179 return quantized_type != nullptr &&
180 quantized_type.getStorageTypeIntegralWidth() == 8 &&
181 !quantized_type.isSigned();
182 }
183
184 // Return true when the given element_type is QI16.
IsQI16Type(Type element_type)185 bool IsQI16Type(Type element_type) {
186 auto quantized_type = element_type.dyn_cast<QuantizedType>();
187 return quantized_type != nullptr &&
188 quantized_type.getStorageTypeIntegralWidth() == 16 &&
189 quantized_type.isSigned();
190 }
191
192 // Return true when the given element_type is I32.
IsI32Type(Type element_type)193 bool IsI32Type(Type element_type) {
194 return element_type.isInteger(32) && !element_type.isUnsignedInteger();
195 }
196
197 // Return true when the given element_type is I64.
IsI64Type(Type element_type)198 bool IsI64Type(Type element_type) {
199 return element_type.isInteger(64) && !element_type.isUnsignedInteger();
200 }
201
202 // Return true if the value is a splat tensor constant zero.
EqualsZero(Value value)203 bool EqualsZero(Value value) {
204 DenseElementsAttr constant;
205 if (!matchPattern(value, m_Constant(&constant)) || !constant.isSplat()) {
206 return false;
207 }
208
209 Type element_type = value.getType().cast<ShapedType>().getElementType();
210 if (element_type.isa<FloatType>()) {
211 return constant.getSplatValue<APFloat>().isZero();
212 } else {
213 return false;
214 }
215 }
216
217 // Replaces the bias operand with a "none" type value if the bias value is
218 // constant zero.
219 // `ConcreteOpType` must be an concrete MLIR op class that has an optional
220 // bias operand named 'bias'.
221 template <typename ConcreteOpType>
222 struct RemoveOptionalZeroBias : public OpRewritePattern<ConcreteOpType> {
223 using OpRewritePattern<ConcreteOpType>::OpRewritePattern;
224
matchAndRewritemlir::TFL::RemoveOptionalZeroBias225 LogicalResult matchAndRewrite(ConcreteOpType op,
226 PatternRewriter &rewriter) const override {
227 if (EqualsZero(op.bias())) {
228 auto none_value = rewriter.create<mlir::ConstantOp>(
229 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
230 op.biasMutable().assign(none_value);
231 }
232
233 return success();
234 }
235 };
236
237 // Return true if the given Add operation has the CPU kernel supported shapes.
VerifyAddOpShapeConstraints(AddOp op)238 bool VerifyAddOpShapeConstraints(AddOp op) {
239 auto element_type = getElementTypeOrSelf(op.output().getType());
240
241 // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
242 // which are broadcastable shapes up to four dimensions or have same shapes.
243 if (element_type.isF32() || IsQI8Type(element_type) ||
244 IsQUI8Type(element_type) || IsI32Type(element_type) ||
245 IsI64Type(element_type)) {
246 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
247 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
248 /*max_bcast_rank=*/4);
249 }
250
251 // Allows QI16 output when operands have the same shape.
252 if (IsQI16Type(element_type)) {
253 return succeeded(
254 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
255 }
256 return false;
257 }
258
259 // Return true if the given Sub operation has the CPU kernel supported shapes.
VerifySubOpShapeConstraints(SubOp op)260 bool VerifySubOpShapeConstraints(SubOp op) {
261 auto element_type = getElementTypeOrSelf(op.output().getType());
262
263 // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
264 // which are broadcastable shapes up to five dimension or have same shapes.
265 if (element_type.isF32() || IsI32Type(element_type) ||
266 IsI64Type(element_type) || IsQUI8Type(element_type) ||
267 IsQI16Type(element_type)) {
268 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
269 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
270 /*max_bcast_rank=*/5);
271 }
272
273 // Allows QI8 output when the operands have valid shapes, which are
274 // broadcastable shapes up to four dimension or have same shapes.
275 if (IsQI8Type(element_type)) {
276 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
277 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
278 /*max_bcast_rank=*/4);
279 }
280 return false;
281 }
282
283 // Return true if the given Mul operation has the CPU kernel supported shapes.
VerifyMulOpShapeConstraints(MulOp op)284 bool VerifyMulOpShapeConstraints(MulOp op) {
285 auto element_type = getElementTypeOrSelf(op.output().getType());
286
287 // Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
288 // output type is not QI16. If the output type is Q16, allows only the same
289 // shape operands.
290 if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
291 if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
292 return succeeded(
293 mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
294 }
295 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
296 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
297 /*max_bcast_rank=*/4);
298 }
299
300 // Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
301 // which are broadcastable shapes up to four dimension or have same shapes.
302 if (IsI32Type(element_type) || IsI64Type(element_type) ||
303 IsQI16Type(element_type) || element_type.isF32()) {
304 return VerifyOperandsHaveSameShapesOrBroadcastableShape(
305 /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
306 /*max_bcast_rank=*/4);
307 }
308 return false;
309 }
310
311 //===----------------------------------------------------------------------===//
312 // TensorFlowLiteDialect
313 //===----------------------------------------------------------------------===//
314
315 struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface {
316 using DialectInlinerInterface::DialectInlinerInterface;
317
318 //===--------------------------------------------------------------------===//
319 // Analysis Hooks
320 //===--------------------------------------------------------------------===//
321
322 // Allow all call operations to be inlined.
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface323 bool isLegalToInline(Operation *call, Operation *callable,
324 bool wouldBeCloned) const final {
325 return true;
326 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface327 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
328 BlockAndValueMapping &) const final {
329 // No TFLite op restricts inlining today, revise as needed in the future.
330 return true;
331 }
isLegalToInlinemlir::TFL::TensorFlowLiteInlinerInterface332 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
333 BlockAndValueMapping &valueMapping) const final {
334 return isa<WhileOp>(dest->getParentOp());
335 }
336 };
337
338 struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface {
339 using DialectFoldInterface::DialectFoldInterface;
340
341 // Registered hook to check if the given region, which is attached to an
342 // operation that is *not* isolated from above (i.e. no internal regions
343 // reference values defined in an enclosing region), should be used when
344 // materializing constants.
345 // In the TFLite dialect we materialize inside a while regions as slightly
346 // more efficient computationally.
shouldMaterializeIntomlir::TFL::TensorFlowLiteDialectFoldInterface347 bool shouldMaterializeInto(Region *region) const final {
348 return isa<WhileOp>(region->getParentOp());
349 }
350 };
351
TensorFlowLiteDialect(mlir::MLIRContext * context)352 TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context)
353 : Dialect(/*name=*/"tfl", context, TypeID::get<TensorFlowLiteDialect>()) {
354 addOperations<
355 #define GET_OP_LIST
356 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
357 >();
358 addInterfaces<TensorFlowLiteInlinerInterface,
359 TensorFlowLiteDialectFoldInterface>();
360 }
361
362 //===----------------------------------------------------------------------===//
363 // Common support logic
364 //===----------------------------------------------------------------------===//
365
366 namespace {
367
368 // Returns true if the dimensions in `a` is a suffix of the ones in `b`.
369 // For example, dimensions {2}, {1, 2}, and {3, 1, 2} are all suffixes to
370 // {5, 4, 3, 1, 2}, while {1}, {5, 4}, and {1, 3, 2} are all not.
IsTrailingDimensions(ArrayRef<int64_t> a,ArrayRef<int64_t> b)371 inline bool IsTrailingDimensions(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
372 if (a.size() > b.size()) return false;
373
374 return std::equal(a.rbegin(), a.rend(), b.rbegin());
375 }
376
377 // Returns true if it is a shaped type of f32 elements.
IsF32ShapedType(Type t)378 inline bool IsF32ShapedType(Type t) {
379 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
380 return shaped_type.getElementType().isF32();
381 }
382 return false;
383 }
384
385 // Returns true if it is a shaped type of bf16 elements.
IsBF16ShapedType(Type t)386 inline bool IsBF16ShapedType(Type t) {
387 if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
388 return shaped_type.getElementType().isBF16();
389 }
390 return false;
391 }
392
393 // Returns new shape with rank 'new_dims' with padded ones on the
394 // left if needed.
GetPaddedShape(ArrayRef<int64_t> old_shape,int new_dims)395 inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
396 int new_dims) {
397 std::vector<int64_t> new_shape(new_dims, 1);
398 std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
399 return new_shape;
400 }
401
402 // Helper method that given and 'current_index' representing
403 // index in broadcasted tensor, get the index in the flat original tensor.
404 // 'shape' is the original shape with padding to match result shape.
GetElementIndex(const std::vector<int64_t> & shape,const std::vector<int64_t> & current_index)405 int64_t GetElementIndex(const std::vector<int64_t> &shape,
406 const std::vector<int64_t> ¤t_index) {
407 int64_t ind = 0;
408 int64_t mul = 1;
409 for (int i = shape.size() - 1; i >= 0; --i) {
410 ind += (current_index[i] % shape[i]) * mul;
411 mul *= shape[i];
412 }
413 return ind;
414 }
415
416 // Helper method that increment index represented in 'current_index_ptr'
417 // in the shape of 'result_shape'.
IncrementIndex(ArrayRef<int64_t> result_shape,std::vector<int64_t> * current_index_ptr)418 void IncrementIndex(ArrayRef<int64_t> result_shape,
419 std::vector<int64_t> *current_index_ptr) {
420 std::vector<int64_t> ¤t_index = *current_index_ptr;
421 for (int i = result_shape.size() - 1; i >= 0; --i) {
422 current_index[i]++;
423 if (current_index[i] == result_shape[i]) {
424 current_index[i] = 0;
425 } else {
426 break;
427 }
428 }
429 }
430
431 /// Performs const folding `calculate` with broadcast behavior on the two
432 /// attributes `operand1` and `operand2` and returns the result if possible.
433 /// This function assumes the both operands are verified to have value
434 /// attributes of broadcastable types.
435 template <class AttrElementT,
436 class ElementValueT = typename AttrElementT::ValueType,
437 class CalculationT =
438 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOpDenseDense(Type result_type,DenseElementsAttr lhs,DenseElementsAttr rhs,const CalculationT & calculate)439 Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
440 DenseElementsAttr rhs,
441 const CalculationT &calculate) {
442 auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
443 .dyn_cast_or_null<ShapedType>();
444 if (!type) {
445 return {};
446 }
447
448 const bool rhs_is_splat = rhs.isSplat();
449 const bool lhs_is_splat = lhs.isSplat();
450
451 // If both of them are splat, compute and return.
452 if (lhs_is_splat && rhs_is_splat) {
453 auto element_result = AttrElementT::get(
454 type.getElementType(), calculate(lhs.getSplatValue<ElementValueT>(),
455 rhs.getSplatValue<ElementValueT>()));
456 if (!element_result) return {};
457
458 return DenseElementsAttr::get(type, element_result);
459 }
460
461 auto num_elements = type.getNumElements();
462
463 SmallVector<ElementValueT, 16> new_values;
464 new_values.reserve(num_elements);
465 const auto result_shape = type.getShape();
466 std::vector<int64_t> current_index(type.getRank(), 0);
467 // Create the new shape with ones padded to the left.
468 const std::vector<int64_t> lhs_new_shape =
469 GetPaddedShape(lhs.getType().getShape(), type.getRank());
470 const std::vector<int64_t> rhs_new_shape =
471 GetPaddedShape(rhs.getType().getShape(), type.getRank());
472
473 auto lhs_old_values = lhs.getValues<ElementValueT>();
474 auto rhs_old_values = rhs.getValues<ElementValueT>();
475
476 // Add each pair of the corresponding values in the dense elements
477 // attributes.
478 for (int64_t i = 0; i < num_elements; ++i) {
479 // current_index represents the index
480 // in the N-dimension tensor. GetElementIndex returns
481 // the index in the flat representation of the original tensor
482 // to use.
483 const int64_t lhs_index =
484 lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
485 const int64_t rhs_index =
486 rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
487
488 new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index),
489 *(rhs_old_values.begin() + rhs_index)));
490 IncrementIndex(result_shape, ¤t_index);
491 }
492 return DenseElementsAttr::get(type, ArrayRef<ElementValueT>(new_values));
493 }
494
495 /// Performs const folding `calculate` with broadcast behavior on the two
496 /// attributes `operand1` and `operand2` and returns the result if possible.
497 /// This function assumes the two operands are verified to have value
498 /// attributes of broadcastable types.
499 template <class AttrElementT,
500 class ElementValueT = typename AttrElementT::ValueType,
501 class CalculationT =
502 llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
ConstFoldBinaryOp(Type result_type,Attribute operand1,Attribute operand2,const CalculationT & calculate)503 Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1,
504 Attribute operand2, const CalculationT &calculate) {
505 if (operand1.dyn_cast_or_null<DenseElementsAttr>() &&
506 operand2.dyn_cast_or_null<DenseElementsAttr>()) {
507 return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
508 result_type, operand1.cast<DenseElementsAttr>(),
509 operand2.cast<DenseElementsAttr>(), calculate);
510 }
511
512 // TODO: support other attribute kinds
513
514 return {};
515 }
516
517 /// Performs const folding with broadcast behavior on the two attributes in
518 /// `operands` and returns the result if possible.
519 /// Depending on the given `resultType`, either `floatCalculate` or
520 /// `intCalculate` is chosen to conduct the calculate.
ConstFoldBinaryOp(Type result_type,ArrayRef<Attribute> operands,llvm::function_ref<APFloat (APFloat,APFloat)> float_calculate,llvm::function_ref<APInt (APInt,APInt)> int_calculate)521 Attribute ConstFoldBinaryOp(
522 Type result_type, ArrayRef<Attribute> operands,
523 llvm::function_ref<APFloat(APFloat, APFloat)> float_calculate,
524 llvm::function_ref<APInt(APInt, APInt)> int_calculate) {
525 // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is
526 // represented as tensor<f32>. So we are only handling tensor types here.
527 auto type = result_type.dyn_cast<ShapedType>();
528 if (!type) return {};
529
530 auto elemType = type.getElementType();
531
532 if (elemType.isa<FloatType>())
533 return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
534 float_calculate);
535
536 if (elemType.isSignlessInteger())
537 return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
538 int_calculate);
539
540 return {};
541 }
542
543 /// Performs const folding a attributes `operand` and returns the result if
544 /// possible.
545 /// The function currently asserts that the `result_type` to be a f32 tensor
546 /// type.
547 /// TODO: Extend this function to handle integral tensor for ops like
548 /// "tfl.logical_not".
ConstFoldUnaryOp(Type result_type,Attribute operand,llvm::function_ref<APFloat (APFloat)> calculate)549 Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
550 llvm::function_ref<APFloat(APFloat)> calculate) {
551 assert(IsF32ShapedType(result_type) || IsBF16ShapedType(result_type));
552 auto result_shape_type = result_type.cast<ShapedType>();
553
554 if (!result_shape_type.hasStaticShape()) return {};
555
556 if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>()) {
557 SmallVector<APFloat, 16> new_values;
558 const int num_elements = result_shape_type.getNumElements();
559 new_values.reserve(num_elements);
560
561 for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
562 new_values.push_back(calculate(old_value));
563 }
564
565 return DenseElementsAttr::get(result_shape_type, new_values);
566 }
567
568 return {};
569 }
570
buildComparisonBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs)571 void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
572 Value rhs) {
573 auto result_type =
574 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
575 if (!result_type)
576 emitError(result.location)
577 << "non-broadcastable operands: " << lhs.getType() << " and "
578 << rhs.getType();
579 result.addOperands({lhs, rhs});
580 // Comparison binary ops always return i1 tensor.
581 if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
582 auto result_shape = shaped_type.getShape();
583 result.types.push_back(
584 RankedTensorType::get(result_shape, builder->getI1Type()));
585 } else {
586 result.types.push_back(UnrankedTensorType::get(builder->getI1Type()));
587 }
588 }
589
buildFusedBroadcastableBinOp(Builder * builder,OperationState & result,Value lhs,Value rhs,StringAttr fused_activation_function)590 void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
591 Value lhs, Value rhs,
592 StringAttr fused_activation_function) {
593 auto result_type =
594 OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
595
596 if (!result_type)
597 emitError(result.location)
598 << "non-broadcastable operands: " << lhs.getType() << " and "
599 << rhs.getType();
600
601 result.addOperands({lhs, rhs});
602 result.addAttribute("fused_activation_function", fused_activation_function);
603 result.types.push_back(result_type);
604 }
605
606 } // end anonymous namespace
607
608 //===----------------------------------------------------------------------===//
609 // AddOp
610 //===----------------------------------------------------------------------===//
611
fold(ArrayRef<Attribute> operands)612 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
613 // TODO(b/142478136): Handle fused ops.
614 if (fused_activation_function() != "NONE") return {};
615 return ConstFoldBinaryOp(
616 getType(), operands, [](APFloat a, APFloat b) { return a + b; },
617 [](APInt a, APInt b) { return a + b; });
618 }
619
GetArithmeticCount(Operation * op)620 int64_t AddOp::GetArithmeticCount(Operation *op) {
621 int64_t count;
622 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
623
624 return -1;
625 }
626
627 //===----------------------------------------------------------------------===//
628 // ConcatenationOp
629 //===----------------------------------------------------------------------===//
630 // TODO(ashwinm): Implement shape inference for Concatenation
631
632 namespace {
633
GetConcatenationOpAxis(ConcatenationOp op)634 int64_t GetConcatenationOpAxis(ConcatenationOp op) {
635 auto output_type = op.output().getType().cast<RankedTensorType>();
636 int32_t axis = op.axis();
637 if (axis < 0) axis += output_type.getRank();
638 return axis;
639 }
640
641 // Verify operand types and the result type:
642 //
643 // 1. Operand type ranks must be equal to the output type rank.
644 //
645 // 2. Operand dimension sizes (except dimension `axis`) must be equal to
646 // previously seen dimension sizes of the same dimension.
647 //
648 // 3. Sum of operand dimension sizes of the `axis` dimension must be equal to
649 // the dimension size of the `axis` dimension of output.
650 //
651 // Note: If an operand has unranked tensor type or has dynamic dimension size,
652 // those dimensions will be skipped.
VerifyConcatenationOpTypes(Operation * op,RankedTensorType output_type,ArrayRef<TensorType> operand_types,int64_t axis)653 LogicalResult VerifyConcatenationOpTypes(Operation *op,
654 RankedTensorType output_type,
655 ArrayRef<TensorType> operand_types,
656 int64_t axis) {
657 const int64_t output_rank = output_type.getRank();
658
659 constexpr int64_t kDynamicSize = -1;
660 SmallVector<int64_t, 4> result_dim_sizes_loc(output_rank, -1);
661 SmallVector<int64_t, 4> result_dim_sizes(output_type.getShape().begin(),
662 output_type.getShape().end());
663 result_dim_sizes[axis] = 0;
664
665 auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) {
666 const int64_t loc = result_dim_sizes_loc[dim];
667 if (loc == -1) return std::string("output");
668 return llvm::formatv("operand #{0}", loc).str();
669 };
670
671 for (auto operand : llvm::enumerate(operand_types)) {
672 auto operand_type = operand.value().dyn_cast<RankedTensorType>();
673 if (!operand_type) {
674 result_dim_sizes[axis] = kDynamicSize;
675 continue;
676 }
677
678 const int64_t operand_rank = operand_type.getRank();
679 if (operand_rank != output_rank)
680 return op->emitOpError() << "rank of operand #" << operand.index()
681 << " must be equal to rank of output, expected "
682 << output_rank << ", got " << operand_rank;
683
684 for (int64_t dim = 0; dim < output_rank; ++dim) {
685 const int64_t operand_dim_size = operand_type.getDimSize(dim);
686 const int64_t result_dim_size = result_dim_sizes[dim];
687
688 if (dim == axis) {
689 if (RankedTensorType::isDynamic(operand_dim_size) ||
690 RankedTensorType::isDynamic(result_dim_size))
691 result_dim_sizes[axis] = kDynamicSize;
692 else
693 result_dim_sizes[axis] += operand_dim_size;
694 continue;
695 }
696
697 if (RankedTensorType::isDynamic(operand_dim_size)) continue;
698
699 if (RankedTensorType::isDynamic(result_dim_size)) {
700 result_dim_sizes[dim] = operand_dim_size;
701 result_dim_sizes_loc[dim] = operand.index();
702 continue;
703 }
704
705 if (result_dim_size != operand_dim_size)
706 return op->emitOpError()
707 << "dimension size of dimension #" << dim << " of operand #"
708 << operand.index() << " must be equal to "
709 << "dimension size of dimension #" << dim << " of "
710 << FormatLoc(dim) << ", expected " << result_dim_size << ", got "
711 << operand_dim_size;
712 }
713 }
714
715 const int64_t output_concated_dim_size = output_type.getDimSize(axis);
716 if (!RankedTensorType::isDynamic(output_concated_dim_size) &&
717 !RankedTensorType::isDynamic(result_dim_sizes[axis]) &&
718 result_dim_sizes[axis] != output_concated_dim_size)
719 return op->emitOpError()
720 << "dimension size of dimension #" << axis << " of output "
721 << "must be equal to the sum of dimension sizes of dimension #"
722 << axis << ", expected " << result_dim_sizes[axis] << ", got "
723 << output_concated_dim_size;
724
725 return success();
726 }
727
Verify(ConcatenationOp op)728 LogicalResult Verify(ConcatenationOp op) {
729 auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
730
731 // If the output type is unranked, there is nothing else to be verified.
732 if (!output_type) return success();
733
734 const int64_t axis = GetConcatenationOpAxis(op);
735 if (axis < 0 || axis >= output_type.getRank())
736 return op.emitOpError("concatenation dimension must be in [-rank, rank)");
737
738 SmallVector<TensorType, 4> operand_types;
739 for (Value operand : op.values())
740 operand_types.push_back(operand.getType().cast<TensorType>());
741
742 return VerifyConcatenationOpTypes(op.getOperation(), output_type,
743 operand_types, axis);
744 }
745
746 // Returns true when all operands are instances of DenseElementsAttr and the
747 // output type has a static shape.
IsConcatenationOpConstFoldable(ConcatenationOp op,ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)748 bool IsConcatenationOpConstFoldable(ConcatenationOp op,
749 ArrayRef<Attribute> operands,
750 RankedTensorType output_type,
751 int64_t axis) {
752 if (operands.empty()) return false;
753 if (!output_type.hasStaticShape()) return false;
754 if (axis < 0) return false;
755
756 return llvm::all_of(operands, [](Attribute operand) {
757 return operand && operand.isa<DenseElementsAttr>();
758 });
759 }
760
ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,RankedTensorType output_type,int64_t axis)761 DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
762 RankedTensorType output_type,
763 int64_t axis) {
764 const auto outer_dims = output_type.getShape().take_front(axis);
765 const int64_t outer_size = std::accumulate(
766 outer_dims.begin(), outer_dims.end(), 1, std::multiplies<int64_t>());
767
768 const auto base_inner_dims = output_type.getShape().drop_front(axis + 1);
769 const int64_t base_inner_size =
770 std::accumulate(base_inner_dims.begin(), base_inner_dims.end(), 1,
771 std::multiplies<int64_t>());
772
773 // Splits each input operand into outer_size pieces and combines them in
774 // round-robin ordering.
775 std::vector<Attribute> out_attrs(output_type.getNumElements());
776 int64_t out = 0;
777 for (int64_t outer = 0; outer < outer_size; ++outer) {
778 for (auto op : operands) {
779 const int64_t dim_size =
780 op.getType().cast<RankedTensorType>().getDimSize(axis);
781 const int64_t inner_size = dim_size * base_inner_size;
782
783 auto input_attrs = op.cast<DenseElementsAttr>().getValues<Attribute>();
784 auto input_iter = input_attrs.begin() + outer * inner_size;
785 for (int64_t inner = 0; inner < inner_size; ++inner)
786 out_attrs[out++] = *input_iter++;
787 }
788 }
789
790 return DenseElementsAttr::get(output_type, out_attrs);
791 }
792
793 } // end anonymous namespace
794
fold(ArrayRef<Attribute> operands)795 OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
796 if (fused_activation_function() == "NONE") {
797 if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
798 const int64_t axis = GetConcatenationOpAxis(*this);
799 if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
800 return ConstFoldConcatenateOpDense(operands, output_type, axis);
801 }
802 }
803
804 // Remove all empty values.
805 SmallVector<Value, 4> non_empty_values;
806 for (Value value : this->values()) {
807 const auto shaped_type = value.getType().cast<ShapedType>();
808 if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
809 continue;
810 }
811 non_empty_values.push_back(value);
812 }
813
814 // All are not empty, do nothing.
815 if (non_empty_values.size() == getNumOperands()) return nullptr;
816
817 // If only one input is non-empty, just return it as the result of folding.
818 if (non_empty_values.size() == 1) {
819 return non_empty_values[0];
820 }
821
822 // Otherwise, build a new concatenation op with non-empty values.
823 mlir::OpBuilder builder(getOperation());
824 auto new_concat = builder.create<TFL::ConcatenationOp>(
825 getLoc(), getType(), non_empty_values,
826 builder.getIntegerAttr(builder.getIntegerType(32), axis()),
827 builder.getStringAttr(fused_activation_function()));
828 return new_concat.getResult();
829 }
830
831 //===----------------------------------------------------------------------===//
832 // CustomOp
833 //===----------------------------------------------------------------------===//
834
Verify(CustomOp op)835 static LogicalResult Verify(CustomOp op) {
836 OpaqueElementsAttr opaque_attr =
837 op.custom_option().cast<OpaqueElementsAttr>();
838 if (!opaque_attr.getType().hasStaticShape())
839 return op.emitOpError("custom_option should have a static shape.");
840 const int attribute_size = opaque_attr.getValue().size();
841 if (attribute_size != opaque_attr.getType().cast<ShapedType>().getDimSize(0))
842 return op.emitOpError(
843 "custom_option should have the same length of content with shape.");
844 return success();
845 }
846
847 //===----------------------------------------------------------------------===//
848 // CustomTfOp
849 //===----------------------------------------------------------------------===//
850
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange ranges,SmallVectorImpl<Type> & inferredReturnTypes)851 LogicalResult CustomTfOp::inferReturnTypes(
852 MLIRContext *, Optional<Location> location, ValueRange operands,
853 DictionaryAttr attr, RegionRange ranges,
854 SmallVectorImpl<Type> &inferredReturnTypes) {
855 CustomTfOpAdaptor op(operands, attr, ranges);
856
857 if (op.getRegions().empty()) return success();
858 auto *real_op = &op.body().front().front();
859 if (llvm::isa<TF::FakeQuantWithMinMaxArgsOp, TF::FakeQuantWithMinMaxVarsOp,
860 TF::FakeQuantWithMinMaxVarsPerChannelOp>(real_op)) {
861 Value input = *operands.begin();
862 inferredReturnTypes.assign({input.getType()});
863 }
864 return success();
865 }
866
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)867 bool CustomTfOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
868 if (lhs.empty()) return true;
869 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
870 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
871 return true;
872 }
873
874 //===----------------------------------------------------------------------===//
875 // FullyConnectedOp
876 //===----------------------------------------------------------------------===//
877
Verify(FullyConnectedOp op)878 LogicalResult Verify(FullyConnectedOp op) {
879 ShapedType input_type = op.input().getType().cast<ShapedType>();
880 ShapedType filter_type = op.filter().getType().cast<ShapedType>();
881 if (filter_type.hasRank() && filter_type.getRank() != 2) {
882 return op.emitOpError("expect 2d filter, got ") << filter_type;
883 }
884
885 if (!input_type.hasStaticShape() || !filter_type.hasStaticShape()) {
886 return mlir::success();
887 }
888
889 // Input's element size must be multiple of parameter's z_in dimension.
890 const int z_in = filter_type.getDimSize(1);
891 const int num_input_elements = input_type.getNumElements();
892 if (num_input_elements % z_in != 0) {
893 return op.emitOpError(llvm::formatv(
894 "expect 'input' num_elements % {0} == 0, got input type ", z_in))
895 << input_type;
896 }
897
898 // TODO(jpienaar): Include more shape verification for SHUFFLED4x16INT8
899 // format.
900 if (op.weights_format() == "DEFAULT") {
901 ShapedType output_type =
902 (*op.output().begin()).getType().cast<ShapedType>();
903 if (!output_type.hasStaticShape()) {
904 return mlir::success();
905 }
906
907 const int num_output_elements = output_type.getNumElements();
908 const int z_out = filter_type.getDimSize(0);
909 if (num_output_elements % z_out != 0) {
910 return op.emitOpError(llvm::formatv(
911 "expect 'output' num_elements % {0} == 0, got ", z_out))
912 << output_type;
913 }
914
915 if (num_input_elements / z_in != num_output_elements / z_out) {
916 return op.emitOpError(
917 "num_input_elements / z_in != num_output_elements / z_out");
918 }
919 }
920
921 return mlir::success();
922 }
923
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)924 LogicalResult FullyConnectedOp::fold(ArrayRef<Attribute> operands,
925 SmallVectorImpl<OpFoldResult> &results) {
926 assert(operands.size() == 3);
927
928 // Folding not implemented with any activation function or any weight type
929 // besides the default.
930 if (fused_activation_function() != "NONE") return failure();
931 if (weights_format() != "DEFAULT") return failure();
932
933 // Bias tensor is optional.
934 const bool has_bias = !(!bias() || bias().getType().isa<NoneType>());
935
936 // Get the tensors.
937 DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
938 if (!matchPattern(input(), m_Constant(&input_tensor)) ||
939 !matchPattern(filter(), m_Constant(&weights_tensor)) ||
940 (has_bias && !matchPattern(bias(), m_Constant(&bias_tensor)))) {
941 return failure();
942 }
943
944 // Get the tensor types.
945 const auto input_type = input_tensor.getType().cast<ShapedType>();
946 const auto weights_type = weights_tensor.getType().cast<ShapedType>();
947 const auto bias_type =
948 has_bias ? bias_tensor.getType().cast<ShapedType>() : ShapedType{};
949
950 const auto output_type = getType(0).cast<ShapedType>();
951
952 // Folding only implemented for float tensors.
953 if (!input_type.getElementType().isF32() ||
954 !weights_type.getElementType().isF32() ||
955 !output_type.getElementType().isF32() ||
956 (has_bias && !bias_type.getElementType().isF32())) {
957 return failure();
958 }
959
960 // Folding only implemented for static shapes
961 if (!input_type.hasStaticShape() || !weights_type.hasStaticShape() ||
962 (has_bias && !bias_type.hasStaticShape())) {
963 return failure();
964 }
965
966 // Folding only implemented for 1D input, 2D weights and 1D bias
967 if (input_type.getShape().size() != 1 ||
968 weights_type.getShape().size() != 2 ||
969 (has_bias && bias_type.getShape().size() != 1)) {
970 return failure();
971 }
972
973 // Get the sizes
974 const auto input_size = input_type.getNumElements();
975 const auto output_size = output_type.getNumElements();
976
977 // Get iterators to the tensors.
978 const auto input_values_it = input_tensor.getValues<float>().begin();
979 const auto weights_values_ptr = weights_tensor.getValues<float>().begin();
980 auto weights_row_it = weights_values_ptr;
981 // The 'else' case could be nullptr, but the types don't match.
982 auto bias_values_it =
983 has_bias ? bias_tensor.getValues<float>().begin() : input_values_it;
984
985 // Do the actual folding, one output at a time.
986 std::vector<float> result_values;
987 result_values.reserve(output_size);
988
989 for (int i = 0; i < output_size; ++i) {
990 // Dot product with Kahan/Neumaier summation to minimize numeric errors.
991 float sum = has_bias ? *bias_values_it : 0.0f;
992 float compensation = 0.0f;
993 for (int j = 0; j < input_size; ++j) {
994 const float addend = input_values_it[j] * weights_row_it[j];
995 const float new_sum = sum + addend;
996 // DO NOT enable -funsafe-math-optimizations here.
997 // There is a test detecting unsafe optimizations.
998 // Unsafe math optimizations can reorder float formulas, and set the
999 // compensation to constant 0. The formula must be evaluated as written
1000 // for the algorithm to work.
1001 // (Note: -ffast-math is a superset of -funsafe-math-optimizations.)
1002 if (std::abs(sum) >= std::abs(addend)) {
1003 compensation += (sum - new_sum) + addend;
1004 } else {
1005 compensation += (addend - new_sum) + sum;
1006 }
1007 sum = new_sum;
1008 }
1009 result_values.push_back(sum + compensation);
1010 weights_row_it += input_size;
1011 bias_values_it++;
1012 }
1013
1014 // Set result tensor
1015 const auto folded =
1016 DenseElementsAttr::get(output_type, ArrayRef<float>(result_values));
1017 results.assign({folded});
1018
1019 return success();
1020 }
1021
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1022 void FullyConnectedOp::getCanonicalizationPatterns(
1023 OwningRewritePatternList &results, MLIRContext *context) {
1024 results.insert<RemoveOptionalZeroBias<FullyConnectedOp>>(context);
1025 }
1026
GetArithmeticCount(Operation * op)1027 int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) {
1028 int64_t count;
1029 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1030 op, &count))
1031 return count;
1032
1033 return -1;
1034 }
1035
1036 //===----------------------------------------------------------------------===//
1037 // Conv2DOp
1038 //===----------------------------------------------------------------------===//
1039
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1040 void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1041 MLIRContext *context) {
1042 // TODO(b/180121750): Enable the pattern after the integration tests are
1043 // fixed.
1044 // results.insert<RemoveOptionalZeroBias<Conv2DOp>>(context);
1045 }
1046
ComputeConvWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding,int64_t * output_size)1047 static LogicalResult ComputeConvWindowedOutputSize(
1048 int64_t input_size, int64_t filter_size, int64_t dilation_rate,
1049 int64_t stride, tensorflow::Padding padding, int64_t *output_size) {
1050 int64_t pad_low;
1051 int64_t pad_high;
1052
1053 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1054 input_size, filter_size, dilation_rate, stride, padding, output_size,
1055 &pad_low, &pad_high);
1056 // Return failure if expected_output_size could not be calculated.
1057 if (!status.ok()) return failure();
1058 return success();
1059 }
1060
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1061 LogicalResult Conv2DOp::inferReturnTypes(
1062 MLIRContext *, Optional<Location> location, ValueRange operands,
1063 DictionaryAttr attr, RegionRange,
1064 SmallVectorImpl<Type> &inferredReturnTypes) {
1065 Conv2DOpAdaptor op(operands, attr);
1066
1067 const Value input = op.input();
1068 const Value filter = op.filter();
1069
1070 const RankedTensorType input_ty =
1071 input.getType().dyn_cast_or_null<RankedTensorType>();
1072 const RankedTensorType filter_ty =
1073 filter.getType().dyn_cast_or_null<RankedTensorType>();
1074 // If indeed both input type & filter type are ranked type and have ranks.
1075 // We will need to check their ranks are valid.
1076 if ((input_ty && input_ty.hasRank() && input_ty.getRank() != 4) ||
1077 (filter_ty && filter_ty.hasRank() && filter_ty.getRank() != 4)) {
1078 return emitOptionalError(location, "Invalid ranks");
1079 }
1080
1081 // If either input or filter is unranked, we will just return unranked output
1082 // shape.
1083 if (!input_ty || !filter_ty || !input_ty.hasRank() || !filter_ty.hasRank()) {
1084 Type result_type;
1085 result_type = UnrankedTensorType::get(
1086 input.getType().cast<ShapedType>().getElementType());
1087 inferredReturnTypes.assign({result_type});
1088 return success();
1089 }
1090
1091 auto stride_h = op.stride_h().getInt();
1092 auto stride_w = op.stride_w().getInt();
1093 auto dilation_h = op.dilation_h_factor().getInt();
1094 auto dilation_w = op.dilation_w_factor().getInt();
1095
1096 // We don't have EXPLICIT PADDING in TfLite.
1097 auto paddings = op.padding().getValue();
1098 tensorflow::Padding padding;
1099 auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1100 if (!padding_is_valid.ok()) {
1101 return emitOptionalError(location, "invalid padding format provided");
1102 }
1103
1104 // Output always have rank 4. All dimensions are initialized to
1105 // dynamic size and can be partially inferred.
1106 // TFL's conv2d is always NHWC format & the filter is OHWI.
1107 SmallVector<int64_t, 4> return_shape(4, ShapedType::kDynamicSize);
1108 return_shape[0] = input_ty.getDimSize(0);
1109 return_shape[3] = filter_ty.getDimSize(0);
1110
1111 // Spatial dimensions can be inferred only when both input and filter are
1112 // ranked because we need to get their spatial dimensions.
1113
1114 // Height.
1115 if (!input_ty.isDynamicDim(1) && !filter_ty.isDynamicDim(1)) {
1116 int64_t output_height;
1117 if (failed(ComputeConvWindowedOutputSize(
1118 input_ty.getDimSize(1), filter_ty.getDimSize(1), dilation_h,
1119 stride_h, padding, &output_height))) {
1120 return failure();
1121 }
1122 return_shape[1] = output_height;
1123 }
1124
1125 // Width.
1126 if (!input_ty.isDynamicDim(2) && !filter_ty.isDynamicDim(2)) {
1127 int64_t output_width;
1128 if (failed(ComputeConvWindowedOutputSize(
1129 input_ty.getDimSize(2), filter_ty.getDimSize(2), dilation_w,
1130 stride_w, padding, &output_width))) {
1131 return failure();
1132 }
1133 return_shape[2] = output_width;
1134 }
1135
1136 auto result_type =
1137 mlir::RankedTensorType::get(return_shape, input_ty.getElementType());
1138
1139 inferredReturnTypes.assign({result_type});
1140 return success();
1141 }
1142
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)1143 bool Conv2DOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
1144 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
1145 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
1146 return true;
1147 }
1148
GetArithmeticCount(Operation * op)1149 int64_t Conv2DOp::GetArithmeticCount(Operation *op) {
1150 int64_t count;
1151 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1152 op, &count))
1153 return count;
1154
1155 return -1;
1156 }
1157
1158 //===----------------------------------------------------------------------===//
1159 // DepthwiseConv2DO
1160 //===----------------------------------------------------------------------===//
1161
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1162 void DepthwiseConv2DOp::getCanonicalizationPatterns(
1163 OwningRewritePatternList &results, MLIRContext *context) {
1164 // TODO(b/180121750): Enable the pattern after the integration tests are
1165 // fixed.
1166 // results.insert<RemoveOptionalZeroBias<DepthwiseConv2DOp>>(context);
1167 }
1168
GetArithmeticCount(Operation * op)1169 int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) {
1170 int64_t count;
1171 if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
1172 op, &count))
1173 return count;
1174
1175 return -1;
1176 }
1177
1178 //===----------------------------------------------------------------------===//
1179 // GatherOp
1180 //===----------------------------------------------------------------------===//
1181
BuildGatherOp(OpBuilder * builder,OperationState & result,Value params,Value indices,IntegerAttr axis,IntegerAttr batch_dims)1182 static void BuildGatherOp(OpBuilder *builder, OperationState &result,
1183 Value params, Value indices, IntegerAttr axis,
1184 IntegerAttr batch_dims) {
1185 auto params_type = params.getType().cast<TensorType>();
1186 auto indices_type = indices.getType().cast<TensorType>();
1187
1188 // If params/indices is unranked, then output is unranked.
1189 if (!params_type.hasRank() || !indices_type.hasRank())
1190 return TFL::GatherOp::build(
1191 *builder, result, UnrankedTensorType::get(params_type.getElementType()),
1192 params, indices, axis, batch_dims);
1193
1194 int64_t params_rank = params_type.getRank();
1195 int64_t indices_rank = indices_type.getRank();
1196
1197 // params rank is guaranteed to be at least 1.
1198 // Produces an output tensor with shape:
1199 // params.shape[:axis] + indices.shape + params.shape[axis + 1:]
1200 std::vector<int64_t> shape(params_type.getShape());
1201 int64_t axis_i = axis.getInt();
1202
1203 // For neg axis values, we wrap around params, e.g. axis = -1 => params[:-1]
1204 if (axis_i < 0) {
1205 axis_i += params_rank;
1206 }
1207
1208 // params must be at least rank axis + 1
1209 if (params_rank < axis_i + 1) {
1210 emitError(result.location, "params must be at least rank axis + 1");
1211 }
1212
1213 int64_t batch_dims_i = batch_dims.getInt();
1214 if (batch_dims_i < 0) {
1215 batch_dims_i += indices_rank;
1216 }
1217
1218 if (batch_dims_i > axis_i) {
1219 emitError(result.location,
1220 "axis should be bigger than or equal to batch_dims");
1221 }
1222 if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
1223 emitError(result.location,
1224 "batch_dims must be smaller than params' rank and smaller than "
1225 "or equal to indices'rank");
1226 }
1227 for (int i = 0; i < batch_dims_i; ++i) {
1228 if (indices_type.getShape()[i] != params_type.getShape()[i]) {
1229 emitError(result.location,
1230 "batch dimensions of params must be equal to batch dimensions "
1231 "of indices");
1232 }
1233 }
1234
1235 if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
1236 // Scalar indices (output is rank(params) - 1).
1237 // Erase shape[axis]
1238 shape.erase(shape.begin() + axis_i);
1239 } else if (indices_rank == 1) {
1240 // Vector indices (output is rank(params)).
1241 // Copy indices.shape into params.shape[axis]
1242 std::copy(std::begin(indices_type.getShape()),
1243 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1244 } else {
1245 // Higher rank indices (output is rank(params) + rank(indices) - 1).
1246 shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
1247 // Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
1248 std::copy(std::begin(params_type.getShape()) + axis_i + 1,
1249 std::end(params_type.getShape()),
1250 std::begin(shape) + axis_i + indices_rank - batch_dims_i);
1251
1252 // Copy indices.shape into params.shape[axis]
1253 std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
1254 std::end(indices_type.getShape()), std::begin(shape) + axis_i);
1255 }
1256
1257 TFL::GatherOp::build(
1258 *builder, result,
1259 RankedTensorType::get(shape, params_type.getElementType()), params,
1260 indices, axis, batch_dims);
1261 }
1262
1263 //===----------------------------------------------------------------------===//
1264 // ScatterNdOp
1265 //===----------------------------------------------------------------------===//
1266
Verify(ScatterNdOp op)1267 static LogicalResult Verify(ScatterNdOp op) {
1268 auto indices = op.indices();
1269 auto updates = op.updates();
1270 auto shape = op.shape();
1271 auto output = op.output();
1272
1273 auto updates_type = updates.getType().cast<ShapedType>();
1274 auto indices_type = indices.getType().cast<ShapedType>();
1275
1276 if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) {
1277 return success();
1278 }
1279
1280 // Checks if the shape of `updates` is a tensor of shape
1281 // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in
1282 // ScatterNd op description.
1283
1284 auto outer_dims = indices_type.getRank() - 1;
1285 auto outermost_dim = indices_type.getDimSize(outer_dims);
1286 // Checks whether the first `outer_dims` dimensions of `indices` and
1287 // `updates` are equal.
1288 for (auto i = 0; i < outer_dims; i++) {
1289 if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) {
1290 return op.emitOpError()
1291 << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i)
1292 << ", but updates.Dims(" << i
1293 << ") == " << updates_type.getDimSize(i);
1294 }
1295 }
1296
1297 auto output_type = output.getType().cast<ShapedType>();
1298 auto shape_type = shape.getType().cast<ShapedType>();
1299 if (shape_type.hasStaticShape()) {
1300 // Check the rank of `shape`.
1301 auto output_rank = outermost_dim + updates_type.getRank() - outer_dims;
1302 if (shape_type.getDimSize(0) != output_rank) {
1303 return op.emitOpError()
1304 << "shape must be a vector of length " << output_rank;
1305 }
1306 if (output_type.hasRank()) {
1307 if (output_type.getRank() != output_rank) {
1308 return op.emitOpError()
1309 << "output must have the same rank with the length of shape = "
1310 << output_rank;
1311 }
1312 }
1313 }
1314
1315 DenseIntElementsAttr shape_value;
1316 if (matchPattern(shape, m_Constant(&shape_value))) {
1317 for (const auto shape_elem : shape_value) {
1318 if (shape_elem.getSExtValue() <= 0) {
1319 return op.emitOpError("all elements of shape must be > 0");
1320 }
1321 }
1322
1323 // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)`
1324 // dimensions of `updates` and `shape` are equal.
1325 for (auto shape_it : llvm::enumerate(shape_value)) {
1326 int64_t i = shape_it.index();
1327 auto value = shape_it.value().getSExtValue();
1328 if (i >= outermost_dim) {
1329 auto corresponding_dim = i - outermost_dim + outer_dims;
1330 if (value != updates_type.getDimSize(corresponding_dim)) {
1331 return op.emitOpError()
1332 << "updates.Dims(" << i
1333 << ") == " << updates_type.getDimSize(corresponding_dim)
1334 << ", but shape[" << i << "] == " << value;
1335 }
1336 }
1337 }
1338
1339 // Checks if the output has the shape specified by `shape`.
1340 if (output_type.hasStaticShape()) {
1341 for (auto shape_it : llvm::enumerate(shape_value)) {
1342 int i = shape_it.index();
1343 auto value = shape_it.value().getSExtValue();
1344 if (output_type.getDimSize(i) != value) {
1345 return op.emitOpError()
1346 << "output shape [" << output_type.getShape()
1347 << "] must be equal to the value of shape " << shape_value;
1348 }
1349 }
1350 }
1351 }
1352 return success();
1353 }
1354
1355 //===----------------------------------------------------------------------===//
1356 // MulOp
1357 //===----------------------------------------------------------------------===//
1358
fold(ArrayRef<Attribute> operands)1359 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1360 // TODO(b/142478136): Handle fused ops.
1361 if (fused_activation_function() != "NONE") return {};
1362
1363 // This function is performance critical for op fusion patterns, e.g.
1364 // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d.
1365 // So a few specializations are provided to evaluate the math operation
1366 // more efficiently.
1367
1368 // Specialization for f32 type.
1369 if (getType().cast<ShapedType>().getElementType().isF32()) {
1370 return ConstFoldBinaryOp<FloatAttr, float>(
1371 getType(), operands[0], operands[1],
1372 [](float a, float b) { return a * b; });
1373 }
1374
1375 // Specialization for bf16 type.
1376 if (getType().cast<ShapedType>().getElementType().isBF16()) {
1377 return ConstFoldBinaryOp<FloatAttr, Eigen::bfloat16>(
1378 getType(), operands[0], operands[1],
1379 [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; });
1380 }
1381
1382 // Generic fallback with APFloat
1383 return ConstFoldBinaryOp(
1384 getType(), operands, [](APFloat a, APFloat b) { return a * b; },
1385 [](APInt a, APInt b) { return a * b; });
1386 }
1387
GetArithmeticCount(Operation * op)1388 int64_t MulOp::GetArithmeticCount(Operation *op) {
1389 int64_t count;
1390 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1391
1392 return -1;
1393 }
1394
1395 //===----------------------------------------------------------------------===//
1396 // DivOp
1397 //===----------------------------------------------------------------------===//
1398
fold(ArrayRef<Attribute> operands)1399 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1400 // TODO(b/142478136): Handle fused ops.
1401 if (fused_activation_function() != "NONE") return {};
1402 return ConstFoldBinaryOp(
1403 getType(), operands, [](APFloat a, APFloat b) { return a / b; },
1404 [](APInt a, APInt b) { return a.sdiv(b); });
1405 }
1406
GetArithmeticCount(Operation * op)1407 int64_t DivOp::GetArithmeticCount(Operation *op) {
1408 int64_t count;
1409 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1410
1411 return -1;
1412 }
1413
1414 //===----------------------------------------------------------------------===//
1415 // PackOp
1416 //===----------------------------------------------------------------------===//
1417
1418 // TODO(b/133486129): Implement shape inference for pack
1419
Verify(PackOp op)1420 static LogicalResult Verify(PackOp op) {
1421 // TODO(antiagainst): Implement other checks as in
1422 // tensorflow/lite/kernels/pack.cc
1423
1424 if (op.getOperation()->getNumOperands() != op.values_count())
1425 return op.emitOpError("input count should match 'values_count' attribute");
1426
1427 Value operand0 = op.getOperand(0);
1428 auto input_type = operand0.getType().cast<ShapedType>();
1429
1430 // Check axis bounds.
1431 if (input_type.hasRank()) {
1432 int32_t axis_value = op.axis();
1433 if (axis_value < 0) axis_value += input_type.getRank() + 1;
1434 if (axis_value < 0 || axis_value >= input_type.getRank() + 1)
1435 return op.emitOpError()
1436 << "op attribute 'axis' should be in range [-rank - 1, rank + 1), "
1437 << "got rank = " << input_type.getRank()
1438 << ", and axis = " << op.axis();
1439 }
1440
1441 // Make sure all inputs have the same shape and element type.
1442 // TODO(b/135032063): Simplify once fixed.
1443 for (Type operand_type : op.getOperandTypes()) {
1444 if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
1445 return op.emitOpError("operands should be of the same type. got ")
1446 << input_type << ", " << operand_type;
1447 }
1448
1449 return success();
1450 }
1451
1452 //===----------------------------------------------------------------------===//
1453 // PReluOp
1454 //===----------------------------------------------------------------------===//
1455
Verify(PReluOp op)1456 static LogicalResult Verify(PReluOp op) {
1457 auto input_type = op.input().getType().cast<ShapedType>();
1458 auto alpha_type = op.alpha().getType().cast<ShapedType>();
1459 auto output_type = op.output().getType().cast<ShapedType>();
1460
1461 if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
1462 if (input_type.getRank() != alpha_type.getRank() + 1) {
1463 return op.emitOpError("'alpha' should have one less rank than 'input'.");
1464 }
1465
1466 // Check if alpha is broadcastable
1467 for (int i = 0; i < alpha_type.getRank(); i++) {
1468 if (alpha_type.getDimSize(i) != input_type.getDimSize(i + 1) &&
1469 alpha_type.getDimSize(i) != 1) {
1470 return op.emitOpError(
1471 llvm::formatv("'alpha' is not broadcastable at dimension {0}.", i));
1472 }
1473 }
1474 }
1475
1476 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
1477 if (input_type.getRank() != output_type.getRank()) {
1478 return op.emitOpError("'input' and 'output' should have the same rank.");
1479 }
1480
1481 // Check if input and output shapes are same
1482 for (int i = 0; i < input_type.getRank(); i++) {
1483 if (input_type.getDimSize(i) != output_type.getDimSize(i)) {
1484 return op.emitOpError(
1485 "'input' and 'output' should have the same shape.");
1486 }
1487 }
1488 }
1489 return success();
1490 }
1491
1492 //===----------------------------------------------------------------------===//
1493 // ReshapeOp
1494 //===----------------------------------------------------------------------===//
1495
1496 namespace {
1497 // This pattern matches and merges a tfl.reshape under the following
1498 // condition:
1499 // * The input's defining op is another tfl.reshape.
1500 // TODO(antiagainst): This pattern probably should be moved to the peephole
1501 // category, after we have the infra for peephole passes.
1502 struct RemoveAdjacentReshape : public RewritePattern {
RemoveAdjacentReshapemlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1503 RemoveAdjacentReshape(MLIRContext *context)
1504 : RewritePattern(ReshapeOp::getOperationName(), 1, context) {}
1505
matchmlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1506 LogicalResult match(Operation *op) const override {
1507 auto thisOp = cast<ReshapeOp>(op);
1508 auto prevOp = thisOp.getOperand(0).getDefiningOp();
1509 return isa_and_nonnull<ReshapeOp>(prevOp) ? success() : failure();
1510 }
1511
rewritemlir::TFL::__anona97bba200e11::RemoveAdjacentReshape1512 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
1513 auto thisOp = cast<ReshapeOp>(op);
1514 auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
1515
1516 // Replace
1517 // %1 = "tfl.reshape"(%0, %shape0)
1518 // %2 = "tfl.reshape"(%1, %shape1)
1519 // With
1520 // %2 = "tfl.reshape"(%0, %shape1)
1521 rewriter.replaceOpWithNewOp<ReshapeOp>(
1522 op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
1523 }
1524 };
1525
1526 // The kernel expects an 1-D tensor for the shape operand if it presents. If all
1527 // the dimensions are '1's except the last dimension, it will be reshaped to a
1528 // 1-D tensor.
1529 // Note that this pattern doesn't check or change the content of the shape
1530 // tensor.
1531 struct ConvertShapeTo1D : public OpRewritePattern<ReshapeOp> {
1532 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1533
matchAndRewritemlir::TFL::__anona97bba200e11::ConvertShapeTo1D1534 LogicalResult matchAndRewrite(ReshapeOp reshape,
1535 PatternRewriter &rewriter) const override {
1536 if (!reshape.shape().hasOneUse()) return failure();
1537
1538 DenseIntElementsAttr shape;
1539 if (!matchPattern(reshape.shape(), m_Constant(&shape))) {
1540 return failure();
1541 }
1542 // It is already a 1-D constant, no change.
1543 auto old_shape = shape.getType().getShape();
1544 if (old_shape.size() == 1) {
1545 return failure();
1546 }
1547 // Verify all the leading dimensions are length one, except the last one.
1548 for (auto it = ++old_shape.rbegin(); it != old_shape.rend(); ++it) {
1549 if (*it != 1) {
1550 reshape->emitError(
1551 "Non-vector shape input is used, might cause runtime error");
1552 return failure();
1553 }
1554 }
1555 auto new_shape = shape.reshape(RankedTensorType::get(
1556 {*old_shape.rbegin()}, shape.getType().getElementType()));
1557 rewriter.replaceOpWithNewOp<TFL::ConstOp>(reshape.shape().getDefiningOp(),
1558 new_shape);
1559 return success();
1560 }
1561 };
1562
InputOutputHasSameShape(mlir::Type input_type,mlir::Type output_type)1563 bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type) {
1564 auto input_shaped_type = input_type.dyn_cast_or_null<ShapedType>();
1565 if (!input_shaped_type || !input_shaped_type.hasStaticShape()) return false;
1566
1567 auto output_shaped_type = output_type.dyn_cast_or_null<ShapedType>();
1568 if (!output_shaped_type || !output_shaped_type.hasStaticShape()) return false;
1569
1570 return input_shaped_type == output_shaped_type;
1571 }
1572
1573 } // end anonymous namespace
1574
fold(ArrayRef<Attribute> operands)1575 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1576 // Remove identity reshape with both static result and input shape.
1577 auto result_type = getType().cast<ShapedType>();
1578 auto input_type = getOperand(0).getType().cast<ShapedType>();
1579 if (InputOutputHasSameShape(input_type, result_type)) return input();
1580
1581 // Constant folding
1582 if (auto dense_elements = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1583 // If the result type isn't static, tries to derive the result type from
1584 // the #2 operand.
1585 if (!result_type.hasStaticShape()) {
1586 auto shape_elements = operands[1].dyn_cast_or_null<DenseElementsAttr>();
1587 if (!shape_elements) return nullptr;
1588
1589 SmallVector<int64_t, 4> shape_data;
1590 for (const auto &it : shape_elements.getValues<APInt>()) {
1591 shape_data.push_back(it.getSExtValue());
1592 }
1593 result_type =
1594 RankedTensorType::get(shape_data, input_type.getElementType());
1595 }
1596 return dense_elements.reshape(result_type);
1597 }
1598
1599 return nullptr;
1600 }
1601
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1602 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1603 MLIRContext *context) {
1604 results.insert<RemoveAdjacentReshape, ConvertShapeTo1D>(context);
1605 }
1606
1607 using ReshapeErrorHandler =
1608 llvm::function_ref<LogicalResult(const llvm::Twine &)>;
1609
GetReshapeOutputType(Value input,Value shape,ReshapeErrorHandler error_handler,TensorType & output_ty)1610 LogicalResult GetReshapeOutputType(Value input, Value shape,
1611 ReshapeErrorHandler error_handler,
1612 TensorType &output_ty) {
1613 auto input_ty = input.getType().cast<TensorType>();
1614 auto element_ty = input_ty.getElementType();
1615 output_ty = UnrankedTensorType::get(element_ty);
1616
1617 auto shape_ty = shape.getType().dyn_cast<RankedTensorType>();
1618 if (!shape_ty) return success();
1619 if (shape_ty.getRank() != 1)
1620 return error_handler(llvm::formatv(
1621 "requires 'shape' to be rank 1, but got {0}", shape_ty.getRank()));
1622
1623 DenseIntElementsAttr shape_attr;
1624 if (!matchPattern(shape, m_Constant(&shape_attr))) {
1625 // If only shape of `shape` is known, return ranked but dynamic output
1626 // shape.
1627 if (shape_ty.hasStaticShape()) {
1628 llvm::SmallVector<int64_t, 8> dynamic_shape(shape_ty.getDimSize(0),
1629 ShapedType::kDynamicSize);
1630 output_ty = RankedTensorType::get(dynamic_shape, element_ty);
1631 }
1632 return success();
1633 }
1634
1635 // Detect if reshape output shape is folded.
1636 bool shape_ty_zero_dim = false;
1637 int unknown_index = -1;
1638 // The product of constant shape argument excluding unknown dimension.
1639 int64_t shape_ty_size = 1;
1640 llvm::SmallVector<int64_t, 8> output_ty_shape;
1641 output_ty_shape.reserve(shape_attr.getNumElements());
1642 for (const auto &dim : llvm::enumerate(shape_attr.getIntValues())) {
1643 const int64_t size = dim.value().getSExtValue();
1644 if (size == ShapedType::kDynamicSize) {
1645 if (unknown_index != -1)
1646 return error_handler(llvm::formatv(
1647 "requires 'shape' to have at most one dynamic dimension, but got "
1648 "multiple dynamic dimensions at indices {0} and {1}. You need to "
1649 "set up the unspecified size(s) to avoid this problem, for example,"
1650 "setting batch size in keras model or setting unspecified input "
1651 "size(s) with fixed ones.",
1652 unknown_index, dim.index()));
1653
1654 unknown_index = dim.index();
1655 } else if (size == 0) {
1656 shape_ty_zero_dim = true;
1657 } else if (size > 0) {
1658 shape_ty_size *= size;
1659 } else {
1660 return error_handler(
1661 llvm::formatv("requires 'shape' to have dimensions greater than -1, "
1662 "but got {0} at index {1}",
1663 size, dim.index()));
1664 }
1665 output_ty_shape.push_back(size);
1666 }
1667
1668 if (!input_ty.hasStaticShape()) {
1669 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1670 return success();
1671 }
1672
1673 // Compute the value of the unknown dimension.
1674 if (unknown_index != -1) {
1675 // Compute number of elements in tensor shape.
1676 int64_t input_ty_size = 1;
1677 bool input_ty_zero_dim = false;
1678 for (const auto &dim : input_ty.getShape()) {
1679 if (dim > 0 || !shape_ty_zero_dim) {
1680 input_ty_size *= dim;
1681 } else {
1682 input_ty_zero_dim = true;
1683 }
1684 }
1685
1686 const int64_t missing_dim = input_ty_size / shape_ty_size;
1687 if (!input_ty_zero_dim && shape_ty_size * missing_dim != input_ty_size)
1688 return error_handler(
1689 llvm::formatv("requires 'input' number of elements be a multiple of "
1690 "{0}, but got {1}",
1691 shape_ty_size, input_ty_size));
1692
1693 // Set the unknown dimension such that total number of elements remain
1694 // constant.
1695 output_ty_shape[unknown_index] = missing_dim;
1696 }
1697
1698 output_ty = RankedTensorType::get(output_ty_shape, element_ty);
1699
1700 return success();
1701 }
1702
Verify(ReshapeOp op)1703 static LogicalResult Verify(ReshapeOp op) {
1704 auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult {
1705 return op.emitOpError() << message;
1706 };
1707 TensorType expected_ty;
1708 if (failed(GetReshapeOutputType(op.input(), op.shape(), error_handler,
1709 expected_ty)))
1710 return failure();
1711
1712 auto output_ty = op.getType().dyn_cast<RankedTensorType>();
1713 if (!output_ty) return success();
1714 auto input_ty = op.input().getType().cast<TensorType>();
1715 if (output_ty.hasStaticShape() && input_ty.hasStaticShape()) {
1716 const int64_t output_ty_size = output_ty.getNumElements();
1717 const int64_t input_ty_size = input_ty.getNumElements();
1718 if (input_ty_size != output_ty_size)
1719 return op.emitOpError() << "requires 'output' number of elements to "
1720 "match 'input' number of elements, but got "
1721 << output_ty_size << " and " << input_ty_size;
1722 }
1723
1724 if (!TF::AreCastCompatible({output_ty, expected_ty}))
1725 return op.emitOpError()
1726 << "requires 'output' type " << output_ty
1727 << " to be cast compatible with expected type " << expected_ty;
1728
1729 return success();
1730 }
1731
1732 //===----------------------------------------------------------------------===//
1733 // PackOp
1734 //===----------------------------------------------------------------------===//
1735
1736 // Remove redundant unpack pack op.
1737 // If a unpack op is followed by a pack op, we can remove the pack op, if the
1738 // unpack op is only consumed by the pack op, it will be removed as well.
1739 // An example illustration is:
1740 // Unpack [5, 8, 9], axis = 1
1741 // / \
1742 // value ... value [5, 9], 8 values in total
1743 // \ /
1744 // pack, axis = 1
1745 // |
1746 // value [5, 8, 9]
1747 //
1748 // This can actually be simplified into just:
1749 //
1750 // => Value [5, 8, 9]
1751 // TODO(b/133341698): Move to tablegen when variadic is supported.
1752 struct RemoveRedundantUnpackPack : public RewritePattern {
RemoveRedundantUnpackPackmlir::TFL::RemoveRedundantUnpackPack1753 explicit RemoveRedundantUnpackPack(MLIRContext *context)
1754 : RewritePattern(PackOp::getOperationName(), 2, context) {}
1755
matchAndRewritemlir::TFL::RemoveRedundantUnpackPack1756 LogicalResult matchAndRewrite(Operation *op,
1757 PatternRewriter &rewriter) const override {
1758 TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1759 Operation *first_input = pack_op.getOperand(0).getDefiningOp();
1760 if (!first_input) return failure();
1761 auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
1762 if (!input_unpack_op) return failure();
1763
1764 // The unpack & pack should have the same axis & num inputs/outputs.
1765 if (pack_op.axis() != input_unpack_op.axis() ||
1766 pack_op.values_count() != input_unpack_op.num())
1767 return failure();
1768
1769 const int total_pack_inputs = pack_op.getNumOperands();
1770 const int num_results = input_unpack_op.getNumResults();
1771 if (total_pack_inputs != num_results) return failure();
1772 for (auto input_output :
1773 llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
1774 Value pack_input = std::get<0>(input_output);
1775 Value unpack_output = std::get<1>(input_output);
1776 // Make sure the ordering is the same for the pack op & unpack op.
1777 if (pack_input != unpack_output) return failure();
1778 }
1779
1780 // Replace the pack's output to the unpack's input.
1781 rewriter.replaceOp(pack_op, input_unpack_op.getOperand());
1782 // At this point, we don't manually remove the redundant pack op & unpack op
1783 // (we cannot actually), but trust the PatterRewriter to garbage collect
1784 // these two ops.
1785 return success();
1786 }
1787 };
1788
1789 // Replace PackOp with a reshape when there is only one operand.
1790 struct ReplacePackWithReshape : public RewritePattern {
ReplacePackWithReshapemlir::TFL::ReplacePackWithReshape1791 explicit ReplacePackWithReshape(MLIRContext *context)
1792 : RewritePattern(PackOp::getOperationName(), 2, context) {}
matchAndRewritemlir::TFL::ReplacePackWithReshape1793 LogicalResult matchAndRewrite(Operation *op,
1794 PatternRewriter &rewriter) const override {
1795 TFL::PackOp pack_op = cast<TFL::PackOp>(op);
1796 if (pack_op.getNumOperands() != 1) return failure();
1797
1798 Location loc = pack_op.getLoc();
1799 auto output_type = pack_op.getType().cast<ShapedType>();
1800 if (!output_type.hasStaticShape()) return failure();
1801
1802 // This is to workaround the unnecessary cast i64 -> i32.
1803 SmallVector<int32_t, 4> new_shape_array;
1804 for (auto size : output_type.getShape()) {
1805 new_shape_array.push_back(static_cast<int32_t>(size));
1806 }
1807
1808 auto new_shape = rewriter.create<TFL::ConstOp>(
1809 loc, DenseIntElementsAttr::get(
1810 RankedTensorType::get(new_shape_array.size(),
1811 rewriter.getIntegerType(32)),
1812 new_shape_array));
1813
1814 rewriter.replaceOpWithNewOp<ReshapeOp>(op, output_type,
1815 pack_op.getOperand(0), new_shape);
1816 return success();
1817 }
1818 };
1819
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1820 void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1821 MLIRContext *context) {
1822 results.insert<RemoveRedundantUnpackPack, ReplacePackWithReshape>(context);
1823 }
1824
1825 //===----------------------------------------------------------------------===//
1826 // SliceOp
1827 //===----------------------------------------------------------------------===//
1828
Verify(SliceOp op)1829 static LogicalResult Verify(SliceOp op) {
1830 auto input_type = op.input().getType().cast<ShapedType>();
1831 auto begin_type = op.begin().getType().cast<ShapedType>();
1832 auto size_type = op.size().getType().cast<ShapedType>();
1833 if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
1834 size_type.hasStaticShape()) {
1835 if (input_type.getRank() != begin_type.getNumElements()) {
1836 return op.emitError(
1837 "begin tensor elements size is not equal to input tensor rank");
1838 }
1839
1840 if (input_type.getRank() != size_type.getNumElements()) {
1841 return op.emitError(
1842 "size tensor elements size is not equal to input tensor rank");
1843 }
1844 }
1845
1846 DenseIntElementsAttr begin;
1847 if (matchPattern(op.begin(), m_Constant(&begin))) {
1848 int axis = 0;
1849 for (auto begin_i : llvm::enumerate(begin)) {
1850 if (begin_i.value().getSExtValue() < 0) {
1851 return op.emitError(
1852 llvm::formatv("begin[{0}] cannot be negative", axis));
1853 }
1854 axis++;
1855 }
1856 }
1857
1858 DenseIntElementsAttr size;
1859 if (matchPattern(op.size(), m_Constant(&size))) {
1860 int axis = 0;
1861 for (auto size_i : llvm::enumerate(size)) {
1862 if (size_i.value().getSExtValue() < -1) {
1863 return op.emitError(
1864 llvm::formatv("size[{0}] cannot be negative other than -1", axis));
1865 }
1866 axis++;
1867 }
1868 }
1869
1870 if (begin && size && input_type.hasStaticShape()) {
1871 for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) {
1872 int begin_i =
1873 begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1874 int size_i =
1875 size.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
1876 int dim_i = input_type.getShape()[i];
1877 if (begin_i > dim_i) {
1878 return op.emitOpError(llvm::formatv(
1879 "begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
1880 }
1881 if (size_i >= 0 && begin_i + size_i > dim_i) {
1882 return op.emitError(llvm::formatv(
1883 "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
1884 dim_i));
1885 }
1886 }
1887 }
1888
1889 return success();
1890 }
1891
NarrowDownInt64InputValuesForOp(Operation * input_op,RankedTensorType value_type,Location loc,OpBuilder * builder)1892 TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
1893 RankedTensorType value_type,
1894 Location loc, OpBuilder *builder) {
1895 if (input_op == nullptr) return nullptr;
1896
1897 mlir::DenseIntElementsAttr attr;
1898 if (!matchPattern(input_op, m_Constant(&attr))) {
1899 return nullptr;
1900 }
1901
1902 auto value_shape_type = mlir::RankedTensorType::get(
1903 value_type.getShape(), builder->getIntegerType(32));
1904
1905 SmallVector<int32_t, 4> value_i32;
1906 value_i32.reserve(value_type.getRank());
1907 for (const auto &size : attr) {
1908 value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
1909 }
1910 auto new_value_i32_attr =
1911 mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
1912
1913 return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
1914 }
1915
1916 // This will cast down int64 values for TFL slice op.
1917 // This will require the begin & size are constants.
1918 struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
1919 using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
1920
matchAndRewritemlir::TFL::CastDonwInt64BeginEndToInt321921 LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
1922 PatternRewriter &rewriter) const override {
1923 auto begin = slice_op.begin();
1924 auto size = slice_op.size();
1925 auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
1926 auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
1927 auto begin_op = begin.getDefiningOp();
1928 auto size_op = size.getDefiningOp();
1929
1930 if (begin_op == nullptr && size_op == nullptr) return failure();
1931
1932 if (begin_type == nullptr && size_type == nullptr) return failure();
1933
1934 // Handle begin.
1935 if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
1936 auto new_begin = NarrowDownInt64InputValuesForOp(
1937 begin_op, begin_type, slice_op.getLoc(), &rewriter);
1938 if (new_begin != nullptr) {
1939 slice_op.setOperand(1, new_begin);
1940 }
1941 }
1942
1943 // Handle size.
1944 if (size_op && size_type && size_type.getElementType().isInteger(64)) {
1945 auto new_size = NarrowDownInt64InputValuesForOp(
1946 size_op, size_type, slice_op.getLoc(), &rewriter);
1947 if (new_size != nullptr) {
1948 slice_op.setOperand(2, new_size);
1949 }
1950 }
1951
1952 return success();
1953 }
1954 };
1955
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1956 void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1957 MLIRContext *context) {
1958 results.insert<CastDonwInt64BeginEndToInt32>(context);
1959 }
1960
1961 //===----------------------------------------------------------------------===//
1962 // SqueezeOp
1963 //===----------------------------------------------------------------------===//
1964
fold(ArrayRef<Attribute> operands)1965 OpFoldResult SqueezeOp::fold(ArrayRef<Attribute> operands) {
1966 auto input_ty = input().getType().dyn_cast<RankedTensorType>();
1967 auto result_ty = getType().dyn_cast<RankedTensorType>();
1968
1969 if (!input_ty || !result_ty) return {};
1970 if (input_ty == result_ty) return input();
1971 return {};
1972 }
1973
1974 //===----------------------------------------------------------------------===//
1975 // SubOp
1976 //===----------------------------------------------------------------------===//
1977
fold(ArrayRef<Attribute> operands)1978 OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
1979 // TODO(b/142478136): Handle fused ops.
1980 if (fused_activation_function() != "NONE") return {};
1981 return ConstFoldBinaryOp(
1982 getType(), operands, [](APFloat a, APFloat b) { return a - b; },
1983 [](APInt a, APInt b) { return a - b; });
1984 }
1985
GetArithmeticCount(Operation * op)1986 int64_t SubOp::GetArithmeticCount(Operation *op) {
1987 int64_t count;
1988 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count;
1989
1990 return -1;
1991 }
1992
1993 //===----------------------------------------------------------------------===//
1994 // TopKOp
1995 //===----------------------------------------------------------------------===//
1996
BuildTopKOp(OpBuilder * builder,OperationState & result,Value input,Value k)1997 static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input,
1998 Value k) {
1999 // Output size is only known if k is constant value. A negative dimension is
2000 // considered dynamic so use -1 here if k is not a constant value.
2001 int const_k = -1;
2002 ElementsAttr cst;
2003 if (matchPattern(k, m_Constant(&cst)))
2004 // These casts should all be valid due to how Tensor constants are stored.
2005 // TODO(jpienaar): This should use a helper function.
2006 const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
2007
2008 auto val_type = input.getType().cast<TensorType>();
2009 // If value is unranked, then so is results.
2010 if (!val_type.hasRank())
2011 return TFL::TopKV2Op::build(
2012 *builder, result, UnrankedTensorType::get(val_type.getElementType()),
2013 UnrankedTensorType::get(builder->getIntegerType(32)), input, k);
2014
2015 // Resultant shape is value.shape[:-1] + [k]
2016 std::vector<int64_t> shape(val_type.getShape());
2017 shape[shape.size() - 1] = const_k;
2018 TFL::TopKV2Op::build(
2019 *builder, result, RankedTensorType::get(shape, val_type.getElementType()),
2020 RankedTensorType::get(shape, builder->getIntegerType(32)), input, k);
2021 }
2022
2023 //===----------------------------------------------------------------------===//
2024 // FakeQuantOp
2025 //===----------------------------------------------------------------------===//
2026
2027 // Return true if the op has non-empty "minmax" attribute.
HasValidMinMaxAttribute(Operation * op)2028 static inline bool HasValidMinMaxAttribute(Operation *op) {
2029 auto minmax = op->getAttrOfType<ArrayAttr>("minmax");
2030 return minmax && minmax.getValue().size() == 2;
2031 }
2032
2033 namespace {
2034
2035 /// This pattern matches and remove a tfl.fake_quant if all the users of this op
2036 /// and itself have "minmax" attribute set.
2037 struct DropFakeQuant : public RewritePattern {
DropFakeQuantmlir::TFL::__anona97bba201211::DropFakeQuant2038 explicit DropFakeQuant(MLIRContext *context)
2039 : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {}
2040
matchmlir::TFL::__anona97bba201211::DropFakeQuant2041 LogicalResult match(Operation *op) const override {
2042 // We only match the op with valid "minmax" attribute.
2043 if (!HasValidMinMaxAttribute(op)) return failure();
2044
2045 // If all the users of this op have valid "minmax" attributes, it is matched
2046 // and can be removed.
2047 auto fakeQuantOp = cast<FakeQuantOp>(op);
2048 for (auto *operand : fakeQuantOp.getResult().getUsers())
2049 if (!HasValidMinMaxAttribute(operand)) return failure();
2050
2051 return success();
2052 }
2053
rewritemlir::TFL::__anona97bba201211::DropFakeQuant2054 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
2055 // Replace the matched FakeQuantOp by its primary operand.
2056 rewriter.replaceOp(op, op->getOperand(0));
2057 }
2058 };
2059 } // end anonymous namespace
2060
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2061 void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2062 MLIRContext *context) {
2063 results.insert<DropFakeQuant>(context);
2064 }
2065
2066 //===----------------------------------------------------------------------===//
2067 // UnpackOp
2068 //===----------------------------------------------------------------------===//
2069
2070 // TODO(b/133486129): Implement shape inference for unpack
2071
inferReturnTypes(MLIRContext * context,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2072 LogicalResult UnpackOp::inferReturnTypes(
2073 MLIRContext *context, Optional<Location> loc, ValueRange operands,
2074 DictionaryAttr attributes, RegionRange regions,
2075 SmallVectorImpl<Type> &inferredReturnTypes) {
2076 UnpackOpAdaptor op(operands, attributes);
2077 // TODO(jpienaar): Refactor verify
2078 if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
2079 return failure();
2080
2081 if (operands.size() != 1) {
2082 return emitOptionalError(loc, "input count should be equal to 1");
2083 }
2084
2085 const int64_t num_value = op.num().getInt();
2086 auto input_type = operands[0].getType().dyn_cast<ShapedType>();
2087 if (!input_type || !input_type.hasRank()) {
2088 // If input is unranked, then so is output.
2089 inferredReturnTypes.assign(
2090 num_value, UnrankedTensorType::get(input_type.getElementType()));
2091 return success();
2092 }
2093
2094 if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) {
2095 return emitOptionalError(
2096 loc, "number of elements in input should be larger than 0");
2097 }
2098
2099 const int64_t rank = input_type.getRank();
2100 if (rank <= 0) {
2101 return emitOptionalError(loc, "input should be of rank larger than 0");
2102 }
2103
2104 int64_t axis_value = op.axis().getInt();
2105 if (axis_value < 0) {
2106 axis_value += rank;
2107 }
2108 if (axis_value < 0 || axis_value >= rank) {
2109 return emitOptionalError(
2110 loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
2111 op.axis().getInt(), ", and rank = ", rank);
2112 }
2113
2114 if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
2115 input_type.getDimSize(axis_value) != num_value) {
2116 return emitOptionalError(loc, "output count should match 'num' attribute");
2117 }
2118
2119 auto output_shape = llvm::to_vector<4>(input_type.getShape());
2120 output_shape.erase(output_shape.begin() + axis_value);
2121
2122 auto output_type =
2123 RankedTensorType::get(output_shape, input_type.getElementType());
2124 inferredReturnTypes.assign(num_value, output_type);
2125
2126 return success();
2127 }
2128
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2129 bool UnpackOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
2130 if (lhs.size() != rhs.size()) return false;
2131 for (auto pair : llvm::zip(lhs, rhs)) {
2132 if (failed(
2133 mlir::verifyCompatibleShape(std::get<0>(pair), std::get<1>(pair))))
2134 return false;
2135 }
2136 return true;
2137 }
2138
2139 //===----------------------------------------------------------------------===//
2140 // SplitOp
2141 //===----------------------------------------------------------------------===//
2142
2143 // Extracts and returns the signed integer constant in a 0-rank integer tensor
2144 // or 1-element 1-rank integer tensor if 'value' is a constant.
ExtractConstantIntFromTensor(Value value)2145 static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
2146 ElementsAttr attr;
2147 if (!matchPattern(value, m_Constant(&attr))) return {};
2148 if (attr.getNumElements() != 1) return {};
2149 IntegerAttr int_attr = *attr.getValues<IntegerAttr>().begin();
2150 return int_attr.getValue().getSExtValue();
2151 }
2152
2153 // Returns a RankedTensorType which is similar to `input_type` but replaces the
2154 // dimension size of `dim` with `dim_size`. For example,
2155 // `SubstituteRankedTensorTypeDimSize(tensor<3x4xi32>, 1, 2)` returns
2156 // `tensor<3x2xi32>`.
SubstituteRankedTensorTypeDimSize(RankedTensorType input_type,int64_t dim,int64_t dim_size)2157 static RankedTensorType SubstituteRankedTensorTypeDimSize(
2158 RankedTensorType input_type, int64_t dim, int64_t dim_size) {
2159 auto shape = input_type.getShape().vec();
2160 shape[dim] = dim_size;
2161 return RankedTensorType::get(shape, input_type.getElementType());
2162 }
2163
2164 // Verifies the output tensor types of SplitOp or SplitVOp.
2165 template <typename ExpectedOutputTypeGetter>
VerifySplitOpOutputTypes(Operation * op,int64_t num_splits,ExpectedOutputTypeGetter get_expected_output_type)2166 static LogicalResult VerifySplitOpOutputTypes(
2167 Operation *op, int64_t num_splits,
2168 ExpectedOutputTypeGetter get_expected_output_type) {
2169 for (int64_t i = 0; i < num_splits; ++i) {
2170 auto expected_output_type = get_expected_output_type(i);
2171 Value output = op->getResult(i);
2172 if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
2173 return op->emitOpError()
2174 << "output #" << i << " should be " << expected_output_type
2175 << " instead got " << output.getType();
2176 }
2177 return success();
2178 }
2179
Verify(SplitOp op)2180 static LogicalResult Verify(SplitOp op) {
2181 int64_t num_splits = op.num_splits();
2182 if (op.getNumResults() != num_splits)
2183 return op.emitOpError("output count should match 'num_splits' attribute");
2184
2185 // If 'split_dim' is not a constant, there are no other checks.
2186 llvm::Optional<int64_t> split_dim_opt =
2187 ExtractConstantIntFromTensor(op.split_dim());
2188 if (!split_dim_opt) return success();
2189
2190 // If 'input' is not a ranked tensor, there are no other checks.
2191 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2192 if (!input_type) return success();
2193
2194 int64_t split_dim = split_dim_opt.getValue();
2195 const int64_t rank = input_type.getRank();
2196 if (split_dim < 0) split_dim += rank;
2197 if (split_dim < 0 || split_dim >= rank)
2198 return op.emitOpError("'split_dim' should be in [-rank, rank)");
2199
2200 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2201 // there are no other checks.
2202 const int64_t dim_size = input_type.getDimSize(split_dim);
2203 if (ShapedType::isDynamic(dim_size)) return success();
2204
2205 if (dim_size % num_splits != 0)
2206 return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis");
2207
2208 // Verifies output tensor types.
2209 RankedTensorType expected_output_type = SubstituteRankedTensorTypeDimSize(
2210 input_type, split_dim, dim_size / num_splits);
2211 return VerifySplitOpOutputTypes(
2212 op.getOperation(), num_splits,
2213 [expected_output_type](int64_t) { return expected_output_type; });
2214 }
2215
Verify(SplitVOp op)2216 static LogicalResult Verify(SplitVOp op) {
2217 int64_t num_splits = op.num_splits();
2218 if (op.getNumResults() != num_splits)
2219 return op.emitOpError("output count should match 'num_splits' attribute");
2220
2221 // If 'split_dim' is not a constant, there are no other checks.
2222 llvm::Optional<int64_t> split_dim_opt =
2223 ExtractConstantIntFromTensor(op.split_dim());
2224 if (!split_dim_opt) return success();
2225
2226 // If 'input' is not a ranked tensor, there are no other checks.
2227 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
2228 if (!input_type) return success();
2229
2230 int64_t split_dim = split_dim_opt.getValue();
2231 const int64_t rank = input_type.getRank();
2232 if (split_dim < 0) split_dim += rank;
2233 if (split_dim < 0 || split_dim >= rank)
2234 return op.emitOpError("'split_dim' should be in [-rank, rank)");
2235
2236 // If the 'split_dim' dimension of the 'input' tensor has a dynamic size,
2237 // there are no other checks.
2238 const int64_t dim_size = input_type.getDimSize(split_dim);
2239 if (ShapedType::isDynamic(dim_size)) return success();
2240
2241 // If 'size_splits' is not a constant, there are no other checks.
2242 ElementsAttr size_splits_attr;
2243 if (!matchPattern(op.size_splits(), m_Constant(&size_splits_attr)))
2244 return success();
2245
2246 if (size_splits_attr.getNumElements() != num_splits) {
2247 auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
2248 RankedTensorType expected_size_splits_type =
2249 RankedTensorType::get({num_splits}, size_splits_type.getElementType());
2250 return op.emitOpError("'size_splits' should be ")
2251 << expected_size_splits_type;
2252 }
2253
2254 // Normalizes and verifies 'size_splits'.
2255 // Note: TensorFlow allows one -1 element in 'size_splits'. The -1 element
2256 // means the rest of the dimension size.
2257 llvm::SmallVector<int64_t, 4> size_splits;
2258 size_splits.reserve(num_splits);
2259
2260 int64_t negative_size_split_loc = -1;
2261 int64_t total_size_splits = 0;
2262
2263 for (int64_t i = 0; i < num_splits; ++i) {
2264 auto size_split_attr = size_splits_attr.getValue<IntegerAttr>(i);
2265 int64_t size_split = size_split_attr.getValue().getSExtValue();
2266 size_splits.push_back(size_split);
2267 if (size_split >= 0) {
2268 total_size_splits += size_split;
2269 continue;
2270 }
2271 if (size_split < -1)
2272 return op.emitOpError(
2273 "elements of 'size_splits' should be greater than or equal to -1");
2274 if (negative_size_split_loc != -1)
2275 return op.emitOpError("'size_splits' can only have one -1");
2276 negative_size_split_loc = i;
2277 }
2278
2279 if (negative_size_split_loc != -1) {
2280 if (total_size_splits > dim_size)
2281 return op.emitOpError(
2282 "sum of non-negative elements of 'size_splits' is greater than the "
2283 "dimension size of 'split_dim' axis");
2284 size_splits[negative_size_split_loc] = dim_size - total_size_splits;
2285 total_size_splits = dim_size;
2286 }
2287
2288 if (total_size_splits != dim_size)
2289 return op.emitOpError(
2290 "sum of 'size_splits' should match the dimension size of 'split_dim' "
2291 "axis");
2292
2293 // Verifies result tensor types.
2294 auto get_expected_output_type = [input_type, split_dim,
2295 &size_splits](int64_t i) {
2296 return SubstituteRankedTensorTypeDimSize(input_type, split_dim,
2297 size_splits[i]);
2298 };
2299 return VerifySplitOpOutputTypes(op.getOperation(), num_splits,
2300 get_expected_output_type);
2301 }
2302
2303 //===----------------------------------------------------------------------===//
2304 // MeanOp
2305 //===----------------------------------------------------------------------===//
2306
2307 // TODO(b/133854225): Implement shape inference to Mean
2308
2309 //===----------------------------------------------------------------------===//
2310 // LSTMOp
2311 //===----------------------------------------------------------------------===//
2312
Verify(LSTMOp op)2313 static LogicalResult Verify(LSTMOp op) {
2314 auto operands = op.GetStatefulOperands();
2315 if (operands.size() != 2 || operands[0] != 18 || operands[1] != 19) {
2316 return op.emitOpError("LSTMOp expected to have two stateful operands");
2317 }
2318
2319 const auto input_type = op.input().getType().cast<ShapedType>();
2320 // Since TFLite runtime generally supports dynamic shape/rank, if `input_type`
2321 // doesn't have static shape, we skip the shape check below.
2322 if (!input_type.hasStaticShape()) return success();
2323 // The input should be at least 2D tensor since it will go through fully
2324 // connected layer.
2325 if (!input_type.hasRank() || input_type.getRank() < 2)
2326 return op.emitOpError(
2327 "the first input operand should have more than 2 dimensions.");
2328
2329 const auto activation_state =
2330 op.input_activation_state().getType().cast<ShapedType>();
2331 const auto cell_state = op.input_cell_state().getType().cast<ShapedType>();
2332 const auto input_to_output_weights =
2333 op.input_to_output_weights().getType().cast<ShapedType>();
2334 const auto recurrent_to_output_weights =
2335 op.recurrent_to_output_weights().getType().cast<ShapedType>();
2336 if (activation_state.hasStaticShape() && cell_state.hasStaticShape() &&
2337 input_to_output_weights.hasStaticShape() &&
2338 recurrent_to_output_weights.hasStaticShape()) {
2339 const int n_input = input_type.getDimSize(input_type.getRank() - 1);
2340 const int n_cell = input_to_output_weights.getDimSize(0);
2341 const int n_output = recurrent_to_output_weights.getDimSize(1);
2342 const int output_state_size = activation_state.getNumElements();
2343 const int n_batch = input_type.getRank() == 2 ? input_type.getDimSize(0)
2344 : input_type.getDimSize(1);
2345 const int state_size = cell_state.getNumElements();
2346
2347 // Check if the dimension of the inputs matches.
2348 if ((output_state_size != n_batch * n_output) ||
2349 (state_size != n_batch * n_cell) ||
2350 (input_to_output_weights.getDimSize(1) != n_input) ||
2351 (recurrent_to_output_weights.getRank() != 2) ||
2352 (recurrent_to_output_weights.getDimSize(0) != n_cell) ||
2353 (input_to_output_weights.getRank() != 2)) {
2354 return op.emitOpError("inputs don't match with the dimensions.");
2355 }
2356
2357 const bool is_layer_norm_lstm =
2358 !op.forget_layer_norm_coefficients().getType().isa<NoneType>();
2359 if (is_layer_norm_lstm) {
2360 const auto forget_layer_norm_coefficients =
2361 op.forget_layer_norm_coefficients().getType().cast<ShapedType>();
2362 // If this lstm has layer normalization, this input value,
2363 // "forget_layer_norm_coefficients" should be a 1D tensor.
2364 if (!forget_layer_norm_coefficients.hasRank() ||
2365 forget_layer_norm_coefficients.getRank() != 1 ||
2366 forget_layer_norm_coefficients.getDimSize(0) != n_cell)
2367 return op.emitOpError(
2368 "coefficient inputs have more than 2 dimensions or "
2369 "don't match the dimension with input operand "
2370 "`input_to_output_weights`.");
2371 }
2372 }
2373
2374 return success();
2375 }
2376
2377 namespace {
2378
2379 // Replaces the optional bias operands with a "none" type value if the bias
2380 // values are constant zeros.
2381 struct RemoveLSTMOpZeroBias : public OpRewritePattern<LSTMOp> {
2382 using OpRewritePattern<LSTMOp>::OpRewritePattern;
2383
matchAndRewritemlir::TFL::__anona97bba201511::RemoveLSTMOpZeroBias2384 LogicalResult matchAndRewrite(LSTMOp op,
2385 PatternRewriter &rewriter) const override {
2386 if (EqualsZero(op.input_gate_bias())) {
2387 auto none_value = rewriter.create<mlir::ConstantOp>(
2388 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2389 op.input_gate_biasMutable().assign(none_value);
2390 }
2391
2392 if (EqualsZero(op.projection_bias())) {
2393 auto none_value = rewriter.create<mlir::ConstantOp>(
2394 rewriter.getUnknownLoc(), rewriter.getUnitAttr());
2395 op.projection_biasMutable().assign(none_value);
2396 }
2397
2398 return success();
2399 }
2400 };
2401
2402 } // namespace
2403
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2404 void LSTMOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2405 MLIRContext *context) {
2406 results.insert<RemoveLSTMOpZeroBias>(context);
2407 }
2408
2409 //===----------------------------------------------------------------------===//
2410 // UnidirectionalSequenceLSTMOp
2411 //===----------------------------------------------------------------------===//
2412
Verify(UnidirectionalSequenceLSTMOp op)2413 static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
2414 auto operands = op.GetStatefulOperands();
2415 if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) {
2416 return success();
2417 }
2418 return op.emitError(
2419 "UnidirectionalSequenceLSTMOp expected to have two stateful operands");
2420 }
2421
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2422 LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
2423 MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
2424 RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
2425 Value input = operands[0];
2426 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
2427
2428 Value output_state = operands[18];
2429 auto output_state_type =
2430 output_state.getType().dyn_cast_or_null<RankedTensorType>();
2431
2432 if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
2433 return failure();
2434 }
2435
2436 if (output_state_type && output_state_type.hasRank() &&
2437 output_state_type.getRank() != 2) {
2438 return failure();
2439 }
2440
2441 if (!input_type || !input_type.hasRank() || !output_state_type ||
2442 !output_state_type.hasRank()) {
2443 // We cannot infer the output shape since we don't know the input shape or
2444 // the output state shape. We will set the output shape as unranked.
2445 Type result_type;
2446 result_type = UnrankedTensorType::get(
2447 input.getType().cast<ShapedType>().getElementType());
2448 inferredReturnTypes.assign({result_type});
2449 return success();
2450 }
2451
2452 // Default to non-time_major.
2453 bool time_majored = attr.getNamed("time_major").hasValue()
2454 ? attr.getNamed("time_major")
2455 .getValue()
2456 .second.cast<BoolAttr>()
2457 .getValue()
2458 : false;
2459
2460 int batch =
2461 time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
2462 int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
2463 int n_output = output_state_type.getDimSize(1);
2464
2465 // Build the output shape.
2466 SmallVector<int64_t, 3> output_shape;
2467 if (time_majored) {
2468 output_shape = {time, batch, n_output};
2469 } else {
2470 output_shape = {batch, time, n_output};
2471 }
2472 auto result_type =
2473 mlir::RankedTensorType::get(output_shape, input_type.getElementType());
2474
2475 inferredReturnTypes.assign({result_type});
2476 return success();
2477 }
2478
isCompatibleReturnTypes(TypeRange lhs,TypeRange rhs)2479 bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(TypeRange lhs,
2480 TypeRange rhs) {
2481 if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
2482 if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
2483 return true;
2484 }
2485
2486 //===----------------------------------------------------------------------===//
2487 // BidirectionalSequenceLSTMOp
2488 //===----------------------------------------------------------------------===//
2489
Verify(BidirectionalSequenceLSTMOp op)2490 static LogicalResult Verify(BidirectionalSequenceLSTMOp op) {
2491 auto operands = op.GetStatefulOperands();
2492 if (operands.size() == 4 && operands[0] == 35 && operands[1] == 36 &&
2493 operands[2] == 37 && operands[3] == 38) {
2494 return success();
2495 }
2496 return op.emitError(
2497 "BidirectionalSequenceLSTMOp expected to have four stateful operands");
2498 }
2499
2500 //===----------------------------------------------------------------------===//
2501 // UnidirectionalSequenceRNNOp
2502 //===----------------------------------------------------------------------===//
2503
Verify(UnidirectionalSequenceRNNOp op)2504 static LogicalResult Verify(UnidirectionalSequenceRNNOp op) {
2505 auto operands = op.GetStatefulOperands();
2506 if (operands.size() == 1 && operands[0] == 4) {
2507 return success();
2508 }
2509 return op.emitError(
2510 "UnidirectionalSequenceRNNOp expected to have one stateful operand");
2511 }
2512
2513 //===----------------------------------------------------------------------===//
2514 // SvdfOp
2515 //===----------------------------------------------------------------------===//
2516
Verify(SVDFOp op)2517 static LogicalResult Verify(SVDFOp op) {
2518 auto operands = op.GetStatefulOperands();
2519 if (operands.size() == 1 && operands[0] == 4) {
2520 return success();
2521 }
2522 return op.emitError("SvdfOp expected to have one stateful operand");
2523 }
2524
2525 //===----------------------------------------------------------------------===//
2526 // AbsOp
2527 //===----------------------------------------------------------------------===//
2528
fold(ArrayRef<Attribute> operands)2529 OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
2530 Type result_type = getType();
2531 // Only constant fold for tensor of f32 is implemented.
2532 if (!IsF32ShapedType(result_type)) return nullptr;
2533
2534 auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); };
2535 return ConstFoldUnaryOp(result_type, operands[0], compute);
2536 }
2537
2538 //===----------------------------------------------------------------------===//
2539 // NegOp
2540 //===----------------------------------------------------------------------===//
2541
fold(ArrayRef<Attribute> operands)2542 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
2543 Type result_type = getType();
2544 // Only constant fold for tensor of f32 is implemented.
2545 if (!IsF32ShapedType(result_type)) return nullptr;
2546
2547 auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
2548 return ConstFoldUnaryOp(result_type, operands[0], compute);
2549 }
2550
2551 //===----------------------------------------------------------------------===//
2552 // SinOp
2553 //===----------------------------------------------------------------------===//
2554
fold(ArrayRef<Attribute> operands)2555 OpFoldResult SinOp::fold(ArrayRef<Attribute> operands) {
2556 Type result_type = getType();
2557 // Only constant fold for tensor of f32 is implemented.
2558 if (!IsF32ShapedType(result_type)) return nullptr;
2559
2560 auto compute = [](APFloat value) -> APFloat {
2561 float f = value.convertToFloat();
2562 float result = std::sin(f);
2563 return APFloat(result);
2564 };
2565 return ConstFoldUnaryOp(result_type, operands[0], compute);
2566 }
2567
2568 //===----------------------------------------------------------------------===//
2569 // CosOp
2570 //===----------------------------------------------------------------------===//
2571
fold(ArrayRef<Attribute> operands)2572 OpFoldResult CosOp::fold(ArrayRef<Attribute> operands) {
2573 Type result_type = getType();
2574 // Only constant fold for tensor of f32 is implemented.
2575 if (!IsF32ShapedType(result_type)) return nullptr;
2576
2577 auto compute = [](APFloat value) -> APFloat {
2578 float f = value.convertToFloat();
2579 float result = std::cos(f);
2580 return APFloat(result);
2581 };
2582 return ConstFoldUnaryOp(result_type, operands[0], compute);
2583 }
2584
2585 //===----------------------------------------------------------------------===//
2586 // LogOp
2587 //===----------------------------------------------------------------------===//
2588
fold(ArrayRef<Attribute> operands)2589 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
2590 Type result_type = getType();
2591 // Only constant fold for tensor of f32 is implemented.
2592 if (!IsF32ShapedType(result_type)) return nullptr;
2593
2594 auto compute = [](APFloat value) -> APFloat {
2595 float f = value.convertToFloat();
2596 float result = std::log(f);
2597 return APFloat(result);
2598 };
2599 return ConstFoldUnaryOp(result_type, operands[0], compute);
2600 }
2601
2602 //===----------------------------------------------------------------------===//
2603 // ShapeOp
2604 //===----------------------------------------------------------------------===//
2605
fold(ArrayRef<Attribute> operands)2606 OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
2607 auto input_type = input().getType().cast<ShapedType>();
2608 if (!input_type.hasStaticShape()) return nullptr;
2609
2610 ArrayRef<int64_t> shape = input_type.getShape();
2611 auto result_type = getType().cast<ShapedType>();
2612 if (result_type.getElementType().isInteger(64)) {
2613 return DenseElementsAttr::get<int64_t>(result_type, shape);
2614 } else if (result_type.getElementType().isInteger(32)) {
2615 SmallVector<int32_t, 4> shape_i32;
2616 shape_i32.reserve(shape.size());
2617 for (int64_t dim : shape) {
2618 shape_i32.push_back(dim);
2619 }
2620 return DenseElementsAttr::get<int32_t>(result_type, shape_i32);
2621 }
2622 return nullptr;
2623 }
2624
2625 //===----------------------------------------------------------------------===//
2626 // SqrtOp
2627 //===----------------------------------------------------------------------===//
2628
fold(ArrayRef<Attribute> operands)2629 OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
2630 Type result_type = getType();
2631 // Only constant fold for tensor of f32 is implemented.
2632 if (!IsF32ShapedType(result_type)) return nullptr;
2633
2634 auto compute = [](APFloat value) -> APFloat {
2635 float f = value.convertToFloat();
2636 float result = std::sqrt(f);
2637 return APFloat(result);
2638 };
2639 return ConstFoldUnaryOp(result_type, operands[0], compute);
2640 }
2641
2642 //===----------------------------------------------------------------------===//
2643 // RsqrtOp
2644 //===----------------------------------------------------------------------===//
2645
fold(ArrayRef<Attribute> operands)2646 OpFoldResult RsqrtOp::fold(ArrayRef<Attribute> operands) {
2647 Type result_type = getType();
2648 // Only constant fold for tensor of f32/bf16 is implemented.
2649 if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type))
2650 return nullptr;
2651
2652 auto compute = [](APFloat value) -> APFloat {
2653 bool loseInfo;
2654 const llvm::fltSemantics &original_float_semantics = value.getSemantics();
2655 value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
2656 &loseInfo);
2657 float f = value.convertToFloat();
2658 APFloat result(1.f / std::sqrt(f));
2659 result.convert(original_float_semantics, APFloat::rmNearestTiesToEven,
2660 &loseInfo);
2661 return result;
2662 };
2663 return ConstFoldUnaryOp(result_type, operands[0], compute);
2664 }
2665
2666 //===----------------------------------------------------------------------===//
2667 // SquareOp
2668 //===----------------------------------------------------------------------===//
2669
fold(ArrayRef<Attribute> operands)2670 OpFoldResult SquareOp::fold(ArrayRef<Attribute> operands) {
2671 Type result_type = getType();
2672 // Only constant fold for tensor of f32 is implemented.
2673 if (!IsF32ShapedType(result_type)) return nullptr;
2674
2675 auto compute = [](APFloat value) -> APFloat { return value * value; };
2676 return ConstFoldUnaryOp(result_type, operands[0], compute);
2677 }
2678
2679 //===----------------------------------------------------------------------===//
2680 // RankOp
2681 //===----------------------------------------------------------------------===//
2682
fold(ArrayRef<Attribute> operands)2683 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
2684 assert(operands.size() == 1);
2685 auto result_type = getType().cast<ShapedType>();
2686 if (auto elements_attr = operands[0].dyn_cast_or_null<ElementsAttr>()) {
2687 auto rank = static_cast<int32_t>(elements_attr.getType().getRank());
2688 return DenseElementsAttr::get(result_type, {rank});
2689 }
2690
2691 // Also fold if `input` has a known rank.
2692 auto input_type = input().getType().cast<ShapedType>();
2693 // Do not fold if rank is zero because the TFLite converter doesn't
2694 // distinguish between unranked input and scalar input due to b/138865275.
2695 // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
2696 // predicate and fold the op when rank is zero.
2697 if (input_type.hasRank() && input_type.getRank() != 0) {
2698 auto rank = static_cast<int32_t>(input_type.getRank());
2699 return DenseElementsAttr::get(result_type, {rank});
2700 }
2701
2702 return nullptr;
2703 }
2704
2705 //===----------------------------------------------------------------------===//
2706 // ConstOp
2707 //===----------------------------------------------------------------------===//
2708
fold(ArrayRef<Attribute> operands)2709 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
2710 assert(operands.empty() && "constant has no operands");
2711 // Return the held attribute value.
2712 return value();
2713 }
2714
2715 namespace {
2716 struct FoldPseudoConstOp : public OpRewritePattern<ConstOp> {
2717 using OpRewritePattern<ConstOp>::OpRewritePattern;
2718
matchAndRewritemlir::TFL::__anona97bba201e11::FoldPseudoConstOp2719 LogicalResult matchAndRewrite(ConstOp const_op,
2720 PatternRewriter &rewriter) const override {
2721 if (!ConstantOp::isBuildableWith(const_op.value(), const_op.getType()))
2722 return failure();
2723 rewriter.replaceOpWithNewOp<ConstantOp>(const_op, const_op.value());
2724 return success();
2725 }
2726 };
2727
2728 } // namespace
2729
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2730 void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2731 MLIRContext *context) {
2732 results.insert<FoldPseudoConstOp>(context);
2733 }
2734
2735 //===----------------------------------------------------------------------===//
2736 // CastOp
2737 //===----------------------------------------------------------------------===//
2738
fold(ArrayRef<Attribute> operands)2739 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
2740 assert(operands.size() == 1);
2741 if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
2742 return input();
2743 }
2744
2745 // For now, only supports cast between integer types.
2746 auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2747 if (!elements_attr) {
2748 return nullptr;
2749 }
2750
2751 auto result_element_type =
2752 getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
2753 auto operand_element_type = input()
2754 .getType()
2755 .cast<ShapedType>()
2756 .getElementType()
2757 .dyn_cast<IntegerType>();
2758 // Returns nullptr if either result/operand element type is not integer.
2759 if (!result_element_type || !operand_element_type) {
2760 return nullptr;
2761 }
2762
2763 const bool is_unsigned = operand_element_type.isUnsigned();
2764 const bool involves_bool = operand_element_type.getWidth() == 1 ||
2765 result_element_type.getWidth() == 1;
2766 const int output_bitwidth = result_element_type.getWidth();
2767 // The integer cast op is the same as C integer cast. Depends on the operand
2768 // type's signedness, we will determine whether or not sign extension is
2769 // needed.
2770 auto cast = [&](APInt value) {
2771 if (involves_bool) {
2772 // Handle boolean inputs or outputs explicitly as it doesn't have the same
2773 // behavior as extension or truncation.
2774 // true input should always be cast to 1 and not -1 as the sign extension
2775 // would do for signed outputs. Similarly, non-zero inputs should be cast
2776 // to true. Truncating even numbers to one bit will result in `false`.
2777 return APInt(result_element_type.getWidth(), value != 0);
2778 }
2779 return is_unsigned ? value.zextOrTrunc(output_bitwidth)
2780 : value.sextOrTrunc(output_bitwidth);
2781 };
2782
2783 return elements_attr.mapValues(result_element_type, cast);
2784 }
2785
2786 //===----------------------------------------------------------------------===//
2787 // SelectV2Op
2788 //===----------------------------------------------------------------------===//
2789
BuildSelectV2Op(Builder * builder,OperationState & result,Value cond,Value x,Value y)2790 static void BuildSelectV2Op(Builder *builder, OperationState &result,
2791 Value cond, Value x, Value y) {
2792 auto operand_type =
2793 OpTrait::util::getBroadcastedType(x.getType(), y.getType());
2794
2795 if (!operand_type)
2796 emitError(result.location) << "non-broadcastable operands: " << x.getType()
2797 << " and " << y.getType();
2798
2799 bool has_static_cond_shape = false;
2800 bool has_static_operand_shape = false;
2801 ArrayRef<int64_t> cond_shape;
2802 ArrayRef<int64_t> operand_shape;
2803
2804 if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
2805 if (shaped_type.hasStaticShape()) {
2806 has_static_cond_shape = true;
2807 cond_shape = shaped_type.getShape();
2808 }
2809 }
2810 if (auto shaped_type = operand_type.dyn_cast<ShapedType>()) {
2811 if (shaped_type.hasStaticShape()) {
2812 has_static_operand_shape = true;
2813 operand_shape = shaped_type.getShape();
2814 }
2815 }
2816
2817 SmallVector<int64_t, 4> broadcastedShape;
2818 if (has_static_cond_shape && has_static_operand_shape &&
2819 !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
2820 broadcastedShape)) {
2821 emitError(result.location) << "non-broadcastable operands: " << operand_type
2822 << " and " << cond.getType();
2823 }
2824
2825 result.addOperands({cond, x, y});
2826
2827 auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
2828 if (has_static_cond_shape && has_static_operand_shape) {
2829 result.types.push_back(
2830 RankedTensorType::get(broadcastedShape, elementType));
2831 } else {
2832 result.types.push_back(UnrankedTensorType::get(elementType));
2833 }
2834 }
2835
2836 //===----------------------------------------------------------------------===//
2837 // RangeOp
2838 //===----------------------------------------------------------------------===//
2839
2840 namespace {
2841
2842 // Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`.
2843 // Template parameter `FloatOrInt` must be standard C integer or floating-point
2844 // types.
2845 template <typename FloatOrInt>
GetLengthOfRange(FloatOrInt start,FloatOrInt limit,FloatOrInt delta)2846 int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) {
2847 // Refer to the implementation in
2848 // tensorflow/lite/kernels/range.cc.
2849 return std::is_integral<FloatOrInt>::value
2850 ? ((std::abs(limit - start) + std::abs(delta) - 1) /
2851 std::abs(delta))
2852 : std::ceil(std::abs((limit - start) / delta));
2853 }
2854
2855 // Builds a constant range tensor of `result_elem_type` elements.
2856 // Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or
2857 // mlir::FloatAttr.
2858 template <typename FloatOrIntAtrr>
BuildConstRangeTensor(Type result_elem_type,int num_elements,FloatOrIntAtrr start_attr,FloatOrIntAtrr delta_attr)2859 DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements,
2860 FloatOrIntAtrr start_attr,
2861 FloatOrIntAtrr delta_attr) {
2862 using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat
2863 ValueType start = start_attr.getValue();
2864 ValueType delta = delta_attr.getValue();
2865
2866 SmallVector<ValueType, 16> new_values;
2867 new_values.reserve(num_elements);
2868 ValueType new_value = start;
2869 for (int i = 0; i < num_elements; ++i) {
2870 new_values.push_back(new_value);
2871 new_value = new_value + delta;
2872 }
2873 // Result is always a 1-D tensor.
2874 auto new_result_type =
2875 RankedTensorType::get({num_elements}, result_elem_type);
2876 return DenseElementsAttr::get(new_result_type, new_values);
2877 }
2878 } // namespace
2879
fold(ArrayRef<Attribute> operands)2880 OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
2881 assert(operands.size() == 3);
2882 auto start_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
2883 auto limit_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
2884 auto delta_tensor = operands[2].dyn_cast_or_null<ElementsAttr>();
2885 if (start_tensor && limit_tensor && delta_tensor) {
2886 // Operands should all be scalars
2887 assert(start_tensor.getType().getRank() == 0 &&
2888 limit_tensor.getType().getRank() == 0 &&
2889 delta_tensor.getType().getRank() == 0);
2890 Type elem_type = getType().cast<ShapedType>().getElementType();
2891 if (elem_type.isSignlessInteger()) {
2892 auto start_attr = start_tensor.getValue<IntegerAttr>({});
2893 auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
2894 auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
2895 const int num_elements = GetLengthOfRange(
2896 start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt());
2897 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2898 delta_attr);
2899 } else if (elem_type.isa<FloatType>()) {
2900 auto start_attr = start_tensor.getValue<FloatAttr>({});
2901 auto limit_attr = limit_tensor.getValue<FloatAttr>({});
2902 auto delta_attr = delta_tensor.getValue<FloatAttr>({});
2903 const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(),
2904 limit_attr.getValueAsDouble(),
2905 delta_attr.getValueAsDouble());
2906 return BuildConstRangeTensor(elem_type, num_elements, start_attr,
2907 delta_attr);
2908 }
2909 }
2910
2911 return nullptr;
2912 }
2913
2914 //===----------------------------------------------------------------------===//
2915 // TransposeConvOp
2916 //===----------------------------------------------------------------------===//
2917
Verify(TransposeConvOp op)2918 static LogicalResult Verify(TransposeConvOp op) {
2919 ShapedType output_type = op.output().getType().cast<ShapedType>();
2920 ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
2921 if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
2922 if (output_type.getRank() != output_shape_type.getDimSize(0)) {
2923 return op.emitOpError(llvm::formatv(
2924 "expect output type has rank = {0}, got output type {1}",
2925 output_shape_type.getDimSize(0), output_type));
2926 }
2927 }
2928
2929 DenseIntElementsAttr output_shape_elements;
2930 if (!matchPattern(op.output_shape(), m_Constant(&output_shape_elements))) {
2931 return success();
2932 }
2933
2934 llvm::SmallVector<int64_t, 4> output_shape;
2935 output_shape.reserve(output_shape_elements.getNumElements());
2936 for (auto dim : output_shape_elements.getValues<int>()) {
2937 output_shape.push_back(dim);
2938 }
2939
2940 auto expected_output_type =
2941 RankedTensorType::get(output_shape, output_type.getElementType());
2942 if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
2943 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
2944 expected_output_type, output_type));
2945 }
2946
2947 return success();
2948 }
2949
GetArithmeticCount(Operation * op)2950 int64_t TransposeConvOp::GetArithmeticCount(Operation *op) {
2951 int64_t count = -1;
2952 auto transpose_conv = llvm::dyn_cast<TransposeConvOp>(op);
2953 auto input_type = transpose_conv.input()
2954 .getType()
2955 .dyn_cast_or_null<mlir::RankedTensorType>();
2956 auto weight_type = transpose_conv.weights()
2957 .getType()
2958 .dyn_cast_or_null<mlir::RankedTensorType>();
2959 if (input_type && weight_type && input_type.hasStaticShape() &&
2960 weight_type.hasStaticShape()) {
2961 // Compute op count from the seven nested loops of
2962 // tflite::reference_ops::TransposeConv():
2963 count = 2 * input_type.getNumElements() * weight_type.getDimSize(0) *
2964 weight_type.getDimSize(1) * weight_type.getDimSize(2);
2965 }
2966
2967 return count;
2968 }
2969
2970 //===----------------------------------------------------------------------===//
2971 // StridedSliceOp
2972 //===----------------------------------------------------------------------===//
2973
Verify(StridedSliceOp op)2974 LogicalResult Verify(StridedSliceOp op) {
2975 auto ranked_input_type = op.input().getType().dyn_cast<RankedTensorType>();
2976
2977 // If input is unranked, there is nothing else to be verified.
2978 if (!ranked_input_type) return success();
2979 int num_input_dims = ranked_input_type.getRank();
2980
2981 if (auto begin_type = op.begin().getType().dyn_cast<RankedTensorType>()) {
2982 if (begin_type.getRank() != 1) return failure();
2983 if (begin_type.getDimSize(0) > num_input_dims) return failure();
2984 }
2985
2986 if (auto end_type = op.end().getType().dyn_cast<RankedTensorType>()) {
2987 if (end_type.getRank() != 1) return failure();
2988 if (end_type.getDimSize(0) > num_input_dims) return failure();
2989 }
2990
2991 if (auto strides_type = op.strides().getType().dyn_cast<RankedTensorType>()) {
2992 if (strides_type.getRank() != 1) return failure();
2993 if (strides_type.getDimSize(0) > num_input_dims) return failure();
2994 }
2995
2996 // The kernel will reshape the input tensor with new axis, it only supports
2997 // this reshaped tensor up to 5D.
2998 uint32_t ellipsis_mask = op.ellipsis_mask();
2999 uint32_t new_axis_mask = op.new_axis_mask();
3000 int num_added_axis = 0;
3001 for (int i = 0; i < 8; ++i) {
3002 if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) {
3003 num_added_axis++;
3004 }
3005 }
3006 if (num_input_dims + num_added_axis > 5) return failure();
3007 return success();
3008 }
3009
fold(ArrayRef<Attribute> operands)3010 OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
3011 // Currently only support all masks being 0.
3012 if (begin_mask() != 0 || end_mask() != 0 || ellipsis_mask() != 0 ||
3013 new_axis_mask() != 0 || shrink_axis_mask() != 0)
3014 return {};
3015
3016 auto input_type = input().getType().dyn_cast_or_null<RankedTensorType>();
3017 if (!input_type || !input_type.hasStaticShape()) return {};
3018
3019 // Begin has to be all 0s.
3020 DenseIntElementsAttr begin_dense_elem_attr;
3021 if (!matchPattern(begin(), m_Constant(&begin_dense_elem_attr))) {
3022 return {};
3023 }
3024 for (auto begin_ele : begin_dense_elem_attr) {
3025 if (begin_ele.getSExtValue() != 0) {
3026 return {};
3027 }
3028 }
3029
3030 // Strides has to be all 1s.
3031 DenseIntElementsAttr strides_dense_elem_attr;
3032 if (!matchPattern(strides(), m_Constant(&strides_dense_elem_attr))) {
3033 return {};
3034 }
3035 for (auto stride_ele : strides_dense_elem_attr) {
3036 if (stride_ele.getSExtValue() != 1) {
3037 return {};
3038 }
3039 }
3040 // End has to map the input shape.
3041 DenseIntElementsAttr end_dense_elem_attr;
3042 if (!matchPattern(end(), m_Constant(&end_dense_elem_attr))) {
3043 return {};
3044 }
3045 int i = 0;
3046 for (auto end_ele : end_dense_elem_attr) {
3047 if (end_ele.getSExtValue() != input_type.getDimSize(i)) {
3048 return {};
3049 }
3050 ++i;
3051 }
3052
3053 return input();
3054 }
3055
3056 //===----------------------------------------------------------------------===//
3057 // TransposeOp
3058 //===----------------------------------------------------------------------===//
3059
3060 namespace {
3061
3062 // Computes the permutation of a constant `input_tensor` according to `perm`.
3063 // The function recursively traverses the dimensions of the output tensor in
3064 // a row-major order and writes the value in the output tensor into
3065 // `new_values`.
ComputePermutation(ElementsAttr input_tensor,ArrayRef<int32_t> perm,ArrayRef<int64_t> output_shape,int num_dimensions,int output_axis,std::vector<uint64_t> * input_indices,std::vector<Attribute> * new_values)3066 void ComputePermutation(ElementsAttr input_tensor, ArrayRef<int32_t> perm,
3067 ArrayRef<int64_t> output_shape, int num_dimensions,
3068 int output_axis, std::vector<uint64_t> *input_indices,
3069 std::vector<Attribute> *new_values) {
3070 // Refer to the implementation of `Transpose` function in
3071 // tensorflow/lite/kernels/internal/reference/reference_ops.h
3072 assert(output_axis < num_dimensions);
3073 const int input_axis = perm[output_axis];
3074 for (int i = 0; i < output_shape[output_axis]; ++i) {
3075 // Update the input indices on `input_axis`.
3076 input_indices->at(input_axis) = i;
3077 // Write the value from `input_tensor` if it is the last axis or
3078 // recurse into the next axis.
3079 const bool is_last_axis = output_axis == num_dimensions - 1;
3080 if (is_last_axis) {
3081 new_values->push_back(input_tensor.getValue(*input_indices));
3082 } else {
3083 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3084 output_axis + 1, input_indices, new_values);
3085 }
3086 }
3087 }
3088
3089 } // namespace
3090
fold(ArrayRef<Attribute> operands)3091 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
3092 assert(operands.size() == 2);
3093 auto input_tensor = operands[0].dyn_cast_or_null<ElementsAttr>();
3094 auto perm_tensor = operands[1].dyn_cast_or_null<ElementsAttr>();
3095 if (!input_tensor || !perm_tensor) return nullptr;
3096
3097 // Do not try to fold elements attr of a quant type because
3098 // DenseElementsAttr does not support it.
3099 if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
3100 return nullptr;
3101
3102 assert(perm_tensor.getType().getRank() == 1);
3103 const int num_dimensions = input_tensor.getType().getRank();
3104 assert(perm_tensor.getType().getNumElements() == num_dimensions);
3105
3106 ArrayRef<int64_t> input_shape = input_tensor.getType().getShape();
3107 auto output_type = getType().cast<ShapedType>();
3108
3109 SmallVector<int32_t, 4> perm;
3110 SmallVector<int64_t, 4> output_shape;
3111 for (int i = 0; i < num_dimensions; ++i) {
3112 perm.push_back(
3113 perm_tensor.getValue<IntegerAttr>({static_cast<uint64_t>(i)}).getInt());
3114 output_shape.push_back(input_shape[perm[i]]);
3115
3116 // Check that the derived output shape matches the static shape.
3117 assert(!output_type.hasStaticShape() ||
3118 output_type.getShape()[i] == output_shape[i]);
3119 }
3120
3121 std::vector<Attribute> new_values;
3122 new_values.reserve(input_tensor.getType().getNumElements());
3123 std::vector<uint64_t> input_indices(num_dimensions);
3124 ComputePermutation(input_tensor, perm, output_shape, num_dimensions,
3125 /*output_axis=*/0, &input_indices, &new_values);
3126 auto result_type =
3127 RankedTensorType::get(output_shape, output_type.getElementType());
3128 return DenseElementsAttr::get(result_type, new_values);
3129 }
3130
Verify(TransposeOp op)3131 static LogicalResult Verify(TransposeOp op) {
3132 auto input_type = op.input().getType().cast<ShapedType>();
3133 auto perm_type = op.perm().getType().cast<ShapedType>();
3134 auto output_type = op.output().getType().cast<ShapedType>();
3135 if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
3136 if (perm_type.getNumElements() != input_type.getRank()) {
3137 return op.emitOpError(
3138 "perm tensor elements size is not equal to input tensor rank");
3139 }
3140 }
3141
3142 DenseIntElementsAttr perm;
3143 if (!matchPattern(op.perm(), m_Constant(&perm))) {
3144 return success();
3145 }
3146
3147 int index = 0;
3148 llvm::SmallVector<int64_t, 4> axes;
3149 for (const auto &axis_int : perm.getValues<APInt>()) {
3150 const int64_t axis = axis_int.getSExtValue();
3151 if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
3152 return op.emitOpError(
3153 llvm::formatv("perm[{0}] must be in [0, rank)", index));
3154 }
3155 if (std::count(axes.begin(), axes.end(), axis) > 0) {
3156 return op.emitOpError(
3157 llvm::formatv("perm[{0}] cannot have duplicated axis", index));
3158 }
3159 axes.push_back(axis);
3160 index++;
3161 }
3162
3163 if (input_type.hasStaticShape() && output_type.hasStaticShape()) {
3164 llvm::SmallVector<int64_t, 4> transposed_shape;
3165 for (int64_t axis : axes) {
3166 transposed_shape.push_back(input_type.getDimSize(axis));
3167 }
3168 auto expected_output_type =
3169 RankedTensorType::get(transposed_shape, input_type.getElementType());
3170 if (failed(
3171 mlir::verifyCompatibleShape(output_type, expected_output_type))) {
3172 return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
3173 expected_output_type, output_type));
3174 }
3175 }
3176
3177 return success();
3178 }
3179
BuildTransposeOp(OpBuilder * builder,OperationState & result,Value input,Value perm)3180 static void BuildTransposeOp(OpBuilder *builder, OperationState &result,
3181 Value input, Value perm) {
3182 // Output size is only known if input is ranked and perm is a constant.
3183 auto input_type = input.getType().cast<TensorType>();
3184 DenseIntElementsAttr perm_const;
3185 if (!input_type.hasRank() || !matchPattern(perm, m_Constant(&perm_const)) ||
3186 perm_const.getIntValues().empty()) {
3187 TFL::TransposeOp::build(
3188 *builder, result, UnrankedTensorType::get(input_type.getElementType()),
3189 input, perm);
3190 return;
3191 }
3192
3193 const auto perm_value_it = perm_const.getIntValues().begin();
3194
3195 const ArrayRef<int64_t> input_shape = input_type.getShape();
3196 SmallVector<int64_t, 4> output_shape(input_shape.size());
3197
3198 for (int i = 0; i < output_shape.size(); ++i) {
3199 const APInt perm_val = perm_value_it[i];
3200 output_shape[i] = input_shape[perm_val.getSExtValue()];
3201 }
3202
3203 TFL::TransposeOp::build(
3204 *builder, result,
3205 RankedTensorType::get(output_shape, input_type.getElementType()), input,
3206 perm);
3207 }
3208
3209 //===----------------------------------------------------------------------===//
3210 // IfOp
3211 //===----------------------------------------------------------------------===//
3212
3213 /// Given the region at `index`, or the parent operation if `index` is None,
3214 /// return the successor regions. These are the regions that may be selected
3215 /// during the flow of control. `operands` is a set of optional attributes that
3216 /// correspond to a constant value for each operand, or null if that operand is
3217 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)3218 void IfOp::getSuccessorRegions(Optional<unsigned> index,
3219 ArrayRef<Attribute> operands,
3220 SmallVectorImpl<RegionSuccessor> ®ions) {
3221 // The `then` and the `else` region branch back to the parent operation.
3222 if (index.hasValue()) {
3223 regions.push_back(RegionSuccessor(getResults()));
3224 return;
3225 }
3226
3227 // Don't consider the else region if it is empty.
3228 Region *else_reg = &else_region();
3229 if (else_reg->empty()) else_reg = nullptr;
3230
3231 // Otherwise, the successor is dependent on the condition.
3232 bool condition;
3233 if (auto cond_attr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
3234 condition = cond_attr.getValue().isOneValue();
3235 } else {
3236 // If the condition isn't constant, both regions may be executed.
3237 regions.push_back(RegionSuccessor(&then_region()));
3238 // If the else region does not exist, it is not a viable successor.
3239 if (else_reg) regions.push_back(RegionSuccessor(else_reg));
3240 return;
3241 }
3242
3243 // Add the successor regions using the condition.
3244 regions.push_back(RegionSuccessor(condition ? &then_region() : else_reg));
3245 }
3246
3247 //===----------------------------------------------------------------------===//
3248 // WhileOp
3249 //===----------------------------------------------------------------------===//
3250
Verify(WhileOp op)3251 LogicalResult Verify(WhileOp op) {
3252 if (op.getNumOperands() != op.getNumResults())
3253 return op.emitOpError(llvm::formatv(
3254 "number of operands does not match number of results ({0} != {1})",
3255 op.getNumOperands(), op.getNumResults()));
3256 // TODO(jpienaar): Verify operand, result & block arguments types
3257 return success();
3258 }
3259
3260 namespace {
3261 // Canonicalize While op so that results and operands match and external values
3262 // are via implicit capture rather than via block args.
3263 struct WhileResultOperandsMatchAndImplicitCapture
3264 : public OpRewritePattern<WhileOp> {
3265 using OpRewritePattern<WhileOp>::OpRewritePattern;
3266
matchAndRewritemlir::TFL::__anona97bba202211::WhileResultOperandsMatchAndImplicitCapture3267 LogicalResult matchAndRewrite(WhileOp while_op,
3268 PatternRewriter &rewriter) const override {
3269 // Replace values simply passed through the body with extern values
3270 // (in both body and condition regions as well as while result). The
3271 // block arguments of body and while match and so the corresponding cond
3272 // argument can be easily found.
3273 bool unchanged = true;
3274 auto &body_block = while_op.body().front();
3275 auto &cond_block = while_op.cond().front();
3276 auto &yield = *body_block.getTerminator();
3277 for (auto ba : body_block.getArguments()) {
3278 int arg_no = ba.getArgNumber();
3279 // Skip removing resources that are not read-only variables.
3280 if (getElementTypeOrSelf(ba.getType()).isa<TF::ResourceType>()) {
3281 bool has_read_only_variables = true;
3282 for (auto user : ba.getUsers()) {
3283 // Ternimator ops, for example, tfl::yield op, should be ignored since
3284 // the argument can be used for yielding as the `body` function result
3285 // and that does not give any meaningful points to the decision
3286 // whether the given arugment is a read-only variable or not.
3287 if (user->hasTrait<OpTrait::IsTerminator>()) continue;
3288 if (!llvm::isa<mlir::TF::ReadVariableOp>(user)) {
3289 has_read_only_variables = false;
3290 break;
3291 }
3292 }
3293 if (!has_read_only_variables) continue;
3294 }
3295 if (ba == yield.getOperand(arg_no)) {
3296 unchanged = false;
3297 auto value = while_op.getOperand(arg_no);
3298 ba.replaceAllUsesWith(value);
3299 cond_block.getArgument(arg_no).replaceAllUsesWith(value);
3300
3301 // This could be relaxed and casts inserted.
3302 if (while_op.getResult(arg_no).getType() == value.getType())
3303 while_op.getResult(arg_no).replaceAllUsesWith(value);
3304 }
3305 }
3306
3307 // The While ops operands and result types need to match
3308 SmallVector<Value, 4> new_operands;
3309 SmallVector<Value, 4> new_body_yield;
3310 SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
3311 llvm::SmallVector<Type, 4> types;
3312 new_operands.reserve(while_op.getNumOperands());
3313 new_body_yield.reserve(while_op.getNumOperands());
3314 types.reserve(while_op.getNumOperands());
3315
3316 // Remove block arguments not used in either cond or body. This leaves the
3317 // block arguments of body and cond matching still.
3318 int arg_index = 0;
3319 for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
3320 ++while_index) {
3321 auto value = while_op.getOperand(while_index);
3322 if (body_block.getArgument(arg_index).use_empty() &&
3323 cond_block.getArgument(arg_index).use_empty() &&
3324 // Note: since we are not erasing results, need to use while_index
3325 // to check if the corresponding result is unused.
3326 while_op.getResult(while_index).use_empty()) {
3327 unchanged = false;
3328 body_block.eraseArgument(arg_index);
3329 cond_block.eraseArgument(arg_index);
3330
3331 // Mark operand for removal.
3332 removed_operand[while_index] = true;
3333 } else {
3334 new_operands.push_back(value);
3335 new_body_yield.push_back(yield.getOperand(while_index));
3336 auto type = while_op.getResult(while_index).getType();
3337 types.push_back(type);
3338 ++arg_index;
3339 }
3340 }
3341
3342 // Done if no values removed from blocks and operands & results match.
3343 if (unchanged) return failure();
3344
3345 // Replace with new While with matching operands and results.
3346 Operation *op = while_op.getOperation();
3347 Operation *new_op = rewriter.insert(
3348 Operation::create(op->getLoc(), op->getName(), types, new_operands,
3349 op->getAttrs(), {}, /*numRegions=*/2));
3350
3351 for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
3352 int new_index = 0;
3353 for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
3354 if (removed_operand[op_index]) continue;
3355 op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
3356 ++new_index;
3357 }
3358 rewriter.eraseOp(op);
3359
3360 Block &new_body_block = cast<WhileOp>(new_op).body().front();
3361 rewriter.setInsertionPointToEnd(&new_body_block);
3362 rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
3363 new_body_yield);
3364
3365 return success();
3366 }
3367 };
3368
3369 } // namespace
3370
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)3371 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
3372 MLIRContext *context) {
3373 results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
3374 }
3375
getLoopBody()3376 Region &WhileOp::getLoopBody() { return body(); }
3377
isDefinedOutsideOfLoop(Value value)3378 bool WhileOp::isDefinedOutsideOfLoop(Value value) {
3379 // TODO(jpienaar): This is to overly conservative and disables anything other
3380 // than constant hoisting initially.
3381 return false;
3382 }
3383
moveOutOfLoop(llvm::ArrayRef<mlir::Operation * > ops)3384 LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
3385 if (ops.empty()) return success();
3386
3387 // Move the hoisted value to just before the while.
3388 Operation *while_op = this->getOperation();
3389 for (auto op : ops) op->moveBefore(while_op);
3390
3391 return success();
3392 }
3393
3394 //===----------------------------------------------------------------------===//
3395 // LogisticOp
3396 //===----------------------------------------------------------------------===//
3397
GetArithmeticCount(Operation * op)3398 int64_t LogisticOp::GetArithmeticCount(Operation *op) {
3399 int64_t count;
3400 // As a very rough ballpark, the cost of evaluating a math function
3401 // such as tanh or logistic is about 32 multiplications, and about as
3402 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3403 // from looking at actual implementations that we use in runtime/code).
3404 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3405 return 64 * count;
3406
3407 return -1;
3408 }
3409
3410 //===----------------------------------------------------------------------===//
3411 // LogSoftmaxOp
3412 //===----------------------------------------------------------------------===//
3413
GetArithmeticCount(Operation * op)3414 int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) {
3415 int64_t count;
3416 // As a very rough ballpark, the cost of evaluating a math function
3417 // such as tanh or logistic is about 32 multiplications, and about as
3418 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3419 // from looking at actual implementations that we use in runtime/code).
3420 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3421 return 64 * count;
3422
3423 return -1;
3424 }
3425
3426 //===----------------------------------------------------------------------===//
3427 // SoftmaxOp
3428 //===----------------------------------------------------------------------===//
3429
GetArithmeticCount(Operation * op)3430 int64_t SoftmaxOp::GetArithmeticCount(Operation *op) {
3431 int64_t count;
3432 // As a very rough ballpark, the cost of evaluating a math function
3433 // such as tanh or logistic is about 32 multiplications, and about as
3434 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3435 // from looking at actual implementations that we use in runtime/code).
3436 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3437 return 64 * count;
3438
3439 return -1;
3440 }
3441
3442 //===----------------------------------------------------------------------===//
3443 // TanhOp
3444 //===----------------------------------------------------------------------===//
3445
GetArithmeticCount(Operation * op)3446 int64_t TanhOp::GetArithmeticCount(Operation *op) {
3447 int64_t count;
3448 // As a very rough ballpark, the cost of evaluating a math function
3449 // such as tanh or logistic is about 32 multiplications, and about as
3450 // many additions/subtractions. (Just a power-of-two order-of-magnitude
3451 // from looking at actual implementations that we use in runtime/code).
3452 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count))
3453 return 64 * count;
3454
3455 return -1;
3456 }
3457
3458 //===----------------------------------------------------------------------===//
3459 // AddNOp
3460 //===----------------------------------------------------------------------===//
3461
GetArithmeticCount(Operation * op)3462 int64_t AddNOp::GetArithmeticCount(Operation *op) {
3463 int64_t count;
3464 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3465 // AddN cost is roughly the same cost as N-1 Adds.
3466 const int64_t num_adds = op->getNumOperands() - 1;
3467 return num_adds * count;
3468 }
3469
3470 return -1;
3471 }
3472
3473 //===----------------------------------------------------------------------===//
3474 // AveragePool2DOp
3475 //===----------------------------------------------------------------------===//
3476
GetArithmeticCount(Operation * op)3477 int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) {
3478 int64_t count;
3479 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3480 auto avg_pool = llvm::dyn_cast<AveragePool2DOp>(op);
3481 return avg_pool.filter_height() * avg_pool.filter_width() * count;
3482 }
3483
3484 return -1;
3485 }
3486
3487 //===----------------------------------------------------------------------===//
3488 // MaxPool2DOp
3489 //===----------------------------------------------------------------------===//
3490
GetArithmeticCount(Operation * op)3491 int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) {
3492 int64_t count;
3493 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3494 auto max_pool = llvm::dyn_cast<MaxPool2DOp>(op);
3495 return max_pool.filter_height() * max_pool.filter_width() * count;
3496 }
3497
3498 return -1;
3499 }
3500
3501 //===----------------------------------------------------------------------===//
3502 // L2NormalizationOp
3503 //===----------------------------------------------------------------------===//
3504
GetArithmeticCount(Operation * op)3505 int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) {
3506 int64_t count;
3507 // Computing the squared L2 norm is N multiply-adds so 2N ops,
3508 // then the single inverse-sqrt is negligible, then we multiply each
3509 // value by the resulting multiplier, so an extra N ops. count 3N ops.
3510 if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) {
3511 return 3 * count;
3512 }
3513
3514 return -1;
3515 }
3516
3517 //===----------------------------------------------------------------------===//
3518 // PadOp
3519 //===----------------------------------------------------------------------===//
3520
fold(ArrayRef<Attribute> operands)3521 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
3522 if (InputOutputHasSameShape(input().getType(), output().getType()))
3523 return input();
3524
3525 return {};
3526 }
3527
3528 //===----------------------------------------------------------------------===//
3529 // PadV2Op
3530 //===----------------------------------------------------------------------===//
3531
fold(ArrayRef<Attribute> operands)3532 OpFoldResult PadV2Op::fold(ArrayRef<Attribute> operands) {
3533 if (InputOutputHasSameShape(input().getType(), output().getType()))
3534 return input();
3535
3536 return {};
3537 }
3538
3539 //===----------------------------------------------------------------------===//
3540 // TableGen'd op method definitions
3541 //===----------------------------------------------------------------------===//
3542
3543 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
3544
3545 } // namespace TFL
3546 } // namespace mlir
3547
3548 #define GET_OP_CLASSES
3549 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
3550
3551 namespace mlir {
3552 namespace TFL {
3553
3554 #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
3555
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)3556 Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
3557 Attribute value,
3558 Type type, Location loc) {
3559 // If this is an opaque elements attribute or the result type doesn't match
3560 // the attribute type, then generate a tfl.pseudo_const.
3561 if (value.isa<OpaqueElementsAttr>() ||
3562 (value.isa<ElementsAttr>() && value.getType() != type))
3563 return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
3564 if (ConstantOp::isBuildableWith(value, type))
3565 return builder.create<ConstantOp>(loc, type, value);
3566 return nullptr;
3567 }
3568
3569 } // namespace TFL
3570 } // namespace mlir
3571