• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <utility>
2 
3 #include <ATen/native/vulkan/api/Shader.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace api {
9 
10 //
11 // ShaderInfo
12 //
13 
ShaderInfo()14 ShaderInfo::ShaderInfo()
15     : src_code{
16           nullptr,
17           0u,
18       } {}
19 
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout)20 ShaderInfo::ShaderInfo(
21     std::string name,
22     const uint32_t* const spirv_bin,
23     const uint32_t size,
24     std::vector<VkDescriptorType>  layout)
25     : src_code{
26           spirv_bin,
27           size,
28       },
29       kernel_name{std::move(name)},
30       kernel_layout{std::move(layout)} {}
31 
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout,const std::vector<uint32_t> & tile_size,const StorageType bias_storage_type,const StorageType weight_storage_type)32 ShaderInfo::ShaderInfo(
33     std::string name,
34     const uint32_t* const spirv_bin,
35     const uint32_t size,
36     std::vector<VkDescriptorType>  layout,
37     const std::vector<uint32_t>& tile_size,
38     const StorageType bias_storage_type,
39     const StorageType weight_storage_type)
40     : src_code{
41           spirv_bin,
42           size,
43       },
44       kernel_name{std::move(name)},
45       kernel_layout{std::move(layout)},
46       tile_size(tile_size),
47       bias_storage_type(bias_storage_type),
48       weight_storage_type(weight_storage_type) {
49   for (uint64_t i = 0; i < tile_size.size(); ++i) {
50     out_tile_size.data[i] = tile_size[i];
51   }
52 }
53 
operator ==(const ShaderInfo & _1,const ShaderInfo & _2)54 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
55   return (
56       _1.src_code.bin == _2.src_code.bin &&
57       _1.src_code.size == _2.src_code.size);
58 }
59 
60 //
61 // ShaderLayout
62 //
63 
ShaderLayout(VkDevice device,const ShaderLayout::Signature & signature)64 ShaderLayout::ShaderLayout(
65     VkDevice device,
66     const ShaderLayout::Signature& signature)
67     : device_(device), handle_{VK_NULL_HANDLE} {
68   std::vector<VkDescriptorSetLayoutBinding> bindings;
69 
70   uint32_t binding_num = 0u;
71   for (const VkDescriptorType type : signature) {
72     bindings.push_back({
73         binding_num++, // binding
74         type, // descriptorType
75         1u, // descriptorCount
76         VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
77         nullptr, // pImmutableSamplers
78     });
79   }
80 
81   const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
82       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
83       nullptr, // pNext
84       0u, // flags
85       static_cast<uint32_t>(bindings.size()), // bindingCount
86       bindings.data(), // pBindings
87   };
88 
89   VK_CHECK(vkCreateDescriptorSetLayout(
90       device_, &descriptor_set_layout_create_info, nullptr, &handle_));
91 }
92 
ShaderLayout(ShaderLayout && other)93 ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept
94     : device_(other.device_), handle_(other.handle_) {
95   other.handle_ = VK_NULL_HANDLE;
96 }
97 
~ShaderLayout()98 ShaderLayout::~ShaderLayout() {
99   if (VK_NULL_HANDLE == handle_) {
100     return;
101   }
102   vkDestroyDescriptorSetLayout(device_, handle_, nullptr);
103   handle_ = VK_NULL_HANDLE;
104 }
105 
swap(ShaderLayout & lhs,ShaderLayout & rhs)106 void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
107   VkDevice tmp_device = lhs.device_;
108   VkDescriptorSetLayout tmp_handle = lhs.handle_;
109 
110   lhs.device_ = rhs.device_;
111   lhs.handle_ = rhs.handle_;
112 
113   rhs.device_ = tmp_device;
114   rhs.handle_ = tmp_handle;
115 }
116 
117 //
118 // ShaderModule
119 //
120 
ShaderModule(VkDevice device,const ShaderInfo & source)121 ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source)
122     : device_(device), handle_{VK_NULL_HANDLE} {
123   const uint32_t* code = source.src_code.bin;
124   uint32_t size = source.src_code.size;
125 
126   const VkShaderModuleCreateInfo shader_module_create_info{
127       VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
128       nullptr, // pNext
129       0u, // flags
130       size, // codeSize
131       code, // pCode
132   };
133 
134   VK_CHECK(vkCreateShaderModule(
135       device_, &shader_module_create_info, nullptr, &handle_));
136 }
137 
ShaderModule(ShaderModule && other)138 ShaderModule::ShaderModule(ShaderModule&& other) noexcept
139     : device_(other.device_), handle_(other.handle_) {
140   other.handle_ = VK_NULL_HANDLE;
141 }
142 
~ShaderModule()143 ShaderModule::~ShaderModule() {
144   if (VK_NULL_HANDLE == handle_) {
145     return;
146   }
147   vkDestroyShaderModule(device_, handle_, nullptr);
148   handle_ = VK_NULL_HANDLE;
149 }
150 
swap(ShaderModule & lhs,ShaderModule & rhs)151 void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept {
152   VkDevice tmp_device = lhs.device_;
153   VkShaderModule tmp_handle = lhs.handle_;
154 
155   lhs.device_ = rhs.device_;
156   lhs.handle_ = rhs.handle_;
157 
158   rhs.device_ = tmp_device;
159   rhs.handle_ = tmp_handle;
160 }
161 
162 //
163 // ShaderLayoutCache
164 //
165 
ShaderLayoutCache(VkDevice device)166 ShaderLayoutCache::ShaderLayoutCache(VkDevice device)
167     : cache_mutex_{}, device_(device), cache_{} {}
168 
ShaderLayoutCache(ShaderLayoutCache && other)169 ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept
170     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
171   std::lock_guard<std::mutex> lock(other.cache_mutex_);
172 }
173 
~ShaderLayoutCache()174 ShaderLayoutCache::~ShaderLayoutCache() {
175   purge();
176 }
177 
retrieve(const ShaderLayoutCache::Key & key)178 VkDescriptorSetLayout ShaderLayoutCache::retrieve(
179     const ShaderLayoutCache::Key& key) {
180   std::lock_guard<std::mutex> lock(cache_mutex_);
181 
182   auto it = cache_.find(key);
183   if (cache_.cend() == it) {
184     it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first;
185   }
186 
187   return it->second.handle();
188 }
189 
purge()190 void ShaderLayoutCache::purge() {
191   std::lock_guard<std::mutex> lock(cache_mutex_);
192   cache_.clear();
193 }
194 
195 //
196 // ShaderCache
197 //
198 
ShaderCache(VkDevice device)199 ShaderCache::ShaderCache(VkDevice device)
200     : cache_mutex_{}, device_(device), cache_{} {}
201 
ShaderCache(ShaderCache && other)202 ShaderCache::ShaderCache(ShaderCache&& other) noexcept
203     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
204   std::lock_guard<std::mutex> lock(other.cache_mutex_);
205 }
206 
~ShaderCache()207 ShaderCache::~ShaderCache() {
208   purge();
209 }
210 
retrieve(const ShaderCache::Key & key)211 VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) {
212   std::lock_guard<std::mutex> lock(cache_mutex_);
213 
214   auto it = cache_.find(key);
215   if (cache_.cend() == it) {
216     it = cache_.insert({key, ShaderCache::Value(device_, key)}).first;
217   }
218 
219   return it->second.handle();
220 }
221 
purge()222 void ShaderCache::purge() {
223   cache_.clear();
224 }
225 
226 } // namespace api
227 } // namespace vulkan
228 } // namespace native
229 } // namespace at
230