• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 namespace mlir {
25 
26 /// ShapedTypeComponents that represents the components of a ShapedType.
27 /// The components consist of
28 ///  - A ranked or unranked shape with the dimension specification match those
29 ///    of ShapeType's getShape() (e.g., dynamic dimension represented using
30 ///    ShapedType::kDynamicSize)
31 ///  - A element type, may be unset (nullptr)
32 ///  - A attribute, may be unset (nullptr)
33 /// Used by ShapedType type inferences.
34 class ShapedTypeComponents {
35   /// Internal storage type for shape.
36   using ShapeStorageT = SmallVector<int64_t, 3>;
37 
38 public:
39   /// Default construction is an unranked shape.
ShapedTypeComponents()40   ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){};
ShapedTypeComponents(Type elementType)41   ShapedTypeComponents(Type elementType)
42       : ranked(false), elementType(elementType), attr(nullptr) {}
43   template <typename Arg, typename = typename std::enable_if_t<
44                               std::is_constructible<ShapeStorageT, Arg>::value>>
45   ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
46                        Attribute attr = nullptr)
dims(std::forward<Arg> (arg))47       : dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType),
48         attr(attr) {}
49   ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
50                        Attribute attr = nullptr)
51       : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType),
52         attr(attr) {}
53 
54   /// Return the dimensions of the shape.
55   /// Requires: shape is ranked.
getDims()56   ArrayRef<int64_t> getDims() const {
57     assert(ranked && "requires ranked shape");
58     return dims;
59   }
60 
61   /// Return whether the shape has a rank.
hasRank()62   bool hasRank() const { return ranked; };
63 
64   /// Return the element type component.
getElementType()65   Type getElementType() const { return elementType; };
66 
67   /// Return the raw attribute component.
getAttribute()68   Attribute getAttribute() const { return attr; };
69 
70 private:
71   ShapeStorageT dims;
72   bool ranked;
73   Type elementType;
74   Attribute attr;
75 };
76 
77 namespace detail {
78 // Helper function to infer return tensor returns types given element and shape
79 // inference function.
80 //
81 // TODO: Consider generating typedefs for trait member functions if this usage
82 // becomes more common.
83 LogicalResult inferReturnTensorTypes(
84     function_ref<LogicalResult(
85         MLIRContext *, Optional<Location> location, ValueRange operands,
86         DictionaryAttr attributes, RegionRange regions,
87         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
88         componentTypeFn,
89     MLIRContext *context, Optional<Location> location, ValueRange operands,
90     DictionaryAttr attributes, RegionRange regions,
91     SmallVectorImpl<Type> &inferredReturnTypes);
92 
93 /// Verifies that the inferred result types match the actual result types for
94 /// the op. Precondition: op implements InferTypeOpInterface.
95 LogicalResult verifyInferredResultTypes(Operation *op);
96 } // namespace detail
97 
98 namespace OpTrait {
99 
100 /// Tensor type inference trait that constructs a tensor from the inferred
101 /// shape and elemental types.
102 /// Requires: Op implements functions of InferShapedTypeOpInterface.
103 template <typename ConcreteType>
104 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
105 public:
106   static LogicalResult
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)107   inferReturnTypes(MLIRContext *context, Optional<Location> location,
108                    ValueRange operands, DictionaryAttr attributes,
109                    RegionRange regions,
110                    SmallVectorImpl<Type> &inferredReturnTypes) {
111     return ::mlir::detail::inferReturnTensorTypes(
112         ConcreteType::inferReturnTypeComponents, context, location, operands,
113         attributes, regions, inferredReturnTypes);
114   }
115 };
116 
117 } // namespace OpTrait
118 } // namespace mlir
119 
120 /// Include the generated interface declarations.
121 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
122 
123 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
124