• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/api/ShaderRegistry.h>
2 
3 namespace at {
4 namespace native {
5 namespace vulkan {
6 namespace api {
7 
has_shader(const std::string & shader_name)8 bool ShaderRegistry::has_shader(const std::string& shader_name) {
9   const ShaderListing::const_iterator it = listings_.find(shader_name);
10   return it != listings_.end();
11 }
12 
has_dispatch(const std::string & op_name)13 bool ShaderRegistry::has_dispatch(const std::string& op_name) {
14   const Registry::const_iterator it = registry_.find(op_name);
15   return it != registry_.end();
16 }
17 
register_shader(ShaderInfo && shader_info)18 void ShaderRegistry::register_shader(ShaderInfo&& shader_info) {
19   if (has_shader(shader_info.kernel_name)) {
20     VK_THROW(
21         "Shader with name ", shader_info.kernel_name, "already registered");
22   }
23   listings_.emplace(shader_info.kernel_name, shader_info);
24 }
25 
register_op_dispatch(const std::string & op_name,const DispatchKey key,const std::string & shader_name)26 void ShaderRegistry::register_op_dispatch(
27     const std::string& op_name,
28     const DispatchKey key,
29     const std::string& shader_name) {
30   if (!has_dispatch(op_name)) {
31     registry_.emplace(op_name, Dispatcher());
32   }
33   const Dispatcher::const_iterator it = registry_[op_name].find(key);
34   if (it != registry_[op_name].end()) {
35     registry_[op_name][key] = shader_name;
36   } else {
37     registry_[op_name].emplace(key, shader_name);
38   }
39 }
40 
get_shader_info(const std::string & shader_name)41 const ShaderInfo& ShaderRegistry::get_shader_info(
42     const std::string& shader_name) {
43   const ShaderListing::const_iterator it = listings_.find(shader_name);
44 
45   VK_CHECK_COND(
46       it != listings_.end(),
47       "Could not find ShaderInfo with name ",
48       shader_name);
49 
50   return it->second;
51 }
52 
shader_registry()53 ShaderRegistry& shader_registry() {
54   static ShaderRegistry registry;
55   return registry;
56 }
57 
58 } // namespace api
59 } // namespace vulkan
60 } // namespace native
61 } // namespace at
62