• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
18 
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/Dialect/Traits.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/Location.h"  // from @llvm-project
25 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
26 #include "mlir/IR/TypeRange.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 
29 namespace mlir {
30 
31 class Operation;
32 
33 namespace TF {
34 
35 class AddV2Op;
36 class SubOp;
37 class MulOp;
38 class DivOp;
39 class RealDivOp;
40 
41 // Verifies an reduction op's `input` and reduction `dims`.
42 LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
43                                           Location loc);
44 
45 // A type range with description (in singular form) attached to it.
46 using TypeRangeWithDesc = std::pair<TypeRange, StringRef>;
47 
48 LogicalResult VerifyTypeRangesAreCompatible(Operation *op,
49                                             TypeRangeWithDesc range0,
50                                             TypeRangeWithDesc range1);
51 
52 // Fold Arithmetic Op if one of the operands is a constant known to be an
53 // Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
54 // known identity value is either lhs or rhs.
55 template <
56     typename OpT,
57     typename std::enable_if<llvm::is_one_of<
58         OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
IdentityArithmeticOpFolder(OpT arithmetic_op,ArrayRef<Attribute> operands)59 OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
60                                         ArrayRef<Attribute> operands) {
61   auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
62   auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
63   auto result_type =
64       arithmetic_op.getResult().getType().template cast<ShapedType>();
65 
66   // We can fold arithmetic operation only of we can prove that we will not
67   // accidentally hide a broadcasting error.
68   auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
69                                   ShapedType result_ty) -> bool {
70     // Scalar identity is broadcastable to any operand shape, we only need to
71     // check that operand has the same shape as a result.
72     bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
73     if (scalar_identity) return operand_ty == result_ty;
74 
75     // If identity is not a scalar, we must verify that identity shape is
76     // statically known to be broadcastable to the operand shape and the operand
77     // and result shape are equal.
78     return operand_ty == result_ty && identity_ty.hasStaticShape() &&
79            result_ty.hasStaticShape() &&
80            OpTrait::util::staticallyKnownBroadcastable(operand_ty.getShape(),
81                                                        identity_ty.getShape());
82   };
83 
84   // Check that we have a constant operand on one side (candidate for identity).
85   const bool is_commutative =
86       (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
87   auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
88   auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
89   if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
90 
91   // Mul and Div ops have identity value one while AddV2 and SubOp have identity
92   // value zero.
93   const int identity =
94       (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
95        std::is_same<OpT, RealDivOp>::value)
96           ? 1
97           : 0;
98 
99   Type element_ty = lhs_type.getElementType();
100   Attribute identity_attr;
101   if (auto ty = element_ty.template dyn_cast<FloatType>()) {
102     identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
103   } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
104     identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
105   } else {
106     return {};
107   }
108 
109   // Fold: Op(Operand, Identity) -> Operand.
110   if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
111     if (rhs_attr.isSplat() &&
112         rhs_attr.getSplatValue<Attribute>() == identity_attr)
113       return arithmetic_op.x();
114   }
115 
116   // Fold: Op(Identity, Operand) -> Operand for commutative operations.
117   if (lhs_attr && is_commutative &&
118       is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
119     if (lhs_attr.isSplat() &&
120         lhs_attr.getSplatValue<Attribute>() == identity_attr)
121       return arithmetic_op.y();
122   }
123 
124   return {};
125 }
126 
127 }  // namespace TF
128 }  // namespace mlir
129 
130 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_
131