• Home
  • Raw
  • Download

Lines Matching refs:spirv

39     LabelT label, const spirv::TargetEnv &targetEnv,  in checkExtensionRequirements()
40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { in checkExtensionRequirements()
46 for (spirv::Extension ext : ors) in checkExtensionRequirements()
47 extStrings.push_back(spirv::stringifyExtension(ext)); in checkExtensionRequirements()
66 LabelT label, const spirv::TargetEnv &targetEnv, in checkCapabilityRequirements()
67 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { in checkCapabilityRequirements()
73 for (spirv::Capability cap : ors) in checkCapabilityRequirements()
74 capStrings.push_back(spirv::stringifyCapability(cap)); in checkCapabilityRequirements()
110 MAP_FN(spirv::StorageClass::Generic, 1) \
111 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::Uniform, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::PushConstant, 7) \
117 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
118 MAP_FN(spirv::StorageClass::Input, 9) \
119 MAP_FN(spirv::StorageClass::Output, 10) \
120 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
121 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
122 MAP_FN(spirv::StorageClass::Image, 13) \
123 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \
124 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \
125 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \
126 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \
127 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \
128 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \
129 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { in getMemorySpaceForStorageClass()
142 Optional<spirv::StorageClass>
161 if (t.isa<spirv::ScalarType>()) { in getTypeNumBytes()
234 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, in convertScalarType()
235 Optional<spirv::StorageClass> storageClass = {}) { in convertScalarType()
237 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
238 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
270 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, in convertVectorType()
271 Optional<spirv::StorageClass> storageClass = {}) { in convertVectorType()
272 if (!spirv::CompositeType::isValid(type)) {
282 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
283 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
284 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
285 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
293 targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
305 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv, in convertTensorType()
314 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); in convertTensorType()
340 return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); in convertTensorType()
343 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv, in convertMemrefType()
345 Optional<spirv::StorageClass> storageClass = in convertMemrefType()
357 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { in convertMemrefType()
377 auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); in convertMemrefType()
379 auto structType = spirv::StructType::get(arrayType, 0); in convertMemrefType()
380 return spirv::PointerType::get(structType, *storageClass); in convertMemrefType()
400 spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); in convertMemrefType()
404 auto structType = *storageClass == spirv::StorageClass::Workgroup in convertMemrefType()
405 ? spirv::StructType::get(arrayType) in convertMemrefType()
406 : spirv::StructType::get(arrayType, 0); in convertMemrefType()
407 return spirv::PointerType::get(structType, *storageClass); in convertMemrefType()
410 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) in SPIRVTypeConverter()
425 addConversion([](spirv::SPIRVType type) { return type; }); in SPIRVTypeConverter()
432 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) in SPIRVTypeConverter()
438 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) in SPIRVTypeConverter()
490 auto newFuncOp = rewriter.create<spirv::FuncOp>( in matchAndRewrite()
521 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, in getBuiltinVariable()
522 spirv::BuiltIn builtin) { in getBuiltinVariable()
525 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { in getBuiltinVariable()
527 spirv::SPIRVDialect::getAttributeName( in getBuiltinVariable()
528 spirv::Decoration::BuiltIn))) { in getBuiltinVariable()
529 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); in getBuiltinVariable()
539 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { in getBuiltinVarName()
544 static spirv::GlobalVariableOp
545 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, in getOrInsertBuiltinVariable()
553 spirv::GlobalVariableOp newVarOp; in getOrInsertBuiltinVariable()
555 case spirv::BuiltIn::NumWorkgroups: in getOrInsertBuiltinVariable()
556 case spirv::BuiltIn::WorkgroupSize: in getOrInsertBuiltinVariable()
557 case spirv::BuiltIn::WorkgroupId: in getOrInsertBuiltinVariable()
558 case spirv::BuiltIn::LocalInvocationId: in getOrInsertBuiltinVariable()
559 case spirv::BuiltIn::GlobalInvocationId: { in getOrInsertBuiltinVariable()
560 auto ptrType = spirv::PointerType::get( in getOrInsertBuiltinVariable()
562 spirv::StorageClass::Input); in getOrInsertBuiltinVariable()
565 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); in getOrInsertBuiltinVariable()
568 case spirv::BuiltIn::SubgroupId: in getOrInsertBuiltinVariable()
569 case spirv::BuiltIn::NumSubgroups: in getOrInsertBuiltinVariable()
570 case spirv::BuiltIn::SubgroupSize: { in getOrInsertBuiltinVariable()
571 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), in getOrInsertBuiltinVariable()
572 spirv::StorageClass::Input); in getOrInsertBuiltinVariable()
575 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); in getOrInsertBuiltinVariable()
585 Value mlir::spirv::getBuiltinVariableValue(Operation *op, in getBuiltinVariableValue()
586 spirv::BuiltIn builtin, in getBuiltinVariableValue()
594 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( in getBuiltinVariableValue()
596 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); in getBuiltinVariableValue()
597 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); in getBuiltinVariableValue()
604 spirv::AccessChainOp mlir::spirv::getElementPtr( in getElementPtr()
621 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); in getElementPtr()
630 Value ptrLoc = builder.create<spirv::ConstantOp>( in getElementPtr()
635 Value strideVal = builder.create<spirv::ConstantOp>( in getElementPtr()
638 builder.create<spirv::IMulOp>(loc, strideVal, index.value()); in getElementPtr()
639 ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update); in getElementPtr()
643 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); in getElementPtr()
651 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, in setABIAttrs()
652 spirv::EntryPointABIAttr entryPointInfo, in setABIAttrs()
653 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { in setABIAttrs()
655 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); in setABIAttrs()
659 funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); in setABIAttrs()
667 std::unique_ptr<spirv::SPIRVConversionTarget>
668 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { in get()
680 spirv::SPIRVConversionTarget::SPIRVConversionTarget( in SPIRVConversionTarget()
681 spirv::TargetEnvAttr targetAttr) in SPIRVConversionTarget()
684 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { in isLegalOp()
688 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) in isLegalOp()
692 << spirv::stringifyVersion(minVersion.getMinVersion()) in isLegalOp()
696 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) in isLegalOp()
700 << spirv::stringifyVersion(maxVersion.getMaxVersion()) in isLegalOp()
708 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) in isLegalOp()
716 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) in isLegalOp()
727 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) in isLegalOp()
732 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; in isLegalOp()
733 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; in isLegalOp()
736 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); in isLegalOp()
742 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); in isLegalOp()