• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
18 
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 namespace mlir {
26 namespace hlo {
27 // Attrs for OP type
28 // TODO(disc): create and move to placement_utils.h
29 constexpr llvm::StringRef kDiscShapeCalcAttr = "disc.shape_op";
30 
31 // Attrs for placement
32 constexpr llvm::StringRef kDiscPlaceAssignment = "disc.device";
33 constexpr llvm::StringRef kCpu = "cpu";
34 constexpr llvm::StringRef kGpu = "gpu";
35 enum class PlacementType {
36   kCpu,
37   kGpu,
38 };
39 
40 // Function arguments and results placement attributes.
41 constexpr StringRef kInputPlacementAttr = "hlo.input_placements";
42 constexpr StringRef kOutputPlacementAttr = "hlo.output_placements";
43 
44 // Computes the broadcast dimensions attr for an elementwise binary operator
45 // between two ranked tensors.
46 // If `allow_empty` is true, then null can be returned to mean that the
47 // broadcast is an "identity".
48 mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
49                                                       mlir::Value x,
50                                                       mlir::Value y,
51                                                       bool allow_empty = true);
52 
53 // Get a constant splat for the given value of type. Requires value to be of
54 // type static shaped RankedTensorType.
55 template <typename T>
getSplat(Builder * b,RankedTensorType ty,T constant)56 static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
57   Type element_ty = getElementTypeOrSelf(ty);
58 
59   if (element_ty.isSignlessInteger())
60     return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
61 
62   if (element_ty.isa<FloatType>())
63     return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
64 
65   if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
66     auto complex_element_ty = complex_ty.getElementType();
67     if (complex_element_ty.isF32())
68       return DenseElementsAttr::get(ty,
69                                     static_cast<std::complex<float>>(constant));
70     if (complex_element_ty.isF64())
71       return DenseElementsAttr::get(
72           ty, static_cast<std::complex<double>>(constant));
73   }
74   llvm_unreachable("unhandled element type");
75 }
76 
77 template <typename T>
getSplat(Builder * b,Value val,T constant)78 static ElementsAttr getSplat(Builder* b, Value val, T constant) {
79   return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
80 }
81 
82 // Returns DenseElementsAttr of rank zero with the given element type and the
83 // value.
84 // Requires `ty` to be either FloatType, IntegerType, or ComplexType.
85 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
86 
87 // Enum type used to specify scalar argument to GetScalarLimitOfType.
88 enum ScalarLimit {
89   kLowest,          // The scalar corresponding to numeric_limits<T>::lowest.
90   kInfinityLowest,  // Like kLowest, but returns -infinity where available.
91   kMax,             // The scalar corresponding to numeric_limits<T>::max.
92   kInfinityMax,     // Like kMax, but returns infinity where available.
93 };
94 
95 // Returns a scalar limit value for the given type.
96 //
97 // The argument 'limit' describes which scalar value to return.
98 //
99 // Requires `ty` to be either FloatType or IntegerType.
100 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
101 
102 // Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
103 // Returns empty string if no such op exists.
104 std::string LmhloToMhloOpName(llvm::StringRef op_name,
105                               mlir::MLIRContext* context);
106 
107 // Return true if Attr has values [0, 1, ...].
108 bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
109 
110 // Returns the argument index for the giving FuncOp and its operand value.
111 int64_t getArgumentIndex(mlir::FuncOp op, Value value);
112 
113 }  // namespace hlo
114 }  // namespace mlir
115 
116 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
117