• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // CLKernelVk.cpp: Implements the class methods for CLKernelVk.
7 
8 #include "libANGLE/renderer/vulkan/CLKernelVk.h"
9 #include "libANGLE/renderer/vulkan/CLContextVk.h"
10 #include "libANGLE/renderer/vulkan/CLDeviceVk.h"
11 #include "libANGLE/renderer/vulkan/CLProgramVk.h"
12 
13 #include "libANGLE/CLContext.h"
14 #include "libANGLE/CLKernel.h"
15 #include "libANGLE/CLProgram.h"
16 #include "libANGLE/cl_utils.h"
17 
18 namespace rx
19 {
20 
CLKernelVk(const cl::Kernel & kernel,std::string & name,std::string & attributes,CLKernelArguments & args)21 CLKernelVk::CLKernelVk(const cl::Kernel &kernel,
22                        std::string &name,
23                        std::string &attributes,
24                        CLKernelArguments &args)
25     : CLKernelImpl(kernel),
26       mProgram(&kernel.getProgram().getImpl<CLProgramVk>()),
27       mContext(&kernel.getProgram().getContext().getImpl<CLContextVk>()),
28       mName(name),
29       mAttributes(attributes),
30       mArgs(args)
31 {
32     mShaderProgramHelper.setShader(gl::ShaderType::Compute,
33                                    mKernel.getProgram().getImpl<CLProgramVk>().getShaderModule());
34 }
35 
~CLKernelVk()36 CLKernelVk::~CLKernelVk()
37 {
38     for (auto &dsLayouts : mDescriptorSetLayouts)
39     {
40         dsLayouts.reset();
41     }
42 
43     mPipelineLayout.reset();
44     for (auto &pipelineHelper : mComputePipelineCache)
45     {
46         pipelineHelper.destroy(mContext->getDevice());
47     }
48     mShaderProgramHelper.destroy(mContext->getRenderer());
49 }
50 
setArg(cl_uint argIndex,size_t argSize,const void * argValue)51 angle::Result CLKernelVk::setArg(cl_uint argIndex, size_t argSize, const void *argValue)
52 {
53     auto &arg = mArgs.at(argIndex);
54     if (arg.used)
55     {
56         arg.handle     = const_cast<void *>(argValue);
57         arg.handleSize = argSize;
58     }
59 
60     return angle::Result::Continue;
61 }
62 
createInfo(CLKernelImpl::Info * info) const63 angle::Result CLKernelVk::createInfo(CLKernelImpl::Info *info) const
64 {
65     info->functionName = mName;
66     info->attributes   = mAttributes;
67     info->numArgs      = static_cast<cl_uint>(mArgs.size());
68     for (const auto &arg : mArgs)
69     {
70         ArgInfo argInfo;
71         argInfo.name             = arg.info.name;
72         argInfo.typeName         = arg.info.typeName;
73         argInfo.accessQualifier  = arg.info.accessQualifier;
74         argInfo.addressQualifier = arg.info.addressQualifier;
75         argInfo.typeQualifier    = arg.info.typeQualifier;
76         info->args.push_back(std::move(argInfo));
77     }
78 
79     auto &ctx = mKernel.getProgram().getContext();
80     info->workGroups.resize(ctx.getDevices().size());
81     const CLProgramVk::DeviceProgramData *deviceProgramData = nullptr;
82     for (auto i = 0u; i < ctx.getDevices().size(); ++i)
83     {
84         auto &workGroup     = info->workGroups[i];
85         const auto deviceVk = &ctx.getDevices()[i]->getImpl<CLDeviceVk>();
86         deviceProgramData   = mProgram->getDeviceProgramData(ctx.getDevices()[i]->getNative());
87         if (deviceProgramData == nullptr)
88         {
89             continue;
90         }
91 
92         // TODO: http://anglebug.com/8576
93         ANGLE_TRY(
94             deviceVk->getInfoSizeT(cl::DeviceInfo::MaxWorkGroupSize, &workGroup.workGroupSize));
95 
96         // TODO: http://anglebug.com/8575
97         workGroup.privateMemSize = 0;
98         workGroup.localMemSize   = 0;
99 
100         workGroup.prefWorkGroupSizeMultiple = 16u;
101         workGroup.globalWorkSize            = {0, 0, 0};
102         if (deviceProgramData->reflectionData.kernelCompileWorkgroupSize.contains(mName))
103         {
104             workGroup.compileWorkGroupSize = {
105                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[0],
106                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[1],
107                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[2]};
108         }
109         else
110         {
111             workGroup.compileWorkGroupSize = {0, 0, 0};
112         }
113     }
114 
115     return angle::Result::Continue;
116 }
117 
getOrCreateComputePipeline(vk::PipelineCacheAccess * pipelineCache,const cl::NDRange & ndrange,const cl::Device & device,vk::PipelineHelper ** pipelineOut,cl::WorkgroupCount * workgroupCountOut)118 angle::Result CLKernelVk::getOrCreateComputePipeline(vk::PipelineCacheAccess *pipelineCache,
119                                                      const cl::NDRange &ndrange,
120                                                      const cl::Device &device,
121                                                      vk::PipelineHelper **pipelineOut,
122                                                      cl::WorkgroupCount *workgroupCountOut)
123 {
124     uint32_t constantDataOffset = 0;
125     angle::FixedVector<size_t, 3> specConstantData;
126     angle::FixedVector<VkSpecializationMapEntry, 3> mapEntries;
127     const CLProgramVk::DeviceProgramData *devProgramData =
128         getProgram()->getDeviceProgramData(device.getNative());
129     ASSERT(devProgramData != nullptr);
130 
131     // Start with Workgroup size (WGS) from kernel attribute (if available)
132     cl::WorkgroupSize workgroupSize = devProgramData->getCompiledWorkgroupSize(getKernelName());
133 
134     if (workgroupSize == kEmptyWorkgroupSize)
135     {
136         if (ndrange.nullLocalWorkSize)
137         {
138             // NULL value was passed, in which case the OpenCL implementation will determine
139             // how to be break the global work-items into appropriate work-group instances.
140             workgroupSize = device.getImpl<CLDeviceVk>().selectWorkGroupSize(ndrange);
141         }
142         else
143         {
144             // Local work size (LWS) was valid, use that as WGS
145             workgroupSize = ndrange.localWorkSize;
146         }
147 
148         // If at least one of the kernels does not use the reqd_work_group_size attribute, the
149         // Vulkan SPIR-V produced by the compiler will contain specialization constants
150         const std::array<uint32_t, 3> &specConstantWorkgroupSizeIDs =
151             devProgramData->reflectionData.specConstantWorkgroupSizeIDs;
152         ASSERT(ndrange.workDimensions <= 3);
153         for (cl_uint i = 0; i < ndrange.workDimensions; ++i)
154         {
155             mapEntries.push_back(
156                 VkSpecializationMapEntry{.constantID = specConstantWorkgroupSizeIDs.at(i),
157                                          .offset     = constantDataOffset,
158                                          .size       = sizeof(uint32_t)});
159             constantDataOffset += sizeof(uint32_t);
160             specConstantData.push_back(workgroupSize[i]);
161         }
162     }
163 
164     // Calculate the workgroup count
165     // TODO: Add support for non-uniform WGS
166     // http://angleproject:8631
167     ASSERT(workgroupSize[0] != 0);
168     ASSERT(workgroupSize[1] != 0);
169     ASSERT(workgroupSize[2] != 0);
170     (*workgroupCountOut)[0] = static_cast<uint32_t>((ndrange.globalWorkSize[0] / workgroupSize[0]));
171     (*workgroupCountOut)[1] = static_cast<uint32_t>((ndrange.globalWorkSize[1] / workgroupSize[1]));
172     (*workgroupCountOut)[2] = static_cast<uint32_t>((ndrange.globalWorkSize[2] / workgroupSize[2]));
173 
174     VkSpecializationInfo computeSpecializationInfo{
175         .mapEntryCount = static_cast<uint32_t>(mapEntries.size()),
176         .pMapEntries   = mapEntries.data(),
177         .dataSize      = specConstantData.size() * sizeof(specConstantData[0]),
178         .pData         = specConstantData.data(),
179     };
180 
181     // Now get or create (on compute pipeline cache miss) compute pipeline and return it
182     return mShaderProgramHelper.getOrCreateComputePipeline(
183         mContext, &mComputePipelineCache, pipelineCache, getPipelineLayout().get(),
184         vk::ComputePipelineFlags{}, PipelineSource::Draw, pipelineOut, mName.c_str(),
185         &computeSpecializationInfo);
186 }
187 
188 }  // namespace rx
189