#pragma once // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName #ifdef USE_VULKAN_API #include #include #include #define VK_KERNEL(shader_name) \ ::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name) #define VK_KERNEL_FROM_STR(shader_name_str) \ ::at::native::vulkan::api::shader_registry().get_shader_info(shader_name_str) namespace at { namespace native { namespace vulkan { namespace api { enum class DispatchKey : int8_t { CATCHALL, ADRENO, MALI, OVERRIDE, }; class ShaderRegistry final { using ShaderListing = std::unordered_map; using Dispatcher = std::unordered_map; using Registry = std::unordered_map; ShaderListing listings_; Dispatcher dispatcher_; Registry registry_; public: /* * Check if the registry has a shader registered under the given name */ bool has_shader(const std::string& shader_name); /* * Check if the registry has a dispatch registered under the given name */ bool has_dispatch(const std::string& op_name); /* * Register a ShaderInfo to a given shader name */ void register_shader(ShaderInfo&& shader_info); /* * Register a dispatch entry to the given op name */ void register_op_dispatch( const std::string& op_name, const DispatchKey key, const std::string& shader_name); /* * Given a shader name, return the ShaderInfo which contains the SPIRV binary */ const ShaderInfo& get_shader_info(const std::string& shader_name); }; class ShaderRegisterInit final { using InitFn = void(); public: ShaderRegisterInit(InitFn* init_fn) { init_fn(); }; }; // The global shader registry is retrieved using this function, where it is // declared as a static local variable. ShaderRegistry& shader_registry(); } // namespace api } // namespace vulkan } // namespace native } // namespace at #endif /* USE_VULKAN_API */