1 #include <ATen/native/vulkan/api/Descriptor.h>
2 #include <ATen/native/vulkan/api/Utils.h>
3
4 #include <algorithm>
5 #include <utility>
6
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace api {
11
12 //
13 // DescriptorSet
14 //
15
DescriptorSet(VkDevice device,VkDescriptorSet handle,ShaderLayout::Signature shader_layout_signature)16 DescriptorSet::DescriptorSet(
17 VkDevice device,
18 VkDescriptorSet handle,
19 ShaderLayout::Signature shader_layout_signature)
20 : device_(device),
21 handle_(handle),
22 shader_layout_signature_(std::move(shader_layout_signature)),
23 bindings_{} {}
24
DescriptorSet(DescriptorSet && other)25 DescriptorSet::DescriptorSet(DescriptorSet&& other) noexcept
26 : device_(other.device_),
27 handle_(other.handle_),
28 shader_layout_signature_(std::move(other.shader_layout_signature_)),
29 bindings_(std::move(other.bindings_)) {
30 other.handle_ = VK_NULL_HANDLE;
31 }
32
operator =(DescriptorSet && other)33 DescriptorSet& DescriptorSet::operator=(DescriptorSet&& other) noexcept {
34 device_ = other.device_;
35 handle_ = other.handle_;
36 shader_layout_signature_ = std::move(other.shader_layout_signature_);
37 bindings_ = std::move(other.bindings_);
38
39 other.handle_ = VK_NULL_HANDLE;
40
41 return *this;
42 }
43
bind(const uint32_t idx,const VulkanBuffer & buffer)44 DescriptorSet& DescriptorSet::bind(
45 const uint32_t idx,
46 const VulkanBuffer& buffer) {
47 VK_CHECK_COND(
48 buffer.has_memory(),
49 "Buffer must be bound to memory for it to be usable");
50
51 DescriptorSet::ResourceBinding binder{};
52 binder.binding_idx = idx; // binding_idx
53 binder.descriptor_type = shader_layout_signature_[idx]; // descriptor_type
54 binder.is_image = false; // is_image
55 binder.resource_info.buffer_info.buffer = buffer.handle(); // buffer
56 binder.resource_info.buffer_info.offset = buffer.mem_offset(); // offset
57 binder.resource_info.buffer_info.range = buffer.mem_range(); // range
58 add_binding(binder);
59
60 return *this;
61 }
62
bind(const uint32_t idx,const VulkanImage & image)63 DescriptorSet& DescriptorSet::bind(
64 const uint32_t idx,
65 const VulkanImage& image) {
66 VK_CHECK_COND(
67 image.has_memory(), "Image must be bound to memory for it to be usable");
68
69 VkImageLayout binding_layout = image.layout();
70 if (shader_layout_signature_[idx] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE) {
71 binding_layout = VK_IMAGE_LAYOUT_GENERAL;
72 }
73
74 DescriptorSet::ResourceBinding binder{};
75 binder.binding_idx = idx; // binding_idx
76 binder.descriptor_type = shader_layout_signature_[idx]; // descriptor_type
77 binder.is_image = true; // is_image
78 binder.resource_info.image_info.sampler = image.sampler(); // buffer
79 binder.resource_info.image_info.imageView = image.image_view(); // imageView
80 binder.resource_info.image_info.imageLayout = binding_layout; // imageLayout
81 add_binding(binder);
82
83 return *this;
84 }
85
get_bind_handle() const86 VkDescriptorSet DescriptorSet::get_bind_handle() const {
87 std::vector<VkWriteDescriptorSet> write_descriptor_sets;
88
89 for (const ResourceBinding& binding : bindings_) {
90 VkWriteDescriptorSet write{
91 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, // sType
92 nullptr, // pNext
93 handle_, // dstSet
94 binding.binding_idx, // dstBinding
95 0u, // dstArrayElement
96 1u, // descriptorCount
97 binding.descriptor_type, // descriptorType
98 nullptr, // pImageInfo
99 nullptr, // pBufferInfo
100 nullptr, // pTexelBufferView
101 };
102
103 if (binding.is_image) {
104 write.pImageInfo = &binding.resource_info.image_info;
105 } else {
106 write.pBufferInfo = &binding.resource_info.buffer_info;
107 }
108
109 write_descriptor_sets.emplace_back(write);
110 }
111
112 vkUpdateDescriptorSets(
113 device_,
114 write_descriptor_sets.size(),
115 write_descriptor_sets.data(),
116 0u,
117 nullptr);
118
119 VkDescriptorSet ret = handle_;
120
121 return ret;
122 }
123
add_binding(const ResourceBinding & binding)124 void DescriptorSet::add_binding(const ResourceBinding& binding) {
125 const auto bindings_itr = std::find_if(
126 bindings_.begin(),
127 bindings_.end(),
128 [binding_idx = binding.binding_idx](const ResourceBinding& other) {
129 return other.binding_idx == binding_idx;
130 });
131
132 if (bindings_.end() == bindings_itr) {
133 bindings_.emplace_back(binding);
134 } else {
135 *bindings_itr = binding;
136 }
137 }
138
139 //
140 // DescriptorSetPile
141 //
142
DescriptorSetPile(const uint32_t pile_size,VkDescriptorSetLayout descriptor_set_layout,VkDevice device,VkDescriptorPool descriptor_pool)143 DescriptorSetPile::DescriptorSetPile(
144 const uint32_t pile_size,
145 VkDescriptorSetLayout descriptor_set_layout,
146 VkDevice device,
147 VkDescriptorPool descriptor_pool)
148 : pile_size_{pile_size},
149 set_layout_{descriptor_set_layout},
150 device_{device},
151 pool_{descriptor_pool},
152 descriptors_{},
153 in_use_(0u) {
154 descriptors_.resize(pile_size_);
155 allocate_new_batch();
156 }
157
get_descriptor_set()158 VkDescriptorSet DescriptorSetPile::get_descriptor_set() {
159 // No-ops if there are descriptor sets available
160 allocate_new_batch();
161
162 VkDescriptorSet handle = descriptors_[in_use_];
163 descriptors_[in_use_] = VK_NULL_HANDLE;
164
165 in_use_++;
166 return handle;
167 }
168
allocate_new_batch()169 void DescriptorSetPile::allocate_new_batch() {
170 // No-ops if there are still descriptor sets available
171 if (in_use_ < descriptors_.size() &&
172 descriptors_[in_use_] != VK_NULL_HANDLE) {
173 return;
174 }
175
176 std::vector<VkDescriptorSetLayout> layouts(descriptors_.size());
177 fill(layouts.begin(), layouts.end(), set_layout_);
178
179 const VkDescriptorSetAllocateInfo allocate_info{
180 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, // sType
181 nullptr, // pNext
182 pool_, // descriptorPool
183 utils::safe_downcast<uint32_t>(layouts.size()), // descriptorSetCount
184 layouts.data(), // pSetLayouts
185 };
186
187 VK_CHECK(
188 vkAllocateDescriptorSets(device_, &allocate_info, descriptors_.data()));
189
190 in_use_ = 0u;
191 }
192
193 //
194 // DescriptorPool
195 //
196
DescriptorPool(VkDevice device,const DescriptorPoolConfig & config)197 DescriptorPool::DescriptorPool(
198 VkDevice device,
199 const DescriptorPoolConfig& config)
200 : device_(device),
201 pool_(VK_NULL_HANDLE),
202 config_(config),
203 mutex_{},
204 piles_{} {
205 if (config.descriptorPoolMaxSets > 0) {
206 init(config);
207 }
208 }
209
~DescriptorPool()210 DescriptorPool::~DescriptorPool() {
211 if (VK_NULL_HANDLE == pool_) {
212 return;
213 }
214 vkDestroyDescriptorPool(device_, pool_, nullptr);
215 }
216
init(const DescriptorPoolConfig & config)217 void DescriptorPool::init(const DescriptorPoolConfig& config) {
218 VK_CHECK_COND(
219 pool_ == VK_NULL_HANDLE,
220 "Trying to init a DescriptorPool that has already been created!");
221
222 config_ = config;
223
224 std::vector<VkDescriptorPoolSize> type_sizes{
225 {
226 VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
227 config_.descriptorUniformBufferCount,
228 },
229 {
230 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
231 config_.descriptorStorageBufferCount,
232 },
233 {
234 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
235 config_.descriptorCombinedSamplerCount,
236 },
237 {
238 VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
239 config_.descriptorStorageBufferCount,
240 },
241 };
242
243 const VkDescriptorPoolCreateInfo create_info{
244 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, // sType
245 nullptr, // pNext
246 0u, // flags
247 config_.descriptorPoolMaxSets, // maxSets
248 static_cast<uint32_t>(type_sizes.size()), // poolSizeCounts
249 type_sizes.data(), // pPoolSizes
250 };
251
252 VK_CHECK(vkCreateDescriptorPool(device_, &create_info, nullptr, &pool_));
253 }
254
get_descriptor_set(VkDescriptorSetLayout set_layout,const ShaderLayout::Signature & signature)255 DescriptorSet DescriptorPool::get_descriptor_set(
256 VkDescriptorSetLayout set_layout,
257 const ShaderLayout::Signature& signature) {
258 VK_CHECK_COND(
259 pool_ != VK_NULL_HANDLE, "DescriptorPool has not yet been initialized!");
260
261 auto it = piles_.find(set_layout);
262 if (piles_.cend() == it) {
263 it = piles_
264 .insert({
265 set_layout,
266 DescriptorSetPile(
267 config_.descriptorPileSizes, set_layout, device_, pool_),
268 })
269 .first;
270 }
271
272 VkDescriptorSet handle = it->second.get_descriptor_set();
273
274 return DescriptorSet(device_, handle, signature);
275 }
276
flush()277 void DescriptorPool::flush() {
278 if (pool_ != VK_NULL_HANDLE) {
279 VK_CHECK(vkResetDescriptorPool(device_, pool_, 0u));
280 piles_.clear();
281 }
282 }
283
284 } // namespace api
285 } // namespace vulkan
286 } // namespace native
287 } // namespace at
288