• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
14 #define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
15 
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 
22 namespace mlir {
23 namespace edsc {
24 namespace ops {
25 
26 /// Build a generic vector contraction, that is a `vector.contract` op with
27 /// specified `iteratorTypes`. The client is responsible for specifying proper
28 /// indexings when creating the StructuredIndexed.
29 /// The computation represents a notional (A * B + C) where indexings specify
30 /// which dimensions are reduced and reordered.
31 /// Return the result of the `vector.contract` op
32 ///
33 /// Prerequisites:
34 /// A, B and C capture values of proper vector types, and indexing expressions
35 /// that match semantics of the `vector.contract` op.
36 Value vector_contraction(StructuredIndexed A, StructuredIndexed B,
37                          StructuredIndexed C,
38                          ArrayRef<IteratorType> iteratorTypes);
39 
40 /// Build a generic vector contraction that computes a matmul on vectors.
41 /// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors.
42 ///
43 /// Prerequisites:
44 /// A, B and C capture values of proper vector types. For instance
45 /// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`.
46 Value vector_contraction_matmul(Value A, Value B, Value C);
47 
48 } // namespace ops
49 } // namespace edsc
50 } // namespace mlir
51 
52 #endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
53