• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 //   CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 //   find memory access strategies for given input layouts and tiling configs.
27 
28 #include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.h"
29 
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/AffineOps/AffineOps.h"  // TF:llvm-project
34 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
35 #include "mlir/IR/AffineExpr.h"  // TF:llvm-project
36 #include "mlir/IR/AffineMap.h"  // TF:llvm-project
37 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
38 #include "mlir/Transforms/LoopUtils.h"  // TF:llvm-project
39 #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
40 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 
43 namespace xla {
44 namespace mlir_gpu {
45 namespace {
46 
47 // Various extracted information for input shapes.
48 struct ShapeInfo {
49   // Buffer dimensions in the order of NCHW.
50   std::vector<int64_t> nchw_dimensions;
51 
52   // Buffer dimensions in the order of major to minor;
53   std::vector<int64_t> physical_dimensions;
54 
55   // The affine map that takes NCHW indices, and maps to the physical order.
56   mlir::AffineMap affine_map;
57 
58   mlir::Type element_type;
59 };
60 
GetShapeInfo(const Shape & shape,int64 n_dim,int64 c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)61 ShapeInfo GetShapeInfo(
62     const Shape& shape, int64 n_dim, int64 c_dim,
63     absl::Span<const tensorflow::protobuf_int64> spatial_dims,
64     mlir::Builder builder) {
65   ShapeInfo shape_info;
66 
67   std::vector<int64> physical_to_logical(
68       shape.layout().minor_to_major().rbegin(),
69       shape.layout().minor_to_major().rend());
70 
71   std::vector<int64> nchw_to_logical;
72 
73   nchw_to_logical.push_back(n_dim);
74   nchw_to_logical.push_back(c_dim);
75   for (int64 dim : spatial_dims) {
76     nchw_to_logical.push_back(dim);
77   }
78 
79   for (int64 dim : nchw_to_logical) {
80     shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
81   }
82 
83   for (int64 dim : physical_to_logical) {
84     shape_info.physical_dimensions.push_back(shape.dimensions(dim));
85   }
86 
87   std::vector<mlir::AffineExpr> affine_exprs;
88   // We want physical to nchw order.
89   for (int64 dim : ComposePermutations(InversePermutation(nchw_to_logical),
90                                        physical_to_logical)) {
91     affine_exprs.push_back(builder.getAffineDimExpr(dim));
92   }
93 
94   shape_info.affine_map = mlir::AffineMap::get(
95       /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs);
96 
97   shape_info.element_type = [&] {
98     switch (shape.element_type()) {
99       case xla::F16:
100         return builder.getF16Type();
101       case xla::F32:
102         return builder.getF32Type();
103       default:
104         break;
105     }
106     CHECK(false);
107   }();
108 
109   return shape_info;
110 }
111 
IsSimpleLoop(mlir::AffineForOp loop)112 bool IsSimpleLoop(mlir::AffineForOp loop) {
113   return loop.getLowerBoundMap().isSingleConstant() &&
114          loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
115          loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
116          std::next(loop.region().begin()) == loop.region().end();
117 }
118 
119 struct BoundAffineMap {
120   mlir::AffineMap affine_map;
121   std::vector<mlir::Value> operands;
122 };
123 
GetBoundAffineMapFrom(mlir::Operation * op)124 BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
125   if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
126     return {load.getAffineMap(),
127             std::vector<mlir::Value>(load.getMapOperands().begin(),
128                                      load.getMapOperands().end())};
129   } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
130     return {store.getAffineMap(),
131             std::vector<mlir::Value>(store.getMapOperands().begin(),
132                                      store.getMapOperands().end())};
133   } else {
134     CHECK(false);
135   }
136 }
137 
CloneWithNewAffineMap(mlir::Operation * op,BoundAffineMap new_affine,mlir::OpBuilder builder)138 mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
139                                        BoundAffineMap new_affine,
140                                        mlir::OpBuilder builder) {
141   if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
142     return builder.create<mlir::AffineLoadOp>(
143         builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
144         new_affine.operands);
145   } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
146     return builder.create<mlir::AffineStoreOp>(
147         builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
148         new_affine.affine_map, new_affine.operands);
149   } else {
150     CHECK(false);
151   }
152 }
153 
SetMemRef(mlir::Operation * op,mlir::Value memref)154 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
155   if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
156     load.setMemRef(memref);
157   } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
158     store.setMemRef(memref);
159   } else {
160     CHECK(false);
161   }
162 }
163 
CreateNestedSimpleLoops(absl::Span<const int64_t> upper_bounds,mlir::OpBuilder builder)164 std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
165     absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
166   std::vector<mlir::AffineForOp> loops;
167   loops.reserve(upper_bounds.size());
168   for (int64_t dim : upper_bounds) {
169     auto loop =
170         builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
171     loops.push_back(loop);
172     builder = loop.getBodyBuilder();
173   }
174   return loops;
175 }
176 
SetBoundForSimpleLoop(mlir::AffineForOp loop,mlir::AffineExpr new_bound,mlir::OpBuilder builder)177 void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
178                            mlir::OpBuilder builder) {
179   CHECK(IsSimpleLoop(loop));
180 
181   loop.setUpperBoundMap(mlir::AffineMap::get(
182       loop.getUpperBoundMap().getNumDims(),
183       loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
184 }
185 
186 // Tile a loop with trip count N by `size`. For now, N has to be a multiple of
187 // size, but later this constraint will be removed.
188 //
189 // The major loop (with trip count N / size) stays as-is, while the minor loop
190 // (with trip count `size`) will take over the body of `target`, and be placed
191 // as the new body of `target`.
192 //
193 // `target` has to be within the same "perfectly nested loop group" as `loop`.
194 // See the documentation for mlir::getPerfectlyNestedLoops.
195 //
196 // Example:
197 // Before tiling `loop` with tile size X:
198 //   for (loop in N)
199 //     for (unrelated_loop in ...)
200 //       for (target in ...)
201 //         // pass loop into affine maps
202 // After:
203 //   for (loop in N / X)
204 //     for (unrelated_loop in ...)
205 //       for (target in ...)
206 //         for (tiled_loop in X)
207 //           // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
208 //
209 // Design note:
210 // TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
211 // documented about the exact tiling semantics, but the observed behavior is:
212 //   for (i from 0 to N)
213 //     for (unrelated_loop in ...)
214 //       for (target in ...)
215 //         // pass i into affine maps
216 // =>
217 //   for (i from 0 to N, step = X)
218 //     for (unrelated_loop in ...)
219 //       for (target in ...)
220 //         for (j from i to min(i + X, N), step = 1)
221 //           // pass j into affine maps
222 //
223 // There are two differences between mlir::tile and TileLoop:
224 // * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
225 //   With that all index calculation is done in AffineExprs and easier to
226 //   analyze in a single place.
227 // * TileLoop doesn't plan to use use max() and min() to resolve the issue when
228 //   N % X != 0. max() and min() are not representable in AffineExprs.
229 //   TODO(timshen): support the case where N % X != 0.
230 //
231 // TODO(timshen): consider the possibility to reuse mlir::tile's logic to
232 // achieve the same goal.
TileLoop(mlir::AffineForOp loop,int64_t size,mlir::AffineForOp target)233 mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
234                            mlir::AffineForOp target) {
235   CHECK(IsSimpleLoop(loop));
236   CHECK(IsSimpleLoop(target));
237   {
238     llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
239     getPerfectlyNestedLoops(all_loops, loop);
240     CHECK(absl::c_linear_search(all_loops, target));
241   }
242 
243   auto builder = target.getBodyBuilder();
244 
245   auto inner_loop =
246       builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
247   {
248     auto& inner_operations = inner_loop.getBody()->getOperations();
249     auto& target_operations = target.getBody()->getOperations();
250 
251     inner_operations.splice(inner_operations.begin(), target_operations,
252                             target_operations.begin(),
253                             std::prev(target_operations.end(), 2));
254 
255     mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
256     CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
257     SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
258   }
259 
260   for (auto& use :
261        llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
262     mlir::Operation* owner = use.getOwner();
263     BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
264     unsigned new_dim = affine_map.operands.size();
265     affine_map.operands.push_back(inner_loop.getInductionVar());
266     std::vector<mlir::AffineExpr> replacements;
267     for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
268       if (affine_map.operands[i] == loop.getInductionVar()) {
269         replacements.push_back(builder.getAffineDimExpr(i) * size +
270                                builder.getAffineDimExpr(new_dim));
271       } else {
272         replacements.push_back(builder.getAffineDimExpr(i));
273       }
274     }
275     affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
276         replacements, {}, affine_map.operands.size(), 0);
277     auto new_op =
278         CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
279     owner->replaceAllUsesWith(new_op);
280     owner->erase();
281   }
282   return inner_loop;
283 }
284 
285 // Hoist operations out of `where`. [begin_op, end_op) must be the first
286 // operations of their parent loop, and `where` must be an ancestor of that
287 // parent loop.
288 //
289 // It always preserves the semantics of the program, therefore it may modify the
290 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)291 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
292                              llvm::iplist<mlir::Operation>::iterator end_op,
293                              mlir::AffineForOp where) {
294   // All loops to hoist through.
295   llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
296   getPerfectlyNestedLoops(ancestors, where);
297   {
298     int i;
299     for (i = 0; i < ancestors.size(); i++) {
300       if (&ancestors[i].getBody()->front() == &*begin_op) {
301         break;
302       }
303     }
304     CHECK(i < ancestors.size());
305     ancestors.resize(i + 1);
306   }
307 
308   std::vector<int64_t> ancestor_dimensions;
309   for (auto ancestor : ancestors) {
310     CHECK(IsSimpleLoop(ancestor));
311     ancestor_dimensions.push_back(
312         ancestor.getUpperBoundMap().getSingleConstantResult());
313   }
314 
315   if (auto alloc = mlir::dyn_cast<mlir::AllocOp>(begin_op)) {
316     CHECK(std::next(begin_op) == end_op)
317         << "alloc() needs to be hoisted by its own";
318 
319     mlir::OpBuilder builder(where);
320     mlir::MemRefType type = alloc.getType();
321     CHECK(type.getAffineMaps().empty());
322     ancestor_dimensions.insert(ancestor_dimensions.end(),
323                                type.getShape().begin(), type.getShape().end());
324     mlir::MemRefType new_type =
325         mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
326     auto new_alloc =
327         builder.create<mlir::AllocOp>(builder.getUnknownLoc(), new_type);
328 
329     std::vector<mlir::Value> indvars;
330     for (auto ancestor : ancestors) {
331       indvars.push_back(ancestor.getInductionVar());
332     }
333     for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
334       mlir::Operation* owner = use.getOwner();
335       BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
336       affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
337                                  indvars.end());
338       CHECK(affine_map.affine_map.isIdentity());
339       affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
340           affine_map.operands.size(), builder.getContext());
341 
342       mlir::Operation* new_op =
343           CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
344       SetMemRef(new_op, new_alloc);
345       owner->replaceAllUsesWith(new_op);
346       owner->erase();
347     }
348     alloc.erase();
349     return new_alloc;
350   }
351 
352   const bool any_op_is_loop_variant = [&] {
353     for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
354       if (mlir::isa<mlir::AffineForOp>(op) ||
355           mlir::isa<mlir::AffineStoreOp>(op)) {
356         return true;
357       }
358     }
359     return false;
360   }();
361 
362   if (any_op_is_loop_variant) {
363     auto builder = mlir::OpBuilder(where);
364     std::vector<mlir::AffineForOp> new_loops;
365     for (auto dim : ancestor_dimensions) {
366       auto where =
367           builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
368       new_loops.push_back(where);
369       builder = where.getBodyBuilder();
370     }
371     for (mlir::Operation& op :
372          llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
373       op.moveBefore(&new_loops.back().getBody()->back());
374     }
375     CHECK_EQ(ancestors.size(), new_loops.size());
376     for (int i = 0; i < ancestors.size(); i++) {
377       replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
378                                  new_loops[i].getInductionVar(),
379                                  new_loops.back().region());
380     }
381     return new_loops.front();
382   }
383   CHECK(false);
384 }
385 
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)386 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
387   return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
388 }
389 
390 // Sinks a segment of perfectly nested loops to the bottom. It implements this
391 // by rotating the loop nest by rotate_amount.
SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,int rotate_amount)392 void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
393                               int rotate_amount) {
394   CHECK_GE(rotate_amount, 0);
395   std::vector<unsigned> permutation(loops.size());
396   std::iota(permutation.begin(), permutation.end(), unsigned(0));
397   std::rotate(permutation.begin(),
398               permutation.begin() + loops.size() - rotate_amount,
399               permutation.end());
400   mlir::interchangeLoops(
401       llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
402       permutation);
403 }
404 
405 struct InitialMlirConvAnchors {
406   std::vector<mlir::AffineForOp> cartesian_product_loops;
407   std::vector<mlir::AffineForOp> reduction_loops;
408   mlir::AllocOp output_acc;
409 };
410 
411 // Return the following IR with the anchors set to corresponding operations.
412 //   for (cartesian loops...) {
413 //     %output_acc = alloc() : memref(f32)
414 //     output_acc[] = 0
415 //     for (reduction loops...) {
416 //       output_acc[] += input[...] * filter[...]
417 //     }
418 //     output[...] = output_acc[]
419 //   }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,mlir::OpBuilder builder)420 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
421     mlir::Value input, mlir::Value filter, mlir::Value output,
422     const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
423     const ShapeInfo& output_shape_info, const Window& window,
424     mlir::OpBuilder builder) {
425   CHECK(input_shape_info.element_type == builder.getF16Type());
426   CHECK(filter_shape_info.element_type == builder.getF16Type());
427   CHECK(output_shape_info.element_type == builder.getF16Type());
428 
429   auto location = mlir::UnknownLoc::get(builder.getContext());
430 
431   std::vector<mlir::AffineForOp> cartesian_product_loops =
432       CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
433 
434   builder = cartesian_product_loops.back().getBodyBuilder();
435 
436   mlir::AllocOp output_acc = builder.create<mlir::AllocOp>(
437       location, mlir::MemRefType::get({}, builder.getF32Type()));
438 
439   builder.create<mlir::AffineStoreOp>(
440       location,
441       builder.create<mlir::ConstantOp>(
442           location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
443       output_acc, llvm::ArrayRef<mlir::Value>());
444 
445   std::vector<mlir::AffineForOp> reduction_loops;
446   reduction_loops = CreateNestedSimpleLoops(
447       absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
448 
449   mlir::AffineForOp loop_n = cartesian_product_loops[0];
450   mlir::AffineForOp loop_o = cartesian_product_loops[1];
451   mlir::AffineForOp loop_c = reduction_loops[0];
452 
453   std::vector<mlir::Value> output_spatial_indvars;
454   for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
455     output_spatial_indvars.push_back(loop.getInductionVar());
456   }
457   std::vector<mlir::Value> filter_spatial_indvars;
458   for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
459     filter_spatial_indvars.push_back(loop.getInductionVar());
460   }
461   int num_spatial_dims = output_spatial_indvars.size();
462   CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
463 
464   builder = reduction_loops.back().getBodyBuilder();
465 
466   mlir::Value loaded_input = [&] {
467     std::vector<mlir::AffineExpr> input_indices;
468     input_indices.push_back(builder.getAffineDimExpr(0));
469     input_indices.push_back(builder.getAffineDimExpr(1));
470 
471     // For spatial dimensions, generate input_index * stride + filter_index -
472     // left_pad
473     //
474     // TODO(timshen): guard out-of-bound loads and stores brought by padding.
475     for (int i = 0; i < num_spatial_dims; i++) {
476       const WindowDimension& window_dim = window.dimensions(i);
477       input_indices.push_back(
478           builder.getAffineDimExpr(i + 2) * window_dim.stride() +
479           builder.getAffineDimExpr(2 + num_spatial_dims + i) -
480           window_dim.padding_low());
481     }
482     std::vector<mlir::Value> input_vars;
483     input_vars.push_back(loop_n.getInductionVar());
484     input_vars.push_back(loop_c.getInductionVar());
485     input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
486                       output_spatial_indvars.end());
487     input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
488                       filter_spatial_indvars.end());
489 
490     return builder.create<mlir::FPExtOp>(
491         location,
492         builder.createOrFold<mlir::AffineLoadOp>(
493             location, input,
494             mlir::AffineMap(input_shape_info.affine_map)
495                 .compose(
496                     mlir::AffineMap::get(/*dimCount=*/2 + num_spatial_dims * 2,
497                                          /*symbolCount=*/0, input_indices)),
498             input_vars),
499         builder.getF32Type());
500   }();
501 
502   mlir::Value loaded_filter = [&] {
503     std::vector<mlir::Value> filter_vars;
504     filter_vars.push_back(loop_o.getInductionVar());
505     filter_vars.push_back(loop_c.getInductionVar());
506     filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
507                        filter_spatial_indvars.end());
508 
509     return builder.create<mlir::FPExtOp>(
510         location,
511         builder.createOrFold<mlir::AffineLoadOp>(
512             location, filter, filter_shape_info.affine_map, filter_vars),
513         builder.getF32Type());
514   }();
515 
516   builder.createOrFold<mlir::AffineStoreOp>(
517       location,
518       builder.create<mlir::AddFOp>(
519           location,
520           builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
521           builder.create<mlir::MulFOp>(location, loaded_input, loaded_filter)),
522       output_acc, llvm::ArrayRef<mlir::Value>());
523 
524   builder.setInsertionPointAfter(reduction_loops[0]);
525   {
526     std::vector<mlir::Value> output_vars;
527     output_vars.push_back(loop_n.getInductionVar());
528     output_vars.push_back(loop_o.getInductionVar());
529     output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
530                        output_spatial_indvars.end());
531     builder.createOrFold<mlir::AffineStoreOp>(
532         location,
533         builder.create<mlir::FPTruncOp>(
534             location,
535             builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
536             builder.getF16Type()),
537         output, output_shape_info.affine_map, output_vars);
538   }
539 
540   return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
541                                 output_acc};
542 }
543 
544 // Contains the following pattern with anchors:
545 //   for (cartesian loops...) {
546 //     %output_acc = alloc() : memref(..., f32)
547 //     for (reduction loops...) {
548 //       for (tiled cartesian loops...) {
549 //         output_acc[...] = 0
550 //       }
551 //       for (tiled cartesian loops...) {
552 //         for (reduction loops...) {
553 //           output_acc[] += input[...] * filter[...]
554 //         }
555 //       }
556 //       for (tiled cartesian loops...) {
557 //         output[...] = output_acc[...]
558 //       }
559 //     }
560 //   }
561 struct TransformedMlirConvAnchors {
562   std::vector<mlir::AffineForOp> cartesian_product_loops;
563   std::vector<mlir::AffineForOp> reduction_loops;
564 };
565 
TransformMlirConv(InitialMlirConvAnchors anchors)566 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
567     InitialMlirConvAnchors anchors) {
568   std::vector<mlir::AffineForOp> cartesian_product_loops =
569       anchors.cartesian_product_loops;
570   std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
571   mlir::AllocOp output_acc = anchors.output_acc;
572 
573   // TODO(timshen): consider using pattern matchers for transformations
574   //
575   // Initial form:
576   //   for (cartesian loops...) {
577   //     %output_acc = alloc() : memref(f32)
578   //     output_acc[] = 0
579   //     for (reduction loops...) {
580   //       output_acc[] += input[...] * filter[...]
581   //     }
582   //     output[...] = output_acc[]
583   //   }
584 
585   // Tile cartesian loops to:
586   //   for (cartesian loops...) {
587   //     for (tiled cartesian loops...) {
588   //       %output_acc = alloc() : memref(f32)
589   //       output_acc[] = 0
590   //       for (reduction loops...) {
591   //         output_acc[] += input[...] * filter[...]
592   //       }
593   //       output[...] = output_acc[]
594   //     }
595   //   }
596   TileLoop(reduction_loops[0], 4, reduction_loops.back());
597 
598   std::vector<mlir::AffineForOp> tiled_cartesian_loops;
599   tiled_cartesian_loops.push_back(
600       TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
601 
602   tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
603                                            tiled_cartesian_loops.back()));
604 
605   // Two hoist operations to interleave the allocation, computation, and
606   // writebacks to output_acc:
607   // After first hoist:
608   //   for (cartesian loops...) {
609   //     %output_acc = alloc() : memref(..., f32)
610   //     for (tiled cartesian loops...) {
611   //       output_acc[...] = 0
612   //       for (reduction loops...) {
613   //         output_acc[...] += input[...] * filter[...]
614   //       }
615   //       output[...] = output_acc[...]
616   //     }
617   //   }
618   output_acc = llvm::cast<mlir::AllocOp>(
619       HoistAndFix(output_acc, tiled_cartesian_loops.front()));
620 
621   // Hoist everything before reduction loops (aka zero initializations of
622   // output_acc):
623   //   for (cartesian loops...) {
624   //     %output_acc = alloc() : memref(..., f32)
625   //     for (tiled cartesian loops...) {
626   //       output_acc[...] = 0
627   //     }
628   //     for (tiled cartesian loops...) {
629   //       for (reduction loops...) {
630   //         output_acc[...] += input[...] * filter[...]
631   //       }
632   //       output[...] = output_acc[...]
633   //     }
634   //   }
635   HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
636               reduction_loops.front().getOperation()->getIterator(),
637               tiled_cartesian_loops.front());
638 
639   // Now hoist all reduction loops outside of tiled cartesian loops.
640   // Notice that HoistAndFix automatically add a new set of tiled cartesian
641   // loops for hoisted reduction loops to keep the semantics correct.
642   //
643   // After second hoist:
644   //   for (cartesian loops...) {
645   //     %output_acc = alloc() : memref(..., f32)
646   //     for (tiled cartesian loops...) {
647   //       output_acc[...] = 0
648   //     }
649   //     for (tiled cartesian loops...) {
650   //       for (reduction loops...) {
651   //         output_acc[] += input[...] * filter[...]
652   //       }
653   //     }  // compute loop
654   //     for (tiled cartesian loops...) {
655   //       output[...] = output_acc[...]
656   //     }
657   //   }
658   {
659     auto compute_loop = llvm::cast<mlir::AffineForOp>(
660         HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
661 
662     // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
663     // not the writeback loops to output buffer.
664     llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
665     getPerfectlyNestedLoops(all_loops, compute_loop);
666     absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
667                    tiled_cartesian_loops.data());
668   }
669 
670   // After exchanging tiled cartesian compute loops with reduction loops:
671   //   for (cartesian loops...) {
672   //     %output_acc = alloc() : memref(..., f32)
673   //     for (tiled cartesian loops...) {
674   //       output_acc[...] = 0
675   //     }
676   //     for (reduction loops...) {
677   //       for (tiled cartesian loops...) {
678   //         output_acc[] += input[...] * filter[...]
679   //       }
680   //     }
681   //     for (tiled cartesian loops...) {
682   //       output[...] = output_acc[...]
683   //     }
684   //   }
685   //
686   // ...so that later tiled cartesian loops (with computations in it) can be
687   // replaced by CUDA MMA instructions.
688   {
689     std::vector<mlir::AffineForOp> loops;
690     loops.insert(loops.end(), tiled_cartesian_loops.begin(),
691                  tiled_cartesian_loops.end());
692     loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
693     SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
694   }
695   return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
696 }
697 
698 }  // namespace
699 
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)700 StatusOr<mlir::FuncOp> EmitConvolutionForwardAsMlir(
701     HloInstruction* conv, absl::string_view function_name,
702     mlir::MLIRContext* context) {
703   mlir::OpBuilder builder(context);
704 
705   const auto& dim_nums = conv->convolution_dimension_numbers();
706   ShapeInfo input_shape_info =
707       GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
708                    dim_nums.input_feature_dimension(),
709                    dim_nums.input_spatial_dimensions(), builder);
710 
711   ShapeInfo filter_shape_info = GetShapeInfo(
712       conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
713       dim_nums.kernel_input_feature_dimension(),
714       dim_nums.kernel_spatial_dimensions(), builder);
715 
716   ShapeInfo output_shape_info = GetShapeInfo(
717       conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
718       dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
719       builder);
720 
721   auto function = mlir::FuncOp::create(
722       mlir::UnknownLoc::get(builder.getContext()),
723       llvm_ir::AsStringRef(function_name),
724       builder.getFunctionType(
725           {mlir::MemRefType::get(output_shape_info.physical_dimensions,
726                                  output_shape_info.element_type, {}),
727            mlir::MemRefType::get(input_shape_info.physical_dimensions,
728                                  input_shape_info.element_type, {}),
729            mlir::MemRefType::get(filter_shape_info.physical_dimensions,
730                                  filter_shape_info.element_type, {})},
731           {}));
732 
733   auto* entry_block = function.addEntryBlock();
734   builder.setInsertionPointToStart(entry_block);
735   builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
736   builder.setInsertionPointToStart(entry_block);
737 
738   mlir::Value input = entry_block->getArgument(1);
739   mlir::Value filter = entry_block->getArgument(2);
740   mlir::Value output = entry_block->getArgument(0);
741 
742   TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
743 
744   TF_ASSIGN_OR_RETURN(
745       InitialMlirConvAnchors initial_anchors,
746       CreateNaiveMlirConv(input, filter, output, input_shape_info,
747                           filter_shape_info, output_shape_info, conv->window(),
748                           builder));
749 
750   TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
751                       TransformMlirConv(initial_anchors));
752 
753   // TODO(timshen): Implement a transformation that collects loads to a given
754   // buffer, create a local alloc() for the accessed part, redirects all loads
755   // and stores to that local alloc(), and create code to initialize /
756   // writeback the local alloc() if needed.
757 
758   // TODO(timshen): Implement CUDA-specific lowering.
759 
760   return function;
761 }
762 
ConvIsImplemented(const HloInstruction * conv)763 Status ConvIsImplemented(const HloInstruction* conv) {
764   if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
765     return Unimplemented("group count is not implemented.");
766   }
767   if (window_util::HasWindowReversal(conv->window())) {
768     return Unimplemented("Window reversal is not implemented.");
769   }
770   if (window_util::HasDilation(conv->window())) {
771     return Unimplemented("Dilation is not implemented.");
772   }
773   return Status::OK();
774 }
775 
776 }  // namespace mlir_gpu
777 }  // namespace xla
778