1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
18
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/Dialect/Traits.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/Location.h" // from @llvm-project
25 #include "mlir/IR/OpDefinition.h" // from @llvm-project
26 #include "mlir/IR/TypeRange.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28
29 namespace mlir {
30
31 class Operation;
32
33 namespace TF {
34
35 class AddV2Op;
36 class SubOp;
37 class MulOp;
38 class DivOp;
39 class RealDivOp;
40
41 // Verifies an reduction op's `input` and reduction `dims`.
42 LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
43 Location loc);
44
45 // A type range with description (in singular form) attached to it.
46 using TypeRangeWithDesc = std::pair<TypeRange, StringRef>;
47
48 LogicalResult VerifyTypeRangesAreCompatible(Operation *op,
49 TypeRangeWithDesc range0,
50 TypeRangeWithDesc range1);
51
52 // Fold Arithmetic Op if one of the operands is a constant known to be an
53 // Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
54 // known identity value is either lhs or rhs.
55 template <
56 typename OpT,
57 typename std::enable_if<llvm::is_one_of<
58 OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
IdentityArithmeticOpFolder(OpT arithmetic_op,ArrayRef<Attribute> operands)59 OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
60 ArrayRef<Attribute> operands) {
61 auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
62 auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
63 auto result_type =
64 arithmetic_op.getResult().getType().template cast<ShapedType>();
65
66 // We can fold arithmetic operation only of we can prove that we will not
67 // accidentally hide a broadcasting error.
68 auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
69 ShapedType result_ty) -> bool {
70 // Scalar identity is broadcastable to any operand shape, we only need to
71 // check that operand has the same shape as a result.
72 bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
73 if (scalar_identity) return operand_ty == result_ty;
74
75 // If identity is not a scalar, we must verify that identity shape is
76 // statically known to be broadcastable to the operand shape and the operand
77 // and result shape are equal.
78 return operand_ty == result_ty && identity_ty.hasStaticShape() &&
79 result_ty.hasStaticShape() &&
80 OpTrait::util::staticallyKnownBroadcastable(operand_ty.getShape(),
81 identity_ty.getShape());
82 };
83
84 // Check that we have a constant operand on one side (candidate for identity).
85 const bool is_commutative =
86 (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
87 auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
88 auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
89 if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
90
91 // Mul and Div ops have identity value one while AddV2 and SubOp have identity
92 // value zero.
93 const int identity =
94 (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
95 std::is_same<OpT, RealDivOp>::value)
96 ? 1
97 : 0;
98
99 Type element_ty = lhs_type.getElementType();
100 Attribute identity_attr;
101 if (auto ty = element_ty.template dyn_cast<FloatType>()) {
102 identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
103 } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
104 identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
105 } else {
106 return {};
107 }
108
109 // Fold: Op(Operand, Identity) -> Operand.
110 if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
111 if (rhs_attr.isSplat() &&
112 rhs_attr.getSplatValue<Attribute>() == identity_attr)
113 return arithmetic_op.x();
114 }
115
116 // Fold: Op(Identity, Operand) -> Operand for commutative operations.
117 if (lhs_attr && is_commutative &&
118 is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
119 if (lhs_attr.isSplat() &&
120 lhs_attr.getSplatValue<Attribute>() == identity_attr)
121 return arithmetic_op.y();
122 }
123
124 return {};
125 }
126
127 } // namespace TF
128 } // namespace mlir
129
130 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
131