• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2   Copyright 2022 The StableHLO Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#ifndef STABLEHLO_DIALECT_STABLEHLO_ATTRS
18#define STABLEHLO_DIALECT_STABLEHLO_ATTRS
19
20include "mlir/IR/OpBase.td"
21include "mlir/IR/TensorEncoding.td"
22
23def StableHLO_Dim : ArrayRefParameter<"int64_t", "Dimension">;
24
25def StableHLO_ScatterDimensionNumbers : AttrDef<StableHLO_Dialect, "ScatterDimensionNumbers"> {
26  let cppNamespace = "::mlir::stablehlo";
27  let mnemonic = "scatter";
28  let summary = "Attribute that models the dimension information for scatter";
29  let parameters = (ins
30      StableHLO_Dim:$updateWindowDims,
31      StableHLO_Dim:$insertedWindowDims,
32      StableHLO_Dim:$scatterDimsToOperandDims,
33      "int64_t":$indexVectorDim
34  );
35  let hasCustomAssemblyFormat = 1;
36}
37
38def StableHLO_GatherDimensionNumbers : AttrDef<StableHLO_Dialect, "GatherDimensionNumbers"> {
39  let cppNamespace = "::mlir::stablehlo";
40  let mnemonic = "gather";
41  let summary = "Attribute that models the dimension information for gather";
42  let parameters = (ins
43      StableHLO_Dim:$offsetDims,
44      StableHLO_Dim:$collapsedSliceDims,
45      StableHLO_Dim:$startIndexMap,
46      "int64_t":$indexVectorDim
47  );
48  let hasCustomAssemblyFormat = 1;
49}
50
51def StableHLO_DotDimensionNumbers : AttrDef<StableHLO_Dialect, "DotDimensionNumbers"> {
52  let cppNamespace = "::mlir::stablehlo";
53  let mnemonic = "dot";
54  let summary = "Attribute that models the dimension information for dot.";
55  let parameters = (ins
56      StableHLO_Dim:$lhsBatchingDimensions,
57      StableHLO_Dim:$rhsBatchingDimensions,
58      StableHLO_Dim:$lhsContractingDimensions,
59      StableHLO_Dim:$rhsContractingDimensions
60  );
61  let hasCustomAssemblyFormat = 1;
62}
63
64def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNumbers"> {
65  let cppNamespace = "::mlir::stablehlo";
66  let mnemonic = "conv";
67  let summary = "Structure of dimension information for conv op";
68  let parameters = (ins
69    "int64_t":$inputBatchDimension,
70    "int64_t":$inputFeatureDimension,
71    StableHLO_Dim:$inputSpatialDimensions,
72    "int64_t":$kernelInputFeatureDimension,
73    "int64_t":$kernelOutputFeatureDimension,
74    StableHLO_Dim:$kernelSpatialDimensions,
75    "int64_t":$outputBatchDimension,
76    "int64_t":$outputFeatureDimension,
77    StableHLO_Dim:$outputSpatialDimensions
78  );
79  let hasCustomAssemblyFormat = 1;
80}
81
82def StableHLO_ArgResultAlias : AttrDef<StableHLO_Dialect, "ArgResultAlias"> {
83  let cppNamespace = "::mlir::stablehlo";
84  let mnemonic = "result_alias";
85  let summary =
86    "Attribute that models the alias relationship of entry function argument";
87  let description = [{
88    This attribute captures the alias relationship of a main function
89    argument to one of the results, denoted by `resultIndex`. The
90    `argTupleIndices` and `resultTupleIndices` are used to index into nested
91    tuples in operand and result respectively. If `isMustAlias` is true then the
92    operand-result pair must alias.
93
94    This is meant to be used as an attribute on a function argument.
95    For example, in the following code it expresses that `%arg1` may alias 0-th
96    result.
97
98    ```mlir
99    func @main(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>
100      {stablehlo.result_alias = stablehlo.result_alias<result_index = [2], ...>}
101      ) -> tensor<2xf32>, tensor<3xf32> {
102      // function body ...
103    }
104    ```
105  }];
106  let parameters = (ins
107    StableHLO_Dim:$argTupleIndices,
108    "int64_t":$resultIndex,
109    StableHLO_Dim:$resultTupleIndices,
110    "bool":$isMustAlias
111  );
112  let hasCustomAssemblyFormat = 1;
113}
114
115// Represents a unique identifier for each Send/Recv instruction pair or
116// optionally for collective instructions (AllReduce, CollectivePermute,
117// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
118def StableHLO_ChannelHandle : AttrDef<StableHLO_Dialect, "ChannelHandle"> {
119  let cppNamespace = "::mlir::stablehlo";
120  let mnemonic = "channel_handle";
121  let parameters = (ins "int64_t":$handle, "int64_t":$type);
122  let summary = "two 64-bit integers 'handle' and 'type'";
123  let assemblyFormat = "`<` struct(params) `>`";
124}
125
126// Note: This is an experimental attribute and shouldn't be relied upon for
127// production.
128def StableHLO_TypeExtensions : AttrDef<StableHLO_Dialect, "TypeExtensions", [
129    DeclareAttrInterfaceMethods<VerifiableTensorEncoding>,
130    DeclareAttrInterfaceMethods<HLO_BoundedAttrInterface>]> {
131  let cppNamespace = "::mlir::stablehlo";
132  let mnemonic = "type_extensions";
133
134  // TODO(b/238903065): Move sparsity related info here from the standalone
135  // attribute. That will allow composition of bounds and sparsity info.
136  let parameters = (ins
137    ArrayRefParameter<"int64_t">:$bounds
138  );
139
140  let summary = "Attribute that extends tensor type with StableHLO type properties.";
141
142  let description = [{
143    This attribute is used to extend MLIR tensor type with StableHLO tensor
144    specific properties. These properties aren't modeled in the MLIR type. This
145    attribute is set in the `encoding` field of the tensor type.
146
147    See `HLO_BoundedAttrInterface` for documentation for `bounds`.
148  }];
149  let assemblyFormat = "`<` `bounds` `=` `[` $bounds `]` `>`";
150}
151
152// A layout attribute (1D tensor of index type)
153def StableHLO_LayoutAttr : Attr<
154  And<[IndexElementsAttr.predicate,
155       CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
156               == 1}]>]>,
157  "A 1D tensor of index type (layout)"> {
158  let storageType = IndexElementsAttr.storageType;
159  let returnType = IndexElementsAttr.returnType;
160  let convertFromStorage = IndexElementsAttr.convertFromStorage;
161}
162
163// An array of layout (1D tensor) attributes.
164def StableHLO_ArrayOfLayoutAttr : TypedArrayAttrBase<StableHLO_LayoutAttr,
165    "Array of layout (1D tensor of index type) attributes">;
166
167// An array of FlatSymbolRef attributes that can be used as a default valued
168// attribute.
169def StableHLO_FlatSymbolRefArrayAttr :
170  TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
171  let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
172}
173
174def StableHLO_BoolElementsAttr :
175    ElementsAttrBase<
176      And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
177           CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
178      "constant boolean vector/tensor attribute"> {
179  let storageType = [{ ::mlir::DenseElementsAttr }];
180  let returnType = [{ ::mlir::DenseElementsAttr }];
181
182  let convertFromStorage = "$_self";
183}
184
185def StableHLO_ConvolutionAttributes {
186  dag attributes = (ins
187    // Default value: one for each of the spatial dimension.
188    OptionalAttr<I64ElementsAttr>:$window_strides,
189    // Default value: two zeros for each of the spatial dimension.
190    OptionalAttr<I64ElementsAttr>:$padding,
191    // Default value: one for each of the spatial dimension.
192    OptionalAttr<I64ElementsAttr>:$lhs_dilation,
193    // Default value: one for each of the spatial dimension.
194    OptionalAttr<I64ElementsAttr>:$rhs_dilation,
195    // Default value: false for each of the spatial dimension.
196    OptionalAttr<StableHLO_BoolElementsAttr>:$window_reversal,
197    StableHLO_ConvDimensionNumbers:$dimension_numbers,
198    I64Attr:$feature_group_count,
199    I64Attr:$batch_group_count,
200    StableHLO_PrecisionConfigAttr:$precision_config
201  );
202}
203
204#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
205