1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17
18 #include "mlir/Transforms/Bufferize.h" // from @llvm-project
19
20 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
21 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
28 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
30 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
31 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
32
33 namespace mlir {
34 namespace kernel_gen {
35 namespace transforms {
36 namespace {
37
38 class BufferizeConstantOp : public OpConversionPattern<ConstantOp> {
39 public:
40 using OpConversionPattern<ConstantOp>::OpConversionPattern;
41
matchAndRewrite(ConstantOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const42 LogicalResult matchAndRewrite(
43 ConstantOp op, ArrayRef<Value> operands,
44 ConversionPatternRewriter &rewriter) const final {
45 // We only need to bufferize tensor constants.
46 Location loc = op.getLoc();
47 auto result_type = op.getType().dyn_cast<RankedTensorType>();
48 int64_t result_rank = result_type.getRank();
49 if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
50 return failure();
51
52 auto memref_type =
53 MemRefType::get(result_type.getShape(), result_type.getElementType());
54 auto elements_attr = op.value().cast<DenseElementsAttr>();
55
56 if (result_rank == 0) {
57 Value buffer = rewriter.create<memref::AllocOp>(loc, memref_type);
58 Value constant =
59 rewriter.create<ConstantOp>(loc, elements_attr.getValue({}));
60 rewriter.create<memref::StoreOp>(loc, constant, buffer);
61 rewriter.replaceOp(op, {buffer});
62 return success();
63 }
64
65 Value buffer = rewriter.create<memref::AllocaOp>(loc, memref_type);
66
67 bool all_same_elems = elements_attr.isSplat();
68 Value value;
69 if (all_same_elems)
70 value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
71 for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
72 if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
73 Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
74 rewriter.create<memref::StoreOp>(loc, value, buffer, index);
75 }
76 rewriter.replaceOp(op, {buffer});
77 return success();
78 }
79 };
80
81 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
82 public:
83 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(tensor::DimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const84 LogicalResult matchAndRewrite(
85 tensor::DimOp op, ArrayRef<Value> operands,
86 ConversionPatternRewriter &rewriter) const override {
87 tensor::DimOp::Adaptor adaptor(operands);
88 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
89 adaptor.index());
90 return success();
91 }
92 };
93
94 class BufferizeAndConvertMinimumBroadcastShapesOp
95 : public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
96 public:
97 using OpConversionPattern<
98 chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
99
matchAndRewrite(chlo::MinimumBroadcastShapesOp broadcast_shapes_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const100 LogicalResult matchAndRewrite(
101 chlo::MinimumBroadcastShapesOp broadcast_shapes_op,
102 ArrayRef<Value> operands,
103 ConversionPatternRewriter &rewriter) const override {
104 chlo::MinimumBroadcastShapesOp::Adaptor adaptor(operands);
105 auto loc = broadcast_shapes_op.getLoc();
106 ImplicitLocOpBuilder lb(loc, rewriter);
107 Value zero = lb.create<ConstantIndexOp>(0);
108 SmallVector<Value> shapes = adaptor.shapes();
109 size_t k = shapes.size();
110 SmallVector<Value> ranks;
111 ranks.reserve(k);
112
113 // Determine the maximum rank of the operands.
114 Value max_rank;
115 for (size_t i = 0; i < k; ++i) {
116 Value rank = lb.create<memref::DimOp>(loc, shapes[i], zero);
117 ranks.push_back(rank);
118 if (i) {
119 Value rank_is_greater =
120 lb.create<CmpIOp>(CmpIPredicate::ugt, ranks[i], max_rank);
121 max_rank = lb.create<SelectOp>(rank_is_greater, ranks[i], max_rank);
122 } else {
123 max_rank = ranks[0];
124 }
125 }
126
127 // Allocate buffers for the return values and initialize them with 1's.
128 SmallVector<Value> result_shapes;
129 result_shapes.reserve(k);
130 auto result_type =
131 MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
132 Value one = lb.create<ConstantIndexOp>(1);
133 for (size_t i = 0; i < k; ++i) {
134 // We assume the buffer will be small, so we allocate it on the stack.
135 // TODO(b/181654096): Replace AllocaOp with AllocOp.
136 auto result = lb.create<memref::AllocaOp>(result_type, ranks[i]);
137 lb.create<scf::ForOp>(zero, ranks[i], one, llvm::None,
138 [&one, &result](OpBuilder &b, Location l, Value idx,
139 ValueRange /*vr*/) {
140 b.create<memref::StoreOp>(l, one, result, idx);
141 b.create<scf::YieldOp>(l, llvm::None);
142 });
143 result_shapes.push_back(result);
144 }
145
146 // Iterate through the dimensions and determine which adjacent dimensions
147 // can be combined. Keep a running product of the dimensions that can be
148 // combined as iteration variable (initialized to 1), and the current
149 // dimension offset in the result shapes. We iterate through the shapes
150 // backward, because the broadcasting semantics mean that the last
151 // dimensions of each shape (the least significant ones) are matched
152 // together.
153 Value two = lb.create<ConstantIndexOp>(2);
154 Value max_rank_plus_two = lb.create<AddIOp>(loc, max_rank, two);
155 Value constant_false =
156 lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(false));
157 SmallVector<Value> init_values;
158 init_values.reserve(k + 3);
159 // Initially, all values are marked as not broadcasted.
160 for (int i = 0; i < k; ++i) {
161 init_values.push_back(constant_false);
162 }
163 // The running product is initially 1.
164 init_values.push_back(one);
165 // The current dimension offset is initially 0.
166 init_values.push_back(zero);
167 // Whether the broadcasting is invalid.
168 init_values.push_back(constant_false);
169
170 // Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
171 // used as an offset from the end of each shape vector. We iterate until
172 // max_rank + 1 to handle the case that we have a running_product > 1 left
173 // when we have processed all dimensions of the largest shape.
174 auto main_loop = lb.create<scf::ForOp>(
175 one, max_rank_plus_two, one, init_values,
176 [&](OpBuilder &b, Location l, Value v, ValueRange vr) {
177 // 'same_size' should track what the size of the dimension is to which
178 // the 1-sized dimensions are broadcasted. If all of the dimensions
179 // are 1, it will stay 1.
180 Value same_size = one;
181 // 'result_dimensions' stores the current dimension with an offset of
182 // 'leading_ones' to make it easier to check whether we are in-bounds
183 // with respect to the "real" shape with leading 1's removed.
184 SmallVector<Value> result_dimensions;
185 result_dimensions.reserve(k);
186 // 'no_broadcasting' stores boolean flags that encode whether the
187 // corresponding shape does not need broadcasting at the current
188 // position.
189 SmallVector<Value> no_broadcasting;
190 no_broadcasting.reserve(k + 3);
191 // The first k loop carried values are the previous broadcasting
192 // state.
193 auto prev_no_broadcasting = vr.take_front(k);
194
195 // This loop checks which shapes need broadcasting at the current
196 // dimension. A shape needs broadcasting if it is indexed out of
197 // bounds, or its current dimension size is 1.
198 Value current_dimension_has_invalid_broadcast = constant_false;
199 for (size_t i = 0; i < k; ++i) {
200 // Determine the size of the current dimension. If the dimension is
201 // out of bounds, we choose the value 'one'.
202 Value is_out_of_bounds =
203 b.create<CmpIOp>(l, CmpIPredicate::ult, ranks[i], v);
204 Value dimension = b.create<SubIOp>(l, ranks[i], v);
205 result_dimensions.push_back(dimension);
206 Value current_size =
207 b.create<scf::IfOp>(
208 l, TypeRange{b.getIndexType()}, is_out_of_bounds,
209 [&](OpBuilder &b, Location l) {
210 b.create<scf::YieldOp>(l, one);
211 },
212 [&](OpBuilder &b, Location l) {
213 // Using IfOp instead of SelectOp makes sure that we
214 // don't try to load if the dimension is out of bounds.
215 Value size =
216 b.create<memref::LoadOp>(l, shapes[i], dimension);
217 b.create<scf::YieldOp>(l, size);
218 })
219 .getResult(0);
220 // Compute whether the current dimension does require broadcasting.
221 Value current_size_is_not_one =
222 b.create<CmpIOp>(l, CmpIPredicate::ne, current_size, one);
223 no_broadcasting.push_back(current_size_is_not_one);
224 Value new_same_size = b.create<SelectOp>(l, current_size_is_not_one,
225 current_size, same_size);
226 Value same_size_was_not_one =
227 b.create<CmpIOp>(l, CmpIPredicate::ne, same_size, one);
228 Value is_different_size = b.create<CmpIOp>(
229 l, CmpIPredicate::ne, same_size, new_same_size);
230 // The broadcast is invalid if the size of the current dimension
231 // is not equal to the expected size, unless the expected size was
232 // still the initial value 1.
233 Value is_invalid =
234 b.create<AndOp>(l, same_size_was_not_one, is_different_size);
235 current_dimension_has_invalid_broadcast = b.create<OrOp>(
236 l, current_dimension_has_invalid_broadcast, is_invalid);
237 same_size = new_same_size;
238 }
239
240 // Check whether we have at least one shape that has a different
241 // status regarding whether it needs broadcasting at the current
242 // dimension versus whether it needs broadcasting at the previous
243 // dimension.
244 Value same_size_is_one =
245 b.create<CmpIOp>(l, CmpIPredicate::eq, same_size, one);
246 Value different_broadcasting_set = constant_false;
247 for (size_t i = 0; i < k; ++i) {
248 // If all dimensions are 1, we preserve the status whether a shape
249 // needs broadcasting or not, because in that case the dimension can
250 // just be ignored.
251 no_broadcasting[i] =
252 b.create<SelectOp>(l, same_size_is_one, prev_no_broadcasting[i],
253 no_broadcasting[i]);
254 // Compare whether the current shape changes its status regarding
255 // whether it needs broadcasting at the current dimension.
256 Value broadcasting_is_different =
257 b.create<CmpIOp>(l, CmpIPredicate::ne, prev_no_broadcasting[i],
258 no_broadcasting[i]);
259 different_broadcasting_set = b.create<OrOp>(
260 l, different_broadcasting_set, broadcasting_is_different);
261 }
262 Value running_product = vr[k];
263 Value current_dimension_offset = vr[k + 1];
264
265 // We need to stop combining dimensions if the set of shapes which
266 // need broadcasting at the current dimension changes compared to the
267 // set of shapes needing broadcasting at the previous dimension.
268 Value is_last_iteration =
269 b.create<CmpIOp>(l, CmpIPredicate::sgt, v, max_rank);
270 Value stop_combining_dimensions =
271 b.create<OrOp>(l, is_last_iteration, different_broadcasting_set);
272 auto if_stop_combining_dimensions = b.create<scf::IfOp>(
273 l, TypeRange{b.getIndexType(), b.getIndexType()},
274 stop_combining_dimensions,
275 [&](OpBuilder &b, Location l) {
276 // If the running product is not 1, add one dimension of size
277 // 'running_product' to each shape that didn't need
278 // broadcasting, otherwise add a 1 dimension if it was
279 // previously indexed in-bounds.
280 Value running_product_not_one = b.create<CmpIOp>(
281 l, CmpIPredicate::ne, running_product, one);
282 Value new_dimension_offset =
283 b.create<scf::IfOp>(
284 l, TypeRange{b.getIndexType()},
285 running_product_not_one,
286 [&](OpBuilder &b, Location l) {
287 Value new_dimension_offset = b.create<AddIOp>(
288 l, current_dimension_offset, one);
289 Value minus_one = lb.create<ConstantIndexOp>(-1);
290 for (size_t i = 0; i < k; ++i) {
291 Value was_in_bounds = b.create<CmpIOp>(
292 l, CmpIPredicate::sge, result_dimensions[i],
293 minus_one);
294 Value should_store_dimension = b.create<OrOp>(
295 l, was_in_bounds, prev_no_broadcasting[i]);
296 b.create<scf::IfOp>(
297 l, should_store_dimension,
298 [&](OpBuilder &b, Location l) {
299 Value output_dimension = b.create<SubIOp>(
300 l, ranks[i], new_dimension_offset);
301 // If the shape needed broadcasting at the
302 // previous dimension, we set the output size
303 // to 1, otherwise to 'running_product'.
304 Value output_size = b.create<SelectOp>(
305 l, prev_no_broadcasting[i],
306 running_product, one);
307 b.create<memref::StoreOp>(l, output_size,
308 result_shapes[i],
309 output_dimension);
310 b.create<scf::YieldOp>(l, llvm::None);
311 });
312 }
313 b.create<scf::YieldOp>(l, new_dimension_offset);
314 },
315 [&](OpBuilder &b, Location l) {
316 b.create<scf::YieldOp>(l, current_dimension_offset);
317 })
318 .getResult(0);
319 b.create<scf::YieldOp>(
320 l, ValueRange{same_size, new_dimension_offset});
321 },
322 [&](OpBuilder &b, Location l) {
323 Value new_running_product =
324 b.create<MulIOp>(l, running_product, same_size);
325 b.create<scf::YieldOp>(l, ValueRange{new_running_product,
326 current_dimension_offset});
327 });
328 // Add the remaining results.
329 no_broadcasting.push_back(if_stop_combining_dimensions.getResult(0));
330 no_broadcasting.push_back(if_stop_combining_dimensions.getResult(1));
331 Value is_invalid = vr.back();
332 is_invalid = b.create<OrOp>(l, is_invalid,
333 current_dimension_has_invalid_broadcast);
334 no_broadcasting.push_back(is_invalid);
335 b.create<scf::YieldOp>(l, no_broadcasting);
336 });
337 Value is_invalid = main_loop.getResults().back();
338 for (size_t i = 0; i < k; ++i) {
339 result_shapes[i] =
340 RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], ranks[i]);
341 result_shapes[i] =
342 lb.create<SelectOp>(is_invalid, shapes[i], result_shapes[i]);
343 }
344 rewriter.replaceOp(broadcast_shapes_op, result_shapes);
345 return success();
346 }
347
348 private:
CountLeadingOnes(ImplicitLocOpBuilder & lb,Value extent_memref,Value rank) const349 Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref,
350 Value rank) const {
351 // Count leading 1's. Use two iteration variables for that: one with a
352 // boolean flag for whether every size so far was 1, one with the number of
353 // leading 1's.
354 Value constant_true =
355 lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
356 Value zero = lb.create<ConstantIndexOp>(0);
357 Value one = lb.create<ConstantIndexOp>(1);
358 auto leading_ones_loop = lb.create<scf::ForOp>(
359 zero, rank, one, ValueRange{constant_true, zero},
360 [&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
361 auto size = b.create<memref::LoadOp>(l, extent_memref, idx);
362 auto is_equal_to_one =
363 b.create<CmpIOp>(l, CmpIPredicate::eq, size, one);
364 auto all_ones = b.create<AndOp>(l, vr.front(), is_equal_to_one);
365 auto increased_value = b.create<AddIOp>(l, vr.back(), one);
366 auto number_of_leading_ones =
367 b.create<SelectOp>(l, all_ones, increased_value, vr.back());
368 b.create<scf::YieldOp>(l,
369 ValueRange{all_ones, number_of_leading_ones});
370 });
371 return leading_ones_loop.results()[1];
372 }
373
RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder & lb,Value extent_memref,Value rank) const374 Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
375 Value extent_memref, Value rank) const {
376 Value leading_ones = CountLeadingOnes(lb, extent_memref, rank);
377 Value new_rank = lb.create<SubIOp>(rank, leading_ones);
378 auto result_type =
379 MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
380 // We cannot use SubView here to return a MemRef with 'leading_ones' as
381 // offset, because that also changes the size, so the result type would need
382 // to have an affine map to change the layout. This is incompatible to our
383 // other MemRef types without affine map. So instead we just allocate
384 // another buffer of the desired size and copy the elements over. We assume
385 // the buffer will be small, so we allocate it on the stack.
386 // TODO(b/181654096): Replace AllocaOp with AllocOp.
387 Value result = lb.create<memref::AllocaOp>(result_type, new_rank);
388 Value zero = lb.create<ConstantIndexOp>(0);
389 Value one = lb.create<ConstantIndexOp>(1);
390 lb.create<scf::ForOp>(
391 zero, new_rank, one, llvm::None,
392 [&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
393 Value idx_with_offset = b.create<AddIOp>(l, idx, leading_ones);
394 auto size =
395 b.create<memref::LoadOp>(l, extent_memref, idx_with_offset);
396 b.create<memref::StoreOp>(l, size, result, idx);
397 b.create<scf::YieldOp>(l, llvm::None);
398 });
399 return result;
400 }
401 };
402
403 struct BufferizeJITExecuteOp
404 : public OpConversionPattern<tf_framework::JITExecuteOp> {
405 using OpConversionPattern::OpConversionPattern;
406
matchAndRewritemlir::kernel_gen::transforms::__anon43e10c2b0111::BufferizeJITExecuteOp407 LogicalResult matchAndRewrite(
408 tf_framework::JITExecuteOp op, ArrayRef<Value> operands,
409 ConversionPatternRewriter &rewriter) const override {
410 SmallVector<Type, 2> result_types;
411 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
412 result_types))) {
413 return failure();
414 }
415 rewriter.replaceOpWithNewOp<tf_framework::JITExecuteOp>(
416 op, result_types, operands, op->getAttrs());
417 return success();
418 }
419 };
420
421 class BufferizeRankOp : public OpConversionPattern<RankOp> {
422 public:
423 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const424 LogicalResult matchAndRewrite(
425 RankOp op, ArrayRef<Value> operands,
426 ConversionPatternRewriter &rewriter) const override {
427 RankOp::Adaptor adaptor(operands);
428 rewriter.replaceOpWithNewOp<RankOp>(op, adaptor.memrefOrTensor());
429 return success();
430 }
431 };
432
433 } // namespace
434
populateExtraBufferizePatterns(MLIRContext * context,BufferizeTypeConverter * converter,RewritePatternSet * patterns)435 void populateExtraBufferizePatterns(MLIRContext *context,
436 BufferizeTypeConverter *converter,
437 RewritePatternSet *patterns) {
438 // clang-format off
439 patterns->insert<
440 BufferizeAndConvertMinimumBroadcastShapesOp,
441 BufferizeConstantOp,
442 BufferizeDimOp,
443 BufferizeJITExecuteOp,
444 BufferizeRankOp>(*converter, context);
445 // clang-format on
446 }
447
448 } // namespace transforms
449 } // namespace kernel_gen
450 } // namespace mlir
451