• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2   Copyright 2022 The StableHLO Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#ifndef STABLEHLO_DIALECT_BASE
18#define STABLEHLO_DIALECT_BASE
19
20include "mlir/Dialect/Quant/QuantOpsBase.td"
21include "mlir/Interfaces/InferTypeOpInterface.td"
22include "mlir/IR/AttrTypeBase.td"
23include "mlir/IR/OpBase.td"
24
25//===----------------------------------------------------------------------===//
26// HLO type definitions.
27//===----------------------------------------------------------------------===//
28
29def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
30
31// TODO(hinsu): Use signed integers instead of signless integer which is being
32// used for legacy reasons.
33def HLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>;
34def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>;
35def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
36
37def HLO_Float : AnyTypeOf<[F16, F32, F64, BF16]>;
38def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;
39
40def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
41
42//===----------------------------------------------------------------------===//
43// Quantized element type definitions.
44//===----------------------------------------------------------------------===//
45
46// TODO(b/230381284): Upstream width-specific uniform quantized element types.
47class UniformQuantizedSignedInt<int width>
48  : Type<Or<[
49      And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">,
50           CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
51                 ".getStorageTypeIntegralWidth() == " # width>,
52           CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
53                 ".isSigned()">]>,
54      And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedPerAxisType>()">,
55           CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" #
56                 ".getStorageTypeIntegralWidth() == " # width>,
57           CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" #
58                 ".isSigned()">]>]>,
59    "QI" # width # " type"> {
60  string name = "UniformQuantizedSignedInt";
61  int bitwidth = width;
62}
63
64class UniformQuantizedUnsignedInt<int width>
65  : Type<Or<[
66      And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">,
67           CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
68                 ".getStorageTypeIntegralWidth() == " # width>,
69           CPred<"!$_self.cast<mlir::quant::UniformQuantizedType>()" #
70                 ".isSigned()">]>,
71      And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedPerAxisType>()">,
72           CPred<"$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" #
73                 ".getStorageTypeIntegralWidth() == " # width>,
74           CPred<"!$_self.cast<mlir::quant::UniformQuantizedPerAxisType>()" #
75                 ".isSigned()">]>]>,
76    "QUI" # width # " type"> {
77  string name = "UniformQuantizedUnsignedInt";
78  int bitwidth = width;
79}
80
81class UniformQuantizedSignedIntOfWidths<list<int> widths> :
82    AnyTypeOf<!foreach(w, widths, UniformQuantizedSignedInt<w>),
83              !interleave(widths, "/") # "-bit uniform quantized signed " #
84              "integer">;
85
86class UniformQuantizedUnsignedIntOfWidths<list<int> widths> :
87    AnyTypeOf<!foreach(w, widths, UniformQuantizedUnsignedInt<w>),
88              !interleave(widths, "/") # "-bit uniform quantized unsigned " #
89              "integer">;
90
91// Integer-based uniform quantized types. The definitions can be used to specify
92// operand's tensor types.
93def HLO_QuantizedSignedInt : UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>;
94def HLO_QuantizedUnsignedInt : UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>;
95def HLO_QuantizedInt :
96    AnyTypeOf<[HLO_QuantizedSignedInt, HLO_QuantizedUnsignedInt]>;
97
98// The broadcasting dimensions correspond to a tuple that describes how a
99// smaller rank shape is broadcast into a larger rank shape. For example,
100// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
101// matching the matrix to dimensions 1 and 2 of the cuboid.
102defvar BroadcastDimAttr = I64ElementsAttr;
103
104// Token type.
105def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
106
107// Any integer tensor types
108def HLO_IntTensor : TensorOf<[HLO_Int]>;
109
110// Any integer tensor type with rank 0 (i.e. representing a single integer).
111def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
112
113// Any floating-point tensor types
114def HLO_FpTensor : TensorOf<[HLO_Float]>;
115
116// 32 or 64 bits floating-point tensor types
117def HLO_Fp32Or64Tensor : TensorOf<[HLO_Float32Or64]>;
118
119// Any quantized integer tensor types
120def HLO_QuantizedIntTensor : TensorOf<[HLO_QuantizedInt]>;
121
122def HLO_PredTensor : TensorOf<[HLO_Pred]>;
123
124def HLO_Tensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;
125
126def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
127
128def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
129
130def HLO_TensorOrToken : AnyTypeOf<[HLO_Tensor, HLO_Token]>;
131
132def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
133
134def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Int]>;
135
136// Dynamic representation of a shape vector as a tensor.
137def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
138
139// In general, static shaped tensor constraints should be avoided unless
140// it is for a legacy op which is only correct with static shapes.
141def HLO_StaticShapeTensor : StaticShapeTensorOf<[
142      HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;
143
144//===----------------------------------------------------------------------===//
145// HLO combined type definitions.
146//===----------------------------------------------------------------------===//
147
148// Any integer or floating-point tensor types
149def HLO_IntOrFpTensor : TensorOf<[HLO_Int, HLO_Float]>;
150
151// Any integer or predicate tensor types
152def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
153
154// Any floating-point or complex tensor types
155def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>;
156
157// Any int, floating-point or complex tensor types
158def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>;
159
160// Any pred, int or floating-point tensor types
161def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;
162
163//===----------------------------------------------------------------------===//
164// HLO traits
165//===----------------------------------------------------------------------===//
166
167class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
168  let cppNamespace = "::mlir::hlo::OpTrait";
169}
170
171// An operation that is essentially element-wise but may implement broadcasting
172// semantics.
173def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
174
175// Op has pairwise operand and result type matching: the number of operands
176// must be equal to the number of results and the type of ith operand must
177// match the type of ith result.
178// TODO(b/195086460) Promote this to be an mlir trait and remove it here.
179def HLO_PairwiseSameOperandAndResultType :
180  HLO_NativeOpTrait<"PairwiseSameOperandAndResultType">;
181
182// Op has operand and result types compatible with each other according to
183// the rules implemented in isCompatibleForHloTypeInference, which account for
184// special properties dynamism, quantization and sparsity.
185def HLO_CompatibleOperandsAndResultType : TraitList<
186  // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
187  // support quantization or sparsity.
188  [
189    InferTypeOpInterface,
190    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>,
191    HLO_NativeOpTrait<"CompatibleOperandsAndResultType">
192  ]>;
193
194def HLO_BoundedAttrInterface : AttrInterface<"BoundedAttrInterface"> {
195  let cppNamespace = "::mlir::hlo";
196
197  let description = [{
198    This interface is used for attributes that carry bounds for dimension sizes
199    of an accompanying shaped type, e.g. when the attribute represents a
200    RankedTensorType::getEncoding.
201    The number of bounds is expected to be the same as the number of dimensions
202    in the accompanying shaped type.
203    For a static dimension, the corresponding bound is ShapedType::kDynamicSize.
204    For a dynamic dimension, the corresponding bound is either known and is
205    a non-negative number or unknown and is ShapedType::kDynamicSize.
206  }];
207
208  let methods = [InterfaceMethod<
209    "Get the attribute's bounds",
210    "::llvm::ArrayRef<int64_t>", "getBounds"
211  >];
212}
213
214#endif // STABLEHLO_DIALECT_BASE
215