1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Support/CheckedArithmetic.h"
20
21 namespace llvm {
22 class TargetLibraryInfo;
23
24 /// Describes the type of Parameters
25 enum class VFParamKind {
26 Vector, // No semantic information.
27 OMP_Linear, // declare simd linear(i)
28 OMP_LinearRef, // declare simd linear(ref(i))
29 OMP_LinearVal, // declare simd linear(val(i))
30 OMP_LinearUVal, // declare simd linear(uval(i))
31 OMP_LinearPos, // declare simd linear(i:c) uniform(c)
32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c)
33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c)
34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c)
35 OMP_Uniform, // declare simd uniform(i)
36 GlobalPredicate, // Global logical predicate that acts on all lanes
37 // of the input and output mask concurrently. For
38 // example, it is implied by the `M` token in the
39 // Vector Function ABI mangled name.
40 Unknown
41 };
42
43 /// Describes the type of Instruction Set Architecture
44 enum class VFISAKind {
45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
46 SVE, // AArch64 Scalable Vector Extension
47 SSE, // x86 SSE
48 AVX, // x86 AVX
49 AVX2, // x86 AVX2
50 AVX512, // x86 AVX512
51 LLVM, // LLVM internal ISA for functions that are not
52 // attached to an existing ABI via name mangling.
53 Unknown // Unknown ISA
54 };
55
56 /// Encapsulates information needed to describe a parameter.
57 ///
58 /// The description of the parameter is not linked directly to
59 /// OpenMP or any other vector function description. This structure
60 /// is extendible to handle other paradigms that describe vector
61 /// functions and their parameters.
62 struct VFParameter {
63 unsigned ParamPos; // Parameter Position in Scalar Function.
64 VFParamKind ParamKind; // Kind of Parameter.
65 int LinearStepOrPos = 0; // Step or Position of the Parameter.
66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
67
68 // Comparison operator.
69 bool operator==(const VFParameter &Other) const {
70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
72 Other.Alignment);
73 }
74 };
75
76 /// Contains the information about the kind of vectorization
77 /// available.
78 ///
79 /// This object in independent on the paradigm used to
80 /// represent vector functions. in particular, it is not attached to
81 /// any target-specific ABI.
82 struct VFShape {
83 ElementCount VF; // Vectorization factor.
84 SmallVector<VFParameter, 8> Parameters; // List of parameter information.
85 // Comparison operator.
86 bool operator==(const VFShape &Other) const {
87 return std::tie(VF, Parameters) == std::tie(Other.VF, Other.Parameters);
88 }
89
90 /// Update the parameter in position P.ParamPos to P.
updateParamVFShape91 void updateParam(VFParameter P) {
92 assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
93 Parameters[P.ParamPos] = P;
94 assert(hasValidParameterList() && "Invalid parameter list");
95 }
96
97 /// Retrieve the VFShape that can be used to map a scalar function to itself,
98 /// with VF = 1.
getScalarShapeVFShape99 static VFShape getScalarShape(const FunctionType *FTy) {
100 return VFShape::get(FTy, ElementCount::getFixed(1),
101 /*HasGlobalPredicate*/ false);
102 }
103
104 /// Retrieve the basic vectorization shape of the function, where all
105 /// parameters are mapped to VFParamKind::Vector with \p EC lanes. Specifies
106 /// whether the function has a Global Predicate argument via \p HasGlobalPred.
getVFShape107 static VFShape get(const FunctionType *FTy, ElementCount EC,
108 bool HasGlobalPred) {
109 SmallVector<VFParameter, 8> Parameters;
110 for (unsigned I = 0; I < FTy->getNumParams(); ++I)
111 Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
112 if (HasGlobalPred)
113 Parameters.push_back(
114 VFParameter({FTy->getNumParams(), VFParamKind::GlobalPredicate}));
115
116 return {EC, Parameters};
117 }
118 /// Validation check on the Parameters in the VFShape.
119 bool hasValidParameterList() const;
120 };
121
122 /// Holds the VFShape for a specific scalar to vector function mapping.
123 struct VFInfo {
124 VFShape Shape; /// Classification of the vector function.
125 std::string ScalarName; /// Scalar Function Name.
126 std::string VectorName; /// Vector Function Name associated to this VFInfo.
127 VFISAKind ISA; /// Instruction Set Architecture.
128
129 /// Returns the index of the first parameter with the kind 'GlobalPredicate',
130 /// if any exist.
getParamIndexForOptionalMaskVFInfo131 std::optional<unsigned> getParamIndexForOptionalMask() const {
132 unsigned ParamCount = Shape.Parameters.size();
133 for (unsigned i = 0; i < ParamCount; ++i)
134 if (Shape.Parameters[i].ParamKind == VFParamKind::GlobalPredicate)
135 return i;
136
137 return std::nullopt;
138 }
139
140 /// Returns true if at least one of the operands to the vectorized function
141 /// has the kind 'GlobalPredicate'.
isMaskedVFInfo142 bool isMasked() const { return getParamIndexForOptionalMask().has_value(); }
143 };
144
145 namespace VFABI {
146 /// LLVM Internal VFABI ISA token for vector functions.
147 static constexpr char const *_LLVM_ = "_LLVM_";
148 /// Prefix for internal name redirection for vector function that
149 /// tells the compiler to scalarize the call using the scalar name
150 /// of the function. For example, a mangled name like
151 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
152 /// vectorizer to vectorize the scalar call `foo`, and to scalarize
153 /// it once vectorization is done.
154 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
155
156 /// Function to construct a VFInfo out of a mangled names in the
157 /// following format:
158 ///
159 /// <VFABI_name>{(<redirection>)}
160 ///
161 /// where <VFABI_name> is the name of the vector function, mangled according
162 /// to the rules described in the Vector Function ABI of the target vector
163 /// extension (or <isa> from now on). The <VFABI_name> is in the following
164 /// format:
165 ///
166 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
167 ///
168 /// This methods support demangling rules for the following <isa>:
169 ///
170 /// * AArch64: https://developer.arm.com/docs/101129/latest
171 ///
172 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
173 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
174 ///
175 /// \param MangledName -> input string in the format
176 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
177 /// \param FTy -> FunctionType of the scalar function which we're trying to find
178 /// a vectorized variant for. This is required to determine the vectorization
179 /// factor for scalable vectors, since the mangled name doesn't encode that;
180 /// it needs to be derived from the widest element types of vector arguments
181 /// or return values.
182 std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
183 const FunctionType *FTy);
184
185 /// Retrieve the `VFParamKind` from a string token.
186 VFParamKind getVFParamKindFromString(const StringRef Token);
187
188 // Name of the attribute where the variant mappings are stored.
189 static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
190
191 /// Populates a set of strings representing the Vector Function ABI variants
192 /// associated to the CallInst CI. If the CI does not contain the
193 /// vector-function-abi-variant attribute, we return without populating
194 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for
195 /// the presence of the attribute (see InjectTLIMappings).
196 void getVectorVariantNames(const CallInst &CI,
197 SmallVectorImpl<std::string> &VariantMappings);
198
199 /// Constructs a FunctionType by applying vector function information to the
200 /// type of a matching scalar function.
201 /// \param Info gets the vectorization factor (VF) and the VFParamKind of the
202 /// parameters.
203 /// \param ScalarFTy gets the Type information of parameters, as it is not
204 /// stored in \p Info.
205 /// \returns a pointer to a newly created vector FunctionType
206 FunctionType *createFunctionType(const VFInfo &Info,
207 const FunctionType *ScalarFTy);
208 } // end namespace VFABI
209
210 /// The Vector Function Database.
211 ///
212 /// Helper class used to find the vector functions associated to a
213 /// scalar CallInst.
214 class VFDatabase {
215 /// The Module of the CallInst CI.
216 const Module *M;
217 /// The CallInst instance being queried for scalar to vector mappings.
218 const CallInst &CI;
219 /// List of vector functions descriptors associated to the call
220 /// instruction.
221 const SmallVector<VFInfo, 8> ScalarToVectorMappings;
222
223 /// Retrieve the scalar-to-vector mappings associated to the rule of
224 /// a vector Function ABI.
getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)225 static void getVFABIMappings(const CallInst &CI,
226 SmallVectorImpl<VFInfo> &Mappings) {
227 if (!CI.getCalledFunction())
228 return;
229
230 const StringRef ScalarName = CI.getCalledFunction()->getName();
231
232 SmallVector<std::string, 8> ListOfStrings;
233 // The check for the vector-function-abi-variant attribute is done when
234 // retrieving the vector variant names here.
235 VFABI::getVectorVariantNames(CI, ListOfStrings);
236 if (ListOfStrings.empty())
237 return;
238 for (const auto &MangledName : ListOfStrings) {
239 const std::optional<VFInfo> Shape =
240 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
241 // A match is found via scalar and vector names, and also by
242 // ensuring that the variant described in the attribute has a
243 // corresponding definition or declaration of the vector
244 // function in the Module M.
245 if (Shape && (Shape->ScalarName == ScalarName)) {
246 assert(CI.getModule()->getFunction(Shape->VectorName) &&
247 "Vector function is missing.");
248 Mappings.push_back(*Shape);
249 }
250 }
251 }
252
253 public:
254 /// Retrieve all the VFInfo instances associated to the CallInst CI.
getMappings(const CallInst & CI)255 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
256 SmallVector<VFInfo, 8> Ret;
257
258 // Get mappings from the Vector Function ABI variants.
259 getVFABIMappings(CI, Ret);
260
261 // Other non-VFABI variants should be retrieved here.
262
263 return Ret;
264 }
265
266 static bool hasMaskedVariant(const CallInst &CI,
267 std::optional<ElementCount> VF = std::nullopt) {
268 // Check whether we have at least one masked vector version of a scalar
269 // function. If no VF is specified then we check for any masked variant,
270 // otherwise we look for one that matches the supplied VF.
271 auto Mappings = VFDatabase::getMappings(CI);
272 for (VFInfo Info : Mappings)
273 if (!VF || Info.Shape.VF == *VF)
274 if (Info.isMasked())
275 return true;
276
277 return false;
278 }
279
280 /// Constructor, requires a CallInst instance.
VFDatabase(CallInst & CI)281 VFDatabase(CallInst &CI)
282 : M(CI.getModule()), CI(CI),
283 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
284 /// \defgroup VFDatabase query interface.
285 ///
286 /// @{
287 /// Retrieve the Function with VFShape \p Shape.
getVectorizedFunction(const VFShape & Shape)288 Function *getVectorizedFunction(const VFShape &Shape) const {
289 if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
290 return CI.getCalledFunction();
291
292 for (const auto &Info : ScalarToVectorMappings)
293 if (Info.Shape == Shape)
294 return M->getFunction(Info.VectorName);
295
296 return nullptr;
297 }
298 /// @}
299 };
300
301 template <typename T> class ArrayRef;
302 class DemandedBits;
303 template <typename InstTy> class InterleaveGroup;
304 class IRBuilderBase;
305 class Loop;
306 class ScalarEvolution;
307 class TargetTransformInfo;
308 class Type;
309 class Value;
310
311 namespace Intrinsic {
312 typedef unsigned ID;
313 }
314
315 /// A helper function for converting Scalar types to vector types. If
316 /// the incoming type is void, we return void. If the EC represents a
317 /// scalar, we return the scalar type.
ToVectorTy(Type * Scalar,ElementCount EC)318 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
319 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
320 return Scalar;
321 return VectorType::get(Scalar, EC);
322 }
323
ToVectorTy(Type * Scalar,unsigned VF)324 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
325 return ToVectorTy(Scalar, ElementCount::getFixed(VF));
326 }
327
328 /// Identify if the intrinsic is trivially vectorizable.
329 /// This method returns true if the intrinsic's argument types are all scalars
330 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
331 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
332 bool isTriviallyVectorizable(Intrinsic::ID ID);
333
334 /// Identifies if the vector form of the intrinsic has a scalar operand.
335 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
336 unsigned ScalarOpdIdx);
337
338 /// Identifies if the vector form of the intrinsic is overloaded on the type of
339 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
340 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
341
342 /// Returns intrinsic ID for call.
343 /// For the input call instruction it finds mapping intrinsic and returns
344 /// its intrinsic ID, in case it does not found it return not_intrinsic.
345 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
346 const TargetLibraryInfo *TLI);
347
348 /// Given a vector and an element number, see if the scalar value is
349 /// already around as a register, for example if it were inserted then extracted
350 /// from the vector.
351 Value *findScalarElement(Value *V, unsigned EltNo);
352
353 /// If all non-negative \p Mask elements are the same value, return that value.
354 /// If all elements are negative (undefined) or \p Mask contains different
355 /// non-negative values, return -1.
356 int getSplatIndex(ArrayRef<int> Mask);
357
358 /// Get splat value if the input is a splat vector or return nullptr.
359 /// The value may be extracted from a splat constants vector or from
360 /// a sequence of instructions that broadcast a single value into a vector.
361 Value *getSplatValue(const Value *V);
362
363 /// Return true if each element of the vector value \p V is poisoned or equal to
364 /// every other non-poisoned element. If an index element is specified, either
365 /// every element of the vector is poisoned or the element at that index is not
366 /// poisoned and equal to every other non-poisoned element.
367 /// This may be more powerful than the related getSplatValue() because it is
368 /// not limited by finding a scalar source value to a splatted vector.
369 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
370
371 /// Transform a shuffle mask's output demanded element mask into demanded
372 /// element masks for the 2 operands, returns false if the mask isn't valid.
373 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
374 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
375 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
376 const APInt &DemandedElts, APInt &DemandedLHS,
377 APInt &DemandedRHS, bool AllowUndefElts = false);
378
379 /// Replace each shuffle mask index with the scaled sequential indices for an
380 /// equivalent mask of narrowed elements. Mask elements that are less than 0
381 /// (sentinel values) are repeated in the output mask.
382 ///
383 /// Example with Scale = 4:
384 /// <4 x i32> <3, 2, 0, -1> -->
385 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
386 ///
387 /// This is the reverse process of widening shuffle mask elements, but it always
388 /// succeeds because the indexes can always be multiplied (scaled up) to map to
389 /// narrower vector elements.
390 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
391 SmallVectorImpl<int> &ScaledMask);
392
393 /// Try to transform a shuffle mask by replacing elements with the scaled index
394 /// for an equivalent mask of widened elements. If all mask elements that would
395 /// map to a wider element of the new mask are the same negative number
396 /// (sentinel value), that element of the new mask is the same value. If any
397 /// element in a given slice is negative and some other element in that slice is
398 /// not the same value, return false (partial matches with sentinel values are
399 /// not allowed).
400 ///
401 /// Example with Scale = 4:
402 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
403 /// <4 x i32> <3, 2, 0, -1>
404 ///
405 /// This is the reverse process of narrowing shuffle mask elements if it
406 /// succeeds. This transform is not always possible because indexes may not
407 /// divide evenly (scale down) to map to wider vector elements.
408 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
409 SmallVectorImpl<int> &ScaledMask);
410
411 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
412 /// to get the shuffle mask with widest possible elements.
413 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
414 SmallVectorImpl<int> &ScaledMask);
415
416 /// Splits and processes shuffle mask depending on the number of input and
417 /// output registers. The function does 2 main things: 1) splits the
418 /// source/destination vectors into real registers; 2) do the mask analysis to
419 /// identify which real registers are permuted. Then the function processes
420 /// resulting registers mask using provided action items. If no input register
421 /// is defined, \p NoInputAction action is used. If only 1 input register is
422 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
423 /// process > 2 input registers and masks.
424 /// \param Mask Original shuffle mask.
425 /// \param NumOfSrcRegs Number of source registers.
426 /// \param NumOfDestRegs Number of destination registers.
427 /// \param NumOfUsedRegs Number of actually used destination registers.
428 void processShuffleMasks(
429 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
430 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
431 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
432 function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
433
434 /// Compute a map of integer instructions to their minimum legal type
435 /// size.
436 ///
437 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
438 /// type (e.g. i32) whenever arithmetic is performed on them.
439 ///
440 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
441 /// the arithmetic type down again. However InstCombine refuses to create
442 /// illegal types, so for targets without i8 or i16 registers, the lengthening
443 /// and shrinking remains.
444 ///
445 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
446 /// their scalar equivalents do not, so during vectorization it is important to
447 /// remove these lengthens and truncates when deciding the profitability of
448 /// vectorization.
449 ///
450 /// This function analyzes the given range of instructions and determines the
451 /// minimum type size each can be converted to. It attempts to remove or
452 /// minimize type size changes across each def-use chain, so for example in the
453 /// following code:
454 ///
455 /// %1 = load i8, i8*
456 /// %2 = add i8 %1, 2
457 /// %3 = load i16, i16*
458 /// %4 = zext i8 %2 to i32
459 /// %5 = zext i16 %3 to i32
460 /// %6 = add i32 %4, %5
461 /// %7 = trunc i32 %6 to i16
462 ///
463 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
464 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
465 ///
466 /// If the optional TargetTransformInfo is provided, this function tries harder
467 /// to do less work by only looking at illegal types.
468 MapVector<Instruction*, uint64_t>
469 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
470 DemandedBits &DB,
471 const TargetTransformInfo *TTI=nullptr);
472
473 /// Compute the union of two access-group lists.
474 ///
475 /// If the list contains just one access group, it is returned directly. If the
476 /// list is empty, returns nullptr.
477 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
478
479 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
480 /// are both in. If either instruction does not access memory at all, it is
481 /// considered to be in every list.
482 ///
483 /// If the list contains just one access group, it is returned directly. If the
484 /// list is empty, returns nullptr.
485 MDNode *intersectAccessGroups(const Instruction *Inst1,
486 const Instruction *Inst2);
487
488 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
489 /// MD_nontemporal, MD_access_group].
490 /// For K in Kinds, we get the MDNode for K from each of the
491 /// elements of VL, compute their "intersection" (i.e., the most generic
492 /// metadata value that covers all of the individual values), and set I's
493 /// metadata for M equal to the intersection value.
494 ///
495 /// This function always sets a (possibly null) value for each K in Kinds.
496 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
497
498 /// Create a mask that filters the members of an interleave group where there
499 /// are gaps.
500 ///
501 /// For example, the mask for \p Group with interleave-factor 3
502 /// and \p VF 4, that has only its first member present is:
503 ///
504 /// <1,0,0,1,0,0,1,0,0,1,0,0>
505 ///
506 /// Note: The result is a mask of 0's and 1's, as opposed to the other
507 /// create[*]Mask() utilities which create a shuffle mask (mask that
508 /// consists of indices).
509 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
510 const InterleaveGroup<Instruction> &Group);
511
512 /// Create a mask with replicated elements.
513 ///
514 /// This function creates a shuffle mask for replicating each of the \p VF
515 /// elements in a vector \p ReplicationFactor times. It can be used to
516 /// transform a mask of \p VF elements into a mask of
517 /// \p VF * \p ReplicationFactor elements used by a predicated
518 /// interleaved-group of loads/stores whose Interleaved-factor ==
519 /// \p ReplicationFactor.
520 ///
521 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
522 ///
523 /// <0,0,0,1,1,1,2,2,2,3,3,3>
524 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
525 unsigned VF);
526
527 /// Create an interleave shuffle mask.
528 ///
529 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
530 /// vectorization factor \p VF into a single wide vector. The mask is of the
531 /// form:
532 ///
533 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
534 ///
535 /// For example, the mask for VF = 4 and NumVecs = 2 is:
536 ///
537 /// <0, 4, 1, 5, 2, 6, 3, 7>.
538 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
539
540 /// Create a stride shuffle mask.
541 ///
542 /// This function creates a shuffle mask whose elements begin at \p Start and
543 /// are incremented by \p Stride. The mask can be used to deinterleave an
544 /// interleaved vector into separate vectors of vectorization factor \p VF. The
545 /// mask is of the form:
546 ///
547 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
548 ///
549 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
550 ///
551 /// <0, 2, 4, 6>
552 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
553 unsigned VF);
554
555 /// Create a sequential shuffle mask.
556 ///
557 /// This function creates shuffle mask whose elements are sequential and begin
558 /// at \p Start. The mask contains \p NumInts integers and is padded with \p
559 /// NumUndefs undef values. The mask is of the form:
560 ///
561 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
562 ///
563 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
564 ///
565 /// <0, 1, 2, 3, undef, undef, undef, undef>
566 llvm::SmallVector<int, 16>
567 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
568
569 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
570 /// mask assuming both operands are identical. This assumes that the unary
571 /// shuffle will use elements from operand 0 (operand 1 will be unused).
572 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
573 unsigned NumElts);
574
575 /// Concatenate a list of vectors.
576 ///
577 /// This function generates code that concatenate the vectors in \p Vecs into a
578 /// single large vector. The number of vectors should be greater than one, and
579 /// their element types should be the same. The number of elements in the
580 /// vectors should also be the same; however, if the last vector has fewer
581 /// elements, it will be padded with undefs.
582 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
583
584 /// Given a mask vector of i1, Return true if all of the elements of this
585 /// predicate mask are known to be false or undef. That is, return true if all
586 /// lanes can be assumed inactive.
587 bool maskIsAllZeroOrUndef(Value *Mask);
588
589 /// Given a mask vector of i1, Return true if all of the elements of this
590 /// predicate mask are known to be true or undef. That is, return true if all
591 /// lanes can be assumed active.
592 bool maskIsAllOneOrUndef(Value *Mask);
593
594 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
595 /// for each lane which may be active.
596 APInt possiblyDemandedEltsInMask(Value *Mask);
597
598 /// The group of interleaved loads/stores sharing the same stride and
599 /// close to each other.
600 ///
601 /// Each member in this group has an index starting from 0, and the largest
602 /// index should be less than interleaved factor, which is equal to the absolute
603 /// value of the access's stride.
604 ///
605 /// E.g. An interleaved load group of factor 4:
606 /// for (unsigned i = 0; i < 1024; i+=4) {
607 /// a = A[i]; // Member of index 0
608 /// b = A[i+1]; // Member of index 1
609 /// d = A[i+3]; // Member of index 3
610 /// ...
611 /// }
612 ///
613 /// An interleaved store group of factor 4:
614 /// for (unsigned i = 0; i < 1024; i+=4) {
615 /// ...
616 /// A[i] = a; // Member of index 0
617 /// A[i+1] = b; // Member of index 1
618 /// A[i+2] = c; // Member of index 2
619 /// A[i+3] = d; // Member of index 3
620 /// }
621 ///
622 /// Note: the interleaved load group could have gaps (missing members), but
623 /// the interleaved store group doesn't allow gaps.
624 template <typename InstTy> class InterleaveGroup {
625 public:
InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)626 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
627 : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
628 InsertPos(nullptr) {}
629
InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)630 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
631 : Alignment(Alignment), InsertPos(Instr) {
632 Factor = std::abs(Stride);
633 assert(Factor > 1 && "Invalid interleave factor");
634
635 Reverse = Stride < 0;
636 Members[0] = Instr;
637 }
638
isReverse()639 bool isReverse() const { return Reverse; }
getFactor()640 uint32_t getFactor() const { return Factor; }
getAlign()641 Align getAlign() const { return Alignment; }
getNumMembers()642 uint32_t getNumMembers() const { return Members.size(); }
643
644 /// Try to insert a new member \p Instr with index \p Index and
645 /// alignment \p NewAlign. The index is related to the leader and it could be
646 /// negative if it is the new leader.
647 ///
648 /// \returns false if the instruction doesn't belong to the group.
insertMember(InstTy * Instr,int32_t Index,Align NewAlign)649 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
650 // Make sure the key fits in an int32_t.
651 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
652 if (!MaybeKey)
653 return false;
654 int32_t Key = *MaybeKey;
655
656 // Skip if the key is used for either the tombstone or empty special values.
657 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
658 DenseMapInfo<int32_t>::getEmptyKey() == Key)
659 return false;
660
661 // Skip if there is already a member with the same index.
662 if (Members.contains(Key))
663 return false;
664
665 if (Key > LargestKey) {
666 // The largest index is always less than the interleave factor.
667 if (Index >= static_cast<int32_t>(Factor))
668 return false;
669
670 LargestKey = Key;
671 } else if (Key < SmallestKey) {
672
673 // Make sure the largest index fits in an int32_t.
674 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
675 if (!MaybeLargestIndex)
676 return false;
677
678 // The largest index is always less than the interleave factor.
679 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
680 return false;
681
682 SmallestKey = Key;
683 }
684
685 // It's always safe to select the minimum alignment.
686 Alignment = std::min(Alignment, NewAlign);
687 Members[Key] = Instr;
688 return true;
689 }
690
691 /// Get the member with the given index \p Index
692 ///
693 /// \returns nullptr if contains no such member.
getMember(uint32_t Index)694 InstTy *getMember(uint32_t Index) const {
695 int32_t Key = SmallestKey + Index;
696 return Members.lookup(Key);
697 }
698
699 /// Get the index for the given member. Unlike the key in the member
700 /// map, the index starts from 0.
getIndex(const InstTy * Instr)701 uint32_t getIndex(const InstTy *Instr) const {
702 for (auto I : Members) {
703 if (I.second == Instr)
704 return I.first - SmallestKey;
705 }
706
707 llvm_unreachable("InterleaveGroup contains no such member");
708 }
709
getInsertPos()710 InstTy *getInsertPos() const { return InsertPos; }
setInsertPos(InstTy * Inst)711 void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
712
713 /// Add metadata (e.g. alias info) from the instructions in this group to \p
714 /// NewInst.
715 ///
716 /// FIXME: this function currently does not add noalias metadata a'la
717 /// addNewMedata. To do that we need to compute the intersection of the
718 /// noalias info from all members.
719 void addMetadata(InstTy *NewInst) const;
720
721 /// Returns true if this Group requires a scalar iteration to handle gaps.
requiresScalarEpilogue()722 bool requiresScalarEpilogue() const {
723 // If the last member of the Group exists, then a scalar epilog is not
724 // needed for this group.
725 if (getMember(getFactor() - 1))
726 return false;
727
728 // We have a group with gaps. It therefore can't be a reversed access,
729 // because such groups get invalidated (TODO).
730 assert(!isReverse() && "Group should have been invalidated");
731
732 // This is a group of loads, with gaps, and without a last-member
733 return true;
734 }
735
736 private:
737 uint32_t Factor; // Interleave Factor.
738 bool Reverse;
739 Align Alignment;
740 DenseMap<int32_t, InstTy *> Members;
741 int32_t SmallestKey = 0;
742 int32_t LargestKey = 0;
743
744 // To avoid breaking dependences, vectorized instructions of an interleave
745 // group should be inserted at either the first load or the last store in
746 // program order.
747 //
748 // E.g. %even = load i32 // Insert Position
749 // %add = add i32 %even // Use of %even
750 // %odd = load i32
751 //
752 // store i32 %even
753 // %odd = add i32 // Def of %odd
754 // store i32 %odd // Insert Position
755 InstTy *InsertPos;
756 };
757
758 /// Drive the analysis of interleaved memory accesses in the loop.
759 ///
760 /// Use this class to analyze interleaved accesses only when we can vectorize
761 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
762 /// on interleaved accesses is unsafe.
763 ///
764 /// The analysis collects interleave groups and records the relationships
765 /// between the member and the group in a map.
766 class InterleavedAccessInfo {
767 public:
InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)768 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
769 DominatorTree *DT, LoopInfo *LI,
770 const LoopAccessInfo *LAI)
771 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
772
~InterleavedAccessInfo()773 ~InterleavedAccessInfo() { invalidateGroups(); }
774
775 /// Analyze the interleaved accesses and collect them in interleave
776 /// groups. Substitute symbolic strides using \p Strides.
777 /// Consider also predicated loads/stores in the analysis if
778 /// \p EnableMaskedInterleavedGroup is true.
779 void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
780
781 /// Invalidate groups, e.g., in case all blocks in loop will be predicated
782 /// contrary to original assumption. Although we currently prevent group
783 /// formation for predicated accesses, we may be able to relax this limitation
784 /// in the future once we handle more complicated blocks. Returns true if any
785 /// groups were invalidated.
invalidateGroups()786 bool invalidateGroups() {
787 if (InterleaveGroups.empty()) {
788 assert(
789 !RequiresScalarEpilogue &&
790 "RequiresScalarEpilog should not be set without interleave groups");
791 return false;
792 }
793
794 InterleaveGroupMap.clear();
795 for (auto *Ptr : InterleaveGroups)
796 delete Ptr;
797 InterleaveGroups.clear();
798 RequiresScalarEpilogue = false;
799 return true;
800 }
801
802 /// Check if \p Instr belongs to any interleave group.
isInterleaved(Instruction * Instr)803 bool isInterleaved(Instruction *Instr) const {
804 return InterleaveGroupMap.contains(Instr);
805 }
806
807 /// Get the interleave group that \p Instr belongs to.
808 ///
809 /// \returns nullptr if doesn't have such group.
810 InterleaveGroup<Instruction> *
getInterleaveGroup(const Instruction * Instr)811 getInterleaveGroup(const Instruction *Instr) const {
812 return InterleaveGroupMap.lookup(Instr);
813 }
814
815 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
getInterleaveGroups()816 getInterleaveGroups() {
817 return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
818 }
819
820 /// Returns true if an interleaved group that may access memory
821 /// out-of-bounds requires a scalar epilogue iteration for correctness.
requiresScalarEpilogue()822 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
823
824 /// Invalidate groups that require a scalar epilogue (due to gaps). This can
825 /// happen when optimizing for size forbids a scalar epilogue, and the gap
826 /// cannot be filtered by masking the load/store.
827 void invalidateGroupsRequiringScalarEpilogue();
828
829 /// Returns true if we have any interleave groups.
hasGroups()830 bool hasGroups() const { return !InterleaveGroups.empty(); }
831
832 private:
833 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
834 /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
835 /// The interleaved access analysis can also add new predicates (for example
836 /// by versioning strides of pointers).
837 PredicatedScalarEvolution &PSE;
838
839 Loop *TheLoop;
840 DominatorTree *DT;
841 LoopInfo *LI;
842 const LoopAccessInfo *LAI;
843
844 /// True if the loop may contain non-reversed interleaved groups with
845 /// out-of-bounds accesses. We ensure we don't speculatively access memory
846 /// out-of-bounds by executing at least one scalar epilogue iteration.
847 bool RequiresScalarEpilogue = false;
848
849 /// Holds the relationships between the members and the interleave group.
850 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
851
852 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
853
854 /// Holds dependences among the memory accesses in the loop. It maps a source
855 /// access to a set of dependent sink accesses.
856 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
857
858 /// The descriptor for a strided memory access.
859 struct StrideDescriptor {
860 StrideDescriptor() = default;
StrideDescriptorStrideDescriptor861 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
862 Align Alignment)
863 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
864
865 // The access's stride. It is negative for a reverse access.
866 int64_t Stride = 0;
867
868 // The scalar expression of this access.
869 const SCEV *Scev = nullptr;
870
871 // The size of the memory object.
872 uint64_t Size = 0;
873
874 // The alignment of this access.
875 Align Alignment;
876 };
877
878 /// A type for holding instructions and their stride descriptors.
879 using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
880
881 /// Create a new interleave group with the given instruction \p Instr,
882 /// stride \p Stride and alignment \p Align.
883 ///
884 /// \returns the newly created interleave group.
885 InterleaveGroup<Instruction> *
createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)886 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
887 assert(!InterleaveGroupMap.count(Instr) &&
888 "Already in an interleaved access group");
889 InterleaveGroupMap[Instr] =
890 new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
891 InterleaveGroups.insert(InterleaveGroupMap[Instr]);
892 return InterleaveGroupMap[Instr];
893 }
894
895 /// Release the group and remove all the relationships.
releaseGroup(InterleaveGroup<Instruction> * Group)896 void releaseGroup(InterleaveGroup<Instruction> *Group) {
897 for (unsigned i = 0; i < Group->getFactor(); i++)
898 if (Instruction *Member = Group->getMember(i))
899 InterleaveGroupMap.erase(Member);
900
901 InterleaveGroups.erase(Group);
902 delete Group;
903 }
904
905 /// Collect all the accesses with a constant stride in program order.
906 void collectConstStrideAccesses(
907 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
908 const DenseMap<Value *, const SCEV *> &Strides);
909
910 /// Returns true if \p Stride is allowed in an interleaved group.
911 static bool isStrided(int Stride);
912
913 /// Returns true if \p BB is a predicated block.
isPredicated(BasicBlock * BB)914 bool isPredicated(BasicBlock *BB) const {
915 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
916 }
917
918 /// Returns true if LoopAccessInfo can be used for dependence queries.
areDependencesValid()919 bool areDependencesValid() const {
920 return LAI && LAI->getDepChecker().getDependences();
921 }
922
923 /// Returns true if memory accesses \p A and \p B can be reordered, if
924 /// necessary, when constructing interleaved groups.
925 ///
926 /// \p A must precede \p B in program order. We return false if reordering is
927 /// not necessary or is prevented because \p A and \p B may be dependent.
canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)928 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
929 StrideEntry *B) const {
930 // Code motion for interleaved accesses can potentially hoist strided loads
931 // and sink strided stores. The code below checks the legality of the
932 // following two conditions:
933 //
934 // 1. Potentially moving a strided load (B) before any store (A) that
935 // precedes B, or
936 //
937 // 2. Potentially moving a strided store (A) after any load or store (B)
938 // that A precedes.
939 //
940 // It's legal to reorder A and B if we know there isn't a dependence from A
941 // to B. Note that this determination is conservative since some
942 // dependences could potentially be reordered safely.
943
944 // A is potentially the source of a dependence.
945 auto *Src = A->first;
946 auto SrcDes = A->second;
947
948 // B is potentially the sink of a dependence.
949 auto *Sink = B->first;
950 auto SinkDes = B->second;
951
952 // Code motion for interleaved accesses can't violate WAR dependences.
953 // Thus, reordering is legal if the source isn't a write.
954 if (!Src->mayWriteToMemory())
955 return true;
956
957 // At least one of the accesses must be strided.
958 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
959 return true;
960
961 // If dependence information is not available from LoopAccessInfo,
962 // conservatively assume the instructions can't be reordered.
963 if (!areDependencesValid())
964 return false;
965
966 // If we know there is a dependence from source to sink, assume the
967 // instructions can't be reordered. Otherwise, reordering is legal.
968 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
969 }
970
971 /// Collect the dependences from LoopAccessInfo.
972 ///
973 /// We process the dependences once during the interleaved access analysis to
974 /// enable constant-time dependence queries.
collectDependences()975 void collectDependences() {
976 if (!areDependencesValid())
977 return;
978 auto *Deps = LAI->getDepChecker().getDependences();
979 for (auto Dep : *Deps)
980 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
981 }
982 };
983
984 } // llvm namespace
985
986 #endif
987