1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 16 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/Support/LLVM.h" 22 #include "llvm/ADT/SmallVector.h" 23 24 namespace mlir { 25 26 /// ShapedTypeComponents that represents the components of a ShapedType. 27 /// The components consist of 28 /// - A ranked or unranked shape with the dimension specification match those 29 /// of ShapeType's getShape() (e.g., dynamic dimension represented using 30 /// ShapedType::kDynamicSize) 31 /// - A element type, may be unset (nullptr) 32 /// - A attribute, may be unset (nullptr) 33 /// Used by ShapedType type inferences. 34 class ShapedTypeComponents { 35 /// Internal storage type for shape. 36 using ShapeStorageT = SmallVector<int64_t, 3>; 37 38 public: 39 /// Default construction is an unranked shape. ShapedTypeComponents()40 ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){}; ShapedTypeComponents(Type elementType)41 ShapedTypeComponents(Type elementType) 42 : ranked(false), elementType(elementType), attr(nullptr) {} 43 template <typename Arg, typename = typename std::enable_if_t< 44 std::is_constructible<ShapeStorageT, Arg>::value>> 45 ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, 46 Attribute attr = nullptr) dims(std::forward<Arg> (arg))47 : dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType), 48 attr(attr) {} 49 ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr, 50 Attribute attr = nullptr) 51 : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType), 52 attr(attr) {} 53 54 /// Return the dimensions of the shape. 55 /// Requires: shape is ranked. getDims()56 ArrayRef<int64_t> getDims() const { 57 assert(ranked && "requires ranked shape"); 58 return dims; 59 } 60 61 /// Return whether the shape has a rank. hasRank()62 bool hasRank() const { return ranked; }; 63 64 /// Return the element type component. getElementType()65 Type getElementType() const { return elementType; }; 66 67 /// Return the raw attribute component. getAttribute()68 Attribute getAttribute() const { return attr; }; 69 70 private: 71 ShapeStorageT dims; 72 bool ranked; 73 Type elementType; 74 Attribute attr; 75 }; 76 77 namespace detail { 78 // Helper function to infer return tensor returns types given element and shape 79 // inference function. 80 // 81 // TODO: Consider generating typedefs for trait member functions if this usage 82 // becomes more common. 83 LogicalResult inferReturnTensorTypes( 84 function_ref<LogicalResult( 85 MLIRContext *, Optional<Location> location, ValueRange operands, 86 DictionaryAttr attributes, RegionRange regions, 87 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 88 componentTypeFn, 89 MLIRContext *context, Optional<Location> location, ValueRange operands, 90 DictionaryAttr attributes, RegionRange regions, 91 SmallVectorImpl<Type> &inferredReturnTypes); 92 93 /// Verifies that the inferred result types match the actual result types for 94 /// the op. Precondition: op implements InferTypeOpInterface. 95 LogicalResult verifyInferredResultTypes(Operation *op); 96 } // namespace detail 97 98 namespace OpTrait { 99 100 /// Tensor type inference trait that constructs a tensor from the inferred 101 /// shape and elemental types. 102 /// Requires: Op implements functions of InferShapedTypeOpInterface. 103 template <typename ConcreteType> 104 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> { 105 public: 106 static LogicalResult inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)107 inferReturnTypes(MLIRContext *context, Optional<Location> location, 108 ValueRange operands, DictionaryAttr attributes, 109 RegionRange regions, 110 SmallVectorImpl<Type> &inferredReturnTypes) { 111 return ::mlir::detail::inferReturnTensorTypes( 112 ConcreteType::inferReturnTypeComponents, context, location, operands, 113 attributes, regions, inferredReturnTypes); 114 } 115 }; 116 117 } // namespace OpTrait 118 } // namespace mlir 119 120 /// Include the generated interface declarations. 121 #include "mlir/Interfaces/InferTypeOpInterface.h.inc" 122 123 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ 124