1 #include <utility>
2
3 #include <ATen/native/vulkan/api/Shader.h>
4
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace api {
9
10 //
11 // ShaderInfo
12 //
13
ShaderInfo()14 ShaderInfo::ShaderInfo()
15 : src_code{
16 nullptr,
17 0u,
18 } {}
19
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout)20 ShaderInfo::ShaderInfo(
21 std::string name,
22 const uint32_t* const spirv_bin,
23 const uint32_t size,
24 std::vector<VkDescriptorType> layout)
25 : src_code{
26 spirv_bin,
27 size,
28 },
29 kernel_name{std::move(name)},
30 kernel_layout{std::move(layout)} {}
31
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout,const std::vector<uint32_t> & tile_size,const StorageType bias_storage_type,const StorageType weight_storage_type)32 ShaderInfo::ShaderInfo(
33 std::string name,
34 const uint32_t* const spirv_bin,
35 const uint32_t size,
36 std::vector<VkDescriptorType> layout,
37 const std::vector<uint32_t>& tile_size,
38 const StorageType bias_storage_type,
39 const StorageType weight_storage_type)
40 : src_code{
41 spirv_bin,
42 size,
43 },
44 kernel_name{std::move(name)},
45 kernel_layout{std::move(layout)},
46 tile_size(tile_size),
47 bias_storage_type(bias_storage_type),
48 weight_storage_type(weight_storage_type) {
49 for (uint64_t i = 0; i < tile_size.size(); ++i) {
50 out_tile_size.data[i] = tile_size[i];
51 }
52 }
53
operator ==(const ShaderInfo & _1,const ShaderInfo & _2)54 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
55 return (
56 _1.src_code.bin == _2.src_code.bin &&
57 _1.src_code.size == _2.src_code.size);
58 }
59
60 //
61 // ShaderLayout
62 //
63
ShaderLayout(VkDevice device,const ShaderLayout::Signature & signature)64 ShaderLayout::ShaderLayout(
65 VkDevice device,
66 const ShaderLayout::Signature& signature)
67 : device_(device), handle_{VK_NULL_HANDLE} {
68 std::vector<VkDescriptorSetLayoutBinding> bindings;
69
70 uint32_t binding_num = 0u;
71 for (const VkDescriptorType type : signature) {
72 bindings.push_back({
73 binding_num++, // binding
74 type, // descriptorType
75 1u, // descriptorCount
76 VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
77 nullptr, // pImmutableSamplers
78 });
79 }
80
81 const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
82 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
83 nullptr, // pNext
84 0u, // flags
85 static_cast<uint32_t>(bindings.size()), // bindingCount
86 bindings.data(), // pBindings
87 };
88
89 VK_CHECK(vkCreateDescriptorSetLayout(
90 device_, &descriptor_set_layout_create_info, nullptr, &handle_));
91 }
92
ShaderLayout(ShaderLayout && other)93 ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept
94 : device_(other.device_), handle_(other.handle_) {
95 other.handle_ = VK_NULL_HANDLE;
96 }
97
~ShaderLayout()98 ShaderLayout::~ShaderLayout() {
99 if (VK_NULL_HANDLE == handle_) {
100 return;
101 }
102 vkDestroyDescriptorSetLayout(device_, handle_, nullptr);
103 handle_ = VK_NULL_HANDLE;
104 }
105
swap(ShaderLayout & lhs,ShaderLayout & rhs)106 void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
107 VkDevice tmp_device = lhs.device_;
108 VkDescriptorSetLayout tmp_handle = lhs.handle_;
109
110 lhs.device_ = rhs.device_;
111 lhs.handle_ = rhs.handle_;
112
113 rhs.device_ = tmp_device;
114 rhs.handle_ = tmp_handle;
115 }
116
117 //
118 // ShaderModule
119 //
120
ShaderModule(VkDevice device,const ShaderInfo & source)121 ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source)
122 : device_(device), handle_{VK_NULL_HANDLE} {
123 const uint32_t* code = source.src_code.bin;
124 uint32_t size = source.src_code.size;
125
126 const VkShaderModuleCreateInfo shader_module_create_info{
127 VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
128 nullptr, // pNext
129 0u, // flags
130 size, // codeSize
131 code, // pCode
132 };
133
134 VK_CHECK(vkCreateShaderModule(
135 device_, &shader_module_create_info, nullptr, &handle_));
136 }
137
ShaderModule(ShaderModule && other)138 ShaderModule::ShaderModule(ShaderModule&& other) noexcept
139 : device_(other.device_), handle_(other.handle_) {
140 other.handle_ = VK_NULL_HANDLE;
141 }
142
~ShaderModule()143 ShaderModule::~ShaderModule() {
144 if (VK_NULL_HANDLE == handle_) {
145 return;
146 }
147 vkDestroyShaderModule(device_, handle_, nullptr);
148 handle_ = VK_NULL_HANDLE;
149 }
150
swap(ShaderModule & lhs,ShaderModule & rhs)151 void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept {
152 VkDevice tmp_device = lhs.device_;
153 VkShaderModule tmp_handle = lhs.handle_;
154
155 lhs.device_ = rhs.device_;
156 lhs.handle_ = rhs.handle_;
157
158 rhs.device_ = tmp_device;
159 rhs.handle_ = tmp_handle;
160 }
161
162 //
163 // ShaderLayoutCache
164 //
165
ShaderLayoutCache(VkDevice device)166 ShaderLayoutCache::ShaderLayoutCache(VkDevice device)
167 : cache_mutex_{}, device_(device), cache_{} {}
168
ShaderLayoutCache(ShaderLayoutCache && other)169 ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept
170 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
171 std::lock_guard<std::mutex> lock(other.cache_mutex_);
172 }
173
~ShaderLayoutCache()174 ShaderLayoutCache::~ShaderLayoutCache() {
175 purge();
176 }
177
retrieve(const ShaderLayoutCache::Key & key)178 VkDescriptorSetLayout ShaderLayoutCache::retrieve(
179 const ShaderLayoutCache::Key& key) {
180 std::lock_guard<std::mutex> lock(cache_mutex_);
181
182 auto it = cache_.find(key);
183 if (cache_.cend() == it) {
184 it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first;
185 }
186
187 return it->second.handle();
188 }
189
purge()190 void ShaderLayoutCache::purge() {
191 std::lock_guard<std::mutex> lock(cache_mutex_);
192 cache_.clear();
193 }
194
195 //
196 // ShaderCache
197 //
198
ShaderCache(VkDevice device)199 ShaderCache::ShaderCache(VkDevice device)
200 : cache_mutex_{}, device_(device), cache_{} {}
201
ShaderCache(ShaderCache && other)202 ShaderCache::ShaderCache(ShaderCache&& other) noexcept
203 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
204 std::lock_guard<std::mutex> lock(other.cache_mutex_);
205 }
206
~ShaderCache()207 ShaderCache::~ShaderCache() {
208 purge();
209 }
210
retrieve(const ShaderCache::Key & key)211 VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) {
212 std::lock_guard<std::mutex> lock(cache_mutex_);
213
214 auto it = cache_.find(key);
215 if (cache_.cend() == it) {
216 it = cache_.insert({key, ShaderCache::Value(device_, key)}).first;
217 }
218
219 return it->second.handle();
220 }
221
purge()222 void ShaderCache::purge() {
223 cache_.clear();
224 }
225
226 } // namespace api
227 } // namespace vulkan
228 } // namespace native
229 } // namespace at
230