• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/vulkan/compute_pipeline.h"
16 
17 #include "src/vulkan/command_pool.h"
18 #include "src/vulkan/device.h"
19 
20 namespace amber {
21 namespace vulkan {
22 
ComputePipeline(Device * device,uint32_t fence_timeout_ms,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info)23 ComputePipeline::ComputePipeline(
24     Device* device,
25     uint32_t fence_timeout_ms,
26     const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info)
27     : Pipeline(PipelineType::kCompute,
28                device,
29                fence_timeout_ms,
30                shader_stage_info) {}
31 
32 ComputePipeline::~ComputePipeline() = default;
33 
Initialize(CommandPool * pool)34 Result ComputePipeline::Initialize(CommandPool* pool) {
35   return Pipeline::Initialize(pool);
36 }
37 
CreateVkComputePipeline(const VkPipelineLayout & pipeline_layout,VkPipeline * pipeline)38 Result ComputePipeline::CreateVkComputePipeline(
39     const VkPipelineLayout& pipeline_layout,
40     VkPipeline* pipeline) {
41   auto shader_stage_info = GetVkShaderStageInfo();
42   if (shader_stage_info.size() != 1) {
43     return Result(
44         "Vulkan::CreateVkComputePipeline number of shaders given to compute "
45         "pipeline is not 1");
46   }
47 
48   if (shader_stage_info[0].stage != VK_SHADER_STAGE_COMPUTE_BIT)
49     return Result("Vulkan: Non compute shader for compute pipeline");
50 
51   shader_stage_info[0].pName = GetEntryPointName(VK_SHADER_STAGE_COMPUTE_BIT);
52 
53   VkComputePipelineCreateInfo pipeline_info = VkComputePipelineCreateInfo();
54   pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
55   pipeline_info.stage = shader_stage_info[0];
56   pipeline_info.layout = pipeline_layout;
57 
58   if (device_->GetPtrs()->vkCreateComputePipelines(
59           device_->GetVkDevice(), VK_NULL_HANDLE, 1, &pipeline_info, nullptr,
60           pipeline) != VK_SUCCESS) {
61     return Result("Vulkan::Calling vkCreateComputePipelines Fail");
62   }
63 
64   return {};
65 }
66 
Compute(uint32_t x,uint32_t y,uint32_t z)67 Result ComputePipeline::Compute(uint32_t x, uint32_t y, uint32_t z) {
68   Result r = SendDescriptorDataToDeviceIfNeeded();
69   if (!r.IsSuccess())
70     return r;
71 
72   VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
73   r = CreateVkPipelineLayout(&pipeline_layout);
74   if (!r.IsSuccess())
75     return r;
76 
77   VkPipeline pipeline = VK_NULL_HANDLE;
78   r = CreateVkComputePipeline(pipeline_layout, &pipeline);
79   if (!r.IsSuccess())
80     return r;
81 
82   // Note that a command updating a descriptor set and a command using
83   // it must be submitted separately, because using a descriptor set
84   // while updating it is not safe.
85   UpdateDescriptorSetsIfNeeded();
86 
87   {
88     CommandBufferGuard guard(GetCommandBuffer());
89     if (!guard.IsRecording())
90       return guard.GetResult();
91 
92     BindVkDescriptorSets(pipeline_layout);
93 
94     r = RecordPushConstant(pipeline_layout);
95     if (!r.IsSuccess())
96       return r;
97 
98     device_->GetPtrs()->vkCmdBindPipeline(command_->GetVkCommandBuffer(),
99                                           VK_PIPELINE_BIND_POINT_COMPUTE,
100                                           pipeline);
101     device_->GetPtrs()->vkCmdDispatch(command_->GetVkCommandBuffer(), x, y, z);
102 
103     r = guard.Submit(GetFenceTimeout());
104     if (!r.IsSuccess())
105       return r;
106   }
107 
108   r = ReadbackDescriptorsToHostDataQueue();
109   if (!r.IsSuccess())
110     return r;
111 
112   device_->GetPtrs()->vkDestroyPipeline(device_->GetVkDevice(), pipeline,
113                                         nullptr);
114   device_->GetPtrs()->vkDestroyPipelineLayout(device_->GetVkDevice(),
115                                               pipeline_layout, nullptr);
116 
117   return {};
118 }
119 
120 }  // namespace vulkan
121 }  // namespace amber
122