1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef HLO_OPS_BASE 17#define HLO_OPS_BASE 18 19include "mlir/IR/OpBase.td" 20 21def HLO_Dialect : Dialect { 22 let name = "mhlo"; 23 let cppNamespace = "::mlir::mhlo"; 24} 25 26include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td" 27include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" 28 29def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; 30 31// TODO(hinsu): Use signed integers instead of signless integer which is being 32// used for legacy reasons. 33def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; 34def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; 35def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; 36 37def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>; 38 39// The broadcasting dimensions correspond to a tuple that describes how a 40// smaller rank shape is broadcast into a larger rank shape. For example, 41// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means 42// matching the matrix to dimensions 1 and 2 of the cuboid. 43defvar BroadcastDimAttr = I64ElementsAttr; 44 45//===----------------------------------------------------------------------===// 46// MHLO on tensors type definitions. 47//===----------------------------------------------------------------------===// 48 49// Token type. 50def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">; 51 52// Any integer tensor types 53def HLO_IntTensor : TensorOf<[HLO_Int]>; 54 55// Any integer tensor type with rank 0 (i.e. representing a single integer). 56def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; 57 58// Any floating-point tensor types 59def HLO_FpTensor : TensorOf<[AnyFloat]>; 60 61def HLO_PredTensor : TensorOf<[HLO_Pred]>; 62 63def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; 64 65def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; 66 67def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; 68 69def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; 70 71def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; 72 73def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; 74 75// Dynamic representation of a shape vector as a tensor. 76def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; 77 78// In general, static shaped tensor constraints should be avoided unless 79// it is for a legacy op which is only correct with static shapes. 80def HLO_StaticShapeTensor : StaticShapeTensorOf<[ 81 AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; 82 83//===----------------------------------------------------------------------===// 84// MHLO on tensors combined type definitions. 85//===----------------------------------------------------------------------===// 86 87// Any integer or floating-point tensor types 88def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; 89 90// Any integer or predicate tensor types 91def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; 92 93// Any floating-point or complex tensor types 94def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>; 95 96// Any int, floating-point or complex tensor types 97def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; 98 99// Any pred, int or floating-point tensor types 100def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; 101 102// A layout attribute (1D tensor of index type) 103def HLO_LayoutAttr : Attr< 104 And<[IndexElementsAttr.predicate, 105 CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank() 106 == 1}]>]>, 107 "A 1D tensor of index type (layout)"> { 108 let storageType = IndexElementsAttr.storageType; 109 let returnType = IndexElementsAttr.returnType; 110 let convertFromStorage = IndexElementsAttr.convertFromStorage; 111} 112 113//===----------------------------------------------------------------------===// 114// Common convolution attributes 115//===----------------------------------------------------------------------===// 116 117// TODO(b/129153247) See if it's possible to also validate the size. 118def HLO_PrecisionConfigAttr: 119 OptionalAttr< 120 TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>; 121 122def BoolElementsAttr : 123 ElementsAttrBase< 124 And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, 125 CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, 126 "constant boolean vector/tensor attribute"> { 127 let storageType = [{ ::mlir::DenseElementsAttr }]; 128 let returnType = [{ ::mlir::DenseElementsAttr }]; 129 130 let convertFromStorage = "$_self"; 131} 132 133def ConvolutionAttributes { 134 dag attributes = (ins 135 // Default value: one for each of the spatial dimension. 136 OptionalAttr<I64ElementsAttr>:$window_strides, 137 // Default value: zero for each of the spatial dimension. 138 OptionalAttr<I64ElementsAttr>:$padding, 139 // Default value: one for each of the spatial dimension. 140 OptionalAttr<I64ElementsAttr>:$lhs_dilation, 141 // Default value: one for each of the spatial dimension. 142 OptionalAttr<I64ElementsAttr>:$rhs_dilation, 143 // Default value: one for each of the spatial dimension. 144 OptionalAttr<BoolElementsAttr>:$window_reversal, 145 ConvDimensionNumbers:$dimension_numbers, 146 I64Attr:$feature_group_count, 147 I64Attr:$batch_group_count, 148 HLO_PrecisionConfigAttr:$precision_config 149 ); 150} 151 152class BASE_HLO_ConvOp { 153} 154 155//===----------------------------------------------------------------------===// 156// Common traits 157//===----------------------------------------------------------------------===// 158 159class HLO_NativeOpTrait<string name> : NativeOpTrait<name> { 160 let cppNamespace = "::mlir::mhlo::OpTrait"; 161} 162 163// An operation that is essentially element-wise but may implement broadcasting 164// semantics. 165def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; 166 167// Op has pairwise operand and result type matching: the number of operands 168// must be equal to the number of results and the type of ith operand must 169// match the type of ith result. 170// TODO(b/195086460) Promote this to be an mlir trait and remove it here. 171def HLO_PairwiseSameOperandAndResultType : 172 HLO_NativeOpTrait<"PairwiseSameOperandAndResultType">; 173 174#endif // HLO_OPS_BASE 175