1 //===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides intuitive composable interfaces for building structured MLIR 10 // snippets in a declarative fashion. 11 // 12 //===----------------------------------------------------------------------===// 13 #ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ 14 #define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ 15 16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/EDSC/Builders.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/Builders.h" 21 22 namespace mlir { 23 namespace edsc { 24 namespace ops { 25 26 /// Build a generic vector contraction, that is a `vector.contract` op with 27 /// specified `iteratorTypes`. The client is responsible for specifying proper 28 /// indexings when creating the StructuredIndexed. 29 /// The computation represents a notional (A * B + C) where indexings specify 30 /// which dimensions are reduced and reordered. 31 /// Return the result of the `vector.contract` op 32 /// 33 /// Prerequisites: 34 /// A, B and C capture values of proper vector types, and indexing expressions 35 /// that match semantics of the `vector.contract` op. 36 Value vector_contraction(StructuredIndexed A, StructuredIndexed B, 37 StructuredIndexed C, 38 ArrayRef<IteratorType> iteratorTypes); 39 40 /// Build a generic vector contraction that computes a matmul on vectors. 41 /// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors. 42 /// 43 /// Prerequisites: 44 /// A, B and C capture values of proper vector types. For instance 45 /// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`. 46 Value vector_contraction_matmul(Value A, Value B, Value C); 47 48 } // namespace ops 49 } // namespace edsc 50 } // namespace mlir 51 52 #endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_ 53