• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef HLO_OPS_BASE
17#define HLO_OPS_BASE
18
19include "mlir/IR/OpBase.td"
20
21def HLO_Dialect : Dialect {
22  let name = "mhlo";
23  let cppNamespace = "::mlir::mhlo";
24}
25
26include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
27include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
28
29def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
30
31// TODO(hinsu): Use signed integers instead of signless integer which is being
32// used for legacy reasons.
33def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
34def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
35def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
36
37def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
38
39// The broadcasting dimensions correspond to a tuple that describes how a
40// smaller rank shape is broadcast into a larger rank shape. For example,
41// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
42// matching the matrix to dimensions 1 and 2 of the cuboid.
43defvar BroadcastDimAttr = I64ElementsAttr;
44
45//===----------------------------------------------------------------------===//
46// MHLO on tensors type definitions.
47//===----------------------------------------------------------------------===//
48
49// Token type.
50def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
51
52// Any integer tensor types
53def HLO_IntTensor : TensorOf<[HLO_Int]>;
54
55// Any integer tensor type with rank 0 (i.e. representing a single integer).
56def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
57
58// Any floating-point tensor types
59def HLO_FpTensor : TensorOf<[AnyFloat]>;
60
61def HLO_PredTensor : TensorOf<[HLO_Pred]>;
62
63def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
64
65def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
66
67def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
68
69def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
70
71def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
72
73def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
74
75// Dynamic representation of a shape vector as a tensor.
76def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
77
78// In general, static shaped tensor constraints should be avoided unless
79// it is for a legacy op which is only correct with static shapes.
80def HLO_StaticShapeTensor : StaticShapeTensorOf<[
81      AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
82
83//===----------------------------------------------------------------------===//
84// MHLO on tensors combined type definitions.
85//===----------------------------------------------------------------------===//
86
87// Any integer or floating-point tensor types
88def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
89
90// Any integer or predicate tensor types
91def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
92
93// Any floating-point or complex tensor types
94def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
95
96// Any int, floating-point or complex tensor types
97def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
98
99// Any pred, int or floating-point tensor types
100def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
101
102// A layout attribute (1D tensor of index type)
103def HLO_LayoutAttr : Attr<
104  And<[IndexElementsAttr.predicate,
105       CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
106               == 1}]>]>,
107  "A 1D tensor of index type (layout)"> {
108  let storageType = IndexElementsAttr.storageType;
109  let returnType = IndexElementsAttr.returnType;
110  let convertFromStorage = IndexElementsAttr.convertFromStorage;
111}
112
113//===----------------------------------------------------------------------===//
114// Common convolution attributes
115//===----------------------------------------------------------------------===//
116
117// TODO(b/129153247) See if it's possible to also validate the size.
118def HLO_PrecisionConfigAttr:
119    OptionalAttr<
120          TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
121
122def BoolElementsAttr :
123    ElementsAttrBase<
124      And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
125           CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
126      "constant boolean vector/tensor attribute"> {
127  let storageType = [{ ::mlir::DenseElementsAttr }];
128  let returnType = [{ ::mlir::DenseElementsAttr }];
129
130  let convertFromStorage = "$_self";
131}
132
133def ConvolutionAttributes {
134  dag attributes = (ins
135    // Default value: one for each of the spatial dimension.
136    OptionalAttr<I64ElementsAttr>:$window_strides,
137    // Default value: zero for each of the spatial dimension.
138    OptionalAttr<I64ElementsAttr>:$padding,
139    // Default value: one for each of the spatial dimension.
140    OptionalAttr<I64ElementsAttr>:$lhs_dilation,
141    // Default value: one for each of the spatial dimension.
142    OptionalAttr<I64ElementsAttr>:$rhs_dilation,
143    // Default value: one for each of the spatial dimension.
144    OptionalAttr<BoolElementsAttr>:$window_reversal,
145    ConvDimensionNumbers:$dimension_numbers,
146    I64Attr:$feature_group_count,
147    I64Attr:$batch_group_count,
148    HLO_PrecisionConfigAttr:$precision_config
149  );
150}
151
152class BASE_HLO_ConvOp {
153}
154
155//===----------------------------------------------------------------------===//
156// Common traits
157//===----------------------------------------------------------------------===//
158
159class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
160  let cppNamespace = "::mlir::mhlo::OpTrait";
161}
162
163// An operation that is essentially element-wise but may implement broadcasting
164// semantics.
165def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
166
167// Op has pairwise operand and result type matching: the number of operands
168// must be equal to the number of results and the type of ith operand must
169// match the type of ith result.
170// TODO(b/195086460) Promote this to be an mlir trait and remove it here.
171def HLO_PairwiseSameOperandAndResultType :
172  HLO_NativeOpTrait<"PairwiseSameOperandAndResultType">;
173
174#endif // HLO_OPS_BASE
175