1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/vulkan/compute_pipeline.h"
16
17 #include "src/vulkan/command_pool.h"
18 #include "src/vulkan/device.h"
19
20 namespace amber {
21 namespace vulkan {
22
ComputePipeline(Device * device,uint32_t fence_timeout_ms,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info)23 ComputePipeline::ComputePipeline(
24 Device* device,
25 uint32_t fence_timeout_ms,
26 const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info)
27 : Pipeline(PipelineType::kCompute,
28 device,
29 fence_timeout_ms,
30 shader_stage_info) {}
31
32 ComputePipeline::~ComputePipeline() = default;
33
Initialize(CommandPool * pool)34 Result ComputePipeline::Initialize(CommandPool* pool) {
35 return Pipeline::Initialize(pool);
36 }
37
CreateVkComputePipeline(const VkPipelineLayout & pipeline_layout,VkPipeline * pipeline)38 Result ComputePipeline::CreateVkComputePipeline(
39 const VkPipelineLayout& pipeline_layout,
40 VkPipeline* pipeline) {
41 auto shader_stage_info = GetVkShaderStageInfo();
42 if (shader_stage_info.size() != 1) {
43 return Result(
44 "Vulkan::CreateVkComputePipeline number of shaders given to compute "
45 "pipeline is not 1");
46 }
47
48 if (shader_stage_info[0].stage != VK_SHADER_STAGE_COMPUTE_BIT)
49 return Result("Vulkan: Non compute shader for compute pipeline");
50
51 shader_stage_info[0].pName = GetEntryPointName(VK_SHADER_STAGE_COMPUTE_BIT);
52
53 VkComputePipelineCreateInfo pipeline_info = VkComputePipelineCreateInfo();
54 pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
55 pipeline_info.stage = shader_stage_info[0];
56 pipeline_info.layout = pipeline_layout;
57
58 if (device_->GetPtrs()->vkCreateComputePipelines(
59 device_->GetVkDevice(), VK_NULL_HANDLE, 1, &pipeline_info, nullptr,
60 pipeline) != VK_SUCCESS) {
61 return Result("Vulkan::Calling vkCreateComputePipelines Fail");
62 }
63
64 return {};
65 }
66
Compute(uint32_t x,uint32_t y,uint32_t z)67 Result ComputePipeline::Compute(uint32_t x, uint32_t y, uint32_t z) {
68 Result r = SendDescriptorDataToDeviceIfNeeded();
69 if (!r.IsSuccess())
70 return r;
71
72 VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
73 r = CreateVkPipelineLayout(&pipeline_layout);
74 if (!r.IsSuccess())
75 return r;
76
77 VkPipeline pipeline = VK_NULL_HANDLE;
78 r = CreateVkComputePipeline(pipeline_layout, &pipeline);
79 if (!r.IsSuccess())
80 return r;
81
82 // Note that a command updating a descriptor set and a command using
83 // it must be submitted separately, because using a descriptor set
84 // while updating it is not safe.
85 UpdateDescriptorSetsIfNeeded();
86
87 {
88 CommandBufferGuard guard(GetCommandBuffer());
89 if (!guard.IsRecording())
90 return guard.GetResult();
91
92 BindVkDescriptorSets(pipeline_layout);
93
94 r = RecordPushConstant(pipeline_layout);
95 if (!r.IsSuccess())
96 return r;
97
98 device_->GetPtrs()->vkCmdBindPipeline(command_->GetVkCommandBuffer(),
99 VK_PIPELINE_BIND_POINT_COMPUTE,
100 pipeline);
101 device_->GetPtrs()->vkCmdDispatch(command_->GetVkCommandBuffer(), x, y, z);
102
103 r = guard.Submit(GetFenceTimeout());
104 if (!r.IsSuccess())
105 return r;
106 }
107
108 r = ReadbackDescriptorsToHostDataQueue();
109 if (!r.IsSuccess())
110 return r;
111
112 device_->GetPtrs()->vkDestroyPipeline(device_->GetVkDevice(), pipeline,
113 nullptr);
114 device_->GetPtrs()->vkDestroyPipelineLayout(device_->GetVkDevice(),
115 pipeline_layout, nullptr);
116
117 return {};
118 }
119
120 } // namespace vulkan
121 } // namespace amber
122