1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef MLIR_HLO_UTILS_HLO_UTILS_H
17 #define MLIR_HLO_UTILS_HLO_UTILS_H
18
19 #include <string>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27
28 namespace mlir {
29 namespace hlo {
30 // Computes the broadcast dimensions attr for an elementwise binary operator
31 // between two ranked tensors.
32 // If `allow_empty` is true, then null can be returned to mean that the
33 // broadcast is an "identity".
34 mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
35 mlir::Value x,
36 mlir::Value y,
37 bool allowEmpty = true);
38
39 // Get a constant splat for the given value of type. Requires value to be of
40 // type static shaped RankedTensorType.
41 template <typename T>
getSplat(Builder * b,RankedTensorType ty,T constant)42 static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
43 Type elementTy = getElementTypeOrSelf(ty);
44
45 if (elementTy.isSignlessInteger())
46 return DenseElementsAttr::get(ty, b->getIntegerAttr(elementTy, constant));
47
48 if (elementTy.isa<FloatType>())
49 return DenseElementsAttr::get(ty, b->getFloatAttr(elementTy, constant));
50
51 if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
52 auto complexElementTy = complexTy.getElementType();
53 if (complexElementTy.isF32())
54 return DenseElementsAttr::get(ty,
55 static_cast<std::complex<float>>(constant));
56 if (complexElementTy.isF64())
57 return DenseElementsAttr::get(
58 ty, static_cast<std::complex<double>>(constant));
59 }
60 llvm_unreachable("unhandled element type");
61 }
62
63 template <typename T>
getSplat(Builder * b,Value val,T constant)64 static ElementsAttr getSplat(Builder* b, Value val, T constant) {
65 return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
66 }
67
68 // Returns DenseElementsAttr of rank zero with the given element type and the
69 // value.
70 // Requires `ty` to be either FloatType, IntegerType, or ComplexType.
71 DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue);
72
73 // Returns DenseElementsAttr of rank zero with the given element type and the
74 // value which is the neutral element for additions.
75 // Requires `ty` to be either FloatType, IntegerType, or ComplexType.
76 DenseElementsAttr getScalarNegZeroOfType(Type ty);
77
78 // Enum type used to specify scalar argument to GetScalarLimitOfType.
79 enum ScalarLimit {
80 kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
81 kInfinityLowest, // Like kLowest, but returns -infinity where available.
82 kMax, // The scalar corresponding to numeric_limits<T>::max.
83 kInfinityMax, // Like kMax, but returns infinity where available.
84 };
85
86 // Returns a scalar limit value for the given type.
87 //
88 // The argument 'limit' describes which scalar value to return.
89 //
90 // Requires `ty` to be either FloatType or IntegerType.
91 DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit);
92
93 // Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
94 // Returns empty string if no such op exists.
95 std::string lmhloToMhloOpName(llvm::StringRef opName,
96 mlir::MLIRContext* context);
97
98 // Return true if Attr has values [0, 1, ...].
99 bool isSequenceStartingWith0(Attribute attr);
100
101 // Returns the argument index for the giving FuncOp and its operand value.
102 int64_t getArgumentIndex(func::FuncOp op, Value value);
103
104 /// Computes the memory usage of the given allocations.
105 std::pair<size_t, size_t> computeMemory(const std::vector<Value>& allocs);
106
107 } // namespace hlo
108 } // namespace mlir
109
110 #endif // MLIR_HLO_UTILS_HLO_UTILS_H
111