• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
17 
18 #include <set>
19 #include <unordered_map>
20 
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 
25 namespace mlir {
26 namespace mhlo {
27 
28 namespace {
29 enum NonSpatialDim : int64_t {
30   IOBatch = -1,    // Input or output batch dimension
31   IOFeature = -2,  // Input or output feature dimension
32   KIFeature = -3,  // Kernel input feature dimension
33   KOFeature = -4,  // Kernel output feature dimensions.
34 };
35 
NonSpatialDimToString(NonSpatialDim dim)36 char NonSpatialDimToString(NonSpatialDim dim) {
37   switch (dim) {
38     case IOBatch:
39       return 'b';
40     case IOFeature:
41       return 'f';
42     case KIFeature:
43       return 'i';
44     case KOFeature:
45       return 'o';
46   }
47 }
48 }  // namespace
49 
50 // Custom printer and parser for struct attributes.
printConvolutionDimensions(OpAsmPrinter & p,Operation *,ConvDimensionNumbers dnums)51 void printConvolutionDimensions(OpAsmPrinter &p, Operation * /*op*/,
52                                 ConvDimensionNumbers dnums) {
53   auto print_dim =
54       [&p](DenseIntElementsAttr spatial_dims,
55            ArrayRef<std::pair<IntegerAttr, NonSpatialDim>> non_spatial_dims) {
56         llvm::SmallVector<int64_t> dims(non_spatial_dims.size() +
57                                         spatial_dims.size());
58         // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
59         // spatial dimension index.
60         for (const std::pair<IntegerAttr, NonSpatialDim> &non_spatial_dim :
61              non_spatial_dims) {
62           dims[non_spatial_dim.first.getInt()] = non_spatial_dim.second;
63         }
64         for (auto spatial_dim :
65              llvm::enumerate(spatial_dims.getValues<int64_t>())) {
66           dims[spatial_dim.value()] = static_cast<int64_t>(spatial_dim.index());
67         }
68 
69         // Each dimension numbers will be printed as a comma separated list
70         // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
71         p << '[';
72         llvm::interleaveComma(dims, p, [&](int64_t dim) {
73           if (dim >= 0) {
74             p << dim;
75           } else {
76             p << NonSpatialDimToString(static_cast<NonSpatialDim>(dim));
77           }
78         });
79         p << ']';
80       };
81 
82   print_dim(dnums.input_spatial_dimensions(),
83             {{dnums.input_batch_dimension(), IOBatch},
84              {dnums.input_feature_dimension(), IOFeature}});
85   p << "x";
86   print_dim(dnums.kernel_spatial_dimensions(),
87             {{dnums.kernel_input_feature_dimension(), KIFeature},
88              {dnums.kernel_output_feature_dimension(), KOFeature}});
89   p << "->";
90   print_dim(dnums.output_spatial_dimensions(),
91             {{dnums.output_batch_dimension(), IOBatch},
92              {dnums.output_feature_dimension(), IOFeature}});
93 }
94 
parseConvolutionDimensions(OpAsmParser & parser,ConvDimensionNumbers & dnums)95 ParseResult parseConvolutionDimensions(OpAsmParser &parser,
96                                        ConvDimensionNumbers &dnums) {
97   // Parsing a single set of dim numbers gives the spatial dimensions as a
98   // single DenseIntElementsAttr and a list of non-spatial dimensions as
99   // IntegerAttrs (indexed by the NonSpatialDim enum).
100   using parse_dim_result_t = std::pair<
101       DenseIntElementsAttr,
102       std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>>;
103 
104   // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
105   // set) because its used to print a list of allowed non spatial dims in the
106   // error messages, so making it a set keeps the error messages deterministic.
107   auto parse_dims =
108       [&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
109           parse_dim_result_t &parsed_dims) -> ParseResult {
110     // Parse the starting [
111     if (parser.parseLSquare()) {
112       return failure();
113     }
114     llvm::SmallVector<int64_t> spatial_dims;
115     std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>
116         non_spatial_dims;
117 
118     int64_t index = 0;
119     do {
120       int64_t spatial_dim;
121       OptionalParseResult parseResult =
122           parser.parseOptionalInteger(spatial_dim);
123       if (parseResult.hasValue()) {
124         if (parseResult.getValue().failed()) {
125           return failure();
126         }
127         // We were successful in parsing an integer. Add its index to the
128         // spatial dims.
129         spatial_dims.push_back(index);
130       } else {
131         // We did not parse an integer. We expect a keyword token.
132         StringRef keyword;
133         if (parser.parseKeyword(&keyword)) {
134           return failure();
135         }
136         if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
137           return parser.emitError(parser.getCurrentLocation(),
138                                   "Unexpected keyword ")
139                  << keyword;
140         }
141         // Check if the keyword matches one of the allowed non-spatial dims.
142         // If so, add it to the non_spatial dims and remove it from the
143         // allowed set so that it won't be allowed again.
144         bool is_allowed = false;
145         for (NonSpatialDim allowed : allowed_non_spatial_dims) {
146           if (keyword[0] == NonSpatialDimToString(allowed)) {
147             non_spatial_dims.insert(
148                 {allowed, parser.getBuilder().getI64IntegerAttr(index)});
149             allowed_non_spatial_dims.erase(allowed);
150             is_allowed = true;
151             break;
152           }
153         }
154 
155         if (!is_allowed) {
156           mlir::InFlightDiagnostic diag = parser.emitError(
157               parser.getCurrentLocation(), "Unexpected dimension ");
158           diag << keyword << ", expecting ";
159           llvm::interleaveComma(
160               allowed_non_spatial_dims, diag,
161               [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
162           return diag;
163         }
164       }
165       index++;
166     } while (parser.parseOptionalComma().succeeded());
167 
168     // Make sure all expected non-spatial dimensions are parsed.
169     if (!allowed_non_spatial_dims.empty()) {
170       mlir::InFlightDiagnostic diag =
171           parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
172       llvm::interleaveComma(
173           allowed_non_spatial_dims, diag,
174           [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
175       diag << " not specified";
176       return diag;
177     }
178 
179     // parse ending ]
180     if (parser.parseRSquare()) {
181       return failure();
182     }
183 
184     parsed_dims = std::make_pair(
185         parser.getBuilder().getI64TensorAttr(spatial_dims), non_spatial_dims);
186     return success();
187   };
188 
189   parse_dim_result_t parsed_dims;
190   if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
191     return failure();
192   }
193   DenseIntElementsAttr input_spatial_dimensions = parsed_dims.first;
194   IntegerAttr input_batch_dimension = parsed_dims.second[IOBatch];
195   IntegerAttr input_feature_dimension = parsed_dims.second[IOFeature];
196   if (parser.parseKeyword("x")) return failure();
197   if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
198     return failure();
199   }
200   DenseIntElementsAttr kernel_spatial_dimensions = parsed_dims.first;
201   IntegerAttr kernel_input_feature_dimension = parsed_dims.second[KIFeature];
202   IntegerAttr kernel_output_feature_dimension = parsed_dims.second[KOFeature];
203   if (parser.parseArrow()) {
204     return failure();
205   }
206   if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
207     return failure();
208   }
209   DenseIntElementsAttr output_spatial_dimensions = parsed_dims.first;
210   IntegerAttr output_batch_dimension = parsed_dims.second[IOBatch];
211   IntegerAttr output_feature_dimension = parsed_dims.second[IOFeature];
212   dnums = ConvDimensionNumbers::get(
213       input_batch_dimension, input_feature_dimension, input_spatial_dimensions,
214       kernel_input_feature_dimension, kernel_output_feature_dimension,
215       kernel_spatial_dimensions, output_batch_dimension,
216       output_feature_dimension, output_spatial_dimensions,
217       parser.getBuilder().getContext());
218 
219   return success();
220 }
221 
222 }  // namespace mhlo
223 }  // namespace mlir
224