1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 10 #include "mlir/IR/AffineExpr.h" 11 #include "mlir/IR/AffineMap.h" 12 13 using namespace mlir; 14 using namespace mlir::edsc; 15 using namespace mlir::edsc::intrinsics; 16 getMemRefSizes(Value memRef)17static SmallVector<Value, 8> getMemRefSizes(Value memRef) { 18 MemRefType memRefType = memRef.getType().cast<MemRefType>(); 19 assert(isStrided(memRefType) && "Expected strided MemRef type"); 20 21 SmallVector<Value, 8> res; 22 res.reserve(memRefType.getShape().size()); 23 const auto &shape = memRefType.getShape(); 24 for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) { 25 if (shape[idx] == -1) 26 res.push_back(std_dim(memRef, idx)); 27 else 28 res.push_back(std_constant_index(shape[idx])); 29 } 30 return res; 31 } 32 MemRefBoundsCapture(Value v)33mlir::edsc::MemRefBoundsCapture::MemRefBoundsCapture(Value v) { 34 auto memrefSizeValues = getMemRefSizes(v); 35 for (auto s : memrefSizeValues) { 36 lbs.push_back(std_constant_index(0)); 37 ubs.push_back(s); 38 steps.push_back(1); 39 } 40 } 41 VectorBoundsCapture(VectorType t)42mlir::edsc::VectorBoundsCapture::VectorBoundsCapture(VectorType t) { 43 for (auto s : t.getShape()) { 44 lbs.push_back(std_constant_index(0)); 45 ubs.push_back(std_constant_index(s)); 46 steps.push_back(1); 47 } 48 } 49 VectorBoundsCapture(Value v)50mlir::edsc::VectorBoundsCapture::VectorBoundsCapture(Value v) 51 : VectorBoundsCapture(v.getType().cast<VectorType>()) {} 52