1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
18
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24
25 namespace mlir {
26 namespace hlo {
27
28 // Computes the broadcast dimensions attr for an elementwise binary operator
29 // between two ranked tensors.
30 // If `allow_empty` is true, then null can be returned to mean that the
31 // broadcast is an "identity".
32 mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
33 mlir::Value x,
34 mlir::Value y,
35 bool allow_empty = true);
36
37 // Get a constant splat for the given value of type. Requires value to be of
38 // type static shaped RankedTensorType.
39 template <typename T>
getSplat(Builder * b,RankedTensorType ty,T constant)40 static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
41 Type element_ty = getElementTypeOrSelf(ty);
42
43 if (element_ty.isSignlessInteger())
44 return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
45
46 if (element_ty.isa<FloatType>())
47 return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
48
49 if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
50 auto complex_element_ty = complex_ty.getElementType();
51 if (complex_element_ty.isF32())
52 return DenseElementsAttr::get(ty,
53 static_cast<std::complex<float>>(constant));
54 if (complex_element_ty.isF64())
55 return DenseElementsAttr::get(
56 ty, static_cast<std::complex<double>>(constant));
57 }
58 llvm_unreachable("unhandled element type");
59 }
60
61 template <typename T>
getSplat(Builder * b,Value val,T constant)62 static ElementsAttr getSplat(Builder* b, Value val, T constant) {
63 return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
64 }
65
66 // Returns DenseElementsAttr of rank zero with the given element type and the
67 // value.
68 // Requires `ty` to be either FloatType, IntegerType, or ComplexType.
69 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
70
71 // Enum type used to specify scalar argument to GetScalarLimitOfType.
72 enum ScalarLimit {
73 kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
74 kInfinityLowest, // Like kMax, but returns -infinity where available.
75 kMax, // The scalar corresponding to numeric_limits<T>::max.
76 kInfinityMax, // Like kMax, but returns infinity where available.
77 };
78
79 // Returns a scalar limit value for the given type.
80 //
81 // The argument 'limit' describes which scalar value to return.
82 //
83 // Requires `ty` to be either FloatType or IntegerType.
84 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
85
86 // Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
87 // Returns empty string if no such op exists.
88 std::string LmhloToMhloOpName(llvm::StringRef op_name,
89 mlir::MLIRContext* context);
90
91 // Return true if Attr has values [0, 1, ...].
92 bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
93
94 } // namespace hlo
95 } // namespace mlir
96
97 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
98