1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_
18
19 #include "mlir/IR/Matchers.h" // from @llvm-project
20 #include "mlir/IR/PatternMatch.h" // from @llvm-project
21
22 namespace mlir {
23 namespace TF {
24
25 // Returns int, float or complex DenseElementsAttr with scalar shape with the
26 // given element type and the integer value.
27 template <typename T>
GetScalarOfType(Type ty,T raw_value)28 DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
29 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
30 if (auto float_ty = ty.dyn_cast<FloatType>()) {
31 FloatAttr attr = FloatAttr::get(float_ty, raw_value);
32 return DenseElementsAttr::get(scalar_ty, attr);
33 } else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
34 IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
35 return DenseElementsAttr::get(scalar_ty, attr);
36 } else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
37 Type complex_element_ty = complex_ty.getElementType();
38 if (complex_element_ty.isF32()) {
39 return DenseElementsAttr::get(
40 scalar_ty, static_cast<std::complex<float>>(raw_value));
41 } else if (complex_element_ty.isF64()) {
42 return DenseElementsAttr::get(
43 scalar_ty, static_cast<std::complex<double>>(raw_value));
44 }
45 }
46 llvm_unreachable("unsupported type");
47 }
48
49 // Returns true if `value` is compile-time constant and its splat value equals
50 // to `raw_value`.
51 template <typename T>
IsConstantValueOf(Value value,T raw_value)52 bool IsConstantValueOf(Value value, T raw_value) {
53 auto element_type = value.getType().cast<ShapedType>().getElementType();
54 if (element_type.isa<FloatType>()) {
55 DenseFPElementsAttr float_attr;
56 if (matchPattern(value, m_Constant(&float_attr)) && float_attr.isSplat() &&
57 float_attr.getSplatValue<APFloat>().isExactlyValue(raw_value))
58 return true;
59 } else if (element_type.isa<IntegerType>()) {
60 DenseIntElementsAttr int_attr;
61 if (matchPattern(value, m_Constant(&int_attr)) && int_attr.isSplat() &&
62 int_attr.getSplatValue<APInt>() == raw_value)
63 return true;
64 }
65
66 return false;
67 }
68
69 // Returns true if `op` is placed on GPU device, and false if it's on other
70 // devices or the device is not specified.
71 bool IsOnGpuDevice(mlir::Operation *op);
72
73 } // namespace TF
74 } // namespace mlir
75
76 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_
77