• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 
18 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
19 
20 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
28 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
30 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
31 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
32 
33 namespace mlir {
34 namespace kernel_gen {
35 namespace transforms {
36 namespace {
37 
38 class BufferizeConstantOp : public OpConversionPattern<ConstantOp> {
39  public:
40   using OpConversionPattern<ConstantOp>::OpConversionPattern;
41 
matchAndRewrite(ConstantOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const42   LogicalResult matchAndRewrite(
43       ConstantOp op, ArrayRef<Value> operands,
44       ConversionPatternRewriter &rewriter) const final {
45     // We only need to bufferize tensor constants.
46     Location loc = op.getLoc();
47     auto result_type = op.getType().dyn_cast<RankedTensorType>();
48     int64_t result_rank = result_type.getRank();
49     if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
50       return failure();
51 
52     auto memref_type =
53         MemRefType::get(result_type.getShape(), result_type.getElementType());
54     auto elements_attr = op.value().cast<DenseElementsAttr>();
55 
56     if (result_rank == 0) {
57       Value buffer = rewriter.create<memref::AllocOp>(loc, memref_type);
58       Value constant =
59           rewriter.create<ConstantOp>(loc, elements_attr.getValue({}));
60       rewriter.create<memref::StoreOp>(loc, constant, buffer);
61       rewriter.replaceOp(op, {buffer});
62       return success();
63     }
64 
65     Value buffer = rewriter.create<memref::AllocaOp>(loc, memref_type);
66 
67     bool all_same_elems = elements_attr.isSplat();
68     Value value;
69     if (all_same_elems)
70       value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
71     for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
72       if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
73       Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
74       rewriter.create<memref::StoreOp>(loc, value, buffer, index);
75     }
76     rewriter.replaceOp(op, {buffer});
77     return success();
78   }
79 };
80 
81 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
82  public:
83   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(tensor::DimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const84   LogicalResult matchAndRewrite(
85       tensor::DimOp op, ArrayRef<Value> operands,
86       ConversionPatternRewriter &rewriter) const override {
87     tensor::DimOp::Adaptor adaptor(operands);
88     rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
89                                                adaptor.index());
90     return success();
91   }
92 };
93 
94 class BufferizeAndConvertMinimumBroadcastShapesOp
95     : public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
96  public:
97   using OpConversionPattern<
98       chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
99 
matchAndRewrite(chlo::MinimumBroadcastShapesOp broadcast_shapes_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const100   LogicalResult matchAndRewrite(
101       chlo::MinimumBroadcastShapesOp broadcast_shapes_op,
102       ArrayRef<Value> operands,
103       ConversionPatternRewriter &rewriter) const override {
104     chlo::MinimumBroadcastShapesOp::Adaptor adaptor(operands);
105     auto loc = broadcast_shapes_op.getLoc();
106     ImplicitLocOpBuilder lb(loc, rewriter);
107     Value zero = lb.create<ConstantIndexOp>(0);
108     SmallVector<Value> shapes = adaptor.shapes();
109     size_t k = shapes.size();
110     SmallVector<Value> ranks;
111     ranks.reserve(k);
112 
113     // Determine the maximum rank of the operands.
114     Value max_rank;
115     for (size_t i = 0; i < k; ++i) {
116       Value rank = lb.create<memref::DimOp>(loc, shapes[i], zero);
117       ranks.push_back(rank);
118       if (i) {
119         Value rank_is_greater =
120             lb.create<CmpIOp>(CmpIPredicate::ugt, ranks[i], max_rank);
121         max_rank = lb.create<SelectOp>(rank_is_greater, ranks[i], max_rank);
122       } else {
123         max_rank = ranks[0];
124       }
125     }
126 
127     // Allocate buffers for the return values and initialize them with 1's.
128     SmallVector<Value> result_shapes;
129     result_shapes.reserve(k);
130     auto result_type =
131         MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
132     Value one = lb.create<ConstantIndexOp>(1);
133     for (size_t i = 0; i < k; ++i) {
134       // We assume the buffer will be small, so we allocate it on the stack.
135       // TODO(b/181654096): Replace AllocaOp with AllocOp.
136       auto result = lb.create<memref::AllocaOp>(result_type, ranks[i]);
137       lb.create<scf::ForOp>(zero, ranks[i], one, llvm::None,
138                             [&one, &result](OpBuilder &b, Location l, Value idx,
139                                             ValueRange /*vr*/) {
140                               b.create<memref::StoreOp>(l, one, result, idx);
141                               b.create<scf::YieldOp>(l, llvm::None);
142                             });
143       result_shapes.push_back(result);
144     }
145 
146     // Iterate through the dimensions and determine which adjacent dimensions
147     // can be combined. Keep a running product of the dimensions that can be
148     // combined as iteration variable (initialized to 1), and the current
149     // dimension offset in the result shapes. We iterate through the shapes
150     // backward, because the broadcasting semantics mean that the last
151     // dimensions of each shape (the least significant ones) are matched
152     // together.
153     Value two = lb.create<ConstantIndexOp>(2);
154     Value max_rank_plus_two = lb.create<AddIOp>(loc, max_rank, two);
155     Value constant_false =
156         lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(false));
157     SmallVector<Value> init_values;
158     init_values.reserve(k + 3);
159     // Initially, all values are marked as not broadcasted.
160     for (int i = 0; i < k; ++i) {
161       init_values.push_back(constant_false);
162     }
163     // The running product is initially 1.
164     init_values.push_back(one);
165     // The current dimension offset is initially 0.
166     init_values.push_back(zero);
167     // Whether the broadcasting is invalid.
168     init_values.push_back(constant_false);
169 
170     // Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
171     // used as an offset from the end of each shape vector. We iterate until
172     // max_rank + 1 to handle the case that we have a running_product > 1 left
173     // when we have processed all dimensions of the largest shape.
174     auto main_loop = lb.create<scf::ForOp>(
175         one, max_rank_plus_two, one, init_values,
176         [&](OpBuilder &b, Location l, Value v, ValueRange vr) {
177           // 'same_size' should track what the size of the dimension is to which
178           // the 1-sized dimensions are broadcasted. If all of the dimensions
179           // are 1, it will stay 1.
180           Value same_size = one;
181           // 'result_dimensions' stores the current dimension with an offset of
182           // 'leading_ones' to make it easier to check whether we are in-bounds
183           // with respect to the "real" shape with leading 1's removed.
184           SmallVector<Value> result_dimensions;
185           result_dimensions.reserve(k);
186           // 'no_broadcasting' stores boolean flags that encode whether the
187           // corresponding shape does not need broadcasting at the current
188           // position.
189           SmallVector<Value> no_broadcasting;
190           no_broadcasting.reserve(k + 3);
191           // The first k loop carried values are the previous broadcasting
192           // state.
193           auto prev_no_broadcasting = vr.take_front(k);
194 
195           // This loop checks which shapes need broadcasting at the current
196           // dimension. A shape needs broadcasting if it is indexed out of
197           // bounds, or its current dimension size is 1.
198           Value current_dimension_has_invalid_broadcast = constant_false;
199           for (size_t i = 0; i < k; ++i) {
200             // Determine the size of the current dimension. If the dimension is
201             // out of bounds, we choose the value 'one'.
202             Value is_out_of_bounds =
203                 b.create<CmpIOp>(l, CmpIPredicate::ult, ranks[i], v);
204             Value dimension = b.create<SubIOp>(l, ranks[i], v);
205             result_dimensions.push_back(dimension);
206             Value current_size =
207                 b.create<scf::IfOp>(
208                      l, TypeRange{b.getIndexType()}, is_out_of_bounds,
209                      [&](OpBuilder &b, Location l) {
210                        b.create<scf::YieldOp>(l, one);
211                      },
212                      [&](OpBuilder &b, Location l) {
213                        // Using IfOp instead of SelectOp makes sure that we
214                        // don't try to load if the dimension is out of bounds.
215                        Value size =
216                            b.create<memref::LoadOp>(l, shapes[i], dimension);
217                        b.create<scf::YieldOp>(l, size);
218                      })
219                     .getResult(0);
220             // Compute whether the current dimension does require broadcasting.
221             Value current_size_is_not_one =
222                 b.create<CmpIOp>(l, CmpIPredicate::ne, current_size, one);
223             no_broadcasting.push_back(current_size_is_not_one);
224             Value new_same_size = b.create<SelectOp>(l, current_size_is_not_one,
225                                                      current_size, same_size);
226             Value same_size_was_not_one =
227                 b.create<CmpIOp>(l, CmpIPredicate::ne, same_size, one);
228             Value is_different_size = b.create<CmpIOp>(
229                 l, CmpIPredicate::ne, same_size, new_same_size);
230             // The broadcast is invalid if the size of the current dimension
231             // is not equal to the expected size, unless the expected size was
232             // still the initial value 1.
233             Value is_invalid =
234                 b.create<AndOp>(l, same_size_was_not_one, is_different_size);
235             current_dimension_has_invalid_broadcast = b.create<OrOp>(
236                 l, current_dimension_has_invalid_broadcast, is_invalid);
237             same_size = new_same_size;
238           }
239 
240           // Check whether we have at least one shape that has a different
241           // status regarding whether it needs broadcasting at the current
242           // dimension versus whether it needs broadcasting at the previous
243           // dimension.
244           Value same_size_is_one =
245               b.create<CmpIOp>(l, CmpIPredicate::eq, same_size, one);
246           Value different_broadcasting_set = constant_false;
247           for (size_t i = 0; i < k; ++i) {
248             // If all dimensions are 1, we preserve the status whether a shape
249             // needs broadcasting or not, because in that case the dimension can
250             // just be ignored.
251             no_broadcasting[i] =
252                 b.create<SelectOp>(l, same_size_is_one, prev_no_broadcasting[i],
253                                    no_broadcasting[i]);
254             // Compare whether the current shape changes its status regarding
255             // whether it needs broadcasting at the current dimension.
256             Value broadcasting_is_different =
257                 b.create<CmpIOp>(l, CmpIPredicate::ne, prev_no_broadcasting[i],
258                                  no_broadcasting[i]);
259             different_broadcasting_set = b.create<OrOp>(
260                 l, different_broadcasting_set, broadcasting_is_different);
261           }
262           Value running_product = vr[k];
263           Value current_dimension_offset = vr[k + 1];
264 
265           // We need to stop combining dimensions if the set of shapes which
266           // need broadcasting at the current dimension changes compared to the
267           // set of shapes needing broadcasting at the previous dimension.
268           Value is_last_iteration =
269               b.create<CmpIOp>(l, CmpIPredicate::sgt, v, max_rank);
270           Value stop_combining_dimensions =
271               b.create<OrOp>(l, is_last_iteration, different_broadcasting_set);
272           auto if_stop_combining_dimensions = b.create<scf::IfOp>(
273               l, TypeRange{b.getIndexType(), b.getIndexType()},
274               stop_combining_dimensions,
275               [&](OpBuilder &b, Location l) {
276                 // If the running product is not 1, add one dimension of size
277                 // 'running_product' to each shape that didn't need
278                 // broadcasting, otherwise add a 1 dimension if it was
279                 // previously indexed in-bounds.
280                 Value running_product_not_one = b.create<CmpIOp>(
281                     l, CmpIPredicate::ne, running_product, one);
282                 Value new_dimension_offset =
283                     b.create<scf::IfOp>(
284                          l, TypeRange{b.getIndexType()},
285                          running_product_not_one,
286                          [&](OpBuilder &b, Location l) {
287                            Value new_dimension_offset = b.create<AddIOp>(
288                                l, current_dimension_offset, one);
289                            Value minus_one = lb.create<ConstantIndexOp>(-1);
290                            for (size_t i = 0; i < k; ++i) {
291                              Value was_in_bounds = b.create<CmpIOp>(
292                                  l, CmpIPredicate::sge, result_dimensions[i],
293                                  minus_one);
294                              Value should_store_dimension = b.create<OrOp>(
295                                  l, was_in_bounds, prev_no_broadcasting[i]);
296                              b.create<scf::IfOp>(
297                                  l, should_store_dimension,
298                                  [&](OpBuilder &b, Location l) {
299                                    Value output_dimension = b.create<SubIOp>(
300                                        l, ranks[i], new_dimension_offset);
301                                    // If the shape needed broadcasting at the
302                                    // previous dimension, we set the output size
303                                    // to 1, otherwise to 'running_product'.
304                                    Value output_size = b.create<SelectOp>(
305                                        l, prev_no_broadcasting[i],
306                                        running_product, one);
307                                    b.create<memref::StoreOp>(l, output_size,
308                                                              result_shapes[i],
309                                                              output_dimension);
310                                    b.create<scf::YieldOp>(l, llvm::None);
311                                  });
312                            }
313                            b.create<scf::YieldOp>(l, new_dimension_offset);
314                          },
315                          [&](OpBuilder &b, Location l) {
316                            b.create<scf::YieldOp>(l, current_dimension_offset);
317                          })
318                         .getResult(0);
319                 b.create<scf::YieldOp>(
320                     l, ValueRange{same_size, new_dimension_offset});
321               },
322               [&](OpBuilder &b, Location l) {
323                 Value new_running_product =
324                     b.create<MulIOp>(l, running_product, same_size);
325                 b.create<scf::YieldOp>(l, ValueRange{new_running_product,
326                                                      current_dimension_offset});
327               });
328           // Add the remaining results.
329           no_broadcasting.push_back(if_stop_combining_dimensions.getResult(0));
330           no_broadcasting.push_back(if_stop_combining_dimensions.getResult(1));
331           Value is_invalid = vr.back();
332           is_invalid = b.create<OrOp>(l, is_invalid,
333                                       current_dimension_has_invalid_broadcast);
334           no_broadcasting.push_back(is_invalid);
335           b.create<scf::YieldOp>(l, no_broadcasting);
336         });
337     Value is_invalid = main_loop.getResults().back();
338     for (size_t i = 0; i < k; ++i) {
339       result_shapes[i] =
340           RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], ranks[i]);
341       result_shapes[i] =
342           lb.create<SelectOp>(is_invalid, shapes[i], result_shapes[i]);
343     }
344     rewriter.replaceOp(broadcast_shapes_op, result_shapes);
345     return success();
346   }
347 
348  private:
CountLeadingOnes(ImplicitLocOpBuilder & lb,Value extent_memref,Value rank) const349   Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref,
350                          Value rank) const {
351     // Count leading 1's. Use two iteration variables for that: one with a
352     // boolean flag for whether every size so far was 1, one with the number of
353     // leading 1's.
354     Value constant_true =
355         lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
356     Value zero = lb.create<ConstantIndexOp>(0);
357     Value one = lb.create<ConstantIndexOp>(1);
358     auto leading_ones_loop = lb.create<scf::ForOp>(
359         zero, rank, one, ValueRange{constant_true, zero},
360         [&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
361           auto size = b.create<memref::LoadOp>(l, extent_memref, idx);
362           auto is_equal_to_one =
363               b.create<CmpIOp>(l, CmpIPredicate::eq, size, one);
364           auto all_ones = b.create<AndOp>(l, vr.front(), is_equal_to_one);
365           auto increased_value = b.create<AddIOp>(l, vr.back(), one);
366           auto number_of_leading_ones =
367               b.create<SelectOp>(l, all_ones, increased_value, vr.back());
368           b.create<scf::YieldOp>(l,
369                                  ValueRange{all_ones, number_of_leading_ones});
370         });
371     return leading_ones_loop.results()[1];
372   }
373 
RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder & lb,Value extent_memref,Value rank) const374   Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
375                                       Value extent_memref, Value rank) const {
376     Value leading_ones = CountLeadingOnes(lb, extent_memref, rank);
377     Value new_rank = lb.create<SubIOp>(rank, leading_ones);
378     auto result_type =
379         MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
380     // We cannot use SubView here to return a MemRef with 'leading_ones' as
381     // offset, because that also changes the size, so the result type would need
382     // to have an affine map to change the layout. This is incompatible to our
383     // other MemRef types without affine map. So instead we just allocate
384     // another buffer of the desired size and copy the elements over. We assume
385     // the buffer will be small, so we allocate it on the stack.
386     // TODO(b/181654096): Replace AllocaOp with AllocOp.
387     Value result = lb.create<memref::AllocaOp>(result_type, new_rank);
388     Value zero = lb.create<ConstantIndexOp>(0);
389     Value one = lb.create<ConstantIndexOp>(1);
390     lb.create<scf::ForOp>(
391         zero, new_rank, one, llvm::None,
392         [&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
393           Value idx_with_offset = b.create<AddIOp>(l, idx, leading_ones);
394           auto size =
395               b.create<memref::LoadOp>(l, extent_memref, idx_with_offset);
396           b.create<memref::StoreOp>(l, size, result, idx);
397           b.create<scf::YieldOp>(l, llvm::None);
398         });
399     return result;
400   }
401 };
402 
403 struct BufferizeJITExecuteOp
404     : public OpConversionPattern<tf_framework::JITExecuteOp> {
405   using OpConversionPattern::OpConversionPattern;
406 
matchAndRewritemlir::kernel_gen::transforms::__anon43e10c2b0111::BufferizeJITExecuteOp407   LogicalResult matchAndRewrite(
408       tf_framework::JITExecuteOp op, ArrayRef<Value> operands,
409       ConversionPatternRewriter &rewriter) const override {
410     SmallVector<Type, 2> result_types;
411     if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
412                                                 result_types))) {
413       return failure();
414     }
415     rewriter.replaceOpWithNewOp<tf_framework::JITExecuteOp>(
416         op, result_types, operands, op->getAttrs());
417     return success();
418   }
419 };
420 
421 class BufferizeRankOp : public OpConversionPattern<RankOp> {
422  public:
423   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const424   LogicalResult matchAndRewrite(
425       RankOp op, ArrayRef<Value> operands,
426       ConversionPatternRewriter &rewriter) const override {
427     RankOp::Adaptor adaptor(operands);
428     rewriter.replaceOpWithNewOp<RankOp>(op, adaptor.memrefOrTensor());
429     return success();
430   }
431 };
432 
433 }  // namespace
434 
populateExtraBufferizePatterns(MLIRContext * context,BufferizeTypeConverter * converter,RewritePatternSet * patterns)435 void populateExtraBufferizePatterns(MLIRContext *context,
436                                     BufferizeTypeConverter *converter,
437                                     RewritePatternSet *patterns) {
438   // clang-format off
439   patterns->insert<
440       BufferizeAndConvertMinimumBroadcastShapesOp,
441       BufferizeConstantOp,
442       BufferizeDimOp,
443       BufferizeJITExecuteOp,
444       BufferizeRankOp>(*converter, context);
445   // clang-format on
446 }
447 
448 }  // namespace transforms
449 }  // namespace kernel_gen
450 }  // namespace mlir
451