1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <type_traits>
14
15 #include "mlir/Dialect/Affine/EDSC/Builders.h"
16 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/Dialect/Vector/VectorUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/IR/Types.h"
38 #include "mlir/Interfaces/VectorInterfaces.h"
39
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43
44 #define DEBUG_TYPE "vector-to-vector"
45
46 using namespace mlir;
47 using llvm::dbgs;
48
49 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)50 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
51 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
52 int64_t idx = map.getDimPosition(i);
53 if (idx == index)
54 return i;
55 }
56 return None;
57 }
58
59 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)60 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
61 int64_t index) {
62 SmallVector<Attribute, 4> results;
63 for (auto it : llvm::enumerate(iteratorTypes)) {
64 int64_t idx = it.index();
65 if (idx == index)
66 continue;
67 results.push_back(it.value());
68 }
69 return results;
70 }
71
72 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)73 static AffineMap adjustMap(AffineMap map, int64_t index,
74 PatternRewriter &rewriter) {
75 auto *ctx = rewriter.getContext();
76 SmallVector<AffineExpr, 4> results;
77 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
78 int64_t idx = map.getDimPosition(i);
79 if (idx == index)
80 continue;
81 // Re-insert remaining indices, but renamed when occurring
82 // after the removed index.
83 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
84 results.push_back(targetExpr);
85 }
86 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
87 }
88
89 // Helper to drop dimension from vector type.
adjustType(VectorType tp,int64_t index)90 static Type adjustType(VectorType tp, int64_t index) {
91 int64_t rank = tp.getRank();
92 Type eltType = tp.getElementType();
93 if (rank == 1) {
94 assert(index == 0 && "index for scalar result out of bounds");
95 return eltType;
96 }
97 SmallVector<int64_t, 4> adjustedShape;
98 for (int64_t i = 0; i < rank; ++i) {
99 // Omit dimension at the given index.
100 if (i == index)
101 continue;
102 // Otherwise, add dimension back.
103 adjustedShape.push_back(tp.getDimSize(i));
104 }
105 return VectorType::get(adjustedShape, eltType);
106 }
107
108 // Helper method to possibly drop a dimension in a load.
109 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)110 static Value reshapeLoad(Location loc, Value val, VectorType type,
111 int64_t index, int64_t pos,
112 PatternRewriter &rewriter) {
113 if (index == -1)
114 return val;
115 Type lowType = adjustType(type, 0);
116 // At extraction dimension?
117 if (index == 0) {
118 auto posAttr = rewriter.getI64ArrayAttr(pos);
119 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
120 }
121 // Unroll leading dimensions.
122 VectorType vType = lowType.cast<VectorType>();
123 VectorType resType = adjustType(type, index).cast<VectorType>();
124 Value result =
125 rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
126 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
127 auto posAttr = rewriter.getI64ArrayAttr(d);
128 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
129 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
130 result =
131 rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
132 }
133 return result;
134 }
135
136 // Helper method to possibly drop a dimension in a store.
137 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)138 static Value reshapeStore(Location loc, Value val, Value result,
139 VectorType type, int64_t index, int64_t pos,
140 PatternRewriter &rewriter) {
141 // Unmodified?
142 if (index == -1)
143 return val;
144 // At insertion dimension?
145 if (index == 0) {
146 auto posAttr = rewriter.getI64ArrayAttr(pos);
147 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
148 }
149 // Unroll leading dimensions.
150 Type lowType = adjustType(type, 0);
151 VectorType vType = lowType.cast<VectorType>();
152 Type insType = adjustType(vType, 0);
153 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
154 auto posAttr = rewriter.getI64ArrayAttr(d);
155 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
156 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
157 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
158 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
159 }
160 return result;
161 }
162
163 // Clones `op` into a new operations that takes `operands` and returns
164 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)165 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
166 Operation *op,
167 ArrayRef<Value> operands,
168 ArrayRef<Type> resultTypes) {
169 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
170 op->getAttrs());
171 return builder.createOperation(res);
172 }
173
174 // Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
175 // for each index 'i' in inputElements with a valid mapping in 'indexMap'.
getMappedElements(const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> inputElements,SmallVectorImpl<int64_t> & resultElements)176 static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
177 ArrayRef<int64_t> inputElements,
178 SmallVectorImpl<int64_t> &resultElements) {
179 assert(indexMap.size() == resultElements.size());
180 assert(inputElements.size() >= resultElements.size());
181 for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
182 auto it = indexMap.find(i);
183 if (it != indexMap.end())
184 resultElements[it->second] = inputElements[i];
185 }
186 }
187
188 // Returns a tuple type with vector element types for each resulting slice
189 // of 'vectorType' unrolled by 'sizes' and 'strides'.
190 // TODO: Move this to a utility function and share it with
191 // Extract/InsertSlicesOp verification.
generateExtractSlicesOpResultType(VectorType vectorType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,OpBuilder & builder)192 static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
193 ArrayRef<int64_t> sizes,
194 ArrayRef<int64_t> strides,
195 OpBuilder &builder) {
196 assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
197 assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
198 assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
199
200 // Compute shape ratio of 'shape' and 'sizes'.
201 auto shape = vectorType.getShape();
202 auto maybeDimSliceCounts = shapeRatio(shape, sizes);
203 assert(maybeDimSliceCounts.hasValue());
204 auto sliceDimCounts = *maybeDimSliceCounts;
205
206 // Compute strides w.r.t number of slices in each dimension.
207 auto sliceStrides = computeStrides(sliceDimCounts);
208 int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts);
209 SmallVector<Type, 4> vectorTypes(sliceCount);
210 for (unsigned i = 0; i < sliceCount; ++i) {
211 auto vectorOffsets = delinearize(sliceStrides, i);
212 auto elementOffsets =
213 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
214 auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
215 // Create Vector type and add to 'vectorTypes[i]'.
216 vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
217 }
218 return TupleType::get(vectorTypes, builder.getContext());
219 }
220
221 // UnrolledVectorState aggregates per-operand/result vector state required for
222 // unrolling.
223 struct UnrolledVectorState {
224 SmallVector<int64_t, 4> unrolledShape;
225 SmallVector<int64_t, 4> unrollFactors;
226 SmallVector<int64_t, 8> basis;
227 int64_t numInstances;
228 Value slicesTuple;
229 };
230
231 // Populates 'state' with unrolled shape, unroll factors, basis and
232 // num unrolled instances for 'vectorType'.
initUnrolledVectorState(VectorType vectorType,Value initValue,const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> targetShape,UnrolledVectorState & state,OpBuilder & builder)233 static void initUnrolledVectorState(VectorType vectorType, Value initValue,
234 const DenseMap<int64_t, int64_t> &indexMap,
235 ArrayRef<int64_t> targetShape,
236 UnrolledVectorState &state,
237 OpBuilder &builder) {
238 // Compute unrolled shape of 'vectorType'.
239 state.unrolledShape.resize(vectorType.getRank());
240 getMappedElements(indexMap, targetShape, state.unrolledShape);
241 // Compute unroll factors for unrolled shape.
242 auto maybeUnrollFactors =
243 shapeRatio(vectorType.getShape(), state.unrolledShape);
244 assert(maybeUnrollFactors.hasValue());
245 state.unrollFactors = *maybeUnrollFactors;
246 // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
247 state.basis = computeStrides(state.unrollFactors);
248 state.numInstances = computeMaxLinearIndex(state.unrollFactors);
249 state.slicesTuple = nullptr;
250 if (initValue != nullptr) {
251 // Create ExtractSlicesOp.
252 SmallVector<int64_t, 4> sizes(state.unrolledShape);
253 SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1);
254 auto tupleType =
255 generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
256 state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
257 initValue.getLoc(), tupleType, initValue, sizes, strides);
258 }
259 }
260
261 // Computes and returns the linear index of the unrolled vector at
262 // 'vectorOffsets' within the vector represented by 'state'.
263 static int64_t
getUnrolledVectorLinearIndex(UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,DenseMap<int64_t,int64_t> & indexMap)264 getUnrolledVectorLinearIndex(UnrolledVectorState &state,
265 ArrayRef<int64_t> vectorOffsets,
266 DenseMap<int64_t, int64_t> &indexMap) {
267 // Compute vector offsets.
268 SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
269 getMappedElements(indexMap, vectorOffsets, sliceOffsets);
270 // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
271 return linearize(sliceOffsets, state.basis);
272 }
273
274 // Returns an unrolled vector at 'vectorOffsets' within the vector
275 // represented by 'state'. The vector is created from a slice of 'initValue'
276 // if not present in 'cache'.
getOrCreateUnrolledVectorSlice(Location loc,UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,ArrayRef<int64_t> offsets,DenseMap<int64_t,int64_t> & indexMap,Value initValue,SmallVectorImpl<Value> & cache,OpBuilder & builder)277 static Value getOrCreateUnrolledVectorSlice(
278 Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
279 ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
280 Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
281 // Compute slice offsets.
282 SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
283 getMappedElements(indexMap, offsets, sliceOffsets);
284 // TODO: Support non-1 strides.
285 SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
286 // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
287 int64_t sliceLinearIndex =
288 getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
289 assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
290 auto valueSlice = cache[sliceLinearIndex];
291 if (valueSlice == nullptr) {
292 // Return tuple element at 'sliceLinearIndex'.
293 auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
294 auto initValueType = initValue.getType().cast<VectorType>();
295 auto vectorType =
296 VectorType::get(state.unrolledShape, initValueType.getElementType());
297 // Initialize 'cache' with slice from 'initValue'.
298 valueSlice = builder.create<vector::TupleGetOp>(
299 loc, vectorType, state.slicesTuple, tupleIndex);
300 // Store value back to 'cache'.
301 cache[sliceLinearIndex] = valueSlice;
302 }
303 return valueSlice;
304 }
305
306 // VectorState aggregates per-operand/result vector state required for
307 // creating slices of vector operands, and clones of the operation being
308 // unrolled.
309 struct VectorState {
310 // The type of this vector.
311 VectorType type;
312 // Map from iteration space index to vector dimension index.
313 DenseMap<int64_t, int64_t> indexMap;
314 // Index of this value in operation's operand list (-1 if not an operand).
315 int64_t operandIndex = -1;
316 // Accumulator iterator flag.
317 bool isAcc = false;
318 };
319
320 //
321 // unrollSingleResultStructuredOp
322 //
323 // Returns a value representing the result of structured operation 'op'
324 // with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
325 // A list of VectorState objects must be specified in 'vectors', where
326 // each VectorState in the list represents a vector operand or vector result
327 // (if the operation does not have an accumulator operand).
328 // The VectorState at index 'resultIndex' in the list must be the state
329 // associated with the operations single result (i.e. either its accumulator
330 // operand or vector result value).
331 //
332 // Example:
333 //
334 // // Before unrolling
335 //
336 // operand0 operand1 operand2
337 // \ | /
338 // -------------------- opA --------------------
339 //
340 // // After unrolling by 2
341 //
342 // operand0 operand1 operand2
343 // / \ / \ / \
344 // slice00 slice01 slice10 slice11 slice20 slice21
345 // \ | | | / |
346 // -------------------- opA0 -------------------- |
347 // | | | |
348 // \ | | /
349 // -------------------- opA1 -------------------
350 // | |
351 // \ /
352 // insertslice
353 // |
354
355 // TODO: Add the following canonicalization/simplification patterns:
356 // *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
357 // InsertStridedSlice operand to StridedSlice.
358 // *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
359 // if there are duplicate identical StridedSlice ops from SourceOp, and
360 // rewrites itself to use the first duplicate. This transformation should
361 // cause users of identifical StridedSlice ops to reuse the same StridedSlice
362 // operation, and leave the duplicate StridedSlice ops with no users
363 // (removable with DCE).
364
365 // TODO: Generalize this to support structured ops beyond
366 // vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
unrollSingleResultStructuredOp(Operation * op,ArrayRef<int64_t> iterationBounds,std::vector<VectorState> & vectors,unsigned resultIndex,ArrayRef<int64_t> targetShape,OpBuilder & builder)367 static Value unrollSingleResultStructuredOp(Operation *op,
368 ArrayRef<int64_t> iterationBounds,
369 std::vector<VectorState> &vectors,
370 unsigned resultIndex,
371 ArrayRef<int64_t> targetShape,
372 OpBuilder &builder) {
373 auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
374 if (!shapedType || !shapedType.hasStaticShape())
375 assert(false && "Expected a statically shaped result type");
376
377 // Compute unroll factors for 'iterationBounds' based on 'targetShape'
378 auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
379 if (!maybeUnrollFactors.hasValue())
380 assert(false && "Failed to compute unroll factors for target shape");
381 auto unrollFactors = *maybeUnrollFactors;
382
383 // Compute unrolled vector state for each vector in 'vectors'.
384 unsigned numVectors = vectors.size();
385 SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
386 for (unsigned i = 0; i < numVectors; ++i) {
387 int64_t operandIndex = vectors[i].operandIndex;
388 auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
389 initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
390 targetShape, unrolledVectorState[i], builder);
391 }
392 // Compute number of total unrolled instances.
393 auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
394 auto sliceStrides = computeStrides(unrollFactors);
395
396 auto &resultValueState = unrolledVectorState[resultIndex];
397 auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
398 shapedType.getElementType());
399
400 // Initialize caches for intermediate vector results.
401 std::vector<SmallVector<Value, 4>> caches(numVectors);
402 for (unsigned i = 0; i < numVectors; ++i)
403 caches[i].resize(unrolledVectorState[i].numInstances);
404
405 // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
406 for (unsigned i = 0; i < numUnrolledInstances; ++i) {
407 auto vectorOffsets = delinearize(sliceStrides, i);
408 auto elementOffsets =
409 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
410 // Get cached slice (or create slice) for each operand at 'offsets'.
411 SmallVector<Value, 3> operands;
412 operands.resize(op->getNumOperands());
413 for (unsigned i = 0; i < numVectors; ++i) {
414 int64_t operandIndex = vectors[i].operandIndex;
415 if (operandIndex < 0)
416 continue; // Output
417 auto operand = op->getOperand(operandIndex);
418 operands[operandIndex] = getOrCreateUnrolledVectorSlice(
419 op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets,
420 vectors[i].indexMap, operand, caches[i], builder);
421 }
422 // Create op on sliced vector arguments.
423 auto resultVector =
424 cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
425 unrolledResultType)
426 ->getResult(0);
427
428 // Compute linear result index.
429 int64_t linearIndex = getUnrolledVectorLinearIndex(
430 resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
431 // Update result cache at 'linearIndex'.
432 caches[resultIndex][linearIndex] = resultVector;
433 }
434
435 // Create TupleOp of unrolled result vectors.
436 SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
437 SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
438 for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
439 vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
440 vectorTupleValues[i] = caches[resultIndex][i];
441 }
442 TupleType tupleType = builder.getTupleType(vectorTupleTypes);
443 Value tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
444 vectorTupleValues);
445
446 // Create InsertSlicesOp(Tuple(result_vectors)).
447 auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
448 SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
449 SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
450
451 Value insertSlicesOp = builder.create<vector::InsertSlicesOp>(
452 op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
453 builder.getI64ArrayAttr(strides));
454 return insertSlicesOp;
455 }
456
getVectorContractionOpUnrollState(vector::ContractionOp contractionOp,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)457 static void getVectorContractionOpUnrollState(
458 vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
459 std::vector<VectorState> &vectors, unsigned &resultIndex) {
460 // Get map from iteration space index to lhs/rhs/result shape index.
461 std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
462 contractionOp.getIterationIndexMap(iterationIndexMapList);
463 unsigned numIterators = iterationIndexMapList.size();
464 vectors.resize(numIterators);
465 unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
466 for (unsigned i = 0; i < numIterators; ++i) {
467 vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
468 vectors[i].indexMap = iterationIndexMapList[i];
469 vectors[i].operandIndex = i;
470 vectors[i].isAcc = i == accOperandIndex ? true : false;
471 }
472
473 if (llvm::size(contractionOp.masks()) == 2) {
474 // Add vectors for lhs/rhs vector mask arguments. Masks have the
475 // same vector shape lhs/rhs args, so copy their index maps.
476 vectors.push_back({contractionOp.getLHSVectorMaskType(),
477 vectors[0].indexMap, accOperandIndex + 1, false});
478 vectors.push_back({contractionOp.getRHSVectorMaskType(),
479 vectors[1].indexMap, accOperandIndex + 2, false});
480 }
481 // TODO: Use linalg style 'args_in'/'args_out' to partition
482 // 'vectors' instead of 'resultIndex'.
483 resultIndex = accOperandIndex;
484 }
485
getVectorElementwiseOpUnrollState(Operation * op,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)486 static void getVectorElementwiseOpUnrollState(Operation *op,
487 ArrayRef<int64_t> targetShape,
488 std::vector<VectorState> &vectors,
489 unsigned &resultIndex) {
490 // Verify that operation and operands all have the same vector shape.
491 auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
492 assert(resultType && "Expected op with vector result type");
493 auto resultShape = resultType.getShape();
494 // Verify that all operands have the same vector type as result.
495 assert(llvm::all_of(op->getOperandTypes(),
496 [=](Type type) { return type == resultType; }));
497
498 // Create trivial elementwise identity index map based on 'resultShape'.
499 DenseMap<int64_t, int64_t> indexMap;
500 indexMap.reserve(resultShape.size());
501 for (unsigned i = 0; i < resultShape.size(); ++i)
502 indexMap[i] = i;
503
504 // Create VectorState each operand and single result.
505 unsigned numVectors = op->getNumOperands() + op->getNumResults();
506 vectors.resize(numVectors);
507 for (unsigned i = 0; i < op->getNumOperands(); ++i)
508 vectors[i] = {resultType, indexMap, i, false};
509 vectors[numVectors - 1] = {resultType, indexMap, -1, false};
510 resultIndex = numVectors - 1;
511 }
512
513 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
514 /// calls 'fn' with linear index and indices for each slice.
generateTransferOpSlices(Type memrefElementType,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<Value> indices,OpBuilder & builder,function_ref<void (unsigned,ArrayRef<Value>)> fn)515 static void generateTransferOpSlices(
516 Type memrefElementType, VectorType vectorType, TupleType tupleType,
517 ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
518 OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
519 // Compute strides w.r.t. to slice counts in each dimension.
520 auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
521 assert(maybeDimSliceCounts.hasValue());
522 auto sliceDimCounts = *maybeDimSliceCounts;
523 auto sliceStrides = computeStrides(sliceDimCounts);
524
525 int64_t numSlices = tupleType.size();
526 unsigned numSliceIndices = indices.size();
527 // Compute 'indexOffset' at which to update 'indices', which is equal
528 // to the memref rank (indices.size) minus the effective 'vectorRank'.
529 // The effective 'vectorRank', is equal to the rank of the vector type
530 // minus the rank of the memref vector element type (if it has one).
531 //
532 // For example:
533 //
534 // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
535 // transfer_read/write ops which read/write vectors of type
536 // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
537 // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
538 //
539 unsigned vectorRank = vectorType.getRank();
540 if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
541 assert(vectorRank >= memrefVectorElementType.getRank());
542 vectorRank -= memrefVectorElementType.getRank();
543 }
544 unsigned indexOffset = numSliceIndices - vectorRank;
545
546 auto *ctx = builder.getContext();
547 for (unsigned i = 0; i < numSlices; ++i) {
548 auto vectorOffsets = delinearize(sliceStrides, i);
549 auto elementOffsets =
550 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
551 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
552 SmallVector<Value, 4> sliceIndices(numSliceIndices);
553 for (unsigned j = 0; j < numSliceIndices; ++j) {
554 if (j < indexOffset) {
555 sliceIndices[j] = indices[j];
556 } else {
557 auto expr = getAffineDimExpr(0, ctx) +
558 getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
559 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
560 sliceIndices[j] = builder.create<AffineApplyOp>(
561 indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
562 }
563 }
564 // Call 'fn' to generate slice 'i' at 'sliceIndices'.
565 fn(i, sliceIndices);
566 }
567 }
568
569 /// Returns true if 'map' is a suffix of an identity affine map, false
570 /// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
isIdentitySuffix(AffineMap map)571 static bool isIdentitySuffix(AffineMap map) {
572 if (map.getNumDims() < map.getNumResults())
573 return false;
574 ArrayRef<AffineExpr> results = map.getResults();
575 Optional<int> lastPos;
576 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
577 auto expr = results[i].dyn_cast<AffineDimExpr>();
578 if (!expr)
579 return false;
580 int currPos = static_cast<int>(expr.getPosition());
581 if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
582 return false;
583 lastPos = currPos;
584 }
585 return true;
586 }
587
588 /// Unroll transfer_read ops to the given shape and create an aggregate with all
589 /// the chunks.
unrollTransferReadOp(vector::TransferReadOp readOp,ArrayRef<int64_t> targetShape,OpBuilder & builder)590 static Value unrollTransferReadOp(vector::TransferReadOp readOp,
591 ArrayRef<int64_t> targetShape,
592 OpBuilder &builder) {
593 if (!isIdentitySuffix(readOp.permutation_map()))
594 return nullptr;
595 auto sourceVectorType = readOp.getVectorType();
596 SmallVector<int64_t, 4> strides(targetShape.size(), 1);
597
598 Location loc = readOp.getLoc();
599 auto memrefElementType =
600 readOp.memref().getType().cast<MemRefType>().getElementType();
601 auto tupleType = generateExtractSlicesOpResultType(
602 sourceVectorType, targetShape, strides, builder);
603 int64_t numSlices = tupleType.size();
604
605 SmallVector<Value, 4> vectorTupleValues(numSlices);
606 SmallVector<Value, 4> indices(readOp.indices().begin(),
607 readOp.indices().end());
608 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
609 // Get VectorType for slice 'i'.
610 auto sliceVectorType = tupleType.getType(index);
611 // Create split TransferReadOp for 'sliceUser'.
612 // `masked` attribute propagates conservatively: if the coarse op didn't
613 // need masking, the fine op doesn't either.
614 vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
615 loc, sliceVectorType, readOp.memref(), sliceIndices,
616 readOp.permutation_map(), readOp.padding(),
617 readOp.masked() ? *readOp.masked() : ArrayAttr());
618 };
619 generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
620 targetShape, strides, indices, builder, createSlice);
621
622 // Create tuple of splice transfer read operations.
623 Value tupleOp =
624 builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
625 // Replace 'readOp' with result 'insertSlicesResult'.
626 Value newVec = builder.create<vector::InsertSlicesOp>(
627 loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
628 builder.getI64ArrayAttr(strides));
629 return newVec;
630 }
631
632 // Entry point for unrolling declarative pattern rewrite for transfer_write op.
633 LogicalResult
unrollTransferWriteOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)634 mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
635 ArrayRef<int64_t> targetShape) {
636 auto writeOp = cast<vector::TransferWriteOp>(op);
637 if (!isIdentitySuffix(writeOp.permutation_map()))
638 return failure();
639 VectorType sourceVectorType = writeOp.getVectorType();
640 SmallVector<int64_t, 4> strides(targetShape.size(), 1);
641 TupleType tupleType = generateExtractSlicesOpResultType(
642 sourceVectorType, targetShape, strides, builder);
643 Location loc = writeOp.getLoc();
644 Value tuple = builder.create<vector::ExtractSlicesOp>(
645 loc, tupleType, writeOp.vector(), targetShape, strides);
646 auto memrefElementType =
647 writeOp.memref().getType().cast<MemRefType>().getElementType();
648 SmallVector<Value, 4> indices(writeOp.indices().begin(),
649 writeOp.indices().end());
650 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
651 auto element = builder.create<vector::TupleGetOp>(
652 loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
653 builder.create<vector::TransferWriteOp>(
654 loc, element.getResult(), writeOp.memref(), sliceIndices,
655 writeOp.permutation_map(),
656 writeOp.masked() ? *writeOp.masked() : ArrayAttr());
657 };
658 generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
659 targetShape, strides, indices, builder, createSlice);
660 return success();
661 }
662
663 // Entry point for unrolling declarative pattern rewrites.
664 SmallVector<Value, 1>
unrollSingleResultVectorOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)665 mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
666 ArrayRef<int64_t> targetShape) {
667 assert(op->getNumResults() == 1 && "Expected single result operation");
668
669 // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
670 SmallVector<int64_t, 6> iterationBounds;
671 auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
672 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
673 assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
674
675 std::vector<VectorState> vectors;
676 unsigned resultIndex;
677
678 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
679 return SmallVector<Value, 1>{
680 unrollTransferReadOp(readOp, targetShape, builder)};
681
682 if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
683 // Populate state for vector ContractionOp.
684 getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
685 resultIndex);
686 } else {
687 // Populate state for vector elementwise op.
688 getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
689 }
690
691 // Unroll 'op' with 'iterationBounds' to 'targetShape'.
692 return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
693 op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
694 }
695
696 namespace {
697
698 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
699 // scheme of its unique ExtractSlicesOp user.
700 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
701 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
702
matchAndRewrite__anon642465890511::SplitTransferReadOp703 LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
704 PatternRewriter &rewriter) const override {
705 // TODO: Support splitting TransferReadOp with non-identity
706 // permutation maps. Repurpose code from MaterializeVectors transformation.
707 if (!isIdentitySuffix(xferReadOp.permutation_map()))
708 return failure();
709 // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
710 Value xferReadResult = xferReadOp.getResult();
711 auto extractSlicesOp =
712 dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
713 if (!xferReadResult.hasOneUse() || !extractSlicesOp)
714 return failure();
715
716 // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
717 SmallVector<int64_t, 4> sizes;
718 extractSlicesOp.getSizes(sizes);
719 SmallVector<int64_t, 4> strides;
720 extractSlicesOp.getStrides(strides);
721 assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
722
723 Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
724 if (!newVec)
725 return failure();
726 rewriter.replaceOp(xferReadOp, newVec);
727 return success();
728 }
729 };
730
731 // Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
732 struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
733 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
734
matchAndRewrite__anon642465890511::SplitTransferWriteOp735 LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
736 PatternRewriter &rewriter) const override {
737 // TODO: Support splitting TransferWriteOp with non-identity
738 // permutation maps. Repurpose code from MaterializeVectors transformation.
739 if (!isIdentitySuffix(xferWriteOp.permutation_map()))
740 return failure();
741 // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
742 auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
743 auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
744 if (!insertSlicesOp)
745 return failure();
746
747 // Get TupleOp operand of 'insertSlicesOp'.
748 auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
749 insertSlicesOp.vectors().getDefiningOp());
750 if (!tupleOp)
751 return failure();
752
753 // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
754 auto sourceTupleType = insertSlicesOp.getSourceTupleType();
755 auto resultVectorType = insertSlicesOp.getResultVectorType();
756 SmallVector<int64_t, 4> sizes;
757 insertSlicesOp.getSizes(sizes);
758 SmallVector<int64_t, 4> strides;
759 insertSlicesOp.getStrides(strides);
760
761 Location loc = xferWriteOp.getLoc();
762 auto memrefElementType =
763 xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
764 SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
765 xferWriteOp.indices().end());
766 auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
767 // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
768 // `masked` attribute propagates conservatively: if the coarse op didn't
769 // need masking, the fine op doesn't either.
770 rewriter.create<vector::TransferWriteOp>(
771 loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
772 xferWriteOp.permutation_map(),
773 xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
774 };
775 generateTransferOpSlices(memrefElementType, resultVectorType,
776 sourceTupleType, sizes, strides, indices, rewriter,
777 createSlice);
778
779 // Erase old 'xferWriteOp'.
780 rewriter.eraseOp(xferWriteOp);
781 return success();
782 }
783 };
784
785 /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
786 /// on vector types.
787 struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
788 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
789
matchAndRewrite__anon642465890511::ShapeCastOpDecomposer790 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
791 PatternRewriter &rewriter) const override {
792 // Check if 'shapeCastOp' has tuple source/result type.
793 auto sourceTupleType =
794 shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
795 auto resultTupleType =
796 shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
797 if (!sourceTupleType || !resultTupleType)
798 return failure();
799 assert(sourceTupleType.size() == resultTupleType.size());
800
801 // Create single-vector ShapeCastOp for each source tuple element.
802 Location loc = shapeCastOp.getLoc();
803 SmallVector<Value, 8> resultElements;
804 resultElements.reserve(resultTupleType.size());
805 for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
806 auto sourceElement = rewriter.create<vector::TupleGetOp>(
807 loc, sourceTupleType.getType(i), shapeCastOp.source(),
808 rewriter.getI64IntegerAttr(i));
809 resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
810 loc, resultTupleType.getType(i), sourceElement));
811 }
812
813 // Replace 'shapeCastOp' with tuple of 'resultElements'.
814 rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
815 resultElements);
816 return success();
817 }
818 };
819
820 /// Returns the producer Value of the same type as 'consumerValue', by tracking
821 /// the tuple index and offsets of the consumer vector value through the
822 /// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
823 /// and ShapeCastOp) from consumer to producer. Each operation in the chain is
824 /// structured, and so the tuple index and offsets can be mapped from result to
825 /// input, while visiting each operation in the chain.
826 /// Returns nullptr on failure.
getProducerValue(Value consumerValue)827 static Value getProducerValue(Value consumerValue) {
828 auto consumerVectorType = consumerValue.getType().cast<VectorType>();
829 // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
830 int64_t tupleIndex = -1;
831 SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
832 auto *op = consumerValue.getDefiningOp();
833 while (op != nullptr) {
834 if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
835 assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
836
837 // Update 'tupleIndex' and next defining 'op' to visit.
838 tupleIndex = tupleGetOp.getIndex();
839 op = tupleGetOp.vectors().getDefiningOp();
840 } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
841 assert(tupleIndex >= 0);
842
843 // Compute slice strides for 'extractSlicesOp'.
844 SmallVector<int64_t, 4> sizes;
845 extractSlicesOp.getSizes(sizes);
846 auto sliceStrides = computeStrides(
847 extractSlicesOp.getSourceVectorType().getShape(), sizes);
848
849 // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
850 // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
851 auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
852 auto elementOffsets =
853 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
854
855 // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
856 // to the 'extractSlicesOp' input vector type.
857 assert(offsets.size() == elementOffsets.size());
858 for (unsigned i = 0, e = offsets.size(); i < e; ++i)
859 offsets[i] += elementOffsets[i];
860
861 // Clear 'tupleIndex' and update next defining 'op' to visit.
862 tupleIndex = -1;
863 op = extractSlicesOp.vector().getDefiningOp();
864 } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
865 assert(tupleIndex == -1);
866
867 // Compute slice strides for 'insertSlicesOp'.
868 SmallVector<int64_t, 4> sizes;
869 insertSlicesOp.getSizes(sizes);
870 auto sliceStrides = computeStrides(
871 insertSlicesOp.getResultVectorType().getShape(), sizes);
872
873 // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
874 // of 'insertSlicesOp' result vector type at 'offsets'.
875 SmallVector<int64_t, 4> vectorOffsets(offsets.size());
876 assert(offsets.size() == sizes.size());
877 for (unsigned i = 0, e = offsets.size(); i < e; ++i)
878 vectorOffsets[i] = offsets[i] / sizes[i];
879
880 // Compute the source tuple element index.
881 tupleIndex = linearize(vectorOffsets, sliceStrides);
882
883 // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
884 // relative to input tuple element vector type at 'tupleIndex'.
885 auto elementOffsets =
886 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
887 assert(offsets.size() == elementOffsets.size());
888 for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
889 offsets[i] -= elementOffsets[i];
890 assert(offsets[i] >= 0);
891 }
892
893 // Update next defining 'op' to visit.
894 op = insertSlicesOp.vectors().getDefiningOp();
895 } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
896 assert(tupleIndex >= 0);
897
898 // Return tuple element 'value' at 'tupleIndex' if it matches type.
899 auto value = tupleOp.getOperand(tupleIndex);
900 if (value.getType() == consumerVectorType)
901 return value;
902
903 // Update 'tupleIndex' and next defining 'op' to visit.
904 tupleIndex = -1;
905 op = value.getDefiningOp();
906 } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
907 if (shapeCastOp.source().getType().isa<TupleType>())
908 return nullptr;
909 assert(tupleIndex == -1);
910 auto sourceVectorType = shapeCastOp.getSourceVectorType();
911 auto sourceVectorShape = sourceVectorType.getShape();
912 unsigned sourceVectorRank = sourceVectorType.getRank();
913 auto resultVectorType = shapeCastOp.getResultVectorType();
914 auto resultVectorShape = resultVectorType.getShape();
915 unsigned resultVectorRank = resultVectorType.getRank();
916
917 int i = sourceVectorRank - 1;
918 int j = resultVectorRank - 1;
919
920 // Check that source/result vector shape prefixes match while updating
921 // 'newOffsets'.
922 SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
923 for (auto it : llvm::zip(llvm::reverse(sourceVectorShape),
924 llvm::reverse(resultVectorShape))) {
925 if (std::get<0>(it) != std::get<1>(it))
926 return nullptr;
927 newOffsets[i--] = offsets[j--];
928 }
929
930 // Check that remaining prefix of source/result vector shapes are all 1s.
931 // Currently we only support producer/consumer tracking through trivial
932 // shape cast ops. Examples:
933 // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
934 // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
935 assert(i == -1 || j == -1);
936 if (i >= 0 &&
937 !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
938 [](int64_t v) { return v == 1; }))
939 return nullptr;
940 if (j >= 0 &&
941 !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
942 [](int64_t v) { return v == 1; }))
943 return nullptr;
944
945 offsets.swap(newOffsets);
946 op = shapeCastOp.source().getDefiningOp();
947 } else {
948 // Check if 'op' produces a Value with the same type as 'consumerValue'.
949 if (op->getNumResults() == 1 &&
950 op->getResult(0).getType() == consumerVectorType)
951 return op->getResult(0);
952 return nullptr;
953 }
954 }
955 return nullptr;
956 }
957
958 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
959 //
960 // Example:
961 //
962 // The following MLIR with cancelling ShapeCastOps:
963 //
964 // %0 = source : vector<5x4x2xf32>
965 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
966 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
967 // %3 = user %2 : vector<5x4x2xf32>
968 //
969 // Should canonicalize to the following:
970 //
971 // %0 = source : vector<5x4x2xf32>
972 // %1 = user %0 : vector<5x4x2xf32>
973 //
974 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
975 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
976
matchAndRewrite__anon642465890511::ShapeCastOpFolder977 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
978 PatternRewriter &rewriter) const override {
979 // Check if we can replace 'shapeCastOp' result with its producer.
980 if (auto producer = getProducerValue(shapeCastOp.getResult())) {
981 rewriter.replaceOp(shapeCastOp, producer);
982 return success();
983 }
984
985 // Check if 'shapeCastOp' has vector source/result type.
986 auto sourceVectorType =
987 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
988 auto resultVectorType =
989 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
990 if (!sourceVectorType || !resultVectorType)
991 return failure();
992
993 // Check if shape cast op source operand is also a shape cast op.
994 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
995 shapeCastOp.source().getDefiningOp());
996 if (!sourceShapeCastOp)
997 return failure();
998 auto operandSourceVectorType =
999 sourceShapeCastOp.source().getType().cast<VectorType>();
1000 auto operandResultVectorType =
1001 sourceShapeCastOp.result().getType().cast<VectorType>();
1002
1003 // Check if shape cast operations invert each other.
1004 if (operandSourceVectorType != resultVectorType ||
1005 operandResultVectorType != sourceVectorType)
1006 return failure();
1007
1008 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
1009 return success();
1010 }
1011 };
1012
1013 // Patter rewrite which forward tuple elements to their users.
1014 // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
1015 // -> User(Producer)
1016 struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
1017 using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
1018
matchAndRewrite__anon642465890511::TupleGetFolderOp1019 LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
1020 PatternRewriter &rewriter) const override {
1021 if (auto producer = getProducerValue(tupleGetOp.getResult())) {
1022 rewriter.replaceOp(tupleGetOp, producer);
1023 return success();
1024 }
1025 return failure();
1026 }
1027 };
1028
1029 /// Progressive lowering of ExtractSlicesOp to tuple of ExtractStridedSliceOp.
1030 /// One:
1031 /// %x = vector.extract_slices %0
1032 /// is replaced by:
1033 /// %a = vector.strided_slice %0
1034 /// %b = vector.strided_slice %0
1035 /// ..
1036 /// %x = vector.tuple %a, %b, ..
1037 class ExtractSlicesOpLowering
1038 : public OpRewritePattern<vector::ExtractSlicesOp> {
1039 public:
1040 using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
1041
matchAndRewrite(vector::ExtractSlicesOp op,PatternRewriter & rewriter) const1042 LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
1043 PatternRewriter &rewriter) const override {
1044 auto loc = op.getLoc();
1045
1046 VectorType vectorType = op.getSourceVectorType();
1047 auto shape = vectorType.getShape();
1048
1049 SmallVector<int64_t, 4> sizes;
1050 op.getSizes(sizes);
1051 SmallVector<int64_t, 4> strides;
1052 op.getStrides(strides); // all-ones at the moment
1053
1054 // For each element in the tuple, generate the proper strided slice.
1055 TupleType tupleType = op.getResultTupleType();
1056 int64_t tupleSize = tupleType.size();
1057 SmallVector<Value, 4> tupleValues(tupleSize);
1058 auto sliceStrides = computeStrides(shape, sizes);
1059 for (int64_t i = 0; i < tupleSize; ++i) {
1060 auto vectorOffsets = delinearize(sliceStrides, i);
1061 auto elementOffsets =
1062 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1063 auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
1064 // Insert in tuple.
1065 tupleValues[i] = rewriter.create<vector::ExtractStridedSliceOp>(
1066 loc, op.vector(), elementOffsets, sliceSizes, strides);
1067 }
1068
1069 rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
1070 return success();
1071 }
1072 };
1073
1074 /// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
1075 /// One:
1076 /// %x = vector.insert_slices %0
1077 /// is replaced by:
1078 /// %r0 = zero-result
1079 /// %t1 = vector.tuple_get %0, 0
1080 /// %r1 = vector.insert_strided_slice %r0, %t1
1081 /// %t2 = vector.tuple_get %0, 1
1082 /// %r2 = vector.insert_strided_slice %r1, %t2
1083 /// ..
1084 /// %x = ..
1085 class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
1086 public:
1087 using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
1088
matchAndRewrite(vector::InsertSlicesOp op,PatternRewriter & rewriter) const1089 LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
1090 PatternRewriter &rewriter) const override {
1091 auto loc = op.getLoc();
1092
1093 VectorType vectorType = op.getResultVectorType();
1094 auto shape = vectorType.getShape();
1095
1096 SmallVector<int64_t, 4> sizes;
1097 op.getSizes(sizes);
1098 SmallVector<int64_t, 4> strides;
1099 op.getStrides(strides); // all-ones at the moment
1100
1101 // Prepare result.
1102 Value result = rewriter.create<ConstantOp>(
1103 loc, vectorType, rewriter.getZeroAttr(vectorType));
1104
1105 // For each element in the tuple, extract the proper strided slice.
1106 TupleType tupleType = op.getSourceTupleType();
1107 int64_t tupleSize = tupleType.size();
1108 auto sliceStrides = computeStrides(shape, sizes);
1109 for (int64_t i = 0; i < tupleSize; ++i) {
1110 auto vectorOffsets = delinearize(sliceStrides, i);
1111 auto elementOffsets =
1112 computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1113 // Extract from tuple into the result.
1114 auto index = rewriter.getI64IntegerAttr(i);
1115 auto tupleGet = rewriter.create<vector::TupleGetOp>(
1116 loc, tupleType.getType(i), op.getOperand(), index);
1117 result = rewriter.create<vector::InsertStridedSliceOp>(
1118 loc, tupleGet, result, elementOffsets, strides);
1119 }
1120
1121 rewriter.replaceOp(op, result);
1122 return success();
1123 }
1124 };
1125
1126 /// Progressive lowering of BroadcastOp.
1127 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
1128 public:
1129 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
1130
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const1131 LogicalResult matchAndRewrite(vector::BroadcastOp op,
1132 PatternRewriter &rewriter) const override {
1133 auto loc = op.getLoc();
1134 VectorType dstType = op.getVectorType();
1135 VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
1136 Type eltType = dstType.getElementType();
1137
1138 // Determine rank of source and destination.
1139 int64_t srcRank = srcType ? srcType.getRank() : 0;
1140 int64_t dstRank = dstType.getRank();
1141
1142 // Duplicate this rank.
1143 // For example:
1144 // %x = broadcast %y : k-D to n-D, k < n
1145 // becomes:
1146 // %b = broadcast %y : k-D to (n-1)-D
1147 // %x = [%b,%b,%b,%b] : n-D
1148 // becomes:
1149 // %b = [%y,%y] : (n-1)-D
1150 // %x = [%b,%b,%b,%b] : n-D
1151 if (srcRank < dstRank) {
1152 // Scalar to any vector can use splat.
1153 if (srcRank == 0) {
1154 rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
1155 return success();
1156 }
1157 // Duplication.
1158 VectorType resType =
1159 VectorType::get(dstType.getShape().drop_front(), eltType);
1160 Value bcst =
1161 rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
1162 Value result = rewriter.create<ConstantOp>(loc, dstType,
1163 rewriter.getZeroAttr(dstType));
1164 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1165 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1166 rewriter.replaceOp(op, result);
1167 return success();
1168 }
1169
1170 // Find non-matching dimension, if any.
1171 assert(srcRank == dstRank);
1172 int64_t m = -1;
1173 for (int64_t r = 0; r < dstRank; r++)
1174 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
1175 m = r;
1176 break;
1177 }
1178
1179 // All trailing dimensions are the same. Simply pass through.
1180 if (m == -1) {
1181 rewriter.replaceOp(op, op.source());
1182 return success();
1183 }
1184
1185 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
1186 if (srcRank == 1) {
1187 assert(m == 0);
1188 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1189 rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
1190 return success();
1191 }
1192
1193 // Any non-matching dimension forces a stretch along this rank.
1194 // For example:
1195 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
1196 // becomes:
1197 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
1198 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
1199 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
1200 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
1201 // %x = [%a,%b,%c,%d]
1202 // becomes:
1203 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
1204 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
1205 // %a = [%u, %v]
1206 // ..
1207 // %x = [%a,%b,%c,%d]
1208 VectorType resType =
1209 VectorType::get(dstType.getShape().drop_front(), eltType);
1210 Value result = rewriter.create<ConstantOp>(loc, dstType,
1211 rewriter.getZeroAttr(dstType));
1212 if (m == 0) {
1213 // Stetch at start.
1214 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1215 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1216 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1217 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1218 } else {
1219 // Stetch not at start.
1220 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
1221 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
1222 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1223 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1224 }
1225 }
1226 rewriter.replaceOp(op, result);
1227 return success();
1228 }
1229 };
1230
1231 /// Progressive lowering of TransposeOp.
1232 /// One:
1233 /// %x = vector.transpose %y, [1, 0]
1234 /// is replaced by:
1235 /// %z = constant dense<0.000000e+00>
1236 /// %0 = vector.extract %y[0, 0]
1237 /// %1 = vector.insert %0, %z [0, 0]
1238 /// ..
1239 /// %x = vector.insert .., .. [.., ..]
1240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
1241 public:
1242 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1243
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,MLIRContext * context)1244 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
1245 MLIRContext *context)
1246 : OpRewritePattern<vector::TransposeOp>(context),
1247 vectorTransformsOptions(vectorTransformsOptions) {}
1248
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const1249 LogicalResult matchAndRewrite(vector::TransposeOp op,
1250 PatternRewriter &rewriter) const override {
1251 auto loc = op.getLoc();
1252
1253 VectorType resType = op.getResultType();
1254
1255 // Set up convenience transposition table.
1256 SmallVector<int64_t, 4> transp;
1257 for (auto attr : op.transp())
1258 transp.push_back(attr.cast<IntegerAttr>().getInt());
1259
1260 // Handle a true 2-D matrix transpose differently when requested.
1261 if (vectorTransformsOptions.vectorTransposeLowering ==
1262 vector::VectorTransposeLowering::Flat &&
1263 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
1264 Type flattenedType =
1265 VectorType::get(resType.getNumElements(), resType.getElementType());
1266 auto matrix =
1267 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
1268 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
1269 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
1270 Value trans = rewriter.create<vector::FlatTransposeOp>(
1271 loc, flattenedType, matrix, rows, columns);
1272 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
1273 return success();
1274 }
1275
1276 // Generate fully unrolled extract/insert ops.
1277 Value result = rewriter.create<ConstantOp>(loc, resType,
1278 rewriter.getZeroAttr(resType));
1279 SmallVector<int64_t, 4> lhs(transp.size(), 0);
1280 SmallVector<int64_t, 4> rhs(transp.size(), 0);
1281 rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
1282 op.vector(), result, rewriter));
1283 return success();
1284 }
1285
1286 private:
1287 // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
1288 // operation when al ranks are exhausted.
expandIndices(Location loc,VectorType resType,int64_t pos,SmallVector<int64_t,4> & transp,SmallVector<int64_t,4> & lhs,SmallVector<int64_t,4> & rhs,Value input,Value result,PatternRewriter & rewriter) const1289 Value expandIndices(Location loc, VectorType resType, int64_t pos,
1290 SmallVector<int64_t, 4> &transp,
1291 SmallVector<int64_t, 4> &lhs,
1292 SmallVector<int64_t, 4> &rhs, Value input, Value result,
1293 PatternRewriter &rewriter) const {
1294 if (pos >= resType.getRank()) {
1295 auto ridx = rewriter.getI64ArrayAttr(rhs);
1296 auto lidx = rewriter.getI64ArrayAttr(lhs);
1297 Type eltType = resType.getElementType();
1298 Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
1299 return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
1300 }
1301 for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
1302 lhs[pos] = d;
1303 rhs[transp[pos]] = d;
1304 result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
1305 result, rewriter);
1306 }
1307 return result;
1308 }
1309
1310 /// Options to control the vector patterns.
1311 vector::VectorTransformsOptions vectorTransformsOptions;
1312 };
1313
1314 /// Progressive lowering of OuterProductOp.
1315 /// One:
1316 /// %x = vector.outerproduct %lhs, %rhs, %acc
1317 /// is replaced by:
1318 /// %z = zero-result
1319 /// %0 = vector.extract %lhs[0]
1320 /// %1 = vector.broadcast %0
1321 /// %2 = vector.extract %acc[0]
1322 /// %3 = vector.fma %1, %rhs, %2
1323 /// %4 = vector.insert %3, %z[0]
1324 /// ..
1325 /// %x = vector.insert %.., %..[N-1]
1326 ///
1327 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1328 public:
1329 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
1330
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const1331 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1332 PatternRewriter &rewriter) const override {
1333 auto loc = op.getLoc();
1334
1335 VectorType lhsType = op.getOperandVectorTypeLHS();
1336 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
1337 VectorType resType = op.getVectorType();
1338 Type eltType = resType.getElementType();
1339 bool isInt = eltType.isa<IntegerType>();
1340 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
1341
1342 if (!rhsType) {
1343 // Special case: AXPY operation.
1344 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
1345 rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
1346 return success();
1347 }
1348
1349 Value result = rewriter.create<ConstantOp>(loc, resType,
1350 rewriter.getZeroAttr(resType));
1351 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1352 auto pos = rewriter.getI64ArrayAttr(d);
1353 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
1354 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1355 Value r = nullptr;
1356 if (acc)
1357 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
1358 Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
1359 result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
1360 }
1361 rewriter.replaceOp(op, result);
1362 return success();
1363 }
1364
1365 private:
genMult(Location loc,Value x,Value y,Value acc,bool isInt,PatternRewriter & rewriter)1366 static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
1367 PatternRewriter &rewriter) {
1368 if (acc) {
1369 if (isInt)
1370 return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
1371 acc);
1372 return rewriter.create<vector::FMAOp>(loc, x, y, acc);
1373 }
1374 if (isInt)
1375 return rewriter.create<MulIOp>(loc, x, y);
1376 return rewriter.create<MulFOp>(loc, x, y);
1377 }
1378 };
1379
1380 /// Progressive lowering of ConstantMaskOp.
1381 /// One:
1382 /// %x = vector.constant_mask [a,b]
1383 /// is replaced by:
1384 /// %z = zero-result
1385 /// %l = vector.constant_mask [b]
1386 /// %4 = vector.insert %l, %z[0]
1387 /// ..
1388 /// %x = vector.insert %l, %..[a-1]
1389 /// until a one-dimensional vector is reached. All these operations
1390 /// will be folded at LLVM IR level.
1391 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
1392 public:
1393 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
1394
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const1395 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
1396 PatternRewriter &rewriter) const override {
1397 auto loc = op.getLoc();
1398 auto dstType = op.getResult().getType().cast<VectorType>();
1399 auto eltType = dstType.getElementType();
1400 auto dimSizes = op.mask_dim_sizes();
1401 int64_t rank = dimSizes.size();
1402 int64_t trueDim = std::min(dstType.getDimSize(0),
1403 dimSizes[0].cast<IntegerAttr>().getInt());
1404
1405 if (rank == 1) {
1406 // Express constant 1-D case in explicit vector form:
1407 // [T,..,T,F,..,F].
1408 SmallVector<bool, 4> values(dstType.getDimSize(0));
1409 for (int64_t d = 0; d < trueDim; d++)
1410 values[d] = true;
1411 rewriter.replaceOpWithNewOp<ConstantOp>(
1412 op, dstType, rewriter.getBoolVectorAttr(values));
1413 return success();
1414 }
1415
1416 VectorType lowType =
1417 VectorType::get(dstType.getShape().drop_front(), eltType);
1418 SmallVector<int64_t, 4> newDimSizes;
1419 for (int64_t r = 1; r < rank; r++)
1420 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
1421 Value trueVal = rewriter.create<vector::ConstantMaskOp>(
1422 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
1423 Value result = rewriter.create<ConstantOp>(loc, dstType,
1424 rewriter.getZeroAttr(dstType));
1425 for (int64_t d = 0; d < trueDim; d++) {
1426 auto pos = rewriter.getI64ArrayAttr(d);
1427 result =
1428 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
1429 }
1430 rewriter.replaceOp(op, result);
1431 return success();
1432 }
1433 };
1434
1435 /// Progressive lowering of CreateMaskOp.
1436 /// One:
1437 /// %x = vector.create_mask %a, ... : vector<dx...>
1438 /// is replaced by:
1439 /// %l = vector.create_mask ... : vector<...> ; one lower rank
1440 /// %0 = cmpi "slt", %ci, %a |
1441 /// %1 = select %0, %l, %zeroes |
1442 /// %r = vector.insert %1, %pr [i] | d-times
1443 /// %x = ....
1444 /// until a one-dimensional vector is reached.
1445 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
1446 public:
1447 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
1448
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const1449 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1450 PatternRewriter &rewriter) const override {
1451 auto loc = op.getLoc();
1452 auto dstType = op.getResult().getType().cast<VectorType>();
1453 auto eltType = dstType.getElementType();
1454 int64_t dim = dstType.getDimSize(0);
1455 int64_t rank = dstType.getRank();
1456 Value idx = op.getOperand(0);
1457
1458 if (rank == 1)
1459 return failure(); // leave for lowering
1460
1461 VectorType lowType =
1462 VectorType::get(dstType.getShape().drop_front(), eltType);
1463 Value trueVal = rewriter.create<vector::CreateMaskOp>(
1464 loc, lowType, op.getOperands().drop_front());
1465 Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
1466 rewriter.getZeroAttr(lowType));
1467 Value result = rewriter.create<ConstantOp>(loc, dstType,
1468 rewriter.getZeroAttr(dstType));
1469 for (int64_t d = 0; d < dim; d++) {
1470 Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
1471 Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
1472 Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
1473 auto pos = rewriter.getI64ArrayAttr(d);
1474 result =
1475 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
1476 }
1477 rewriter.replaceOp(op, result);
1478 return success();
1479 }
1480 };
1481
1482 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
1483 /// vectors progressively on the way to target llvm.matrix intrinsics.
1484 /// This iterates over the most major dimension of the 2-D vector and performs
1485 /// rewrites into:
1486 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
1487 class ShapeCastOp2DDownCastRewritePattern
1488 : public OpRewritePattern<vector::ShapeCastOp> {
1489 public:
1490 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1491
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1492 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1493 PatternRewriter &rewriter) const override {
1494 auto sourceVectorType = op.getSourceVectorType();
1495 auto resultVectorType = op.getResultVectorType();
1496 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
1497 return failure();
1498
1499 auto loc = op.getLoc();
1500 Value desc = rewriter.create<ConstantOp>(
1501 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1502 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
1503 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
1504 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
1505 desc = rewriter.create<vector::InsertStridedSliceOp>(
1506 loc, vec, desc,
1507 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
1508 }
1509 rewriter.replaceOp(op, desc);
1510 return success();
1511 }
1512 };
1513
1514 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
1515 /// vectors progressively on the way from targeting llvm.matrix intrinsics.
1516 /// This iterates over the most major dimension of the 2-D vector and performs
1517 /// rewrites into:
1518 /// vector.strided_slice from 1-D + vector.insert into 2-D
1519 class ShapeCastOp2DUpCastRewritePattern
1520 : public OpRewritePattern<vector::ShapeCastOp> {
1521 public:
1522 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1523
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1524 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1525 PatternRewriter &rewriter) const override {
1526 auto sourceVectorType = op.getSourceVectorType();
1527 auto resultVectorType = op.getResultVectorType();
1528 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
1529 return failure();
1530
1531 auto loc = op.getLoc();
1532 Value desc = rewriter.create<ConstantOp>(
1533 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1534 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
1535 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
1536 Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
1537 loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
1538 /*sizes=*/mostMinorVectorSize,
1539 /*strides=*/1);
1540 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
1541 }
1542 rewriter.replaceOp(op, desc);
1543 return success();
1544 }
1545 };
1546
1547 // We typically should not lower general shape cast operations into data
1548 // movement instructions, since the assumption is that these casts are
1549 // optimized away during progressive lowering. For completeness, however,
1550 // we fall back to a reference implementation that moves all elements
1551 // into the right place if we get here.
1552 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1553 public:
1554 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1555
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1556 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1557 PatternRewriter &rewriter) const override {
1558 Location loc = op.getLoc();
1559 auto sourceVectorType = op.getSourceVectorType();
1560 auto resultVectorType = op.getResultVectorType();
1561 // Intended 2D/1D lowerings with better implementations.
1562 int64_t srcRank = sourceVectorType.getRank();
1563 int64_t resRank = resultVectorType.getRank();
1564 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
1565 return failure();
1566 // Compute number of elements involved in the reshape.
1567 int64_t numElts = 1;
1568 for (int64_t r = 0; r < srcRank; r++)
1569 numElts *= sourceVectorType.getDimSize(r);
1570 // Replace with data movement operations:
1571 // x[0,0,0] = y[0,0]
1572 // x[0,0,1] = y[0,1]
1573 // x[0,1,0] = y[0,2]
1574 // etc., incrementing the two index vectors "row-major"
1575 // within the source and result shape.
1576 SmallVector<int64_t, 4> srcIdx(srcRank);
1577 SmallVector<int64_t, 4> resIdx(resRank);
1578 Value result = rewriter.create<ConstantOp>(
1579 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1580 for (int64_t i = 0; i < numElts; i++) {
1581 if (i != 0) {
1582 incIdx(srcIdx, sourceVectorType, srcRank - 1);
1583 incIdx(resIdx, resultVectorType, resRank - 1);
1584 }
1585 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
1586 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
1587 }
1588 rewriter.replaceOp(op, result);
1589 return success();
1590 }
1591
1592 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)1593 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
1594 assert(0 <= r && r < tp.getRank());
1595 if (++idx[r] == tp.getDimSize(r)) {
1596 idx[r] = 0;
1597 incIdx(idx, tp, r - 1);
1598 }
1599 }
1600 };
1601
1602 } // namespace
1603
1604 namespace mlir {
1605
1606 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1607 /// semantics to:
1608 /// ```
1609 /// %flattened_a = vector.shape_cast %a
1610 /// %flattened_b = vector.shape_cast %b
1611 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
1612 /// %d = vector.shape_cast %%flattened_d
1613 /// %e = add %c, %d
1614 /// ```
1615 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1616 //
1617 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
1618 /// the vector.contract op is a row-major matrix multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1619 LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
1620 vector::ContractionOp op, PatternRewriter &rewriter) const {
1621 // TODO: implement masks
1622 if (llvm::size(op.masks()) != 0)
1623 return failure();
1624 if (vectorTransformsOptions.vectorContractLowering !=
1625 vector::VectorContractLowering::Matmul)
1626 return failure();
1627 if (failed(filter(op)))
1628 return failure();
1629
1630 auto iteratorTypes = op.iterator_types().getValue();
1631 if (!isParallelIterator(iteratorTypes[0]) ||
1632 !isParallelIterator(iteratorTypes[1]) ||
1633 !isReductionIterator(iteratorTypes[2]))
1634 return failure();
1635
1636 if (!isRowMajorMatmul(op.indexing_maps()))
1637 return failure();
1638
1639 Type elementType = op.getLhsType().getElementType();
1640 if (!elementType.isIntOrFloat())
1641 return failure();
1642
1643 VectorType lhsType = op.getLhsType();
1644 VectorType rhsType = op.getRhsType();
1645 int64_t lhsRows = lhsType.getDimSize(0);
1646 int64_t lhsColumns = lhsType.getDimSize(1);
1647 int64_t rhsColumns = rhsType.getDimSize(1);
1648
1649 Type flattenedLHSType =
1650 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1651 Type flattenedRHSType =
1652 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1653 auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
1654 op.lhs());
1655 auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
1656 op.rhs());
1657
1658 Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
1659 lhsColumns, rhsColumns);
1660 mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
1661 mul);
1662 if (elementType.isa<IntegerType>())
1663 rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1664 else
1665 rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1666
1667 return success();
1668 }
1669
1670 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1671 /// semantics to a reduction_size-unrolled sequence:
1672 /// ```
1673 /// %at = vector.transpose %a, [1, 0]
1674 /// %bRow0 = vector.extract %b[0]
1675 /// %atRow0 = vector.extract %at[0]
1676 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
1677 /// ...
1678 /// %bRowK = vector.extract %b[K]
1679 /// %atRowK = vector.extract %at[K]
1680 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1681 /// ```
1682 ///
1683 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1684 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1685 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1686 vector::ContractionOp op, PatternRewriter &rewriter) const {
1687 // TODO: implement masks
1688 if (llvm::size(op.masks()) != 0)
1689 return failure();
1690
1691 if (vectorTransformsOptions.vectorContractLowering !=
1692 vector::VectorContractLowering::OuterProduct)
1693 return failure();
1694
1695 if (failed(filter(op)))
1696 return failure();
1697
1698 Location loc = op.getLoc();
1699 int64_t reductionSize = 0;
1700 VectorType lhsType = op.getLhsType();
1701 Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
1702
1703 // Set up the parallel/reduction structure in right form.
1704 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1705 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1706 AffineExpr m, n, k;
1707 bindDims(rewriter.getContext(), m, n, k);
1708 static constexpr std::array<int64_t, 2> perm = {1, 0};
1709 auto iteratorTypes = op.iterator_types().getValue();
1710 SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1711 if (isParallelIterator(iteratorTypes[0]) &&
1712 isParallelIterator(iteratorTypes[1]) &&
1713 isReductionIterator(iteratorTypes[2])) {
1714 //
1715 // Two outer parallel, one inner reduction (matmat flavor).
1716 //
1717 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1718 // This is the classical row-major matmul. Just permute the lhs.
1719 reductionSize = lhsType.getDimSize(1);
1720 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1721 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1722 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1723 reductionSize = lhsType.getDimSize(1);
1724 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1725 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1726 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1727 // No need to permute anything.
1728 reductionSize = lhsType.getDimSize(0);
1729 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1730 // Just permute the rhs.
1731 reductionSize = lhsType.getDimSize(0);
1732 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1733 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1734 // This is the classical row-major matmul. Just permute the lhs.
1735 reductionSize = lhsType.getDimSize(1);
1736 Value tmp = rhs;
1737 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1738 lhs = tmp;
1739 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1740 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1741 reductionSize = lhsType.getDimSize(1);
1742 Value tmp = rhs;
1743 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1744 lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1745 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1746 // No need to permute anything, but still swap lhs and rhs.
1747 reductionSize = lhsType.getDimSize(0);
1748 std::swap(lhs, rhs);
1749 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1750 // Just permute the rhs.
1751 reductionSize = lhsType.getDimSize(0);
1752 Value tmp = lhs;
1753 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1754 rhs = tmp;
1755 } else {
1756 return failure();
1757 }
1758 } else if (isParallelIterator(iteratorTypes[0]) &&
1759 isReductionIterator(iteratorTypes[1])) {
1760 //
1761 // One outer parallel, one inner reduction (matvec flavor)
1762 //
1763 if (maps == infer({{m, n}, {n}, {m}})) {
1764 // Case mat-vec: transpose.
1765 reductionSize = lhsType.getDimSize(1);
1766 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1767 } else if (maps == infer({{n, m}, {n}, {m}})) {
1768 // Case mat-trans-vec: ready to go.
1769 reductionSize = lhsType.getDimSize(0);
1770 } else if (maps == infer({{n}, {m, n}, {m}})) {
1771 // Case vec-mat: swap and transpose.
1772 reductionSize = lhsType.getDimSize(0);
1773 std::swap(lhs, rhs);
1774 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1775 } else if (maps == infer({{n}, {n, m}, {m}})) {
1776 // Case vec-mat-trans: swap and ready to go.
1777 reductionSize = lhsType.getDimSize(0);
1778 std::swap(lhs, rhs);
1779 } else {
1780 return failure();
1781 }
1782 } else {
1783 return failure();
1784 }
1785 assert(reductionSize > 0);
1786
1787 // Unroll outer-products along reduction.
1788 for (int64_t k = 0; k < reductionSize; ++k) {
1789 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1790 Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1791 res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
1792 }
1793 rewriter.replaceOp(op, res);
1794 return success();
1795 }
1796
1797 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1798 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1799 PatternRewriter &rewriter) const {
1800 // TODO: implement masks
1801 if (llvm::size(op.masks()) != 0)
1802 return failure();
1803
1804 if (failed(filter(op)))
1805 return failure();
1806
1807 if (vectorTransformsOptions.vectorContractLowering !=
1808 vector::VectorContractLowering::Dot)
1809 return failure();
1810
1811 auto iteratorTypes = op.iterator_types().getValue();
1812 static constexpr std::array<int64_t, 2> perm = {1, 0};
1813 Location loc = op.getLoc();
1814 Value lhs = op.lhs(), rhs = op.rhs();
1815
1816 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1817 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1818 AffineExpr m, n, k;
1819 bindDims(rewriter.getContext(), m, n, k);
1820 SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1821 //
1822 // In the following we wish to make the reduction dimension innermost so we
1823 // can load vectors and just fmul + reduce into a scalar.
1824 //
1825 if (isParallelIterator(iteratorTypes[0]) &&
1826 isParallelIterator(iteratorTypes[1]) &&
1827 isReductionIterator(iteratorTypes[2])) {
1828 //
1829 // Two outer parallel, one inner reduction (matmat flavor).
1830 //
1831 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1832 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1833 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1834 // No need to permute anything.
1835 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1836 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1837 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1838 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1839 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1840 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1841 // This is the classical row-major matmul. Just permute the lhs.
1842 Value tmp = lhs;
1843 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1844 rhs = tmp;
1845 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1846 std::swap(lhs, rhs);
1847 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1848 Value tmp = lhs;
1849 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1850 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1851 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1852 Value tmp = rhs;
1853 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1854 lhs = tmp;
1855 } else {
1856 return failure();
1857 }
1858 } else if (isParallelIterator(iteratorTypes[0]) &&
1859 isReductionIterator(iteratorTypes[1])) {
1860 //
1861 // One outer parallel, one inner reduction (matvec flavor)
1862 //
1863 if (maps == infer({{m, n}, {n}, {m}})) {
1864 // No need to permute anything.
1865 } else if (maps == infer({{n, m}, {n}, {m}})) {
1866 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1867 } else if (maps == infer({{n}, {m, n}, {m}})) {
1868 std::swap(lhs, rhs);
1869 } else if (maps == infer({{n}, {n, m}, {m}})) {
1870 std::swap(lhs, rhs);
1871 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1872 } else {
1873 return failure();
1874 }
1875 } else {
1876 return failure();
1877 }
1878
1879 VectorType dstType = op.getResultType().cast<VectorType>();
1880 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1881 "Expected dst type of rank 1 or 2");
1882
1883 unsigned rank = dstType.getRank();
1884 unsigned dstRows = dstType.getShape()[0];
1885 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1886
1887 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1888 Value res =
1889 rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
1890 for (unsigned r = 0; r < dstRows; ++r) {
1891 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1892 for (unsigned c = 0; c < dstColumns; ++c) {
1893 Value b = rank == 1
1894 ? rhs
1895 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1896 Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
1897 Value reduced = rewriter.create<vector::ReductionOp>(
1898 op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
1899 m, ValueRange{});
1900
1901 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1902 : SmallVector<int64_t, 2>{r, c};
1903 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1904 }
1905 }
1906 if (auto acc = op.acc())
1907 res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
1908 rewriter.replaceOp(op, res);
1909 return success();
1910 }
1911
1912 /// Progressive lowering of ContractionOp.
1913 /// One:
1914 /// %x = vector.contract with at least one free/batch dimension
1915 /// is replaced by:
1916 /// %a = vector.contract with one less free/batch dimension
1917 /// %b = vector.contract with one less free/batch dimension
1918 /// ..
1919 /// %x = combine %a %b ..
1920 /// until a pure contraction is reached (no free/batch dimensions),
1921 /// which is replaced by a dot-product.
1922 ///
1923 /// This only kicks in when either VectorTransformsOptions is set
1924 /// to DOT or when other contraction patterns fail.
1925 //
1926 // TODO: break down into transpose/reshape/cast ops
1927 // when they become available to avoid code dup
1928 // TODO: investigate lowering order impact on performance
1929 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1930 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1931 PatternRewriter &rewriter) const {
1932 // TODO: implement masks.
1933 if (llvm::size(op.masks()) != 0)
1934 return failure();
1935
1936 if (failed(filter(op)))
1937 return failure();
1938
1939 // TODO: support mixed mode contract lowering.
1940 if (op.getLhsType().getElementType() !=
1941 getElementTypeOrSelf(op.getAccType()) ||
1942 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1943 return failure();
1944
1945 // TODO: implement benefits, cost models.
1946 MLIRContext *ctx = op.getContext();
1947 ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
1948 if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1949 return success();
1950 ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
1951 if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1952 return success();
1953 ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
1954 if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1955 return success();
1956
1957 // Find first batch dimension in LHS/RHS, and lower when found.
1958 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1959 if (!batchDimMap.empty()) {
1960 int64_t lhsIndex = batchDimMap[0].first;
1961 int64_t rhsIndex = batchDimMap[0].second;
1962 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1963 return success();
1964 }
1965
1966 // Collect contracting dimensions.
1967 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1968 op.getContractingDimMap();
1969 DenseSet<int64_t> lhsContractingDimSet;
1970 DenseSet<int64_t> rhsContractingDimSet;
1971 for (auto &dimPair : contractingDimMap) {
1972 lhsContractingDimSet.insert(dimPair.first);
1973 rhsContractingDimSet.insert(dimPair.second);
1974 }
1975
1976 // Find first free dimension in LHS, and lower when found.
1977 VectorType lhsType = op.getLhsType();
1978 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1979 if (lhsContractingDimSet.count(lhsIndex) == 0) {
1980 rewriter.replaceOp(
1981 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1982 return success();
1983 }
1984 }
1985
1986 // Find first free dimension in RHS, and lower when found.
1987 VectorType rhsType = op.getRhsType();
1988 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1989 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1990 rewriter.replaceOp(
1991 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1992 return success();
1993 }
1994 }
1995
1996 // Lower the first remaining reduction dimension.
1997 if (!contractingDimMap.empty()) {
1998 rewriter.replaceOp(op, lowerReduction(op, rewriter));
1999 return success();
2000 }
2001
2002 return failure();
2003 }
2004
2005 // Lower one parallel dimension.
2006 // TODO: consider reusing existing contract unrolling
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const2007 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
2008 int64_t lhsIndex, int64_t rhsIndex,
2009 PatternRewriter &rewriter) const {
2010 VectorType lhsType = op.getLhsType();
2011 VectorType rhsType = op.getRhsType();
2012 VectorType resType = op.getResultType().cast<VectorType>();
2013 // Find the iterator type index and result index.
2014 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2015 int64_t iterIndex = -1;
2016 int64_t dimSize = -1;
2017 if (lhsIndex >= 0) {
2018 iterIndex = iMap[0].getDimPosition(lhsIndex);
2019 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
2020 "parallel index should be free in LHS or batch in LHS/RHS");
2021 dimSize = lhsType.getDimSize(lhsIndex);
2022 } else {
2023 assert(rhsIndex >= 0 && "missing parallel index");
2024 iterIndex = iMap[1].getDimPosition(rhsIndex);
2025 dimSize = rhsType.getDimSize(rhsIndex);
2026 }
2027 assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
2028 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
2029 assert(lookup.hasValue() && "parallel index not listed in reduction");
2030 int64_t resIndex = lookup.getValue();
2031 // Construct new iterator types and affine map array attribute.
2032 std::array<AffineMap, 3> lowIndexingMaps = {
2033 adjustMap(iMap[0], iterIndex, rewriter),
2034 adjustMap(iMap[1], iterIndex, rewriter),
2035 adjustMap(iMap[2], iterIndex, rewriter)};
2036 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2037 auto lowIter =
2038 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2039 // Unroll into a series of lower dimensional vector.contract ops.
2040 Location loc = op.getLoc();
2041 Value result =
2042 rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
2043 for (int64_t d = 0; d < dimSize; ++d) {
2044 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2045 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2046 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
2047 Value lowContract = rewriter.create<vector::ContractionOp>(
2048 loc, lhs, rhs, acc, lowAffine, lowIter);
2049 result =
2050 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
2051 }
2052 return result;
2053 }
2054
2055 // Lower one reduction dimension.
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const2056 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
2057 PatternRewriter &rewriter) const {
2058 auto loc = op.getLoc();
2059 VectorType lhsType = op.getLhsType();
2060 VectorType rhsType = op.getRhsType();
2061 Type resType = op.getResultType();
2062 assert(!resType.isa<VectorType>());
2063 // Use iterator index 0.
2064 int64_t iterIndex = 0;
2065 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2066 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
2067 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
2068 assert(lookupLhs.hasValue() && "missing LHS parallel index");
2069 assert(lookupRhs.hasValue() && "missing RHS parallel index");
2070 int64_t lhsIndex = lookupLhs.getValue();
2071 int64_t rhsIndex = lookupRhs.getValue();
2072 int64_t dimSize = lhsType.getDimSize(lhsIndex);
2073 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
2074 // Base case.
2075 if (lhsType.getRank() == 1) {
2076 assert(rhsType.getRank() == 1 && "corrupt contraction");
2077 Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
2078 StringAttr kind = rewriter.getStringAttr("add");
2079 return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
2080 op.acc());
2081 }
2082 // Construct new iterator types and affine map array attribute.
2083 std::array<AffineMap, 3> lowIndexingMaps = {
2084 adjustMap(iMap[0], iterIndex, rewriter),
2085 adjustMap(iMap[1], iterIndex, rewriter),
2086 adjustMap(iMap[2], iterIndex, rewriter)};
2087 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2088 auto lowIter =
2089 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2090 // Unroll into a series of lower dimensional vector.contract ops.
2091 // By feeding the initial accumulator into the first contraction,
2092 // and the result of each contraction into the next, eventually
2093 // the sum of all reductions is computed.
2094 Value result = op.acc();
2095 for (int64_t d = 0; d < dimSize; ++d) {
2096 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2097 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2098 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
2099 lowAffine, lowIter);
2100 }
2101 return result;
2102 }
2103
2104 } // namespace mlir
2105
extractConstantIndex(Value v)2106 static Optional<int64_t> extractConstantIndex(Value v) {
2107 if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
2108 return cstOp.getValue();
2109 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
2110 if (affineApplyOp.getAffineMap().isSingleConstant())
2111 return affineApplyOp.getAffineMap().getSingleConstantResult();
2112 return None;
2113 }
2114
2115 // Missing foldings of scf.if make it necessary to perform poor man's folding
2116 // eagerly, especially in the case of unrolling. In the future, this should go
2117 // away once scf.if folds properly.
createScopedFoldedSLE(Value v,Value ub)2118 static Value createScopedFoldedSLE(Value v, Value ub) {
2119 using namespace edsc::op;
2120 auto maybeCstV = extractConstantIndex(v);
2121 auto maybeCstUb = extractConstantIndex(ub);
2122 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
2123 return Value();
2124 return sle(v, ub);
2125 }
2126
2127 // Operates under a scoped context to build the condition to ensure that a
2128 // particular VectorTransferOpInterface is unmasked.
createScopedInBoundsCond(VectorTransferOpInterface xferOp)2129 static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
2130 assert(xferOp.permutation_map().isMinorIdentity() &&
2131 "Expected minor identity map");
2132 Value inBoundsCond;
2133 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2134 // Zip over the resulting vector shape and memref indices.
2135 // If the dimension is known to be unmasked, it does not participate in the
2136 // construction of `inBoundsCond`.
2137 if (!xferOp.isMaskedDim(resultIdx))
2138 return;
2139 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
2140 using namespace edsc::op;
2141 using namespace edsc::intrinsics;
2142 // Fold or create the check that `index + vector_size` <= `memref_size`.
2143 Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
2144 Value cond =
2145 createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
2146 if (!cond)
2147 return;
2148 // Conjunction over all dims for which we are in-bounds.
2149 inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
2150 });
2151 return inBoundsCond;
2152 }
2153
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)2154 LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
2155 VectorTransferOpInterface xferOp) {
2156 // TODO: expand support to these 2 cases.
2157 if (!xferOp.permutation_map().isMinorIdentity())
2158 return failure();
2159 // Must have some masked dimension to be a candidate for splitting.
2160 if (!xferOp.hasMaskedDim())
2161 return failure();
2162 // Don't split transfer operations directly under IfOp, this avoids applying
2163 // the pattern recursively.
2164 // TODO: improve the filtering condition to make it more applicable.
2165 if (isa<scf::IfOp>(xferOp->getParentOp()))
2166 return failure();
2167 return success();
2168 }
2169
2170 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
2171 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
2172 /// return null; otherwise:
2173 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
2174 /// 2. else return a new MemRefType obtained by iterating over the shape and
2175 /// strides and:
2176 /// a. keeping the ones that are static and equal across `aT` and `bT`.
2177 /// b. using a dynamic shape and/or stride for the dimensions that don't
2178 /// agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)2179 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
2180 if (MemRefCastOp::areCastCompatible(aT, bT))
2181 return aT;
2182 if (aT.getRank() != bT.getRank())
2183 return MemRefType();
2184 int64_t aOffset, bOffset;
2185 SmallVector<int64_t, 4> aStrides, bStrides;
2186 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2187 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2188 aStrides.size() != bStrides.size())
2189 return MemRefType();
2190
2191 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
2192 int64_t resOffset;
2193 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
2194 resStrides(bT.getRank(), 0);
2195 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
2196 resShape[idx] =
2197 (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
2198 resStrides[idx] = (aStrides[idx] == bStrides[idx])
2199 ? aStrides[idx]
2200 : MemRefType::kDynamicStrideOrOffset;
2201 }
2202 resOffset =
2203 (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
2204 return MemRefType::get(
2205 resShape, aT.getElementType(),
2206 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
2207 }
2208
2209 /// Operates under a scoped context to build the intersection between the
2210 /// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
2211 // TODO: view intersection/union/differences should be a proper std op.
createScopedSubViewIntersection(VectorTransferOpInterface xferOp,Value alloc)2212 static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
2213 Value alloc) {
2214 using namespace edsc::intrinsics;
2215 int64_t memrefRank = xferOp.getMemRefType().getRank();
2216 // TODO: relax this precondition, will require rank-reducing subviews.
2217 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
2218 "Expected memref rank to match the alloc rank");
2219 Value one = std_constant_index(1);
2220 ValueRange leadingIndices =
2221 xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
2222 SmallVector<Value, 4> sizes;
2223 sizes.append(leadingIndices.begin(), leadingIndices.end());
2224 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2225 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2226 Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
2227 Value dimAlloc = std_dim(alloc, resultIdx);
2228 Value index = xferOp.indices()[indicesIdx];
2229 AffineExpr i, j, k;
2230 bindDims(xferOp.getContext(), i, j, k);
2231 SmallVector<AffineMap, 4> maps =
2232 AffineMap::inferFromExprList(MapList{{i - j, k}});
2233 // affine_min(%dimMemRef - %index, %dimAlloc)
2234 Value affineMin = affine_min(index.getType(), maps[0],
2235 ValueRange{dimMemRef, index, dimAlloc});
2236 sizes.push_back(affineMin);
2237 });
2238 return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
2239 SmallVector<Value, 4>(memrefRank, one));
2240 }
2241
2242 /// Given an `xferOp` for which:
2243 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2244 /// 2. a memref of single vector `alloc` has been allocated.
2245 /// Produce IR resembling:
2246 /// ```
2247 /// %1:3 = scf.if (%inBounds) {
2248 /// memref_cast %A: memref<A...> to compatibleMemRefType
2249 /// scf.yield %view, ... : compatibleMemRefType, index, index
2250 /// } else {
2251 /// %2 = linalg.fill(%alloc, %pad)
2252 /// %3 = subview %view [...][...][...]
2253 /// linalg.copy(%3, %alloc)
2254 /// memref_cast %alloc: memref<B...> to compatibleMemRefType
2255 /// scf.yield %4, ... : compatibleMemRefType, index, index
2256 /// }
2257 /// ```
2258 /// Return the produced scf::IfOp.
createScopedFullPartialLinalgCopy(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2259 static scf::IfOp createScopedFullPartialLinalgCopy(
2260 vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2261 MemRefType compatibleMemRefType, Value alloc) {
2262 using namespace edsc;
2263 using namespace edsc::intrinsics;
2264 scf::IfOp fullPartialIfOp;
2265 Value zero = std_constant_index(0);
2266 Value memref = xferOp.memref();
2267 conditionBuilder(
2268 returnTypes, inBoundsCond,
2269 [&]() -> scf::ValueVector {
2270 Value res = memref;
2271 if (compatibleMemRefType != xferOp.getMemRefType())
2272 res = std_memref_cast(memref, compatibleMemRefType);
2273 scf::ValueVector viewAndIndices{res};
2274 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2275 xferOp.indices().end());
2276 return viewAndIndices;
2277 },
2278 [&]() -> scf::ValueVector {
2279 linalg_fill(alloc, xferOp.padding());
2280 // Take partial subview of memref which guarantees no dimension
2281 // overflows.
2282 Value memRefSubView = createScopedSubViewIntersection(
2283 cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
2284 linalg_copy(memRefSubView, alloc);
2285 Value casted = std_memref_cast(alloc, compatibleMemRefType);
2286 scf::ValueVector viewAndIndices{casted};
2287 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2288 zero);
2289 return viewAndIndices;
2290 },
2291 &fullPartialIfOp);
2292 return fullPartialIfOp;
2293 }
2294
2295 /// Given an `xferOp` for which:
2296 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2297 /// 2. a memref of single vector `alloc` has been allocated.
2298 /// Produce IR resembling:
2299 /// ```
2300 /// %1:3 = scf.if (%inBounds) {
2301 /// memref_cast %A: memref<A...> to compatibleMemRefType
2302 /// scf.yield %view, ... : compatibleMemRefType, index, index
2303 /// } else {
2304 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
2305 /// %3 = vector.type_cast %extra_alloc :
2306 /// memref<...> to memref<vector<...>>
2307 /// store %2, %3[] : memref<vector<...>>
2308 /// %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
2309 /// scf.yield %4, ... : compatibleMemRefType, index, index
2310 /// }
2311 /// ```
2312 /// Return the produced scf::IfOp.
createScopedFullPartialVectorTransferRead(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2313 static scf::IfOp createScopedFullPartialVectorTransferRead(
2314 vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2315 MemRefType compatibleMemRefType, Value alloc) {
2316 using namespace edsc;
2317 using namespace edsc::intrinsics;
2318 scf::IfOp fullPartialIfOp;
2319 Value zero = std_constant_index(0);
2320 Value memref = xferOp.memref();
2321 conditionBuilder(
2322 returnTypes, inBoundsCond,
2323 [&]() -> scf::ValueVector {
2324 Value res = memref;
2325 if (compatibleMemRefType != xferOp.getMemRefType())
2326 res = std_memref_cast(memref, compatibleMemRefType);
2327 scf::ValueVector viewAndIndices{res};
2328 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2329 xferOp.indices().end());
2330 return viewAndIndices;
2331 },
2332 [&]() -> scf::ValueVector {
2333 Operation *newXfer =
2334 ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
2335 Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
2336 std_store(vector, vector_type_cast(
2337 MemRefType::get({}, vector.getType()), alloc));
2338
2339 Value casted = std_memref_cast(alloc, compatibleMemRefType);
2340 scf::ValueVector viewAndIndices{casted};
2341 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2342 zero);
2343
2344 return viewAndIndices;
2345 },
2346 &fullPartialIfOp);
2347 return fullPartialIfOp;
2348 }
2349
2350 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
2351 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
2352 /// newly created conditional upon function return.
2353 /// To accomodate for the fact that the original vector.transfer indexing may be
2354 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
2355 /// scf.if op returns a view and values of type index.
2356 /// At this time, only vector.transfer_read case is implemented.
2357 ///
2358 /// Example (a 2-D vector.transfer_read):
2359 /// ```
2360 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
2361 /// ```
2362 /// is transformed into:
2363 /// ```
2364 /// %1:3 = scf.if (%inBounds) {
2365 /// // fastpath, direct cast
2366 /// memref_cast %A: memref<A...> to compatibleMemRefType
2367 /// scf.yield %view : compatibleMemRefType, index, index
2368 /// } else {
2369 /// // slowpath, masked vector.transfer or linalg.copy.
2370 /// memref_cast %alloc: memref<B...> to compatibleMemRefType
2371 /// scf.yield %4 : compatibleMemRefType, index, index
2372 // }
2373 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
2374 /// ```
2375 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
2376 ///
2377 /// Preconditions:
2378 /// 1. `xferOp.permutation_map()` must be a minor identity map
2379 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
2380 /// must be equal. This will be relaxed in the future but requires
2381 /// rank-reducing subviews.
splitFullAndPartialTransfer(OpBuilder & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)2382 LogicalResult mlir::vector::splitFullAndPartialTransfer(
2383 OpBuilder &b, VectorTransferOpInterface xferOp,
2384 VectorTransformsOptions options, scf::IfOp *ifOp) {
2385 using namespace edsc;
2386 using namespace edsc::intrinsics;
2387
2388 if (options.vectorTransferSplit == VectorTransferSplit::None)
2389 return failure();
2390
2391 SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
2392 auto unmaskedAttr = b.getBoolArrayAttr(bools);
2393 if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
2394 xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2395 return success();
2396 }
2397
2398 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
2399 "Expected splitFullAndPartialTransferPrecondition to hold");
2400 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
2401
2402 // TODO: add support for write case.
2403 if (!xferReadOp)
2404 return failure();
2405
2406 OpBuilder::InsertionGuard guard(b);
2407 if (xferOp.memref().getDefiningOp())
2408 b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
2409 else
2410 b.setInsertionPoint(xferOp);
2411 ScopedContext scope(b, xferOp.getLoc());
2412 Value inBoundsCond = createScopedInBoundsCond(
2413 cast<VectorTransferOpInterface>(xferOp.getOperation()));
2414 if (!inBoundsCond)
2415 return failure();
2416
2417 // Top of the function `alloc` for transient storage.
2418 Value alloc;
2419 {
2420 FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
2421 OpBuilder::InsertionGuard guard(b);
2422 b.setInsertionPointToStart(&funcOp.getRegion().front());
2423 auto shape = xferOp.getVectorType().getShape();
2424 Type elementType = xferOp.getVectorType().getElementType();
2425 alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
2426 b.getI64IntegerAttr(32));
2427 }
2428
2429 MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
2430 xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
2431
2432 // Read case: full fill + partial copy -> unmasked vector.xfer_read.
2433 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
2434 b.getIndexType());
2435 returnTypes[0] = compatibleMemRefType;
2436 scf::IfOp fullPartialIfOp =
2437 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
2438 ? createScopedFullPartialVectorTransferRead(
2439 xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
2440 alloc)
2441 : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
2442 inBoundsCond,
2443 compatibleMemRefType, alloc);
2444 if (ifOp)
2445 *ifOp = fullPartialIfOp;
2446
2447 // Unmask the existing read op, it always reads from a full buffer.
2448 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
2449 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
2450 xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2451
2452 return success();
2453 }
2454
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2455 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
2456 Operation *op, PatternRewriter &rewriter) const {
2457 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
2458 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
2459 failed(filter(xferOp)))
2460 return failure();
2461 rewriter.startRootUpdate(xferOp);
2462 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
2463 rewriter.finalizeRootUpdate(xferOp);
2464 return success();
2465 }
2466 rewriter.cancelRootUpdate(xferOp);
2467 return failure();
2468 }
2469
matchAndRewrite(ExtractMapOp extract,PatternRewriter & rewriter) const2470 LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2471 ExtractMapOp extract, PatternRewriter &rewriter) const {
2472 Operation *definedOp = extract.vector().getDefiningOp();
2473 if (!definedOp || definedOp->getNumResults() != 1)
2474 return failure();
2475 // TODO: Create an interfaceOp for elementwise operations.
2476 if (!isa<AddFOp>(definedOp))
2477 return failure();
2478 Location loc = extract.getLoc();
2479 SmallVector<Value, 4> extractOperands;
2480 for (OpOperand &operand : definedOp->getOpOperands())
2481 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2482 loc, extract.getResultType(), operand.get(), extract.ids()));
2483 Operation *newOp = cloneOpWithOperandsAndTypes(
2484 rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2485 rewriter.replaceOp(extract, newOp->getResult(0));
2486 return success();
2487 }
2488
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)2489 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2490 OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
2491 ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2492 OpBuilder::InsertionGuard guard(builder);
2493 builder.setInsertionPointAfter(op);
2494 Location loc = op->getLoc();
2495 if (op->getNumResults() != 1)
2496 return {};
2497 Value result = op->getResult(0);
2498 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2499 if (!type || map.getNumResults() != multiplicity.size())
2500 return {};
2501 // For each dimension being distributed check that the size is a multiple of
2502 // the multiplicity. To handle more sizes we would need to support masking.
2503 unsigned multiplictyCount = 0;
2504 for (auto exp : map.getResults()) {
2505 auto affinExp = exp.dyn_cast<AffineDimExpr>();
2506 if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2507 type.getDimSize(affinExp.getPosition()) %
2508 multiplicity[multiplictyCount++] !=
2509 0)
2510 return {};
2511 }
2512 DistributeOps ops;
2513 ops.extract =
2514 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2515 ops.insert =
2516 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2517 return ops;
2518 }
2519
2520 struct TransferReadExtractPattern
2521 : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPatternTransferReadExtractPattern2522 TransferReadExtractPattern(MLIRContext *context)
2523 : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewriteTransferReadExtractPattern2524 LogicalResult matchAndRewrite(vector::TransferReadOp read,
2525 PatternRewriter &rewriter) const override {
2526 if (!read.getResult().hasOneUse())
2527 return failure();
2528 auto extract =
2529 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
2530 if (!extract)
2531 return failure();
2532 edsc::ScopedContext scope(rewriter, read.getLoc());
2533 using mlir::edsc::op::operator+;
2534 using mlir::edsc::op::operator*;
2535 using namespace mlir::edsc::intrinsics;
2536 SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
2537 AffineMap map = extract.map();
2538 unsigned idCount = 0;
2539 for (auto expr : map.getResults()) {
2540 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2541 indices[pos] =
2542 indices[pos] +
2543 extract.ids()[idCount++] *
2544 std_constant_index(extract.getResultType().getDimSize(pos));
2545 }
2546 Value newRead = vector_transfer_read(extract.getType(), read.memref(),
2547 indices, read.permutation_map(),
2548 read.padding(), read.maskedAttr());
2549 Value dest = rewriter.create<ConstantOp>(
2550 read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
2551 newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
2552 extract.ids());
2553 rewriter.replaceOp(read, newRead);
2554 return success();
2555 }
2556 };
2557
2558 struct TransferWriteInsertPattern
2559 : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPatternTransferWriteInsertPattern2560 TransferWriteInsertPattern(MLIRContext *context)
2561 : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewriteTransferWriteInsertPattern2562 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2563 PatternRewriter &rewriter) const override {
2564 auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
2565 if (!insert)
2566 return failure();
2567 edsc::ScopedContext scope(rewriter, write.getLoc());
2568 using mlir::edsc::op::operator+;
2569 using mlir::edsc::op::operator*;
2570 using namespace mlir::edsc::intrinsics;
2571 SmallVector<Value, 4> indices(write.indices().begin(),
2572 write.indices().end());
2573 AffineMap map = insert.map();
2574 unsigned idCount = 0;
2575 for (auto expr : map.getResults()) {
2576 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2577 indices[pos] =
2578 indices[pos] +
2579 insert.ids()[idCount++] *
2580 std_constant_index(insert.getSourceVectorType().getDimSize(pos));
2581 }
2582 vector_transfer_write(insert.vector(), write.memref(), indices,
2583 write.permutation_map(), write.maskedAttr());
2584 rewriter.eraseOp(write);
2585 return success();
2586 }
2587 };
2588
2589 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
2590 // TODO: Add this as DRR pattern.
populateVectorToVectorTransformationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2591 void mlir::vector::populateVectorToVectorTransformationPatterns(
2592 OwningRewritePatternList &patterns, MLIRContext *context) {
2593 // clang-format off
2594 patterns.insert<ShapeCastOpDecomposer,
2595 ShapeCastOpFolder,
2596 SplitTransferReadOp,
2597 SplitTransferWriteOp,
2598 TupleGetFolderOp,
2599 TransferReadExtractPattern,
2600 TransferWriteInsertPattern>(context);
2601 // clang-format on
2602 }
2603
populateVectorSlicesLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2604 void mlir::vector::populateVectorSlicesLoweringPatterns(
2605 OwningRewritePatternList &patterns, MLIRContext *context) {
2606 patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
2607 }
2608
populateVectorContractLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context,VectorTransformsOptions parameters)2609 void mlir::vector::populateVectorContractLoweringPatterns(
2610 OwningRewritePatternList &patterns, MLIRContext *context,
2611 VectorTransformsOptions parameters) {
2612 // clang-format off
2613 patterns.insert<BroadcastOpLowering,
2614 CreateMaskOpLowering,
2615 ConstantMaskOpLowering,
2616 OuterProductOpLowering,
2617 ShapeCastOp2DDownCastRewritePattern,
2618 ShapeCastOp2DUpCastRewritePattern,
2619 ShapeCastOpRewritePattern>(context);
2620 patterns.insert<TransposeOpLowering,
2621 ContractionOpLowering,
2622 ContractionOpToMatmulOpLowering,
2623 ContractionOpToOuterProductOpLowering>(parameters, context);
2624 // clang-format on
2625 }
2626