1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallSet.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/Support/CheckedArithmetic.h" 21 22 namespace llvm { 23 24 /// Describes the type of Parameters 25 enum class VFParamKind { 26 Vector, // No semantic information. 27 OMP_Linear, // declare simd linear(i) 28 OMP_LinearRef, // declare simd linear(ref(i)) 29 OMP_LinearVal, // declare simd linear(val(i)) 30 OMP_LinearUVal, // declare simd linear(uval(i)) 31 OMP_LinearPos, // declare simd linear(i:c) uniform(c) 32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c) 33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c) 34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c 35 OMP_Uniform, // declare simd uniform(i) 36 GlobalPredicate, // Global logical predicate that acts on all lanes 37 // of the input and output mask concurrently. For 38 // example, it is implied by the `M` token in the 39 // Vector Function ABI mangled name. 40 Unknown 41 }; 42 43 /// Describes the type of Instruction Set Architecture 44 enum class VFISAKind { 45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON) 46 SVE, // AArch64 Scalable Vector Extension 47 SSE, // x86 SSE 48 AVX, // x86 AVX 49 AVX2, // x86 AVX2 50 AVX512, // x86 AVX512 51 LLVM, // LLVM internal ISA for functions that are not 52 // attached to an existing ABI via name mangling. 53 Unknown // Unknown ISA 54 }; 55 56 /// Encapsulates information needed to describe a parameter. 57 /// 58 /// The description of the parameter is not linked directly to 59 /// OpenMP or any other vector function description. This structure 60 /// is extendible to handle other paradigms that describe vector 61 /// functions and their parameters. 62 struct VFParameter { 63 unsigned ParamPos; // Parameter Position in Scalar Function. 64 VFParamKind ParamKind; // Kind of Parameter. 65 int LinearStepOrPos = 0; // Step or Position of the Parameter. 66 Align Alignment = Align(); // Optional aligment in bytes, defaulted to 1. 67 68 // Comparison operator. 69 bool operator==(const VFParameter &Other) const { 70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) == 71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos, 72 Other.Alignment); 73 } 74 }; 75 76 /// Contains the information about the kind of vectorization 77 /// available. 78 /// 79 /// This object in independent on the paradigm used to 80 /// represent vector functions. in particular, it is not attached to 81 /// any target-specific ABI. 82 struct VFShape { 83 unsigned VF; // Vectorization factor. 84 bool IsScalable; // True if the function is a scalable function. 85 SmallVector<VFParameter, 8> Parameters; // List of parameter informations. 86 // Comparison operator. 87 bool operator==(const VFShape &Other) const { 88 return std::tie(VF, IsScalable, Parameters) == 89 std::tie(Other.VF, Other.IsScalable, Other.Parameters); 90 } 91 92 /// Update the parameter in position P.ParamPos to P. updateParamVFShape93 void updateParam(VFParameter P) { 94 assert(P.ParamPos < Parameters.size() && "Invalid parameter position."); 95 Parameters[P.ParamPos] = P; 96 assert(hasValidParameterList() && "Invalid parameter list"); 97 } 98 99 // Retrieve the basic vectorization shape of the function, where all 100 // parameters are mapped to VFParamKind::Vector with \p EC 101 // lanes. Specifies whether the function has a Global Predicate 102 // argument via \p HasGlobalPred. getVFShape103 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) { 104 SmallVector<VFParameter, 8> Parameters; 105 for (unsigned I = 0; I < CI.arg_size(); ++I) 106 Parameters.push_back(VFParameter({I, VFParamKind::Vector})); 107 if (HasGlobalPred) 108 Parameters.push_back( 109 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate})); 110 111 return {EC.Min, EC.Scalable, Parameters}; 112 } 113 /// Sanity check on the Parameters in the VFShape. 114 bool hasValidParameterList() const; 115 }; 116 117 /// Holds the VFShape for a specific scalar to vector function mapping. 118 struct VFInfo { 119 VFShape Shape; // Classification of the vector function. 120 StringRef ScalarName; // Scalar Function Name. 121 StringRef VectorName; // Vector Function Name associated to this VFInfo. 122 VFISAKind ISA; // Instruction Set Architecture. 123 124 // Comparison operator. 125 bool operator==(const VFInfo &Other) const { 126 return std::tie(Shape, ScalarName, VectorName, ISA) == 127 std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA); 128 } 129 }; 130 131 namespace VFABI { 132 /// LLVM Internal VFABI ISA token for vector functions. 133 static constexpr char const *_LLVM_ = "_LLVM_"; 134 135 /// Function to contruct a VFInfo out of a mangled names in the 136 /// following format: 137 /// 138 /// <VFABI_name>{(<redirection>)} 139 /// 140 /// where <VFABI_name> is the name of the vector function, mangled according 141 /// to the rules described in the Vector Function ABI of the target vector 142 /// extentsion (or <isa> from now on). The <VFABI_name> is in the following 143 /// format: 144 /// 145 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] 146 /// 147 /// This methods support demangling rules for the following <isa>: 148 /// 149 /// * AArch64: https://developer.arm.com/docs/101129/latest 150 /// 151 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and 152 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt 153 /// 154 /// \param MangledName -> input string in the format 155 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]. 156 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName); 157 158 /// Retrieve the `VFParamKind` from a string token. 159 VFParamKind getVFParamKindFromString(const StringRef Token); 160 161 // Name of the attribute where the variant mappings are stored. 162 static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; 163 164 /// Populates a set of strings representing the Vector Function ABI variants 165 /// associated to the CallInst CI. 166 void getVectorVariantNames(const CallInst &CI, 167 SmallVectorImpl<std::string> &VariantMappings); 168 } // end namespace VFABI 169 170 template <typename T> class ArrayRef; 171 class DemandedBits; 172 class GetElementPtrInst; 173 template <typename InstTy> class InterleaveGroup; 174 class Loop; 175 class ScalarEvolution; 176 class TargetTransformInfo; 177 class Type; 178 class Value; 179 180 namespace Intrinsic { 181 typedef unsigned ID; 182 } 183 184 /// Identify if the intrinsic is trivially vectorizable. 185 /// This method returns true if the intrinsic's argument types are all scalars 186 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 187 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic. 188 bool isTriviallyVectorizable(Intrinsic::ID ID); 189 190 /// Identifies if the vector form of the intrinsic has a scalar operand. 191 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx); 192 193 /// Returns intrinsic ID for call. 194 /// For the input call instruction it finds mapping intrinsic and returns 195 /// its intrinsic ID, in case it does not found it return not_intrinsic. 196 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 197 const TargetLibraryInfo *TLI); 198 199 /// Find the operand of the GEP that should be checked for consecutive 200 /// stores. This ignores trailing indices that have no effect on the final 201 /// pointer. 202 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep); 203 204 /// If the argument is a GEP, then returns the operand identified by 205 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 206 /// operand, it returns that instead. 207 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 208 209 /// If a value has only one user that is a CastInst, return it. 210 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty); 211 212 /// Get the stride of a pointer access in a loop. Looks for symbolic 213 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 214 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 215 216 /// Given a vector and an element number, see if the scalar value is 217 /// already around as a register, for example if it were inserted then extracted 218 /// from the vector. 219 Value *findScalarElement(Value *V, unsigned EltNo); 220 221 /// Get splat value if the input is a splat vector or return nullptr. 222 /// The value may be extracted from a splat constants vector or from 223 /// a sequence of instructions that broadcast a single value into a vector. 224 const Value *getSplatValue(const Value *V); 225 226 /// Return true if the input value is known to be a vector with all identical 227 /// elements (potentially including undefined elements). 228 /// This may be more powerful than the related getSplatValue() because it is 229 /// not limited by finding a scalar source value to a splatted vector. 230 bool isSplatValue(const Value *V, unsigned Depth = 0); 231 232 /// Compute a map of integer instructions to their minimum legal type 233 /// size. 234 /// 235 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 236 /// type (e.g. i32) whenever arithmetic is performed on them. 237 /// 238 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 239 /// the arithmetic type down again. However InstCombine refuses to create 240 /// illegal types, so for targets without i8 or i16 registers, the lengthening 241 /// and shrinking remains. 242 /// 243 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 244 /// their scalar equivalents do not, so during vectorization it is important to 245 /// remove these lengthens and truncates when deciding the profitability of 246 /// vectorization. 247 /// 248 /// This function analyzes the given range of instructions and determines the 249 /// minimum type size each can be converted to. It attempts to remove or 250 /// minimize type size changes across each def-use chain, so for example in the 251 /// following code: 252 /// 253 /// %1 = load i8, i8* 254 /// %2 = add i8 %1, 2 255 /// %3 = load i16, i16* 256 /// %4 = zext i8 %2 to i32 257 /// %5 = zext i16 %3 to i32 258 /// %6 = add i32 %4, %5 259 /// %7 = trunc i32 %6 to i16 260 /// 261 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 262 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 263 /// 264 /// If the optional TargetTransformInfo is provided, this function tries harder 265 /// to do less work by only looking at illegal types. 266 MapVector<Instruction*, uint64_t> 267 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 268 DemandedBits &DB, 269 const TargetTransformInfo *TTI=nullptr); 270 271 /// Compute the union of two access-group lists. 272 /// 273 /// If the list contains just one access group, it is returned directly. If the 274 /// list is empty, returns nullptr. 275 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 276 277 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 278 /// are both in. If either instruction does not access memory at all, it is 279 /// considered to be in every list. 280 /// 281 /// If the list contains just one access group, it is returned directly. If the 282 /// list is empty, returns nullptr. 283 MDNode *intersectAccessGroups(const Instruction *Inst1, 284 const Instruction *Inst2); 285 286 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 287 /// MD_nontemporal, MD_access_group]. 288 /// For K in Kinds, we get the MDNode for K from each of the 289 /// elements of VL, compute their "intersection" (i.e., the most generic 290 /// metadata value that covers all of the individual values), and set I's 291 /// metadata for M equal to the intersection value. 292 /// 293 /// This function always sets a (possibly null) value for each K in Kinds. 294 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 295 296 /// Create a mask that filters the members of an interleave group where there 297 /// are gaps. 298 /// 299 /// For example, the mask for \p Group with interleave-factor 3 300 /// and \p VF 4, that has only its first member present is: 301 /// 302 /// <1,0,0,1,0,0,1,0,0,1,0,0> 303 /// 304 /// Note: The result is a mask of 0's and 1's, as opposed to the other 305 /// create[*]Mask() utilities which create a shuffle mask (mask that 306 /// consists of indices). 307 Constant *createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF, 308 const InterleaveGroup<Instruction> &Group); 309 310 /// Create a mask with replicated elements. 311 /// 312 /// This function creates a shuffle mask for replicating each of the \p VF 313 /// elements in a vector \p ReplicationFactor times. It can be used to 314 /// transform a mask of \p VF elements into a mask of 315 /// \p VF * \p ReplicationFactor elements used by a predicated 316 /// interleaved-group of loads/stores whose Interleaved-factor == 317 /// \p ReplicationFactor. 318 /// 319 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 320 /// 321 /// <0,0,0,1,1,1,2,2,2,3,3,3> 322 Constant *createReplicatedMask(IRBuilder<> &Builder, unsigned ReplicationFactor, 323 unsigned VF); 324 325 /// Create an interleave shuffle mask. 326 /// 327 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 328 /// vectorization factor \p VF into a single wide vector. The mask is of the 329 /// form: 330 /// 331 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 332 /// 333 /// For example, the mask for VF = 4 and NumVecs = 2 is: 334 /// 335 /// <0, 4, 1, 5, 2, 6, 3, 7>. 336 Constant *createInterleaveMask(IRBuilder<> &Builder, unsigned VF, 337 unsigned NumVecs); 338 339 /// Create a stride shuffle mask. 340 /// 341 /// This function creates a shuffle mask whose elements begin at \p Start and 342 /// are incremented by \p Stride. The mask can be used to deinterleave an 343 /// interleaved vector into separate vectors of vectorization factor \p VF. The 344 /// mask is of the form: 345 /// 346 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 347 /// 348 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 349 /// 350 /// <0, 2, 4, 6> 351 Constant *createStrideMask(IRBuilder<> &Builder, unsigned Start, 352 unsigned Stride, unsigned VF); 353 354 /// Create a sequential shuffle mask. 355 /// 356 /// This function creates shuffle mask whose elements are sequential and begin 357 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 358 /// NumUndefs undef values. The mask is of the form: 359 /// 360 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 361 /// 362 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 363 /// 364 /// <0, 1, 2, 3, undef, undef, undef, undef> 365 Constant *createSequentialMask(IRBuilder<> &Builder, unsigned Start, 366 unsigned NumInts, unsigned NumUndefs); 367 368 /// Concatenate a list of vectors. 369 /// 370 /// This function generates code that concatenate the vectors in \p Vecs into a 371 /// single large vector. The number of vectors should be greater than one, and 372 /// their element types should be the same. The number of elements in the 373 /// vectors should also be the same; however, if the last vector has fewer 374 /// elements, it will be padded with undefs. 375 Value *concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs); 376 377 /// Given a mask vector of the form <Y x i1>, Return true if all of the 378 /// elements of this predicate mask are false or undef. That is, return true 379 /// if all lanes can be assumed inactive. 380 bool maskIsAllZeroOrUndef(Value *Mask); 381 382 /// Given a mask vector of the form <Y x i1>, Return true if all of the 383 /// elements of this predicate mask are true or undef. That is, return true 384 /// if all lanes can be assumed active. 385 bool maskIsAllOneOrUndef(Value *Mask); 386 387 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 388 /// for each lane which may be active. 389 APInt possiblyDemandedEltsInMask(Value *Mask); 390 391 /// The group of interleaved loads/stores sharing the same stride and 392 /// close to each other. 393 /// 394 /// Each member in this group has an index starting from 0, and the largest 395 /// index should be less than interleaved factor, which is equal to the absolute 396 /// value of the access's stride. 397 /// 398 /// E.g. An interleaved load group of factor 4: 399 /// for (unsigned i = 0; i < 1024; i+=4) { 400 /// a = A[i]; // Member of index 0 401 /// b = A[i+1]; // Member of index 1 402 /// d = A[i+3]; // Member of index 3 403 /// ... 404 /// } 405 /// 406 /// An interleaved store group of factor 4: 407 /// for (unsigned i = 0; i < 1024; i+=4) { 408 /// ... 409 /// A[i] = a; // Member of index 0 410 /// A[i+1] = b; // Member of index 1 411 /// A[i+2] = c; // Member of index 2 412 /// A[i+3] = d; // Member of index 3 413 /// } 414 /// 415 /// Note: the interleaved load group could have gaps (missing members), but 416 /// the interleaved store group doesn't allow gaps. 417 template <typename InstTy> class InterleaveGroup { 418 public: InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)419 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 420 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 421 InsertPos(nullptr) {} 422 InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)423 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 424 : Alignment(Alignment), InsertPos(Instr) { 425 Factor = std::abs(Stride); 426 assert(Factor > 1 && "Invalid interleave factor"); 427 428 Reverse = Stride < 0; 429 Members[0] = Instr; 430 } 431 isReverse()432 bool isReverse() const { return Reverse; } getFactor()433 uint32_t getFactor() const { return Factor; } getAlignment()434 uint32_t getAlignment() const { return Alignment.value(); } getNumMembers()435 uint32_t getNumMembers() const { return Members.size(); } 436 437 /// Try to insert a new member \p Instr with index \p Index and 438 /// alignment \p NewAlign. The index is related to the leader and it could be 439 /// negative if it is the new leader. 440 /// 441 /// \returns false if the instruction doesn't belong to the group. insertMember(InstTy * Instr,int32_t Index,Align NewAlign)442 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 443 // Make sure the key fits in an int32_t. 444 Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 445 if (!MaybeKey) 446 return false; 447 int32_t Key = *MaybeKey; 448 449 // Skip if there is already a member with the same index. 450 if (Members.find(Key) != Members.end()) 451 return false; 452 453 if (Key > LargestKey) { 454 // The largest index is always less than the interleave factor. 455 if (Index >= static_cast<int32_t>(Factor)) 456 return false; 457 458 LargestKey = Key; 459 } else if (Key < SmallestKey) { 460 461 // Make sure the largest index fits in an int32_t. 462 Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 463 if (!MaybeLargestIndex) 464 return false; 465 466 // The largest index is always less than the interleave factor. 467 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 468 return false; 469 470 SmallestKey = Key; 471 } 472 473 // It's always safe to select the minimum alignment. 474 Alignment = std::min(Alignment, NewAlign); 475 Members[Key] = Instr; 476 return true; 477 } 478 479 /// Get the member with the given index \p Index 480 /// 481 /// \returns nullptr if contains no such member. getMember(uint32_t Index)482 InstTy *getMember(uint32_t Index) const { 483 int32_t Key = SmallestKey + Index; 484 auto Member = Members.find(Key); 485 if (Member == Members.end()) 486 return nullptr; 487 488 return Member->second; 489 } 490 491 /// Get the index for the given member. Unlike the key in the member 492 /// map, the index starts from 0. getIndex(const InstTy * Instr)493 uint32_t getIndex(const InstTy *Instr) const { 494 for (auto I : Members) { 495 if (I.second == Instr) 496 return I.first - SmallestKey; 497 } 498 499 llvm_unreachable("InterleaveGroup contains no such member"); 500 } 501 getInsertPos()502 InstTy *getInsertPos() const { return InsertPos; } setInsertPos(InstTy * Inst)503 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 504 505 /// Add metadata (e.g. alias info) from the instructions in this group to \p 506 /// NewInst. 507 /// 508 /// FIXME: this function currently does not add noalias metadata a'la 509 /// addNewMedata. To do that we need to compute the intersection of the 510 /// noalias info from all members. 511 void addMetadata(InstTy *NewInst) const; 512 513 /// Returns true if this Group requires a scalar iteration to handle gaps. requiresScalarEpilogue()514 bool requiresScalarEpilogue() const { 515 // If the last member of the Group exists, then a scalar epilog is not 516 // needed for this group. 517 if (getMember(getFactor() - 1)) 518 return false; 519 520 // We have a group with gaps. It therefore cannot be a group of stores, 521 // and it can't be a reversed access, because such groups get invalidated. 522 assert(!getMember(0)->mayWriteToMemory() && 523 "Group should have been invalidated"); 524 assert(!isReverse() && "Group should have been invalidated"); 525 526 // This is a group of loads, with gaps, and without a last-member 527 return true; 528 } 529 530 private: 531 uint32_t Factor; // Interleave Factor. 532 bool Reverse; 533 Align Alignment; 534 DenseMap<int32_t, InstTy *> Members; 535 int32_t SmallestKey = 0; 536 int32_t LargestKey = 0; 537 538 // To avoid breaking dependences, vectorized instructions of an interleave 539 // group should be inserted at either the first load or the last store in 540 // program order. 541 // 542 // E.g. %even = load i32 // Insert Position 543 // %add = add i32 %even // Use of %even 544 // %odd = load i32 545 // 546 // store i32 %even 547 // %odd = add i32 // Def of %odd 548 // store i32 %odd // Insert Position 549 InstTy *InsertPos; 550 }; 551 552 /// Drive the analysis of interleaved memory accesses in the loop. 553 /// 554 /// Use this class to analyze interleaved accesses only when we can vectorize 555 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 556 /// on interleaved accesses is unsafe. 557 /// 558 /// The analysis collects interleave groups and records the relationships 559 /// between the member and the group in a map. 560 class InterleavedAccessInfo { 561 public: InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)562 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 563 DominatorTree *DT, LoopInfo *LI, 564 const LoopAccessInfo *LAI) 565 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 566 ~InterleavedAccessInfo()567 ~InterleavedAccessInfo() { reset(); } 568 569 /// Analyze the interleaved accesses and collect them in interleave 570 /// groups. Substitute symbolic strides using \p Strides. 571 /// Consider also predicated loads/stores in the analysis if 572 /// \p EnableMaskedInterleavedGroup is true. 573 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 574 575 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 576 /// contrary to original assumption. Although we currently prevent group 577 /// formation for predicated accesses, we may be able to relax this limitation 578 /// in the future once we handle more complicated blocks. reset()579 void reset() { 580 InterleaveGroupMap.clear(); 581 for (auto *Ptr : InterleaveGroups) 582 delete Ptr; 583 InterleaveGroups.clear(); 584 RequiresScalarEpilogue = false; 585 } 586 587 588 /// Check if \p Instr belongs to any interleave group. isInterleaved(Instruction * Instr)589 bool isInterleaved(Instruction *Instr) const { 590 return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end(); 591 } 592 593 /// Get the interleave group that \p Instr belongs to. 594 /// 595 /// \returns nullptr if doesn't have such group. 596 InterleaveGroup<Instruction> * getInterleaveGroup(const Instruction * Instr)597 getInterleaveGroup(const Instruction *Instr) const { 598 if (InterleaveGroupMap.count(Instr)) 599 return InterleaveGroupMap.find(Instr)->second; 600 return nullptr; 601 } 602 603 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> getInterleaveGroups()604 getInterleaveGroups() { 605 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 606 } 607 608 /// Returns true if an interleaved group that may access memory 609 /// out-of-bounds requires a scalar epilogue iteration for correctness. requiresScalarEpilogue()610 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 611 612 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 613 /// happen when optimizing for size forbids a scalar epilogue, and the gap 614 /// cannot be filtered by masking the load/store. 615 void invalidateGroupsRequiringScalarEpilogue(); 616 617 private: 618 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 619 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 620 /// The interleaved access analysis can also add new predicates (for example 621 /// by versioning strides of pointers). 622 PredicatedScalarEvolution &PSE; 623 624 Loop *TheLoop; 625 DominatorTree *DT; 626 LoopInfo *LI; 627 const LoopAccessInfo *LAI; 628 629 /// True if the loop may contain non-reversed interleaved groups with 630 /// out-of-bounds accesses. We ensure we don't speculatively access memory 631 /// out-of-bounds by executing at least one scalar epilogue iteration. 632 bool RequiresScalarEpilogue = false; 633 634 /// Holds the relationships between the members and the interleave group. 635 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 636 637 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 638 639 /// Holds dependences among the memory accesses in the loop. It maps a source 640 /// access to a set of dependent sink accesses. 641 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 642 643 /// The descriptor for a strided memory access. 644 struct StrideDescriptor { 645 StrideDescriptor() = default; StrideDescriptorStrideDescriptor646 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 647 Align Alignment) 648 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 649 650 // The access's stride. It is negative for a reverse access. 651 int64_t Stride = 0; 652 653 // The scalar expression of this access. 654 const SCEV *Scev = nullptr; 655 656 // The size of the memory object. 657 uint64_t Size = 0; 658 659 // The alignment of this access. 660 Align Alignment; 661 }; 662 663 /// A type for holding instructions and their stride descriptors. 664 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 665 666 /// Create a new interleave group with the given instruction \p Instr, 667 /// stride \p Stride and alignment \p Align. 668 /// 669 /// \returns the newly created interleave group. 670 InterleaveGroup<Instruction> * createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)671 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 672 assert(!InterleaveGroupMap.count(Instr) && 673 "Already in an interleaved access group"); 674 InterleaveGroupMap[Instr] = 675 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 676 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 677 return InterleaveGroupMap[Instr]; 678 } 679 680 /// Release the group and remove all the relationships. releaseGroup(InterleaveGroup<Instruction> * Group)681 void releaseGroup(InterleaveGroup<Instruction> *Group) { 682 for (unsigned i = 0; i < Group->getFactor(); i++) 683 if (Instruction *Member = Group->getMember(i)) 684 InterleaveGroupMap.erase(Member); 685 686 InterleaveGroups.erase(Group); 687 delete Group; 688 } 689 690 /// Collect all the accesses with a constant stride in program order. 691 void collectConstStrideAccesses( 692 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 693 const ValueToValueMap &Strides); 694 695 /// Returns true if \p Stride is allowed in an interleaved group. 696 static bool isStrided(int Stride); 697 698 /// Returns true if \p BB is a predicated block. isPredicated(BasicBlock * BB)699 bool isPredicated(BasicBlock *BB) const { 700 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 701 } 702 703 /// Returns true if LoopAccessInfo can be used for dependence queries. areDependencesValid()704 bool areDependencesValid() const { 705 return LAI && LAI->getDepChecker().getDependences(); 706 } 707 708 /// Returns true if memory accesses \p A and \p B can be reordered, if 709 /// necessary, when constructing interleaved groups. 710 /// 711 /// \p A must precede \p B in program order. We return false if reordering is 712 /// not necessary or is prevented because \p A and \p B may be dependent. canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)713 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 714 StrideEntry *B) const { 715 // Code motion for interleaved accesses can potentially hoist strided loads 716 // and sink strided stores. The code below checks the legality of the 717 // following two conditions: 718 // 719 // 1. Potentially moving a strided load (B) before any store (A) that 720 // precedes B, or 721 // 722 // 2. Potentially moving a strided store (A) after any load or store (B) 723 // that A precedes. 724 // 725 // It's legal to reorder A and B if we know there isn't a dependence from A 726 // to B. Note that this determination is conservative since some 727 // dependences could potentially be reordered safely. 728 729 // A is potentially the source of a dependence. 730 auto *Src = A->first; 731 auto SrcDes = A->second; 732 733 // B is potentially the sink of a dependence. 734 auto *Sink = B->first; 735 auto SinkDes = B->second; 736 737 // Code motion for interleaved accesses can't violate WAR dependences. 738 // Thus, reordering is legal if the source isn't a write. 739 if (!Src->mayWriteToMemory()) 740 return true; 741 742 // At least one of the accesses must be strided. 743 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 744 return true; 745 746 // If dependence information is not available from LoopAccessInfo, 747 // conservatively assume the instructions can't be reordered. 748 if (!areDependencesValid()) 749 return false; 750 751 // If we know there is a dependence from source to sink, assume the 752 // instructions can't be reordered. Otherwise, reordering is legal. 753 return Dependences.find(Src) == Dependences.end() || 754 !Dependences.lookup(Src).count(Sink); 755 } 756 757 /// Collect the dependences from LoopAccessInfo. 758 /// 759 /// We process the dependences once during the interleaved access analysis to 760 /// enable constant-time dependence queries. collectDependences()761 void collectDependences() { 762 if (!areDependencesValid()) 763 return; 764 auto *Deps = LAI->getDepChecker().getDependences(); 765 for (auto Dep : *Deps) 766 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); 767 } 768 }; 769 770 } // llvm namespace 771 772 #endif 773