1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSet.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23
24 namespace mlir {
25 namespace hlo {
26 // Verifies the source target pairs attached to collective permute.
VerifyCollectivePermuteSourceTargetPairs(Operation * op,DenseIntElementsAttr attr)27 LogicalResult VerifyCollectivePermuteSourceTargetPairs(
28 Operation *op, DenseIntElementsAttr attr) {
29 auto type = attr.getType().dyn_cast<RankedTensorType>();
30 if (type.getRank() != 2)
31 return op->emitError() << "expect source_target_pairs attribute to be of "
32 "rank 2, but got rank "
33 << type.getRank();
34 if (type.getShape()[1] != 2)
35 return op->emitError()
36 << "expect source_target_pairs attribute of shape (N, 2), but got ("
37 << type.getShape() << ")";
38 // Check source target pairs for duplicate sources or targets.
39 llvm::DenseSet<int64_t> sources;
40 llvm::DenseSet<int64_t> targets;
41 for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
42 auto val = (*i).getSExtValue();
43 if (i.getIndex() % 2 == 0) {
44 bool is_unique = sources.insert(val).second;
45 if (!is_unique)
46 return op->emitError() << "duplicate sources not allowed.";
47 } else {
48 bool is_unique = targets.insert(val).second;
49 if (!is_unique)
50 return op->emitError() << "duplicate targets not allowed.";
51 }
52 }
53 return success();
54 }
55
VerifyReduceScatter(Operation * op,TypeRange operand_types,TypeRange result_types,uint64_t scatter_dimension)56 LogicalResult VerifyReduceScatter(Operation *op, TypeRange operand_types,
57 TypeRange result_types,
58 uint64_t scatter_dimension) {
59 // If operand and result are both ranked, then the size of the scatter
60 // dimension in the operand should be a multiple of the size of the scatter
61 // dimension in the result.
62 for (auto it : llvm::zip(operand_types, result_types)) {
63 auto operand_type = std::get<0>(it).cast<ShapedType>();
64 auto result_type = std::get<1>(it).cast<ShapedType>();
65 if (!operand_type.hasRank() || !result_type.hasRank()) continue;
66 if (operand_type.getRank() != result_type.getRank())
67 return op->emitOpError() << "operand and result should have same rank";
68 if (scatter_dimension >= operand_type.getRank())
69 return op->emitOpError()
70 << "scatter dim should be less than operand/result rank";
71 if (operand_type.isDynamicDim(scatter_dimension) ||
72 result_type.isDynamicDim(scatter_dimension))
73 continue;
74 if (operand_type.getDimSize(scatter_dimension) == 0)
75 return op->emitOpError() << "operand scatter dimension cannot be zero";
76 if (result_type.getDimSize(scatter_dimension) == 0)
77 return op->emitOpError() << "result scatter dimension cannot be zero";
78 if ((operand_type.getDimSize(scatter_dimension) %
79 result_type.getDimSize(scatter_dimension)) != 0)
80 return op->emitOpError()
81 << "operand scatter dimension has size "
82 << operand_type.getDimSize(scatter_dimension)
83 << ", expected to be a multiple of result scatter dimension size "
84 << result_type.getDimSize(scatter_dimension);
85
86 // Non scatter dimensions should be equal.
87 for (uint64_t index : llvm::seq<uint64_t>(0, operand_type.getRank())) {
88 if (index == scatter_dimension || operand_type.isDynamicDim(index) ||
89 result_type.isDynamicDim(index))
90 continue;
91 if (operand_type.getDimSize(index) != result_type.getDimSize(index))
92 return op->emitOpError()
93 << "non scatter dimensions should be same for operand ("
94 << operand_type.getDimSize(index) << ") and result ("
95 << result_type.getDimSize(index) << ")";
96 }
97 }
98 return success();
99 }
100
101 namespace {
102 // Custom formatting for convolution window attributes.
printWindowAttribute(OpAsmPrinter & p,DenseElementsAttr attribute)103 void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
104 if (attribute.getType().getElementType().isInteger(/*width=*/1)) {
105 // boolean attribute.
106 llvm::interleaveComma(attribute.getBoolValues(), p,
107 [&](bool b) { p << (b ? 1 : 0); });
108 return;
109 }
110 if (attribute.getType().getRank() == 2) {
111 // Padding is Nx2 attribute.
112 auto it = attribute.getValues<int64_t>().begin();
113 std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
114 2);
115 for (auto &item : values) {
116 int64_t first = *it;
117 ++it;
118 int64_t second = *it;
119 ++it;
120 item = {first, second};
121 }
122 llvm::interleaveComma(
123 values, p, [&](const std::pair<int64_t, int64_t> pair) {
124 p << '[' << pair.first << ", " << pair.second << ']';
125 });
126 } else {
127 llvm::interleaveComma(attribute.getValues<int64_t>(), p);
128 }
129 }
130 } // namespace
131
printWindowAttributes(OpAsmPrinter & p,Operation * op,llvm::Optional<DenseIntElementsAttr> window_strides,llvm::Optional<DenseIntElementsAttr> padding,llvm::Optional<DenseIntElementsAttr> lhs_dilation,llvm::Optional<DenseIntElementsAttr> rhs_dilation,llvm::Optional<DenseElementsAttr> window_reversal)132 void printWindowAttributes(OpAsmPrinter &p, Operation *op,
133 llvm::Optional<DenseIntElementsAttr> window_strides,
134 llvm::Optional<DenseIntElementsAttr> padding,
135 llvm::Optional<DenseIntElementsAttr> lhs_dilation,
136 llvm::Optional<DenseIntElementsAttr> rhs_dilation,
137 llvm::Optional<DenseElementsAttr> window_reversal) {
138 using pair_t = std::pair<DenseElementsAttr, StringRef>;
139 std::array<pair_t, 5> printed_attributes = {{
140 {window_strides ? *window_strides : nullptr, "stride"},
141 {padding ? *padding : nullptr, "pad"},
142 {lhs_dilation ? *lhs_dilation : nullptr, "lhs_dilate"},
143 {rhs_dilation ? *rhs_dilation : nullptr, "rhs_dilate"},
144 {window_reversal ? *window_reversal : nullptr, "reverse"},
145 }};
146
147 // Do not print attributes that do no exist.
148 auto non_null_attributes = llvm::make_filter_range(
149 printed_attributes,
150 [](const pair_t &a) { return static_cast<bool>(a.first); });
151
152 llvm::interleaveComma(non_null_attributes, p, [&](const pair_t &a) {
153 p << a.second << " = [";
154 printWindowAttribute(p, a.first);
155 p << "]";
156 });
157 }
158
parseWindowAttributes(OpAsmParser & parser,DenseIntElementsAttr & window_strides,DenseIntElementsAttr & padding,DenseIntElementsAttr & lhs_dilation,DenseIntElementsAttr & rhs_dilation,DenseElementsAttr & window_reversal)159 ParseResult parseWindowAttributes(OpAsmParser &parser,
160 DenseIntElementsAttr &window_strides,
161 DenseIntElementsAttr &padding,
162 DenseIntElementsAttr &lhs_dilation,
163 DenseIntElementsAttr &rhs_dilation,
164 DenseElementsAttr &window_reversal) {
165 StringRef attribute_name;
166
167 // Helper to parse an array of the form [ e0, e1, .. ]
168 auto parse_array = [&](std::function<ParseResult(void)> parse_element,
169 llvm::Optional<size_t> expected_size =
170 llvm::None) -> ParseResult {
171 if (parser.parseLSquare()) {
172 return failure();
173 }
174 size_t size = 0;
175 do {
176 if (parse_element()) {
177 return failure();
178 }
179 size++;
180 } while (parser.parseOptionalComma().succeeded());
181 if (parser.parseRSquare()) {
182 return failure();
183 }
184 if (expected_size && size != *expected_size) {
185 return parser.emitError(parser.getCurrentLocation(),
186 "Expected array with")
187 << *expected_size << " elements, got " << size
188 << " elements instead";
189 }
190 return success();
191 };
192
193 llvm::StringSet<> allowed_attribute_names{
194 {"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
195
196 while (parser.parseOptionalKeyword(&attribute_name).succeeded()) {
197 // Verify that the attribute name is valid and erase it.
198 if (!allowed_attribute_names.erase(attribute_name)) {
199 return parser.emitError(parser.getCurrentLocation(),
200 "Unexpected keyword ")
201 << attribute_name;
202 }
203
204 if (parser.parseEqual()) {
205 return failure();
206 }
207
208 // parse the attribute value. We need to support either 1D and Nx2 array of
209 // integers to parse.
210 llvm::SmallVector<int64_t> values;
211 auto int64_parser = [&]() {
212 return parser.parseInteger(values.emplace_back(0));
213 };
214
215 if (attribute_name == "pad") {
216 // Parse a 2D array of integers.
217 auto inner_parser = [&]() {
218 return parse_array(int64_parser, /*expected_size=*/2);
219 };
220 if (parse_array(inner_parser)) {
221 return failure();
222 }
223 const int64_t size = static_cast<int64_t>(values.size());
224 // values should be filled with the Nx2 padding values.
225 auto ty = RankedTensorType::get({size / 2, 2},
226 parser.getBuilder().getIntegerType(64));
227 padding = DenseIntElementsAttr::get(ty, values);
228 } else {
229 // Parse 1D array of integers.
230 if (parse_array(int64_parser)) {
231 return failure();
232 }
233 const int64_t size = static_cast<int64_t>(values.size());
234 if (attribute_name == "reverse") {
235 auto ty = RankedTensorType::get({size},
236 parser.getBuilder().getIntegerType(1));
237 auto bool_vector = llvm::to_vector<4>(
238 llvm::map_range(values, [](int64_t v) { return v != 0; }));
239 window_reversal = DenseElementsAttr::get(ty, bool_vector);
240 } else {
241 auto attr = parser.getBuilder().getI64TensorAttr(values);
242
243 if (attribute_name == "stride") {
244 window_strides = attr;
245 } else if (attribute_name == "lhs_dilate") {
246 lhs_dilation = attr;
247 } else if (attribute_name == "rhs_dilate") {
248 rhs_dilation = attr;
249 } else {
250 llvm_unreachable("Unexpected attribute name");
251 }
252 }
253 }
254 // continue parsing if there is a comma at the end.
255 if (parser.parseOptionalComma().failed()) break;
256 }
257 return success();
258 }
259
260 } // namespace hlo
261 } // namespace mlir
262