• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/api/Descriptor.h>
2 #include <ATen/native/vulkan/api/Utils.h>
3 
4 #include <algorithm>
5 #include <utility>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace api {
11 
12 //
13 // DescriptorSet
14 //
15 
DescriptorSet(VkDevice device,VkDescriptorSet handle,ShaderLayout::Signature shader_layout_signature)16 DescriptorSet::DescriptorSet(
17     VkDevice device,
18     VkDescriptorSet handle,
19     ShaderLayout::Signature shader_layout_signature)
20     : device_(device),
21       handle_(handle),
22       shader_layout_signature_(std::move(shader_layout_signature)),
23       bindings_{} {}
24 
DescriptorSet(DescriptorSet && other)25 DescriptorSet::DescriptorSet(DescriptorSet&& other) noexcept
26     : device_(other.device_),
27       handle_(other.handle_),
28       shader_layout_signature_(std::move(other.shader_layout_signature_)),
29       bindings_(std::move(other.bindings_)) {
30   other.handle_ = VK_NULL_HANDLE;
31 }
32 
operator =(DescriptorSet && other)33 DescriptorSet& DescriptorSet::operator=(DescriptorSet&& other) noexcept {
34   device_ = other.device_;
35   handle_ = other.handle_;
36   shader_layout_signature_ = std::move(other.shader_layout_signature_);
37   bindings_ = std::move(other.bindings_);
38 
39   other.handle_ = VK_NULL_HANDLE;
40 
41   return *this;
42 }
43 
bind(const uint32_t idx,const VulkanBuffer & buffer)44 DescriptorSet& DescriptorSet::bind(
45     const uint32_t idx,
46     const VulkanBuffer& buffer) {
47   VK_CHECK_COND(
48       buffer.has_memory(),
49       "Buffer must be bound to memory for it to be usable");
50 
51   DescriptorSet::ResourceBinding binder{};
52   binder.binding_idx = idx; // binding_idx
53   binder.descriptor_type = shader_layout_signature_[idx]; // descriptor_type
54   binder.is_image = false; // is_image
55   binder.resource_info.buffer_info.buffer = buffer.handle(); // buffer
56   binder.resource_info.buffer_info.offset = buffer.mem_offset(); // offset
57   binder.resource_info.buffer_info.range = buffer.mem_range(); // range
58   add_binding(binder);
59 
60   return *this;
61 }
62 
bind(const uint32_t idx,const VulkanImage & image)63 DescriptorSet& DescriptorSet::bind(
64     const uint32_t idx,
65     const VulkanImage& image) {
66   VK_CHECK_COND(
67       image.has_memory(), "Image must be bound to memory for it to be usable");
68 
69   VkImageLayout binding_layout = image.layout();
70   if (shader_layout_signature_[idx] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE) {
71     binding_layout = VK_IMAGE_LAYOUT_GENERAL;
72   }
73 
74   DescriptorSet::ResourceBinding binder{};
75   binder.binding_idx = idx; // binding_idx
76   binder.descriptor_type = shader_layout_signature_[idx]; // descriptor_type
77   binder.is_image = true; // is_image
78   binder.resource_info.image_info.sampler = image.sampler(); // buffer
79   binder.resource_info.image_info.imageView = image.image_view(); // imageView
80   binder.resource_info.image_info.imageLayout = binding_layout; // imageLayout
81   add_binding(binder);
82 
83   return *this;
84 }
85 
get_bind_handle() const86 VkDescriptorSet DescriptorSet::get_bind_handle() const {
87   std::vector<VkWriteDescriptorSet> write_descriptor_sets;
88 
89   for (const ResourceBinding& binding : bindings_) {
90     VkWriteDescriptorSet write{
91         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, // sType
92         nullptr, // pNext
93         handle_, // dstSet
94         binding.binding_idx, // dstBinding
95         0u, // dstArrayElement
96         1u, // descriptorCount
97         binding.descriptor_type, // descriptorType
98         nullptr, // pImageInfo
99         nullptr, // pBufferInfo
100         nullptr, // pTexelBufferView
101     };
102 
103     if (binding.is_image) {
104       write.pImageInfo = &binding.resource_info.image_info;
105     } else {
106       write.pBufferInfo = &binding.resource_info.buffer_info;
107     }
108 
109     write_descriptor_sets.emplace_back(write);
110   }
111 
112   vkUpdateDescriptorSets(
113       device_,
114       write_descriptor_sets.size(),
115       write_descriptor_sets.data(),
116       0u,
117       nullptr);
118 
119   VkDescriptorSet ret = handle_;
120 
121   return ret;
122 }
123 
add_binding(const ResourceBinding & binding)124 void DescriptorSet::add_binding(const ResourceBinding& binding) {
125   const auto bindings_itr = std::find_if(
126       bindings_.begin(),
127       bindings_.end(),
128       [binding_idx = binding.binding_idx](const ResourceBinding& other) {
129         return other.binding_idx == binding_idx;
130       });
131 
132   if (bindings_.end() == bindings_itr) {
133     bindings_.emplace_back(binding);
134   } else {
135     *bindings_itr = binding;
136   }
137 }
138 
139 //
140 // DescriptorSetPile
141 //
142 
DescriptorSetPile(const uint32_t pile_size,VkDescriptorSetLayout descriptor_set_layout,VkDevice device,VkDescriptorPool descriptor_pool)143 DescriptorSetPile::DescriptorSetPile(
144     const uint32_t pile_size,
145     VkDescriptorSetLayout descriptor_set_layout,
146     VkDevice device,
147     VkDescriptorPool descriptor_pool)
148     : pile_size_{pile_size},
149       set_layout_{descriptor_set_layout},
150       device_{device},
151       pool_{descriptor_pool},
152       descriptors_{},
153       in_use_(0u) {
154   descriptors_.resize(pile_size_);
155   allocate_new_batch();
156 }
157 
get_descriptor_set()158 VkDescriptorSet DescriptorSetPile::get_descriptor_set() {
159   // No-ops if there are descriptor sets available
160   allocate_new_batch();
161 
162   VkDescriptorSet handle = descriptors_[in_use_];
163   descriptors_[in_use_] = VK_NULL_HANDLE;
164 
165   in_use_++;
166   return handle;
167 }
168 
allocate_new_batch()169 void DescriptorSetPile::allocate_new_batch() {
170   // No-ops if there are still descriptor sets available
171   if (in_use_ < descriptors_.size() &&
172       descriptors_[in_use_] != VK_NULL_HANDLE) {
173     return;
174   }
175 
176   std::vector<VkDescriptorSetLayout> layouts(descriptors_.size());
177   fill(layouts.begin(), layouts.end(), set_layout_);
178 
179   const VkDescriptorSetAllocateInfo allocate_info{
180       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, // sType
181       nullptr, // pNext
182       pool_, // descriptorPool
183       utils::safe_downcast<uint32_t>(layouts.size()), // descriptorSetCount
184       layouts.data(), // pSetLayouts
185   };
186 
187   VK_CHECK(
188       vkAllocateDescriptorSets(device_, &allocate_info, descriptors_.data()));
189 
190   in_use_ = 0u;
191 }
192 
193 //
194 // DescriptorPool
195 //
196 
DescriptorPool(VkDevice device,const DescriptorPoolConfig & config)197 DescriptorPool::DescriptorPool(
198     VkDevice device,
199     const DescriptorPoolConfig& config)
200     : device_(device),
201       pool_(VK_NULL_HANDLE),
202       config_(config),
203       mutex_{},
204       piles_{} {
205   if (config.descriptorPoolMaxSets > 0) {
206     init(config);
207   }
208 }
209 
~DescriptorPool()210 DescriptorPool::~DescriptorPool() {
211   if (VK_NULL_HANDLE == pool_) {
212     return;
213   }
214   vkDestroyDescriptorPool(device_, pool_, nullptr);
215 }
216 
init(const DescriptorPoolConfig & config)217 void DescriptorPool::init(const DescriptorPoolConfig& config) {
218   VK_CHECK_COND(
219       pool_ == VK_NULL_HANDLE,
220       "Trying to init a DescriptorPool that has already been created!");
221 
222   config_ = config;
223 
224   std::vector<VkDescriptorPoolSize> type_sizes{
225       {
226           VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
227           config_.descriptorUniformBufferCount,
228       },
229       {
230           VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
231           config_.descriptorStorageBufferCount,
232       },
233       {
234           VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
235           config_.descriptorCombinedSamplerCount,
236       },
237       {
238           VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
239           config_.descriptorStorageBufferCount,
240       },
241   };
242 
243   const VkDescriptorPoolCreateInfo create_info{
244       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, // sType
245       nullptr, // pNext
246       0u, // flags
247       config_.descriptorPoolMaxSets, // maxSets
248       static_cast<uint32_t>(type_sizes.size()), // poolSizeCounts
249       type_sizes.data(), // pPoolSizes
250   };
251 
252   VK_CHECK(vkCreateDescriptorPool(device_, &create_info, nullptr, &pool_));
253 }
254 
get_descriptor_set(VkDescriptorSetLayout set_layout,const ShaderLayout::Signature & signature)255 DescriptorSet DescriptorPool::get_descriptor_set(
256     VkDescriptorSetLayout set_layout,
257     const ShaderLayout::Signature& signature) {
258   VK_CHECK_COND(
259       pool_ != VK_NULL_HANDLE, "DescriptorPool has not yet been initialized!");
260 
261   auto it = piles_.find(set_layout);
262   if (piles_.cend() == it) {
263     it = piles_
264              .insert({
265                  set_layout,
266                  DescriptorSetPile(
267                      config_.descriptorPileSizes, set_layout, device_, pool_),
268              })
269              .first;
270   }
271 
272   VkDescriptorSet handle = it->second.get_descriptor_set();
273 
274   return DescriptorSet(device_, handle, signature);
275 }
276 
flush()277 void DescriptorPool::flush() {
278   if (pool_ != VK_NULL_HANDLE) {
279     VK_CHECK(vkResetDescriptorPool(device_, pool_, 0u));
280     piles_.clear();
281   }
282 }
283 
284 } // namespace api
285 } // namespace vulkan
286 } // namespace native
287 } // namespace at
288