• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 
13 using namespace mlir;
14 using namespace mlir::edsc;
15 using namespace mlir::edsc::intrinsics;
16 
getMemRefSizes(Value memRef)17 static SmallVector<Value, 8> getMemRefSizes(Value memRef) {
18   MemRefType memRefType = memRef.getType().cast<MemRefType>();
19   assert(isStrided(memRefType) && "Expected strided MemRef type");
20 
21   SmallVector<Value, 8> res;
22   res.reserve(memRefType.getShape().size());
23   const auto &shape = memRefType.getShape();
24   for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
25     if (shape[idx] == -1)
26       res.push_back(std_dim(memRef, idx));
27     else
28       res.push_back(std_constant_index(shape[idx]));
29   }
30   return res;
31 }
32 
MemRefBoundsCapture(Value v)33 mlir::edsc::MemRefBoundsCapture::MemRefBoundsCapture(Value v) {
34   auto memrefSizeValues = getMemRefSizes(v);
35   for (auto s : memrefSizeValues) {
36     lbs.push_back(std_constant_index(0));
37     ubs.push_back(s);
38     steps.push_back(1);
39   }
40 }
41 
VectorBoundsCapture(VectorType t)42 mlir::edsc::VectorBoundsCapture::VectorBoundsCapture(VectorType t) {
43   for (auto s : t.getShape()) {
44     lbs.push_back(std_constant_index(0));
45     ubs.push_back(std_constant_index(s));
46     steps.push_back(1);
47   }
48 }
49 
VectorBoundsCapture(Value v)50 mlir::edsc::VectorBoundsCapture::VectorBoundsCapture(Value v)
51     : VectorBoundsCapture(v.getType().cast<VectorType>()) {}
52