1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Vector/EDSC/Builders.h"
10 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Vector/VectorOps.h"
12 #include "mlir/EDSC/Builders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/Builders.h"
15
16 using namespace mlir;
17 using namespace mlir::edsc;
18 using namespace mlir::edsc::intrinsics;
19 using namespace mlir::edsc::ops;
20
vector_contraction(StructuredIndexed A,StructuredIndexed B,StructuredIndexed C,ArrayRef<IteratorType> iteratorTypes)21 Value mlir::edsc::ops::vector_contraction(
22 StructuredIndexed A, StructuredIndexed B, StructuredIndexed C,
23 ArrayRef<IteratorType> iteratorTypes) {
24 using IndexingExprs = ArrayRef<ArrayRef<AffineExpr>>;
25 return vector_contract(
26 A.getValue(), B.getValue(), C.getValue(),
27 IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()},
28 ArrayRef<StringRef>{
29 llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString))});
30 }
31
vector_contraction_matmul(Value A,Value B,Value C)32 Value mlir::edsc::ops::vector_contraction_matmul(Value A, Value B, Value C) {
33 AffineExpr m, n, k;
34 bindDims(ScopedContext::getContext(), m, n, k);
35 return vector_contraction(StructuredIndexed(A, {m, k}),
36 StructuredIndexed(B, {k, n}),
37 StructuredIndexed(C, {m, n}),
38 {IteratorType::Parallel, IteratorType::Parallel,
39 IteratorType::Reduction});
40 }
41