1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
17
18 #include <set>
19 #include <unordered_map>
20
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24
25 namespace mlir {
26 namespace mhlo {
27
28 namespace {
29 enum NonSpatialDim : int64_t {
30 IOBatch = -1, // Input or output batch dimension
31 IOFeature = -2, // Input or output feature dimension
32 KIFeature = -3, // Kernel input feature dimension
33 KOFeature = -4, // Kernel output feature dimensions.
34 };
35
NonSpatialDimToString(NonSpatialDim dim)36 char NonSpatialDimToString(NonSpatialDim dim) {
37 switch (dim) {
38 case IOBatch:
39 return 'b';
40 case IOFeature:
41 return 'f';
42 case KIFeature:
43 return 'i';
44 case KOFeature:
45 return 'o';
46 }
47 }
48 } // namespace
49
50 // Custom printer and parser for struct attributes.
printConvolutionDimensions(OpAsmPrinter & p,Operation *,ConvDimensionNumbers dnums)51 void printConvolutionDimensions(OpAsmPrinter &p, Operation * /*op*/,
52 ConvDimensionNumbers dnums) {
53 auto print_dim =
54 [&p](DenseIntElementsAttr spatial_dims,
55 ArrayRef<std::pair<IntegerAttr, NonSpatialDim>> non_spatial_dims) {
56 llvm::SmallVector<int64_t> dims(non_spatial_dims.size() +
57 spatial_dims.size());
58 // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
59 // spatial dimension index.
60 for (const std::pair<IntegerAttr, NonSpatialDim> &non_spatial_dim :
61 non_spatial_dims) {
62 dims[non_spatial_dim.first.getInt()] = non_spatial_dim.second;
63 }
64 for (auto spatial_dim :
65 llvm::enumerate(spatial_dims.getValues<int64_t>())) {
66 dims[spatial_dim.value()] = static_cast<int64_t>(spatial_dim.index());
67 }
68
69 // Each dimension numbers will be printed as a comma separated list
70 // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
71 p << '[';
72 llvm::interleaveComma(dims, p, [&](int64_t dim) {
73 if (dim >= 0) {
74 p << dim;
75 } else {
76 p << NonSpatialDimToString(static_cast<NonSpatialDim>(dim));
77 }
78 });
79 p << ']';
80 };
81
82 print_dim(dnums.input_spatial_dimensions(),
83 {{dnums.input_batch_dimension(), IOBatch},
84 {dnums.input_feature_dimension(), IOFeature}});
85 p << "x";
86 print_dim(dnums.kernel_spatial_dimensions(),
87 {{dnums.kernel_input_feature_dimension(), KIFeature},
88 {dnums.kernel_output_feature_dimension(), KOFeature}});
89 p << "->";
90 print_dim(dnums.output_spatial_dimensions(),
91 {{dnums.output_batch_dimension(), IOBatch},
92 {dnums.output_feature_dimension(), IOFeature}});
93 }
94
parseConvolutionDimensions(OpAsmParser & parser,ConvDimensionNumbers & dnums)95 ParseResult parseConvolutionDimensions(OpAsmParser &parser,
96 ConvDimensionNumbers &dnums) {
97 // Parsing a single set of dim numbers gives the spatial dimensions as a
98 // single DenseIntElementsAttr and a list of non-spatial dimensions as
99 // IntegerAttrs (indexed by the NonSpatialDim enum).
100 using parse_dim_result_t = std::pair<
101 DenseIntElementsAttr,
102 std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>>;
103
104 // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
105 // set) because its used to print a list of allowed non spatial dims in the
106 // error messages, so making it a set keeps the error messages deterministic.
107 auto parse_dims =
108 [&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
109 parse_dim_result_t &parsed_dims) -> ParseResult {
110 // Parse the starting [
111 if (parser.parseLSquare()) {
112 return failure();
113 }
114 llvm::SmallVector<int64_t> spatial_dims;
115 std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>
116 non_spatial_dims;
117
118 int64_t index = 0;
119 do {
120 int64_t spatial_dim;
121 OptionalParseResult parseResult =
122 parser.parseOptionalInteger(spatial_dim);
123 if (parseResult.hasValue()) {
124 if (parseResult.getValue().failed()) {
125 return failure();
126 }
127 // We were successful in parsing an integer. Add its index to the
128 // spatial dims.
129 spatial_dims.push_back(index);
130 } else {
131 // We did not parse an integer. We expect a keyword token.
132 StringRef keyword;
133 if (parser.parseKeyword(&keyword)) {
134 return failure();
135 }
136 if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
137 return parser.emitError(parser.getCurrentLocation(),
138 "Unexpected keyword ")
139 << keyword;
140 }
141 // Check if the keyword matches one of the allowed non-spatial dims.
142 // If so, add it to the non_spatial dims and remove it from the
143 // allowed set so that it won't be allowed again.
144 bool is_allowed = false;
145 for (NonSpatialDim allowed : allowed_non_spatial_dims) {
146 if (keyword[0] == NonSpatialDimToString(allowed)) {
147 non_spatial_dims.insert(
148 {allowed, parser.getBuilder().getI64IntegerAttr(index)});
149 allowed_non_spatial_dims.erase(allowed);
150 is_allowed = true;
151 break;
152 }
153 }
154
155 if (!is_allowed) {
156 mlir::InFlightDiagnostic diag = parser.emitError(
157 parser.getCurrentLocation(), "Unexpected dimension ");
158 diag << keyword << ", expecting ";
159 llvm::interleaveComma(
160 allowed_non_spatial_dims, diag,
161 [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
162 return diag;
163 }
164 }
165 index++;
166 } while (parser.parseOptionalComma().succeeded());
167
168 // Make sure all expected non-spatial dimensions are parsed.
169 if (!allowed_non_spatial_dims.empty()) {
170 mlir::InFlightDiagnostic diag =
171 parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
172 llvm::interleaveComma(
173 allowed_non_spatial_dims, diag,
174 [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
175 diag << " not specified";
176 return diag;
177 }
178
179 // parse ending ]
180 if (parser.parseRSquare()) {
181 return failure();
182 }
183
184 parsed_dims = std::make_pair(
185 parser.getBuilder().getI64TensorAttr(spatial_dims), non_spatial_dims);
186 return success();
187 };
188
189 parse_dim_result_t parsed_dims;
190 if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
191 return failure();
192 }
193 DenseIntElementsAttr input_spatial_dimensions = parsed_dims.first;
194 IntegerAttr input_batch_dimension = parsed_dims.second[IOBatch];
195 IntegerAttr input_feature_dimension = parsed_dims.second[IOFeature];
196 if (parser.parseKeyword("x")) return failure();
197 if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
198 return failure();
199 }
200 DenseIntElementsAttr kernel_spatial_dimensions = parsed_dims.first;
201 IntegerAttr kernel_input_feature_dimension = parsed_dims.second[KIFeature];
202 IntegerAttr kernel_output_feature_dimension = parsed_dims.second[KOFeature];
203 if (parser.parseArrow()) {
204 return failure();
205 }
206 if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
207 return failure();
208 }
209 DenseIntElementsAttr output_spatial_dimensions = parsed_dims.first;
210 IntegerAttr output_batch_dimension = parsed_dims.second[IOBatch];
211 IntegerAttr output_feature_dimension = parsed_dims.second[IOFeature];
212 dnums = ConvDimensionNumbers::get(
213 input_batch_dimension, input_feature_dimension, input_spatial_dimensions,
214 kernel_input_feature_dimension, kernel_output_feature_dimension,
215 kernel_spatial_dimensions, output_batch_dimension,
216 output_feature_dimension, output_spatial_dimensions,
217 parser.getBuilder().getContext());
218
219 return success();
220 }
221
222 } // namespace mhlo
223 } // namespace mlir
224