• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
10 
11 namespace vkcompute {
12 
add_storage_type_suffix(std::string & kernel_name,const utils::StorageType storage_type)13 void add_storage_type_suffix(
14     std::string& kernel_name,
15     const utils::StorageType storage_type) {
16   switch (storage_type) {
17     case utils::kBuffer:
18       kernel_name += "_buffer";
19       break;
20     case utils::kTexture3D:
21       kernel_name += "_texture3d";
22       break;
23     case utils::kTexture2D:
24       kernel_name += "_texture2d";
25       break;
26   }
27 }
28 
add_storage_type_suffix(std::string & kernel_name,const api::vTensor & tensor)29 void add_storage_type_suffix(
30     std::string& kernel_name,
31     const api::vTensor& tensor) {
32   return add_storage_type_suffix(kernel_name, tensor.storage_type());
33 }
34 
add_dtype_suffix(std::string & kernel_name,const vkapi::ScalarType dtype)35 void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) {
36   switch (dtype) {
37     case vkapi::kFloat:
38       kernel_name += "_float";
39       break;
40     case vkapi::kHalf:
41       kernel_name += "_half";
42       break;
43     case vkapi::kInt:
44       kernel_name += "_int";
45       break;
46     case vkapi::kChar:
47     case vkapi::kQInt8:
48       kernel_name += "_int8";
49       break;
50     case vkapi::kByte:
51     case vkapi::kQUInt8:
52       kernel_name += "_uint8";
53       break;
54     default:
55       break;
56   }
57 }
58 
add_dtype_suffix(std::string & kernel_name,const api::vTensor & tensor)59 void add_dtype_suffix(std::string& kernel_name, const api::vTensor& tensor) {
60   return add_dtype_suffix(kernel_name, tensor.dtype());
61 }
62 
add_ndim_suffix(std::string & kernel_name,const api::vTensor & tensor)63 void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor) {
64   switch (tensor.storage_type()) {
65     case utils::kTexture3D:
66       kernel_name += "_3d";
67       break;
68     case utils::kTexture2D:
69       kernel_name += "_2d";
70       break;
71     default:
72       break;
73   }
74 }
75 
add_packed_dim_suffix(std::string & kernel_name,const int32_t packed_dim)76 void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) {
77   switch (packed_dim) {
78     case WHCN::kWidthDim:
79       kernel_name += "_W_packed";
80       break;
81     case WHCN::kHeightDim:
82       kernel_name += "_H_packed";
83       break;
84     case WHCN::kChannelsDim:
85       kernel_name += "_C_packed";
86       break;
87     default:
88       VK_THROW("Invalid packed dim!");
89   }
90 }
91 
add_packed_dim_suffix(std::string & kernel_name,const api::vTensor & tensor)92 void add_packed_dim_suffix(
93     std::string& kernel_name,
94     const api::vTensor& tensor) {
95   return add_packed_dim_suffix(kernel_name, tensor.packed_dim());
96 }
97 
98 } // namespace vkcompute
99