• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 
24 namespace mlir {
25 namespace hlo {
26 // Verifies the source target pairs attached to collective permute.
VerifyCollectivePermuteSourceTargetPairs(Operation * op,DenseIntElementsAttr attr)27 LogicalResult VerifyCollectivePermuteSourceTargetPairs(
28     Operation *op, DenseIntElementsAttr attr) {
29   auto type = attr.getType().dyn_cast<RankedTensorType>();
30   if (type.getRank() != 2)
31     return op->emitError() << "expect source_target_pairs attribute to be of "
32                               "rank 2, but got rank "
33                            << type.getRank();
34   if (type.getShape()[1] != 2)
35     return op->emitError()
36            << "expect source_target_pairs attribute of shape (N, 2), but got ("
37            << type.getShape() << ")";
38   // Check source target pairs for duplicate sources or targets.
39   llvm::DenseSet<int64_t> sources;
40   llvm::DenseSet<int64_t> targets;
41   for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
42     auto val = (*i).getSExtValue();
43     if (i.getIndex() % 2 == 0) {
44       bool is_unique = sources.insert(val).second;
45       if (!is_unique)
46         return op->emitError() << "duplicate sources not allowed.";
47     } else {
48       bool is_unique = targets.insert(val).second;
49       if (!is_unique)
50         return op->emitError() << "duplicate targets not allowed.";
51     }
52   }
53   return success();
54 }
55 
VerifyReduceScatter(Operation * op,TypeRange operand_types,TypeRange result_types,uint64_t scatter_dimension)56 LogicalResult VerifyReduceScatter(Operation *op, TypeRange operand_types,
57                                   TypeRange result_types,
58                                   uint64_t scatter_dimension) {
59   // If operand and result are both ranked, then the size of the scatter
60   // dimension in the operand should be a multiple of the size of the scatter
61   // dimension in the result.
62   for (auto it : llvm::zip(operand_types, result_types)) {
63     auto operand_type = std::get<0>(it).cast<ShapedType>();
64     auto result_type = std::get<1>(it).cast<ShapedType>();
65     if (!operand_type.hasRank() || !result_type.hasRank()) continue;
66     if (operand_type.getRank() != result_type.getRank())
67       return op->emitOpError() << "operand and result should have same rank";
68     if (scatter_dimension >= operand_type.getRank())
69       return op->emitOpError()
70              << "scatter dim should be less than operand/result rank";
71     if (operand_type.isDynamicDim(scatter_dimension) ||
72         result_type.isDynamicDim(scatter_dimension))
73       continue;
74     if (operand_type.getDimSize(scatter_dimension) == 0)
75       return op->emitOpError() << "operand scatter dimension cannot be zero";
76     if (result_type.getDimSize(scatter_dimension) == 0)
77       return op->emitOpError() << "result scatter dimension cannot be zero";
78     if ((operand_type.getDimSize(scatter_dimension) %
79          result_type.getDimSize(scatter_dimension)) != 0)
80       return op->emitOpError()
81              << "operand scatter dimension has size "
82              << operand_type.getDimSize(scatter_dimension)
83              << ", expected to be a multiple of result scatter dimension size "
84              << result_type.getDimSize(scatter_dimension);
85 
86     // Non scatter dimensions should be equal.
87     for (uint64_t index : llvm::seq<uint64_t>(0, operand_type.getRank())) {
88       if (index == scatter_dimension || operand_type.isDynamicDim(index) ||
89           result_type.isDynamicDim(index))
90         continue;
91       if (operand_type.getDimSize(index) != result_type.getDimSize(index))
92         return op->emitOpError()
93                << "non scatter dimensions should be same for operand ("
94                << operand_type.getDimSize(index) << ") and result ("
95                << result_type.getDimSize(index) << ")";
96     }
97   }
98   return success();
99 }
100 
101 namespace {
102 // Custom formatting for convolution window attributes.
printWindowAttribute(OpAsmPrinter & p,DenseElementsAttr attribute)103 void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
104   if (attribute.getType().getElementType().isInteger(/*width=*/1)) {
105     // boolean attribute.
106     llvm::interleaveComma(attribute.getBoolValues(), p,
107                           [&](bool b) { p << (b ? 1 : 0); });
108     return;
109   }
110   if (attribute.getType().getRank() == 2) {
111     // Padding is Nx2 attribute.
112     auto it = attribute.getValues<int64_t>().begin();
113     std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
114                                                     2);
115     for (auto &item : values) {
116       int64_t first = *it;
117       ++it;
118       int64_t second = *it;
119       ++it;
120       item = {first, second};
121     }
122     llvm::interleaveComma(
123         values, p, [&](const std::pair<int64_t, int64_t> pair) {
124           p << '[' << pair.first << ", " << pair.second << ']';
125         });
126   } else {
127     llvm::interleaveComma(attribute.getValues<int64_t>(), p);
128   }
129 }
130 }  // namespace
131 
printWindowAttributes(OpAsmPrinter & p,Operation * op,llvm::Optional<DenseIntElementsAttr> window_strides,llvm::Optional<DenseIntElementsAttr> padding,llvm::Optional<DenseIntElementsAttr> lhs_dilation,llvm::Optional<DenseIntElementsAttr> rhs_dilation,llvm::Optional<DenseElementsAttr> window_reversal)132 void printWindowAttributes(OpAsmPrinter &p, Operation *op,
133                            llvm::Optional<DenseIntElementsAttr> window_strides,
134                            llvm::Optional<DenseIntElementsAttr> padding,
135                            llvm::Optional<DenseIntElementsAttr> lhs_dilation,
136                            llvm::Optional<DenseIntElementsAttr> rhs_dilation,
137                            llvm::Optional<DenseElementsAttr> window_reversal) {
138   using pair_t = std::pair<DenseElementsAttr, StringRef>;
139   std::array<pair_t, 5> printed_attributes = {{
140       {window_strides ? *window_strides : nullptr, "stride"},
141       {padding ? *padding : nullptr, "pad"},
142       {lhs_dilation ? *lhs_dilation : nullptr, "lhs_dilate"},
143       {rhs_dilation ? *rhs_dilation : nullptr, "rhs_dilate"},
144       {window_reversal ? *window_reversal : nullptr, "reverse"},
145   }};
146 
147   // Do not print attributes that do no exist.
148   auto non_null_attributes = llvm::make_filter_range(
149       printed_attributes,
150       [](const pair_t &a) { return static_cast<bool>(a.first); });
151 
152   llvm::interleaveComma(non_null_attributes, p, [&](const pair_t &a) {
153     p << a.second << " = [";
154     printWindowAttribute(p, a.first);
155     p << "]";
156   });
157 }
158 
parseWindowAttributes(OpAsmParser & parser,DenseIntElementsAttr & window_strides,DenseIntElementsAttr & padding,DenseIntElementsAttr & lhs_dilation,DenseIntElementsAttr & rhs_dilation,DenseElementsAttr & window_reversal)159 ParseResult parseWindowAttributes(OpAsmParser &parser,
160                                   DenseIntElementsAttr &window_strides,
161                                   DenseIntElementsAttr &padding,
162                                   DenseIntElementsAttr &lhs_dilation,
163                                   DenseIntElementsAttr &rhs_dilation,
164                                   DenseElementsAttr &window_reversal) {
165   StringRef attribute_name;
166 
167   // Helper to parse an array of the form [ e0, e1, .. ]
168   auto parse_array = [&](std::function<ParseResult(void)> parse_element,
169                          llvm::Optional<size_t> expected_size =
170                              llvm::None) -> ParseResult {
171     if (parser.parseLSquare()) {
172       return failure();
173     }
174     size_t size = 0;
175     do {
176       if (parse_element()) {
177         return failure();
178       }
179       size++;
180     } while (parser.parseOptionalComma().succeeded());
181     if (parser.parseRSquare()) {
182       return failure();
183     }
184     if (expected_size && size != *expected_size) {
185       return parser.emitError(parser.getCurrentLocation(),
186                               "Expected array with")
187              << *expected_size << " elements, got " << size
188              << " elements instead";
189     }
190     return success();
191   };
192 
193   llvm::StringSet<> allowed_attribute_names{
194       {"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
195 
196   while (parser.parseOptionalKeyword(&attribute_name).succeeded()) {
197     // Verify that the attribute name is valid and erase it.
198     if (!allowed_attribute_names.erase(attribute_name)) {
199       return parser.emitError(parser.getCurrentLocation(),
200                               "Unexpected keyword ")
201              << attribute_name;
202     }
203 
204     if (parser.parseEqual()) {
205       return failure();
206     }
207 
208     // parse the attribute value. We need to support either 1D and Nx2 array of
209     // integers to parse.
210     llvm::SmallVector<int64_t> values;
211     auto int64_parser = [&]() {
212       return parser.parseInteger(values.emplace_back(0));
213     };
214 
215     if (attribute_name == "pad") {
216       // Parse a 2D array of integers.
217       auto inner_parser = [&]() {
218         return parse_array(int64_parser, /*expected_size=*/2);
219       };
220       if (parse_array(inner_parser)) {
221         return failure();
222       }
223       const int64_t size = static_cast<int64_t>(values.size());
224       // values should be filled with the Nx2 padding values.
225       auto ty = RankedTensorType::get({size / 2, 2},
226                                       parser.getBuilder().getIntegerType(64));
227       padding = DenseIntElementsAttr::get(ty, values);
228     } else {
229       // Parse 1D array of integers.
230       if (parse_array(int64_parser)) {
231         return failure();
232       }
233       const int64_t size = static_cast<int64_t>(values.size());
234       if (attribute_name == "reverse") {
235         auto ty = RankedTensorType::get({size},
236                                         parser.getBuilder().getIntegerType(1));
237         auto bool_vector = llvm::to_vector<4>(
238             llvm::map_range(values, [](int64_t v) { return v != 0; }));
239         window_reversal = DenseElementsAttr::get(ty, bool_vector);
240       } else {
241         auto attr = parser.getBuilder().getI64TensorAttr(values);
242 
243         if (attribute_name == "stride") {
244           window_strides = attr;
245         } else if (attribute_name == "lhs_dilate") {
246           lhs_dilation = attr;
247         } else if (attribute_name == "rhs_dilate") {
248           rhs_dilation = attr;
249         } else {
250           llvm_unreachable("Unexpected attribute name");
251         }
252       }
253     }
254     // continue parsing if there is a comma at the end.
255     if (parser.parseOptionalComma().failed()) break;
256   }
257   return success();
258 }
259 
260 }  // namespace hlo
261 }  // namespace mlir
262