1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 // CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 // find memory access strategies for given input layouts and tiling configs.
27
28 #include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h"
29
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project
34 #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
35 #include "mlir/IR/AffineExpr.h" // TF:llvm-project
36 #include "mlir/IR/AffineMap.h" // TF:llvm-project
37 #include "mlir/IR/StandardTypes.h" // TF:llvm-project
38 #include "mlir/Transforms/LoopUtils.h" // TF:llvm-project
39 #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42
43 namespace xla {
44 namespace mlir_gpu {
45 namespace {
46
47 // Various extracted information for input shapes.
48 struct ShapeInfo {
49 // Buffer dimensions in the order of NCHW.
50 std::vector<int64_t> nchw_dimensions;
51
52 // Buffer dimensions in the order of major to minor;
53 std::vector<int64_t> physical_dimensions;
54
55 // The affine map that takes NCHW indices, and maps to the physical order.
56 mlir::AffineMap affine_map;
57
58 mlir::Type element_type;
59 };
60
GetShapeInfo(const Shape & shape,int64 n_dim,int64 c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)61 ShapeInfo GetShapeInfo(
62 const Shape& shape, int64 n_dim, int64 c_dim,
63 absl::Span<const tensorflow::protobuf_int64> spatial_dims,
64 mlir::Builder builder) {
65 ShapeInfo shape_info;
66
67 std::vector<int64> physical_to_logical(
68 shape.layout().minor_to_major().rbegin(),
69 shape.layout().minor_to_major().rend());
70
71 std::vector<int64> nchw_to_logical;
72
73 nchw_to_logical.push_back(n_dim);
74 nchw_to_logical.push_back(c_dim);
75 for (int64 dim : spatial_dims) {
76 nchw_to_logical.push_back(dim);
77 }
78
79 for (int64 dim : nchw_to_logical) {
80 shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
81 }
82
83 for (int64 dim : physical_to_logical) {
84 shape_info.physical_dimensions.push_back(shape.dimensions(dim));
85 }
86
87 std::vector<mlir::AffineExpr> affine_exprs;
88 // We want physical to nchw order.
89 for (int64 dim : ComposePermutations(InversePermutation(nchw_to_logical),
90 physical_to_logical)) {
91 affine_exprs.push_back(builder.getAffineDimExpr(dim));
92 }
93
94 shape_info.affine_map = mlir::AffineMap::get(
95 /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs);
96
97 shape_info.element_type = [&] {
98 switch (shape.element_type()) {
99 case xla::F16:
100 return builder.getF16Type();
101 case xla::F32:
102 return builder.getF32Type();
103 default:
104 break;
105 }
106 CHECK(false);
107 }();
108
109 return shape_info;
110 }
111
IsSimpleLoop(mlir::AffineForOp loop)112 bool IsSimpleLoop(mlir::AffineForOp loop) {
113 return loop.getLowerBoundMap().isSingleConstant() &&
114 loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
115 loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
116 std::next(loop.region().begin()) == loop.region().end();
117 }
118
119 struct BoundAffineMap {
120 mlir::AffineMap affine_map;
121 std::vector<mlir::Value> operands;
122 };
123
GetBoundAffineMapFrom(mlir::Operation * op)124 BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
125 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
126 return {load.getAffineMap(),
127 std::vector<mlir::Value>(load.getMapOperands().begin(),
128 load.getMapOperands().end())};
129 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
130 return {store.getAffineMap(),
131 std::vector<mlir::Value>(store.getMapOperands().begin(),
132 store.getMapOperands().end())};
133 } else {
134 CHECK(false);
135 }
136 }
137
CloneWithNewAffineMap(mlir::Operation * op,BoundAffineMap new_affine,mlir::OpBuilder builder)138 mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
139 BoundAffineMap new_affine,
140 mlir::OpBuilder builder) {
141 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
142 return builder.create<mlir::AffineLoadOp>(
143 builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
144 new_affine.operands);
145 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
146 return builder.create<mlir::AffineStoreOp>(
147 builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
148 new_affine.affine_map, new_affine.operands);
149 } else {
150 CHECK(false);
151 }
152 }
153
SetMemRef(mlir::Operation * op,mlir::Value memref)154 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
155 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
156 load.setMemRef(memref);
157 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
158 store.setMemRef(memref);
159 } else {
160 CHECK(false);
161 }
162 }
163
CreateNestedSimpleLoops(absl::Span<const int64_t> upper_bounds,mlir::OpBuilder builder)164 std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
165 absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
166 std::vector<mlir::AffineForOp> loops;
167 loops.reserve(upper_bounds.size());
168 for (int64_t dim : upper_bounds) {
169 auto loop =
170 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
171 loops.push_back(loop);
172 builder = loop.getBodyBuilder();
173 }
174 return loops;
175 }
176
SetBoundForSimpleLoop(mlir::AffineForOp loop,mlir::AffineExpr new_bound,mlir::OpBuilder builder)177 void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
178 mlir::OpBuilder builder) {
179 CHECK(IsSimpleLoop(loop));
180
181 loop.setUpperBoundMap(mlir::AffineMap::get(
182 loop.getUpperBoundMap().getNumDims(),
183 loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
184 }
185
186 // Tile a loop with trip count N by `size`. For now, N has to be a multiple of
187 // size, but later this constraint will be removed.
188 //
189 // The major loop (with trip count N / size) stays as-is, while the minor loop
190 // (with trip count `size`) will take over the body of `target`, and be placed
191 // as the new body of `target`.
192 //
193 // `target` has to be within the same "perfectly nested loop group" as `loop`.
194 // See the documentation for mlir::getPerfectlyNestedLoops.
195 //
196 // Example:
197 // Before tiling `loop` with tile size X:
198 // for (loop in N)
199 // for (unrelated_loop in ...)
200 // for (target in ...)
201 // // pass loop into affine maps
202 // After:
203 // for (loop in N / X)
204 // for (unrelated_loop in ...)
205 // for (target in ...)
206 // for (tiled_loop in X)
207 // // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
208 //
209 // Design note:
210 // TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
211 // documented about the exact tiling semantics, but the observed behavior is:
212 // for (i from 0 to N)
213 // for (unrelated_loop in ...)
214 // for (target in ...)
215 // // pass i into affine maps
216 // =>
217 // for (i from 0 to N, step = X)
218 // for (unrelated_loop in ...)
219 // for (target in ...)
220 // for (j from i to min(i + X, N), step = 1)
221 // // pass j into affine maps
222 //
223 // There are two differences between mlir::tile and TileLoop:
224 // * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
225 // With that all index calculation is done in AffineExprs and easier to
226 // analyze in a single place.
227 // * TileLoop doesn't plan to use use max() and min() to resolve the issue when
228 // N % X != 0. max() and min() are not representable in AffineExprs.
229 // TODO(timshen): support the case where N % X != 0.
230 //
231 // TODO(timshen): consider the possibility to reuse mlir::tile's logic to
232 // achieve the same goal.
TileLoop(mlir::AffineForOp loop,int64_t size,mlir::AffineForOp target)233 mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
234 mlir::AffineForOp target) {
235 CHECK(IsSimpleLoop(loop));
236 CHECK(IsSimpleLoop(target));
237 {
238 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
239 getPerfectlyNestedLoops(all_loops, loop);
240 CHECK(absl::c_linear_search(all_loops, target));
241 }
242
243 auto builder = target.getBodyBuilder();
244
245 auto inner_loop =
246 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
247 {
248 auto& inner_operations = inner_loop.getBody()->getOperations();
249 auto& target_operations = target.getBody()->getOperations();
250
251 inner_operations.splice(inner_operations.begin(), target_operations,
252 target_operations.begin(),
253 std::prev(target_operations.end(), 2));
254
255 mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
256 CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
257 SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
258 }
259
260 for (auto& use :
261 llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
262 mlir::Operation* owner = use.getOwner();
263 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
264 unsigned new_dim = affine_map.operands.size();
265 affine_map.operands.push_back(inner_loop.getInductionVar());
266 std::vector<mlir::AffineExpr> replacements;
267 for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
268 if (affine_map.operands[i] == loop.getInductionVar()) {
269 replacements.push_back(builder.getAffineDimExpr(i) * size +
270 builder.getAffineDimExpr(new_dim));
271 } else {
272 replacements.push_back(builder.getAffineDimExpr(i));
273 }
274 }
275 affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
276 replacements, {}, affine_map.operands.size(), 0);
277 auto new_op =
278 CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
279 owner->replaceAllUsesWith(new_op);
280 owner->erase();
281 }
282 return inner_loop;
283 }
284
285 // Hoist operations out of `where`. [begin_op, end_op) must be the first
286 // operations of their parent loop, and `where` must be an ancestor of that
287 // parent loop.
288 //
289 // It always preserves the semantics of the program, therefore it may modify the
290 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)291 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
292 llvm::iplist<mlir::Operation>::iterator end_op,
293 mlir::AffineForOp where) {
294 // All loops to hoist through.
295 llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
296 getPerfectlyNestedLoops(ancestors, where);
297 {
298 int i;
299 for (i = 0; i < ancestors.size(); i++) {
300 if (&ancestors[i].getBody()->front() == &*begin_op) {
301 break;
302 }
303 }
304 CHECK(i < ancestors.size());
305 ancestors.resize(i + 1);
306 }
307
308 std::vector<int64_t> ancestor_dimensions;
309 for (auto ancestor : ancestors) {
310 CHECK(IsSimpleLoop(ancestor));
311 ancestor_dimensions.push_back(
312 ancestor.getUpperBoundMap().getSingleConstantResult());
313 }
314
315 if (auto alloc = mlir::dyn_cast<mlir::AllocOp>(begin_op)) {
316 CHECK(std::next(begin_op) == end_op)
317 << "alloc() needs to be hoisted by its own";
318
319 mlir::OpBuilder builder(where);
320 mlir::MemRefType type = alloc.getType();
321 CHECK(type.getAffineMaps().empty());
322 ancestor_dimensions.insert(ancestor_dimensions.end(),
323 type.getShape().begin(), type.getShape().end());
324 mlir::MemRefType new_type =
325 mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
326 auto new_alloc =
327 builder.create<mlir::AllocOp>(builder.getUnknownLoc(), new_type);
328
329 std::vector<mlir::Value> indvars;
330 for (auto ancestor : ancestors) {
331 indvars.push_back(ancestor.getInductionVar());
332 }
333 for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
334 mlir::Operation* owner = use.getOwner();
335 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
336 affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
337 indvars.end());
338 CHECK(affine_map.affine_map.isIdentity());
339 affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
340 affine_map.operands.size(), builder.getContext());
341
342 mlir::Operation* new_op =
343 CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
344 SetMemRef(new_op, new_alloc);
345 owner->replaceAllUsesWith(new_op);
346 owner->erase();
347 }
348 alloc.erase();
349 return new_alloc;
350 }
351
352 const bool any_op_is_loop_variant = [&] {
353 for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
354 if (mlir::isa<mlir::AffineForOp>(op) ||
355 mlir::isa<mlir::AffineStoreOp>(op)) {
356 return true;
357 }
358 }
359 return false;
360 }();
361
362 if (any_op_is_loop_variant) {
363 auto builder = mlir::OpBuilder(where);
364 std::vector<mlir::AffineForOp> new_loops;
365 for (auto dim : ancestor_dimensions) {
366 auto where =
367 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
368 new_loops.push_back(where);
369 builder = where.getBodyBuilder();
370 }
371 for (mlir::Operation& op :
372 llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
373 op.moveBefore(&new_loops.back().getBody()->back());
374 }
375 CHECK_EQ(ancestors.size(), new_loops.size());
376 for (int i = 0; i < ancestors.size(); i++) {
377 replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
378 new_loops[i].getInductionVar(),
379 new_loops.back().region());
380 }
381 return new_loops.front();
382 }
383 CHECK(false);
384 }
385
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)386 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
387 return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
388 }
389
390 // Sinks a segment of perfectly nested loops to the bottom. It implements this
391 // by rotating the loop nest by rotate_amount.
SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,int rotate_amount)392 void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
393 int rotate_amount) {
394 CHECK_GE(rotate_amount, 0);
395 std::vector<unsigned> permutation(loops.size());
396 std::iota(permutation.begin(), permutation.end(), unsigned(0));
397 std::rotate(permutation.begin(),
398 permutation.begin() + loops.size() - rotate_amount,
399 permutation.end());
400 mlir::interchangeLoops(
401 llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
402 permutation);
403 }
404
405 struct InitialMlirConvAnchors {
406 std::vector<mlir::AffineForOp> cartesian_product_loops;
407 std::vector<mlir::AffineForOp> reduction_loops;
408 mlir::AllocOp output_acc;
409 };
410
411 // Return the following IR with the anchors set to corresponding operations.
412 // for (cartesian loops...) {
413 // %output_acc = alloc() : memref(f32)
414 // output_acc[] = 0
415 // for (reduction loops...) {
416 // output_acc[] += input[...] * filter[...]
417 // }
418 // output[...] = output_acc[]
419 // }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,mlir::OpBuilder builder)420 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
421 mlir::Value input, mlir::Value filter, mlir::Value output,
422 const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
423 const ShapeInfo& output_shape_info, const Window& window,
424 mlir::OpBuilder builder) {
425 CHECK(input_shape_info.element_type == builder.getF16Type());
426 CHECK(filter_shape_info.element_type == builder.getF16Type());
427 CHECK(output_shape_info.element_type == builder.getF16Type());
428
429 auto location = mlir::UnknownLoc::get(builder.getContext());
430
431 std::vector<mlir::AffineForOp> cartesian_product_loops =
432 CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
433
434 builder = cartesian_product_loops.back().getBodyBuilder();
435
436 mlir::AllocOp output_acc = builder.create<mlir::AllocOp>(
437 location, mlir::MemRefType::get({}, builder.getF32Type()));
438
439 builder.create<mlir::AffineStoreOp>(
440 location,
441 builder.create<mlir::ConstantOp>(
442 location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
443 output_acc, llvm::ArrayRef<mlir::Value>());
444
445 std::vector<mlir::AffineForOp> reduction_loops;
446 reduction_loops = CreateNestedSimpleLoops(
447 absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
448
449 mlir::AffineForOp loop_n = cartesian_product_loops[0];
450 mlir::AffineForOp loop_o = cartesian_product_loops[1];
451 mlir::AffineForOp loop_c = reduction_loops[0];
452
453 std::vector<mlir::Value> output_spatial_indvars;
454 for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
455 output_spatial_indvars.push_back(loop.getInductionVar());
456 }
457 std::vector<mlir::Value> filter_spatial_indvars;
458 for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
459 filter_spatial_indvars.push_back(loop.getInductionVar());
460 }
461 int num_spatial_dims = output_spatial_indvars.size();
462 CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
463
464 builder = reduction_loops.back().getBodyBuilder();
465
466 mlir::Value loaded_input = [&] {
467 std::vector<mlir::AffineExpr> input_indices;
468 input_indices.push_back(builder.getAffineDimExpr(0));
469 input_indices.push_back(builder.getAffineDimExpr(1));
470
471 // For spatial dimensions, generate input_index * stride + filter_index -
472 // left_pad
473 //
474 // TODO(timshen): guard out-of-bound loads and stores brought by padding.
475 for (int i = 0; i < num_spatial_dims; i++) {
476 const WindowDimension& window_dim = window.dimensions(i);
477 input_indices.push_back(
478 builder.getAffineDimExpr(i + 2) * window_dim.stride() +
479 builder.getAffineDimExpr(2 + num_spatial_dims + i) -
480 window_dim.padding_low());
481 }
482 std::vector<mlir::Value> input_vars;
483 input_vars.push_back(loop_n.getInductionVar());
484 input_vars.push_back(loop_c.getInductionVar());
485 input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
486 output_spatial_indvars.end());
487 input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
488 filter_spatial_indvars.end());
489
490 return builder.create<mlir::FPExtOp>(
491 location,
492 builder.createOrFold<mlir::AffineLoadOp>(
493 location, input,
494 mlir::AffineMap(input_shape_info.affine_map)
495 .compose(
496 mlir::AffineMap::get(/*dimCount=*/2 + num_spatial_dims * 2,
497 /*symbolCount=*/0, input_indices)),
498 input_vars),
499 builder.getF32Type());
500 }();
501
502 mlir::Value loaded_filter = [&] {
503 std::vector<mlir::Value> filter_vars;
504 filter_vars.push_back(loop_o.getInductionVar());
505 filter_vars.push_back(loop_c.getInductionVar());
506 filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
507 filter_spatial_indvars.end());
508
509 return builder.create<mlir::FPExtOp>(
510 location,
511 builder.createOrFold<mlir::AffineLoadOp>(
512 location, filter, filter_shape_info.affine_map, filter_vars),
513 builder.getF32Type());
514 }();
515
516 builder.createOrFold<mlir::AffineStoreOp>(
517 location,
518 builder.create<mlir::AddFOp>(
519 location,
520 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
521 builder.create<mlir::MulFOp>(location, loaded_input, loaded_filter)),
522 output_acc, llvm::ArrayRef<mlir::Value>());
523
524 builder.setInsertionPointAfter(reduction_loops[0]);
525 {
526 std::vector<mlir::Value> output_vars;
527 output_vars.push_back(loop_n.getInductionVar());
528 output_vars.push_back(loop_o.getInductionVar());
529 output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
530 output_spatial_indvars.end());
531 builder.createOrFold<mlir::AffineStoreOp>(
532 location,
533 builder.create<mlir::FPTruncOp>(
534 location,
535 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
536 builder.getF16Type()),
537 output, output_shape_info.affine_map, output_vars);
538 }
539
540 return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
541 output_acc};
542 }
543
544 // Contains the following pattern with anchors:
545 // for (cartesian loops...) {
546 // %output_acc = alloc() : memref(..., f32)
547 // for (reduction loops...) {
548 // for (tiled cartesian loops...) {
549 // output_acc[...] = 0
550 // }
551 // for (tiled cartesian loops...) {
552 // for (reduction loops...) {
553 // output_acc[] += input[...] * filter[...]
554 // }
555 // }
556 // for (tiled cartesian loops...) {
557 // output[...] = output_acc[...]
558 // }
559 // }
560 // }
561 struct TransformedMlirConvAnchors {
562 std::vector<mlir::AffineForOp> cartesian_product_loops;
563 std::vector<mlir::AffineForOp> reduction_loops;
564 };
565
TransformMlirConv(InitialMlirConvAnchors anchors)566 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
567 InitialMlirConvAnchors anchors) {
568 std::vector<mlir::AffineForOp> cartesian_product_loops =
569 anchors.cartesian_product_loops;
570 std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
571 mlir::AllocOp output_acc = anchors.output_acc;
572
573 // TODO(timshen): consider using pattern matchers for transformations
574 //
575 // Initial form:
576 // for (cartesian loops...) {
577 // %output_acc = alloc() : memref(f32)
578 // output_acc[] = 0
579 // for (reduction loops...) {
580 // output_acc[] += input[...] * filter[...]
581 // }
582 // output[...] = output_acc[]
583 // }
584
585 // Tile cartesian loops to:
586 // for (cartesian loops...) {
587 // for (tiled cartesian loops...) {
588 // %output_acc = alloc() : memref(f32)
589 // output_acc[] = 0
590 // for (reduction loops...) {
591 // output_acc[] += input[...] * filter[...]
592 // }
593 // output[...] = output_acc[]
594 // }
595 // }
596 TileLoop(reduction_loops[0], 4, reduction_loops.back());
597
598 std::vector<mlir::AffineForOp> tiled_cartesian_loops;
599 tiled_cartesian_loops.push_back(
600 TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
601
602 tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
603 tiled_cartesian_loops.back()));
604
605 // Two hoist operations to interleave the allocation, computation, and
606 // writebacks to output_acc:
607 // After first hoist:
608 // for (cartesian loops...) {
609 // %output_acc = alloc() : memref(..., f32)
610 // for (tiled cartesian loops...) {
611 // output_acc[...] = 0
612 // for (reduction loops...) {
613 // output_acc[...] += input[...] * filter[...]
614 // }
615 // output[...] = output_acc[...]
616 // }
617 // }
618 output_acc = llvm::cast<mlir::AllocOp>(
619 HoistAndFix(output_acc, tiled_cartesian_loops.front()));
620
621 // Hoist everything before reduction loops (aka zero initializations of
622 // output_acc):
623 // for (cartesian loops...) {
624 // %output_acc = alloc() : memref(..., f32)
625 // for (tiled cartesian loops...) {
626 // output_acc[...] = 0
627 // }
628 // for (tiled cartesian loops...) {
629 // for (reduction loops...) {
630 // output_acc[...] += input[...] * filter[...]
631 // }
632 // output[...] = output_acc[...]
633 // }
634 // }
635 HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
636 reduction_loops.front().getOperation()->getIterator(),
637 tiled_cartesian_loops.front());
638
639 // Now hoist all reduction loops outside of tiled cartesian loops.
640 // Notice that HoistAndFix automatically add a new set of tiled cartesian
641 // loops for hoisted reduction loops to keep the semantics correct.
642 //
643 // After second hoist:
644 // for (cartesian loops...) {
645 // %output_acc = alloc() : memref(..., f32)
646 // for (tiled cartesian loops...) {
647 // output_acc[...] = 0
648 // }
649 // for (tiled cartesian loops...) {
650 // for (reduction loops...) {
651 // output_acc[] += input[...] * filter[...]
652 // }
653 // } // compute loop
654 // for (tiled cartesian loops...) {
655 // output[...] = output_acc[...]
656 // }
657 // }
658 {
659 auto compute_loop = llvm::cast<mlir::AffineForOp>(
660 HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
661
662 // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
663 // not the writeback loops to output buffer.
664 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
665 getPerfectlyNestedLoops(all_loops, compute_loop);
666 absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
667 tiled_cartesian_loops.data());
668 }
669
670 // After exchanging tiled cartesian compute loops with reduction loops:
671 // for (cartesian loops...) {
672 // %output_acc = alloc() : memref(..., f32)
673 // for (tiled cartesian loops...) {
674 // output_acc[...] = 0
675 // }
676 // for (reduction loops...) {
677 // for (tiled cartesian loops...) {
678 // output_acc[] += input[...] * filter[...]
679 // }
680 // }
681 // for (tiled cartesian loops...) {
682 // output[...] = output_acc[...]
683 // }
684 // }
685 //
686 // ...so that later tiled cartesian loops (with computations in it) can be
687 // replaced by CUDA MMA instructions.
688 {
689 std::vector<mlir::AffineForOp> loops;
690 loops.insert(loops.end(), tiled_cartesian_loops.begin(),
691 tiled_cartesian_loops.end());
692 loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
693 SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
694 }
695 return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
696 }
697
698 } // namespace
699
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)700 StatusOr<mlir::FuncOp> EmitConvolutionForwardAsMlir(
701 HloInstruction* conv, absl::string_view function_name,
702 mlir::MLIRContext* context) {
703 mlir::OpBuilder builder(context);
704
705 const auto& dim_nums = conv->convolution_dimension_numbers();
706 ShapeInfo input_shape_info =
707 GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
708 dim_nums.input_feature_dimension(),
709 dim_nums.input_spatial_dimensions(), builder);
710
711 ShapeInfo filter_shape_info = GetShapeInfo(
712 conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
713 dim_nums.kernel_input_feature_dimension(),
714 dim_nums.kernel_spatial_dimensions(), builder);
715
716 ShapeInfo output_shape_info = GetShapeInfo(
717 conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
718 dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
719 builder);
720
721 auto function = mlir::FuncOp::create(
722 mlir::UnknownLoc::get(builder.getContext()),
723 llvm_ir::AsStringRef(function_name),
724 builder.getFunctionType(
725 {mlir::MemRefType::get(output_shape_info.physical_dimensions,
726 output_shape_info.element_type, {}),
727 mlir::MemRefType::get(input_shape_info.physical_dimensions,
728 input_shape_info.element_type, {}),
729 mlir::MemRefType::get(filter_shape_info.physical_dimensions,
730 filter_shape_info.element_type, {})},
731 {}));
732
733 auto* entry_block = function.addEntryBlock();
734 builder.setInsertionPointToStart(entry_block);
735 builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
736 builder.setInsertionPointToStart(entry_block);
737
738 mlir::Value input = entry_block->getArgument(1);
739 mlir::Value filter = entry_block->getArgument(2);
740 mlir::Value output = entry_block->getArgument(0);
741
742 TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
743
744 TF_ASSIGN_OR_RETURN(
745 InitialMlirConvAnchors initial_anchors,
746 CreateNaiveMlirConv(input, filter, output, input_shape_info,
747 filter_shape_info, output_shape_info, conv->window(),
748 builder));
749
750 TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
751 TransformMlirConv(initial_anchors));
752
753 // TODO(timshen): Implement a transformation that collects loads to a given
754 // buffer, create a local alloc() for the accessed part, redirects all loads
755 // and stores to that local alloc(), and create code to initialize /
756 // writeback the local alloc() if needed.
757
758 // TODO(timshen): Implement CUDA-specific lowering.
759
760 return function;
761 }
762
ConvIsImplemented(const HloInstruction * conv)763 Status ConvIsImplemented(const HloInstruction* conv) {
764 if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
765 return Unimplemented("group count is not implemented.");
766 }
767 if (window_util::HasWindowReversal(conv->window())) {
768 return Unimplemented("Window reversal is not implemented.");
769 }
770 if (window_util::HasDilation(conv->window())) {
771 return Unimplemented("Dilation is not implemented.");
772 }
773 return Status::OK();
774 }
775
776 } // namespace mlir_gpu
777 } // namespace xla
778